転送元:https://blog.csdn.net/winycg/article/details/100813519
テンソル複製は、さまざまなニーズを達成するためにclone()
機能とdetach()
機能を使用できます。
クローン
clone()関数は同一のテンソルを返すことができ、新しいテンソルは新しいメモリを開きますが、それでも計算グラフに残ります。
clone
この操作は、データメモリを共有せずに勾配バックトラッキングをサポートするため、ニューラルネットワーク内のユニットを再利用する必要があるシナリオで一般的に使用されます。
デタッチ
detach()関数は、同一のテンソルを返すことができます。新しいテンソルが開き、古いテンソルとメモリを共有します。新しいテンソルは計算グラフから分離され、勾配計算は含まれません。さらに、一部のインプレース操作(resize_ / resize_as_ / set_ / transpose_などのインプレース)は、いずれかで実行するとエラーを引き起こす可能性があります。detach
操作は共有データメモリの計算グラフから分離されているため、微分を追跡する必要なしにテンソル値のみがニューラルネットワークで使用されるシナリオで一般的に使用されます。
使用状況分析
#操作 | 新規/共有メモリ | まだ計算グラフにあります |
tensor.clone() | 新着 | はい |
tensor.detach() | 共有 | 番号 |
tensor.clone()。detach() | 新着 | 番号 |
いくつかの例は次のとおりです。
最初にパッケージをインポートし、ランダムシードを修正します
import torch
torch.manual_seed(0)
1. tensor require_grad = clone()後はTrue、tensor require_grad = detach()後はFalseですが、clone()後は勾配がテンソルに流れません。
x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.clone().detach()
f = torch.nn.Linear(3, 1)
y = f(x)
y.backward()
print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
出力:
tensor([-0.0043, 0.3097, -0.4752])
True
None
False
False
2.計算グラフに含まれるテンソルをclone()の後のテンソルに変更します。このとき、勾配は元のテンソルにのみ流れます。
x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.detach().clone()
f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()
print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
出力:
tensor([-0.0043, 0.3097, -0.4752])
None
False
False
3.元のテンソルをrequire_grad = Falseに設定し、clone()後の勾配を.requires_grad _()に設定し、clone()後のテンソルを計算グラフの計算に参加させ、勾配をclone()の前にテンソルに送信します。
x= torch.tensor([1., 2., 3.], requires_grad=False)
clone_x = x.clone().requires_grad_()
detach_x = x.detach()
clone_detach_x = x.detach().clone()
f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()
print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
出力:
tensor([-0.0043, 0.3097, -0.4752])
None
False
False
4. detach()後のテンソルは元のテンソルとメモリを共有しているため、計算グラフで元のテンソルが更新された後、detach()のテンソル値も変更されています。
x = torch.tensor([1., 2., 3.], requires_grad=True)
f = torch.nn.Linear(3, 1)
w = f.weight.detach()
print(f.weight)
print(w)
y = f(x)
y.backward()
optimizer = torch.optim.SGD(f.parameters(), 0.1)
optimizer.step()
print(f.weight)
print(w)
出力:
Parameter containing:
tensor([[-0.0043, 0.3097, -0.4752]], requires_grad=True)
tensor([[-0.0043, 0.3097, -0.4752]])
Parameter containing:
tensor([[-0.1043, 0.1097, -0.7752]], requires_grad=True)
tensor([[-0.1043, 0.1097, -0.7752]])