本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
在使用pytorch的模型时,我们往往需要获取指定的参数对象进行相关操作,例如对参数进行初始化
因此,本文详细讲解,在pytorch中如何获取模型的所有参数对象,以及如何获取指定的参数对象等等
本节讲解pytorch获取模型参数对象的相关API
pytorch如何获取模型里的参数对象
模块(模型)使用Parameter类来记录模块所包含的参数,Parameter类就只是一个tensor,当成tensor来操作就行
要获取模型里包含的参数对象,一般使用的API如下:
get_parameter :通过参数名称获取参数
parameters :获取模块包含的所有参数对象 ,不带名称
named_parameters :获取模块包含的所有参数对象,带名称
总的来说,get_parameter获得单个参数对象,parameters和named_parameters获得所有参数对象
get_parameter、parameters、named_parameters的使用示例
get_parameter、parameters、named_parameters的使用示例如下:
from torch import nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.L2 = nn.Sequential(OrderedDict({
'L2': nn.Linear(2,2),
'R': nn.ReLU(inplace=True)
}))
# 通过get_parameter获取模型参数
model = Model() # 初始化模型
param = model.get_parameter('L1.weight') # 知道参数名时,使用get_parameter可直接获得参数对象
print('\n通过get_parameter获得参数L1.weight如下:\n', param.data)
运行结果如下:
可以看到,在获取指定的参数,只需使用 get_parameter即可
parameters、named_parameters的使用示例
parameters、named_parameters的使用示例如下:
from torch import nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.L2 = nn.Sequential(OrderedDict({
'L2': nn.Linear(2,2),
'R': nn.ReLU(inplace=True)
}))
# 通过parameters与named_parameters获取模型所有参数
model = Model() # 初始化模型
param_list = list(model.parameters()) # 获得所有参数对象(不带参数名),并转为列表
param_dict = dict(model.named_parameters()) # 获得所有参数对象(带参数名),并转为列表
print('\n-----通过parameters获得所有参数列表:-----\n',param_list)
print('\n-----通过named_parameters获得所有参数字典:-----\n',param_dict)
运行结果如下:
可以看到,parameters和named_parameters都能得到模块所有的参数
唯一不同的是,named_parameters获取的参数带有参数名称
本节展示如何按模型的每个子模块获取各个子模块的参数对象
pytorch按子模块获取模型中的参数
如果需要按模块获取模型中的各个模块的参数,那么只需先获得各个子模块,再获取各个子模块的参数就可以了
具体示例如下
from torch import nn
from collections import OrderedDict
# 定义模型
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.L2 = nn.Sequential(OrderedDict({
'L2': nn.Linear(2,2),
'R': nn.ReLU(inplace=True)
}))
model = Model() # 初始化模型
module_dict = dict(model.named_children()) # 提取出所有子模块
for key in module_dict: # 逐模块提取参数
cur_param = dict( module_dict[key].named_parameters()) # 提取当前模块的参数
print('\n----模块'+key+'的参数如下:-----\n',cur_param) # 打印参数
运行结果如下:
可以看到,已经分别独立提取出各个模块的参数
state_dict函数也可以返回模型的参数数据,所以往往也会通过state_dict来查看参数
虽然state_dict返回的并非参数对象本身,但由于较为常用,所以在本节加以介绍
state_dict的使用
state_dict函数也可以获取模型的所有参数,但获取的不是模型的参数对象,而仅仅是参数的数据,
即state_dict返回的是tensor,而不是Parameter,修改state_dict返回的数据并不会影响模型的参数
state_dict的具体使用示例如下:
from torch import nn
from collections import OrderedDict
# 定义模型
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.L1 = nn.Linear(3,2)
self.L2 = nn.Sequential(OrderedDict({
'L2': nn.Linear(2,2),
'R': nn.ReLU(inplace=True)
}))
model = Model() # 初始化模型
state_dict = model.state_dict() # 获取模型的参数字典
print('\n---state_dict:----\n',state_dict) # 打印参数字典
运行结果如下:
可以看到,state_dict以OrderedDict对象返回了所有参数的数据,它存储的是tensor,而不是Parameter
End