Zygote is a package in Julia that realizes automatic differentiation and automatic derivation, and @adjoint
macros are an important part of the Zygote interface. @adjoint
The backpropagation of the function can be customized using .
Pullbacks
To understand, @adjoint
you must first understand the lower-level functions pullback
. gradient
It's actually pullback
syntactic sugar for .
julia> y, back = Zygote.pullback(sin, 0.5)
(0.479425538604203, Zygote.var"#41#42"{Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}}(Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}(ChainRules.var"#sin_pullback#1430"{Float64}(0.8775825618903728))))
julia> y
0.479425538604203
Give pullback
the input two parameters sin
and 0.5
represent the function to be derived and the value to be derived respectively, and two outputs will be obtained: the result of the given function sin(0.5)
and one pullback
, which is the variable in the above code back
. back
Gradient calculation is performed on the function sin
, accepting a derivation, and generating a new variable. . Mathematically, it is the realization of the vector-Jacobian product. where y = f ( x ) y=f(x)y=f ( x ) and the gradient∂ l ∂ x \frac{\partial{l}}{\partial{x}}∂x∂lwritten as x ˉ \bar{x}xˉ,pullback B y \mathcal{B}_y By 如下计算:
x ˉ = ∂ l ∂ x = ∂ l ∂ y ∂ y ∂ x = B y ( y ˉ ) \bar{x}=\frac{\partial l}{\partial x}=\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}=\mathcal{B}_{y}(\bar{y}) xˉ=∂x∂l=∂y∂l∂x∂y=By(yˉ)
More specifically, taking the above code as an example, the functiony = sin ( x ) y=\sin(x)y=sin(x). ∂ y ∂ x = cos ( x ) \frac{\partial y}{\partial x}=\cos (x) ∂x∂y=cos ( x ) , so pullback isy ˉ cos ( x ) \bar{y}\cos(x)yˉcos ( x ),其中y ˉ = ∂ l ∂ y \bar{y}=\frac{\partial l}{\partial y}yˉ=∂y∂l. In other words, equivalent pullback(sin, x)
to dsin(x) = (sin(x), ȳ -> (ȳ * cos(x),))
.
gradient
In the function l = f ( x ) l=f(x)l=f ( x ) and assumel ˉ = ∂ l ∂ l = 1 \bar{l}=\frac{\partial l}{\partial l}=1lˉ=∂l∂l=1 , and feed it into the pullback. Insin
the example of ,
julia> dsin(x) = (sin, ȳ -> (ȳ * cos(x),))
dsin (generic function with 1 method)
julia> function gradsin(x)
_, back = dsin(x)
back(1)
end
gradsin (generic function with 1 method)
julia> gradsin(0.5)
(0.8775825618903728,)
julia> cos(0.5)
0.8775825618903728
julia> back(1)
(0.8775825618903728,)
Personal understanding, why add an item ∂ l ∂ y \frac{\partial l}{\partial y}∂y∂l, which is to implement the chain rule. For example, suppose the final loss is lll , the functiony ( x ) y(x)y ( x ) , to get the loss functionlll for parameterxxDifferentiation of x ∂ l ∂ x \frac{\partial l}{\partial x}∂x∂l, according to the chain rule is the loss function to the function yyDifferentiation of y multiplication function with respect to parameterxxx 的微分,即 ∂ l ∂ y ∂ y ∂ x \frac{\partial l}{\partial y} \frac{\partial y}{\partial x} ∂y∂l∂x∂y. function yyy ispullback
the loss function to the functionyyDifferentiation of y (withy ˉ \bar{y}yˉmeans) multiplied by the function pair xxDifferentiation of x .
For the above example, pullback
the first result returned by the function is: Suppose the function y = sin ( x ) y=\sin(x)y=sin ( x ) is the loss functionllWhen l ,x = 0.5 x=0.5x=The result at 0.5 , that is, cos ( 0.5 ) \cos(0.5)cos ( 0 . 5 ) , and the returnedback
is a∂ l ∂ y \frac{\partial l}{\partial y}∂y∂lfunction, which can be regarded as B ( ∂ l ∂ y ) = ∂ l ∂ y cos ( 0.5 ) \mathcal{B}(\frac{\partial l}{\partial y})=\frac{\partial l }{\partial y}\cos(0.5)B(∂y∂l)=∂y∂lcos(0.5)。
If l = 0.5 y = 0.5 sin ( x ) l=0.5y=0.5\sin(x)l=0 . 5 y=0.5sin ( x ) , we can get∂ l ∂ y = 0.5 \frac{\partial l}{\partial y}=0.5∂y∂l=0.5,那么 ∂ l ∂ x = B ( ∂ l ∂ y ) = B ( 0.5 ) \frac{\partial l}{\partial x}=\mathcal{B}(\frac{\partial l}{\partial y})=\mathcal{B}(0.5) ∂x∂l=B(∂y∂l)=B(0.5)。
reference: