简介
notMNIST数据集 是于2011公布的,可以认为是MNIST数据集地一个加强版本。数据集包含了从A到J十个字母,由large与small两个子集组成。其中samll数据集是经过手工清理的,包含19k个图片,误分类率越为0.5%,large数据集是未经过手工清理的,包含500k张图片,误分类率约为6.5%。
作者推荐在large数据集上训练网络,在small数据集上测试网络。可以将large数据集分为5/6和1/6,使用5/6做training,1/6做validation。
在该网站上网友做的正确率较高的再97%到98%,我自己使用resnet最高达到了98.04%。接下来就说一下我做的步骤。
分类
数据预处理
一步要解决的是数据集的加载。原始数据集是一些很小地图片,一个一个地从磁盘中加载无疑会拖慢模型训练的速度。最好的方式就是将所有数据都加载到内存中。因此,可以将数据加载到内存中,并将标准化之后的数据以二进制文件使用pickle
保存到磁盘。这样,每次从磁盘中读取数据可以直接读取二进制文件,否则每次读取数据集中地图片都会耗时很久。
import os, cv2, pickle
import numpy as np
rootdir = 'D:/DataSet/notMNIST/notMNIST_large'
classlist = os.listdir(rootdir)
imgLabels = []
imgNames = []
for classes in classlist:
imgFolder = os.path.join(rootdir, classes)
imgnames = os.listdir(imgFolder)
imgLabels.extend([idxName[classes]] * len(imgnames))
imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])
imgs = np.zeros((len(imgLabels), 28, 28), np.float)
idx = 0
print('loading training data......')
for imgname in imgNames:
try:
img = cv2.imread(imgname, 0).astype(np.float) / 255.0
imgs[idx, :, :] = img
idx += 1
except AttributeError:
np.delete(imgs, idx, axis=0)
print('loading training data finished, %d samples' % imgs.shape[0])
train_mean, train_std = np.mean(imgs), np.std(imgs)
print('%.6f, %6f', train_mean, train_std)
imgs = (imgs - train_mean) / train_std
data = {'images': imgs, 'labels': imgLabels}
with open('D:/DataSet/notMNIST/trainset', 'wb') as f:
pickle.dump(data, f)
print('train set finished')
rootdir = 'D:/DataSet/notMNIST/notMNIST_small'
classlist = os.listdir(rootdir)
imgLabels = []
imgNames = []
for classes in classlist:
imgFolder = os.path.join(rootdir, classes)
imgnames = os.listdir(imgFolder)
imgLabels.extend([idxName[classes]] * len(imgnames))
imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])
imgs = np.zeros((len(imgLabels), 28, 28), np.float)
idx = 0
print('loading test data......')
for imgname in imgNames:
try:
img = cv2.imread(imgname, 0).astype(np.float) / 255.0
imgs[idx, :, :] = img
idx += 1
except AttributeError:
np.delete(imgs, idx, axis=0)
print('loading test data finished. % d samples' % imgs.shape[0])
train_mean, train_std = np.mean(imgs), np.std(imgs)
imgs = (imgs - train_mean) / train_std
data = {'images': imgs, 'labels': imgLabels}
with open('D:/DataSet/notMNIST/testset', 'wb') as f:
pickle.dump(data, f)
print('test set finished')
使用try
语句地原因是,在读取过程中可能出现一些错误。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:notMNIST 数据集pyTorch分类 - Python技术站