本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本节展示如何使用模型继续的方法来修改模型,以及模型的forward方法
如何修改pytorch模型的forward
往往在模型的一些复杂的修改中,涉及到模型forward逻辑的修改,这时就需要用到类的继承,
即重新定义一个模型,并把旧模型的所有模块、属性、方法继承下来,然后再继续修改
使用类继承的方法,不仅可以修改forward,同时也可以更灵活的增删模块,可以较好地应对模型各种复杂的修改
示例一:用模型继承来修改模型结构与forward
当前需求:现有一个旧模型,它最后一个线性层为2个输出,而现在要其最后一层修改为3个输出,
处理方法:先将旧模型所有模块继续下来,然后再定义出我们最后的输出层,
最后重写forward的逻辑,将我们新定义的输出层作为最后一层的计算逻辑
具体示例如下:
from torch import nn
import torch
from torch.nn import functional as F
# -----------原模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(2,3) # 定义第一个线性层模块
self.L2 = nn.Linear(3,2) # 定义第二个线性层模块
def forward(self, x):
y = F.relu(self.L1(x)) # 第一层的计算
y = self.L2(y) # 第二层的计算
return y
# -----新模型的定义(在旧模型上进行改造)-----
class NewModel(Model):
def __init__(self):
super(NewModel, self).__init__()
self.L2_new = nn.Linear(3,3) # 定义新的模块
def forward(self,x):
# 重写forward的逻辑
y = F.relu(self.L1(x)) # 第一层沿用旧模块的计算
y = self.L2_new(y) # 第二层按新的模块计算
return y
# ------打印新旧模型的计算结果------------
x = torch.tensor([1.,1.])
# 打印旧模型的计算结果
model = Model()
print('旧模型的输出:',model(x))
# 打印新模型的计算结果
new_model = NewModel()
print('新模型的输出',new_model(x))
运行结果如下:
可以看到,旧模型只有2个输出,而新模型有3个输出
在本代码中,不仅可以加入新的模块,还可以自由地编写forward的逻辑,
因此,使用该方法,可以非常灵活地处理各种修改需求
示例二:用模型继承来修改模型结构与forward(并继承参数初始化方法)
场景与需求 :原模型中已经带有初始化参数的方法,现在新模型不但要继续旧模型的结构,并且继承旧模型初始化的参数
解决方案 :只需要在新模型初始化时,先完成旧模型参数的初始化,然后再对新模型进行修改就可以
示例如下所示:
from torch import nn
import torch
from torch.nn import functional as F
# -----------旧模型的定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(2,3) # 定义第一个线性层模块
self.L2 = nn.Linear(3,2) # 定义第二个线性层模块
def forward(self, x):
y = F.relu(self.L1(x)) # 第一层的计算
y = self.L2(y) # 第二层的计算
return y
def initParams(self): # 初始化参数
torch.nn.init.ones_(self.L1.get_parameter('weight')) # 初始化L1的权重
torch.nn.init.zeros_(self.L1.get_parameter('bias')) # 初始化L1的阈值
torch.nn.init.ones_(self.L2.get_parameter('weight')) # 初始化L2的权重
torch.nn.init.zeros_(self.L2.get_parameter('bias')) # 初始化L2的阈值
# ---------新模型的定义(在旧模型上进行改造)------
class NewModel(Model):
def __init__(self,init_old=False):
# ------继承旧模型的初始化----------------
super(NewModel,self).__init__()
if(init_old==True): # 如果沿用旧模型的参数
super(NewModel,self).initParams() # 关键代码:先触发旧模型的参数初始化
# -----新增的模块------------
self.L2_new = nn.Linear(3,3) # 定义新的模块
def forward(self,x): # 重写forward的逻辑
y = F.relu(self.L1(x)) # 第一层沿用旧模块的计算
y = self.L2_new(y) # 第二层按新的模块计算
return y
def initParams(self):
torch.nn.init.ones_(self.L2_new.get_parameter('weight')) # 初始化L2_new的权重
torch.nn.init.zeros_(self.L2_new.get_parameter('bias')) # 初始化L2_new的阈值
# ------打印新旧模型的计算结果-------------------
# 打印旧模型的计算结果
x = torch.tensor([1.,1.])
model = Model()
model.initParams()
print('\n旧模型的输出:',model(x))
# 打印新模型的计算结果
new_model = NewModel(init_old=True)
new_model.initParams()
print('新模型的输出',new_model(x))
运行结果如下:
可以看到,新模型成功继承了旧模型的初始化参数
End