先放一张表,可以看成是二维数组
行(列)索引 | 索引0 | 索引1 | 索引2 | 索引3 |
---|---|---|---|---|
索引0 | 0 | 1 | 2 | 3 |
索引1 | 4 | 5 | 6 | 7 |
索引2 | 8 | 9 | 10 | 11 |
索引3 | 12 | 13 | 14 | 15 |
看一下下面例子代码:
针对0维(输出为行形式)
>>> import torch as t
>>> a = t.arange(0,16).view(4,4)
>>> a
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
#选取对角线的元素
>>> index = t.LongTensor([[0,1,2,3]])
>>> a.gather(0,index)
tensor([[ 0, 5, 10, 15]])
如何理解结果呢?其实很简单,就是a.gather(0,index)中第一个0已经表明输出结果是行形式(0维),如果第一个是1说明输出结果是列形式(1维),然后按照index = tensor([[0, 1, 2, 3]])
顺序作用在行上索引依次为0,1,2,3:
- a[0][0] = 0
- a[1][1] = 5
- a[2][2] = 10
- a[3][3] = 15
针对0维
# 选取反对角线上的元素,注意与上面的不同
>>> index2 = t.LongTensor([[3,2,1,0]])
>>> a.gather(0,index2)
tensor([[12, 9, 6, 3]])
如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])
顺序作用在行上索引依次为3,2,1,0:
- a[3][0] = 12
- a[2][1] = 9
- a[1][2] = 6
- a[0][3] = 3
针对1维(输出为列形式)
选取对角线的元素
>>> index3 = t.LongTensor([[0,1,2,3]]).t()
>>> a.gather(1,index3)
tensor([[ 0],
[ 5],
[10],
[15]])
如何理解结果呢?同理,按照index = tensor([[0, 1, 2, 3]])
顺序作用在列上索引依次为0,1,2,3:
- a[0][0] = 0
- a[1][1] = 5
- a[2][2] = 10
- a[3][3] = 15
针对1维
选取反对角线上的元素
>>> index4 = t.LongTensor([[3,2,1,0]]).t()
>>> a.gather(1,index4)
tensor([[ 3],
[ 6],
[ 9],
[12]])
如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])
顺序作用在列上索引依次为3,2,1,0:
- a[0][3] = 3
- a[1][2] = 6
- a[2][1] = 9
- a[3][0] = 12
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch的gather用法理解 - Python技术站