Pytorch教程

【示例】pytorch初始化参数-按参数名称

作者 : 老饼 发表日期 : 2024-03-14 19:36:33 更新日期 : 2024-03-19 11:57:04
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com





      01. pytorch按参数名称对模型参数进行初始化   



本节展示在pytorch中,如何根据参数的名称对模型参数进行初始化的方法



      pytorch按参数名称初始化参数    


在对pytorch的模型参数进行初始化时,往往需要根据参数的名称进行不同的初始化方法
 例如模型的阈值使用一种初始化方法,模型的权重使用另一种初始化方法等等
要实现按参数名称进行初始化,
只需要将模型的每一个参数进行历遍,然后判断参数名称从而选择不同的初始化方法就可以
 具体示例如下:
from   torch import nn
import torch
# ---------模型定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.C1 = nn.Conv2d(1,2, kernel_size=3,stride=1,padding=1) 
        self.F  = nn.Flatten()
        self.L2 = nn.Linear(2,3)                                 
    def forward(self, x):
        y = self.F(self.Cov(x))
        y = self.L2(y)             
        return y

# -------模型参数初始化----------------------
def init_param(model):
    param_dict = dict(model.named_parameters())       # 获取模型的参数字典
    for key in  param_dict:                           # 历遍每个参数,对其初始化
        param_name = key.split(".")[-1]               # 获取参数的尾缀作为名称
        if (param_name=='weight'):                    # 如果是权重
            torch.nn.init.normal_(param_dict[key])    # 则正态分布初始化
        elif (param_name=='bias'):                    # 如果是阈值
            torch.nn.init.zeros_(param_dict[key])     # 则初始化为0
            
# ------展示效果--------------------------
model = Model()                                       # 初始化模型
init_param(model)                                     # 初始化模型参数
print(model.state_dict())                             # 打印结果
运行结果如下:
  
可以看到,权重用正态分布初始化,而阈值则初始化为0










 End 








联系老饼