Check the intermediate variable gradient pytorch

pytorch To save memory, the process is reversed for propagation only in the calculation of FIG leaf node (leaf variable) retains a gradient value (gradient). But for developers, sometimes we want to probe gradient certain intermediate variables (intermediate variable) to verify that our implementation is wrong, this process will need to use tensor of register_hookthe interface. A simple piece of sample code below, the code mainly from the answer pytorch developer , author slight modifications to make it more in line with the latest version of pytorch syntax (v1.2.0).

grads = {}

def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y

# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))

# 反向传播 
z.backward()

# 查看 y 的梯度值
print(grads['y'])

One example output is:

tensor([-1.5435])

Guess you like

Origin www.cnblogs.com/SivilTaram/p/pytorch_intermediate_variable_gradient.html