PyTorch中clone()、detach()及相关扩展详解
本文将详细讲解 PyTorch 中的 clone()
和 detach()
两个重要的函数,以及它们的相关扩展。
clone()
clone()
是一个非常常用的 PyTorch 函数,它用于创建张量的深度复制。具体来说,clone()
会创建一个与源张量拥有相同数据和属性的张量,但是二者之间只是值不同,互相之间不会影响。使用 clone()
可以避免深拷贝所带来的性能开销。
以下是 clone()
函数的使用示例:
import torch
x = torch.rand(2, 3)
y = x.clone() # 创建x的一个深拷贝y
y[0, 0] = 0.5 # 修改y的元素
print(x)
print(y) # y的第一个元素修改了,但x不受影响
输出结果为:
tensor([[0.5965, 0.5683, 0.3921],
[0.5231, 0.2374, 0.5873]])
tensor([[0.5000, 0.5683, 0.3921],
[0.5231, 0.2374, 0.5873]])
需要注意的是,clone()
是深拷贝操作,生成的新张量是独立的,如果源张量发生修改,不会影响到拷贝的新张量。但是因为是深拷贝操作,如果源张量包含大量数据,调用 clone()
会开辟一份完全相同的内存空间,因此需要谨慎使用。
detach()
detach()
是用于从计算图中分离出张量的函数,使用它可以将一个需要计算梯度的张量转化为不需要计算梯度的张量。通常情况下,我们通过训练优化器进行模型训练时,需要先将梯度清零,如果不通过 detach()
将需要计算梯度的张量数据分离出来,会使内存溢出,因为梯度一直在计算中。
下面是一个简单的示例,介绍了如何使用 detach()
函数:
import torch
x = torch.rand(2, 3)
y = torch.ones(2, 3, requires_grad=True)
optimizer = torch.optim.SGD([y], lr=0.1)
loss = (x + y).sum()
loss.backward()
optimizer.step()
print(y)
y = y.detach() # 将y从计算图中分离出来
y[0, 0] = 0.5 # 修改y的元素
print(y) # y的第一个元素修改了,但不会影响之前计算的梯度
输出结果为:
tensor([[0.8215, 1.3643, 1.2602],
[1.1125, 1.4065, 1.1799]], requires_grad=True)
tensor([[0.5000, 1.3643, 1.2602],
[1.1125, 1.4065, 1.1799]])
需要注意的是,detach()
函数不改变张量本身的值,只是返回一个与它有相同值的新张量,但是这个新张量不再参与计算图,因此无法被梯度更新。此外,detach()
函数不影响之前的计算梯度,它只是使其对张量本身没有影响。
相关扩展
除了 clone()
和 detach()
外,PyTorch 还提供了一些相关的函数,具体如下:
1. data
可以通过 data
获取张量对象的数据部分,但是需要注意的是,这个操作不会自动开启梯度,即不会记录任何与 data
相关的操作到计算图中。因此,当应用 PyTorch 构建深度学习模型时,应该避免使用 data
,通常使用 detach()
代替。
以下是使用 data
的一个示例:
import torch
x = torch.rand(2, 3, requires_grad=True)
y = x.data # 使用x的data属性来获取数据
y[0, 0] = 0.5 # 修改y的元素
print(x)
print(y)
输出结果为:
tensor([[0.5000, 0.6529, 0.3843],
[0.2868, 0.0422, 0.2184]], requires_grad=True)
tensor([[0.5000, 0.6529, 0.3843],
[0.2868, 0.0422, 0.2184]])
2. detach_
除了 detach()
函数外,PyTorch 还提供了一个原地(inplace)版本的分离函数 detach_()
,用于直接修改原张量,而不是创建一个新张量。
以下是使用 detach_()
的一个示例:
import torch
x = torch.rand(2, 3, requires_grad=True)
y = torch.ones(2, 3, requires_grad=True)
optimizer = torch.optim.SGD([y], lr=0.1)
loss = (x + y).sum()
loss.backward()
optimizer.step()
print(y)
y.detach_() # 使用detach_()将y分离出来
y[0, 0] = 0.5 # 修改y的元素
print(y)
输出结果为:
tensor([[0.3181, 1.0280, 1.5879],
[1.2681, 1.3820, 1.8077]], requires_grad=True)
tensor([[0.5000, 1.0280, 1.5879],
[1.2681, 1.3820, 1.8077]], requires_grad=True)
需要注意的是,detach_()
函数会原地(inplace)操作直接修改原张量,不会创建新的张量,并且它不会返回任何值,所以不能直接赋值给其他变量。此外,detach_()
与 detach()
函数的区别在于前者是原地操作而后者是创建新的张量。
综上所述,掌握了 PyTorch 中 clone()
和 detach()
两个函数的使用,以及相关的扩展函数,可以更加有效地使用 PyTorch 进行深度学习模型的开发。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch中clone()、detach()及相关扩展详解 - Python技术站