本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
pytorch为了更灵活的设置模型中的参数,为此提供了ParameterDict和ParameterList
ParameterDict和ParameterList主要用于替代普通的Dict和List在Module中进行使用
本文讲解pytorch中ParameterDict和ParameterList的作用以及展示一个例子用于说明它们的使用方式
本节讲解ParameterDict和ParameterList有什么作用,以及如何使用
ParameterDict与ParameterList有什么用
往往我们需要将模型的参数用字典或列表形式存起来,以方便在forward中使用
但是,直接将模型参数写在Dict或List里是不能被模型所捕捉的
但在一些使用场景中,我们就是需要将模型参数存到Dict或List里,
因此,pytorch提供了ParameterDict和ParameterList对象,
在需要使用Dict和List时,可以使用ParameterDict和ParameterList来替代,它们有类似字典/列表的功能
ParameterDict与ParameterList的使用示例
当参数需要存为List或Dict时,直接用List或Dict是不能被识别为模块的参数的,需要使用ParameterList和ParameterDict方法
ParameterList和ParameterDict的使用示例如下:
import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layerList = [nn.Parameter(torch.randn(1,3))
,nn.Parameter(torch.randn(3,2))]
self.paramsList= nn.ParameterList([nn.Parameter(torch.randn(10, 10)) ,nn.Parameter(torch.randn(3, 2)) ])
self.paramsDict = nn.ParameterDict({
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})
print('\n模型所包含的参数:',Model().state_dict().keys())
运行结果如下:
可以看到,当参数包含在ParameterDict和ParameterList时,是可以被模型所识别的
End