本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
在pytorch模型应用中,我们往往需要对构建好的模型的某层进行修改,
本节讲解pyroch中一些简单场景时如何修改原有模型的中某一层的具体操作方法
本节讲解pytorch中如何修改模型的某一层
pytorch简单修改模型的某层
假设我们已经有一个模型,我们需要修改模型的某一层(即某模块),只需直接在原模型对象上进行修改就可以
例如,原模型有Cov层与FC层,原模型的FC层输出为1*5,
现在我们希望将其改为1*3,则只需直接定义一个FC来替换原模型的FC层就可以
具体代码示例如下:
import torch
from torch import nn
# --------------原模型的结构与输出------------------------------
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.Cov = nn.Sequential(
nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Flatten()
)
self.FC = nn.Linear(784,5)
def forward(self, x):
y = self.Cov(x)
y = self.FC(y)
return y
#----------原模型的输出结果-------------------------
model = ConvNet() # 初始化模型
img = torch.randn(1, 3, 28, 28) # 输入图像
print('原模型的输出:',model(img).data) # 模型输出
# -----在原有模型的基础上,替换全连接层,并打印结果---
model.FC = nn.Linear(784,3) # 修改模型的全连接层
print('修改后的模型输出:',model(img).data) # 打印修改后的模型输出
运行结果如下:
可以看到,模块的输出由原来的5个输出改为了3个输出
pytorch修改在Sequential里的层
如果需要修改的层包含在Sequential里,则需要根据所在层的位置来索引出层对象,然后进行修改
仍然以上述模型为例,如果我们要修改卷积层,它处于Cov对象的第0个位置,即修改model.Cov[0]就可以
具体示例代码如下:
from torch import nn
# --------------定义模型------------------------------
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.Cov = nn.Sequential(
nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=2,stride=2),
nn.Flatten()
)
self.FC = nn.Linear(784,5)
def forward(self, x):
y = self.Cov(x)
y = self.FC(y)
return y
# -----------打印模型结构-------------------
model = ConvNet() # 初始化模型
print('\n原模型结构:',model) # 打印模型结构
# -------在原有模型的基础上,修改Sequential里的卷积层---------
model.Cov[0] = nn.Conv2d(3,4, kernel_size=3,stride=1,padding=1) # 修改模型的卷积层
print('\n修改后的模型结构:',model) # 打印修改后的模型结构
运行结果如下:
可以看到,已经成功修改模型Cov里的卷积层
End