本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
在使用pytorch的tensor时,有时需要修改tensor是否需要梯度属性
本文讲解tensor作为叶子节点变量和非叶子节点变量时分别应该如何修改是否需要梯度
在pytorch中修改tensor是否需要梯度,对于叶子节点与非叶子节点的方法是不一样的
下面展示tensor作为叶子节点变量和非叶子节点变量时如何修改是否需要梯度
👉叶子节点变量的修改
叶子节点可以使用requires_grad_直接修改tensors是否需要梯度
示例如下:
import torch
x = torch.tensor([1,2],dtype=(float),requires_grad=True) # 定义参数x
y = x**2
x.requires_grad_(False)
print('x是否带有梯度:',x.requires_grad)
运行结果如下:
x是否带有梯度: False
👉非叶子节点变量的修改
非叶子节点(例如y)不能直接修改,需要用detach将y从图中脱离,再用requires_grad_进行修改
示例如下:
import torch
x = torch.tensor([1],dtype=(float),requires_grad=True) # 定义参数x
y = x**2 # y由x运算后得到
y = y.detach() # 将y从运算图中脱离
y.requires_grad_(True) # 修改y是否需要梯度
print('y是否带有梯度:',y.requires_grad) # 打印结果
运行结果如下:
y是否带有梯度: True
好了,以上就是如何在pytorch中修改tensor是否需要梯度的方法了~
End