Wave-PDE Nets: Trainable Wave-Equation Layers as an Alternative to Attention

📄 arXiv: 2510.04304v1 📥 PDF

作者: Harshil Vejendla

分类: cs.LG, cs.CL

发布日期: 2025-10-05

备注: PRICAI 2025 Oral, 9 pages, 3 figures


💡 一句话要点

提出Wave-PDE Nets,以可训练波动方程层替代注意力机制,提升计算效率。

🎯 匹配领域: 支柱八:物理动画 (Physics-based Animation)

关键词: 波动方程 神经网络 注意力机制 计算效率 物理建模 辛积分 谱方法

📋 核心要点

  1. 现有Transformer模型计算复杂度高,内存占用大,限制了其在资源受限场景的应用。
  2. Wave-PDE Nets通过模拟波动方程进行信息传播,利用可训练的速度和阻尼参数,实现全局交互。
  3. 实验表明,Wave-PDE Nets在语言和视觉任务上达到或超过Transformer性能,同时显著降低计算时间和内存占用。

📝 摘要(中文)

本文提出Wave-PDE Nets,一种新型神经网络架构,其基本操作是对二阶波动方程进行可微分的模拟。每一层通过具有可训练空间速度c(x)和阻尼γ(x)的介质传播其隐藏状态,将其视为连续场。基于FFT的辛谱求解器以O(n log n)的时间复杂度实现这种传播。这种振荡的全局机制为注意力和一阶状态空间模型提供了一种强大的替代方案。我们证明了单个Wave-PDE层是一个通用逼近器。在语言和视觉基准测试中,Wave-PDE Nets的性能与Transformer相当或超过Transformer,同时表现出卓越的实际效率,减少了高达30%的实际运行时间和25%的峰值内存。消融研究证实了辛积分和谱拉普拉斯算子对于稳定性和性能的关键作用。对学习到的物理参数的可视化揭示了该模型学习了信息传播的直观策略。这些结果将Wave-PDE Nets定位为一种具有强大物理归纳偏置的计算高效且鲁棒的架构。

🔬 方法详解

问题定义:Transformer模型在处理长序列时,计算复杂度呈平方增长,并且需要大量的内存。这限制了它们在计算资源有限的场景中的应用。此外,Transformer的注意力机制缺乏明确的物理意义,难以解释其信息传播方式。

核心思路:Wave-PDE Nets的核心思想是将神经网络层视为波动方程的模拟。通过在可训练的介质中传播隐藏状态,模型可以学习到全局的信息交互模式。这种基于物理的建模方式提供了一种替代注意力机制的有效方法,并且具有更好的计算效率。

技术框架:Wave-PDE Nets的每一层都模拟一个二阶波动方程。输入首先被视为一个连续场,然后在具有可训练空间速度c(x)和阻尼γ(x)的介质中传播。这种传播使用基于FFT的辛谱求解器实现,其时间复杂度为O(n log n)。模型的整体架构由多个Wave-PDE层堆叠而成,每一层都学习不同的信息传播模式。

关键创新:Wave-PDE Nets的关键创新在于使用波动方程来建模神经网络层的信息传播。这种方法与传统的注意力机制和状态空间模型不同,它提供了一种全局的、振荡的信息交互方式。此外,通过训练波动方程的参数(速度和阻尼),模型可以学习到适应特定任务的信息传播策略。

关键设计:Wave-PDE Nets使用辛谱求解器来保证数值模拟的稳定性。谱拉普拉斯算子用于计算波动方程的空间导数。模型使用可训练的空间速度c(x)和阻尼γ(x)来控制信息传播的速度和衰减。损失函数根据具体任务而定,例如交叉熵损失或均方误差损失。

📊 实验亮点

Wave-PDE Nets在语言和视觉基准测试中表现出与Transformer相当或超过Transformer的性能。在某些任务上,Wave-PDE Nets可以将实际运行时间减少高达30%,并将峰值内存减少25%。消融研究表明,辛积分和谱拉普拉斯算子对于稳定性和性能至关重要。可视化结果表明,模型学习了信息传播的直观策略。

🎯 应用场景

Wave-PDE Nets具有广泛的应用前景,包括自然语言处理、计算机视觉和语音识别等领域。由于其计算效率高和内存占用小,它特别适合于资源受限的场景,例如移动设备和嵌入式系统。此外,Wave-PDE Nets的物理可解释性使其在科学计算和物理建模等领域也具有潜在的应用价值。

📄 摘要(原文)

We introduce Wave-PDE Nets, a neural architecture whose elementary operation is a differentiable simulation of the second-order wave equation. Each layer propagates its hidden state as a continuous field through a medium with trainable spatial velocity c(x) and damping γ(x). A symplectic spectral solver based on FFTs realises this propagation in O(nlog n) time. This oscillatory, global mechanism provides a powerful alternative to attention and first-order state-space models. We prove that a single Wave-PDE layer is a universal approximator. On language and vision benchmarks, Wave-PDE Nets match or exceed Transformer performance while demonstrating superior practical efficiency, reducing wall-clock time by up to 30% and peak memory by 25%. Ablation studies confirm the critical role of symplectic integration and a spectral Laplacian for stability and performance. Visualizations of the learned physical parameters reveal that the model learns intuitive strategies for information propagation. These results position Wave-PDE Nets as a computationally efficient and robust architecture with a strong physical inductive bias.