本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
pytorch的自动求导对于模型训练带来了极大的便利,但同时也有较多使用注意事项
例如pytorch如果自动求导时,在指数运算时使用"^"就会报错
本文通过正反实例的讲解,来提醒使用pytorch自动求导时指数运算需要使用"**"
本文展示pytorch求导时使用指数运算符时的注意事项
在pytorch自动求导时,往往表达式中需要使用到指数运算,
但需要注意的是,需要使用“**”作为指数运算,如果使用"^"进行指数运算会出报错,
👉错误示例
pytorch自动求导时,使用"^"进行指数运算会出现报错
示例如下:
import torch
x = torch.tensor([1],dtype=(float),requires_grad=True) # 生成一个tensor数据x
y = x+x^2 # 计算y
y.backward() # 将y反向传播
print('y对x的梯度:',x.grad) # 打印x的梯度
报错如下:
RuntimeError: "bitwise_xor_cpu" not implemented for 'Double'
👉正确示例
上面的例子报错主要是因为指数运算符使用了“^”,
将上述例子的指数运算符改为"**"后,就能顺利求导,
示例如下:
import torch
x = torch.tensor([1],dtype=(float),requires_grad=True) # 生成一个tensor数据x
y = x+x**2 # 计算y
y.backward() # 将y反向传播
print('y对x的梯度:',x.grad) # 打印x的梯度
运行结果:
y对x的梯度: tensor([3.], dtype=torch.float64)
好了,以上就是pytorch自动求导时必须注意指数运算符使用**的示例了~
End