Pytorch教程

【常用】pytorch在backward后保留中间变量的梯度

作者 : 老饼 发表日期 : 2023-11-26 08:19:09 更新日期 : 2024-06-24 15:27:23
本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com



pytorch用backward自动求梯度时,默认情况下是不保留中间变量的梯度的

本文讲解如何让pytorch在backward后保留中间变量的梯度,以方便实际使用所需




一、 pytorch在backward后保留中间变量的梯度  



本节展示pytorch在backward后中间变量的默认值,以及以何保留中间变量的梯度



01、pytorch如何保留中间变量的梯度

在backward后,默认只会保留叶子节点的梯度,

这是由于pytorch是动态运算图,为了节省内存,中间结果的梯度会被释放

默认情况下,对于非叶子节点(中间变量),在backward后是没有梯度的


👉backward后中间变量的梯度

下面展示一个例子,直接在backward后查看中间变量的梯度

示例如下:

import torch
x = torch.tensor([1,2],dtype=(float),requires_grad=True)   # 定义参数x
A = torch.tensor([2,3],dtype=(float))                      # 常量A,不需梯度
y = A@(x*x)                                                # 定义y
z = 5*y                                                    # 定义z 
z.backward()                                               # 向后传播
print('x.grad:',x.grad)                                    # 打印x的梯度
print('y.grad:',y.grad)                                    # 打印y的梯度

运行结果如下:

x.grad: tensor([20., 60.], dtype=torch.float64)
y.grad: None

可以看到,y作为中间变量,是没有梯度的


👉pytorch如何在backward后保留中间变量的梯度

如果希望pytorch在backward后保留中间变量的梯度,需要用retain_grad来指定保留梯度

示例如下:

import torch
x = torch.tensor([1,2],dtype=(float),requires_grad=True)   # 定义参数x
A = torch.tensor([2,3],dtype=(float))                      # 常量A,不需梯度
y = A@(x*x)                                                # 定义y
y.retain_grad()                                            # 指明保留y的梯度
z = 5*y                                                    # 定义z 
z.backward()                                               # 向后传播
print('x.grad:',x.grad)                                    # 打印x的梯度
print('y.grad:',y.grad)                                    # 打印y的梯度

运行结果如下:

x.grad: tensor([20., 60.], dtype=torch.float64)
y.grad: tensor(5., dtype=torch.float64)






好了,以上就是pytorch如何在backward后保留中间变量的梯度的方法了~






 End 





联系老饼