版权声明:转载请注明出处 https://blog.csdn.net/DuanLiuchang/article/details/85227519
template <typename Device, typename T>
class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
public:
using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
// INPUTS:
// g (gradients): backpropagated gradients
// a (inputs): either the inputs that were passed to ReluOp(), or its
// outputs (using either one yields the same result here).
// OUTPUT:
// gradients to backprop
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
Tensor* output) {
OperateNoTemplate(context, g, a, output);
}
};
ReLU的反向传播有三个Tensor入参
第一个是梯度信息
第二个是ReLU正向传播时的输入参数
第三个是反向传播的output
如果正向传播的入参大于零,则output等于梯度的入参
否则output等于0