DiffusionDet源码阅读(1)

本文仅仅适用于已经通读过全文的小伙伴

本文代码节选自 mmdet 中的 DiffusionDet 代码,目前该代码还处于 Development 阶段,所以我博客里写的代码和之后的稳定版本可能稍有不同,不过不用担心,我们只看最关键的部分

DDPM中扩散部分有个参数 β \beta β:

q ( z t ∣ z t − 1 ) : = N ( z t ; 1 − β t z t − 1 , β t I ) q(z_t | z_{t-1}) := \mathcal{N} (z_{t}; \sqrt{1 - \beta_t} z_{t-1}, \beta_t \bf{I} ) q(ztzt1):=N(zt;1βt zt1,βtI)

这就是每次的加噪过程,也可以视为 z t − 1 z_{t-1} zt1先经过一个缩放,再加一个随机噪声之后,就成了 z t z_{t} zt
每次加噪声通过一个参数 β t \beta_t βt来控制,这个参数是人为给定的,而不是可学习的,由于:

q ( z t ∣ z 0 ) : = N ( z t ; α ˉ t z 0 , ( 1 − α ˉ t ) I ) q(z_t | z_{0}) := \mathcal{N} (z_{t}; \sqrt{ \bar{\alpha}_t } z_{0}, (1-\bar{\alpha}_t) \bf{I} ) q(ztz0):=N(zt;αˉt z0,(1αˉt)I)
即:

z t = α ˉ t z 0 + ϵ 1 − α ˉ t ,    w h e r e    ϵ ∈ N ( 0 , I ) z_t = \sqrt{ \bar{\alpha}_t } z_{0} + \epsilon \sqrt{1 - \bar{\alpha}_t}, \ \ where \ \ \epsilon \in \mathcal{N}(0, \bf{I}) zt=αˉt z0+ϵ1αˉt ,  where  ϵN(0,I)

在给定 z 0 z_{0} z0 的基础上, q ( z t ∣ z 0 ) q(z_t | z_{0}) q(ztz0) 也是一个高斯分布,其中:

α t = 1 − β t α ˉ t = Π s = 0 t α s \alpha_t = 1 - \beta_t \\ \bar{\alpha}_t = \Pi_{s=0}^t \alpha_s αt=1βtαˉt=Πs=0tαs

α ˉ t \bar{\alpha}_t αˉt 取值趋近于0时, z t z_t zt 可以视为一个标准的高斯分布,在DiffusionDet中, β 1 : T \beta_{1:T} β1:T取了一系列零到一,且逐渐变大的值,以下是生成 β \beta β 的代码,这里我们取 T = 1000 T=1000 T=1000,即共采样 1000 1000 1000

def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule as proposed in
    https://openreview.net/forum?id=-NEXDKk8gZ."""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(
        ((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

c o s ( x ) cos(x) cos(x) c o s 2 ( x ) cos^2(x) cos2(x) 两个函数的曲线,红线是前者,蓝线是后者,二者有同一个零点 ( π 2 , 0 ) (\frac{\pi}{2}, 0) (2π,0)

请添加图片描述

这是 β \beta β的曲线

请添加图片描述

接下来就是上边计算 α \alpha α α ˉ \bar{\alpha} αˉ之类的代码:

    def _build_diffusion(self):
        betas = cosine_beta_schedule(self.timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (
            1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        # log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer(
            'posterior_mean_coef1',
            betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
                             (1. - alphas_cumprod))

这三行计算了 β t \beta_t βt, α ˉ t \bar{\alpha}_t αˉt α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1,其长度都是 T T T

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

q ( z t ∣ z t − 1 ) : = N ( z t ; 1 − β t z t − 1 , β t I ) q(z_t | z_{t-1}) := \mathcal{N} (z_{t}; \sqrt{1 - \beta_t} z_{t-1}, \beta_t \bf{I} ) q(ztzt1):=N(zt;1βt zt1,βtI)

接下来计算 α ˉ t \sqrt{\bar{\alpha}_{t}} αˉt 1 − α ˉ t \sqrt{1 - \bar{\alpha}_{t}} 1αˉt log ⁡ ( 1 − α ˉ t ) \log{(1-\bar{\alpha}_{t})} log(1αˉt) 1 α ˉ t \frac{1}{\sqrt{\bar{\alpha}_{t}}} αˉt 1 1 α ˉ t − 1 \sqrt{\frac{1}{\bar{\alpha}_t} - 1} αˉt11

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod - 1))

DDPM文中假设,后验分布 q ( z t − 1 ∣ z t , z 0 ) q(z_{t-1} | z_t, z_0) q(zt1zt,z0)也是高斯分布,有:

q ( z t − 1 ∣ z t , z 0 ) = N ( z t − 1 ; μ ~ ( z t , z 0 ) , β t ~ I ) q(z_{t-1} | z_t, z_0) = \mathcal{N} (z_{t-1} ; \tilde{\mu}(z_t, z_0), \tilde{\beta_t} \bm{I}) q(zt1zt,z0)=N(zt1;μ~(zt,z0),βt~I)

算式整理后有:

μ ~ t ( z t , z 0 ) = α ˉ t − 1 β t 1 − α ˉ t z 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t z t \tilde{\mu}_t(z_t, z_0) = \frac{ \sqrt{\bar{\alpha}_{t-1}} \beta_t }{ 1 - \bar{\alpha}_t } z_{0} + \frac { \sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) } { 1 - \bar{\alpha}_t } z_{t} μ~t(zt,z0)=1αˉtαˉt1 βtz0+1αˉtαt (1αˉt1)zt

β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}_{t} = \frac { 1 - \bar{\alpha}_{t-1} } { 1 - \bar{\alpha}_t } \beta_{t} β~t=1αˉt1αˉt1βt

接下来的几行代码用来计算这几个系数:

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (
            1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        # log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer(
            'posterior_mean_coef1',
            betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
                             (1. - alphas_cumprod))

以上就是函数 _build_diffusion 的全部内容,集中几个log项可能是之后计算loss用的

猜你喜欢

转载自blog.csdn.net/HaoZiHuang/article/details/130342149