PyTorch中的squeeze()和unsqueeze()解析与应用案例

PyTorch中的squeeze()和unsqueeze()解析与应用案例

在PyTorch中,squeeze()和unsqueeze()是两个非常有用的函数,可以用于改变张量的形状。本文将介绍这两个函数的用法,并提供两个示例说明。

1. squeeze()函数

squeeze()函数可以用于删除张量中维度为1的维度。以下是一个示例,展示如何使用squeeze()函数。

import torch

# 创建一个形状为(1, 3, 1, 2)的张量
x = torch.randn(1, 3, 1, 2)

# 使用squeeze()函数删除维度为1的维度
y = torch.squeeze(x)

# 打印y的形状
print(y.shape)

在上面的示例中,我们首先创建了一个形状为(1, 3, 1, 2)的张量x。然后,我们使用squeeze()函数删除维度为1的维度,并将结果保存在y中。最后,我们打印y的形状,发现它的形状为(3, 2)。

2. unsqueeze()函数

unsqueeze()函数可以用于在张量中插入一个新的维度。以下是一个示例,展示如何使用unsqueeze()函数。

import torch

# 创建一个形状为(3, 2)的张量
x = torch.randn(3, 2)

# 使用unsqueeze()函数在第0维插入一个新的维度
y = torch.unsqueeze(x, 0)

# 打印y的形状
print(y.shape)

在上面的示例中,我们首先创建了一个形状为(3, 2)的张量x。然后,我们使用unsqueeze()函数在第0维插入一个新的维度,并将结果保存在y中。最后,我们打印y的形状,发现它的形状为(1, 3, 2)。

3. 示例1:使用squeeze()函数删除维度为1的维度

以下是一个示例,展示如何使用squeeze()函数删除维度为1的维度。

import torch

# 创建一个形状为(1, 3, 1, 2)的张量
x = torch.randn(1, 3, 1, 2)

# 使用squeeze()函数删除维度为1的维度
y = torch.squeeze(x)

# 打印x和y的形状
print(x.shape)
print(y.shape)

在上面的示例中,我们首先创建了一个形状为(1, 3, 1, 2)的张量x。然后,我们使用squeeze()函数删除维度为1的维度,并将结果保存在y中。最后,我们打印x和y的形状,发现x的形状为(1, 3, 1, 2),而y的形状为(3, 2)。

4. 示例2:使用unsqueeze()函数在第0维插入一个新的维度

以下是一个示例,展示如何使用unsqueeze()函数在第0维插入一个新的维度。

import torch

# 创建一个形状为(3, 2)的张量
x = torch.randn(3, 2)

# 使用unsqueeze()函数在第0维插入一个新的维度
y = torch.unsqueeze(x, 0)

# 打印x和y的形状
print(x.shape)
print(y.shape)

在上面的示例中,我们首先创建了一个形状为(3, 2)的张量x。然后,我们使用unsqueeze()函数在第0维插入一个新的维度,并将结果保存在y中。最后,我们打印x和y的形状,发现x的形状为(3, 2),而y的形状为(1, 3, 2)。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch中的squeeze()和unsqueeze()解析与应用案例 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • PyTorch笔记之scatter()函数的使用

    PyTorch笔记之scatter()函数的使用 在PyTorch中,scatter()函数可以用于将一个张量中的数据按照指定的索引分散到另一个张量中。本文将介绍scatter()函数的用法,并提供两个示例说明。 1. scatter()函数的用法 scatter()函数的语法如下: torch.scatter(input, dim, index, src)…

    PyTorch 2023年5月15日
    00
  • 训练一个图像分类器demo in PyTorch【学习笔记】

    【学习源】Tutorials > Deep Learning with PyTorch: A 60 Minute Blitz > Training a Classifier  本文相当于对上面链接教程中自认为有用部分进行的截取、翻译和再注释。便于日后复习、修正和补充。 边写边查资料的过程中猛然发现这居然有中文文档……不过中文文档也是志愿者翻译的,…

    2023年4月8日
    00
  • pytorch训练模型的一些坑

    1. 图像读取 opencv的python和c++读取的图像结果不一致,是因为python和c++采用的opencv版本不一样,从而使用的解码库不同,导致读取的结果不同。 详细内容参考:https://www.cnblogs.com/haiyang21/p/11655404.html 2. 图像变换 PIL和pytorch的图像resize操作,与openc…

    PyTorch 2023年4月8日
    00
  • pytorch transform 和 OpenCV及PIL转换

    img_path = “./data/img_37.jpg” # transforms.ToTensor() transform1 = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] ] ) ## openCV img = cv2.imread(img_…

    PyTorch 2023年4月8日
    00
  • pytorch教程之Tensor的值及操作使用学习

    当涉及到深度学习框架时,PyTorch是一个非常流行的选择。在PyTorch中,Tensor是一个非常重要的概念,它是一个多维数组,可以用于存储和操作数据。在本教程中,我们将学习如何使用PyTorch中的Tensor,包括如何创建、访问和操作Tensor。 创建Tensor 在PyTorch中,我们可以使用torch.Tensor()函数来创建一个Tenso…

    PyTorch 2023年5月15日
    00
  • pytorch 图片处理.md

    本篇所有代码位置链接???? pytorch 图片处理,主要用到 torchvision 模块的 datasets 和 transforms。 例如:本地图片资源目录结构如下 ➜ torch_test tree animal_data animal_data ├── train │   ├── ants │   │   ├── 0013035.jpg │  …

    2023年4月8日
    00
  • pytorch children和modules

    参考1参考2官方论坛讨论 children: 只包括网络的第一级孩子,不包括孩子的孩子modules: 深度优先遍历,先输出孩子,再输出孩子的孩子,孩子的孩子的孩子。。。 children的用法:加载预训练模型 resnet = models.resnet50(pretrained=True) modules = list(resnet.children()…

    PyTorch 2023年4月8日
    00
  • 解决pytorch报错ImportError: numpy.core.multiarray failed to import

    最近在学pytorch,先用官网提供的conda命令安装了一下: conda install pytorch torchvision cudatoolkit=10.2 -c pytorch 然后按照官网的方法测试是否安装成功,能不能正常使用: import torch x = torch.rand(5, 3) print(x) 若能正常打印出x的值,则说明可…

    PyTorch 2023年4月7日
    00
合作推广
合作推广
分享本页
返回顶部