猫狗分类,练手级代码,与手写数字识别相比,主要修改的地方是输出全连接层,将输出通道由10(十个数字)改成2(猫狗二分类)。还有一个是对数据集处理,因pytorch没有内置数据集函数,因此图片要自己处理。

数据要用opencv处理,归一化。

数据集:data __train__Cat

      |     |__Dog

      |__test__Cat

         |__Dog

get_data.py

import os
import cv2
import time
from torchvision import transforms
import torch
trans=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((.5,.5,.5),(.5,.5,.5))
    ]
)
DATA_PATH = './data/'
PIC_SIZE = 32


def get_files():

    train_data = []
    test_data =  []
    train_cat_path = DATA_PATH + 'train/Cat/'
    train_dog_path = DATA_PATH + 'train/Dog/'
    test_cat_path = DATA_PATH + 'test/Cat/'
    test_dog_path = DATA_PATH + 'test/Dog/'

    print('now,loading data.due to The amount of data is huge,you have to wait minutes')
    start_time=temp_time=time.time()

    for file in os.listdir(train_cat_path):
        image=cv2.imread(train_cat_path+file)
        try:
            image=cv2.resize(image, (PIC_SIZE, PIC_SIZE))
            train_data.append([image,0])
        except BaseException:
            os.remove(train_cat_path+file)
            # print('无效的图片:%s' % file)
        finally:
            if time.time()-temp_time > 20:
                temp_time=time.time()
                print('Take %d seconds'%(time.time()-start_time))



    for file in os.listdir(train_dog_path):
        image = cv2.imread(train_dog_path + file)
        try:
            image=cv2.resize(image, (PIC_SIZE, PIC_SIZE))
            train_data.append([image,1])

        except BaseException:
            os.remove(train_dog_path + file)
            # print('无效的图片:%s' % file)
        finally:
            if time.time() - temp_time > 20:
                temp_time = time.time()
                print('Take %d seconds' % (time.time() - start_time))


    for file in os.listdir(test_cat_path):
        image = cv2.imread(test_cat_path + file)
        try:
            image = cv2.resize(image, (PIC_SIZE, PIC_SIZE))
            test_data.append([image,0])

        except BaseException:
            os.remove(test_cat_path + file)
            # print('无效的图片:%s' % file)
        finally:
            if time.time() - temp_time > 20:
                temp_time = time.time()
                print('Take %d seconds' % (time.time() - start_time))

    for file in os.listdir(test_dog_path):
        image = cv2.imread(test_dog_path + file)
        try:
            image = cv2.resize(image, (PIC_SIZE, PIC_SIZE))
            test_data.append([image,1])

        except BaseException:
            os.remove(test_dog_path + file)
            # print('无效的图片:%s' % file)
        finally:
            if time.time() - temp_time > 20:
                temp_time = time.time()
                print('Take %d seconds' % (time.time() - start_time))

    for img in train_data:
        img[0]=trans(img[0])


    for img in test_data:
        img[0]=trans(img[0])

    print('have loaded the data:\nThere are %d train_data\nThere are %d test_data' %(len(train_data), len(test_data)))
    print('-----------------------------------------------------------------------------')

    return train_data,test_data

if __name__ == '__main__':
    torch.save(get_files(),"data.pyd")

将数据集写到data.pyd

然后训练,测试。

dogVScat.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

LR = 0.01
MOM = 0.5
EPOCHES=100
BATCHSIZE=50

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3,out_channels=10,kernel_size=3)
        self.conv2 = nn.Conv2d(10,20,3)
        self.conv3 = nn.Conv2d(20,10,3)

        self.mp = nn.MaxPool2d(2)
        self.fc = nn.Linear(40,2)

    def forward(self,x):
        in_size = x.size(0)

        x = F.relu(self.mp(self.conv1(x)))

        x = F.relu(self.mp(self.conv2(x)))

        x = F.relu(self.mp(self.conv3(x)))

        x = x.view(in_size,-1)

        x = self.fc(x)

        return F.log_softmax(x,dim=1)

def train():

    xbatch = []
    ybatch = []

    for i, (x, y) in enumerate(train_data):
        xbatch.append(x)
        ybatch.append(y)

        if (i+1) % BATCHSIZE == 0:

            xbatch = torch.stack(xbatch)  #convert list of tensor into tensor

            ybatch = torch.Tensor(ybatch).long()

            out = model(xbatch)

            loss =  F.nll_loss(out, ybatch)

            xbatch = []
            ybatch = []

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # print(str(epoch)+" epoch has Completed training")
    # torch.save(model,str(epoch)+".pkl")

def test(epoch):

    test_loss = 0
    correct = 0

    xbatch = []
    ybatch = []
    for i,(x,y) in enumerate(test_data):
        xbatch.append(x)
        ybatch.append(y)

        if (i+1) % BATCHSIZE == 0:

            xbatch = torch.stack(xbatch)  #convert list of tensor into tensor
            ybatch = torch.Tensor(ybatch).long()

            output = model(xbatch)

            pred=torch.max(output,1)[1]

            correct +=pred.eq(ybatch).sum(0).numpy()


            # test_loss += F.nll_loss(output, ybatch).data[0]
            xbatch = []
            ybatch = []

    print('correct of epoch {} is {:.2f}%'.format(epoch,correct/len(test_data)*100))

if __name__ == '__main__':
    model = Net()
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOM)
    train_data, test_data = torch.load("data.pyd")
    np.random.shuffle(train_data)
    for epoch in range(EPOCHES):
        train()
        test(epoch)

训练结果:

correct of epoch 0 is 52.33%
correct of epoch 1 is 54.84%
correct of epoch 2 is 55.95%
correct of epoch 3 is 56.59%
correct of epoch 4 is 57.57%
correct of epoch 5 is 60.50%
correct of epoch 6 is 62.18%
correct of epoch 7 is 63.81%
correct of epoch 8 is 64.46%
correct of epoch 9 is 65.24%
correct of epoch 10 is 65.93%
correct of epoch 11 is 66.55%
correct of epoch 12 is 67.47%
correct of epoch 13 is 68.45%
correct of epoch 14 is 69.00%
correct of epoch 15 is 69.62%
correct of epoch 16 is 69.99%
correct of epoch 17 is 70.58%
correct of epoch 18 is 71.10%
correct of epoch 19 is 71.42%
correct of epoch 20 is 71.87%
correct of epoch 21 is 72.31%
correct of epoch 22 is 72.36%
correct of epoch 23 is 72.76%
correct of epoch 24 is 73.01%
correct of epoch 25 is 73.32%
correct of epoch 26 is 73.36%
correct of epoch 27 is 73.51%
correct of epoch 28 is 73.17%
correct of epoch 29 is 73.38%
correct of epoch 30 is 73.50%
correct of epoch 31 is 73.73%
correct of epoch 32 is 73.93%
correct of epoch 33 is 74.15%
correct of epoch 34 is 74.11%
correct of epoch 35 is 74.22%
correct of epoch 36 is 74.26%
correct of epoch 37 is 74.07%
correct of epoch 38 is 74.12%
correct of epoch 39 is 74.35%
correct of epoch 40 is 74.38%
correct of epoch 41 is 74.44%
correct of epoch 42 is 74.17%
correct of epoch 43 is 74.19%
correct of epoch 44 is 74.30%
correct of epoch 45 is 74.61%
correct of epoch 46 is 74.64%
correct of epoch 47 is 74.54%
correct of epoch 48 is 74.58%
correct of epoch 49 is 74.59%
correct of epoch 50 is 74.59%
correct of epoch 51 is 74.53%
correct of epoch 52 is 74.45%
correct of epoch 53 is 74.43%
correct of epoch 54 is 74.43%
correct of epoch 55 is 74.41%
correct of epoch 56 is 74.42%
correct of epoch 57 is 74.52%
correct of epoch 58 is 74.48%
correct of epoch 59 is 74.34%
correct of epoch 60 is 74.21%
correct of epoch 61 is 74.16%
correct of epoch 62 is 74.15%
correct of epoch 63 is 74.25%
correct of epoch 64 is 74.11%
correct of epoch 65 is 73.95%
correct of epoch 66 is 73.85%
correct of epoch 67 is 73.99%
correct of epoch 68 is 74.15%
correct of epoch 69 is 74.05%
correct of epoch 70 is 74.05%
correct of epoch 71 is 74.34%
correct of epoch 72 is 74.21%
correct of epoch 73 is 74.14%
correct of epoch 74 is 73.98%
correct of epoch 75 is 73.87%
correct of epoch 76 is 73.88%
correct of epoch 77 is 73.85%
correct of epoch 78 is 73.84%
correct of epoch 79 is 73.84%
correct of epoch 80 is 73.65%
correct of epoch 81 is 73.66%
correct of epoch 82 is 73.43%
correct of epoch 83 is 73.36%
correct of epoch 84 is 73.30%
correct of epoch 85 is 73.12%
correct of epoch 86 is 73.20%
correct of epoch 87 is 73.22%
correct of epoch 88 is 73.13%
correct of epoch 89 is 73.16%
correct of epoch 90 is 73.17%
correct of epoch 91 is 72.99%
correct of epoch 92 is 73.09%
correct of epoch 93 is 73.02%
correct of epoch 94 is 72.80%
correct of epoch 95 is 72.98%
correct of epoch 96 is 72.73%
correct of epoch 97 is 72.80%
correct of epoch 98 is 72.76%
correct of epoch 99 is 72.68%

最高准确率为74.6%