本站原创文章,转载请说明来自《老饼讲解-深度学习》www.bbbdata.com
在pytorch的应用中,往往需要找出tensor中符合条件的元素进行修改
本文讲解如何在找出tensor中符合条件的元素的索引,
并展示怎么利用索引对tensor的数据进一步进行修改
本节讲解pytorch找出tensor中符合条件的元素的索引的两种方法
1.查找tensor中符合条件的元素的索引-where方法
pytorch可以使用where函数找出tensor中符合条件的元素的索引
示例如下:
import torch
torch.manual_seed(99)
t = torch.arange(0, 9).view(3, 3) # 生成一个3*3的tensor
idx = torch.where(t>4) # 找出tensor中符合>4的元素的索引
print('t:',t) # 打印tensor
print('符合条件>4的索引idx:',idx) # 打印索引
运行结果如下:
t: tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
符合条件>4的索引idx: (tensor([1, 2, 2, 2]), tensor([2, 0, 1, 2]))
✍️解说:
上述的where函数返回的索引分为多个tensor,
idx[0]=[1,2,2,2]代表符合条件的元素的第0维索引,
idx[1]=[2,0,1,2]代表符合条件的元素的第1维索引,
即符合条件>4的元素索引为:[1,2],[2,0],[2,1],[2,2]
2.查找tensor中符合条件的元素的索引-argwhere方法
pytorch也可以使用argwhere函数找出tensor中符合条件的元素的索引
import torch
torch.manual_seed(99)
t = torch.arange(0, 9).view(3, 3) # 生成一个3*3的tensor
idx = torch.argwhere(t>4) # 找出tensor中符合>4的元素的索引
print('t:',t) # 打印tensor
print('符合条件>4的索引idx:',idx) # 打印索引
运行结果如下:
t: tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
符合条件>4的索引idx:
tensor([[1, 2],
[2, 0],
[2, 1],
[2, 2]])
本节讲解pytorch如何修改tensor中符合条件的元素的数据
1.根据条件修改tensor的值
往往需要修改矩阵中符合某个条件的数据,
那么在pytorch的tensor中可以借助where函数来实现这一效果
下面展示一个常用的条件索引的例子,用于学习和借鉴
如下述例子,将tensor中大于4的元素全部改成999:
import torch
torch.manual_seed(99)
t = torch.arange(0, 9).view(3, 3) # 生成一个3*3的tensor
print('修改前的t:',t) # 打印修改前的tensor
idx = torch.where(t>4) # 找出tensor中符合>4的元素的索引
print('符合条件的索引idx:',idx) # 打印索引
t[idx] = 999 # 根据索引修改tensor的值
print('修改后的t:',t) # 打印修改后的tensor
运行结果如下:
修改前的t: tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
符合条件的索引idx: (tensor([1, 2, 2, 2]), tensor([2, 0, 1, 2]))
修改后的t: tensor([[ 0, 1, 2],
[ 3, 4, 999],
[999, 999, 999]])
2.根据条件填充tensor的值
往往只是需要简单地根据条件填充tensor的值,
那么可以在where中加入填充条件,具体如下
示例如下:
import torch
torch.manual_seed(99)
t = torch.arange(0, 9).view(3, 3) # 生成一个3*3的tensor
x = torch.where(t>4,1,0) # t>4是条件,成立则填充为1,否则为0
print(t)
print(x)
运行结果如下
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[0, 0, 0],
[0, 0, 1],
[1, 1, 1]])
好了,以上就是pytorch中tensor的条件索引以及根据条件修改tensor的方法了~
End