本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
pytorch为了更灵活的设置模型中的子模块,为此提供了ModuleDict与ModuleList
ModuleDict与ModuleList主要用于替代普通的Dict和List来存放Module字典及列表
本文讲解pytorch中ModuleDict与ModuleList的作用以及展示一个例子用于说明它们的使用方式
本节讲解ModuleDict与ModuleList有什么作用,以及如何使用
ModuleDict与ModuleList有什么用
往往我们需要将子模块用字典或列表形式存起来,以方便在forward中使用,
但是,直接将子模块写在Dict或List里是不能被父模块作为子模块所捕捉的
这意味着,写在Dict和List里的子模块所定义的参数也不能被父模块捕捉到
但在一些使用场景中,我们就是需要将子模块存到Dict或List里,
因此,pytorch提供了ModuleDict与ModuleList对象,
在需要使用Dict和List时,可以使用ModuleDict与ModuleList来替代,它们有类似字典/列表的功能
ModuleDict使用示例
下面展示ModuleDict的使用方式,它可以将子模块存放到类似字典的对象中
# 本代码用于展示ModuleDict的使用方法
from torch import nn
# -----------定义模型-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# 将两个线性模块存放在字典中,此时需要使用ModuleDict方法
self.L = nn.ModuleDict({
'L1':nn.Linear(2,3),
'L2':nn.Linear(2,3),
})
def forward(self, x):
# 定义模型的计算
y1 = nn.functional.tanh(self.L['L1'](x))
y = self.L['L2'](y1)
return y
# -------打印模型参数-------------------
model = Model() # 初始化模型
param_dict = dict(model.named_parameters()) # 从模型中提取出模型参数,并转为字典
for key in param_dict: # 历遍所有参数名称
print(key,':',param_dict[key].data) # 打印参数名称和数据
运行结果如下:
可以看到,存放在ModuleDict里的两个子模块L1,L2是可以被父类识别的
ModuleList使用示例
下面展示ModuleList的使用方式,它可以将子模块存放到类似列表的对象中
# 本代码用于展示ModuleList的使用方法
from torch import nn
# -----------定义模型-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# 将两个线性模块存放在列表中,此时需要使用ModuleList方法
self.L = nn.ModuleList([nn.Linear(2,3),nn.Linear(2,3)])
def forward(self, x):
# 定义模型的计算
y1 = nn.functional.tanh(self.L[0](x))
y = self.L[1](y1)
return y
# -------打印模型参数-------------------
model = Model() # 初始化模型
param_dict = dict(model.named_parameters()) # 从模型中提取出模型参数,并转为字典
for key in param_dict: # 历遍所有参数名称
print(key,':',param_dict[key].data) # 打印参数名称和数据
运行结果如下:
可以看到,存放在ModuleList里的两个子模块L1,L2是可以被父类识别的
作为ModuleDict与ModuleList的补充,本节讲解如何使用注册模块的方式来实现添加非直属module对象到父模块中
注册模块的方法-add_module
当定义的子模块不是父模块直属下的的module对象时,父模块是无法识别子模块的,
除了以上ModuleDict与ModuleList的方式外,也可以使用add_module强行注册模块,来使得模型识别模块
也可以用register_module来注册模块,register_module和add_module是同一个函数的两个不同名称
具体示例如下:
from torch import nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layerList = [ nn.Linear(3,2),nn.Linear(5,2)] # 包含在列表里,不会识别
self.add_module('L1',self.layerList[0])
print('\n模型结构:', Model()) # 打印模型结构
运行结果如下:
可以看到,3*2的线性层使用add_module后会被父模块所识别,
而5*2的线性层包含在List中,却没有使用add_module,则不会被识别
End