本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本节展示在pytorch中,如何用一个模型去初始化另一个模型的方法
pytorch用一个模型初始化另一个模型
在pytorch的实际使用中,我们往往需要将一个模型的参数拷贝给另一个模型,
此时只需用state_dict将旧模型的参数提取出来,再用load_state_dict赋给新模型就可以
具体示例如下:
from torch import nn
from torch.nn import functional as F
# ---------模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(2,3) # 定义第一个线性层模块
self.L2 = nn.Linear(3,2) # 定义第二个线性层模块
def forward(self, x):
y = F.relu(self.L1(x)) # 第一层的计算
y = self.L2(y) # 第二层的计算
return y
model = Model()
new_model = Model()
state_dict = model.state_dict() # 提取旧模型的参数字典
new_model.load_state_dict(state_dict) # 将旧模型的参数赋给新模型
print('\n-----旧模型的参数列表:----\n',model.state_dict()) # 打印旧模型的参数
print('\n-----新模型的参数列表:----\n',new_model.state_dict()) # 打印新模型的参数
运行结果如下:
可以看到,新旧模型的参数是一样的
End