Avoiding spurious sharpness minimization broadens applicability of SAM
作者: Sidak Pal Singh, Hossein Mobahi, Atish Agarwala, Yann Dauphin
分类: cs.LG, cs.CL, stat.ML
发布日期: 2025-02-04
💡 一句话要点
Functional-SAM通过避免虚假锐度最小化,扩展SAM在NLP领域的适用性
🎯 匹配领域: 支柱一:机器人控制 (Robot Control) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 锐度感知最小化 曲率正则化 自然语言处理 大型语言模型 泛化能力 函数统计量 虚假最小化 优化算法
📋 核心要点
- SAM在视觉任务中表现出色,但在NLP领域性能下降,原因在于其在NLP中主要正则化logit统计量而非函数几何结构。
- Functional-SAM通过修改函数统计量来正则化曲率,避免了logit操作带来的虚假最小化,从而提升了性能。
- 预处理SAM扰动可以进一步防止虚假最小化,与Functional-SAM结合使用效果更佳,并在多种模型规模上验证了有效性。
📝 摘要(中文)
锐度感知最小化(SAM)等曲率正则化技术在提升视觉任务的泛化能力方面表现出巨大潜力。然而,我们发现SAM在自然语言处理(NLP)等领域表现不佳,甚至会降低性能——即使计算预算翻倍。我们研究了跨领域的差异,发现SAM在NLP环境中主要受到logit统计量的正则化影响,而不是改善函数本身的几何结构。基于此,我们开发了一种名为Functional-SAM的替代算法,该算法仅通过修改神经网络实现的整体函数的统计量来正则化曲率,并避免通过logit操作进行虚假最小化。此外,我们认为预处理SAM扰动也能防止虚假最小化,并且与Functional-SAM结合使用时,可以进一步改进性能。在固定长度和Chinchilla风格的训练设置中,我们的算法在各种模型规模(包括数十亿参数规模)上,在相同训练步数下,均优于AdamW和SAM基线。总而言之,我们的工作强调了更精确地表征锐度对于将曲率正则化扩展到大型语言模型(LLM)的重要性。
🔬 方法详解
问题定义:论文旨在解决SAM算法在NLP领域泛化能力不足的问题。现有SAM算法在NLP任务中,容易陷入对logit统计量的虚假锐度最小化,而无法真正改善模型的泛化性能。这种现象导致SAM在NLP任务中表现不如AdamW等传统优化器,即使增加计算资源也难以提升效果。
核心思路:论文的核心思路是避免SAM算法对logit统计量的过度关注,转而关注函数本身的几何结构。通过正则化整体函数的统计量,而非直接操作logit,可以更有效地改善模型的泛化能力。此外,通过预处理SAM扰动,可以进一步防止虚假锐度最小化的发生。
技术框架:论文提出了Functional-SAM算法,其核心在于修改SAM的扰动方式。传统的SAM直接在参数空间添加扰动,而Functional-SAM则通过修改函数输出的统计量来间接影响参数。此外,论文还提出了预处理扰动的策略,以确保扰动方向更有利于泛化性能的提升。整体训练流程与SAM类似,但在计算扰动时采用了不同的方法。
关键创新:论文的关键创新在于提出了Functional-SAM算法,该算法通过正则化函数统计量而非logit统计量来避免虚假锐度最小化。此外,预处理扰动的策略也是一个重要的创新点,它可以进一步提升SAM算法的性能。与传统SAM相比,Functional-SAM更关注函数本身的几何结构,从而在NLP任务中表现更好。
关键设计:Functional-SAM的关键设计在于如何修改函数输出的统计量。具体实现细节未知,但核心思想是通过某种方式约束或正则化函数输出的均值、方差等统计特性,从而影响参数的更新方向。预处理扰动的具体方法也未知,但其目标是确保扰动方向与泛化性能的提升方向一致。损失函数和网络结构与原始SAM保持一致,主要区别在于扰动的计算方式。
🖼️ 关键图片
📊 实验亮点
实验结果表明,Functional-SAM在各种模型规模(包括数十亿参数规模)上,在固定长度和Chinchilla风格的训练设置中,在相同训练步数下,均优于AdamW和SAM基线。这表明Functional-SAM能够更有效地利用计算资源,提升模型的泛化能力。具体的性能提升幅度未知,但论文强调了其在多个NLP任务上的优越性。
🎯 应用场景
该研究成果可广泛应用于自然语言处理领域,尤其是在训练大型语言模型时。Functional-SAM能够提升模型的泛化能力,从而提高模型在各种NLP任务中的性能,例如文本分类、机器翻译、文本生成等。该方法还有助于降低模型对训练数据的过拟合程度,提高模型在实际应用中的鲁棒性。
📄 摘要(原文)
Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance -- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics -- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional-SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional-SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs).