Pytorch教程

【示例】pytorch获取模型里的模块

作者 : 老饼 发表日期 : 2024-03-12 22:38:47 更新日期 : 2024-10-28 15:48:45
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



模块是pytorch中模型的重要组成部分,往往我们需要获取指定的模块进行相关操作

因此,本文详细讲解,在pytorch中如何获取模型中所有子模块或者指定的某个模块





     01. pytorch获取模型里的指定模块     




本节先展示如何在pytorch中获取模型里的指定模块,下节再详述各个API




       获取模块的相关函数    


要获取模型里包含的模块对象,一般使用的API如下:
 get_submodule   :通过模块名称获取模块                               
 children               :获取模块的所有子模块,不带名称              
 named_children  :获取模块的所有子模块,带名称                  
 modules              :获取模块包含的所有模块 ,不带名称          
 named_modules :获取模块包含的所有模块,带名称              
总的来说,get_submodule获得单个模块,children获得子模块,modules获得所有模块




      示例:如何获取pytorch模型的指定模块    


下面展示如何综合利用各个获取模块的函数,来达到获取模块的子模块、孙模块、以及Sequential里的模块
 具体示例代码如下:
from   torch import nn
from collections import OrderedDict

# ---------定义模型-------------------
class subModel(nn.Module):
    def __init__(self):
        super(subModel, self).__init__()
        self.A = nn.Linear(2,3)     
        self.B = nn.Linear(3,5)     
        
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.L2 = subModel()
        self.C3 = nn.Sequential(OrderedDict({
            'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
            'R': nn.ReLU(inplace=True)
            }))
model = Model()                                
print('\n--------model:-------\n',model)         # 打印模型结构  
# ---------访问子模块-------------------
# 方法1:通过属性访问
L1 = model.L1          
                       
# 方法2:通过名字访问
L1 = dict(model.named_children())['L1']
L1 = model.get_submodule('L1')

# --------访问孙模块----------------
# 方法1:通过属性层层访问
A  = model.L2.A             

# 方法2:通过名字层层访问            
L2 = dict(model.named_children())['L2']       
A  = dict(L2.named_children())['A']

# 方法3:通过全称名字访问(两种方式)
A = model.get_submodule('L2.A')
A = dict(model.named_modules())['L2.A']

# ------访问Sequential里的模块---------------------------
# 方法1:通过索引访问  
Cov = model.C3[0]

# 方法2:通过名称访问  
Cov = dict(model.C3.named_children())['Cov']

# 方法3:通过全称访问(两种方式)  
Cov = model.get_submodule('C3.Cov')
Cov = dict(model.named_modules())['C3.Cov']

print('\n-------各个子模块的结果---------')
print('L1:'   ,L1   )
print('A:'    ,A    )
print('Cov:'  ,Cov  )
运行结果如下:
   
可以看到,L1是子模块,A是孙模块,Cov是包含在 Sequential里的模块,都能成功地获取 








     02. pytorch各个获取模块的API-详述     




本节具体讲解pytorch里各个与获取模块相关的API




     get_submodule方法      


get_submodule方法可以根据子模块的名字,来获取指定的模块
 get_submodule的具体使用方法如下:
from   torch import nn
from collections import OrderedDict
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.C2 = nn.Sequential(OrderedDict({
            'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
            'R': nn.ReLU(inplace=True)
            }))
model = Model()                              # 初始化模型
Cov = model.get_submodule('C2.Cov')          # 通过get_submodule获得指定模块 
print(Cov)                                   # 打印结果
运行结果如下:
  
可以看到,get_submodule成功获取到了C2下的Cov模块






      children与named_children     


children()和named_children()的作用如下:
👉children:model.children()获取model里的各个子模块,不带模块名字                       
👉named_children:model.named_children()获取model里的各个子模块,带模块名字
  由于两者返回的都是迭代器,所以一般将children()转换回list对象,方便按索引读取各个子模块
一般将named_children()转换回dict对象,方便按名字读取各个子模块

children()和named_children()的具体使用示例如下:
from   torch import nn
from collections import OrderedDict
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.C2 = nn.Sequential(OrderedDict({
            'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
            'R': nn.ReLU(inplace=True)
            }))
model = Model()                                       # 初始化模型
module_list = list(model.children())                  # 通过children获得所有子模块,并转为列表
print('\n children()的返回结果:',module_list)          # 打印children的返回结果
module_dict = dict(model.named_children())           # 通过named_children获得所有子模块,并转为字典
print('\n named_children()的返回结果:',module_dict)    # 打印named_children的返回结果
运行结果如下: 
  
可以看到,children的返回结果是不带名字的,而named_modules则带有模块名字
两者都返回了model所包含的两个子模块






      modules与named_modules    


modules()和named_modules()的作用如下:
👉modules:model.modules()获取model里的所有子孙模块,包括自身(不带模块名字)                 
      👉named_modules:model.named_modules()获取model里的所有子孙模块,包括自身(带模块名字)
 由于两者返回的都是迭代器,所以一般将modules()转换回list对象,方便按索引读取各个子模块
一般将named_modules()转换回dict对象,方便按名字读取各个子模块

modules()和named_modules()的具体使用示例如下:
from   torch import nn
from collections import OrderedDict
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.C2 = nn.Sequential(OrderedDict({
            'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
            'R': nn.ReLU(inplace=True)
            }))
model = Model()                                                   # 初始化模型
module_list = list(model.modules())                               # 通过modules获得所有子孙模块,并转为列表
print('\n------modules()的返回结果:-----\n',module_list)           # 打印modules的返回结果
module_dict = dict(model.named_modules())                         # 通过named_modules获得所有子孙模块,并转为字典
print('\n------named_modules()的返回结果:------\n',module_dict)    # 打印named_modules的返回结果
 运行结果如下:
  
以named_modules的结果为例,可以看到,它共返回了五个模块:
第一个是model自身                                                                  
第二、三个则是model的子模块L1、C2                                      
第四、五个则是C2的子模块(即model的孙模块)C2.Cov和C2.R    





      关于如何获取Sequential里的模块     


Sequential是一种特殊的Module,除了get_submodule、children和modules方法之外 
Sequential对象还可以直接根据索引来访问它的子模块
具体示例如下:
from   torch import nn
from collections import OrderedDict
model = nn.Sequential(OrderedDict({
    'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
    'R': nn.ReLU(inplace=True)
    }))
Cov = model[0]             # 通过索引获得子模块
print(Cov)                 # 打印结果
运行结果如下:
  
可以看到,Sequential对象只需要直接通过索引就能访问所包含的子模块











 End 









联系老饼