Gradient inversion in tensorflow

Original address: https://codeleading.com/article/27004717987/

import tensorflow as tf
from tensorflow.python.framework import ops
 
 
class FlipGradientBuilder(object):
    def __init__(self):
        self.num_calls = 0
 
 
 
    def __call__(self, x, l=1.0):
        grad_name = "FlipGradient%d" % self.num_calls
        @ops.RegisterGradient(grad_name)
        def _flip_gradients(op, grad):
            return [tf.negative(grad) * l]
 
        g = tf.get_default_graph()
        with g.gradient_override_map({
    
    "Identity": grad_name}):
            y = tf.identity(x)
        self.num_calls += 1
 
        return y
    
flip_gradient = FlipGradientBuilder()

Among them
(1) @ops.RegisterGradient(grad_name) modifies the _flip_gradients(op, grad) function, that is, customize the gradient of this layer to invert

(2) The gradient_override_map function is mainly used to solve the problem of using self-defined functions to find the gradient. The parameter value of the gradient_override_map function is a dictionary. What it means is: the value in the dictionary means that the function represented by the value is used instead of the function represented by the key to perform gradient operations.

The conclusion
is that pytorch is easy to use

附,红色字体使用方法:
<font color=red size=4>还是pytorch使用方便</font>

Guess you like

Origin blog.csdn.net/tailonh/article/details/111322367