Pytorch教程

【示例】pytorch获取模型里的参数

作者 : 老饼 发表日期 : 2024-03-13 15:04:21 更新日期 : 2024-03-24 16:16:30
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com




在使用pytorch的模型时,我们往往需要获取指定的参数对象进行相关操作,例如对参数进行初始化

因此,本文详细讲解,在pytorch中如何获取模型的所有参数对象,以及如何获取指定的参数对象等等





    01. pytorch获取模型中参数对象的API    



本节讲解pytorch获取模型参数对象的相关API



      pytorch如何获取模型里的参数对象     


模块(模型)使用Parameter类来记录模块所包含的参数,Parameter类就只是一个tensor,当成tensor来操作就行
要获取模型里包含的参数对象,一般使用的API如下:
 get_parameter         :通过参数名称获取参数                                      
 parameters              :获取模块包含的所有参数对象 ,不带名称          
 named_parameters :获取模块包含的所有参数对象,带名称               
总的来说,get_parameter获得单个参数对象,parameters和named_parameters获得所有参数对象





             get_parameter、parameters、named_parameters的使用示例           


get_parameter、parameters、named_parameters的使用示例如下:
from   torch import nn
from collections import OrderedDict
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.L2 = nn.Sequential(OrderedDict({
            'L2': nn.Linear(2,2),
            'R': nn.ReLU(inplace=True)
            }))
# 通过get_parameter获取模型参数
model = Model()                                   # 初始化模型
param = model.get_parameter('L1.weight')          # 知道参数名时,使用get_parameter可直接获得参数对象
print('\n通过get_parameter获得参数L1.weight如下:\n', param.data)
运行结果如下:
 
可以看到,在获取指定的参数,只需使用 get_parameter即可





          parameters、named_parameters的使用示例           


parameters、named_parameters的使用示例如下:
from   torch import nn
from collections import OrderedDict
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.L2 = nn.Sequential(OrderedDict({
            'L2': nn.Linear(2,2),
            'R': nn.ReLU(inplace=True)
            }))
# 通过parameters与named_parameters获取模型所有参数
model = Model()                                   # 初始化模型
param_list = list(model.parameters())             # 获得所有参数对象(不带参数名),并转为列表
param_dict = dict(model.named_parameters())       # 获得所有参数对象(带参数名),并转为列表
print('\n-----通过parameters获得所有参数列表:-----\n',param_list)
print('\n-----通过named_parameters获得所有参数字典:-----\n',param_dict)
运行结果如下:
  
可以看到,parameters和named_parameters都能得到模块所有的参数
唯一不同的是,named_parameters获取的参数带有参数名称






    02. pytorch如何分模块获取模型的参数对象    



本节展示如何按模型的每个子模块获取各个子模块的参数对象




     pytorch按子模块获取模型中的参数     


 如果需要按模块获取模型中的各个模块的参数,那么只需先获得各个子模块,再获取各个子模块的参数就可以了
 具体示例如下
from   torch import nn
from collections import OrderedDict
# 定义模型
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.L2 = nn.Sequential(OrderedDict({
            'L2': nn.Linear(2,2),
            'R': nn.ReLU(inplace=True)
            }))
model = Model()                                            # 初始化模型
module_dict = dict(model.named_children())                 # 提取出所有子模块
for key in module_dict:                                    # 逐模块提取参数
    cur_param = dict( module_dict[key].named_parameters()) # 提取当前模块的参数
    print('\n----模块'+key+'的参数如下:-----\n',cur_param)   # 打印参数
运行结果如下:
  
可以看到,已经分别独立提取出各个模块的参数







     03. 关于state_dict     



state_dict函数也可以返回模型的参数数据,所以往往也会通过state_dict来查看参数

虽然state_dict返回的并非参数对象本身,但由于较为常用,所以在本节加以介绍



     state_dict的使用    


state_dict函数也可以获取模型的所有参数,但获取的不是模型的参数对象,而仅仅是参数的数据,
即state_dict返回的是tensor,而不是Parameter,修改state_dict返回的数据并不会影响模型的参数
 
 state_dict的具体使用示例如下:
from   torch import nn
from collections import OrderedDict
# 定义模型
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(3,2)   
        self.L2 = nn.Sequential(OrderedDict({
            'L2': nn.Linear(2,2),
            'R': nn.ReLU(inplace=True)
            }))
model = Model()                                            # 初始化模型
state_dict = model.state_dict()                            # 获取模型的参数字典
print('\n---state_dict:----\n',state_dict)                 # 打印参数字典
运行结果如下:
  
可以看到,state_dict以OrderedDict对象返回了所有参数的数据,它存储的是tensor,而不是Parameter









 End 








联系老饼