PyTorch 是一个广泛使用的深度学习框架,实现了大量的深度学习算法和模型,作为一个深度学习从业者,经常需要对图像处理进行相关处理,如将图像从 HWC(height、width、channel)格式转化为 CHW(channel、height、width)格式。下面将提供两种方法实现 HWC 转 CHW。
方法一: 使用 permute() 函数
PyTorch 提供了 permute() 函数来交换 tensor 的维度顺序,使用该函数可以完成 HWC 到 CHW 格式的转换。下面是示例代码:
import torch
# 定义一个 4D 的 tensor
img_hwc = torch.rand(100, 200, 3)
# 将 tensor 转换为 CHW 格式
img_chw = img_hwc.permute(2, 0, 1)
# 打印 tensor 的形状
print("img_hwc.shape = ", img_hwc.shape)
print("img_chw.shape = ", img_chw.shape)
运行上述代码后,输出结果如下:
img_hwc.shape = (100, 200, 3)
img_chw.shape = torch.Size([3, 100, 200])
在上述代码中,我们首先定义一个形状为 100x200x3 的 4D tensor,然后使用 permute() 函数将通道维度从最后一个维度移动到第一个维度,从而将 tensor 转换为 CHW 格式。
方法二: 使用 transpose() 函数
除了 permute() 函数外,PyTorch 还提供了 transpose() 函数来交换 tensor 的维度顺序,也可以完成 HWC 到 CHW 格式的转换。下面是示例代码:
import torch
# 定义一个 4D 的 tensor
img_hwc = torch.rand(100, 200, 3)
# 将 tensor 转换为 CHW 格式
img_chw = img_hwc.transpose(1, 2).transpose(0, 1)
# 打印 tensor 的形状
print("img_hwc.shape = ", img_hwc.shape)
print("img_chw.shape = ", img_chw.shape)
运行上述代码后,输出结果如下:
img_hwc.shape = (100, 200, 3)
img_chw.shape = torch.Size([3, 100, 200])
在上述代码中,我们首先定义一个形状为 100x200x3 的 4D tensor,然后使用 transpose() 函数先交换高度和宽度维度,再交换通道和高度维度,从而将 tensor 转换为 CHW 格式。
总结:
总体来说,实现 HWC 到 CHW 格式的转换有多种方法,在 PyTorch 中,可以使用 permute() 函数或 transpose() 函数来实现。permute() 函数更为灵活,可以同时交换多个维度的位置;而 transpose() 函数则比较适合对单个或多个维度进行交换。需要根据具体的场景来选择不同的方法。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 如何实现HWC转CHW - Python技术站