中間変数勾配pytorchをチェック

pytorchプロセスのみ図葉ノード(リーフ変数)の計算に伝播するために反転され、メモリを節約する勾配値(勾配)を保持します。しかし、開発者のために、時々私達は私達の実装が間違っていることを確認するために、勾配一定の中間変数(中間変数)をプローブしたい、このプロセスは、のテンソル使用する必要がありますregister_hookインターフェイスを。以下のサンプルコードのシンプルな作品は、主からのコードの回答pytorch開発 pytorch構文(V1.2.0)の最新バージョンに合わせて、よりそれを作るために、著者若干の修正。

grads = {}

def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y

# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))

# 反向传播 
z.backward()

# 查看 y 的梯度值
print(grads['y'])

一つの例の出力は次のようになります。

tensor([-1.5435])

おすすめ

転載: www.cnblogs.com/SivilTaram/p/pytorch_intermediate_variable_gradient.html
おすすめ