Approximating Wasserstein distances with PyTorch学习

https://github.com/dfdazac/wassdistance/tree/master

前置知识

Computational optimal transport学习
具体看到熵对偶的坐标上升那就行

L C ε ( a , b ) =  def.  min ⁡ P ∈ U ( a , b ) ⟨ P , C ⟩ − ε H ( P ) \mathrm{L}_{\mathbf{C}}^{\varepsilon}(\mathbf{a}, \mathbf{b}) \stackrel{\text { def. }}{=} \min _{\mathbf{P} \in \mathbf{U}(\mathbf{a}, \mathbf{b})}\langle\mathbf{P}, \mathbf{C}\rangle-\varepsilon \mathbf{H}(\mathbf{P}) LCε(a,b)= def. PU(a,b)minP,CεH(P)
U ( a , b ) =  def.  { P ∈ R + n × m : P 1 m = a  and  P T 1 n = b } \mathbf{U}(\mathbf{a}, \mathbf{b}) \stackrel{\text { def. }}{=}\left\{\mathbf{P} \in \mathbb{R}_{+}^{n \times m}: \mathbf{P} \mathbf{1}_m=\mathbf{a} \quad \text { and } \quad \mathbf{P}^{\mathrm{T}} \mathbf{1}_n=\mathbf{b}\right\} U(a,b)= def. { PR+n×m:P1m=a and PT1n=b}

对偶
L C ε ( a , b ) = max ⁡ f ∈ R n , g ∈ R m ⟨ f , a ⟩ + ⟨ g , b ⟩ − ε ⟨ e f / ε , K e g / ε ⟩ \mathrm{L}_{\mathbf{C}}^{\varepsilon}(\mathbf{a}, \mathbf{b})=\max _{\mathbf{f} \in \mathbb{R}^n, \mathbf{g} \in \mathbb{R}^m}\langle\mathbf{f}, \mathbf{a}\rangle+\langle\mathbf{g}, \mathbf{b}\rangle-\varepsilon\left\langle e^{\mathbf{f} / \varepsilon}, \mathbf{K} e^{\mathbf{g} / \varepsilon}\right\rangle LCε(a,b)=fRn,gRmmaxf,a+g,bεef/ε,Keg/ε
( u , v ) = ( e f / ε , e g / ε ) (\mathbf{u}, \mathbf{v})=\left(e^{\mathbf{f} / \varepsilon}, e^{\mathbf{g} / \varepsilon}\right) (u,v)=(ef/ε,eg/ε)

P = d i a g ( u ) K d i a g ( v ) , K = e x p ( − C ϵ ) \mathbf{P}=\rm{diag}\left(\mathbf{u}\right)\mathbf{K}\rm{diag}\left(\mathbf{v}\right),\quad \mathbf{K}=exp\left(-\frac{C}{\epsilon}\right) P=diag(u)Kdiag(v),K=exp(ϵC)

坐标上升
f ( ℓ + 1 ) = ε log ⁡ a − ε log ⁡ ( K e g ( ℓ ) / ε ) , g ( ℓ + 1 ) = ε log ⁡ b − ε log ⁡ ( K T e f ( ℓ + 1 ) / ε ) . \begin{aligned} \mathbf{f}^{(\ell+1)} & =\varepsilon \log \mathbf{a}-\varepsilon \log \left(\mathbf{K} e^{\mathbf{g}^{(\ell)} / \varepsilon}\right), \\ \mathbf{g}^{(\ell+1)} & =\varepsilon \log \mathbf{b}-\varepsilon \log \left(\mathbf{K}^{\mathrm{T}} e^{\mathbf{f}^{(\ell+1)} / \varepsilon}\right) . \end{aligned} f(+1)g(+1)=εlogaεlog(Keg()/ε),=εlogbεlog(KTef(+1)/ε).

代码中有一些变化

考虑 C ∈ R n × m , f ∈ R n , g ∈ R m \mathbf{C}\in\mathbb{R}^{n\times m}, \mathbf{f}\in\mathbb{R}^n, \mathbf{g}\in\mathbb{R}^m CRn×m,fRn,gRm

log ⁡ ( K e g / ε ) = log ⁡ ( [ ∑ j e − C i , j − g j ε ] i ) = log ⁡ ( [ ∑ j e − C i , j − g j ε e f i ε e − f i ε ] i ) = log ⁡ ( [ ∑ j e − C i , j − f i − g j ε ] i ⊙ e − f ε ) = log ⁡ ( [ ∑ j e − C i , j − f i − g j ε ] i ) − f ε = logsumexp ⁡ ( − C − f T − g ε , d i m = − 1 ) − f ε \begin{aligned} &\log \left(\mathbf{K} e^{\mathbf{g} / \varepsilon}\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-g_j}{\varepsilon}}\right]_i\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-g_j}{\varepsilon}}e^{\frac{f_i}{\varepsilon}}e^{-\frac{f_i}{\varepsilon}}\right]_i\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_i\odot e^{-\frac{\mathbf{f}}{\varepsilon}}\right)\\ =&\log\left(\left[\sum_{j}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_i\right)-\frac{\mathbf{f}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\mathbf{C}-\mathbf{f}^T-\mathbf{g}}{\varepsilon},dim=-1\right)-\frac{\mathbf{f}}{\varepsilon}\\ \end{aligned} =====log(Keg/ε)log [jeεCi,jgj]i log [jeεCi,jgjeεfieεfi]i log [jeεCi,jfigj]ieεf log [jeεCi,jfigj]i εflogsumexp(εCfTg,dim=1)εf
其中最后一步,向量和矩阵相加涉及广播机制

log ⁡ ( K T e f / ε ) = log ⁡ ( [ ∑ i e − C i , j − f i ε ] j ) = log ⁡ ( [ ∑ i e − C i , j − f i ε e g j ε e − g j ε ] j ) = log ⁡ ( [ ∑ i e − C i , j − f i − g j ε ] j ⊙ e − g ε ) = log ⁡ ( [ ∑ i e − C i , j − f i − g j ε ] j ) − g ε = logsumexp ⁡ ( − C − f T − g ε , d i m = − 2 ) − g ε = logsumexp ⁡ ( − ( C − f T − g ) T ε , d i m = − 1 ) − g ε \begin{aligned} &\log \left(\mathbf{K}^{\mathrm{T}} e^{\mathbf{f} / \varepsilon}\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i}{\varepsilon}}\right]_j\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i}{\varepsilon}}e^{\frac{g_j}{\varepsilon}}e^{-\frac{g_j}{\varepsilon}}\right]_j\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_j\odot e^{-\frac{\mathbf{g}}{\varepsilon}}\right)\\ =&\log\left(\left[\sum_{i}e^{-\frac{C_{i,j}-f_i-g_j}{\varepsilon}}\right]_j\right)-\frac{\mathbf{g}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\mathbf{C}-\mathbf{f}^T-\mathbf{g}}{\varepsilon},dim=-2\right)-\frac{\mathbf{g}}{\varepsilon}\\ =&\operatorname{logsumexp}\left(-\frac{\left(\mathbf{C}-\mathbf{f}^T-\mathbf{g}\right)^T}{\varepsilon},dim=-1\right)-\frac{\mathbf{g}}{\varepsilon}\\ \end{aligned} ======log(KTef/ε)log [ieεCi,jfi]j log [ieεCi,jfieεgjeεgj]j log [ieεCi,jfigj]jeεg log [ieεCi,jfigj]j εglogsumexp(εCfTg,dim=2)εglogsumexp(ε(CfTg)T,dim=1)εg
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_39942341/article/details/131751760