Pytorch教程

【注意】pytorch求导注意-变量必须是tensor

作者 : 老饼 发表日期 : 2023-11-30 04:05:58 更新日期 : 2024-06-24 15:28:49
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



pytorch的自动求导对于模型训练带来了极大的便利,但同时也有较多使用注意事项

例如pytorch如果自动求导时,中间变量不是tensor就会报错

本文通过正反实例的讲解,来提醒使用pytorch自动求导时必须注意所有变量都必须是tensor




一. pytorch自动求导时的注意事项-变量必须是tensor   



本节展示pytorch求导时如果中间变量不是tensor时的报错,及修正后的示例



1. 所有变量必须是tensor变量

在进行复合函数求导时,所有变量都必须是tensor变量,包括中间变量


👉错误示例

如下的例子,会报错:

import torch
import math 
x = torch.tensor([3],dtype=(float),requires_grad=True)   # 生成一个tensor数据x
y = math.sin(x)                                                  # 根据x计算y
z = 3*y                                                  # 根据y计算z
z.backward()                                             # 将y反向传播
print('z对x的梯度:',x.grad)  

报错如下:AttributeError: 'float' object has no attribute 'backward'


👉正确示例

上述例子报错是因为y=math.sin(x)得到的y并非tensor变量,因此不能进行自动求导

必须使用torch.sin而不是math.sin,

修改后代码如下:

import torch
x = torch.tensor([3],dtype=(float),requires_grad=True)   # 生成一个tensor数据x
y = torch.sin(x)                                                  # 根据x计算y
z = 3*y                                                  # 根据y计算z
z.backward()                                             # 将y反向传播
print('z对x的梯度:',x.grad)         

运行结果:

z对x的梯度: tensor([-2.9700], dtype=torch.float64)






好了,以上就是pytorch求导时必须注意变量必须是tensor的示例了~







 End 






联系老饼