本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
模块是pytorch中模型的重要组成部分,往往我们需要获取指定的模块进行相关操作
因此,本文详细讲解,在pytorch中如何获取模型中所有子模块或者指定的某个模块
本节先展示如何在pytorch中获取模型里的指定模块,下节再详述各个API
获取模块的相关函数
要获取模型里包含的模块对象,一般使用的API如下:
get_submodule :通过模块名称获取模块
children :获取模块的所有子模块,不带名称
named_children :获取模块的所有子模块,带名称
modules :获取模块包含的所有模块 ,不带名称
named_modules :获取模块包含的所有模块,带名称
总的来说,get_submodule获得单个模块,children获得子模块,modules获得所有模块
示例:如何获取pytorch模型的指定模块
下面展示如何综合利用各个获取模块的函数,来达到获取模块的子模块、孙模块、以及Sequential里的模块
具体示例代码如下:
from torch import nn
from collections import OrderedDict
# ---------定义模型-------------------
class subModel(nn.Module):
def __init__(self):
super(subModel, self).__init__()
self.A = nn.Linear(2,3)
self.B = nn.Linear(3,5)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.L2 = subModel()
self.C3 = nn.Sequential(OrderedDict({
'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
'R': nn.ReLU(inplace=True)
}))
model = Model()
print('\n--------model:-------\n',model) # 打印模型结构
# ---------访问子模块-------------------
# 方法1:通过属性访问
L1 = model.L1
# 方法2:通过名字访问
L1 = dict(model.named_children())['L1']
L1 = model.get_submodule('L1')
# --------访问孙模块----------------
# 方法1:通过属性层层访问
A = model.L2.A
# 方法2:通过名字层层访问
L2 = dict(model.named_children())['L2']
A = dict(L2.named_children())['A']
# 方法3:通过全称名字访问(两种方式)
A = model.get_submodule('L2.A')
A = dict(model.named_modules())['L2.A']
# ------访问Sequential里的模块---------------------------
# 方法1:通过索引访问
Cov = model.C3[0]
# 方法2:通过名称访问
Cov = dict(model.C3.named_children())['Cov']
# 方法3:通过全称访问(两种方式)
Cov = model.get_submodule('C3.Cov')
Cov = dict(model.named_modules())['C3.Cov']
print('\n-------各个子模块的结果---------')
print('L1:' ,L1 )
print('A:' ,A )
print('Cov:' ,Cov )
运行结果如下:
可以看到,L1是子模块,A是孙模块,Cov是包含在 Sequential里的模块,都能成功地获取
本节具体讲解pytorch里各个与获取模块相关的API
get_submodule方法
get_submodule方法可以根据子模块的名字,来获取指定的模块
get_submodule的具体使用方法如下:
from torch import nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.C2 = nn.Sequential(OrderedDict({
'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
'R': nn.ReLU(inplace=True)
}))
model = Model() # 初始化模型
Cov = model.get_submodule('C2.Cov') # 通过get_submodule获得指定模块
print(Cov) # 打印结果
运行结果如下:
可以看到,get_submodule成功获取到了C2下的Cov模块
children与named_children
children()和named_children()的作用如下:
👉children:model.children()获取model里的各个子模块,不带模块名字
👉named_children:model.named_children()获取model里的各个子模块,带模块名字
由于两者返回的都是迭代器,所以一般将children()转换回list对象,方便按索引读取各个子模块
一般将named_children()转换回dict对象,方便按名字读取各个子模块
children()和named_children()的具体使用示例如下:
from torch import nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.C2 = nn.Sequential(OrderedDict({
'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
'R': nn.ReLU(inplace=True)
}))
model = Model() # 初始化模型
module_list = list(model.children()) # 通过children获得所有子模块,并转为列表
print('\n children()的返回结果:',module_list) # 打印children的返回结果
module_dict = dict(model.named_children()) # 通过named_children获得所有子模块,并转为字典
print('\n named_children()的返回结果:',module_dict) # 打印named_children的返回结果
运行结果如下:
可以看到,children的返回结果是不带名字的,而named_modules则带有模块名字
两者都返回了model所包含的两个子模块
modules与named_modules
modules()和named_modules()的作用如下:
👉modules:model.modules()获取model里的所有子孙模块,包括自身(不带模块名字)
👉named_modules:model.named_modules()获取model里的所有子孙模块,包括自身(带模块名字)
由于两者返回的都是迭代器,所以一般将modules()转换回list对象,方便按索引读取各个子模块
一般将named_modules()转换回dict对象,方便按名字读取各个子模块
modules()和named_modules()的具体使用示例如下:
from torch import nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.C2 = nn.Sequential(OrderedDict({
'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
'R': nn.ReLU(inplace=True)
}))
model = Model() # 初始化模型
module_list = list(model.modules()) # 通过modules获得所有子孙模块,并转为列表
print('\n------modules()的返回结果:-----\n',module_list) # 打印modules的返回结果
module_dict = dict(model.named_modules()) # 通过named_modules获得所有子孙模块,并转为字典
print('\n------named_modules()的返回结果:------\n',module_dict) # 打印named_modules的返回结果
运行结果如下:
以named_modules的结果为例,可以看到,它共返回了五个模块:
第一个是model自身
第二、三个则是model的子模块L1、C2
第四、五个则是C2的子模块(即model的孙模块)C2.Cov和C2.R
关于如何获取Sequential里的模块
Sequential是一种特殊的Module,除了get_submodule、children和modules方法之外
Sequential对象还可以直接根据索引来访问它的子模块
具体示例如下:
from torch import nn
from collections import OrderedDict
model = nn.Sequential(OrderedDict({
'Cov': nn.Conv2d(3,4, kernel_size=5,stride=1,padding=2),
'R': nn.ReLU(inplace=True)
}))
Cov = model[0] # 通过索引获得子模块
print(Cov) # 打印结果
运行结果如下:
可以看到,Sequential对象只需要直接通过索引就能访问所包含的子模块
End