1、cat拼接
- 功能:通过dim指定维度,在当前指定维度上直接拼接
- 默认是dim=0
- 指定的dim上,维度可以不相同,其他dim上维度必须相同,不然会报错。
1)拼接两个维度相同的数
a = torch.rand(2, 3, 2)
a
# 输出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096]]])
b = torch.rand(2, 3, 2) # 定义b与a大小相同
b
# 输出:
tensor([[[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# dim选定合并的维度
torch.cat([a, b]) # 不指定dim时,默认是0
# 输出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096]],
[[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# 选定合并的维度dim=0
torch.cat([a, b], dim=0) # 指定dim=0,可以看到结果和上面的是一样的
# 输出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096]],
[[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# 选定合并的维度dim=1
torch.cat([a, b], dim=1)
# 输出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390],
[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096],
[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# 选定合并的维度dim=2
torch.cat([a, b], dim=2)
# 输出:
tensor([[[0.6072, 0.6531, 0.6144, 0.4561],
[0.2023, 0.2506, 0.9263, 0.0644],
[0.0590, 0.3390, 0.2838, 0.3456]],
[[0.3994, 0.0110, 0.1126, 0.5303],
[0.3615, 0.3826, 0.8140, 0.5715],
[0.3033, 0.3096, 0.7627, 0.5095]]])
2)拼接两个维度不同的数
结合上面维度相同的数对比,便于理解
a = torch.rand(2, 3, 2)
a
# 输出:
tensor([[[0.6447, 0.9758],
[0.0688, 0.9082],
[0.0083, 0.0109]],
[[0.5239, 0.1217],
[0.9562, 0.6831],
[0.8691, 0.2769]]])
b = torch.rand(2, 2, 2)
b
# 输出:
tensor([[[0.3604, 0.7585],
[0.7831, 0.0439]],
[[0.2040, 0.5002],
[0.8878, 0.5973]]])
# 不指定dim:
torch.cat([a, b])
# 因为dim默认是0,且a,b的dim[1]的大小不等(a是3, b是2),所以导致会报错
# 输出:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-32-5484713fecdf> in <module>
----> 1 torch.cat([a, b])
2 # 输出:
RuntimeError: inv
# 可以看到,此时的ab,因为只有dim[1]不同,所以如果要用cat合并,只能在dim=1上合并
torch.cat([a, b], dim=1)
# 输出:
tensor([[[0.6447, 0.9758],
[0.0688, 0.9082],
[0.0083, 0.0109],
[0.3604, 0.7585],
[0.7831, 0.0439]],
[[0.5239, 0.1217],
[0.9562, 0.6831],
[0.8691, 0.2769],
[0.2040, 0.5002],
[0.8878, 0.5973]]])
2.stack拼接
- 与cat不同的是,stack是在拼接的同时,在指定dim处插入维度后拼接。
- 可以理解为:stack是在指定维度处,分别为两个维度数据加上一层[]后,再进行拼接。
- 对比cat会发现,cat的相同维度的两部分数据是在一个[]里面,而stack的两部分数据分别是在2个[]里面
- stack拼接的两个数据,其所有维度必须相同
- 默认dim=0
a = torch.rand(2, 5)
a
# 输出:
tensor([[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.6929, 0.4945, 0.0631, 0.4546, 0.6918]])
b = torch.rand(2, 5)
b
# 输出:
tensor([[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]])
# 默认dim=0。将两个数据直接拼接
torch.stack([a, b])
# 输出:
tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],
[[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])
# 指定dim=0
# 此处可以对比cat拼接,发现同样是dim=0,cat的数据在一个[]里面。此处是数据被分成了2段(在两个[]里面)
torch.stack([a, b], dim=0) # 可以看到和上面默认的结果一致
# 输出:
tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],
[[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])
# 指定dim=1。将数据在dim=1维度上拼接。
# 注意:结果后上面dim=0有区别。
torch.stack([a, b], dim=1)
# 输出:
tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.7893, 0.4141, 0.2971, 0.6791, 0.9791]],
[[0.6929, 0.4945, 0.0631, 0.4546, 0.6918],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])
# 指定dim=2。将数据在dim=2维度上拼接。
torch.stack([a, b], dim=2)
# 输出:
tensor([[[0.2214, 0.7893],
[0.2666, 0.4141],
[0.6486, 0.2971],
[0.7050, 0.6791],
[0.4259, 0.9791]],
[[0.6929, 0.4722],
[0.4945, 0.7540],
[0.0631, 0.5282],
[0.4546, 0.0625],
[0.6918, 0.0448]]])
3、split拆分
- 指定拆分dim
- 给定拆分后的数据大小
a = torch.rand(4, 3, 2)
a
# 输出:
tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]],
[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]])
# 默认情况下dim=0
# 因为dim=0的大小是4,所以拆分为2 + 2 = 4,或者1+3=4.。。均可
a.split([2, 2])
# 输出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]]]),
tensor([[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]]))
# 因为dim=1的大小是3,所以拆分为2 + 1 = 3
a.split([2, 1], dim=1)
# 输出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734]],
[[0.7051, 0.1568],
[0.5890, 0.1075]],
[[0.7780, 0.5424],
[0.4344, 0.8551]],
[[0.1669, 0.8596],
[0.9490, 0.8378]]]),
tensor([[[0.2274, 0.7212]],
[[0.7469, 0.0659]],
[[0.6729, 0.7372]],
[[0.7889, 0.2192]]]))
# 因为dim=2的大小是2,所以拆分为1 + 1 = 2
a.split([1, 1], dim=2)
# 输出:
(tensor([[[0.5790],
[0.4730],
[0.2274]],
[[0.7051],
[0.5890],
[0.7469]],
[[0.7780],
[0.4344],
[0.6729]],
[[0.1669],
[0.9490],
[0.7889]]]),
tensor([[[0.6024],
[0.0734],
[0.7212]],
[[0.1568],
[0.1075],
[0.0659]],
[[0.5424],
[0.8551],
[0.7372]],
[[0.8596],
[0.8378],
[0.2192]]]))
chunk拆分
- chunk是在指定dim下给定,平均拆分的个数
- 如果给定个数不能平均拆分当前维度,则会取比给定个数小的,能平均拆分数据的,最大的个数
- dim默认是0
a
# 输出:
tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]],
[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]])
# 默认dim=0
# 在dim=0上,将数据平均分成4份
a.chunk(4)
# 输出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]]]),
tensor([[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]]]),
tensor([[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]]]),
tensor([[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]]))
# 在dim=0上,将数据平均分成4份
# 因为4不能被3整除,且比3小,能把4整除的数是2。所以,虽然给定是3,其实得到的结果为2个部分。
a.chunk(3, dim=0)
# 输出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]]]),
tensor([[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]]))
# 在dim=1上,将数据平均分成3份
a.chunk(3, dim=1)
# 输出:
(tensor([[[0.5790, 0.6024]],
[[0.7051, 0.1568]],
[[0.7780, 0.5424]],
[[0.1669, 0.8596]]]),
tensor([[[0.4730, 0.0734]],
[[0.5890, 0.1075]],
[[0.4344, 0.8551]],
[[0.9490, 0.8378]]]),
tensor([[[0.2274, 0.7212]],
[[0.7469, 0.0659]],
[[0.6729, 0.7372]],
[[0.7889, 0.2192]]]))
# 在dim=2上,将数据平均分成3份
a.chunk(2, dim=2)
# 输出:
(tensor([[[0.5790],
[0.4730],
[0.2274]],
[[0.7051],
[0.5890],
[0.7469]],
[[0.7780],
[0.4344],
[0.6729]],
[[0.1669],
[0.9490],
[0.7889]]]),
tensor([[[0.6024],
[0.0734],
[0.7212]],
[[0.1568],
[0.1075],
[0.0659]],
[[0.5424],
[0.8551],
[0.7372]],
[[0.8596],
[0.8378],
[0.2192]]]))
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 数据拼接与拆分cat、stack、split、chunck - Python技术站