pytorchプロセスのみ図葉ノード(リーフ変数)の計算に伝播するために反転され、メモリを節約する勾配値(勾配)を保持します。しかし、開発者のために、時々私達は私達の実装が間違っていることを確認するために、勾配一定の中間変数(中間変数)をプローブしたい、このプロセスは、のテンソル使用する必要がありますregister_hook
インターフェイスを。以下のサンプルコードのシンプルな作品は、主からのコードの回答pytorch開発 pytorch構文(V1.2.0)の最新バージョンに合わせて、よりそれを作るために、著者若干の修正。
grads = {}
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播
z.backward()
# 查看 y 的梯度值
print(grads['y'])
一つの例の出力は次のようになります。
tensor([-1.5435])