tensorflow 实现自定义梯度反向传播

以sign函数为例:

sign函数可以对数值进行二值化,但在梯度反向传播是不好处理,一般采用一个近似函数的梯度作为代替,如上图的Htanh。在[-1,1]直接梯度为1,其他为0。

#使用修饰器,建立梯度反向传播函数。其中op.input包含输入值、输出值,grad包含上层传来的梯度
@tf.RegisterGradient("QuantizeGrad")
def sign_grad(op, grad):
    input = op.inputs[0]
    cond = (input>=-1)&(input<=1)
    zeros = tf.zeros_like(grad)
    return tf.where(cond, grad, zeros)

#使用with上下文管理器覆盖原始的sign梯度函数
def binary(input):
    x = input
    with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}):
        x = tf.sign(x)
    return x

#使用
x = binary(x)

更详细教程

猜你喜欢

转载自blog.csdn.net/qq_16234613/article/details/82937867
今日推荐