Pytorch教程

【拓展】ModuleDict与ModuleList的使用

作者 : 老饼 发表日期 : 2024-03-11 00:26:55 更新日期 : 2024-03-24 15:52:42
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



pytorch为了更灵活的设置模型中的子模块,为此提供了ModuleDict与ModuleList

ModuleDict与ModuleList主要用于替代普通的Dict和List来存放Module字典及列表

本文讲解pytorch中ModuleDict与ModuleList的作用以及展示一个例子用于说明它们的使用方式





     01. ModuleDict与ModuleList     



本节讲解ModuleDict与ModuleList有什么作用,以及如何使用



     ModuleDict与ModuleList有什么用    


往往我们需要将子模块用字典或列表形式存起来,以方便在forward中使用,
但是,直接将子模块写在Dict或List里是不能被父模块作为子模块所捕捉的
这意味着,写在Dict和List里的子模块所定义的参数也不能被父模块捕捉到
 
但在一些使用场景中,我们就是需要将子模块存到Dict或List里,
因此,pytorch提供了ModuleDict与ModuleList对象,
在需要使用Dict和List时,可以使用ModuleDict与ModuleList来替代,它们有类似字典/列表的功能




     ModuleDict使用示例    


下面展示ModuleDict的使用方式,它可以将子模块存放到类似字典的对象中
# 本代码用于展示ModuleDict的使用方法
from   torch import nn
# -----------定义模型-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 将两个线性模块存放在字典中,此时需要使用ModuleDict方法
        self.L = nn.ModuleDict({
                'L1':nn.Linear(2,3),
                'L2':nn.Linear(2,3),
        })
    def forward(self, x):
        # 定义模型的计算
        y1 = nn.functional.tanh(self.L['L1'](x)) 
        y = self.L['L2'](y1)
        return y
# -------打印模型参数-------------------
model = Model()                                            # 初始化模型
param_dict = dict(model.named_parameters())                # 从模型中提取出模型参数,并转为字典
for key in param_dict:                                     # 历遍所有参数名称
   print(key,':',param_dict[key].data)                     # 打印参数名称和数据
运行结果如下:
  
可以看到,存放在ModuleDict里的两个子模块L1,L2是可以被父类识别的





     ModuleList使用示例    


下面展示ModuleList的使用方式,它可以将子模块存放到类似列表的对象中
# 本代码用于展示ModuleList的使用方法
from   torch import nn
# -----------定义模型-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 将两个线性模块存放在列表中,此时需要使用ModuleList方法
        self.L =  nn.ModuleList([nn.Linear(2,3),nn.Linear(2,3)])
    def forward(self, x):
        # 定义模型的计算
        y1 = nn.functional.tanh(self.L[0](x)) 
        y = self.L[1](y1)
        return y
# -------打印模型参数-------------------
model = Model()                                            # 初始化模型
param_dict = dict(model.named_parameters())                # 从模型中提取出模型参数,并转为字典
for key in param_dict:                                     # 历遍所有参数名称
   print(key,':',param_dict[key].data)                     # 打印参数名称和数据
运行结果如下:
  
可以看到,存放在ModuleList里的两个子模块L1,L2是可以被父类识别的







     02. 模块的注册-add_module      




作为ModuleDict与ModuleList的补充,本节讲解如何使用注册模块的方式来实现添加非直属module对象到父模块中




    注册模块的方法-add_module     


当定义的子模块不是父模块直属下的的module对象时,父模块是无法识别子模块的,
除了以上ModuleDict与ModuleList的方式外,也可以使用add_module强行注册模块,来使得模型识别模块
 也可以用register_module来注册模块,register_module和add_module是同一个函数的两个不同名称
 具体示例如下:
from   torch import nn
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layerList = [ nn.Linear(3,2),nn.Linear(5,2)]    # 包含在列表里,不会识别
        self.add_module('L1',self.layerList[0])
print('\n模型结构:', Model())                                 # 打印模型结构  
运行结果如下:
  
可以看到,3*2的线性层使用add_module后会被父模块所识别,
而5*2的线性层包含在List中,却没有使用add_module,则不会被识别













 End 









联系老饼