一、数据集简介
二、MNIST数据集介绍
三、CIFAR 10/100数据集介绍
四、tf.data.Dataset.from_tensor_slices()
五、shuffle()随机打散
六、map()数据预处理
七、实战
import tensorflow as tf import tensorflow.keras as keras import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' def prepare_mnist_features_and_labels(x,y): x = tf.cast(x, tf.float32) / 255.0 y = tf.cast(y, tf.int64) return x,y def mnist_dataset(): (x,y), (x_test,y_test) = keras.datasets.fashion_mnist.load_data() #numpy中的格式 y = tf.one_hot(y, depth=10) #[10k] ==> [10k,10]的tensor y_test = tf.one_hot(y_test, depth=10) ds = tf.data.Dataset.from_tensor_slices((x,y)) ds = ds.map(prepare_mnist_features_and_labels) #数据预处理,注意:tf.map中传进的参数 ds = ds.shuffle(60000).batch(100) #随机打散,读取一个batch的样本 ds_val = tf.data.Dataset.from_tensor_slices((x_test,y_test)) ds_val = ds_val.map(prepare_mnist_features_and_labels) ds_val = ds_val.shuffle(10000).batch(100) return ds, ds_val def main(): ds, ds_val = mnist_dataset() print("训练集信息如下:") iteration_ds = iter(ds) iter_ds = next(iteration_ds) print(iter_ds[0].shape, iter_ds[1].shape) print("测试集信息如下:") iteration_ds_val = iter(ds_val) iter_ds_val = next(iteration_ds_val) print(iter_ds_val[0].shape, iter_ds_val[1].shape) if __name__ == '__main__': main()
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow(十七):数据的加载:map()、shuffle()、tf.data.Dataset.from_tensor_slices() - Python技术站