pytorch implements GRL Gradient Reversal Layer

In GRL, the goal to be achieved is: during the forward conduction, the result of the calculation does not change, and during the gradient conduction, the gradient passed to the previous leaf node becomes the original opposite direction. An example best illustrates:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()

print(s)
s.backward()
print(x)
print(x.grad)

The result of this program is:

tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])

The calculation process for each dimension in tensor is:

f(x)=(x^{2}+x)*6

Then the derivative for x is:

\frac{\mathrm{d} f}{\mathrm{d} x} = 12x+6

So when input x=[1,2,3], the corresponding gradient is: [18,30,42]

So this is a normal gradient derivation process, but how to perform gradient flip? Very simple, look at the code below:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input


Grl = GRL()

s =6* f.sum()
s = Grl(s)

print(s)
s.backward()
print(x)
print(x.grad)

The running result is:

tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])

Compared with the previous program, this program only adds a gradient flip layer:

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input

The forward in this part does not perform any operation, and the .neg() operation is performed in the backward, which is equivalent to the inversion of the gradient. In the backward part of FUnction in torch.autograd, without doing any operation, the default value of grad_output here is 1.

Guess you like

Origin blog.csdn.net/t20134297/article/details/107870906