Distribution Matching for Crowd Counting阅读笔记

用OT来解决人群计数问题

用了OT+count loss + TV loss
证明OT的泛化误差比density map和Bayesian Loss

OT

考虑两个分布 X = { x i ∣ x i ∈ R d } i = 1 n \mathcal{X}=\left\{\mathbf{x}_i \mid \mathbf{x}_i \in \mathbb{R}^d\right\}_{i=1}^n X={ xixiRd}i=1n Y = { y j ∣ y j ∈ R d } j = 1 n \mathcal{Y}=\left\{\mathbf{y}_j \mid \mathbf{y}_j \in \mathbb{R}^d\right\}_{j=1}^n Y={ yjyjRd}j=1n
考虑两个测度 μ , ν \boldsymbol{\mu},\boldsymbol{\nu} μ,ν, 1 n T μ = 1 n T ν = 1 \mathbf{1}_n^T \boldsymbol{\mu}=\mathbf{1}_n^T \boldsymbol{\nu}=1 1nTμ=1nTν=1

设代价 c : X × Y ↦ R + c: \mathcal{X} \times \mathcal{Y} \mapsto \mathbb{R}_{+} c:X×YR+
代价矩阵 C i j = c ( x i , y j ) \mathbf{C}_{ij}=c\left(\mathbf{x}_i,\mathbf{y}_j\right) Cij=c(xi,yj)
传输矩阵: Γ = { γ ∈ R + n × n : γ 1 = μ , γ T 1 = ν } \Gamma=\left\{\boldsymbol{\gamma} \in \mathbb{R}_{+}^{n \times n}: \boldsymbol{\gamma} \mathbf{1}=\boldsymbol{\mu},\boldsymbol{\gamma}^T \mathbf{1}=\boldsymbol{\nu}\right\} Γ={ γR+n×n:γ1=μ,γT1=ν}

OT:
W ( μ , ν ) = min ⁡ γ ∈ Γ ⟨ C , γ ⟩ \mathcal{W}(\boldsymbol{\mu}, \boldsymbol{\nu})=\min _{\gamma \in \Gamma}\langle\mathbf{C}, \gamma\rangle W(μ,ν)=γΓminC,γ

W ( μ , ν ) = max ⁡ α , β ∈ R n ⟨ α , μ ⟩ + ⟨ β , ν ⟩  s.t.  α i + β j ≤ c ( x i , y j ) , ∀ i , j \begin{aligned} \mathcal{W}(\boldsymbol{\mu}, \boldsymbol{\nu}) & =\max _{\boldsymbol{\alpha}, \boldsymbol{\beta} \in \mathbb{R}^n}\langle\boldsymbol{\alpha}, \boldsymbol{\mu}\rangle+\langle\boldsymbol{\beta}, \boldsymbol{\nu}\rangle\\ &\quad \text { s.t. } \alpha_i+\beta_j \leq c\left(\mathbf{x}_i, \mathbf{y}_j\right), \forall i, j \end{aligned} W(μ,ν)=α,βRnmaxα,μ+β,ν s.t. αi+βjc(xi,yj),i,j

DM-count

设预测的density map为 z ^ ∈ R + n \hat{\mathbf{z}}\in\mathbb{R}_+^n z^R+n
gt的density map为 z ∈ R + n \mathbf{z}\in\mathbb{R}_+^n zR+n

count loss

这里count loss的作用:因为OT算的归一化的density map,他没有数量信息

ℓ C ( z , z ^ ) = ∣ ∥ z ∥ 1 − ∥ z ^ ∥ 1 ∣ \ell_C(\mathbf{z}, \hat{\mathbf{z}})=\left|\| \mathbf{z}\|_1-\| \hat{\mathbf{z}} \|_1 \right| C(z,z^)=z1z^1
由于 z , z ^ ≥ 0 , \mathbf{z},\hat{\mathbf{z}}\ge 0, z,z^0,,可以用求和代替1范数

ℓ C ( z , z ^ ) = ∣ ∑ i = 1 n z i − ∑ i = 1 n z ^ i ∣ \ell_C(\mathbf{z}, \hat{\mathbf{z}})=\left|\sum_{i=1}^n \mathbf{z}_i-\sum_{i=1}^{n}\hat{\mathbf{z}}_i\right| C(z,z^)=i=1nzii=1nz^i

OT loss

ℓ O T ( z , z ^ ) = W ( z ∥ z ∥ 1 , z ^ ∥ z ^ ∥ 1 ) = ⟨ α ∗ , z ∥ z ∥ 1 ⟩ + ⟨ β ∗ , z ^ ∥ z ^ ∥ 1 ⟩ \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})=\mathcal{W}\left(\frac{\mathbf{z}}{\|\mathbf{z}\|_1}, \frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right)=\left\langle\boldsymbol{\alpha}^*, \frac{\mathbf{z}}{\|\mathbf{z}\|_1}\right\rangle+\left\langle\boldsymbol{\beta}^*, \frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\rangle OT(z,z^)=W(z1z,z^1z^)=α,z1z+β,z^1z^
其中 α ∗ , β ∗ \boldsymbol{\alpha}^*,\boldsymbol{\beta}^* α,β为OT的对偶问题的最优解
代价矩阵用的是 c ( z ( i ) , z ^ ( j ) ) = ∥ z ( i ) − z ^ ( j ) ∥ 2 2 c(\mathbf{z}(i), \hat{\mathbf{z}}(j))=\|\mathbf{z}(i)-\hat{\mathbf{z}}(j)\|_2^2 c(z(i),z^(j))=z(i)z^(j)22

∂ ℓ O T ( z , z ^ ) ∂ z ^ = β ∗ ∥ z ^ ∥ 1 − ⟨ β ∗ , z ^ ⟩ ∥ z ^ ∥ 1 2 \frac{\partial \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}=\frac{\boldsymbol{\beta}^*}{\|\hat{\mathbf{z}}\|_1}-\frac{\left\langle\boldsymbol{\beta}^*, \hat{\mathbf{z}}\right\rangle}{\|\hat{\mathbf{z}}\|_1^2} z^OT(z,z^)=z^1βz^12β,z^

要注意一个问题,代码里,它的OT loss是

ℓ O T ( z , z ^ ) = ⟨ ∂ ℓ O T ( z , z ^ ) ∂ z ^ , z ^ ⟩ \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})= \left\langle \frac{\partial \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}, \hat{\mathbf{z}}\right\rangle OT(z,z^)=z^OT(z,z^),z^

https://github.com/cvlab-stonybrook/DM-Count/issues/29

求解OT,用的最原始的sinkhorn(没有log-domain

TV loss

这里主要是为了稳定结果

ℓ T V ( z , z ^ ) = ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ T V = 1 2 ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ 1 \ell_{T V}(\mathbf{z}, \hat{\mathbf{z}})=\left\|\frac{\mathbf{z}}{\|\mathbf{z}\|_1}-\frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_{T V}=\frac{1}{2}\left\|\frac{\mathbf{z}}{\|\mathbf{z}\|_1}-\frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_1 TV(z,z^)= z1zz^1z^ TV=21 z1zz^1z^ 1

结果

在这里插入图片描述

在UCF-QNRF上
作者模型: mae 85.76006602669905, mse 150.3385868782564
我跑的:best_model_7.pth: mae 89.24010239104311, mse 155.59441664755747

猜你喜欢

转载自blog.csdn.net/qq_39942341/article/details/131807168