@
index索引
torch会自动从左向右索引
例子:
a = torch.randn(4,3,28,28)
表示类似一个CNN 的图片的输入数据,4表示这个batch一共有4张照片,而3表示图片的通道数为3(RGB),(28,28)表示图片的大小
基本索引
索引1:表示第零张图片的shape
print(a[0].shape)
#torch.Size([3,28,28])
索引2:第零张图片的第零个通道的size
print(a[0,0].shape)
#torch.Size([28,28])
索引3:表示第零张图片的第零个通道的第二行第四列的像素点的值
print(a[0,0,2,4])
#tensor(0.8082)
连续选取
⭐索引4:连续取两张图片(取第0张以及第一张图片,不包括第二张
)
print(a[:2].shape
#torch.Size([2,3,28,28])
#由于是两张图片,所以第一维变为2
⭐索引5:前两张图片上的第一个通道上的数据(所以通道数变为了1)
print(a[:2,:1,:,:].shape)
print(a[:2,:1].shape)
#torch.Size(2,1,28,28)
⭐索引6:从后面取(-1表示最后一个,从最后一个取到最后,也就是一个通道)
print(a[:2,-1:,:,:].shape)
#torch.Size(2,1,28,28)
规则间隔索引
⭐索引7:在图片的矩阵进行隔行与隔列索引 0:28:2表示从0到28(不包括28),间隔数为2
print(a[:,:,0:28:2,0:28:2].shape)
print(a[:,:,::2,::2].shape)
#torch.Size([4,3,14,14])
索引总结
start : end : step
:
都取
x:
从x取到最后 :x
从开始取到x x:y
从x取到y
x:y:z
从x到y每隔z个点采样一次
不规则间隔索引
使用index_select()函数
第一个参数表示你对哪个维度进行操作;第二个参数是index(必须是tensor类型
):对第0张与第2张图片进行操作
a.index_select(0,torch.tensor([0,2])).shape
#【2,3,28,28】
同理:选择了两个通道
a.index_select(1,torch.tensor([1,2])).shape
#【4,2,28,28】
同理:只取8行
a.index_select(2,torch.arange(8)).shape
#【4,2,8,28】
任意多的维度索引
使用符号:...
例子:
a[...].shape
#[4,3,28,28]
a[0,...].shape
#[3,28,28]
a[0,1,...].shape
#[4,28,28]
a[...,2].shape
#[4,3,28,2]
使用掩码来索引
函数:.masked_select()
会将筛选出来的元素打平(因为无法维护原来的shape)
x = torch.randn(2,3)
print(x)
tensor([[-1.3081, -0.5651, -0.9843],
[ 1.0051, -0.3829, 0.6300]])
mask = x.ge(0.5)#大于等于0.5的元素
print(mask)
tensor([[False, False, False],
[ True, False, True]])
z = torch.masked_select(x,mask)
print(z)
tensor([1.0051, 0.6300])
打平后的索引
例子:使用take函数:是将输入的tensor打平之后进行index的选择
src = torch.tensor([[4,3,5],[6,7,8]])
torch.take(src,torch.tensor([0,2,8]))
#tensor([4,5,8])
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch索引与切片 - Python技术站