dim=0,按行求平均值,返回的形状是(1,列数)

dim=1,按列求平均值,返回的形状是(行数,1)

1 x = torch.randn(2, 2, 2)
2 x
1 tensor([[[-0.7596, -0.4972],
2          [ 0.3271, -0.0415]],
3 
4         [[ 1.0684, -1.1522],
5          [ 0.5555,  0.6117]]])
1 x.mean(-3)
1 tensor([[ 0.1544, -0.8247],
2         [ 0.4413,  0.2851]])
1 x.mean(-3).shape
1 torch.Size([2, 2])