Pytorch教程

【拓展】ParameterDict和ParameterList的使用

作者 : 老饼 发表日期 : 2024-03-09 21:40:35 更新日期 : 2024-03-24 15:51:29
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



pytorch为了更灵活的设置模型中的参数,为此提供了ParameterDict和ParameterList

ParameterDict和ParameterList主要用于替代普通的Dict和List在Module中进行使用

本文讲解pytorch中ParameterDict和ParameterList的作用以及展示一个例子用于说明它们的使用方式





     01. ParameterDict和ParameterList    



本节讲解ParameterDict和ParameterList有什么作用,以及如何使用




     ParameterDict与ParameterList有什么用    


往往我们需要将模型的参数用字典或列表形式存起来,以方便在forward中使用
但是,直接将模型参数写在Dict或List里是不能被模型所捕捉的
 
但在一些使用场景中,我们就是需要将模型参数存到Dict或List里,
因此,pytorch提供了ParameterDict和ParameterList对象,
在需要使用Dict和List时,可以使用ParameterDict和ParameterList来替代,它们有类似字典/列表的功能




      ParameterDict与ParameterList的使用示例       


当参数需要存为List或Dict时,直接用List或Dict是不能被识别为模块的参数的,需要使用ParameterList和ParameterDict方法
ParameterList和ParameterDict的使用示例如下:
import torch
from   torch import nn
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layerList = [nn.Parameter(torch.randn(1,3))
                          ,nn.Parameter(torch.randn(3,2))] 
        self.paramsList= nn.ParameterList([nn.Parameter(torch.randn(10, 10)) ,nn.Parameter(torch.randn(3, 2)) ])
        self.paramsDict = nn.ParameterDict({
                        'left': nn.Parameter(torch.randn(5, 10)),
                        'right': nn.Parameter(torch.randn(5, 10))
                })
print('\n模型所包含的参数:',Model().state_dict().keys())
运行结果如下:
   
可以看到,当参数包含在ParameterDict和ParameterList时,是可以被模型所识别的








 End 








联系老饼