Layer-Adaptive State Pruning for Deep State Space Models
作者: Minseon Gwak, Seongrok Moon, Joohwan Ko, PooGyeon Park
分类: cs.LG, eess.SY
发布日期: 2024-11-05 (更新: 2025-01-31)
备注: NeurIPS 2024, Added missing arXiv information for one reference
🔗 代码/项目: GITHUB
💡 一句话要点
提出层自适应状态剪枝方法以优化深度状态空间模型
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 状态空间模型 剪枝方法 深度学习 能量优化 多输入多输出
📋 核心要点
- 现有深度状态空间模型在高状态维度下面临计算成本高的问题,缺乏有效的状态维度优化方法。
- 本文提出的层自适应状态剪枝(LAST)方法,通过评估层级能量和子系统的$ ext{H}_{ ext{∞}}$范数,实现了状态的有效剪枝。
- 实验结果表明,LAST方法在多个基准测试中优化了SSMs,剪去33%的状态后,准确率仅下降0.52%。
📝 摘要(中文)
由于缺乏状态维度优化方法,深度状态空间模型(SSMs)在高状态维度下牺牲了模型容量、训练搜索空间或稳定性,以减轻计算成本。本文提出了一种结构化的剪枝方法——层自适应状态剪枝(LAST),通过扩展单一系统的模态截断,减少每层的状态维度,从而最小化模型级输出能量损失。LAST评分通过子系统的$ ext{H}_{ ext{∞}}$范数和层级能量归一化进行评估,作为全局剪枝标准,实现跨层状态比较和层自适应剪枝。在多个序列基准测试中,LAST优化了先前的SSMs,揭示了其状态空间的冗余性和可压缩性。值得注意的是,我们展示了在多输入多输出SSMs中,平均剪去33%的状态仍能保持性能,准确率损失仅为0.52%,且无需重新训练。
🔬 方法详解
问题定义:本文旨在解决深度状态空间模型(SSMs)在高状态维度下的计算成本问题。现有方法往往牺牲模型容量或稳定性,缺乏有效的状态维度优化手段。
核心思路:提出层自适应状态剪枝(LAST)方法,通过对每层状态进行剪枝,最小化模型级输出能量损失,从而优化模型性能。该方法通过扩展模态截断技术,能够实现跨层状态的比较与剪枝。
技术框架:LAST方法的整体架构包括状态评分、剪枝决策和模型优化三个主要模块。首先,计算每层状态的评分,然后根据评分进行剪枝,最后优化剩余状态以提升模型性能。
关键创新:LAST的主要创新在于引入层级能量归一化和$ ext{H}_{ ext{∞}}$范数评估,作为全局剪枝标准。这一方法使得状态剪枝不仅限于单层,而是实现了跨层的自适应剪枝。
关键设计:在LAST中,关键参数包括层级能量的归一化方式和剪枝阈值的设定。此外,损失函数设计上注重保持模型的输出能量,确保剪枝后的模型性能不受显著影响。
🖼️ 关键图片
📊 实验亮点
实验结果显示,使用LAST方法在多输入多输出SSMs中,平均剪去33%的状态后,模型的准确率仅下降0.52%。这一结果表明,LAST在保持模型性能的同时,有效减少了计算资源的消耗,展现了其优越性。
🎯 应用场景
该研究的潜在应用领域包括时间序列预测、控制系统和信号处理等。通过优化深度状态空间模型,LAST方法能够在保持性能的同时显著降低计算成本,具有广泛的实际价值和应用前景。未来,该方法可能推动更高效的模型设计和应用于实时系统。
📄 摘要(原文)
Due to the lack of state dimension optimization methods, deep state space models (SSMs) have sacrificed model capacity, training search space, or stability to alleviate computational costs caused by high state dimensions. In this work, we provide a structured pruning method for SSMs, Layer-Adaptive STate pruning (LAST), which reduces the state dimension of each layer in minimizing model-level output energy loss by extending modal truncation for a single system. LAST scores are evaluated using the $\mathcal{H}_{\infty}$ norms of subsystems and layer-wise energy normalization. The scores serve as global pruning criteria, enabling cross-layer comparison of states and layer-adaptive pruning. Across various sequence benchmarks, LAST optimizes previous SSMs, revealing the redundancy and compressibility of their state spaces. Notably, we demonstrate that, on average, pruning 33% of states still maintains performance with 0.52% accuracy loss in multi-input multi-output SSMs without retraining. Code is available at https://github.com/msgwak/LAST.