CAST: Continuous and Differentiable Semi-Structured Sparsity-Aware Training for Large Language Models

📄 arXiv: 2509.25996v1 📥 PDF

作者: Weiyu Huang, Yuezhou Hu, Jun Zhu, Jianfei Chen

分类: cs.LG, cs.CL

发布日期: 2025-09-30

备注: Submitted to IEEE TPAMI


💡 一句话要点

CAST:面向大语言模型的连续可微半结构化稀疏训练框架

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 稀疏训练 模型压缩 知识蒸馏 半结构化稀疏 自适应优化 硬件加速

📋 核心要点

  1. 现有稀疏训练方法通常分离稀疏模式和权重的优化,导致训练效率低下,难以充分利用硬件加速。
  2. CAST框架通过连续可微的方式,实现稀疏模式和权重的联合优化,逐步将模型转化为目标稀疏格式。
  3. 实验表明,CAST在多种模型上显著提升了困惑度和零样本准确率,且仅需少量训练资源,具有实际应用价值。

📝 摘要(中文)

本文提出了一种名为连续自适应稀疏训练器(CAST)的框架,用于对半结构化(或“N:M”)稀疏模型进行完全连续且可微的稀疏感知训练。与以往分别优化稀疏模式和权重的方法不同,CAST 实现了训练期间的无缝联合优化,同时逐步将模型转换为所需的稀疏格式。CAST 引入了三个关键组件:1) AdamS,一种稀疏感知优化器,利用自适应 L1 衰减来促进所有参数的均匀稀疏化;2) 权重缩放,旨在减轻衰减引起的幅度减小,同时保留所需的稀疏模式;3) 知识蒸馏,使用密集模型作为自教师来提高训练效率。在 125M 到 13B 参数的多个模型系列中,我们在 2:4 稀疏模式下评估了 CAST。结果表明,与之前的最先进方法相比,在困惑度和零样本准确率方面都有显著提高,且仅需极少的训练资源。值得注意的是,在 LLaMA2-7B 上,我们的 2:4 稀疏模型仅使用原始预训练 tokens 的 2%,就实现了可忽略不计的 0.09 困惑度增加和 0.36% 的零样本准确率提升。此外,我们建立了一个准确而稳健的经验缩放定律,以预测在充足的训练资源下稀疏模型的性能。最后,我们通过在量化和微调场景下评估我们的稀疏模型,证明了它们的实际适用性。

🔬 方法详解

问题定义:现有的大语言模型稀疏化训练方法,通常将稀疏模式的确定和权重的训练分开进行,这导致优化过程次优,无法充分利用硬件对特定稀疏模式(如N:M稀疏)的加速能力。此外,训练过程可能不稳定,需要大量的调参工作。

核心思路:CAST的核心思路是设计一个完全连续且可微的稀疏训练框架,使得稀疏模式和权重可以联合优化。通过引入自适应L1衰减、权重缩放和知识蒸馏等技术,CAST能够平滑地将模型转化为目标稀疏结构,同时保持模型的性能。

技术框架:CAST框架包含三个主要模块:1) AdamS:一种稀疏感知优化器,通过自适应L1衰减来促进参数的均匀稀疏化。2) 权重缩放:用于补偿L1衰减导致的权重幅度减小,保持稀疏模式的稳定性。3) 知识蒸馏:利用密集模型作为自教师,指导稀疏模型的训练,提高训练效率和模型性能。整个训练过程是端到端的,可以联合优化所有参数。

关键创新:CAST的关键创新在于其完全连续可微的特性,这使得稀疏模式和权重的优化可以同时进行,避免了传统方法中分离优化带来的问题。此外,AdamS优化器和权重缩放模块的设计,有效地解决了稀疏训练中的一些常见问题,如参数分布不均和权重幅度减小。

关键设计:AdamS优化器使用自适应的L1衰减系数,该系数根据参数的幅度动态调整,以实现更均匀的稀疏化。权重缩放模块通过对权重进行缩放,补偿L1衰减带来的幅度损失,保持稀疏模式的稳定性。知识蒸馏使用密集模型的输出作为软标签,指导稀疏模型的训练,损失函数包括交叉熵损失和蒸馏损失。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

在LLaMA2-7B模型上,使用CAST训练的2:4稀疏模型仅使用2%的原始预训练tokens,就达到了与密集模型几乎相同的困惑度(仅增加0.09)和更高的零样本准确率(提升0.36%)。此外,该论文还建立了准确的经验缩放定律,可以预测在不同训练资源下稀疏模型的性能。

🎯 应用场景

CAST框架可应用于大语言模型的压缩和加速,尤其是在资源受限的场景下,如移动设备和边缘计算。通过将大模型转化为稀疏模型,可以显著降低模型的存储空间和计算复杂度,从而实现更高效的推理。此外,该方法还可以与其他模型压缩技术(如量化)结合使用,进一步提高模型的效率。

📄 摘要(原文)

Sparsity-aware training is an effective approach for transforming large language models (LLMs) into hardware-friendly sparse patterns, thereby reducing latency and memory consumption during inference. In this paper, we propose Continuous Adaptive Sparse Trainer (CAST), a fully continuous and differentiable sparsity-aware training framework for semi-structured (or "N:M") sparse models. Unlike previous approaches that optimize sparsity patterns and weights separately, CAST enables seamless joint optimization during training, while progressively transforming the model into the desired sparsity format. Specifically, CAST introduces three key components: 1) AdamS, a sparsity-aware optimizer that leverages adaptive L1 decay to promote uniform sparsification across all parameters; 2) Weight Scaling, a module designed to mitigate the magnitude reduction caused by decay while preserving desired sparsity patterns; 3) Knowledge Distillation, which employs the dense model as a self-teacher to enhance training efficiency. We evaluate CAST under 2:4 sparsity patterns across multiple model families, ranging from 125M to 13B parameters. Our results demonstrate significant improvements over previous state-of-the-art methods in both perplexity and zero-shot accuracy with minimal training resources. Notably, on LLaMA2-7B, our 2:4 sparse model achieves a negligible perplexity increase of 0.09 and a 0.36% gain in zero-shot accuracy compared to the dense model using only 2% of the original pretraining tokens. Additionally, we establish an accurate and robust empirical scaling law to predict sparse model performance given adequate training resources. Finally, we demonstrate the practical applicability of our sparse models by evaluating them under quantization and fine-tuning scenarios.