PyTorch算子torch.arange在CPU/GPU/NPU中支持数据类型格式
torch.arange
是PyTorch库中用于创建一个具有一定规律的序列,即等差数列的函数。针对不同场景,torch.arange
也支持不同的数据类型格式,包括CPU、GPU和NPU。本文将详细介绍torch.arange
在不同设备上支持的数据类型格式。
支持的数据类型格式
CPU
在CPU上运行的torch.arange
支持以下数据类型格式:
-
torch.int: 默认数据类型,表示整型
-
torch.int64: 同样表示整型,但更精准
-
torch.float: 表示浮点数
-
torch.double: 同样表示浮点数,但更精准
-
torch.long: 与torch.int64相同,表示长整型
GPU
在GPU上运行的torch.arange
支持以下数据类型格式:
-
torch.float16: 表示半精度浮点数
-
torch.float32: 表示单精度浮点数,非常常用
-
torch.float64: 表示双精度浮点数
-
torch.int8: 表示带符号8位整数
-
torch.int16: 表示带符号16位整数
-
torch.int32: 表示带符号32位整数
-
torch.int64: 表示带符号64位整数,非常常用
NPU
在NPU上运行的torch.arange
支持以下数据类型格式:
-
torch.float16: 表示半精度浮点数
-
torch.float32: 表示单精度浮点数,非常常用
-
torch.int8: 表示带符号8位整数
-
torch.uint8: 表示无符号8位整数
示例说明
以下两个示例分别展示了在CPU和GPU上使用torch.arange
创建等差数列。
示例1
在CPU上运行torch.arange
创建0~19
的等差序列:
import torch
# 默认数据类型为torch.int
seq_cpu_int = torch.arange(0, 20)
# 数据类型指定为torch.long
seq_cpu_long = torch.arange(0, 20, dtype=torch.long)
# 输出结果
print(seq_cpu_int)
print(seq_cpu_long)
输出:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
在CPU上创建的等差序列默认数据类型为torch.int
,可以通过指定dtype
参数的值为torch.long
来创建torch.long
类型的等差序列。
示例2
在GPU上运行torch.arange
创建0~9
的等差序列:
import torch
# 在GPU上运行
device = torch.device("cuda")
# 数据类型指定为torch.float32
seq_gpu_float32 = torch.arange(0, 10, dtype=torch.float32, device=device)
# 数据类型指定为torch.int64
seq_gpu_int64 = torch.arange(0, 10, dtype=torch.int64, device=device)
# 输出结果
print(seq_gpu_float32)
print(seq_gpu_int64)
输出:
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
在GPU上创建等差序列时需要指定设备,可以通过torch.device
函数来指定设备地址。同时,也可以通过指定dtype
参数的值来创建不同类型的等差序列。例如,上述示例展示了如何创建torch.float32
和torch.int64
类型的等差序列。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch算子torch.arange在CPU GPU NPU中支持数据类型格式 - Python技术站