Julia: Custom Backpropagation on Zygote

Zygote is a package in Julia that realizes automatic differentiation and automatic derivation, and @adjointmacros are an important part of the Zygote interface. @adjointThe backpropagation of the function can be customized using .

Pullbacks

To understand, @adjointyou must first understand the lower-level functions pullback. gradientIt's actually pullbacksyntactic sugar for .

julia> y, back = Zygote.pullback(sin, 0.5)
(0.479425538604203, Zygote.var"#41#42"{Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}}(Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}(ChainRules.var"#sin_pullback#1430"{Float64}(0.8775825618903728))))

julia> y
0.479425538604203

Give pullbackthe input two parameters sinand 0.5represent the function to be derived and the value to be derived respectively, and two outputs will be obtained: the result of the given function sin(0.5)and one pullback, which is the variable in the above code back. backGradient calculation is performed on the function sin, accepting a derivation, and generating a new variable. . Mathematically, it is the realization of the vector-Jacobian product. where y = f ( x ) y=f(x)y=f ( x ) and the gradient∂ l ∂ x \frac{\partial{l}}{\partial{x}}xlwritten as x ˉ \bar{x}xˉ,pullback B y \mathcal{B}_y By 如下计算:
x ˉ = ∂ l ∂ x = ∂ l ∂ y ∂ y ∂ x = B y ( y ˉ ) \bar{x}=\frac{\partial l}{\partial x}=\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}=\mathcal{B}_{y}(\bar{y}) xˉ=xl=ylxy=By(yˉ)
More specifically, taking the above code as an example, the functiony = sin ⁡ ( x ) y=\sin(x)y=sin(x). ∂ y ∂ x = cos ⁡ ( x ) \frac{\partial y}{\partial x}=\cos (x) xy=cos ( x ) , so pullback isy ˉ cos ⁡ ( x ) \bar{y}\cos(x)yˉcos ( x ),其中y ˉ = ∂ l ∂ y \bar{y}=\frac{\partial l}{\partial y}yˉ=yl. In other words, equivalent pullback(sin, x)to dsin(x) = (sin(x), ȳ -> (ȳ * cos(x),)).

gradientIn the function l = f ( x ) l=f(x)l=f ( x ) and assumel ˉ = ∂ l ∂ l = 1 \bar{l}=\frac{\partial l}{\partial l}=1lˉ=ll=1 , and feed it into the pullback. Insinthe example of ,

julia> dsin(x) = (sin, ȳ -> (ȳ * cos(x),))
dsin (generic function with 1 method)

julia> function gradsin(x)
           _, back = dsin(x)
           back(1)
       end
gradsin (generic function with 1 method)

julia> gradsin(0.5)
(0.8775825618903728,)

julia> cos(0.5)
0.8775825618903728
                
julia> back(1)
(0.8775825618903728,)

Personal understanding, why add an item ∂ l ∂ y \frac{\partial l}{\partial y}yl, which is to implement the chain rule. For example, suppose the final loss is lll , the functiony ( x ) y(x)y ( x ) , to get the loss functionlll for parameterxxDifferentiation of x ∂ l ∂ x \frac{\partial l}{\partial x}xl, according to the chain rule is the loss function to the function yyDifferentiation of y multiplication function with respect to parameterxxx 的微分,即 ∂ l ∂ y ∂ y ∂ x \frac{\partial l}{\partial y} \frac{\partial y}{\partial x} ylxy. function yyy ispullbackthe loss function to the functionyyDifferentiation of y (withy ˉ \bar{y}yˉmeans) multiplied by the function pair xxDifferentiation of x .

For the above example, pullbackthe first result returned by the function is: Suppose the function y = sin ⁡ ( x ) y=\sin(x)y=sin ( x ) is the loss functionllWhen l ,x = 0.5 x=0.5x=The result at 0.5 , that is, cos ⁡ ( 0.5 ) \cos(0.5)cos ( 0 . 5 ) , and the returnedbackis a∂ l ∂ y \frac{\partial l}{\partial y}ylfunction, which can be regarded as B ( ∂ l ∂ y ) = ∂ l ∂ y cos ⁡ ( 0.5 ) \mathcal{B}(\frac{\partial l}{\partial y})=\frac{\partial l }{\partial y}\cos(0.5)B(yl)=ylcos(0.5)

If l = 0.5 y = 0.5 sin ⁡ ( x ) l=0.5y=0.5\sin(x)l=0 . 5 y=0.5sin ( x ) , we can get∂ l ∂ y = 0.5 \frac{\partial l}{\partial y}=0.5yl=0.5,那么 ∂ l ∂ x = B ( ∂ l ∂ y ) = B ( 0.5 ) \frac{\partial l}{\partial x}=\mathcal{B}(\frac{\partial l}{\partial y})=\mathcal{B}(0.5) xl=B(yl)=B(0.5)


reference:

[1] Custom Adjoints • Zygote

Guess you like

Origin blog.csdn.net/weixin_39679367/article/details/121647717