Pytorch教程

【示例】pytorch模型-简单修改层

作者 : 老饼 发表日期 : 2024-03-12 06:24:26 更新日期 : 2024-03-24 17:21:00
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com




在pytorch模型应用中,我们往往需要对构建好的模型的某层进行修改,

本节讲解pyroch中一些简单场景时如何修改原有模型的中某一层的具体操作方法





     01.pytorch模型-简单修改层    



本节讲解pytorch中如何修改模型的某一层



      pytorch简单修改模型的某层     


假设我们已经有一个模型,我们需要修改模型的某一层(即某模块),只需直接在原模型对象上进行修改就可以
例如,原模型有Cov层与FC层,原模型的FC层输出为1*5,
现在我们希望将其改为1*3,则只需直接定义一个FC来替换原模型的FC层就可以
具体代码示例如下:
import torch
from   torch import nn

# --------------原模型的结构与输出------------------------------
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.Cov = nn.Sequential(
            nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
            nn.ReLU(inplace=True),  
            nn.AvgPool2d(kernel_size=2,stride=2),
            nn.Flatten()
            )
        self.FC =  nn.Linear(784,5)

    def forward(self, x):
        y = self.Cov(x)           
        y = self.FC(y)             
        return y

#----------原模型的输出结果-------------------------
model = ConvNet()                                     # 初始化模型
img   = torch.randn(1, 3, 28, 28)                     # 输入图像
print('原模型的输出:',model(img).data)                # 模型输出

# -----在原有模型的基础上,替换全连接层,并打印结果---
model.FC = nn.Linear(784,3)                           # 修改模型的全连接层
print('修改后的模型输出:',model(img).data)            # 打印修改后的模型输出
运行结果如下:
  
可以看到,模块的输出由原来的5个输出改为了3个输出






      pytorch修改在Sequential里的层     


如果需要修改的层包含在Sequential里,则需要根据所在层的位置来索引出层对象,然后进行修改
仍然以上述模型为例,如果我们要修改卷积层,它处于Cov对象的第0个位置,即修改model.Cov[0]就可以
具体示例代码如下:
from   torch import nn

# --------------定义模型------------------------------
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.Cov = nn.Sequential(
            nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
            nn.ReLU(inplace=True),  
            nn.AvgPool2d(kernel_size=2,stride=2),
            nn.Flatten()
            )
        self.FC =  nn.Linear(784,5)

    def forward(self, x):
        y = self.Cov(x)           
        y = self.FC(y)             
        return y
    
# -----------打印模型结构-------------------
model = ConvNet()                                                       # 初始化模型
print('\n原模型结构:',model)                                            # 打印模型结构

# -------在原有模型的基础上,修改Sequential里的卷积层---------
model.Cov[0] = nn.Conv2d(3,4, kernel_size=3,stride=1,padding=1)         # 修改模型的卷积层
print('\n修改后的模型结构:',model)                                      # 打印修改后的模型结构
 运行结果如下:
  
可以看到,已经成功修改模型Cov里的卷积层












 End 







联系老饼