PyTorch-->torch.max()的用法

 

 

 

 PyTorch-->torch.max()的用法

 

 

   _, predited = torch.max(outputs,1)   # 此处表示返回一个元组中有两个值,但是对第一个不感兴趣

  • 返回的元组的第一个元素是image data,即是最大的值;第二个元素是label,即是最大的值对应的索引。由于我们只需要label(最大值的索引),所以有 _ , predicted这样的赋值语句,表示忽略第一个返回值,把它赋值给 _,即舍弃它。
  • 第2个参数1,是 the dimension to reduce,而不是去这个dimension上面找最大的值。

     上述的a是一个4 * 4的TENSOR,所以dim=1指的是消除列这个维度,如何理解它的含义呢? 

  如果将上面的示例代码中的参数 keepdim=True加上,即torch.max(a,1,keepdim=True),会发现返回的结果的第一个元素,即表示最大的值的那部分,是一个size=4*1的Tensor,也就是其实它是按照行来找最大值,所以得到的结果是4行;因为只找一个最大值,所以是1列,整个的size就是 4行 1 列。参数dim=1,相当于调用了 squeeze(1)这个操作,最后就得到结果是一个size为4的vector。

PyTorch-->torch.max()的用法

 

 

注:如果dim=0,则返回每列的最大值。

所以一定不要混淆!这里的dim是指的 the dimension to reduce!并不是在这个dimension上去返回最大值!!!

用 torch.argmax()这个函数似乎更直观,更好理解一些。