本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本节展示在pytorch中,如何根据参数的名称对模型参数进行初始化的方法
pytorch按参数名称初始化参数
在对pytorch的模型参数进行初始化时,往往需要根据参数的名称进行不同的初始化方法
例如模型的阈值使用一种初始化方法,模型的权重使用另一种初始化方法等等
要实现按参数名称进行初始化,
只需要将模型的每一个参数进行历遍,然后判断参数名称从而选择不同的初始化方法就可以
具体示例如下:
from torch import nn
import torch
# ---------模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.C1 = nn.Conv2d(1,2, kernel_size=3,stride=1,padding=1)
self.F = nn.Flatten()
self.L2 = nn.Linear(2,3)
def forward(self, x):
y = self.F(self.Cov(x))
y = self.L2(y)
return y
# -------模型参数初始化----------------------
def init_param(model):
param_dict = dict(model.named_parameters()) # 获取模型的参数字典
for key in param_dict: # 历遍每个参数,对其初始化
param_name = key.split(".")[-1] # 获取参数的尾缀作为名称
if (param_name=='weight'): # 如果是权重
torch.nn.init.normal_(param_dict[key]) # 则正态分布初始化
elif (param_name=='bias'): # 如果是阈值
torch.nn.init.zeros_(param_dict[key]) # 则初始化为0
# ------展示效果--------------------------
model = Model() # 初始化模型
init_param(model) # 初始化模型参数
print(model.state_dict()) # 打印结果
运行结果如下:
可以看到,权重用正态分布初始化,而阈值则初始化为0
End