当我们使用PyTorch时,经常会遇到需要“切断计算图”的情况,同时需要保留某些tensor的值。两个常用的方法就是 detach()
和 data
,但它们具有一些区别。
detach()
和data
的基本作用
detach()
: 用于将一个tensor从计算图上分离出来,并返回一个新的不与计算图相连接的tensor。使用detach()
可以阻止梯度反向传播算法对该tensor的追踪、更新,使其在计算图中断,即成为叶子节点。data
: 用于返回一个新的tensor,这个新的tensor和原始的tensor有相同的数值,但是没有梯度信息。即使在不需要计算梯度时,这个新的tensor仍然可能被加入到计算图中。
detach()
和data
的区别
- 不同点1:
detach()
返回的tensor不再与计算图相连接,而data
返回的新tensor可能仍然会出现在计算图中; - 不同点2:因为
detach()
返回的新tensor是一个新的tensor,它在内存中有新的地址,所以如果对其进行修改,不会影响原来的tensor的值;而data
返回的新tensor在内存中和原来的tensor可能共享一块内存,具体是否共享要根据具体实现而定,如果共享的话,修改新的tensor会改变原来的tensor的值; - 不同点3:
detach()
可以直接作用于具有requires_grad
=True的tensor,而data
只能作用于非叶子节点的tensor。
示例
下面通过两条示例说明 detach()
和 data
的区别。
示例1
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(2, 1)
def forward(self, x):
x = self.fc(x)
x = x.detach()
return x
model = Net()
x = torch.tensor([[1., 2.], [3., 4.]])
y = model(x)
print(y.requires_grad) # False
在示例1中,我们定义了一个简单的神经网络,它只有一个全连接层。我们在forward过程中使用了detach()
方法,将计算图从计算结果中断开,得到一个不需要梯度的新的tensor。在这个例子中,我们检查y
的requires_grad
属性,确认它已被成功设置为False
示例2
import torch
a = torch.tensor([1., 2.], requires_grad=True)
b = a.data.clone().detach()
c = a.data.clone()
print(b.requires_grad) # False
print(torch.all(torch.eq(b, c))) # True
a[0] = 100
print(a) # tensor([100., 2.], requires_grad=True)
print(b) # tensor([1., 2.])
print(c) # tensor([1., 2.])
在示例2中,我们定义了一个张量a,并将其设置为需要计算梯度,然后使用data
方法得到一个新的tensor b,和一个新的tensor c。我们检查b和c的requires_grad
属性,确认b已被成功设置为False,而c的属性仍然为True。接着我们更改a的值,然后打印出a,b,c的值。可以看到,因为新创的tensor b不共享内存,所以在a被修改时,tensor b的值不变。而新创的tensor c共享内存,所以在a被修改时,tensor c的值也发生了变化。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch中 tensor.detach() 和 tensor.data 的区别解析 - Python技术站