在 PyTorch 中,可以使用 .dtype
和 .size()
两个函数来查看数据类型和大小。下面是具体的步骤:
查看数据类型
可以使用 .dtype
函数来查看 Tensor 的数据类型,具体步骤如下:
- 通过加载PyTorch库和创建一个张量,如下代码所示:
import torch
x = torch.ones(2, 3)
这里创建了一个大小为 $2 \times 3$ 的张量 x,并将所有元素初始化为 1。
- 调用
.dtype
函数来查看数据类型,代码如下:
print(x.dtype)
输出结果为:
torch.float32
这表示张量 x 中的元素数据类型为 32 位浮点数。
此外,PyTorch 还支持其他数据类型,例如 int、long、float等等。具体可以参考官方文档中的数据类型章节。
查看大小
可以使用 .size()
函数来查看张量的形状大小,具体步骤如下:
- 继续使用上面的代码,调用
.size()
函数来查看张量的形状大小,代码如下:
print(x.size())
输出结果为:
torch.Size([2, 3])
这表示张量 x 的大小为 $2 \times 3$。
- 另外,也可以使用
.shape
属性来查看大小,代码如下:
print(x.shape)
输出结果和 .size()
函数结果相同。
综上所述,可以使用 .dtype
和 .size()
来方便地查看 Tensor 的数据类型和大小。下面再举一个例子:
import torch
x = torch.rand(5, 3) # 创建一个大小为 5x3 的随机张量
print(x.dtype) # 输出张量的数据类型,一般为 float32
print(x.size()) # 输出张量的大小,一般形如 torch.Size([5, 3])
输出结果如下:
torch.float32
torch.Size([5, 3])
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 如何查看数据类型和大小 - Python技术站