pytorch 实现 GRL Gradient Reversal Layer

在GRL中,要实现的目标是:在前向传导的时候,运算结果不变化,在梯度传导的时候,传递给前面的叶子节点的梯度变为原来的相反方向。举个例子最好说明了:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()

print(s)
s.backward()
print(x)
print(x.grad)

这个程序的运行结果是:

tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])

这个运算过程对于tensor中的每个维度上的运算为:

f(x)=(x^{2}+x)*6

那么对于x的导数为:

\frac{\mathrm{d} f}{\mathrm{d} x} = 12x+6

所以当输入x=[1,2,3]时,对应的梯度为:[18,30,42]

因此这个是正常的梯度求导过程,但是如何进行梯度翻转呢?很简单,看下方的代码:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input


Grl = GRL()

s =6* f.sum()
s = Grl(s)

print(s)
s.backward()
print(x)
print(x.grad)

运行结果为:

tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])

这个程序相对于上一个程序,只是差在加了一个梯度翻转层:

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input

这个部分的forward没有进行任何操作,backward里面做了.neg()操作,相当于进行了梯度的翻转。在torch.autograd 中的FUnction 的backward部分,在不做任何操作的情况下,这里的grad_output的默认值是1.

猜你喜欢

转载自blog.csdn.net/t20134297/article/details/107870906