Pytorch教程

【示例】pytorch初始化参数-对指定层初始化

作者 : 老饼 发表日期 : 2024-03-14 06:12:40 更新日期 : 2024-03-19 12:12:48
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com





      01. pytorch对模型指定层的参数进行初始化   




本节展示两种在pytroch中对指定层的参数进行初始化的方法




      pytorch对指定层的参数初始化-方法一(获取层)      


在pytorch的模型初始化时,我们往往需要对某层(包括它的子层)采用特别的参数初始化方法,
例如对第2层我们希望把它全部初始化为0,那么怎么针对这一层进行特别的初始化呢
只需要将该层的特别地拿出来初始化即可
from   torch import nn
import torch
from torch.nn import functional as F
# ---------模型定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(2,3)                  # 定义第一个线性层模块
        self.L2 = nn.Linear(3,2)                  # 定义第二个线性层模块
        self.L3 = nn.Linear(2,2)                  # 定义第三个线性层模块
                                                  
    def forward(self, x):                         
        y = F.relu(self.L1(x))                    # 第一层的计算
        y = F.relu(self.L2(x))                    # 第二层的计算
        y = self.L3(y)                            # 第三层的计算
        return y                                  
model = Model()                                   # 初始化模型

# ------历遍式初始化参数-------
param_dict = dict(model.named_parameters())       # 获取模型的参数字典
for key in  param_dict:                           # 历遍每个参数,对其初始化
    torch.nn.init.normal_(param_dict[key])        # 将参数初始化为正态分布
    
# -----对L2进行初始化,将参数初始化为0---------------------
param_dict_2 = dict(model.L2.named_parameters())  # 获取第二层的参数
for key in  param_dict_2:                         # 历遍每个参数,对其初始化
    torch.nn.init.zeros_(param_dict_2[key])       # 将参数初始化为0
print(model.state_dict())                         # 打印结果
运行结果如下:
  
 从结果可以看到,L2层的权重阈值都被初始化为0






     pytorch对指定层的参数进行特别初始化-方法二(根据层名)      


在参数初始化过程中,如果要对某些层需要采用特别的初始化方法,
也可以根据参数的名称来进行区别,因为不同层的参数名称的前缀不一样
具体示例如下:
from   torch import nn
import torch
from torch.nn import functional as F
# ---------模型定义-----------
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(2,3)                  # 定义第一个线性层模块
        self.L2 = nn.Linear(3,2)                  # 定义第二个线性层模块
        self.L3 = nn.Linear(2,2)                  # 定义第三个线性层模块
                                                  
    def forward(self, x):                         
        y = F.relu(self.L1(x))                    # 第一层的计算
        y = F.relu(self.L2(x))                    # 第二层的计算
        y = self.L3(y)                            # 第三层的计算
        return y                                  
model = Model()                                   # 初始化模型

# ------历遍式初始化参数-------
param_dict = dict(model.named_parameters())       # 获取模型的参数字典
for key in  param_dict:                           # 历遍每个参数,对其初始化
    layer_name = '.'.join(key.split(".")[0:1])    # 将参数第1个点之前的字符作为层名
    if (layer_name=='L2'):                        # 如果是L2的参数
        torch.nn.init.zeros_(param_dict[key])     # 将参数初始化为0
    else:                                         # 否则
        torch.nn.init.normal_(param_dict[key])    # 将参数初始化为正态分布
print(model.state_dict())                         # 打印结果
运行结果如下:
 
 
 从结果可以看到,L2层的权重阈值都被初始化为0










 End 









联系老饼