[论文阅读笔记18] DiffusionDet论文笔记与代码解读


扩散模型近期在图像生成领域很火, 没想到很快就被用在了检测上. 打算对这篇论文做一个笔记.

论文地址: 论文
代码: 代码


0. 扩散模型简述

首先介绍什么是扩散模型. 我们考虑生成任务, 即encoder-decoder形式的模型, encoder提取输入的抽象信息, 并尝试在decoder中恢复出来. 扩散模型就是这一类中的方法, 其灵感由热力学而来, 基本做法是在输入中逐步加噪, 并学会如何在噪声中恢复出输入. 假定加噪的过程为Markov过程.

扩散模型和GAN, VAE虽然同为生成式模型, 但其思想不同. GAN是将模型分为生成器与鉴别器两个部分, 生成器的目的是让鉴别器分不出她的输出并非来自于真实数据集合, 而鉴别器的目的是不要被生成器欺骗. 这种博弈的方式有的时候也会陷入一些困境(例如难以到达纳什均衡). VAE得到的潜在变量(latent variable)的维度是小于输入的, 而扩散模型的中间变量的维度与输入相同.

0.1 加噪的前向过程

假定原始数据服从分布 x 0 ∼ q ( x ) \mathbf{x}_0\sim q(\mathbf{x}) x0q(x), 现在我们逐步对其加噪, 加入的是高斯噪声. 对于每一步加噪, 我们希望将分布 q q q逐渐向高斯过程靠近, 也即让 q ( x t ∣ x t − 1 ) = N q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N} q(xtxt1)=N. 在每一步, 我们假定高斯分布的均值与过去的值 x t − 1 \mathbf{x}_{t-1} xt1有关, 而协方差为固定值(对角阵):

q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t \mathbf{I}) q(xtxt1)=N(xt;1βt xt1,βtI)

其中 β t \beta_t βt为一小正数, 在0~1之间.

注意我们假定加噪过程为Markov过程, 因此当前状态 x t \mathbf{x}_t xt只假定与上一状态 x t − 1 \mathbf{x}_{t-1} xt1有关.

因此, 当前时刻 x t \mathbf{x}_t xt是由前一时刻 x t − 1 \mathbf{x}_{t-1} xt1决定的正态分布, 其均值为 1 − β t x t − 1 \sqrt{1-\beta_t}\mathbf{x}_{t-1} 1βt xt1, 方差为 β t I \beta_t \mathbf{I} βtI. 为了表示 x t \mathbf{x}_t xt, 这里我们使用一下重参数化(Re-parametrization)技巧. 重参数化是说, 如果我们从一个高斯分布中取样, 也等效于从标准高斯分布中取样, 只不过是加上均值, 以及乘以标准差. 这是因为一个 x ∼ N ( μ , σ 2 ) x\sim\mathcal{N}(\mu, \sigma^2) xN(μ,σ2)的高斯分布可以等价于 μ + σ ϵ \mu+\sigma \epsilon μ+σϵ, 其中 ϵ ∼ N ( 0 , 1 ) \epsilon\sim\mathcal{N}(0,1) ϵN(0,1).

由高斯分布性质立即得到.

因此, x t \mathbf{x}_t xt可以表示为:

x t = 1 − β t x t − 1 + β t ϵ t − 1 \mathbf{x}_t=\sqrt{1-\beta_t} \mathbf{x}_{t-1}+\sqrt{\beta_t}\epsilon_{t-1} xt=1βt xt1+βt ϵt1

其中 ϵ t − 1 ∼ N ( 0 , I ) \epsilon_{t-1}\sim\mathcal{N}(0,I) ϵt1N(0,I). 为了表示方便, 令 1 − β t = α t \sqrt{1-\beta_t}=\sqrt{\alpha_t} 1βt =αt , 将上式递归展开, 我们有:

x t = α t ( α t − 1 x t − 2 + 1 − α t − 1 ϵ t − 2 ) + β t ϵ t − 1 \mathbf{x}_t=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\mathbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2} )+\sqrt{\beta_t}\epsilon_{t-1} xt=αt (αt1 xt2+1αt1 ϵt2)+βt ϵt1

我们注意到后面两项可以合并为一个新的高斯分布, 其均值为0, 方差为 1 − α t α t − 1 1-\alpha_t\alpha_{t-1} 1αtαt1, 按此规律展开, 我们得到:

x t = Π i α i x 0 + 1 − Π i α i ϵ ,    ϵ ∼ N ( 0 , I ) \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon, ~~\epsilon\sim \mathcal{N}(0,I) xt=Πiαi x0+1Πiαi ϵ,  ϵN(0,I)

所以, 我们可以直接从 x 0 \mathbf{x}_0 x0得到 x t \mathbf{x}_t xt的分布:

q ( x t ∣ x 0 ) = N ( x t ; Π i α i x 0 , 1 − Π i α i I ) q(\mathbf{x}_t|\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_t;\sqrt{\Pi_i\alpha_i}\mathbf{x}_0, \sqrt{1-\Pi_i\alpha_i} \mathbf{I}) q(xtx0)=N(xt;Πiαi x0,1Πiαi I)

所以, 随着时间的增加, x t \mathbf{x}_t xt会越来越趋向于标准正态分布. 以上就是加噪的过程.

0.2 去噪的反向过程

我们假定, 在加噪的正向过程中最后的结果已经近似为标准高斯分布 x T ∼ N ( 0 , I ) \mathbf{x}_T \sim \mathcal{N}(0,\mathbf{I}) xTN(0,I). 我们现在希望从加噪后的高斯分布中恢复出原来的信号, 即, 通过逐步计算 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}|\mathbf{x}_t) q(xt1xt)恢复. 然而, 如果这样计算的话, 需要从整个数据集中采样, 计算量非常大(有可能是因为类似于高斯混合模型的过程), 为此, 我们希望学习出一个模型 p θ p_\theta pθ来学习恢复过程中的条件概率:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 , μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)=\mathcal{N}(\mathbf{x}_{t-1},\mu_\theta(\mathbf{x}_t,t), \Sigma_\theta(\mathbf{x}_t,t)) pθ(xt1xt)=N(xt1,μθ(xt,t),Σθ(xt,t))

我们需要做的是让分布 p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1}|\mathbf{x}_t) p(xt1xt)尽可能与 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}|\mathbf{x}_t) q(xt1xt)接近.

我们很难计算 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}|\mathbf{x}_t) q(xt1xt), 但可以考察以 x 0 为条件的以下概率 \mathbf{x}_0为条件的以下概率 x0为条件的以下概率:

q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0)

根据Bays公式, 有:

q ( x t − 1 ∣ x t , x 0 ) = q ( x t , x t − 1 , ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)=\frac{q(\mathbf{x}_t,\mathbf{x}_{t-1},|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)}=q(\mathbf{x}_t|\mathbf{x}_{t-1}, \mathbf{x}_0)\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt1xt,x0)=q(xtx0)q(xt,xt1,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)

扩散过程为Markov过程, 因此有:

q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)=q(\mathbf{x}_t|\mathbf{x}_{t-1})\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt1xt,x0)=q(xtxt1)q(xtx0)q(xt1x0)

代入高斯分布表达式, 并凑出均值和方差(整理成 exp ⁡ { 1 2 σ 2 ( x t − μ ) 2 } \exp\{\frac{1}{2\sigma^2}(x_t-\mu)^2\} exp{ 2σ21(xtμ)2}的形式), 我们得到 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0)均值为:

μ = α t ( 1 − Π i T − 1 α i ) 1 − Π i T α i x t + Π i T − 1 α i ( 1 − α t ) 1 − Π i T α i x 0 \mu=\frac{\sqrt{\alpha_t}(1-\Pi_i^{T-1}\alpha_i)}{1-\Pi_i^{T}\alpha_i}\mathbf{x}_{t}+\frac{\sqrt{\Pi_i^{T-1}\alpha_i}(1-\alpha_t)}{1-\Pi_i^{T}\alpha_i}\mathbf{x}_{0} μ=1ΠiTαiαt (1ΠiT1αi)xt+1ΠiTαiΠiT1αi (1αt)x0
根据前面的重参数化技巧, 有 x t = Π i α i x 0 + 1 − Π i α i ϵ t ,    ϵ t \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon_t, ~~\epsilon_t xt=Πiαi x0+1Πiαi ϵt,  ϵt为网络在这一步预测的高斯噪声, 代入上式得到:

μ = 1 α t ( x t − 1 − α t 1 − Π i T α i ϵ ) \mu=\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\Pi_i^{T}\alpha_i}}\epsilon) μ=αt 1(xt1ΠiTαi 1αtϵ)

方差为:

Σ = 1 − Π i T − 1 α i 1 − Π i T α i \Sigma=\frac{1-\Pi_i^{T-1}\alpha_i}{1-\Pi_i^{T}\alpha_i} Σ=1ΠiTαi1ΠiT1αi

所以

q ( x t − 1 ∣ x t , x 0 ) ∼ N ( μ , Σ ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) \sim \mathcal{N}(\mu, \Sigma) q(xt1xt,x0)N(μ,Σ)

所以, 逆扩散的过程为: 根据网络从 x t \mathbf{x}_t xt预测的噪声 ϵ t \epsilon_t ϵt计算出均值与方差, 进而得到 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0), 作为 p θ ( x t − 1 ∣ x t ) p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) pθ(xt1xt)的近似, 如此得到 x t − 1 \mathbf{x}_{t-1} xt1, 再根据 x t − 1 \mathbf{x}_{t-1} xt1预测出下一步的噪声 ϵ t − 1 \epsilon_{t-1} ϵt1, 如此往复, 如下图所示(图源知乎)

在这里插入图片描述

0.3 采样过程的加速

然而, 如果按照上述方式更新, 则采样速度非常慢. 一种方式是我们可以跨步采样, 也就是一共 T T T个恢复时长, 我们每隔 ⌈ T / S ⌉ \lceil T/S \rceil T/S步采样一次, 这样只需要采样 S S S次.

另一种方法是, 我们直接通过前向加噪过程的变形来计算当前的恢复过程. 前向加噪过程与原始输入 x 0 \mathbf{x}_0 x0的关系为:

x t = Π i α i x 0 + 1 − Π i α i ϵ ,    ϵ ∼ N ( 0 , I ) \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon, ~~\epsilon\sim \mathcal{N}(0,I) xt=Πiαi x0+1Πiαi ϵ,  ϵN(0,I)

为了表示方便, 下面以 α ˉ k \bar{\alpha}_k αˉk表示 Π i = 1 k α i \Pi_{i=1}^k\alpha_i Πi=1kαi.

在噪声恢复过程中, 我们以网络预测的噪声 ϵ t \epsilon_t ϵt估计加噪过程中加入的噪声, 即

x t = α ˉ t x 0 + 1 − α ˉ t ϵ t \mathbf{x}_t=\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon_t xt=αˉt x0+1αˉt ϵt

翻转的过程用 x t \mathbf{x}_t xt估计 x t − 1 \mathbf{x}_{t-1} xt1, 将上式的 t t t换成 t − 1 t-1 t1, 有:

x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ t − 1 \mathbf{x}_{t-1}=\sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon_{t-1} xt1=αˉt1 x0+1αˉt1 ϵt1

但我们在 x t \mathbf{x}_t xt时刻只能得到该时刻的噪声预测 ϵ t \epsilon_t ϵt, 因此对上式做恒等变换:

x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 ϵ t + σ t ϵ \begin{aligned} \mathbf{x}_{t-1} &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1}}\boldsymbol{\epsilon}_{t-1} \\ &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon} \\ \end{aligned} xt1=αˉt1 x0+1αˉt1 ϵt1=αˉt1 x0+1αˉt1σt2 ϵt+σtϵ

该式也可以理解为给采样增加不确定度 σ t ϵ \sigma_t\boldsymbol{\epsilon} σtϵ, 实际上DiffusionDet采样正是采用的这个公式.

所以

q σ ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 x t − α ˉ t x 0 1 − α ˉ t , σ t 2 I ) q_\sigma(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}}, \sigma_t^2 \mathbf{I}) qσ(xt1xt,x0)=N(xt1;αˉt1 x0+1αˉt1σt2 1αˉt xtαˉt x0,σt2I)

对比形式 q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ ( x t , x 0 ) , β ~ t I ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}) q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI), 得到

β ~ t = σ t 2 = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ ( 1 − α t ) \tilde{\beta}_t = \sigma_t^2 = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t=\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot (1-\alpha_t) β~t=σt2=1αˉt1αˉt1βt=1αˉt1αˉt1(1αt)

在实践中, 令 σ t 2 = η ⋅ β ~ t \sigma_t^2 = \eta \cdot \tilde{\beta}_t σt2=ηβ~t来控制采样的随机程度. 当 η = 0 \eta = 0 η=0的时候, 表明采样过程是完全确定的(由网络预测的 ϵ t \boldsymbol{\epsilon}_t ϵt决定, 消除了另一个随机因子 ϵ \boldsymbol{\epsilon} ϵ的影响).

总结一下, 可以在翻转过程中进行如下步骤来提高速度:

  1. 在第 t t t步, 获取 α ˉ t , α ˉ t − 1 , α t \bar{\alpha}_t, \bar{\alpha}_{t-1}, \alpha_t αˉt,αˉt1,αt
  2. 获取网络预测的噪声 ϵ t \boldsymbol{\epsilon}_t ϵt
  3. 计算 σ t = η 1 − α ˉ t − 1 1 − α ˉ t ⋅ ( 1 − α t ) \sigma_t=\eta \sqrt{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot (1-\alpha_t)} σt=η1αˉt1αˉt1(1αt)
  4. 计算 x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 ϵ t + σ t ϵ \mathbf{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon} \\ xt1=αˉt1 x0+1αˉt1σt2 ϵt+σtϵ
  5. 直至 t = 0 t=0 t=0

0.4 损失函数

我们得到 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt1xt,x0)的目的是: 让网络学习的 p p p q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt1xt,x0)尽量接近, 或说用 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt1xt,x0)指导 p p p的训练(可认为最小化二者之间的散度).

具体损失函数再补充.

1. DiffusionDet

1.1 总览

DiffusionDet的思想非常直接, 既然目标检测是要准确地定位边界框的位置, 那么利用Diffusion Model的强大噪声恢复(学习)能力就可以优化检测的结果. 整体框架如下:

在这里插入图片描述
上图中, Image Encoder(为ResNet或Swin Transformer)提取图像的特征, 然后Detection Decoder接受噪声化的边界框, 并恢复边界框的初始值, 同时预测类别. 整体来说, 需要学习一个网络 f θ f_\theta fθ, 从 z T z_T zT中恢复出 z 0 z_0 z0, 其中 z z z为边界框. 损失函数即为恢复的值与初始值的差的2-范数:

L t r a i n = 1 2 ∣ ∣ f θ ( z t , t ) − z 0 ∣ ∣ 2 \mathcal{L}_{train}=\frac{1}{2}||f_\theta (z_t,t)-z_0||^2 Ltrain=21∣∣fθ(zt,t)z02

如上图所示, 为了减少计算量, Diffusion Model从原始图片提取的高级特征中学习. Image Encoder就是提取图像特征的, 作者采用了ResNet和SwinTransformer.

而Detection Decoder接受加噪的bbox和特征图, 并返回恢复的bbox.

1.2 训练过程

训练过程的每次迭代大致分为四步:

  1. Encoder提取特征
  2. 从标准高斯分布中采样给真值框加噪, 公式: x t = Π i α i x 0 + 1 − Π i α i ϵ ,    ϵ ∼ N ( 0 , I ) \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon, ~~\epsilon\sim \mathcal{N}(0,I) xt=Πiαi x0+1Πiαi ϵ,  ϵN(0,I)
  3. 将加噪后的真值框和特征图输入到要学习的encoder网络 f θ f_\theta fθ
  4. 计算loss

伪代码:

在这里插入图片描述

训练过程中有几个细节:

  1. 保证每次迭代输入到encoder中的框数目都相同. 作者通过尝试重复GT框, 以及concat随机大小的框或与图像大小相同的框, 发现还是concat随机大小的框效果最好
  2. 训练损失. 对于预测的框采用集合预测损失. 值得注意的是, 为每个gt框分配 k k k个预测框, 而 k k k个预测框的选取是利用指派问题进行分配(类似匈牙利算法).

1.3 推理过程

推理过程大致分为三步:

  1. Encoder提取特征
  2. 从标准高斯分布中产生边界框
  3. T T T 0 0 0, 将随机框, 特征和时间输入到decoder中, 逐步恢复出初始边界框. 恢复的过程具体是:
  1. 在第 t t t步, 获取 α ˉ t , α ˉ t − 1 , α t \bar{\alpha}_t, \bar{\alpha}_{t-1}, \alpha_t αˉt,αˉt1,αt
  2. 获取网络预测的噪声 ϵ t \boldsymbol{\epsilon}_t ϵt
  3. 计算 σ t = η 1 − α ˉ t − 1 1 − α ˉ t ⋅ ( 1 − α t ) \sigma_t=\eta \sqrt{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot (1-\alpha_t)} σt=η1αˉt1αˉt1(1αt)
  4. 计算 x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 ϵ t + σ t ϵ \mathbf{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon} \\ xt1=αˉt1 x0+1αˉt1σt2 ϵt+σtϵ
  5. 直至 t = 0 t=0 t=0

伪代码:

在这里插入图片描述

在推理过程中值得注意的是bbox更新机制. 由于输入的是固定数量的随机框, 在训练阶段我们也是加入了随机框来使数目一样, 因此输出的有些是对应于GT的bbox, 有些则是随机的. 如果把随机的再一起喂到下一步, 作者说这样就破坏了原本的分布, 因此对于每一步预测的框, 将置信度过低的舍弃, 并以新的随机框补充.

2. 代码解读

首先看一下./diffusiondet/detector.py中的DiffusionDet类, 其是该论文的核心代码. 其中的forward函数:

def forward(self, batched_inputs, do_postprocess=True):
        images, images_whwh = self.preprocess_image(batched_inputs)  # 预处理 归一化&填充
        if isinstance(images, (list, torch.Tensor)):
            images = nested_tensor_from_tensor_list(images)

        # Feature Extraction.
        src = self.backbone(images.tensor)  # Encoder 提取各级特征
        features = list()
        for f in self.in_features:
            feature = src[f]
            features.append(feature)

        # Prepare Proposals.
        if not self.training:  # 如果是推理阶段
            results = self.ddim_sample(batched_inputs, features, images_whwh, images)  # 从T时刻至0时刻 逐步采样恢复
            return results

        if self.training:  # 训练阶段
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
            targets, x_boxes, noises, t = self.prepare_targets(gt_instances)  # prepare_targets: 对GT框逐步加噪
            t = t.squeeze(-1)
            x_boxes = x_boxes * images_whwh[:, None, :]

            outputs_class, outputs_coord = self.head(features, x_boxes, t, None)  # 经过RCNNhead 预测类别和bbox
            output = {
    
    'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}

            if self.deep_supervision:
                output['aux_outputs'] = [{
    
    'pred_logits': a, 'pred_boxes': b}
                                         for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

            loss_dict = self.criterion(output, targets)  # 计算loss
            weight_dict = self.criterion.weight_dict
            for k in loss_dict.keys():
                if k in weight_dict:
                    loss_dict[k] *= weight_dict[k]
            return loss_dict

可以看到, 里面还有两个重点的self.prepare_targets(训练过程中的加噪)和self.ddim_sample(推理过程中的采样)

    def prepare_targets(self, targets):
        new_targets = []
        diffused_boxes = []
        noises = []
        ts = []
        for targets_per_image in targets:
            target = {
    
    }
            h, w = targets_per_image.image_size
            image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
            gt_classes = targets_per_image.gt_classes
            gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
            gt_boxes = box_xyxy_to_cxcywh(gt_boxes)  # 以上预处理真值框
            d_boxes, d_noise, d_t = self.prepare_diffusion_concat(gt_boxes)  # 核心部分 计算加噪后的框
            diffused_boxes.append(d_boxes)
            noises.append(d_noise)
            ts.append(d_t)
            target["labels"] = gt_classes.to(self.device)
            target["boxes"] = gt_boxes.to(self.device)
            target["boxes_xyxy"] = targets_per_image.gt_boxes.tensor.to(self.device)
            target["image_size_xyxy"] = image_size_xyxy.to(self.device)
            image_size_xyxy_tgt = image_size_xyxy.unsqueeze(0).repeat(len(gt_boxes), 1)
            target["image_size_xyxy_tgt"] = image_size_xyxy_tgt.to(self.device)
            target["area"] = targets_per_image.gt_boxes.area().to(self.device)
            new_targets.append(target)  # target为蕴含大小、类别等信息的真值
        # 返回真值、加噪后的框、噪声和步长
        return new_targets, torch.stack(diffused_boxes), torch.stack(noises), torch.stack(ts)

其中的加噪过程在self.prepare_diffusion_concat(gt_boxes), 我们可以看到:

def prepare_diffusion_concat(self, gt_boxes):
        """
        :param gt_boxes: (cx, cy, w, h), normalized
        :param num_proposals:
        """
        t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()  # 确定随机步长
        noise = torch.randn(self.num_proposals, 4, device=self.device)  # 产生标准正态分布

        num_gt = gt_boxes.shape[0]  # gt框数目
        if not num_gt:  # generate fake gt boxes if empty gt boxes
            gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
            num_gt = 1

        if num_gt < self.num_proposals:  # 如果gt框比预设的固定数目小 则随机再填充一些框
            box_placeholder = torch.randn(self.num_proposals - num_gt, 4,
                                          device=self.device) / 6. + 0.5  # 3sigma = 1/2 --> sigma: 1/6
            box_placeholder[:, 2:] = torch.clip(box_placeholder[:, 2:], min=1e-4)
            x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
        elif num_gt > self.num_proposals:  # 如果比预设数目多 就随机抹掉一些GT框
            select_mask = [True] * self.num_proposals + [False] * (num_gt - self.num_proposals)
            random.shuffle(select_mask)
            x_start = gt_boxes[select_mask]
        else:
            x_start = gt_boxes

        x_start = (x_start * 2. - 1.) * self.scale

        # noise sample
        x = self.q_sample(x_start=x_start, t=t, noise=noise)  # 前向加噪过程

        x = torch.clamp(x, min=-1 * self.scale, max=self.scale)  # 限制范围
        x = ((x / self.scale) + 1) / 2.

        diff_boxes = box_cxcywh_to_xyxy(x)

        return diff_boxes, noise, t

最后再来看看推理阶段的self.ddim_sample函数:

@torch.no_grad()
    def ddim_sample(self, batched_inputs, backbone_feats, images_whwh, images, clip_denoised=True, do_postprocess=True):
        batch = images_whwh.shape[0]
        shape = (batch, self.num_proposals, 4)
        total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

        # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)  
        times = list(reversed(times.int().tolist()))  # 时间为倒序 从T到0
        time_pairs = list(zip(times[:-1], times[1:]))  # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = torch.randn(shape, device=self.device)  # 产生标准高斯分布bboxs

        ensemble_score, ensemble_label, ensemble_coord = [], [], []
        x_start = None
        for time, time_next in time_pairs:  # 相邻时间两步计算
            time_cond = torch.full((batch,), time, device=self.device, dtype=torch.long)
            self_cond = x_start if self.self_condition else None  

            # 预测的噪声、x_0和类别与坐标
            preds, outputs_class, outputs_coord = self.model_predictions(backbone_feats, images_whwh, img, time_cond,
                                                                         self_cond, clip_x_start=clip_denoised)
            pred_noise, x_start = preds.pred_noise, preds.pred_x_start  # 获取预测的噪声\epsilon_t 与 预测的初始状态x_0

            if self.box_renewal:  # filter  Box reneral机制 将置信度低的边界框用随机框替换
                score_per_image, box_per_image = outputs_class[-1][0], outputs_coord[-1][0]
                threshold = 0.5
                score_per_image = torch.sigmoid(score_per_image)
                value, _ = torch.max(score_per_image, -1, keepdim=False)
                keep_idx = value > threshold
                num_remain = torch.sum(keep_idx)

                pred_noise = pred_noise[:, keep_idx, :]
                x_start = x_start[:, keep_idx, :]
                img = img[:, keep_idx, :]
            if time_next < 0:
                img = x_start
                continue
            
            # 获取\alpha_i的连乘值
            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(img)  # 标准高斯分布中采样

            # 公式: x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} * x_0 + 
            # \sqrt(1 - \bar{\alpha}_{t-1}} - sigma^2} * \epsilon_t + 
            # \sigma * \epsilon
            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise  # 通过预测的噪声 计算恢复结果

            if self.box_renewal:  # filter
                # replenish with randn boxes
                img = torch.cat((img, torch.randn(1, self.num_proposals - num_remain, 4, device=img.device)), dim=1)
            if self.use_ensemble and self.sampling_timesteps > 1:
                box_pred_per_image, scores_per_image, labels_per_image = self.inference(outputs_class[-1],
                                                                                        outputs_coord[-1],
                                                                                        images.image_sizes)
                ensemble_score.append(scores_per_image)
                ensemble_label.append(labels_per_image)
                ensemble_coord.append(box_pred_per_image)

        if self.use_ensemble and self.sampling_timesteps > 1:
            box_pred_per_image = torch.cat(ensemble_coord, dim=0)
            scores_per_image = torch.cat(ensemble_score, dim=0)
            labels_per_image = torch.cat(ensemble_label, dim=0)
            if self.use_nms:
                keep = batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
                box_pred_per_image = box_pred_per_image[keep]
                scores_per_image = scores_per_image[keep]
                labels_per_image = labels_per_image[keep]

            result = Instances(images.image_sizes[0])
            result.pred_boxes = Boxes(box_pred_per_image)
            result.scores = scores_per_image
            result.pred_classes = labels_per_image
            results = [result]
        else:
            output = {
    
    'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
            box_cls = output["pred_logits"]
            box_pred = output["pred_boxes"]
            results = self.inference(box_cls, box_pred, images.image_sizes)
        if do_postprocess:  # 后处理
            processed_results = []
            for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                r = detector_postprocess(results_per_image, height, width)
                processed_results.append({
    
    "instances": r})
            return processed_results

猜你喜欢

转载自blog.csdn.net/wjpwjpwjp0831/article/details/127973262
今日推荐