tensor索引与numpy类似,支持冒号,和数字直接索引
import torch
a = torch.Tensor(2, 3, 4)
a
# 输出:
tensor([[[9.2755e-39, 1.0561e-38, 9.7347e-39, 1.1112e-38],
[1.0194e-38, 8.4490e-39, 1.0102e-38, 9.0919e-39],
[1.0102e-38, 8.9082e-39, 8.4489e-39, 1.0102e-38]],
[[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
[9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
[1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]]])
# 冒号索引与数字索引
a[:1, :2, 1]
# 输出:
tensor([[1.0561e-38, 8.4490e-39]])
# 通过-1索引
a[-1]
# 输出:
tensor([[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
[9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
[1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]])
...(三个点)索引
用于维度过多,且取中间多个维度所有数据的情况
# 生成多维数据
a = torch.rand(1,2,3,2,4,5)
a
# 输出:
tensor([[[[[[0.1954, 0.1918, 0.3053, 0.3649, 0.3637],
[0.8467, 0.0205, 0.2187, 0.8438, 0.1754],
[0.7076, 0.7047, 0.1852, 0.5374, 0.7024],
[0.5630, 0.4526, 0.0662, 0.9463, 0.9294]],
[[0.6917, 0.5505, 0.5770, 0.3819, 0.9541],
[0.8957, 0.2530, 0.4858, 0.1866, 0.2542],
[0.3745, 0.2125, 0.5537, 0.5642, 0.2284],
[0.2634, 0.1147, 0.1793, 0.0277, 0.9800]]],
...
[[[0.9949, 0.2210, 0.3365, 0.0852, 0.4387],
[0.6440, 0.6391, 0.9141, 0.2288, 0.6203],
[0.0474, 0.7894, 0.4362, 0.9752, 0.7546],
[0.1234, 0.0246, 0.1436, 0.0053, 0.3405]],
[[0.8174, 0.9021, 0.0420, 0.2045, 0.2140],
[0.4844, 0.6342, 0.2965, 0.9299, 0.2284],
[0.1420, 0.1834, 0.0581, 0.8467, 0.8987],
[0.8012, 0.1526, 0.4293, 0.3928, 0.5437]]]]]])
# 取第一维和最后一维的0索引数据,中间所有维度数据全部取出
a[0, ..., 0]
# 输出:
tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
[0.6917, 0.8957, 0.3745, 0.2634]],
[[0.4374, 0.0534, 0.6809, 0.7086],
[0.2231, 0.6680, 0.8643, 0.9057]],
[[0.8169, 0.0649, 0.5923, 0.3802],
[0.2562, 0.0095, 0.8557, 0.6828]]],
[[[0.1514, 0.3948, 0.6452, 0.6332],
[0.8872, 0.7304, 0.6853, 0.9814]],
[[0.5736, 0.5195, 0.9711, 0.5575],
[0.6778, 0.9334, 0.5647, 0.1006]],
[[0.9949, 0.6440, 0.0474, 0.1234],
[0.8174, 0.4844, 0.1420, 0.8012]]]])
# 上面等价于
a[0,:,:,:,:,0]
# 输出:
tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
[0.6917, 0.8957, 0.3745, 0.2634]],
[[0.4374, 0.0534, 0.6809, 0.7086],
[0.2231, 0.6680, 0.8643, 0.9057]],
[[0.8169, 0.0649, 0.5923, 0.3802],
[0.2562, 0.0095, 0.8557, 0.6828]]],
[[[0.1514, 0.3948, 0.6452, 0.6332],
[0.8872, 0.7304, 0.6853, 0.9814]],
[[0.5736, 0.5195, 0.9711, 0.5575],
[0.6778, 0.9334, 0.5647, 0.1006]],
[[0.9949, 0.6440, 0.0474, 0.1234],
[0.8174, 0.4844, 0.1420, 0.8012]]]])
可以看出,使用...可以节省操作。
masked_select
# 生成随机数据
a = torch.randn(3, 4)
a
# 输出:
tensor([[ 0.8710, 0.8862, -0.4620, -0.9985],
[ 0.4734, -0.7182, -0.1516, 0.0209],
[ 0.5089, -0.8130, -0.4519, -0.6190]])
# 大于0.5的数据返回True
mask = a.ge(0.5)
mask
# 输出:
tensor([[ True, True, False, False],
[False, False, False, False],
[ True, False, False, False]])
# 通过上面生成的bool数据,利用masked_select来选择大于0.5的数据
torch.masked_select(a, mask)
# 输出:
tensor([0.8710, 0.8862, 0.5089])
take
a
# 输出:
tensor([[ 0.8710, 0.8862, -0.4620, -0.9985],
[ 0.4734, -0.7182, -0.1516, 0.0209],
[ 0.5089, -0.8130, -0.4519, -0.6190]])
# 先将数据打平展开为一维,再选取展开后对应索引[0, 5, 8, 11]的数据
torch.take(a, torch.tensor([0, 5, 8, 11]))
# 输出:
tensor([ 0.8710, -0.7182, 0.5089, -0.6190])
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch tensor的索引与切片 - Python技术站