import tensorflow as tf import matplotlib.pyplot as plt from tensorflow import keras fashion_mnist = keras.datasets.fashion_mnist (train_X, train_y), (test_X,test_y) = fashion_mnist.load_data() valid_X, train_X = train_X[:1000], train_X[1000:] valid_y, train_y = train_y[:1000], train_y[1000:] plt.figure() row = 3 col = 3 class_name = ['T-shirt', 'Trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] for r in range(row): for c in range(col): index = col*r + c + 1 plt.subplot(row,col,index) plt.imshow(train_X[index], cmap='binary') plt.axis("off") plt.title(class_name[train_y[index]]) plt.show()
load_data可以自动划分为训练集和测试集,不过验证集需要自己划分。
注意plt.subplot的用法
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:读取keras中的fashion_mnist数据集并查看 - Python技术站