import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
# BATCH_SIZE = 5
BATCH_SIZE = 8 # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10) # this is x data (torch tensor)
y = torch.linspace(10, 1, 10) # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=False, # 设置不随机打乱数据 random shuffle for training
num_workers=2, # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
for epoch in range(3): # 全部的数据使用3遍,train entire dataset 3 times
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
# train your data...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
show_batch()
BATCH_SIZE = 8 , 所有数据利用三次
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
END
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 6 batch_train 批训练 - Python技术站