OT を使用して群衆カウントの問題を解決する
OT + カウント損失 + TV 損失を使用して、
OT の汎化誤差が密度マップおよびベイジアン損失よりも優れていることを証明しました。
OT
2 つの分布バツ={
×私は∣バツ私は∈Rd }i = 1ん,Y = { yj ∣ yj ∈ R d } j = 1 n \mathcal{Y}=\left\{\mathbf{y}_j \mid \mathbf{y}_j \in \mathbb{R}^d\right \}_{j=1}^nY={
yj∣yj∈Rd }j = 1ん
2 つのメジャーμ 、 ν \boldsymbol{\mu},\boldsymbol{\nu}を考えます。メートル、ν ,1 n T μ = 1 n T ν = 1 \mathbf{1}_n^T \ボール記号{\mu}=\mathbf{1}_n^T \ボール記号{\nu}=11nTメートル=1nTn=1
任生成c : X × Y ↦ R + c: \mathcal{X} \times \mathcal{Y} \mapsto \mathbb{R}_{+}c:バツ×Y↦R+
C ij = c ( xi , yj ) \mathbf{C}_{ij}=c\left(\mathbf{x}_i,\mathbf{y}_j\right) を構築します。Cイジ=c( ×私は、yj)
定義:Γ = { γ ∈ R + n × n : γ 1 = μ , γ T 1 = ν } \Gamma=\left\{\ball シンボル{\gamma} \in \mathbb{R}_{+} ^{n \times n}: \ボール記号{\gamma} \mathbf{1}=\ボール記号{\mu},\ボール記号{\gamma}^T \mathbf{1}=\ボール記号{\nu }\右\ } }C={
c∈R+n × n:c1 _=メートル、cT1 _=n }
OT:
W ( μ , ν ) = min γ ∈ Γ ⟨ C , γ ⟩ \mathcal{W}(\ballsymbol{\mu}, \ballsymbol{\nu})=\min _{\gamma \in \Gamma }\angle\mathbf{C}, \gamma\angleW ( μ ,n )=γ ∈ Γ分⟨C 、_c⟩ _
W ( μ , ν ) = max α , β ∈ R n ⟨ α , μ ⟩ + ⟨ β , ν ⟩ st α i + β j ≤ c ( xi , yj ) , ∀ i , j \begin{aligned} \ mathcal{W}(\boldsymbol{\mu}, \boldsymbol{\nu}) & =\max _{\boldsymbol{\alpha}, \boldsymbol{\beta} \in \mathbb{R}^n}\langle \boldsymbol{\alpha}, \boldsymbol{\mu}\rangle+\langle\boldsymbol{\beta}, \boldsymbol{\nu}\rangle\\ &\quad \text { st } \alpha_i+\beta_j \leq c\ left(\mathbf{x}_i, \mathbf{y}_j\right), \forall i, j \end{aligned}W ( μ ,n )=α、 β ∈ Rnマックス⟨ 、 _ん⟩+⟨ b、ん⟩ 駅 _私は+bj≤c( ×私は、yj)、∀i 、_j
DM数
予測された密度マップをz ^ ∈ R + n \hat{\mathbf{z}}\in\mathbb{R}_+^n とします。z^∈R+ん
gtz ∈ R + nの密度マップ\mathbf{z}\in\mathbb{R}_+^nz∈R+ん
カウントロス
ここでのカウントロスの役割: OT は正規化された密度マップを計算するため、数量情報はありません。
ℓ C ( z , z ^ ) = ∣ ∥ z ∥ 1 − ∥ z ^ ∥ 1 ∣ \ell_C(\mathbf{z}, \hat{\mathbf{z}})=\left|\| \mathbf{z}\|_1-\| \hat{\mathbf{z}} \|_1 \right|ℓC( z 、z^ )=∣ ∥ z ∥1−∥z^ ∥1∣次元z 、 z ^ ≥ 0 、 \mathbf{z},\hat{\mathbf{z}}\ge 0
、z 、z^≥0 , _
_
_ {z}})=\left|\sum _{i=1}^n \mathbf{z}_i-\sum _{i=1}^{n}\hat{\mathbf{z }}_i\right|ℓC( z 、z^ )=∣ ∑i = 1んz私は−∑i = 1んz^私は∣
OT損失
ℓ OT ( z , z ^ ) = W ( z ∥ z ∥ 1 , z ^ ∥ z ^ ∥ 1 ) = ⟨ α ∗ , z ∥ z ∥ 1 ⟩ + ⟨ β ∗ , z ^ ∥ z ^ ∥ 1 ⟩ \ ell_{OT}(\mathbf{z}, \hat{\mathbf{z}})=\mathcal{W}\left(\frac{\mathbf{z}}{\|\mathbf{z}\|_1 }, \frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right)=\left\langle\boldsymbol{\alpha}^*, \frac {\mathbf{z}}{\|\mathbf{z}\|_1}\right\rangle+\left\langle\boldsymbol{\beta}^*, \frac{\hat{\mathbf{z}}}{ \|\hat{\mathbf{z}}\|_1}\right\rangleℓOT( z 、z^ )=W(∥ ∥ _1z、∥z^ ∥1z^)=⟨ _∗、∥ ∥ _1z⟩+⟨b _∗、∥z^ ∥1z^ここで
α ∗ , β ∗ \boldsymbol{\alpha }^*,\boldsymbol{\beta}^*ある∗、b∗ はOT の双対問題の最適解です。
コスト行列はc ( z ( i ) , z ^ ( j ) ) = ∥ z ( i ) − z ^ ( j ) ∥ 2 2 c(\mathbf {z} (i), \hat{\mathbf{z}}(j))=\|\mathbf{z}(i)-\hat{\mathbf{z}}(j)\|_2^2c ( z ( i ) 、z^ (j))=∥ z ( i )−z^ (j)∥22
∂ ℓ OT ( z , z ^ ) ∂ z ^ = β ∗ ∥ z ^ ∥ 1 − ⟨ β ∗ , z ^ ⟩ ∥ z ^ ∥ 1 2 \frac{\partial \ell_{OT}(\mathbf{z} , \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}=\frac{\boldsymbol{\beta}^*}{\|\hat{\mathbf{z}} \|_1}-\frac{\left\langle\boldsymbol{\beta}^*, \hat{\mathbf{z}}\right\rangle}{\|\hat{\mathbf{z}}\|_1 ^2}∂z^∂ℓ _OT( z 、z^ )=∥z^ ∥1b∗−∥z^ ∥12⟨b _∗、z^ ⟩
注意すべき点の 1 つは、コード内の OT 損失は次のとおりであるということです。
ℓ OT ( z , z ^ ) = ⟨ ∂ ℓ OT ( z , z ^ ) ∂ z ^ , z ^ ⟩ \ell_{OT}(\mathbf{z}, \hat{\mathbf{z}})= \ left\langle \frac{\partial \ell_{OT}(\mathbf{z}, \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}, \hat{\mathbf {z}}\right\rangleℓOT( z 、z^ )=⟨∂z^∂ℓ _OT( z 、z^ )、z^ ⟩
https://github.com/cvlab-stonybrook/DM-Count/issues/29
OT を解決するには、最も原始的なシンクホーン (ログドメインを使用しない) を使用します。
テレビの喪失
これは主に結果を安定させるためです
ℓ TV ( z , z ^ ) = ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ TV = 1 2 ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ 1 \ell_{TV}( \mathbf{z}, \hat{\mathbf{z}})=\left\|\frac{\mathbf{z}}{\|\mathbf{z}\|_1}-\frac{\hat{\ mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_{TV}=\frac{1}{2}\left\|\frac{\mathbf{z }}{\|\mathbf{z}\|_1}-\frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_1ℓテレビ_( z 、z^ )= ∥ ∥ _1z−∥z^ ∥1z^ テレビ_=21 ∥ ∥ _1z−∥z^ ∥1z^ 1
結果
UCF-QNRF の
作成者モデル: mae 85.76006602669905、mse 150.3385868782564
を実行しました: best_model_7.pth: mae 89.24010239104311、mse 155.59441664755747