MPDiT: Multi-Patch Global-to-Local Transformer Architecture For Efficient Flow Matching and Diffusion Model

📄 arXiv: 2603.26357v1 📥 PDF

作者: Quan Dao, Dimitris Metaxas

分类: cs.CV

发布日期: 2026-03-27

备注: Accepted at CVPR 2026

🔗 代码/项目: GITHUB


💡 一句话要点

提出MPDiT多尺度Transformer架构,用于高效Flow Matching和扩散模型,显著降低计算成本。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 扩散模型 Flow Matching Transformer 多尺度patch 图像生成

📋 核心要点

  1. Diffusion Transformer (DiT) 虽然性能优异,但其同构设计导致训练过程计算量大。
  2. MPDiT采用多尺度patch策略,早期关注全局上下文,后期关注局部细节,实现计算效率提升。
  3. 改进的时间和类别嵌入加速了训练收敛,ImageNet实验验证了MPDiT的有效性。

📝 摘要(中文)

本文提出了一种多尺度Transformer架构MPDiT,用于扩散模型和Flow Matching模型。与传统的Diffusion Transformer (DiT) 相比,MPDiT采用多尺度patch策略,在早期模块使用较大的patch以捕获全局上下文,而在后期模块使用较小的patch以细化局部细节。这种分层设计能够在保证生成性能的同时,将计算成本降低高达50%。此外,本文还提出了改进的时间和类别嵌入设计,以加速训练收敛。在ImageNet数据集上的大量实验证明了该架构选择的有效性。代码已开源。

🔬 方法详解

问题定义:现有的Diffusion Transformer (DiT) 架构在扩散模型和Flow Matching模型中表现出色,但其同构设计,即每个block处理相同数量的patchified tokens,导致训练过程中计算量巨大,效率较低。尤其是在处理高分辨率图像时,计算负担更加明显。

核心思路:MPDiT的核心思路是引入多尺度patch处理机制,模仿人类视觉系统对图像的感知方式,即先关注全局信息,再关注局部细节。通过在Transformer的不同层级采用不同大小的patch,实现计算效率和性能的平衡。

技术框架:MPDiT的整体架构是一个分层的Transformer结构。早期层使用较大的patch size,以减少token数量,从而降低计算复杂度,并捕获图像的全局上下文信息。随着网络层数的加深,patch size逐渐减小,token数量增加,从而能够更精细地处理局部细节。此外,还包括改进的时间和类别嵌入模块,用于更好地引导扩散过程。

关键创新:MPDiT的关键创新在于多尺度patch处理策略。与传统的DiT架构相比,MPDiT不再采用统一的patch size,而是根据网络层级动态调整patch size,从而实现了计算效率和生成性能的优化。这种分层处理方式更符合图像生成的内在逻辑。

关键设计:MPDiT的关键设计包括:1) 多尺度patch size的选择策略,即如何确定每一层使用的patch size;2) 改进的时间和类别嵌入方式,具体实现细节未知;3) Transformer block的具体结构,可能包括注意力机制、前馈网络等,具体实现细节未知。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,MPDiT在ImageNet数据集上取得了良好的生成性能,同时将计算成本降低了高达50%。这意味着在相同的计算资源下,MPDiT能够训练更大规模的模型,或者在更短的时间内完成训练。具体的FID分数或其他性能指标未知,但计算效率的提升是显著的。

🎯 应用场景

MPDiT架构可广泛应用于图像生成、图像编辑、视频生成等领域。其高效的计算特性使其更适用于资源受限的场景,例如移动设备或边缘计算平台。未来,该架构有望推动扩散模型在更多实际应用中的落地,例如艺术创作、内容生成、数据增强等。

📄 摘要(原文)

Transformer architectures, particularly Diffusion Transformers (DiTs), have become widely used in diffusion and flow-matching models due to their strong performance compared to convolutional UNets. However, the isotropic design of DiTs processes the same number of patchified tokens in every block, leading to relatively heavy computation during training process. In this work, we introduce a multi-patch transformer design in which early blocks operate on larger patches to capture coarse global context, while later blocks use smaller patches to refine local details. This hierarchical design could reduces computational cost by up to 50\% in GFLOPs while achieving good generative performance. In addition, we also propose improved designs for time and class embeddings that accelerate training convergence. Extensive experiments on the ImageNet dataset demonstrate the effectiveness of our architectural choices. Code is released at \url{https://github.com/quandao10/MPDiT}