将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。
(1)划分方法
# 下面是,数据批次划分 batch_size = 10 # 将训练数据的特征和标签组合,使用from_tensor_slices将数据放入队列 dataset = tfdata.Dataset.from_tensor_slices((features, labels)) # 使用shuffle(),随机打乱数据集顺序,不用shuffle就是按顺序划分,buffer_size 参数应大于等于样本数 # dataset = dataset.shuffle(buffer_size=num_examples) # batch把dataset按照batch_size分批次,得到一个list集合。默认drop_remainder=False时,保留不足批次的部分,如果是True,就是舍去。 dataset = dataset.batch(batch_size) # dataset = dataset.batch(batch_size).repeat() # repeat表示重复次数,默认是None,表示数据序列无限延续
# 输出 # 输出所有batch的list集合。 # print(list(dataset.as_numpy_iterator())) # 输出其中一个batch,两种方法,官方推荐way2! print("way1") data_iter = iter(dataset) for X, y in data_iter: print(X, y) break print("way2") for (batch_num, (X, y)) in enumerate(dataset): print((X, y)) # batch_num是批次号,标识符,也可以起其他名字 break
(2)dataset.batch()方法说明
batch把dataset按照batch_size分批次,得到一个list集合。默认drop_remainder=False时,保留不足批次的部分,如果是True,就是舍去。
用list(dataset.as_numpy_iterator())方法可以输出所有batch的list集合。
def batch(self, batch_size, drop_remainder=False): """Combines consecutive elements of this dataset into batches. >>> dataset = tf.data.Dataset.range(8) >>> dataset = dataset.batch(3) >>> list(dataset.as_numpy_iterator()) #这个方法可以输出所有batch的list [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])] >>> dataset = tf.data.Dataset.range(8) >>> dataset = dataset.batch(3, drop_remainder=True) >>> list(dataset.as_numpy_iterator()) [array([0, 1, 2]), array([3, 4, 5])]
(3)dataset.repeat()方法说明
def repeat(self, count=None): """Repeats this dataset so each original value is seen `count` times. >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) >>> dataset = dataset.repeat(3) >>> list(dataset.as_numpy_iterator()) [1, 2, 3, 1, 2, 3, 1, 2, 3] Note: If this dataset is a function of global state (e.g. a random number generator), then different repetitions may produce different elements. Args: count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the number of times the dataset should be repeated. The default behavior (if `count` is `None` or `-1`) is for the dataset be repeated indefinitely. Returns: Dataset: A `Dataset`. """
2.tf.train
参考:https://www.cnblogs.com/jfl-xx/p/9945967.html
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFlow2.0——划分数据集 - Python技术站