Pytorch教程

【注意】pytorch求导注意-指数运算符

作者 : 老饼 发表日期 : 2023-11-24 12:38:25 更新日期 : 2024-01-19 08:30:29
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



pytorch的自动求导对于模型训练带来了极大的便利,但同时也有较多使用注意事项

例如pytorch如果自动求导时,在指数运算时使用"^"就会报错

本文通过正反实例的讲解,来提醒使用pytorch自动求导时指数运算需要使用"**"




一、pytorch求导注意事项-指数运算符需要使用**   



本文展示pytorch求导时使用指数运算符时的注意事项



01. 指数运算符需要使用**


在pytorch自动求导时,往往表达式中需要使用到指数运算,

但需要注意的是,需要使用“**”作为指数运算,如果使用"^"进行指数运算会出报错,


👉错误示例

pytorch自动求导时,使用"^"进行指数运算会出现报错

示例如下:

import torch 
x = torch.tensor([1],dtype=(float),requires_grad=True)   # 生成一个tensor数据x
y = x+x^2                                                # 计算y
y.backward()                                             # 将y反向传播
print('y对x的梯度:',x.grad)                              # 打印x的梯度  

报错如下:

RuntimeError: "bitwise_xor_cpu" not implemented for 'Double'


👉正确示例

上面的例子报错主要是因为指数运算符使用了“^”,

将上述例子的指数运算符改为"**"后,就能顺利求导,

示例如下:

import torch 
x = torch.tensor([1],dtype=(float),requires_grad=True)   # 生成一个tensor数据x
y = x+x**2                                                # 计算y
y.backward()                                             # 将y反向传播
print('y对x的梯度:',x.grad)                              # 打印x的梯度 

运行结果:

y对x的梯度: tensor([3.], dtype=torch.float64)






好了,以上就是pytorch自动求导时必须注意指数运算符使用**的示例了~







 End 




联系老饼