本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
pytorch用backward自动求梯度时,默认情况下是不保留中间变量的梯度的
本文讲解如何让pytorch在backward后保留中间变量的梯度,以方便实际使用所需
本节展示pytorch在backward后中间变量的默认值,以及以何保留中间变量的梯度
在backward后,默认只会保留叶子节点的梯度,
这是由于pytorch是动态运算图,为了节省内存,中间结果的梯度会被释放
默认情况下,对于非叶子节点(中间变量),在backward后是没有梯度的
👉backward后中间变量的梯度
下面展示一个例子,直接在backward后查看中间变量的梯度
示例如下:
import torch
x = torch.tensor([1,2],dtype=(float),requires_grad=True) # 定义参数x
A = torch.tensor([2,3],dtype=(float)) # 常量A,不需梯度
y = A@(x*x) # 定义y
z = 5*y # 定义z
z.backward() # 向后传播
print('x.grad:',x.grad) # 打印x的梯度
print('y.grad:',y.grad) # 打印y的梯度
运行结果如下:
x.grad: tensor([20., 60.], dtype=torch.float64)
y.grad: None
可以看到,y作为中间变量,是没有梯度的
👉pytorch如何在backward后保留中间变量的梯度
如果希望pytorch在backward后保留中间变量的梯度,需要用retain_grad来指定保留梯度
示例如下:
import torch
x = torch.tensor([1,2],dtype=(float),requires_grad=True) # 定义参数x
A = torch.tensor([2,3],dtype=(float)) # 常量A,不需梯度
y = A@(x*x) # 定义y
y.retain_grad() # 指明保留y的梯度
z = 5*y # 定义z
z.backward() # 向后传播
print('x.grad:',x.grad) # 打印x的梯度
print('y.grad:',y.grad) # 打印y的梯度
运行结果如下:
x.grad: tensor([20., 60.], dtype=torch.float64)
y.grad: tensor(5., dtype=torch.float64)
好了,以上就是pytorch如何在backward后保留中间变量的梯度的方法了~
End