本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本节展示两种按层的类型对模型参数进行初始化的方法
pytorch初始化参数-按层类型(方法一)
在对pytorch的模型参数进行初始化时,往往需要按层的类型进行不同的初始化方法
例如线性层使用一种初始化方法,卷积层使用一种初始化方法等等
按层的类型进行初始化,只需要将模型的每一个模块都进行历遍,然后根据模块的层类型进行初始化就可以
具体示例如下:
from torch import nn
import torch
# ---------模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.C1 = nn.Conv2d(1,2, kernel_size=3,stride=1,padding=1)
self.F = nn.Flatten()
self.L2 = nn.Linear(2,3)
def forward(self, x):
y = self.F(self.C1(x))
y = self.L2(y)
return y
# -------模型参数初始化----------------------
def init_param(model):
# 对模型的每一个模块进行历遍,如果是卷积层或线性层就初始化参数
for m in model.modules():
if isinstance(m,nn.Conv2d): # 如果是卷积层
torch.nn.init.normal_(m.weight.data,0.1) # 卷积层权重的初始化方法
torch.nn.init.constant_(m.bias.data,1.) # 卷积层阈值的初始化方法
elif isinstance(m,nn.Linear): # 如果是线性层
torch.nn.init.normal_(m.weight.data,0.1) # 线性层权重的初始化方法
torch.nn.init.zeros_(m.bias.data) # 线性层阈值的初始化方法
# ------展示效果--------------------------
model = Model() # 初始化模型
init_param(model) # 初始化模型参数
print(model.state_dict()) # 打印结果
运行结果如下:
pytorch初始化参数-按层类型(方法二)
第二种方法是只写模块的初始化函数,然后使用apply函数将模型的每一个模块都使用该函数进行初始化,
具体示例如下:
from torch import nn
import torch
# ---------模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.C1 = nn.Conv2d(1,2, kernel_size=3,stride=1,padding=1)
self.F = nn.Flatten()
self.L2 = nn.Linear(2,3)
def forward(self, x):
y = self.F(self.Cov(x))
y = self.L2(y)
return y
# -------模块的参数初始化方法----------------------
def init_module_param(m):
if isinstance(m,nn.Conv2d): # 如果是卷积层
torch.nn.init.normal_(m.weight.data,0.1) # 卷积层权重的初始化方法
torch.nn.init.constant_(m.bias.data,1.) # 卷积层阈值的初始化方法
elif isinstance(m,nn.Linear): # 如果是线性层
torch.nn.init.normal_(m.weight.data,0.1) # 线性层权重的初始化方法
torch.nn.init.zeros_(m.bias.data) # 线性层阈值的初始化方法
# ------展示效果--------------------------
model = Model() # 初始化模型
model.apply(init_module_param) # 将model的每一个module都应用init_module_param进行初始化
print(model.state_dict()) # 打印结果
运行结果如下:
特别说明
上述初始化参数的方法中,使用到了model.modules()来获取模型的所有层,
需要注意的是,model.modules()获取到的是所有的层,包括根节点层、叶子层等等,
示例如下:
上述模型的model.modules()包含了5个元素,如下
所以在model.modules()的历遍中,并非只历遍最终的叶子层,而是历遍所有的层
因为,我们在历遍过程中,需要注意,初始化方法只需针对属于叶子层的类型就可以,
而对于非叶子层类型,我们需要直接忽略,不要再进行处理
End