本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
本节讲解pytorch中是怎么对模型的参数进行初始化的
如何对pytorch模型的参数进行初始化
pytorch对模型的参数初始化是面向每一个参数的,
所以只要拿出每一个参数对象,然后对其初始化就可以
对于单个参数的初始化方法,可以采取如下两种方式
👉1. 直接赋值
👉2. 采用pytorch提供的初始化函数进行初始化
由于pytorch已经提供了足够多的常用初始化函数,使用起来较为方便
所以一般来说,都是采用pytorch提供的初始化函数进行初始化
关于pytorch提供的初始化参数
pytorch提供了各种各样常用的参数初始化函数,例如将参数初始化为常数、正态分布等等
pytorch提供的相关方法如下:
常数初始化
初始化为指定常数 :torch.nn.init.constant_(tensor, val)
初始化全为1的常数:torch.nn.init.ones_(tensor)
初始化全为0的常数:torch.nn.init.zeros_(tensor)
初始化为单位矩阵 :torch.nn.init.eye_(tensor) :
一般的随机概率分布初始化方法
均匀分布:torch.nn.init.uniform_(tensor, a=0.0, b=1.0, generator=None)
正态分布:torch.nn.init.normal_(tensor, mean=0.0, std=1.0, generator=None)
截断正态分布:torch.nn.init.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0)
稀疏正态分布初始化:torch.nn.init.sparse_(tensor, sparsity, std=0.01)
Xavier初始化、凯明初始化与正交初始化
Xavier均匀分布:torch.nn.init.xavier_uniform_(tensor, gain=1.0, generator=None)
Xavier正态分布:torch.nn.init.xavier_normal_(tensor, gain=1.0, generator=None)
凯明均匀分布 :torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
凯明正态分布 :torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
正交初始化 :torch.nn.init.orthogonal_(tensor, gain=1, generator=None)
pytorch对相关初始化API的说明文档: 《pytorch的官方初始化API说明》
笔者对相关初始化API的使用说明:《pytorch官方提供的各种参数初始化方法》
本节展示一个示例,用于说明在pytorch是如何对模型参数进行初始化的
pytorch初始化参数示例
下面我们展示一个对pytorch模型里的参数进行初始化的例子
在本例子里,我们使用了自行直接赋值的方法以及采用pytorch提供的初始化函数进行赋值的方法
一个初始化的示例如下:
from torch import nn
import torch
from torch.nn import functional as F
# ---------模型定义-----------
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(2,3) # 定义第一个线性层模块
self.L2 = nn.Linear(3,2) # 定义第二个线性层模块
def forward(self, x):
y = F.relu(self.L1(x)) # 第一层的计算
y = self.L2(y) # 第二层的计算
return y
model = Model() # 初始化模型
# ------历遍式初始化参数-------
param_dict = dict(model.named_parameters()) # 获取模型的参数字典
for key in param_dict: # 历遍每个参数,对其初始化
torch.nn.init.normal_(param_dict[key]) # 这里展示用pytorch提供的初始化函数(正态分布)对参数初始化
# -----单个参数的初始化---------------------
param_dict['L1.weight'].data = torch.ones(3,2) # 这里展示直接赋值的方式对参数初始化
print(model.state_dict()) # 打印结果
End