1、cat拼接

  • 功能:通过dim指定维度,在当前指定维度上直接拼接
  • 默认是dim=0
  • 指定的dim上,维度可以不相同,其他dim上维度必须相同,不然会报错。

1)拼接两个维度相同的数

a = torch.rand(2, 3, 2)
a
# 输出:
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096]]])
    
b = torch.rand(2, 3, 2)  # 定义b与a大小相同
b
# 输出:
    tensor([[[0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])   

    
# dim选定合并的维度
torch.cat([a, b])  # 不指定dim时,默认是0
# 输出:
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096]],

            [[0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])

# 选定合并的维度dim=0
torch.cat([a, b], dim=0)  # 指定dim=0,可以看到结果和上面的是一样的
# 输出:   
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096]],

            [[0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])
    
# 选定合并的维度dim=1
torch.cat([a, b], dim=1)
# 输出:
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390],
             [0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096],
             [0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])    
    
# 选定合并的维度dim=2
torch.cat([a, b], dim=2)
# 输出:
    tensor([[[0.6072, 0.6531, 0.6144, 0.4561],
             [0.2023, 0.2506, 0.9263, 0.0644],
             [0.0590, 0.3390, 0.2838, 0.3456]],

            [[0.3994, 0.0110, 0.1126, 0.5303],
             [0.3615, 0.3826, 0.8140, 0.5715],
             [0.3033, 0.3096, 0.7627, 0.5095]]])    

2)拼接两个维度不同的数

结合上面维度相同的数对比,便于理解

a = torch.rand(2, 3, 2)
a
# 输出:
    tensor([[[0.6447, 0.9758],
             [0.0688, 0.9082],
             [0.0083, 0.0109]],

            [[0.5239, 0.1217],
             [0.9562, 0.6831],
             [0.8691, 0.2769]]])
    
b = torch.rand(2, 2, 2)
b
# 输出:   
tensor([[[0.3604, 0.7585],
         [0.7831, 0.0439]],

        [[0.2040, 0.5002],
         [0.8878, 0.5973]]])

# 不指定dim:
torch.cat([a, b])
# 因为dim默认是0,且a,b的dim[1]的大小不等(a是3, b是2),所以导致会报错
# 输出:   
    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-32-5484713fecdf> in <module>
    ----> 1 torch.cat([a, b])
          2 # 输出:

    RuntimeError: inv
        
# 可以看到,此时的ab,因为只有dim[1]不同,所以如果要用cat合并,只能在dim=1上合并       
torch.cat([a, b], dim=1)
# 输出:   
    tensor([[[0.6447, 0.9758],
             [0.0688, 0.9082],
             [0.0083, 0.0109],
             [0.3604, 0.7585],
             [0.7831, 0.0439]],

            [[0.5239, 0.1217],
             [0.9562, 0.6831],
             [0.8691, 0.2769],
             [0.2040, 0.5002],
             [0.8878, 0.5973]]])        

2.stack拼接

  • 与cat不同的是,stack是在拼接的同时,在指定dim处插入维度后拼接。
  • 可以理解为:stack是在指定维度处,分别为两个维度数据加上一层[]后,再进行拼接。
  • 对比cat会发现,cat的相同维度的两部分数据是在一个[]里面,而stack的两部分数据分别是在2个[]里面
  • stack拼接的两个数据,其所有维度必须相同
  • 默认dim=0
a = torch.rand(2, 5)
a
# 输出:
    tensor([[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
            [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]])

b = torch.rand(2, 5)
b
# 输出:
    tensor([[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
            [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]])
    
# 默认dim=0。将两个数据直接拼接
torch.stack([a, b])
# 输出:
    tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
             [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],

            [[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
             [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
    
# 指定dim=0
# 此处可以对比cat拼接,发现同样是dim=0,cat的数据在一个[]里面。此处是数据被分成了2段(在两个[]里面)
torch.stack([a, b], dim=0)  # 可以看到和上面默认的结果一致
# 输出:
    tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
             [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],

            [[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
             [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
    
# 指定dim=1。将数据在dim=1维度上拼接。
# 注意:结果后上面dim=0有区别。
torch.stack([a, b], dim=1)
# 输出:
    tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
             [0.7893, 0.4141, 0.2971, 0.6791, 0.9791]],

            [[0.6929, 0.4945, 0.0631, 0.4546, 0.6918],
             [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
    
# 指定dim=2。将数据在dim=2维度上拼接。
torch.stack([a, b], dim=2)
# 输出:
    tensor([[[0.2214, 0.7893],
             [0.2666, 0.4141],
             [0.6486, 0.2971],
             [0.7050, 0.6791],
             [0.4259, 0.9791]],

            [[0.6929, 0.4722],
             [0.4945, 0.7540],
             [0.0631, 0.5282],
             [0.4546, 0.0625],
             [0.6918, 0.0448]]])    

3、split拆分

  • 指定拆分dim
  • 给定拆分后的数据大小
a = torch.rand(4, 3, 2)
a
# 输出:
    tensor([[[0.5790, 0.6024],
             [0.4730, 0.0734],
             [0.2274, 0.7212]],

            [[0.7051, 0.1568],
             [0.5890, 0.1075],
             [0.7469, 0.0659]],

            [[0.7780, 0.5424],
             [0.4344, 0.8551],
             [0.6729, 0.7372]],

            [[0.1669, 0.8596],
             [0.9490, 0.8378],
             [0.7889, 0.2192]]])

    
# 默认情况下dim=0    
# 因为dim=0的大小是4,所以拆分为2 + 2 = 4,或者1+3=4.。。均可
a.split([2, 2])  
# 输出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734],
              [0.2274, 0.7212]],

             [[0.7051, 0.1568],
              [0.5890, 0.1075],
              [0.7469, 0.0659]]]),
     tensor([[[0.7780, 0.5424],
              [0.4344, 0.8551],
              [0.6729, 0.7372]],

             [[0.1669, 0.8596],
              [0.9490, 0.8378],
              [0.7889, 0.2192]]]))

    
# 因为dim=1的大小是3,所以拆分为2 + 1 = 3
a.split([2, 1], dim=1)
# 输出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734]],

             [[0.7051, 0.1568],
              [0.5890, 0.1075]],

             [[0.7780, 0.5424],
              [0.4344, 0.8551]],

             [[0.1669, 0.8596],
              [0.9490, 0.8378]]]),
     tensor([[[0.2274, 0.7212]],

             [[0.7469, 0.0659]],

             [[0.6729, 0.7372]],

             [[0.7889, 0.2192]]]))    


# 因为dim=2的大小是2,所以拆分为1 + 1 = 2
a.split([1, 1], dim=2)
# 输出:
(tensor([[[0.5790],
          [0.4730],
          [0.2274]],
 
         [[0.7051],
          [0.5890],
          [0.7469]],
 
         [[0.7780],
          [0.4344],
          [0.6729]],
 
         [[0.1669],
          [0.9490],
          [0.7889]]]),
 tensor([[[0.6024],
          [0.0734],
          [0.7212]],
 
         [[0.1568],
          [0.1075],
          [0.0659]],
 
         [[0.5424],
          [0.8551],
          [0.7372]],
 
         [[0.8596],
          [0.8378],
          [0.2192]]]))    

chunk拆分

  • chunk是在指定dim下给定,平均拆分的个数
  • 如果给定个数不能平均拆分当前维度,则会取比给定个数小的,能平均拆分数据的,最大的个数
  • dim默认是0
a
# 输出:
    tensor([[[0.5790, 0.6024],
             [0.4730, 0.0734],
             [0.2274, 0.7212]],

            [[0.7051, 0.1568],
             [0.5890, 0.1075],
             [0.7469, 0.0659]],

            [[0.7780, 0.5424],
             [0.4344, 0.8551],
             [0.6729, 0.7372]],

            [[0.1669, 0.8596],
             [0.9490, 0.8378],
             [0.7889, 0.2192]]])
    
# 默认dim=0
# 在dim=0上,将数据平均分成4份
a.chunk(4)
# 输出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734],
              [0.2274, 0.7212]]]),
     tensor([[[0.7051, 0.1568],
              [0.5890, 0.1075],
              [0.7469, 0.0659]]]),
     tensor([[[0.7780, 0.5424],
              [0.4344, 0.8551],
              [0.6729, 0.7372]]]),
     tensor([[[0.1669, 0.8596],
              [0.9490, 0.8378],
              [0.7889, 0.2192]]]))    
    
# 在dim=0上,将数据平均分成4份
# 因为4不能被3整除,且比3小,能把4整除的数是2。所以,虽然给定是3,其实得到的结果为2个部分。
a.chunk(3, dim=0)
# 输出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734],
              [0.2274, 0.7212]],

             [[0.7051, 0.1568],
              [0.5890, 0.1075],
              [0.7469, 0.0659]]]),
     tensor([[[0.7780, 0.5424],
              [0.4344, 0.8551],
              [0.6729, 0.7372]],

             [[0.1669, 0.8596],
              [0.9490, 0.8378],
              [0.7889, 0.2192]]]))    
    
# 在dim=1上,将数据平均分成3份
a.chunk(3, dim=1)
# 输出:
    (tensor([[[0.5790, 0.6024]],

             [[0.7051, 0.1568]],

             [[0.7780, 0.5424]],

             [[0.1669, 0.8596]]]),
     tensor([[[0.4730, 0.0734]],

             [[0.5890, 0.1075]],

             [[0.4344, 0.8551]],

             [[0.9490, 0.8378]]]),
     tensor([[[0.2274, 0.7212]],

             [[0.7469, 0.0659]],

             [[0.6729, 0.7372]],

             [[0.7889, 0.2192]]]))    
    
# 在dim=2上,将数据平均分成3份
a.chunk(2, dim=2)
# 输出:
    (tensor([[[0.5790],
              [0.4730],
              [0.2274]],

             [[0.7051],
              [0.5890],
              [0.7469]],

             [[0.7780],
              [0.4344],
              [0.6729]],

             [[0.1669],
              [0.9490],
              [0.7889]]]),
     tensor([[[0.6024],
              [0.0734],
              [0.7212]],

             [[0.1568],
              [0.1075],
              [0.0659]],

             [[0.5424],
              [0.8551],
              [0.7372]],

             [[0.8596],
              [0.8378],
              [0.2192]]]))