GRLでは、達成すべき目標は次のとおりです。順方向伝導中は計算結果は変化せず、勾配伝導中は前のリーフノードに渡された勾配が元の反対方向になります。例は最もよく説明します:
import torch
from torch.autograd import Function
x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)
z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()
print(s)
s.backward()
print(x)
print(x.grad)
このプログラムの結果は次のとおりです。
tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])
テンソルの各次元の計算プロセスは次のとおりです。
xの導関数は次のとおりです。
したがって、入力x = [1,2,3]の場合、対応する勾配は次のようになります。[18,30,42]
これは通常の勾配導出プロセスですが、勾配フリップを実行するにはどうすればよいですか?非常に単純です。以下のコードを見てください。
import torch
from torch.autograd import Function
x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)
z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
class GRL(Function):
def forward(self,input):
return input
def backward(self,grad_output):
grad_input = grad_output.neg()
return grad_input
Grl = GRL()
s =6* f.sum()
s = Grl(s)
print(s)
s.backward()
print(x)
print(x.grad)
実行結果は次のとおりです。
tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])
前のプログラムと比較して、このプログラムはグラデーションフリップレイヤーのみを追加します。
class GRL(Function):
def forward(self,input):
return input
def backward(self,grad_output):
grad_input = grad_output.neg()
return grad_input
この部分の順方向は操作を実行せず、.neg()操作は逆方向で実行されます。これは、勾配の反転に相当します。torch.autogradのFUnctionの後方部分では、何も操作を行わずに、ここでのgrad_outputのデフォルト値は1です。