本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
pytorch的自动求导对于模型训练带来了极大的便利,但同时也有较多使用注意事项
例如pytorch如果自动求导时,中间变量不是tensor就会报错
本文通过正反实例的讲解,来提醒使用pytorch自动求导时必须注意所有变量都必须是tensor
本节展示pytorch求导时如果中间变量不是tensor时的报错,及修正后的示例
1. 所有变量必须是tensor变量
在进行复合函数求导时,所有变量都必须是tensor变量,包括中间变量
👉错误示例
如下的例子,会报错:
import torch
import math
x = torch.tensor([3],dtype=(float),requires_grad=True) # 生成一个tensor数据x
y = math.sin(x) # 根据x计算y
z = 3*y # 根据y计算z
z.backward() # 将y反向传播
print('z对x的梯度:',x.grad)
报错如下:AttributeError: 'float' object has no attribute 'backward'
👉正确示例
上述例子报错是因为y=math.sin(x)得到的y并非tensor变量,因此不能进行自动求导
必须使用torch.sin而不是math.sin,
修改后代码如下:
import torch
x = torch.tensor([3],dtype=(float),requires_grad=True) # 生成一个tensor数据x
y = torch.sin(x) # 根据x计算y
z = 3*y # 根据y计算z
z.backward() # 将y反向传播
print('z对x的梯度:',x.grad)
运行结果:
z对x的梯度: tensor([-2.9700], dtype=torch.float64)
好了,以上就是pytorch求导时必须注意变量必须是tensor的示例了~
End