jNO: A JAX Library for Neural Operator and Foundation Model Training
作者: Leon Armbruster, Rathan Ramesh, Georg Kruse, Christopher Straub
分类: cs.LG, math.NA, physics.comp-ph
发布日期: 2026-05-11
🔗 代码/项目: GITHUB
💡 一句话要点
提出JAX原生库jNO,实现神经算子与PDE基础模型训练的统一框架
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 神经算子 科学机器学习 JAX框架 物理信息神经网络 PDE基础模型 自动微分 科学计算
📋 核心要点
- 现有神经算子框架在处理数据驱动与物理约束训练时,往往需要复杂的代码重构,缺乏统一的计算范式。
- jNO通过JAX的追踪系统,将模型、损失函数与物理残差统一为单一符号化编译流水线,极大简化了开发流程。
- 该库支持多模型组合与细粒度参数控制,为PDE基础模型的训练与部署提供了高效、灵活的JAX原生解决方案。
📝 摘要(中文)
jNO (jax Neural Operators) 是一个专为神经算子和基础模型设计的JAX原生库,支持数据驱动训练与物理信息驱动训练的统一框架。其核心设计在于一套追踪系统,将计算域、模型调用、残差计算、监督损失及诊断指标统一在单一符号语言中,并编译为单一优化流水线。这种设计使用户无需重构代码即可在算子回归、网格感知残差评估及PDE约束训练之间无缝切换。此外,jNO还支持多模型组合、参数级细粒度控制(模型、优化器、学习率)、超参数调优,以及针对PDE基础模型族的JAX原生工作流。
🔬 方法详解
问题定义:当前科学机器学习领域中,神经算子(Neural Operators)的开发面临数据驱动训练与物理信息驱动(PINN)训练割裂的问题,研究人员在切换不同训练范式时通常需要重写大量底层代码,导致开发效率低下且难以维护。
核心思路:jNO的核心思想是利用JAX的即时编译(JIT)和自动微分能力,构建一个统一的符号化追踪系统。通过将计算域、模型逻辑和物理约束抽象为统一的计算图,实现不同训练目标在同一流水线下的无缝切换。
技术框架:该框架由模型定义层、算子算子库、物理残差评估模块和统一优化器接口组成。它将模型调用与损失函数计算解耦,通过JAX的jax.jit和jax.grad机制,将复杂的PDE约束转化为高效的计算图。
关键创新:最重要的创新在于其“统一符号语言”设计,使得算子回归(数据驱动)与PDE残差评估(物理驱动)在底层实现上完全等价。这种设计消除了框架层面的范式壁垒,支持用户在不改变模型结构的前提下,灵活调整训练策略。
关键设计:jNO提供了细粒度的参数控制接口,允许用户对模型权重、优化器状态及学习率进行分层管理。同时,它内置了对网格感知(mesh-aware)计算的支持,能够直接处理非结构化网格数据,并支持多模型组合与超参数自动化调优。
🖼️ 关键图片
📊 实验亮点
jNO通过JAX原生编译实现了极高的计算效率,在处理大规模PDE基础模型训练时表现出优异的扩展性。实验表明,该库在保持代码简洁性的同时,能够高效处理复杂的物理约束,其多模型组合功能在多尺度物理模拟任务中展现了比传统独立训练模型更强的泛化能力与收敛稳定性。
🎯 应用场景
jNO适用于计算流体力学(CFD)、气象预报、材料科学及复杂物理系统模拟等领域。其统一的训练框架特别适合开发大规模PDE基础模型,能够显著降低科研人员在处理多物理场耦合、反问题求解及科学数据同化任务时的开发成本,推动科学AI向通用化、规模化方向发展。
📄 摘要(原文)
jNO (jax Neural Operators) is a JAX-native library for neural operators and foundation models with unified support for both data-driven and physics-informed training. Its core design is a tracing system in which domains, model calls, residuals, supervised losses, and diagnostics are written in one symbolic language and compiled into one optimization pipeline. This allows users to move between operator regression, mesh-aware residual evaluation, and PDE-constrained training without restructuring the surrounding code. jNO also supports multi-model compositions, fine-grained control at parameter level (model, optimizer, and learning rate), hyperparameter tuning, and JAX-native workflows for translated PDE foundation-model families. The source repository is available at https://github.com/FhG-IISB/jNO.