Pytorch教程

【示例】pytorch模型-修改forward

作者 : 老饼 发表日期 : 2024-03-12 04:41:27 更新日期 : 2024-03-20 12:15:14
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com






    01. pytorch模型-修改forward    




本节展示如何使用模型继续的方法来修改模型,以及模型的forward方法




     如何修改pytorch模型的forward    


往往在模型的一些复杂的修改中,涉及到模型forward逻辑的修改,这时就需要用到类的继承,
即重新定义一个模型,并把旧模型的所有模块、属性、方法继承下来,然后再继续修改
使用类继承的方法,不仅可以修改forward,同时也可以更灵活的增删模块,可以较好地应对模型各种复杂的修改





      示例一:用模型继承来修改模型结构与forward     


当前需求:现有一个旧模型,它最后一个线性层为2个输出,而现在要其最后一层修改为3个输出,
处理方法:先将旧模型所有模块继续下来,然后再定义出我们最后的输出层,                               
最后重写forward的逻辑,将我们新定义的输出层作为最后一层的计算逻辑 
 具体示例如下:
from   torch import nn
import torch
from torch.nn import functional as F
# -----------原模型定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()            
        self.L1 = nn.Linear(2,3)                 # 定义第一个线性层模块
        self.L2 = nn.Linear(3,2)                 # 定义第二个线性层模块
                                                 
    def forward(self, x):                        
        y = F.relu(self.L1(x))                   # 第一层的计算
        y = self.L2(y)                           # 第二层的计算
        return y

# -----新模型的定义(在旧模型上进行改造)-----
class NewModel(Model):
    def __init__(self):
        super(NewModel, self).__init__()
        self.L2_new =  nn.Linear(3,3)            # 定义新的模块
    def forward(self,x):
        # 重写forward的逻辑
        y = F.relu(self.L1(x))                   # 第一层沿用旧模块的计算
        y = self.L2_new(y)                       # 第二层按新的模块计算
        return y 

# ------打印新旧模型的计算结果------------
x = torch.tensor([1.,1.])

# 打印旧模型的计算结果
model     = Model()                
print('旧模型的输出:',model(x))

# 打印新模型的计算结果
new_model = NewModel()                  
print('新模型的输出',new_model(x))
运行结果如下:
  
可以看到,旧模型只有2个输出,而新模型有3个输出
在本代码中,不仅可以加入新的模块,还可以自由地编写forward的逻辑,
因此,使用该方法,可以非常灵活地处理各种修改需求






      示例二:用模型继承来修改模型结构与forward(并继承参数初始化方法)     


场景与需求 :原模型中已经带有初始化参数的方法,现在新模型不但要继续旧模型的结构,并且继承旧模型初始化的参数
解决方案    :只需要在新模型初始化时,先完成旧模型参数的初始化,然后再对新模型进行修改就可以                          
 示例如下所示:
from   torch import nn
import torch
from torch.nn import functional as F
# -----------旧模型的定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(2,3)                                 # 定义第一个线性层模块
        self.L2 = nn.Linear(3,2)                                 # 定义第二个线性层模块
                                                                 
    def forward(self, x):                                        
        y = F.relu(self.L1(x))                                   # 第一层的计算
        y = self.L2(y)                                           # 第二层的计算
        return y                                                 
                                                                 
    def initParams(self):                                        # 初始化参数
        torch.nn.init.ones_(self.L1.get_parameter('weight'))     # 初始化L1的权重
        torch.nn.init.zeros_(self.L1.get_parameter('bias'))      # 初始化L1的阈值
        torch.nn.init.ones_(self.L2.get_parameter('weight'))     # 初始化L2的权重
        torch.nn.init.zeros_(self.L2.get_parameter('bias'))      # 初始化L2的阈值
                                                                 
# ---------新模型的定义(在旧模型上进行改造)------                
class NewModel(Model):                                           
    def __init__(self,init_old=False):                           
        # ------继承旧模型的初始化----------------               
        super(NewModel,self).__init__()                          
        if(init_old==True):                                      # 如果沿用旧模型的参数
            super(NewModel,self).initParams()                    # 关键代码:先触发旧模型的参数初始化
                                                                 
        # -----新增的模块------------                            
        self.L2_new =  nn.Linear(3,3)                            # 定义新的模块
                                                                 
    def forward(self,x):                                         # 重写forward的逻辑
        y = F.relu(self.L1(x))                                   # 第一层沿用旧模块的计算
        y = self.L2_new(y)                                       # 第二层按新的模块计算
        return y                                                 
                                                                 
    def initParams(self):                                        
        torch.nn.init.ones_(self.L2_new.get_parameter('weight')) # 初始化L2_new的权重
        torch.nn.init.zeros_(self.L2_new.get_parameter('bias'))  # 初始化L2_new的阈值

# ------打印新旧模型的计算结果-------------------
# 打印旧模型的计算结果
x = torch.tensor([1.,1.])
model = Model()
model.initParams()             
print('\n旧模型的输出:',model(x))

# 打印新模型的计算结果
new_model = NewModel(init_old=True)
new_model.initParams()         
print('新模型的输出',new_model(x))
运行结果如下: 
 
可以看到,新模型成功继承了旧模型的初始化参数












 End 







联系老饼