A Proximal Operator for Inducing 2:4-Sparsity
作者: Jonas M Kübler, Yu-Xiang Wang, Shoham Sabach, Navid Ansari, Matthäus Kleindessner, Kailash Budhathoki, Volkan Cevher, George Karypis
分类: cs.LG
发布日期: 2025-01-29
期刊: Transactions on Machine Learning Research, 2835-8856, 2025 (https://openreview.net/forum?id=AsFbXRIe4q)
💡 一句话要点
提出一种诱导2:4稀疏性的近端算子,提升大语言模型剪枝性能。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大语言模型 模型剪枝 稀疏性 近端算子 正则化 硬件加速 2:4稀疏
📋 核心要点
- 现有AI加速器对2:4稀疏矩阵乘法高效,但模型精度常降低,是当前面临的核心问题。
- 利用特征局部相关性设计正则化项,通过优化稀疏掩码提升模型剪枝后的精度。
- 实验表明,该方法在大型语言模型剪枝上表现出色,部分模型优于现有算法。
📝 摘要(中文)
本文提出了一种新的正则化方法,旨在利用特征的局部相关性,为训练模型找到更好的稀疏掩码,从而提高2:4稀疏矩阵乘法的效率。该方法通过最小化正则化项和一个局部平方损失函数,推导出相应的近端算子,并证明了其在2:4稀疏情况下的高效解。在优化掩码后,使用掩码梯度更新进一步最小化局部平方损失。实验结果表明,该方法在玩具问题上有效,并成功应用于剪枝高达70B参数的大型语言模型。在13B参数的模型上,该方法优于现有最佳算法,而在70B模型上,性能与现有最佳算法相当。
🔬 方法详解
问题定义:论文旨在解决大语言模型剪枝过程中,如何在保证模型精度的前提下,有效利用硬件加速器对2:4稀疏矩阵乘法的支持。现有方法在实现2:4稀疏时,通常会导致模型精度显著下降,无法充分发挥硬件优势。
核心思路:论文的核心思路是设计一个正则化项,该正则化项能够利用特征的局部相关性,引导模型学习到更优的稀疏掩码。通过优化该正则化项,可以在保证模型精度的同时,实现高效的2:4稀疏。
技术框架:该方法主要包含以下几个阶段:1) 设计一个基于特征局部相关性的正则化项。2) 推导出该正则化项与局部平方损失函数的近端算子,并证明其在2:4稀疏情况下的高效解。3) 使用近端算子优化稀疏掩码。4) 使用掩码梯度更新进一步最小化局部平方损失。
关键创新:该方法最重要的技术创新点在于提出了一个能够有效利用特征局部相关性的正则化项,并推导出了相应的近端算子。与现有方法相比,该方法能够找到更好的稀疏掩码,从而在保证模型精度的前提下,实现更高的稀疏率。
关键设计:论文的关键设计包括:正则化项的具体形式(未知,需要查阅论文原文),局部平方损失函数的选择,以及近端算子的具体推导过程(未知,需要查阅论文原文)。此外,掩码梯度更新的具体实现方式也是一个关键的设计细节(未知,需要查阅论文原文)。
🖼️ 关键图片
📊 实验亮点
实验结果表明,该方法在高达70B参数的大型语言模型上成功实现了剪枝。在13B参数的模型上,该方法优于现有最佳算法,而在70B模型上,性能与现有最佳算法相当。这些结果表明,该方法在大型语言模型剪枝方面具有显著的优势。
🎯 应用场景
该研究成果可广泛应用于大语言模型的压缩与加速,尤其是在资源受限的边缘设备上部署大型模型。通过提高模型剪枝的效率和精度,可以降低模型存储空间和计算复杂度,从而实现更快的推理速度和更低的功耗。这对于推动人工智能在移动设备、物联网设备等领域的应用具有重要意义。
📄 摘要(原文)
Recent hardware advancements in AI Accelerators and GPUs allow to efficiently compute sparse matrix multiplications, especially when 2 out of 4 consecutive weights are set to zero. However, this so-called 2:4 sparsity usually comes at a decreased accuracy of the model. We derive a regularizer that exploits the local correlation of features to find better sparsity masks in trained models. We minimize the regularizer jointly with a local squared loss by deriving the proximal operator for which we show that it has an efficient solution in the 2:4-sparse case. After optimizing the mask, we use maskedgradient updates to further minimize the local squared loss. We illustrate our method on toy problems and apply it to pruning entire large language models up to 70B parameters. On models up to 13B we improve over previous state of the art algorithms, whilst on 70B models we match their performance.