用OT来解决人群计数问题
用了OT+count loss + TV loss
证明OT的泛化误差比density map和Bayesian Loss
OT
考虑两个分布 X = { x i ∣ x i ∈ R d } i = 1 n \mathcal{X}=\left\{\mathbf{x}_i \mid \mathbf{x}_i \in \mathbb{R}^d\right\}_{i=1}^n X={
xi∣xi∈Rd}i=1n, Y = { y j ∣ y j ∈ R d } j = 1 n \mathcal{Y}=\left\{\mathbf{y}_j \mid \mathbf{y}_j \in \mathbb{R}^d\right\}_{j=1}^n Y={
yj∣yj∈Rd}j=1n
考虑两个测度 μ , ν \boldsymbol{\mu},\boldsymbol{\nu} μ,ν, 1 n T μ = 1 n T ν = 1 \mathbf{1}_n^T \boldsymbol{\mu}=\mathbf{1}_n^T \boldsymbol{\nu}=1 1nTμ=1nTν=1
设代价 c : X × Y ↦ R + c: \mathcal{X} \times \mathcal{Y} \mapsto \mathbb{R}_{+} c:X×Y↦R+
代价矩阵 C i j = c ( x i , y j ) \mathbf{C}_{ij}=c\left(\mathbf{x}_i,\mathbf{y}_j\right) Cij=c(xi,yj)
传输矩阵: Γ = { γ ∈ R + n × n : γ 1 = μ , γ T 1 = ν } \Gamma=\left\{\boldsymbol{\gamma} \in \mathbb{R}_{+}^{n \times n}: \boldsymbol{\gamma} \mathbf{1}=\boldsymbol{\mu},\boldsymbol{\gamma}^T \mathbf{1}=\boldsymbol{\nu}\right\} Γ={
γ∈R+n×n:γ1=μ,γT1=ν}
OT:
W ( μ , ν ) = min γ ∈ Γ ⟨ C , γ ⟩ \mathcal{W}(\boldsymbol{\mu}, \boldsymbol{\nu})=\min _{\gamma \in \Gamma}\langle\mathbf{C}, \gamma\rangle W(μ,ν)=γ∈Γmin⟨C,γ⟩
W ( μ , ν ) = max α , β ∈ R n ⟨ α , μ ⟩ + ⟨ β , ν ⟩ s.t. α i + β j ≤ c ( x i , y j ) , ∀ i , j \begin{aligned} \mathcal{W}(\boldsymbol{\mu}, \boldsymbol{\nu}) & =\max _{\boldsymbol{\alpha}, \boldsymbol{\beta} \in \mathbb{R}^n}\langle\boldsymbol{\alpha}, \boldsymbol{\mu}\rangle+\langle\boldsymbol{\beta}, \boldsymbol{\nu}\rangle\\ &\quad \text { s.t. } \alpha_i+\beta_j \leq c\left(\mathbf{x}_i, \mathbf{y}_j\right), \forall i, j \end{aligned} W(μ,ν)=α,β∈Rnmax⟨α,μ⟩+⟨β,ν⟩ s.t. αi+βj≤c(xi,yj),∀i,j
DM-count
设预测的density map为 z ^ ∈ R + n \hat{\mathbf{z}}\in\mathbb{R}_+^n z^∈R+n
gt的density map为 z ∈ R + n \mathbf{z}\in\mathbb{R}_+^n z∈R+n
count loss
这里count loss的作用:因为OT算的归一化的density map,他没有数量信息
ℓ C ( z , z ^ ) = ∣ ∥ z ∥ 1 − ∥ z ^ ∥ 1 ∣ \ell_C(\mathbf{z}, \hat{\mathbf{z}})=\left|\| \mathbf{z}\|_1-\| \hat{\mathbf{z}} \|_1 \right| ℓC(z,z^)=∣∥z∥1−∥z^∥1∣
由于 z , z ^ ≥ 0 , \mathbf{z},\hat{\mathbf{z}}\ge 0, z,z^≥0,,可以用求和代替1范数
即
ℓ C ( z , z ^ ) = ∣ ∑ i = 1 n z i − ∑ i = 1 n z ^ i ∣ \ell_C(\mathbf{z}, \hat{\mathbf{z}})=\left|\sum_{i=1}^n \mathbf{z}_i-\sum_{i=1}^{n}\hat{\mathbf{z}}_i\right| ℓC(z,z^)=∣∑i=1nzi−∑i=1nz^i∣
OT loss
ℓ O T ( z , z ^ ) = W ( z ∥ z ∥ 1 , z ^ ∥ z ^ ∥ 1 ) = ⟨ α ∗ , z ∥ z ∥ 1 ⟩ + ⟨ β ∗ , z ^ ∥ z ^ ∥ 1 ⟩ \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})=\mathcal{W}\left(\frac{\mathbf{z}}{\|\mathbf{z}\|_1}, \frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right)=\left\langle\boldsymbol{\alpha}^*, \frac{\mathbf{z}}{\|\mathbf{z}\|_1}\right\rangle+\left\langle\boldsymbol{\beta}^*, \frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\rangle ℓOT(z,z^)=W(∥z∥1z,∥z^∥1z^)=⟨α∗,∥z∥1z⟩+⟨β∗,∥z^∥1z^⟩
其中 α ∗ , β ∗ \boldsymbol{\alpha}^*,\boldsymbol{\beta}^* α∗,β∗为OT的对偶问题的最优解
代价矩阵用的是 c ( z ( i ) , z ^ ( j ) ) = ∥ z ( i ) − z ^ ( j ) ∥ 2 2 c(\mathbf{z}(i), \hat{\mathbf{z}}(j))=\|\mathbf{z}(i)-\hat{\mathbf{z}}(j)\|_2^2 c(z(i),z^(j))=∥z(i)−z^(j)∥22
∂ ℓ O T ( z , z ^ ) ∂ z ^ = β ∗ ∥ z ^ ∥ 1 − ⟨ β ∗ , z ^ ⟩ ∥ z ^ ∥ 1 2 \frac{\partial \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}=\frac{\boldsymbol{\beta}^*}{\|\hat{\mathbf{z}}\|_1}-\frac{\left\langle\boldsymbol{\beta}^*, \hat{\mathbf{z}}\right\rangle}{\|\hat{\mathbf{z}}\|_1^2} ∂z^∂ℓOT(z,z^)=∥z^∥1β∗−∥z^∥12⟨β∗,z^⟩
要注意一个问题,代码里,它的OT loss是
ℓ O T ( z , z ^ ) = ⟨ ∂ ℓ O T ( z , z ^ ) ∂ z ^ , z ^ ⟩ \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})= \left\langle \frac{\partial \ell_{O T}(\mathbf{z}, \hat{\mathbf{z}})}{\partial \hat{\mathbf{z}}}, \hat{\mathbf{z}}\right\rangle ℓOT(z,z^)=⟨∂z^∂ℓOT(z,z^),z^⟩
https://github.com/cvlab-stonybrook/DM-Count/issues/29
求解OT,用的最原始的sinkhorn(没有log-domain
TV loss
这里主要是为了稳定结果
ℓ T V ( z , z ^ ) = ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ T V = 1 2 ∥ z ∥ z ∥ 1 − z ^ ∥ z ^ ∥ 1 ∥ 1 \ell_{T V}(\mathbf{z}, \hat{\mathbf{z}})=\left\|\frac{\mathbf{z}}{\|\mathbf{z}\|_1}-\frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_{T V}=\frac{1}{2}\left\|\frac{\mathbf{z}}{\|\mathbf{z}\|_1}-\frac{\hat{\mathbf{z}}}{\|\hat{\mathbf{z}}\|_1}\right\|_1 ℓTV(z,z^)= ∥z∥1z−∥z^∥1z^ TV=21 ∥z∥1z−∥z^∥1z^ 1
结果
在UCF-QNRF上
作者模型: mae 85.76006602669905, mse 150.3385868782564
我跑的:best_model_7.pth: mae 89.24010239104311, mse 155.59441664755747