下面我来为你详细讲解如何获取 PyTorch Tensor 的维度信息。
第一种方法:使用 PyTorch 内置的方法
PyTorch 中的 Tensor 对象有一个 size()
方法,可以用来获取 Tensor 的维度信息。具体用法如下:
import torch
x = torch.randn(3, 4, 5) # 创建一个 3x4x5 大小的 Tensor
print(x.size()) # 打印 Tensor 的维度信息,输出为 torch.Size([3, 4, 5])
size()
方法返回的是一个 torch.Size
对象,可以通过索引获取维度信息,例如:
print(x.size()[0]) # 打印第一维的大小,输出为 3
print(x.size()[1]) # 打印第二维的大小,输出为 4
print(x.size()[2]) # 打印第三维的大小,输出为 5
第二种方法:使用 Numpy 数组的方法
PyTorch Tensor 支持将其转换成 NumPy 数组(numpy()
方法),NumPy 数组可以方便地使用 NumPy 提供的方法获取其维度信息。具体用法如下:
import torch
import numpy as np
x = torch.randn(3, 4, 5) # 创建一个 3x4x5 大小的 Tensor
y = x.numpy() # 将 Tensor 转成 NumPy 数组
print(np.shape(y)) # 打印 NumPy 数组的维度信息,输出为 (3, 4, 5)
numpy()
方法将 Tensor 转换成 NumPy 数组后,可以使用 np.shape()
方法获取其维度信息。
以上就是两种获取 PyTorch Tensor 维度信息的方法。如果你还有其他问题,请随时问我。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 获取tensor维度信息示例 - Python技术站