Pytorch教程

【示例】pytorch模型-简单删除层

作者 : 老饼 发表日期 : 2024-03-12 05:44:03 更新日期 : 2024-03-24 17:19:47
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com




在pytorch模型应用中,我们往往需要对构建好的模型进行微小修改,

本节讲解pyroch中一些简单场景时如何往原有的模型中再删除一层




     01. pytorch模型-简单删除层     




本节展示pytorch中,如果删除层后不会影响forward时,应该怎么删除层




     pytorch模型-删除某层(与forward不冲突)     


在pytorch中删除模型的层是一种较为复杂的场景,因为删除层之后,可能会导致forward流程出现错误
但如果模型删除某层并不影响原有模型forward的代码逻辑,则是较为简单的,
这种场景下,删除该层只需找到该层对象,将其删除即可,不需再作任何处理
 具体示例如下:
from   torch import nn

# --------------定义模型------------------------------
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.Cov = nn.Sequential(
            nn.Conv2d(3,4, kernel_size=3,stride=1,padding=1),
            nn.Conv2d(4,4, kernel_size=3,stride=1,padding=1),       # 删除该层不影响模型的forward
            nn.ReLU(inplace=True),  
            nn.AvgPool2d(kernel_size=2,stride=2),
            nn.Flatten()
            )
        self.FC =  nn.Linear(784,5)

    def forward(self, x):
        y = self.Cov(x)           
        y = self.FC(y)             
        return y

# -----------打印模型结构-------------------
model = ConvNet()                                                       # 初始化模型
print('\n原模型结构:',model)                                            # 打印模型结构

# -------在原有模型的基础上,删除一层----------
del model.Cov[1]                                                        # 删除cov里的第2个卷积层
print('\n修改后的模型结构:',model)                                      # 打印修改后的模型结构
运行结果如下
  
可以看到,已经删除了Cov里的第2个卷积层






    02. pytorch模型-删除某些层(与forward冲突)   




本节展示pytorch中,如果删除层后会影响forward时,应该怎么删除层




     pytorch模型-删除某些层(与forward冲突)     


如果模型删除某层会与原有的forward代码逻辑冲突,那么直接删除该层会导致forward报错
因此,此时需要重写forward逻辑,可以参考《》
这里我们列举一种简单的情况:
如果删除的层的父模块的forward只是简单的将每个子模块进行前馈,
那么我们只需要删除目标层后,将剩余所有子模块重新封装成Sequential就可以
 具体示例如下:
from   torch import nn
import torch
from collections import OrderedDict
# --------------定义模型------------------------------
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.L1 =  nn.Linear(3,5)
        self.R1 = nn.ReLU(inplace=True)
        
        self.L2 =  nn.Linear(5,5)                                 # 如果删除该层,forward会因为找不到self.L2而报错
        self.R2 = nn.ReLU(inplace=True)
        
        self.L3 =  nn.Linear(5,5)
        self.R3 = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.R1(self.L1(x))     
        y = self.R2(self.L2(y))     
        y = self.R3(self.L2(y))              
        return y

# -----------打印模型结构-------------------
model = ConvNet()                                                 # 初始化模型
print('\n原模型结构:',model)                                      # 打印模型结构

# -------在原有模型的基础上,删除一些层,并重新封装模型----------
del model.L2                                                      # 删除L2层
del model.R2                                                      # 删除R2层
model = nn.Sequential(OrderedDict(dict(model.named_children())))  # 将模型剩余层按Sequential重新封装成模型
# model = nn.Sequential(*list(model.children()))                  # 也可以用这种方式,但会丢失模块名称
print('\n修改后的模型结构:',model)                                # 打印修改后的模型结构
model(torch.tensor([1.,2.,3.]))                                   # 修改后模型仍然可以正常执行
运行结果如下:
  
 可以看到,L2和R2都被删除了







     pytorch模型-保留某些层(与forward冲突)     


删除某些层,从另一个角度来说,也可以认为是保留某些层
特别是当删除的层比较多的时候,往往更倾向于以"保留一些层"的方式来实现
此时,只需要从原模型中抽取出需要保留的层,重新生成模型就可以了
 例如,要删除模型第2层之后的层,那么只需要抽出模型的前2层来重新生成模型就可以了
 具体实现示例如下:
from   torch import nn
import torch
# --------------定义模型------------------------------
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.L1 =  nn.Linear(3,5)
        self.R1 = nn.ReLU(inplace=True)
        
        self.L2 =  nn.Linear(5,5)                     # 如果删除该层,forward会因为找不到self.L2而报错
        self.R2 = nn.ReLU(inplace=True)
        
        self.L3 =  nn.Linear(5,5)
        self.R3 = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.R1(self.L1(x))     
        y = self.R2(self.L2(y))     
        y = self.R3(self.L2(y))              
        return y

# -----------打印模型结构-------------------
model = ConvNet()                                     # 初始化模型
print('\n原模型结构:',model)                          # 打印模型结构

# -------只抽取模型的前两层作为新模型----------
model = nn.Sequential(*list(model.children())[0:2])   # 也可以用这种方式,但会丢失模块名称
print('\n修改后的模型结构:',model)                    # 打印修改后的模型结构
model(torch.tensor([1.,2.,3.]))                       # 修改后模型仍然可以正常执行
运行结果如下:
  
可以看到,修改后的模型只保留了前两层












 End 








联系老饼