超分辨图像无限生成!Diffusion Transformer 任意分辨率上采样

作者丨科技猛兽

编辑丨极市平台

本文目录

1 Inf-DiT:Diffusion Transformer 任意分辨率上采样

(来自清华大学,唐杰团队)

1 Inf-DiT 论文解读

1.1 超高分辨率图像生成问题的挑战:GPU 显存需求

1.2 单向块注意力机制

1.3 O(N) 显存消耗的推理过程

1.4 Inf-DiT 架构

1.5 全局和局部一致性

1.6 实验结果

太长不看版

扩散模型在图像生成方面表现出了很显著的性能。然而对于生成超高分辨率的图像 (比如 4096 ×4096) 而言,由于其 Memory 也会二次方增加,因此生成的图像的分辨率通常限制在 1024×1024。在这项工作中。作者提出了一种单向块注意力机制,可以在推理过程中自适应地调整显存开销并处理全局依赖关系。在这个模块的基础上,作者使用 DiT 的架构,并逐渐执行上采样,最终开发了一个无限的超分辨率模型 Inf-DiT,能够对各种形状和分辨率的图像进行上采样。综合实验表明,Inf-DiT 在生成超高分辨率图像方面取得了 SOTA 性能。与常用的 UNet 结构相比,Inf-DiT 在生成 4096×4096 图像时可以节省超过5倍显存。

图1:基于 SDXL、DALL-E 3 和真实图像,选择出的 Inf-DiT 超高分辨率上采样示例

本文做了哪些具体的工作

提出了单向块注意力机制 (Unidirectional Block Attention,UniBA) 算法,在推理过程中将最小显存消耗从 降低到 , 其中 表示边长。该机制还能够通过调整并行生成的块数量、在显存和时间开销之间进行权衡来适应各种显存限制。基于这些方法,训练了一个图像上采样扩散模型 Inf-DiT,这是一个 700M 的模型,能够对不同分辨率的和形状图像进行上采样。Inf-DiT 在机器 (HPDV2 和 DIV2K 数据集) 和人工评估中都实现了最先进的性能。设计了多种技术来进一步增强局部和全局一致性,并为灵活的文本控制提供 Zero-Shot 的能力。

1Inf-DiT:Diffusion Transformer 任意分辨率上采样

论文名称:Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer (Arxiv 2024.03)

论文地址:

https://arxiv.org/pdf/2405.04312

项目地址:

https://github.com/THUDM/Inf-DiT

1.1 超高分辨率图像生成问题的挑战:GPU 显存需求

近年来,扩散模型取得了快速发展,显着推动了图像生成和编辑领域的发展。尽管取得了进步,但仍然存在一个关键的限制:现有图像扩散模型生成的图像的分辨率通常被限制在 1024×1024 像素或更低,这对生成超高分辨率图像提出了重大挑战,这在包括复杂的设计项目、广告和海报和墙壁纸的创建等各种实际应用中是必不可少的。

生成高分辨率的常用方法是 Cascaded Generation,它首先生成低分辨率图像,然后应用多个上采样模型逐步提高图像的分辨率。这种方法将高分辨率图像的生成分解为多个子任务。基于前一阶段产生的结果,后期的模型只需要执行局部的生成。在级联结构的基础上,DALL-E2[1]和 Imagen[2]都可以有效地生成 1024×1024 分辨率的图像。

上采样到更高分辨率的图像的最大挑战是关于 GPU 显存需求。例如,如果使用广泛采用的 U-Net 架构,例如 SDXL[3]进行图像推理 (见下图2),可以观察到显存消耗随着分辨率的增加而急剧增加。具体来说,如果生成 4096×4096 分辨率的图像,其包含超过 16 亿个像素需要超过 80GB 的显存,超过了标准 RTX 4090 或 A100 显卡的容量。此外,用于高分辨率图像生成的训练模型的过程加剧了这些需求,因为它需要额外的显存来存储梯度、优化器状态等。LDM[4]通过利用变分自动编码器 (Variational Autoencoder,VAE) 压缩图像并在更小的 Latent Space 中生成图像来减少显存消耗。然而,过高的压缩比会大大降低生成的质量,对显存消耗的减少造成了严重的限制。

图2:本文模型和 SDXL 架构之间不同分辨率的推理期间显存使用的比较

1.2 单向块注意力机制

作者观察到生成超高分辨率图像的关键障碍是显存限制。随着图像的分辨率的增加,网络中相应的 hidden states 的大小呈二次方的复杂度扩展。例如,1层中形状为 2048×2048×1280 的单个 hidden state 需要 20GB 的显存,这使得很难生成非常大的图像。如何避免将整个图像的 hidden state 存储在内存中成为关键的问题。

本文的方法单向块注意力 (Unidirectional Block Attention, UniBA) 如下图3所示。对于每个层,每个 Block 直接依赖于3个一阶相邻的 Block:顶部的 Block、左侧和左上角的 Block。例如,如果采用 Diffusion Transformer (DiT) 架构,Block 之间的依赖关系是注意力操作,每个 Block 的 Query 向量与4个 Block 的 Key,Value 向量交互:位于其左上角和自身的3个 Blocks,如图3所示。

图3:左侧:单向块注意力。在我们的实现中,每个 Block 直接取决于每一层的3个 Blocks:左上角的块、左侧和顶部的 Block;右侧:Inf-DiT 的推理过程。Inf-DiT 根据内存大小每次生成 n×n 个 Block。在这个过程中,只有后续块所依赖的块的 KV-cache 存储在内存中

Transformer 中的 UniBA 过程可以表述为:

1.3 O(N) 显存消耗的推理过程

1.4 Inf-DiT 架构

如下图4所示是 Inf-DiT 架构,它基于 DiT[5]。与基于卷积的结构 (如 U-Net[6]) 相比,DiT 仅利用注意力作为 Patch 之间的交互机制,可以方便地实现 UniBA。为了适应 UniBA,提高上采样的性能,作者做了如下几个修改和优化。

图4:Inf-DiT 架构

模型输入

位置编码

1.5 全局和局部一致性

使用 CLIP Image Embedding 针对全局一致性

低分辨率 (LR) 图像中的全局语义信息,如艺术风格和物体材料,在上采样过程中起着至关重要的作用。然而,与文生图像模型相比,上采样模型还有一个额外的任务:理解和分析 LR 图像的语义信息,大大增加了模型的负担。在没有文本数据进行训练时尤其具有挑战性,因为高分辨率图像很少具有高质量的配对文本,这使得模型的这些方面变得困难。

使用 Nearby LR Cross Attention 针对局部一致性

尽管将 LR 图像与噪声输入 Concat 起来已经为模型学习 LR 和 HR 图像之间的局部对应关系提供了良好的归纳偏差,但仍然可能存在连续性的问题。原因是,对于给定的 LR Block,有几种上采样的可能性,这需要与附近的几个 LR Block 一起分析以选择一种解决方案。假设上采样仅基于其左侧的 LR Block 执行,它可能会选择一个与右侧和下方 LR Block 冲突的 HR 生成解决方案。然后,当将 LR Block 上采样到右侧时,如果模型认为符合其对应的 LR Block 比与左侧的 Block 连续更重要,则会生成一个与先前块不连续的 HR Block。一个简单的解决方案是将整个 LR 图像输入到每个 Block,但当 LR 图像的分辨率也很大时,它的成本太高。

为了解决这个问题,作者引入了 Nearby LR Cross-Attention。在第一层中,每个 Block 对周围的 3×3 LR Block 进行 Cross-Attention,以捕获附近的 LR 信息。实验结果表明,这种方法显着减少了生成不连续图像的概率。值得注意的是,这个操作不会改变我们的推理过程,因为在生成之前知道整个 LR 图像。

1.6 实验结果

训练细节

本文的数据集包括 LAION-5B[9]的一个子集,分辨率高于 1024×1024,美学得分高于 5 的 100000 来自互联网的分辨率墙纸。在训练过程中,作者使用 512×512 分辨率的固定大小的 Image crop。由于上采样只能使用局部信息进行,因此在推理过程中可以直接用于更高的分辨率,这对于大多数生成模型来说并不容易。

数据准备

由于扩散模型生成的图像通常包含残余噪声和各种细节不准确,因此增强上采样模型的鲁棒性以解决这些问题变得至关重要。作者采用类似于 Real-ESRGAN[10]的方法对训练数据中的低分辨率输入图像执行各种退化。

在处理分辨率高于 512 的图像时,有两种替代方法:一种是直接执行随机裁剪,另一种是在执行随机裁剪之前将较短的边调整为 512。虽然直接裁剪方法在高分辨率图像中保留了高频特征,但调整大小后裁剪方法避免了频繁裁剪单个颜色背景的区域,不利于模型的收敛。因此在实践中,作者从这两种处理方法中随机选择裁剪训练图像。

作者将 Block Size 设置为 128,Patch Size 设置为 4,即每张图片被分成 4×4 Blocks,每个 Block 被分成 32×32 Patches。作者使用 EDM[11]框架训练,并将上采样设置为4倍。由于上采样任务更关注图像的高频细节,我们将训练噪声分布的均值和标准差调整为 -1.0 和 1.4。为了解决训练期间的溢出问题,作者采用了 BF16 格式。采用的 CLIP 模型是在 Datacomp 数据集[12]上预训练的 ViT-L/16。由于 CLIP 只能处理分辨率为 224×224 的图像,作者首先将 LR 图像的大小调整为 224×224,然后将它们输入到 CLIP 中。

机器评测

作者对 Inf-DiT 与超高分辨率图像生成任务的最新方法进行了定量比较,Baseline 包含两类高分辨率生成方法:

1) 直接高分辨率图像生成,包括 SDXL、MultiDiffusion[13]、ScaleCrafter[14]。

2) 基于超分辨率技术的高分辨率图像生成,包括 BSRGAN[15]、DemoFusion[16]。

使用 FID[17]来评估超高分辨率生成的质量。为了进一步验证我们模型的超分辨率能力,作者还在经典的超分辨率任务上将其与著名的超分辨率模型进行了基准测试。

超高分辨率生成结果

作者使用 HPDV2 的测试集进行评估。它包含 3200 个 Prompt,分为4类:”Animation”, “Concept-art”, “Painting”, 和 “Photo”。这允许对各种域和样式的模型生成能力进行全面的评估。作者在 2048 和 4096 两个分辨率上面进行测试。对于基于超分辨率的模型,作者首先使用 SDXL 生成 1024×1024 分辨率的图像并在没有文本的情况下对其进行上采样。作者使用 BSRGAN 的 2 倍和 4 倍版本分别生成 2048×2048 和 4096×4096 分辨率的图片。虽然 Inf-DiT 是在上采样 4× 的设置下训练的,但作者发现它可以在较低的上采样倍数下很好地泛化。对于 2048×2048 分辨率,作者直接将 LR 图像的大小从 1024×1024 调整到 2048×2048,并将其与噪声输入拼接起来。

如下图5所示的实验结果显示,本文模型在平均得分上超过了所有竞争对手。这展示了本文模型生成高分辨率细节和全局信息的能力。唯一的例外是 4096×4096 分辨率上的 FID 指标,略微落后于 BSRGAN。本文模型可以应用于所有生成模型,不仅仅是 SDXL。

图5:HPDV2 数据集上超高分辨率生成方法的定量比较结果

图6:2048×2048 分辨率下对不同方法的定性比较

图7:不同方法在 4096×4096 分辨率下的定性比较

超分辨率实验结果

除了生成高分辨率图像的能力外,Inf-DiT 也可以用作经典的超分辨率模型。作者对 DIV2k 验证集进行评估,该数据集包含不同场景下多个真实世界的高分辨率图像。作者将图像退化固定为 4× 下采样的双三次插值。在与固定分辨率模型 LDM 和 StableSR 进行比较之前,作者从高分辨率图像中心裁剪特定的小块作为 ground truth。在整个过程中,作者使用感知 (FID, FIDcrop) 和保真度 (PSNR, SSIM) 指标来确保详细和全面的评估。

实验结果如图8所示,本文模型在所有指标上实现了最先进的性能。这意味着,作为超分辨率模型,Inf-DiT 不仅擅长在任意尺度上执行超分辨率,还擅长在恢复与原始图像非常相似的结果的同时最佳地保留全局和局部信息。

图8:DIV2K 数据集与最先进的超分辨率方法的比较结果

人类评测结果

为了从人类的角度更准确地反映其生成质量,作者进行了人工评估。作者比较了4个模型,对每个模型随机选择十个比较集,每个比较集包含来自四个模型的输出,最终总共有 40 个数据。为了确保公平,作者在每个比较集中随机化模型输出序列。人类评估者被要求根据3个标准评估模型:细节真实性、全局连贯性和与原始低分辨率输入的一致性。每个评估者平均接收 20 组图像。在每个集合中,评估者需要根据3个标准将四个模型生成的图像从最高排名排名最低。

最终收集了 3,600 次比较。如图9所示,本文模型在所有3个标准上都优于其他3种方法。特别值得注意的是,其他3个模型在3个评估标准中至少有一个的排名相对较低,而 Inf-DiT 在所有3个标准上都取得了最高分:细节真实性、全局连贯性和与低分辨率输入的一致性。这表明本文模型是唯一能够同时在高分辨率生成和超分辨率任务中表现出色的模型。

图9:人类评估结果

迭代上采样

由于本文的模型可以对任意分辨率的图像进行上采样,因此测试模型是否可以迭代地对自身生成的图像进行上采样是很自然的想法。在这项研究中,作者尝试通过3次迭代上采样从 32×32 分辨率图像上采样 64 倍之后生成 2048×2048 分辨率图像。图 10 展示了两种样本。在第1个样本中,模型在上采样3个阶段后成功地生成了高分辨率图像。它在不同的分辨率上采样中生成不同频率的细节:人脸的轮廓、眼球的形状和个人睫毛。然而,模型很难纠正早期阶段产生的不准确,从而导致错误的积累。在第2个样本中,作者演示了这个问题的一个示例。我们将此问题留给未来的工作。

图10:迭代上采样结果。上:Inf-DiT 可以多次上采样自己生成的图像,并在相应分辨率下生成不同频率的细节;下:在低分辨率 128×128 时未能准确生成之后,后续很难纠正错误

参考

^abHigh-resolution image synthesis with latent diffusion models^Photorealistic textto-image diffusion models with deep language understanding^Sdxl: Improving latent diffusion models for high-resolution image synthesis^High-resolution image synthesis with latent diffusion models^Scalable Diffusion Models with Transformers^U-Net: Convolutional Networks for Biomedical Image Segmentation^RoFormer: Enhanced Transformer with Rotary Position Embedding^Learning Transferable Visual Models From Natural Language Supervision^LAION-5B: An open large-scale dataset for training next generation image-text models^Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data^Elucidating the Design Space of Diffusion-Based Generative Models^https://doi.org/10.5281/zenodo.5143773^Multidiffusion: Fusing diffusion paths for controlled image generation^Scalecrafter: Tuning-free higher-resolution visual generation with diffusion models^Designing a practical degradation model for deep blind image super-resolution^Demofusion: Democratising high-resolution image generation with no $$$^GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium

    THE END
    喜欢就支持一下吧
    点赞7 分享
    评论 抢沙发
    头像
    欢迎您留下宝贵的见解!
    提交
    头像

    昵称

    取消
    昵称表情代码图片

      暂无评论内容