将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。

(1)划分方法

# 下面是,数据批次划分

    batch_size = 10
    # 将训练数据的特征和标签组合,使用from_tensor_slices将数据放入队列
    dataset = tfdata.Dataset.from_tensor_slices((features, labels))
    # 使用shuffle(),随机打乱数据集顺序,不用shuffle就是按顺序划分,buffer_size 参数应大于等于样本数
    # dataset = dataset.shuffle(buffer_size=num_examples)
    # batch把dataset按照batch_size分批次,得到一个list集合。默认drop_remainder=False时,保留不足批次的部分,如果是True,就是舍去。
    dataset = dataset.batch(batch_size)
    # dataset = dataset.batch(batch_size).repeat()  # repeat表示重复次数,默认是None,表示数据序列无限延续

 

# 输出

    # 输出所有batch的list集合。
    # print(list(dataset.as_numpy_iterator()))

    # 输出其中一个batch,两种方法,官方推荐way2!
    print("way1")
    data_iter = iter(dataset)
    for X, y in data_iter:
        print(X, y)
        break
    print("way2")
    for (batch_num, (X, y)) in enumerate(dataset):
        print((X, y))  # batch_num是批次号,标识符,也可以起其他名字
        break

 

(2)dataset.batch()方法说明

batch把dataset按照batch_size分批次,得到一个list集合。默认drop_remainder=False时,保留不足批次的部分,如果是True,就是舍去。
list(dataset.as_numpy_iterator())方法可以输出所有batch的list集合。
  def batch(self, batch_size, drop_remainder=False):
    """Combines consecutive elements of this dataset into batches.

    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3)
    >>> list(dataset.as_numpy_iterator()) #这个方法可以输出所有batch的list
    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]

    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3, drop_remainder=True)
    >>> list(dataset.as_numpy_iterator())
    [array([0, 1, 2]), array([3, 4, 5])]

(3)dataset.repeat()方法说明

  def repeat(self, count=None):
    """Repeats this dataset so each original value is seen `count` times.

    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
    >>> dataset = dataset.repeat(3)
    >>> list(dataset.as_numpy_iterator())
    [1, 2, 3, 1, 2, 3, 1, 2, 3]

    Note: If this dataset is a function of global state (e.g. a random number
    generator), then different repetitions may produce different elements.

    Args:
      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
        number of times the dataset should be repeated. The default behavior (if
        `count` is `None` or `-1`) is for the dataset be repeated indefinitely.

    Returns:
      Dataset: A `Dataset`.
    """

 

2.tf.train

参考:https://www.cnblogs.com/jfl-xx/p/9945967.html