本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本节展示两种在pytroch中对指定层的参数进行初始化的方法
pytorch对指定层的参数初始化-方法一(获取层)
在pytorch的模型初始化时,我们往往需要对某层(包括它的子层)采用特别的参数初始化方法,
例如对第2层我们希望把它全部初始化为0,那么怎么针对这一层进行特别的初始化呢
只需要将该层的特别地拿出来初始化即可
from torch import nn
import torch
from torch.nn import functional as F
# ---------模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(2,3) # 定义第一个线性层模块
self.L2 = nn.Linear(3,2) # 定义第二个线性层模块
self.L3 = nn.Linear(2,2) # 定义第三个线性层模块
def forward(self, x):
y = F.relu(self.L1(x)) # 第一层的计算
y = F.relu(self.L2(x)) # 第二层的计算
y = self.L3(y) # 第三层的计算
return y
model = Model() # 初始化模型
# ------历遍式初始化参数-------
param_dict = dict(model.named_parameters()) # 获取模型的参数字典
for key in param_dict: # 历遍每个参数,对其初始化
torch.nn.init.normal_(param_dict[key]) # 将参数初始化为正态分布
# -----对L2进行初始化,将参数初始化为0---------------------
param_dict_2 = dict(model.L2.named_parameters()) # 获取第二层的参数
for key in param_dict_2: # 历遍每个参数,对其初始化
torch.nn.init.zeros_(param_dict_2[key]) # 将参数初始化为0
print(model.state_dict()) # 打印结果
运行结果如下:
从结果可以看到,L2层的权重阈值都被初始化为0
pytorch对指定层的参数进行特别初始化-方法二(根据层名)
在参数初始化过程中,如果要对某些层需要采用特别的初始化方法,
也可以根据参数的名称来进行区别,因为不同层的参数名称的前缀不一样
具体示例如下:
from torch import nn
import torch
from torch.nn import functional as F
# ---------模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(2,3) # 定义第一个线性层模块
self.L2 = nn.Linear(3,2) # 定义第二个线性层模块
self.L3 = nn.Linear(2,2) # 定义第三个线性层模块
def forward(self, x):
y = F.relu(self.L1(x)) # 第一层的计算
y = F.relu(self.L2(x)) # 第二层的计算
y = self.L3(y) # 第三层的计算
return y
model = Model() # 初始化模型
# ------历遍式初始化参数-------
param_dict = dict(model.named_parameters()) # 获取模型的参数字典
for key in param_dict: # 历遍每个参数,对其初始化
layer_name = '.'.join(key.split(".")[0:1]) # 将参数第1个点之前的字符作为层名
if (layer_name=='L2'): # 如果是L2的参数
torch.nn.init.zeros_(param_dict[key]) # 将参数初始化为0
else: # 否则
torch.nn.init.normal_(param_dict[key]) # 将参数初始化为正态分布
print(model.state_dict()) # 打印结果
运行结果如下:
从结果可以看到,L2层的权重阈值都被初始化为0
End