Pytorch教程

【示例】pytorch初始化参数-按层类型

作者 : 老饼 发表日期 : 2024-03-14 06:11:50 更新日期 : 2024-03-19 11:53:04
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com




      01. pytorch初始化参数-按层类型   



本节展示两种按层的类型对模型参数进行初始化的方法



      pytorch初始化参数-按层类型(方法一)     


在对pytorch的模型参数进行初始化时,往往需要按层的类型进行不同的初始化方法
 例如线性层使用一种初始化方法,卷积层使用一种初始化方法等等
按层的类型进行初始化,只需要将模型的每一个模块都进行历遍,然后根据模块的层类型进行初始化就可以
 具体示例如下:
from   torch import nn
import torch
# ---------模型定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.C1 = nn.Conv2d(1,2, kernel_size=3,stride=1,padding=1) 
        self.F  = nn.Flatten()
        self.L2 = nn.Linear(2,3)                                 
    def forward(self, x):
        y = self.F(self.C1(x))
        y = self.L2(y)             
        return y

# -------模型参数初始化----------------------
def init_param(model):
    # 对模型的每一个模块进行历遍,如果是卷积层或线性层就初始化参数
    for m in model.modules():
        if isinstance(m,nn.Conv2d):                   # 如果是卷积层
            torch.nn.init.normal_(m.weight.data,0.1)  # 卷积层权重的初始化方法
            torch.nn.init.constant_(m.bias.data,1.)   # 卷积层阈值的初始化方法
        elif isinstance(m,nn.Linear):                 # 如果是线性层
            torch.nn.init.normal_(m.weight.data,0.1)  # 线性层权重的初始化方法
            torch.nn.init.zeros_(m.bias.data)         # 线性层阈值的初始化方法
            
# ------展示效果--------------------------
model = Model()                                       # 初始化模型
init_param(model)                                     # 初始化模型参数
print(model.state_dict())                             # 打印结果
运行结果如下:
  





      pytorch初始化参数-按层类型(方法二)     


第二种方法是只写模块的初始化函数,然后使用apply函数将模型的每一个模块都使用该函数进行初始化,
具体示例如下:
from   torch import nn
import torch
# ---------模型定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.C1 = nn.Conv2d(1,2, kernel_size=3,stride=1,padding=1) 
        self.F  = nn.Flatten()
        self.L2 = nn.Linear(2,3)                                 
    def forward(self, x):
        y = self.F(self.Cov(x))
        y = self.L2(y)             
        return y

# -------模块的参数初始化方法----------------------
def init_module_param(m):
    if isinstance(m,nn.Conv2d):                   # 如果是卷积层
        torch.nn.init.normal_(m.weight.data,0.1)  # 卷积层权重的初始化方法
        torch.nn.init.constant_(m.bias.data,1.)   # 卷积层阈值的初始化方法
    elif isinstance(m,nn.Linear):                 # 如果是线性层
        torch.nn.init.normal_(m.weight.data,0.1)  # 线性层权重的初始化方法
        torch.nn.init.zeros_(m.bias.data)         # 线性层阈值的初始化方法
            
# ------展示效果--------------------------
model = Model()                                   # 初始化模型
model.apply(init_module_param)                    # 将model的每一个module都应用init_module_param进行初始化
print(model.state_dict())                         # 打印结果
运行结果如下:
 





     特别说明    


上述初始化参数的方法中,使用到了model.modules()来获取模型的所有层,
需要注意的是,model.modules()获取到的是所有的层,包括根节点层、叶子层等等,
 示例如下:
 
上述模型的model.modules()包含了5个元素,如下
  
所以在model.modules()的历遍中,并非只历遍最终的叶子层,而是历遍所有的层
因为,我们在历遍过程中,需要注意,初始化方法只需针对属于叶子层的类型就可以,
而对于非叶子层类型,我们需要直接忽略,不要再进行处理









 End 






联系老饼