【3D生成与重建】SSDNeRF:单阶段Diffusion NeRF的三维生成和重建

系列文章目录

题目:Single-Stage Diffusion NeRF: A Unified Approach to 3D Generation and Reconstruction
论文:https://arxiv.org/pdf/2304.06714.pdf
任务:无条件3D生成(如从噪音中,生成不同的车等)、单视图3D生成
机构:Hansheng Chen,1,* Jiatao Gu,2 Anpei Chen, 同济、苹果、加利福尼亚大学
代码:https://github.com/Lakonik/SSDNeRF


摘要

  3D-aware image synthesis任务,包括场景生成和 image-based 的新视图合成。本文提出了SSDNeRF,使用扩散模型从不同对象的多视图图像中学习神经辐射场(NeRF)的可推广先验。先前的研究使用两阶段方法,依赖于Pretrained NeRF作为真实数据来训练扩散模型。相比之下,SSDNeRF作为单阶段、端到端的训练范式,联合优化NeRF的自动decode 和 latent Diffusion模型,实现同时三维重建和先验学习(甚至包括稀疏视图)。测试时,可以直接对扩散先验进行无条件生成,或将其与不可见物体的任意观测相结合,进行NeRF重建。在无条件生成和单/稀疏视图三维重建方面,SSDNeRF显示了与 task-specific 方法相当或更好的鲁棒结果。

一、前言

  伴随 神经渲染生成模型的发展,3D内容生成(如单/多视图3D重建和3D内容生成)的单一算法得以发展,但缺少全面的框架来连接多任务的技术。神经辐射场(NeRF)通过超分拟合求解逆渲染问题,在新视图合成中显示出令人印象深刻的结果(但只适用密集视图,难以推广到稀疏观测)。相比之下,许多稀疏视图三维重建方法[pixelNeRF,Mvsnerf,ViT NeRF] 依赖于前馈的 image-to-3D 编码器,但它们不能处理遮挡区域的模糊性,也不能生成清晰的图像。在无条件生成方面,3D感知生成对抗网络(GAN)[34,5,18,14]在单图像鉴别器的使用上受到了部分限制,这不能使交叉视图关系能够有效地从多视图数据中学习。

  之前的工作[如Score-based NeRF、Gaudi、triplane diffusion、Diffrf]中,类似的ldm已经应用于2D和3D生成,但它们通常需要两阶段训练,第一阶段排除Diffusion model 只预训练 VAE( 变分自动编码器) 或自动解码器。然而,在扩散nerf的情况下,我们认为由于逆渲染的不确定性,两阶段训练在latent code 中诱发噪声模式和伪影(特别是稀疏视图时),这阻止了 Diffusion model 有效地学习干净的潜在流形。为了解决这个问题,

  本文提出一个统一框架SSDNeRF(单阶段扩散NeRF),用三维潜在扩散模型(LDM)建模场景latent code 的生成先验,从多视图图像中学习可泛化的三维先验,用于处理各种三维任务(图1)。单阶段训练范式,使 Diffusion 和 NeRF 权重的端到端学习成为可能。这种方法将生成的bias和渲染的bias 一致地结合在一起,以提高整体性能,并允许对稀疏视图数据进行训练。此外,学习到的无条件扩散模型的三维先验可以用于任意观测的灵活测试场景采样。

在这里插入图片描述

主要贡献如下:

1.提出SSDNeRF,一种统一的 un-conditional 3D生成和 image-base 3D重建方法;
2.作为新的单阶段训练范式,SSDNeRF从大量对象的多视图图像中联合学习NeRF重建和Diffusion模型(即使每个场景只有三个视图)
3.提出了一种 引导微调采样方案(guidance-finetuning sampling scheme),利用学习到的扩散先验,测试时从任意数量的视图进行3D重建.

二、相关工作

2.1. 3D GANs

  通过将基于投影的渲染集成到生成器中,GAN 已经成功地用于三维生成。以前已经探索过各种3D表示(点云、长方体、球体[27]和体素);最近的NeRF 和带有体积渲染器的特征场[34,18,5],以及带有网格渲染器的可微表面[14]。上述方法都是用二维图像的Discrinimator 进行训练的,无法推理交叉视图关系,这使得它们严重依赖于三维一致性的模型 bias,不能有效地利用多视图数据来学习复杂和多样的几何图形。三维gan主要应用于无条件生成。虽然通过GAN inversion[12]可以完成图像的3D重建,但由于潜在表达能力有限,并不能保证真实度,如论文[Diffrf、RenderDiffusion]所述。

2.2 View-Conditioned 回归和生成

  稀疏三维重建可以通过从输入图像中回归出的新视图来解决,提出的各种架构[8,59,28,61]将图像编码为 volume features通过体渲染投影到监督的目标视图然而,它们不能推理模糊性,并产生多样化和有意义的内容,这往往会导致模糊的结果。相比之下, image-conditioned 的生成模型能更好地合成不同的内容3DiM [57]提出从 view-conditioned 的图像扩散模型中生成新的视图,但该模型缺乏三维一致性偏差。[Sparsefusion、Nerdi、Nerfdiff ] 将 image-conditioned 的二维扩散模型的先验提取成nerf,以加强三维约束。这些方法与我们的轨迹是平行的,因为它们建模了图像空间中的交叉视图关系,而我们的模型本质上是三维的

2.3 自动解码器和 Diffusion NeRF

  NeRF的单一场景拟合方案,可以推广到多场景,通过在所有场景中共享部分参数,将其余的作为单独的场景代码[7]。因此,多场景NeRF 可以被训练为自动解码器[35],共同学习 code
bank和 共享Decoder 的权重
。通过适当的架构,scene codes 可以被视为高斯先验的 latent,允许3D completion 甚至生成[24,48,38]。然而,就像3D GAN一样,这些 latent 并没有足够的表达力来忠实地重建详细的对象。[Gaudi、From data to functa、Rodin等] 改进了具有潜在扩散先验的普通自动解码器。DiffRF [32]在执行3D completion 之前利用DIffusion。这些方法在两阶段独立训练 Auto-Decoder 和 Diffusion models,受到局限。

2.4 . NeRF 作为 Auto-Decoder

  给定一组场景的二维images和对应相机参数,可以在三维空间拟合出场景的光场,表示为光学函数 yψ (r)(其中r 用于参数化世界空间中一条射线的端点和方向,ψ表示模型参数,y∈R3+ 表示接收到光线的RGB格式。NeRF 将光场表示为 沿光线通过三维体积的集成辐射(具体原理见我的blog:【三维重建】NeRF原理+代码讲解)

  NeRF还可以通过在所有场景[7]中共享部分模型参数来推广到多场景设置。给定多个场景{ yijgt,rijgt},其中yijgt,rijgt 是第 i 个场景的第 jRGB像素和射线,可以通过最小化L2渲染损失来优化每个场景码 {xi} 和共享参数 ψ

在这里插入图片描述
有了这个目标,模型被训练为一个自动解码器。场景码 {xi} 可以解释为latent code。在独立高斯分布的假设下,光学函数可以视为解码器的形式:
在这里插入图片描述

2.5 生成和重建中的挑战

  具有训练权重ψ 的 auto-decoder 可以通过解码从高斯先验中提取的 latent code 来进行无条件生成。然而,为了保证生成的连续性,需要一个低维的潜在空间和一个复杂的解码器,这增加了优化中真实重建任何视图的困难。

2.6 潜在扩散模型

  潜在扩散模型(LDM)在参数为ϕ的潜在空间中学习先验分布 pϕ(x),使更有表达能力的潜在表示(如二维图像网格)成为可能。在神经场生成方面,之前的工作[2,32,13,47]采用了两阶段的训练方案:首先训练自动解码器来获得每个场景的潜在码 xi然后将其作为真实数据来训练LDM。LDM将高斯噪声 ϵ∼N(0,I)注入到 xi 中,在经验噪声调度函数α(t),σ(t) 的扩散时间步长 t 处产生噪声码 xi(t) := α(t)xi(t)ϵ。然后,一个具有可训练权重 ϕ 的去噪网络去除 xi(t) 中的噪声,以预测一个去噪码 x ^ \hat{x} x^i。该网络常用简化的L2去噪损失来训练:
在这里插入图片描述

w(t) 是一个经验的时间相关的加权函数,和 x ^ \hat{x} x^ϕ(xi(t),t) 制定了时间分割去噪网络。

  1. 无条件/引导采样

  使用训练好的权重 ϕ,各种解释器(例如DDIM )从扩散先验 pψ(x) 中采样,递归去噪x(t),从随机高斯噪声x(t)开始,直到达到去噪状态x(0)采样过程可以由渲染损失对已知观测值的梯度来指导,允许在测试时从图像进行三维重建。

  1. 两阶段训练三维任务的局限性

  使用2D图像VAEs的LDM 通常分两个阶段的]进行训练;使用NeRF自动解码器训练LDM时:通过基于渲染的优化,来获得一个expressive 的latent code是欠确定,导致去噪网络的噪声(图2左上角);此外,从没有学习先验的稀疏视图中重建nerf是非常困难的(图2的左下角),这将训练限制在密集视图设置中。
在这里插入图片描述


三、本文方法

SSDNeRF,一个expressive的三平面NeRF自动解码器与三平面 latent diffusion 模型连接起来的框架。图3提供了该模型的概述。

在这里插入图片描述

3.1 单阶段扩散NeRF训练

   一个 auto-Decoder 可以看作是一种使用 lookup table Encoder的VAE,而不是典型的神经网络 Encoder。因此,训练目标可以以类似于VAEs的方式推导出。利用NeRF解码器
pψ({yj} | x, {rj}) 和扩散潜先验 pϕ(x)训练目标:最小化观测数据{ yijgt,rijgt}的负对数似然(NLL)的变分上界。本文通过忽略 latent code 中的不确定性(方差),得到了一个简化的训练损失:

在这里插入图片描述
其中,场景码 {xi} 、先验参数 ϕ 和解码器参数 ψ 在单个训练阶段中共同进行优化。这个损失包括公式1中的渲染损失 Lrend,以及一个以NLL形式存在的扩散先验项。仿照[Maximum likelihood training of score-based diffusion models, Score-based generative modeling in latent space 等]论文,我们用等式(2)中的近似上界 Ldiff (也被称为分数蒸馏)代替扩散NLL 。加入经验权重因子,最终的训练目标:

在这里插入图片描述

  单阶段训练,使用以上损失约束场景码 {xi},允许学好的先验完成看不见的部分,这对于稀疏视图数据的训练特别有益(expressive triplane codes 严重不确定)

渲染和先验权重的平衡

  渲染-先验的权重比 λrenddiff 是单阶段训练的关键。为保证泛化,设计了一个经验加权机制,其中扩散损失由场景码的 Frobenius 范数的指数移动平均(EMA)归一化,表示为:
λdiff := cdiff / EMA(||xi||2F) , cdiff 为固定尺度;
λrend := crend(1−e−0.1Nv)/Nv。 渲染权重由可见视图 Nv 的数量决定:基于Nv 的加权是对解码器 pψ 的校准,防止渲染损失根据射线数量线性缩放

与两阶段生成性神经场的比较

  之前的两阶段方法[Gaudi,Diffrf,3d neural field generation using triplane diffusion] 在训练第一阶段,忽略了前项 λdiffLdiff这可以看作是将 渲染-先验 的权重 λrenddiff 设置为无穷大,导致 biased和有噪声的场景码 xi。论文[3d neural field generation using triplane diffusion]通过在三平面场景代码上施加全变分(TV)正则化来强制进行平滑先验,部分地缓解了这一问题,类似于在潜在空间上的LDM约束(图2的中列)。 Control3Diff 提出对在单视图像上预训练的3D GAN生成的数据学习条件扩散模型。相比之下,我们的单阶段训练的目标是在促进端到端一致性之前直接纳入扩散。


3.2 图像引导下的采样和微调

  为了实现可推广的快速NeRF重建,并覆盖了单视到密集多视的重建,我们建议执行图像引导采样,同时考虑扩散先验和渲染似然,对采样码进行微调。根据[Video diffusion models]重建引导的采样方法,计算了近似的渲染梯度g,即一个噪声码x(t):
在这里插入图片描述
其中,(t)(t)) 是一个基于信噪比(SNR)的附加加权因子 (超参数ω为0.5或0.25)。引导梯度g与无条件分数预测相结合,表示为对去噪输出 x ^ \hat{x} x^ 的修正:

在这里插入图片描述
引导尺度为λgd。我们采用 预测-校正采样器[52],通过交替使用DDIM步骤和多个朗之万校正步骤来求解x(0)

  我们观察到,重建指导不能严格地执行忠实重建的渲染约束。为了解决这个问题,我们在等式4中重用,对采样的场景码 x 进行微调,同时冻结扩散和解码器参数:
在这里插入图片描述
其中,λ’diff 是测试时间的先验权值,它应该低于训练权值 λdiff(因为从训练数据集学习到的先验在转移到不同的测试数据集时不太可靠)。使用Adam来优化代码x以进行微调

与以往的NeRF微调方法的比较
  虽然用渲染损失进行微调在 view-conditioned 的NeRF回归方法[8,61]中很常见,但我们的微调方法在三维场景代码上使用扩散先验损失方面有所不同,这显著提高了对新视图的泛化,如5.3所示。

3.3 一些细节

  1. 先验梯度缓存

  三平面NeRF重建需要对每个场景码 xi 至少进行数百次优化迭代。公式(4)中单阶段损失中,扩散损失Ldiff 比原生NeRF渲染损失Lrend 需要更长的时间来验证,降低了整体效率。为了加速训练和微调,我们引入了一种技术称为先验梯度缓存:Prior Gradient Caching,缓存的反传梯度 xλdiffLdiff 重用在多个Adam步中,同时在每一步中刷新渲染梯度 xλrendLrend 。它允许更少的扩散渲染比。以下是一次算法的伪代码:

  1. 去噪的参数化和加权

  去噪模型 x ^ \hat{x} x^ϕ(x(t),t) 被实现为一个DDPM中的U-Net网络(共计122M参数)。其输入和输出分别是有噪声和去噪的三平面特征(三个平面的通道堆叠在一起)。对于测试的形式,我们采用[43:Progressive distillation for fast sampling of diffusion models]中的 v-参数化 v ^ \hat{v} v^ϕ(x(t),t),使 x ^ \hat{x} x^ = α(t)x(t)−σ(t) v ^ \hat{v} v^。关于等式(2)中扩散损失的加权函数w(t),LSGM [54]分别采用两种不同的机制来优化 latent xi 和扩散权重 ϕ ;我们发现使用NeRF自动解码器是不稳定的。相反,我们观察到在公式5中使用的基于信噪比的加权w(t) =(α(t)/σ(t)) 表现很好。


四、实验

41 数据集

  实验采用ShapeNet SRN [6,48]和Amazon Berkeley Objects(ABO)Tables[9]数据集。SRN数据集提供了两类的单对象场景,即汽车和椅子,汽车的 train/test 划分为2458/703,椅子为4612/1317。每个训练场景有50个来自一个球体的随机视图,每个测试场景有251个来自上半球的螺旋视图。ABO Tables数据集提供了1520/156个表场景的训练/测试划分,其中每个场景有来自上半球的91个视图。渲染分辨率为128×128。

4.2 无条件生成

  使用SRN Cars和ABO Tables 数据集对无条件生成进行评估。Cars数据集在生成尖锐和复杂的纹理方面提出了挑战,而Tables 数据集由不同的几何图形和真实的材料组成。模型在训练集的所有图像上进行1M次iter 训练。

  1. 验证方案和指标

  对于SRN Cars,按照Functa,我们从扩散模型中抽取704个场景,并使用测试集中固定的251个摄像机姿态渲染每个场景。对于ABO tabel,按照DiffRF采样了1000个场景,并使用10个随机摄像机渲染每个场景。生成质量度量标注:Frechet Inception Distance (FID)和 Kernel Inception Distance (KID)

  1. 与最新方法对比
    如表1所示,在SRN汽车上,一阶段SSDNeRF在KID(更适合小数据集)中明显优于EG3D。同时,它的FID明显优于Functa(使用了一个LDM,但具有低维的latent code)。在ABO tabels中,SSDNeRF的性能明显优于EG3D和DiffRF。

  2. 单阶段&两阶段
    在SRN Cars上,比较了单阶段训练,以及相同模型架构下的TV正则化两阶段训练,表1结果表明单阶段的优势。

在这里插入图片描述

4.3 稀疏视图NeRF重建

  Cars数据提出了恢复不同纹理的挑战, chair 数据需要精确地重建不同的形状。模型在训练集的所有图像上进行训练8000次迭代,我们发现更长的时间表会导致重建看不见的物体的性能下降。这种行为与插值的结果相一致

  1. 评估方案和指标

  使用PixelNeRF [59]的评价指标:给定测试场景采样的输入图像,通过guidance-finetuning 获得三平面场景码,对未知视角图像,评估新视图质量。图像质量指标:平均峰值信号调压比(PSNR)、结构相似度(SSIM)和学习感知图像补丁相似度(LPIPS)[60]。以及合成图像和真实图像之间的FID(如3DiM)。

  1. 与其他方法比较

  表2中:SSDNeRF有最好的LPIPS,感知保真度最好。3DiM生成高质量的图像(最好的FID),但对地面真相的保真度最低(PSNR);CodeNeRF报告了单视图汽车上最好的PSNR,但其有限的表达力导致输出模糊(图5)和LPIPS和FID;VisionNeRF在所有单视图指标上实现了平衡的性能,但可能难以在汽车看不见的一侧生成纹理细节(例如,图5中救护车的另一边)。此外,SSDNeRF在双视图重建方面具有明显的优势,在所有相关指标上取得最佳性能。

在这里插入图片描述

在这里插入图片描述

4.4 在稀疏视图数据上,训练SSDNeRF

本节在完整的SRN Cars训练集的稀疏视图子集上训练SSDNeRF,在每个场景中随机选取三个视图的固定集合请注意,与密集视图训练相比,由于整个训练数据集已经减少到其原始大小的6%,因此性能将会有预期的合理下降。

  1. 无条件生成

在训练中途,将三平面code 重置到它们的平均值。这有助于防止模型陷入一个过度拟合几何伪影的局部最小值。相应地将训练时间延长一倍。该模型实现了一个良好的FID的19.04±1.10和一个KID/10−3的8.28±0.60。结果如图7所示。

在这里插入图片描述

  1. 单视角重建

我们采用与5.3相同的训练策略。通过我们的指导-微调方法,该模型获得了0.106的LPIPS分,甚至优于表2中之前使用完整训练集的大多数方法。

  1. 与TV正则化的比较

图8 (b)显示了在训练过程中从三个视图中学习到的场景latent code 所代表的RGB图像和几何图形。相比之下,采用电视正则化的普通三平面自动解码器(图8 (a))往往不能从稀疏视图重建场景,导致严重的几何伪影。因此,以前在稀疏视图数据上训练具有表达延迟的两阶段模型是不可行的。

4.5 NeRF插值

     按照DDIM 的设置,对两个初始值 x(T)∼N(0,I) 进行采样,使用球面线性插值[46]( spherical linear interpolation)对其进行插值,然后使用确定性求解器得到插值样本。然而,正如[37,40]所指出的,标准的高斯扩散模型往往会导致非光滑的插值。在SSDNeRF中(结果如图9所示),我们发现早期停止稀疏视图重建训练的模型(a)产生合理的平滑过渡,而无条件生成训练的模型(b)产生不同但不连续的样本。这表明,早期停止保持了更平滑的先验,从而导致更好地泛化稀疏视图重建。

在这里插入图片描述

五、代码

5.1 无条件生成

这里选择配置文件:ssdnerf_cars_uncond.py,下载了对应权重:ssdnerf_cars_uncond_1m_emaonly.pth。配置文件的含义如下图所示:

ssdnerf_cars3v_uncond
   │      │      └── testing data: test unconditional generation
   │      └── training data: train on Cars dataset, using 3 views per scene
   └── training method: single-stage diffusion nerf training

stage2_cars_recons1v
   │     │      └── testing data: test 3D reconstruction from 1 view
   │     └── training data: train on Cars dataset, using all views per scene
   └── training method: stage 2 of two-stage training

提示:权重可能不是跟模型完全匹配,猜想可能因为在重建中还会训练;另外就是想引入一些网络的随机权重,增强泛化性。此外,大部分重要函数,都用C++做了编译(便于cuda并行计算),因此无法解析

1.生成噪音,作为输入

num_batches = len(data['scene_id'])                              # 6
noise = torch.randn((num_batches, *self.code_size), 
                     device=get_module_device(self))             # (8,3,6,128,128)

code_diff = code_diff.reshape(code.size(0), *self.code_reshape)  # (8,18,128,128

2.diffusion过程:
来自类 :class GaussianDiffusion(nn.Module) 中的ddim_sample:

sample_fn_name = 'ddim'

# 参数设置---------------------------------------------------------------------
x_t = noise
num_timesteps = self.test_cfg.get('num_timesteps', self.num_timesteps)  # 50
langevin_steps = self.test_cfg.get('langevin_steps', 0)                 # 0
langevin_t_range = self.test_cfg.get('langevin_t_range', [0, 1000])     # [0,1000]
timesteps = torch.arange(start=self.num_timesteps - 1, end=-1, 
                 step=-(self.num_timesteps / num_timesteps)).long().to(device)
# 从01000步中,随机选50个时间步

# 50次的逆扩散过程 (最重要!!)-------------------------------------------------
for step, t in enumerate(timesteps):                 # 0, 999
    t_prev = timesteps[step + 1]                     # 979
    x_t, x_0_pred = self.p_sample_ddim( x_t, t, t_prev, concat_cond=None, cfg=self.test_cfg, **kwargs)
    # 以上 self.p_sample_ddim过程,下面有详解

	eps_t_pred = (x_t - self.sqrt_alphas_bar[t] * x_0_pred) / self.sqrt_one_minus_alphas_bar[t]
	pred_sample_direction = np.sqrt(1 - alpha_bar_t_prev - tilde_beta_t * (eta ** 2)) * eps_t_pred
	x_prev = np.sqrt(alpha_bar_t_prev) * x_0_pred + pred_sample_direction      # (8,18,128,128)

    x_t = x_prev              # 更新x_t,再次循环

2.1 self.p_sample_ddim中,预定义的各种参数:扩散因子 β \beta β、α、 α ˉ \bar{α} αˉ

def prepare_diffusion_vars(self):

	## 1.扩散因子 beta :-----------------------------------------------------------
	scale = 1000 / diffusion_timesteps      # 一般是(1000步)
	beta_0 = scale * 0.0001
	beta_T = scale * 0.02
	self.betas = np.linspace(beta_0, beta_T, diffusion_timesteps, dtype=np.float64)
	
	## 2.alphas及其他参数 :--------------------------------------------------------
	self.alphas = 1.0 - self.betas
	self.alphas_bar = np.cumproduct(self.alphas, axis=0)
	self.alphas_bar_prev = np.append(1.0, self.alphas_bar[:-1])
	self.alphas_bar_next = np.append(self.alphas_bar[1:], 0.0)
	
	## 3. 计算扩散概率 q(x_t | x_0)--------------------------------------------
	self.sqrt_alphas_bar = np.sqrt(self.alphas_bar)
	self.sqrt_one_minus_alphas_bar = np.sqrt(1.0-self.alphas_bar)
	self.log_one_minus_alphas_bar = np.log(1.0-self.alphas_bar)
	self.sqrt_recip_alplas_bar = np.sqrt(1.0 / self.alphas_bar)
	self.sqrt_recipm1_alphas_bar = np.sqrt(1.0 / self.alphas_bar-1)
	
	##  4.计算后验概率 q(x_{
    
    t-1} | x_t, x_0)  -----------------------------------
	self.tilde_betas_t = self.betas * (1-self.alphas_bar_prev) / (1-self.alphas_bar)
	    
	## 5.clip log var for tilde_betas_0 = 0----------------------------------------
	self.log_tilde_betas_t_clipped = np.log(np.append(self.tilde_betas_t[1], self.tilde_betas_t[1:]))
	self.tilde_mu_t_coef1 = np.sqrt(self.alphas_bar_prev) / (1 - self.alphas_bar) * self.betas
	self.tilde_mu_t_coef2 = np.sqrt(self.alphas) * (1 - self.alphas_bar_prev) / (1 - self.alphas_bar)

2.2 self.p_sample_ddim中 逆扩散过程:标准Unet(带自注意力操作),前后维度不变

	alpha_bar_t_prev = self.alphas_bar[t_prev]  
	embedding = self.time_embedding(t)    # bs*[999] --> (bs,512)
	h, hs = x_t, []
	
	## 0.如果有condition,跟噪声x_t拼接
	if self.concat_cond_channels > 0:
	    h = torch.cat([h, concat_cond], dim=1)
	    
	## 1.下采样阶段 
	for block in self.in_blocks:
	    h = block(h, embedding)
	    hs.append(h)                           # (8,512,8,8)
	
	# 2.中间阶段
	h = self.mid_blocks(h, embedding)
	
	# 上采样
	for block in self.out_blocks:
	    h = block(torch.cat([h, hs.pop()], dim=1), embedding)    #   (8,128,128,128)
	denoising_output = self.out(h)                    

return denoising_output                                # (8,18,128,128)

2.3 self.p_sample_ddim中,得到预测的 x_0

denoising_output = self.denoising(x_t,t,concat_cond=None)   # 上步结果(8,18,128,128if self.denoising_mean_mode.upper() == 'V':
    x_0_pred = sqrt_alpha_bar_t * x_t - sqrt_one_minus_alpha_bar_t * denoising_output
    # 预测的x_0 是 x_t和denoising_output结果的线性组合

if clip_denoised:
    x_0_pred = x_0_pred.clamp(*clip_range)        # 按-22 截断,防止梯度消失/爆炸
  1. self.get_density:NeRF 预测密度
    位于 大类 class BaseNeRF(nn.Module)中:
density_grid = self.get_init_density_grid(num_scenes=8, device)

    tmp_grid = torch.full_like(density_grid, -1)     # (8,262144*[-1]


    # 生成64*64*64的网格:--------------------------------------------------------------------------
    if iter_density < 16:
        X = torch.arange(self.grid_size, dtype=torch.int32, device=device).split(S)   # 64
        Y = torch.arange(self.grid_size, dtype=torch.int32, device=device).split(S)
        Z = torch.arange(self.grid_size, dtype=torch.int32, device=device).split(S)

        # 只循环一次:
        for xs in X:
            for ys in Y:
                for zs in Z:
                    # 构建三维网格----------------------------------------------------------------------
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)         # (64,64,64)
                    coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
                                       dim=-1)                       # [262144, 3]
                    indices = morton3D(coords).long()                # (262144): 打乱的索引
                    xyzs = (coords.float() - (self.grid_size - 1) / 2) * (2 * decoder.bound / self.grid_size)  # 以立方体中心点做归一化到(-1,1)
                    # 添加噪声--------------------------------------------------------------------------
                    half_voxel_width = decoder.bound / self.grid_size    # 1/64
                    xyzs += torch.rand_like(xyzs) * (2 * half_voxel_width) - half_voxel_width   
                    # 解码出密度值----------------------------------------------------------------------
                    sigmas = decoder.point_decode(xyzs , code)       # 下面有解析
                    tmp_grid[:, indices] = sigmas.clamp(max=torch.finfo(tmp_grid.dtype).max).to(tmp_grid.dtype)
            
                    # 求有效密度------------------------------------------------------------------------
		            valid_mask = (density_grid >= 0) & (tmp_grid >= 0)                                 # (8,262144)
		            density_grid[:] = torch.where(valid_mask, torch.maximum(density_grid * decay, tmp_grid), density_grid)   # (8,262144)
		            mean_density = torch.mean(density_grid.clamp(min=0))  # -1 regions are viewed as 0 density.       # 352.75
		            iter_density += 1
		
		            # 保存结果到 bitfield ----------------------------------------------------------------
		            density_thresh = min(mean_density, density_thresh)     # min(352.75, 0.1)
		            packbits(density_grid, density_thresh, density_bitfield)           
		            # 将 grid 与 thresh 作比较,每8个字节为一组,大于的存在 bitfield 中


def point_decode(self, xyzs, dirs, code, density_only=False):
    
    num_scenes, _, n_channels, h, w = code.size()        # 8,6,128,128

    # # 坐标上插值code特征:
    point_code = F.grid_sample(code.reshape(24,6,128,128),    
                               self.xyz_transform(xyzs),      # (24,1,262144,2) .取出其中的3个面坐标,并拼接
                               mode=self.interp_mode, padding_mode='border', 
                               align_corners=False).reshape(8,3,6, 262144)
    point_code = point_code.permute(0, 3, 2, 1).reshape( 2097152, 18)                         
    
    base_x = self.base_net(point_code)                 # linear(2097152, 18)-> (2097152, 64) 
    base_x_act = self.base_activation(base_x)
    sigmas = self.density_net(base_x_act).squeeze(-1)  # (2097152, 64) -> (2097152, 1) 
    rgb = None if density_only
  1. 渲染图像( decode)
image, depth = self.render(decoder, code, density_bitfield, h, w, test_intrinsics, test_poses, cfg=cfg)      

函数输入:# code:(8,3,6,128,128)bitfield:(8,32768) pose:(8,251,4,4) intrinsics:[131.250, 131.250, 64.0, 64.0]
以下是具体展开:

## 1. 得到射线起点和方向---------------------------------------------------
rays_o, rays_d = get_cam_rays(poses, intrinsics, h, w)  #(8,251,128,128,3) 拓展部分有解析

## 2.VolumeRenderer-----------------------------------------------------
outputs = decoder(rays_o, rays_d, code, density_bitfield, grid_size, dt_gamma=0, perturb=False)
   
    # 2.1查询每条射线跟 aabb(-1,-1,-1,1,1,1)交点---------------------------
    nears, fars = batch_near_far_from_aabb(rays_o, rays_d, self.aabb, self.min_near=0.2)
    #  nears、fars:(8, 4112384=251*128*128)

    # 2.2 渲染256次                

          step = 0
          while step < self.max_steps:
                # 射线上均匀采样空间点,插值得到对应密度 delta--------------------
                xyzs, dirs, deltas = march_rays(4112384, n_step, rays_alive, rays_t,
                                     rays_o, rays_d_single, self.bound=1, bitfield, 1, 
                                     grid_size, nears, fars,  align=128, perturb=False, dt_gamma=0, max_steps=256)

                sigmas, rgbs,= self.point_decode(xyzs, dirs, code) # xyzs(4112512, 3)  code(3,6,128,128)
                   
                    # point_decode展开如下:---------------------------------------
	                point_code_single = F.grid_sample(code_single, self.xyz_transform(xyzs_single),
	                                         mode='bilinear', padding_mode='border', align_corners=False ).squeeze(-2) # (3,6,4112512)


			        base_x = self.base_net(point_code)       # linear (4112512, 18) -> (4112512, 64) 
			        base_x_act = self.base_activation(base_x)
			        sigmas = self.density_net(base_x_act).squeeze(-1)     # (4112512, 64) -> (4112512, 1) 

		            # 渲染颜色(因为颜色需要位置信息+方向信息)----------------------------
		            sh_enc = sh_encode(dirs, self.degree=4)   # 三维坐标,位置编码至16维
		            color_in = self.base_activation(base_x + self.dir_net(sh_enc))  # linear+Silu->(4112512,64)
		            rgbs = self.color_net(color_in)    # linear:(64,3)
		            if self.sigmoid_saturation > 0:
		                rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation

                # 光线追踪和 图像生成:
                composite_rays(4112512, rays_t, sigmas, rgbs, deltas, depth, image, T_thresh=0.001)

composite_rays函数,通过对每条光线进行迭代计算,根据光线的传播、颜色值的累积和参数更新等步骤,最终得到复合后的图像信息。

最终得到(batchsize=8)三个结果:image(4112384,3)、weights_sum(4112384)、depth(4112384),经过reshape操作得到最终图像:

weights = torch.stack(outputs['weights_sum'], dim=0) 
rgbs = (torch.stack(outputs['image'], dim=0) + self.bg_color * (1 - weights.unsqueeze(-1))     # self.bg_color = 1
depth = torch.stack(outputs['depth'], dim=0) 

5.2 NeRF重建(多视角)训练: ssdnerf_cars_recons1v

配置文件为: ssdnerf_cars_recons1v.py. 每次选50个视角,进行重建

1.扩散过程损失(缓存的 code_list_很重要):

# 1.加载缓存cache--------------------------------------------------------------------------------------
code_list_, code_optimizers, density_grid, density_bitfield = self.load_cache(data)
      
      cache_list=[None, None, None, None, None, None, None, None]  # 初始化为bs个None
      # 循环bs次--------------------------------------------------------------
      for scene_state_single in cache_list:
          if scene_state_single is None:
              # 随机生成以下参数(torch.empty等)
              code_list_.append(self.get_init_code_(None, device))                   # (3,6,128,128)
              density_grid.append(self.get_init_density_grid(None, device))          # (262144)
              density_bitfield.append(self.get_init_density_bitfield(None, device))  #(32768)

# 2.读入图像------------------------------------------------------------------------------------------
concat_cond = None                     # 这里没有使用图像拼接条件
if 'cond_imgs' in data:
    cond_imgs = data['cond_imgs']  #  (8,50,128,128,3) 8个batch,50个视角的图
    cond_intrinsics = data['cond_intrinsics']  # [fx, fy, cx, cy] (bs,50,4)
    cond_poses = data['cond_poses']       # (bs,50,4,4)

# 3. 计算射线(根据pose和intrinc)-------------------------------------------------------------------
 cond_rays_o, cond_rays_d = get_cam_rays(cond_poses, cond_intrinsics, h, w)  # (bs,50,128,128,3)
 dt_gamma_scale = self.train_cfg.get('dt_gamma_scale', 0.0)        # 0.5
 dt_gamma = dt_gamma_scale / cond_intrinsics[..., :2].mean(dim=(-2, -1))    # [0.0038]*(8)

# 4.采样时间步 t------------------------------
t = self.sampler(num_batches).to(device)  # [605, 733, 138, 312, 997, 128, 178, 752]
# 5. 生成噪声
noise = torch.rand().to(device)    # (bs, 18, 128, 128)

# 6.根据时间t,采样噪声
x_t, mean, std = self.q_sample(x_0, t, noise)  # x_0就是上步的code_list_,经过缩放和激活

			    def q_sample(self, x_0, t, noise=None):

			        mean = var_to_tensor(self.sqrt_alphas_bar, t.cpu(), x_0.shape, device)  将输入变量sqrt_alphas转换为张量,并根据索引t提取相应的值    
			        std = var_to_tensor(self.sqrt_one_minus_alphas_bar, t.cpu(), x_0.shape, device)
			        return x_0 * mean + noise * std,  mean,  std

# 7.逆扩散过程-------------------------------------------------------------------------------------------
denoising_output = self.pred_x_0( x_t, t, grad_guide_fn=None, concat_cond=None, cfg=cfg, update_denoising_output=True)
		# 7.1 位置编码------------------------------------------------------------------------
		embedding = self.time_embedding(t)    # bs --> (bs,512) 正余弦编码+conv+silu。具体展开如下:
					embedding = self.blocks(self.sinusodial_embedding(t))
							    def sinusodial_embedding(timesteps, dim, max_period=10000):
							        half = dim // 2        # 128//2
							        freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)   
							        # freqs:64:[1.0, 0.8, 0.7, 0.6...0.0014]
							        args = timesteps[:, None].float() * freqs[None]
							        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
							
							        return embedding       # (bs,128)

		# 7.2 上下采样,用3D卷积和注意力qkv--------------------------------------------------------------
        h, hs = x_t, []                            # (bs,18,128,128)
        if self.concat_cond_channels > 0:
            h = torch.cat([h, concat_cond], dim=1)
        # forward downsample blocks
        for block in self.in_blocks:
            h = block(h, embedding)
            hs.append(h)                           # (bs,512,8,8)

        # forward middle blocks
        h = self.mid_blocks(h, embedding)

        # forward upsample blocks
        for block in self.out_blocks:
            h = block(torch.cat([h, hs.pop()], dim=1), embedding)    #   (bs,128,128,128)
        outputs = self.out(h)                      # (bs,18,128,128)

        return outputs

# 8.  x_0 是输入x_t与预测值denoising_output的加权差:----------------------------------------------------
if self.denoising_mean_mode.upper() == 'V':
    x_0_pred = sqrt_alpha_bar_t * x_t - sqrt_one_minus_alpha_bar_t * denoising_output

# 9.self.ddpm_loss-----------------------------------------------------------------------------------
class DDPMMSELossMod(DDPMLossMod):
    def mse_loss(pred, target):

    return F.mse_loss(pred, target, reduction='none')    # (bs,18,128,128) target就是v_t
    
loss.mean(dim=list(range(1, loss.ndim)=4))    # (bs)
loss_rescaled = loss * weight.to(timesteps.device)[timesteps] * self.weight_scale   # weight_scale=4

quartile = (timesteps / total_timesteps * 4).int()        # [2, 3, 0, 0, 3, 2, 0, 2]
for idx in range(4):
    loss_quartile = reduce_loss(loss[quartile == idx], reduction)   # mean
    log_vars[f'{prefix_name}_quartile_{idx}'] = loss_quartile.item()    # log_vars包含4个loss,求均值得到loss_diffusion
loss_diffusion.backward()

随后更新缓存梯度:

    if extra_scene_step > 0:
        prior_grad = [code_.grad.data.clone() for code_ in code_list_]        # (8,3,6,128,128)

2.inverse_code:NeRF的损失

# 10.NeRF渲染-----------------------------------------------------------------------------------
for inverse_step_id in range(n_inverse_steps):                  # 循环16if inverse_step_id % self.update_extra_interval == 0:       # 16 的整数倍
        # density_grid, density_bitfield 分别是0初始化的(bs,262144)(bs,32768);经过upgrade后,填充进了sigma
        # update_extra_state 是网格生成过程(同上述无条件生成的3.4节),得到coord与sigma
        self.update_extra_state(decoder, code, density_grid, density_bitfield, iter_density, density_thresh=cfg.get('density_thresh', 0.01))

    inds = raybatch_inds[inverse_step_id % num_raybatch]       # 从 0 到 第200份的射线
    rays_o, rays_d, target_rgbs = self.ray_sample(
        cond_rays_o, cond_rays_d, cond_imgs, n_inverse_rays, sample_inds=inds)   # 选出某份射线:(bs,4096,3)


	## 11.march_rays_train(cuda封装的函数)-------------------------------------------------------------
	xyzs_single, dirs_single, deltas_single, rays_single = march_rays_train(
					                    rays_o_single, rays_d_single, self.bound, density_bitfield_single, 1, grid_size_single, nears_single, fars_single, 
					                    perturb=perturb, align=128, force_all_rays=True, dt_gamma=dt_gamma_single.item(), max_steps=self.max_steps)
	nears, fars = batch_near_far_from_aabb(rays_o, rays_d, self.aabb, self.min_near)
	weights_sum_, depth_, image_ = composite_rays_train(sigmas, rgbs, deltas_, rays_, T_thresh)   # (8,4096)(8,4096)(8,4096,3)
	
	
	## 12.最终的2个损失
	pixel_loss = self.pixel_loss(out_rgbs, target_rgbs, **kwargs) * (scale * 3)     # F.mse_loss, 权重weight=20,求平均   ->7.73
	reg_loss = (code.abs() ** 2).mean()                       # code:特征的大小  -> 3.99e-11
	loss = pixel_loss + reg_loss
	
	## 13. 复制prior_grad的梯度到 code_ (在一个inverse_step中,code_梯度永远来自prior_grad)----------------
    if prior_grad is not None:
        if isinstance(code_, list):
            for code_single_, prior_grad_single in zip(code_, prior_grad):           # 循环8次
                code_single_.grad.copy_(prior_grad_single)      # (3,6,128,128)

    loss.backward()
  1. 更新cache
self.save_cache(
    code_list_, code_optimizers,
    density_grid, density_bitfield, data['scene_id'], data['scene_name'])

5.3 稀疏重建推理: ssdnerf_cars3v_recons1v

配置文件为: ssdnerf_cars3v_recons1v.py 每次选曲1个视角,进行重建3维场景

with torch.no_grad():
    cond_mode = self.test_cfg.get('cond_mode', 'guide')     # 'guide_optim'
    # 1.val_guide 得到三平面特征code, 以及空间网格特征
    code, density_grid, density_bitfield = self.val_guide(data, **kwargs)    # code:(bs,18,128,128)
            # 1.0 初始化三大变量--------------------------------------------------------------
            density_grid = torch.zeros((num_scenes, self.grid_size ** 3), device=device)        # (8,262144)
            density_bitfield = torch.zeros((num_scenes, self.grid_size ** 3 // 8), dtype=torch.uint8, device=device)    # (8,32768)
            noise = torch.randn((num_scenes, *self.code_size), device)          # (8,3,6,128,128)
            
            # 1.1 扩散75-----------------------------------------------------------------------
		    x_t = noise
		    num_timesteps = self.test_cfg.get('num_timesteps', self.num_timesteps)        # 75
	        langevin_t_range = self.test_cfg.get('langevin_t_range', [0, 1000])           # [0,1000]
		    timesteps = torch.arange(start= 75 - 1, end=-1, step=-(1000/ 75)    # 75[0,1000]的随机数

            for step, t in enumerate(timesteps):                                    # 逆扩散75步,每次更新x_t
                x_t, x_0_pred = self.p_sample_ddim( x_t, t, t_prev, concat_cond=None)

			          # 1.1.1 Unet 通过x_t与t,预测x_0:---------------------------------------------------------
				   	  denoising_output = self.denoising(x_t, t, concat_cond=concat_cond)      # (8,18,128,128)
					  x_0_pred = sqrt_alpha_bar_t * x_t - sqrt_one_minus_alpha_bar_t * denoising_output
					  self.update_extra_state(decoder, x_0_pred, density_grid, density_bitfield,
					                             0, density_thresh=self.test_cfg.get('density_thresh', 0.01))
			          # 1.1.2 利用 x_0_pred 三平面特征,特征空间解码,更新density_grid,density_bitfield-------------------
			          rays_o, rays_d, target_rgbs = self.ray_sample(cond_rays_o, cond_rays_d, 
			                                             cond_imgs, n_inverse_rays, sample_inds=inds)  # (bs,16384,3)
			          # 1.1.3 NeRF 射线解码-----------------------------------------------------------------
			          nears, fars = batch_near_far_from_aabb(rays_o, rays_d, self.aabb, self.min_near=0.2)
			          xyzs_single, dirs_single, deltas_single, rays_single = march_rays_train()
			          sigmas, rgbs, num_points = self.point_decode(xyzs, dirs, code)    # F_grid插值,fc层计算密度/颜色 xyz(bs,93440,3-> (724480)(3495936,3) 8个点数
			          weights_sum, depth, image = batch_composite_rays_train(sigmas, rgbs, deltas, rays, num_points, T_thresh)   # (8,16384)(8,16384)(8,16384,3)
			          # 1.1.4 计算loss--------------------------------------------------------------
			          loss = grad_guide_fn(x_0_pred)
			          pixel_loss = self.pixel_loss(out_rgbs, target_rgbs, **kwargs) * (scale * 3)
			          reg_loss = self.reg_loss(code, **kwargs)
			          loss = pixel_loss + reg_loss
			          grad = torch.autograd.grad(loss, x_t)[0]   # (8,18,128,128) 计算 loss对x_t的梯度
			          torch.set_grad_enabled(False)
			          x_0_pred.detach_()
			          x_0_pred -= grad * ((sqrt_one_minus_alpha_bar_t ** (2 - snr_weight_power * 2))
			              * (sqrt_alpha_bar_t ** (snr_weight_power * 2 - 1))* guidance_gain)   # guidance_gain=52484.8
			          eps_t_pred = (x_t - self.sqrt_alphas_bar[t] * x_0_pred) / self.sqrt_one_minus_alphas_bar[t]
			          pred_sample_direction = np.sqrt(1 - alpha_bar_t_prev - tilde_beta_t * (eta ** 2)) * eps_t_pred    # (8,18,128,128)
			          x_t = np.sqrt(alpha_bar_t_prev) * x_0_pred + pred_sample_direction                             # (8,18,128,128)

    # 2. 继续逆扩散,更新特征---------------------------------------------------------------------------
    code, density_grid, density_bitfield = self.val_optim(data, code_=self.code_activation.inverse(code).requires_grad_(True),
                                                          density_grid, density_bitfield=density_bitfield, **kwargs)
            # 2.1 根据pose,计算射线-----------------------------------------------------------------------
	        cond_imgs = data['cond_imgs']  # (num_scenes, num_imgs, h, w, 3)
	        cond_intrinsics = data['cond_intrinsics']  # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy]
	        cond_poses = data['cond_poses']
	
	        cond_rays_o, cond_rays_d = get_cam_rays(cond_poses, cond_intrinsics, h, w)     # (bs,1,128,128,3)
         
            for inverse_step_id in range(n_inverse_steps):                             # 25个inverse步骤
                loss, log_vars = diffusion(self.code_diff_pr(code), return_loss=True, concat_cond=None)
                # 上面是diff_train,采样t,Unet去噪, 计算4个ddpm_mes损失
                loss.backward()
                
                prior_grad = code_.grad.data.clone()
                self.inverse_code(decoder, cond_imgs, cond_rays_o, cond_rays_d, dt_gamma=dt_gamma, cfg code_=code_, density_grid, density_bitfield, prior_grad=prior_grad)

                     # inverse_code包含以下步骤:
                     for inverse_step_id in range(n_inverse_steps):        # 循环4if inverse_step_id % self.update_extra_interval == 0:
                            self.update_extra_state(decoder, code, density_grid, density_bitfield,iter_density, 
                                                    density_thresh=cfg.get('density_thresh', 0.01))    # 生成网格,F特征插值、base_net更新sigma

	                     rays_o, rays_d, target_rgbs = self.ray_sample( cond_rays_o, cond_rays_d, cond_imgs, n_inverse_rays, sample_inds=inds)        # (bs, 16384, 3)
	                     out_rgbs, loss, loss_dict = self.loss( decoder, code, density_bitfield, target_rgbs, rays_o, rays_d)
	                     code_.grad.copy_(prior_grad)
	                     loss.backward()
 # 3.-------------------------------------------------------------------------------------------
 log_vars, pred_imgs = self.eval_and_viz(data, decoder, code, density_bitfield, viz_dir=viz_dir, cfg)
        image, depth = self.render( decoder, code, density_bitfield, h, w, test_intrinsics, test_poses, cfg=cfg)      #code:(8,3,6,128,128)bitfield:(8,32768) pose:(8,251,4,4)

下面是对第3部分:self.render的展开

# 输入数据的num_imgs有250张
test_intrinsics = data['test_intrinsics']  # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy]
test_poses = data['test_poses']
test_imgs = data['test_imgs']  # (num_scenes, num_imgs, h, w, 3)

outputs = decode(batch_near_far_from_aabb、march_rays、self.point_decode等,循环self.max_steps=256)
        results = dict(
            weights_sum=weights_sum,          # bs个(4096000)
            depth=depth,                      # bs个(4096000)
            image=image)                      # bs个(40960003)

weights = torch.stack(outputs['weights_sum'], dim=0) if num_scenes > 1 else outputs['weights_sum'][0]    # 01
rgbs = (torch.stack(outputs['image'], dim=0) if num_scenes > 1 else outputs['image'][0]) \
      + self.bg_color * (1 - weights.unsqueeze(-1))
depth = torch.stack(outputs['depth'], dim=0) if num_scenes > 1 else outputs['depth'][0]
pred_imgs = image.permute(0, 1, 4, 2, 3).reshape(8 * 250, 3, h, w).clamp(min=0, max=1)

# 测试生成的全视角图像,和数据集真实图像的指标
test_psnr = eval_psnr(pred_imgs, target_imgs)   # bs*250个数
test_ssim = eval_ssim_skimage(pred_imgs, target_imgs, data_range=1)

# 最后是保存图像(与真实图像拼接,便于对比)在work_dir文件夹中,此处忽略。。。。。。。。。。。。。。

5.4 稀疏重建训练: ssdnerf_cars3v_recons1v

配置文件为: ssdnerf_cars3v_recons1v.py 每次选曲3个视角,进行重建训练

# 1.载入缓存.若无缓存,初始化(生成)code等三大变量:------------------------------------------------
code_list_, code_optimizers, density_grid, density_bitfield = self.load_cache(data)
code = self.code_activation(torch.stack(code_list_, dim=0), update_stats=True)

# 2.载入图像----------------------------------------------------------------------------
cond_imgs = data['cond_imgs']  # (num_scenes, num_imgs, h, w, 3)     (8,3,128,128,3)
cond_intrinsics = data['cond_intrinsics']  # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy] (8,3,4)
cond_poses = data['cond_poses']       # (8,3,4,4)

num_scenes, num_imgs, h, w, _ = cond_imgs.size()
# (num_scenes, num_imgs, h, w, 3)
cond_rays_o, cond_rays_d = get_cam_rays(cond_poses, cond_intrinsics, h, w)  # (8,3,128,128,3)
dt_gamma_scale = self.train_cfg.get('dt_gamma_scale', 0.0)        # 0.5
# (num_scenes,)
dt_gamma = dt_gamma_scale / cond_intrinsics[..., :2].mean(dim=(-2, -1))    # [0.0038]*(8)

# 3. 扩散模型损失
loss_diffusion, log_vars = diffusion(self.code_diff_pr(code), concat_cond=None, return_loss=True,
                                     x_t_detach=x_t_detach, cfg=self.train_cfg)
                t = self.sampler(num_batches).to(device)              # [630, 757, 989, 452,...] bs个
                noise = torch.randn((num_batches, *image_shape))
                x_t, mean, std = self.q_sample(x_0, t, noise)         # 从x_0 采样到x_t 
                denoising_output = self.denoising(x_t, t, concat_cond=concat_cond)      # Unet逆扩散过程(8,18,128,128)

loss_diffusion.backward()                

# 4. 更新 先验梯度-------------------------------------------------------------------------------------
prior_grad = [code_.grad.data.clone() for code_ in code_list_]        # (8,3,6,128,128)

# 5.self.inverse_code:NeRF渲染:----------------------------------------------------------------------
raybatch_inds, num_raybatch = self.get_raybatch_inds(cond_imgs, n_inverse_rays)  # (8,3,128,12834096 -> 随机分成 12份的 (8,4096)
self.update_extra_state(decoder, code, density_grid, density_bitfield,
                        iter_density, density_thresh=cfg.get('density_thresh', 0.01))

nears, fars = batch_near_far_from_aabb(rays_o, rays_d, self.aabb, self.min_near)
xyzs_single, dirs_single, deltas_single, rays_single = march_rays_train( rays_o_single, rays_d_single, 
                                 self.bound, density_bitfield_single, 1, grid_size_single, nears_single, fars_single,  perturb=True)


sigmas, rgbs, num_points = self.point_decode(xyzs, dirs, code)    # F_grid插值,fc层计算密度/颜色 (8,438144,3-> (3495424)(3495424,3) 8个点数
            weights_sum, depth, image = batch_composite_rays_train(sigmas, rgbs, deltas, rays, num_points, T_thresh)   # (8,4096)(8,4096)(8,4096,3)

pixel_loss = self.pixel_loss(out_rgbs, target_rgbs, **kwargs) * (scale * 3)     # F.mse_loss, 权重weight=20,求平均   ->7.73
reg_loss = self.reg_loss(code, **kwargs)
loss = pixel_loss + reg_loss

for code_single_, prior_grad_single in zip(code_, prior_grad):           # 循环8次
    code_single_.grad.copy_(prior_grad_single)      # (3,6,128,128)
loss.backward()

# 6.更新 density_grid, density_bitfield------------------------------------------------------------
self.update_extra_state( decoder, code, density_grid, density_bitfield,
                0, density_thresh=self.train_cfg.get('density_thresh', 0.01))


# 7.loss_decoder = self.loss_decoder(decoder, code, density_bitfield, cond_rays_o, cond_rays_d,
                                     cond_imgs, dt_gamma, cfg=self.train_cfg) #-----------------------
    # decoder_loss作为第三个损失,与第二部损失相同。内容如下:
	nears, fars = batch_near_far_from_aabb(rays_o, rays_d, self.aabb, self.min_near)
	for i in range(batchsize):
		xyzs_single, dirs_single, deltas_single, rays_single = march_rays_train(
							                   rays_o_single, rays_d_single, self.bound, density_bitfield_single,
							                   1, grid_size_single, nears_single, fars_single,
							                   perturb=perturb, align=128, force_all_rays=True,
							                   dt_gamma_single.item(), max_steps=256)      # (4096,3->(435968,3)(435968,3)(435968,2)
	    sigmas, rgbs, num_points = self.point_decode(xyzs, dirs, code)    # F_grid插值,fc层计算密度/颜色 (8,447212,3-> (3495936)(3495936,3) 8个点数
	    weights_sum, depth, image = batch_composite_rays_train(sigmas, rgbs, deltas, rays, num_points, T_thresh)   # (8,4096)(8,4096)(8,4096,3)

# 8.更新 decoder 梯度------------------------------------------------------------------------------
for code_, prior_grad_single in zip(code_list_, prior_grad):
    code_.grad.copy_(prior_grad_single)
loss_decoder.backward()

# 9.缓存cache-------------------------------------------------------------------------------------
self.save_cache( code_list_, code_optimizers, density_grid, density_bitfield, data['scene_id'], data['scene_name'])

# 10.验证-----------------------------------------------------------------------------------------
train_psnr = eval_psnr(out_rgbs, target_rgbs)

*.cuda封装好的函数

lib/ops/raymarching/scr/raymarching.h 中包含了以下函数(具体定义在raymarching.cu文件)

#pragma once

#include <stdint.h>
#include <torch/torch.h>

void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);

void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);

void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);

局限性

     目前,该方法训练和测试过程中都依赖于Ground Truth 摄像机的参数。未来的工作可能会探索transform-invariant 模型。此外,随着训练时间的延长,扩散先验会变得不连续,从而影响泛化。虽然暂时使用了早期停止,但更好的网络设计或更大的训练数据集可能能够从根本上解决这个问题。


拓展

1.FID和KID、LPIPS指标

Frechet Inception Distance (FID) 是一种用于评估 GAN 生成的图像质量的指标。它基于生成的图像与真实图像之间的Fréchet距离,该距离衡量了两个图像分布之间的相似性。FID越低,表示生成的图像质量越好。

在这里插入图片描述

LPIPS:来自论文《The Unreasonable Effectiveness of Deep Features as a Perceptual Metric》
在这里插入图片描述
数值越小代表两张图像越相似。将两个输入送入神经网络F(可以为VGG、Alexnet、Squeezenet)中进行特征提取,对每个层的输出进行激活后归一化处理,然后经过w层权重点乘后计算L2距离。
在这里插入图片描述

Kernel Inception Distance (KID) 是另一种评估 GAN 生成的图像质量的指标。它基于生成的图像与真实图像之间的核函数距离,通过将特征向量映射到高维空间,并计算它们之间的核矩阵的Frechet距离来衡量真实图像和生成图像之间的差异。KID也是一种衡量生成图像质量的指标,与FID类似。

具体代码实现,见 https://blog.51cto.com/u_16175458/6906283

2. TV 正则化

TV正则化,全称是Total Variation Regularization。TV正则化通过最小化图像的梯度幅度来实现对图像的平滑处理,

具体来说,对于一个二维图像,TV正则化可以通过最小化图像的梯度幅度来实现平滑处理。从而抑制图像中的噪声和细节。使图像变得更加平滑。TV正则化通常会被应用于优化问题的正则化项中,以平衡数据拟合和平滑度之间的关系。

对于NeRF(Neural Radiance Fields)等主要用于三维重建的方法,TV正则化有助于改善重建结果的质量。这是因为在三维重建中,由于数据的稀疏性和噪声等因素,重建结果往往会包含不必要的细节和噪声

另外,TV正则化还可以帮助增强对深度学习模型的约束,有助于提高模型的泛化能力和抗噪声能力。因此,对于NeRF等三维重建任务,应用TV正则化有助于改善重建结果的质量,并提高模型的稳健性。

3.NeRF渲染(从相机内外参,到射线)

rays_o, rays_d = get_cam_rays(poses, intrinsics, h, w)
主要解析 get_cam_rays函数,位于 lib/core/utils/nerf_utils.py

directions = get_ray_directions(h, w, intrinsics, norm=False)  # (8, 251, h, w,3)

      def get_ray_directions():
	    batch_size = intrinsics.shape[:-1]
	    x = torch.linspace(0.5, w - 0.5, w, device=device)
	    y = torch.linspace(0.5, h - 0.5, h, device=device)
	    
	    # 相对坐标,除以焦距,就是射线方向----------------------------------------------
	    directions_xy = torch.stack(
	        [((x - intrinsics[..., 2:3]) / intrinsics[..., 0:1])[..., None, :].expand(*batch_size, h, w),
	         ((y - intrinsics[..., 3:4]) / intrinsics[..., 1:2])[..., :, None].expand(*batch_size, h, w)], dim=-1)
	    # (8,251,128,128, 2)
	    directions = F.pad(directions_xy, [0, 1], mode='constant', value=1.0)
	    
	    return directions


rays_o, rays_d = get_rays(directions, c2w, norm=True)

       def get_rays(directions, c2w, norm=False):
		   rays_d = directions @ c2w[..., None, :3, :3].transpose(-1, -2)   # (8,251,128,128,3)
		   rays_o = c2w[..., None, None, :3, 3].expand(rays_d.shape)        # (8,251,128,128,3)
		   if norm:
		       rays_d = F.normalize(rays_d, dim=-1)                         # 坐标归一化到 -11之间
		   return rays_o, rays_d

猜你喜欢

转载自blog.csdn.net/qq_45752541/article/details/134984758