参考:pytorch torchvision transform官方文档

Pytorch学习--编程实战:猫和狗二分类

深度学习框架PyTorch一书的学习-第五章-常用工具模块

# coding:utf8
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T


class DogCat(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        ''''''
        '''
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        '''
        self.test = test
        imgs = [os.path.join(root, img) for img in os.listdir(root)]

        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg
        if self.test:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

        imgs_num = len(imgs)

        # shuffle imgs 打乱图片顺序
        np.random.seed(100)
        imgs = np.random.permutation(imgs)

        #训练集和验证集比例为7:3
        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]  #训练集
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]  #验证集

        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406], #数据归一化处理
                                    std=[0.229, 0.224, 0.225])

            if self.test or not train:  #测试集+验证集
                self.transforms = T.Compose([   #用来管理各个transform
                    T.Scale(224),#将输入的`PIL.Image`重新改变大小成给定的`size`,`size`是最小边的边长。
                    举个例子,如果原图的`height>width`,那么改变大小后的图片大小是
                    `(size*height/width, size)`。 T.CenterCrop(224), #以输入图像img的中心作为中心点进行指定size的crop操作 T.ToTensor(), #在做数据归一化之前必须要把PIL Image转成Tensor normalize #数据归一化处理 ]) else: #训练集 self.transforms = T.Compose([ T.Scale(256), T.RandomSizedCrop(224),#先将给定的PIL.Image随机切,然后再resize成给定的size大小。 T.RandomHorizontalFlip(),#随机水平翻转给定的PIL.Image,概率为0.5。即:
            一半的概率翻转,一半的概率不翻转。 T.ToTensor(), normalize ]) def __getitem__(self, index):
'''''' ''' 一次返回一张图片的数据,并为训练集和验证集打标签 ''' img_path = self.imgs[index] if self.test: #测试集 label = int(self.imgs[index].split('.')[-2].split('/')[-1]) else: #验证集 训练集 label = 1 if 'dog' in img_path.split('/')[-1] else 0 data = Image.open(img_path) data = self.transforms(data) return data, label def __len__(self): return len(self.imgs)