import torch
通过 help((torch.cat)) 可以查看 cat 的用法 cat(seq,dim,out=None) 其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列 dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接
#实例: #dim=0 时: import torch n_data = torch.ones((100,2)) x0_data = torch.normal(2*n_data,1) y0_data = torch.zeros((100,1)) x1_data = torch.normal(-2*n_data,1) y1_data = torch.ones((100,1)) x_data = torch.cat((x0_data,x1_data),0).type(torch.FloatTensor) y_data = torch.cat((y0_data,y1_data),0).type(torch.LongTensor) print('x_data的形状:',x_data.shape) print("y_data的形状:",y_data.shape)
result: x_data的形状: torch.Size([200, 2]) y_data的形状: torch.Size([200, 1])
#实例: #dim=1 时: import torch n_data = torch.ones((100,2)) x0_data = torch.normal(2*n_data,1) y0_data = torch.zeros((100,1)) x1_data = torch.normal(-2*n_data,1) y1_data = torch.ones((100,1)) x_data = torch.cat((x0_data,x1_data),1).type(torch.FloatTensor) y_data = torch.cat((y0_data,y1_data),1).type(torch.LongTensor) print('x_data的形状:',x_data.shape) print("y_data的形状:",y_data.shape)
result: x_data的形状: torch.Size([100, 4]) y_data的形状: torch.Size([100, 2])
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch的torch.cat实例 - Python技术站