今日、私は微分不可能な関数をカスタマイズするときに大きな落とし穴に遭遇しました。
まず、関数をカスタマイズする必要があります:sign_f
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
その後、私は同じように、モジュール型としてそれをカプセル化する必要がnn.Conv2dモジュールカプセル化f.conv2dので、
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
# 我需要的module
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
def forward(self, inputs):
# 使用自定义函数
outs = sign_f(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
結果は間違っています
TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'
長い間試してみましたが、カスタム関数の後に適用されることがわかりました。詳細については、以下を参照してください
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
self.r = sign_f.apply ### <-----注意此处
def forward(self, inputs):
outs = self.r(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
問題は解決しました