1 import os 2 import imageio 3 from imageio import imread 4 import torch 5 6 # batch_size = 3 7 # batch = torch.zeros(batch_size, 3, 256, 256, dtype=torch.uint8) 8 # batch.shape #torch.Size([3, 3, 256, 256])[B,C,H,W] 9 10 data_dir = 'C:/Users/Dell/Pictures/' 11 filenames = [name for name in os.listdir(data_dir) 12 if os.path.splitext(name)[-1] == '.png'] #选择指定目录下的.png图片 13 14 for i, filename in enumerate(filenames): 15 img_arr = imageio.imread(os.path.join(data_dir, filename)) #imread读入为H*W*C 16 img_t = torch.from_numpy(img_arr) 17 img_t = img_t.permute(2, 0, 1) #交换维度 18 img_t = img_t[:3] #只保留前3个通道 19 batch[i] = img_t #指第i个维度上的所有数据
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch从一个输入目录中加载所有的PNG图像,并将它们存储在张量中 - Python技术站