https://github.com/dfdazac/wassdistance/tree/master
前置知识
Computational optimal transport学习
具体看到熵对偶的坐标上升那就行
L C ε ( a , b ) = def. min P ∈ U ( a , b ) ⟨ P , C ⟩ − ε H ( P ) \mathrm{L}_{\mathbf{C}}^{\varepsilon}(\mathbf{a}, \mathbf{b}) \stackrel{\text { def. }}{=} \min _{\mathbf{P} \in \mathbf{U}(\mathbf{a}, \mathbf{b})}\langle\mathbf{P}, \mathbf{C}\rangle-\varepsilon \mathbf{H}(\mathbf{P}) LCε(a,b)= def. P∈U(a,b)min⟨P,C⟩−εH(P)
U ( a , b ) = def. { P ∈ R + n × m : P 1 m = a and P T 1 n = b } \mathbf{U}(\mathbf{a}, \mathbf{b}) \stackrel{\text { def. }}{=}\left\{\mathbf{P} \in \mathbb{R}_{+}^{n \times m}: \mathbf{P} \mathbf{1}_m=\mathbf{a} \quad \text { and } \quad \mathbf{P}^{\mathrm{T}} \mathbf{1}_n=\mathbf{b}\right\} U(a,b)= def. {
P∈R+n×m:P1m=a and PT1n=b}
对偶
L C ε ( a , b ) = max f ∈ R n , g ∈ R m ⟨ f , a ⟩ + ⟨ g , b ⟩ − ε ⟨ e f / ε , K e g / ε ⟩ \mathrm{L}_{\mathbf{C}}^{\varepsilon}(\mathbf{a}, \mathbf{b})=\max _{\mathbf{f} \in \mathbb{R}^n, \mathbf{g} \in \mathbb{R}^m}\langle\mathbf{f}, \mathbf{a}\rangle+\langle\mathbf{g}, \mathbf{b}\rangle-\varepsilon\left\langle e^{\mathbf{f} / \varepsilon}, \mathbf{K} e^{\mathbf{g} / \varepsilon}\right\rangle LCε(a,b)=f∈Rn,g∈Rmmax⟨f,a⟩+⟨g,b⟩−ε⟨ef/ε,Keg/ε⟩
( u , v ) = ( e f / ε , e g / ε ) (\mathbf{u}, \mathbf{v})=\left(e^{\mathbf{f} / \varepsilon}, e^{\mathbf{g} / \varepsilon}\right) (u,v)=(ef/ε,eg/ε)
P = d i a g ( u ) K d i a g ( v ) , K = e x p ( − C ϵ ) \mathbf{P}=\rm{diag}\left(\mathbf{u}\right)\mathbf{K}\rm{diag}\left(\mathbf{v}\right),\quad \mathbf{K}=exp\left(-\frac{C}{\epsilon}\right) P=diag(u)Kdiag(v),K=exp(−ϵC)
坐标上升
f ( ℓ + 1 ) = ε log a − ε log ( K e g ( ℓ ) / ε ) , g ( ℓ + 1 ) = ε log b − ε log ( K T e f ( ℓ + 1 ) / ε ) . \begin{aligned} \mathbf{f}^{(\ell+1)} & =\varepsilon \log \mathbf{a}-\varepsilon \log \left(\mathbf{K} e^{\mathbf{g}^{(\ell)} / \varepsilon}\right), \\ \mathbf{g}^{(\ell+1)} & =\varepsilon \log \mathbf{b}-\varepsilon \log \left(\mathbf{K}^{\mathrm{T}} e^{\mathbf{f}^{(\ell+1)} / \varepsilon}\right) . \end{aligned} f(ℓ+1)g(ℓ+1)=εloga−εlog(Keg(ℓ)/ε),=εlogb−εlog(KTef(ℓ+1)/ε).
代码中有一些变化
考虑 C ∈ R n × m , f ∈ R n , g ∈ R m \mathbf{C}\in\mathbb{R}^{n\times m}, \mathbf{f}\in\mathbb{R}^n, \mathbf{g}\in\mathbb{R}^m C∈Rn×m,f∈Rn,g∈Rm
log ( K e g / ε ) = log ( [ ∑ j e − C i , j − g j ε ] i ) = log ( [ ∑ j e − C i , j − g j ε e f i ε e − f i ε ] i ) = log ( [ ∑ j e − C i , j − f i − g j ε ] i ⊙ e − f ε ) = log ( [ ∑ j e − C i , j − f i − g j ε ] i ) − f ε = logsumexp ( − C − f T − g ε , d i m = − 1 ) − f ε \begin{aligned} &\log \left(\mathbf{K} e^{\mathbf{g} / \varepsilon}\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-g_j}{\varepsilon}}\right]_i\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-g_j}{\varepsilon}}e^{\frac{f_i}{\varepsilon}}e^{-\frac{f_i}{\varepsilon}}\right]_i\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_i\odot e^{-\frac{\mathbf{f}}{\varepsilon}}\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_i\right)-\frac{\mathbf{f}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\mathbf{C}-\mathbf{f}^T-\mathbf{g}}{\varepsilon},dim=-1\right)-\frac{\mathbf{f}}{\varepsilon}\\ \end{aligned} =====log(Keg/ε)log
[j∑e−εCi,j−gj]i
log
[j∑e−εCi,j−gjeεfie−εfi]i
log
[j∑e−εCi,j−fi−gj]i⊙e−εf
log
[j∑e−εCi,j−fi−gj]i
−εflogsumexp(−εC−fT−g,dim=−1)−εf
其中最后一步,向量和矩阵相加涉及广播机制
log ( K T e f / ε ) = log ( [ ∑ i e − C i , j − f i ε ] j ) = log ( [ ∑ i e − C i , j − f i ε e g j ε e − g j ε ] j ) = log ( [ ∑ i e − C i , j − f i − g j ε ] j ⊙ e − g ε ) = log ( [ ∑ i e − C i , j − f i − g j ε ] j ) − g ε = logsumexp ( − C − f T − g ε , d i m = − 2 ) − g ε = logsumexp ( − ( C − f T − g ) T ε , d i m = − 1 ) − g ε \begin{aligned} &\log \left(\mathbf{K}^{\mathrm{T}} e^{\mathbf{f} / \varepsilon}\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i}{\varepsilon}}\right]_j\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i}{\varepsilon}}e^{\frac{g_j}{\varepsilon}}e^{-\frac{g_j}{\varepsilon}}\right]_j\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_j\odot e^{-\frac{\mathbf{g}}{\varepsilon}}\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_j\right)-\frac{\mathbf{g}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\mathbf{C}-\mathbf{f}^T-\mathbf{g}}{\varepsilon},dim=-2\right)-\frac{\mathbf{g}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\left(\mathbf{C}-\mathbf{f}^T-\mathbf{g}\right)^T}{\varepsilon},dim=-1\right)-\frac{\mathbf{g}}{\varepsilon}\\ \end{aligned} ======log(KTef/ε)log
[i∑e−εCi,j−fi]j
log
[i∑e−εCi,j−fieεgje−εgj]j
log
[i∑e−εCi,j−fi−gj]j⊙e−εg
log
[i∑e−εCi,j−fi−gj]j
−εglogsumexp(−εC−fT−g,dim=−2)−εglogsumexp(−ε(C−fT−g)T,dim=−1)−εg