Pytorch教程

【示例】pytorch用一个模型初始化另一个模型

作者 : 老饼 发表日期 : 2024-03-14 04:17:19 更新日期 : 2024-03-19 12:36:59
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com





     01. pytorch用一个模型初始化另一个模型    




本节展示在pytorch中,如何用一个模型去初始化另一个模型的方法




      pytorch用一个模型初始化另一个模型      


在pytorch的实际使用中,我们往往需要将一个模型的参数拷贝给另一个模型,
此时只需用state_dict将旧模型的参数提取出来,再用load_state_dict赋给新模型就可以
具体示例如下:
from   torch import nn
from torch.nn import functional as F
# ---------模型定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(2,3)                                  # 定义第一个线性层模块
        self.L2 = nn.Linear(3,2)                                  # 定义第二个线性层模块
                                                                 
    def forward(self, x):                                        
        y = F.relu(self.L1(x))                                    # 第一层的计算
        y = self.L2(y)                                            # 第二层的计算
        return y                                                 
model = Model()                                                  
new_model = Model()                                              
state_dict = model.state_dict()                                   # 提取旧模型的参数字典
new_model.load_state_dict(state_dict)                             # 将旧模型的参数赋给新模型
print('\n-----旧模型的参数列表:----\n',model.state_dict())       # 打印旧模型的参数
print('\n-----新模型的参数列表:----\n',new_model.state_dict())   # 打印新模型的参数
运行结果如下:
  
可以看到,新旧模型的参数是一样的










 End 








联系老饼