ReLU反向传播

版权声明:转载请注明出处 https://blog.csdn.net/DuanLiuchang/article/details/85227519
template <typename Device, typename T>
class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
 public:
  using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;

  void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
                         const Tensor& a, Tensor* output);

  // INPUTS:
  //   g (gradients): backpropagated gradients
  //   a (inputs): either the inputs that were passed to ReluOp(), or its
  //               outputs (using either one yields the same result here).
  // OUTPUT:
  //   gradients to backprop
  template <int NDIMS>
  void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
               Tensor* output) {
    OperateNoTemplate(context, g, a, output);
  }
};

ReLU的反向传播有三个Tensor入参

第一个是梯度信息

第二个是ReLU正向传播时的输入参数

第三个是反向传播的output

如果正向传播的入参大于零,则output等于梯度的入参

否则output等于0

猜你喜欢

转载自blog.csdn.net/DuanLiuchang/article/details/85227519