Reinforcement Learning-based Token Pruning in Vision Transformers: A Markov Game Approach

📄 arXiv: 2503.23459v1 📥 PDF

作者: Chenglong Lu, Shen Liang, Xuewei Wang, Wei Wang

分类: cs.CV

发布日期: 2025-03-30

备注: Accepted by IEEE International Conference on Multimedia & Expo (ICME) 2025

🔗 代码/项目: GITHUB


💡 一句话要点

提出基于强化学习的ViT Token剪枝方法,提升推理速度

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: Vision Transformer Token剪枝 强化学习 多智能体 马尔可夫博弈

📋 核心要点

  1. ViT计算量随token数量平方增长,现有手工剪枝策略缺乏自适应性,且忽略了层间的序列关系。
  2. 将token剪枝建模为马尔可夫博弈,利用多智能体强化学习,为每个token学习个性化剪枝策略。
  3. 在ImageNet-1k上,该方法在精度损失仅0.4%的情况下,推理速度提升高达44%。

📝 摘要(中文)

Vision Transformer (ViT) 的计算成本与 token 数量呈平方关系,因此需要有效的 token 剪枝策略。现有策略大多是手工设计的,缺乏对不同输入的适应性,并且未能考虑跨多层 token 剪枝的序列特性。本文首次(据我们所知)利用强化学习 (RL) 来数据自适应地学习剪枝策略。我们将 token 剪枝建模为一个序列决策问题,并将其建模为马尔可夫博弈,利用多智能体近端策略优化 (MAPPO),其中每个智能体为单个 token 做出个性化的剪枝决策。我们还开发了奖励函数,使这些智能体能够同时协作和竞争,以平衡效率和准确性。在著名的 ImageNet-1k 数据集上,我们的方法将推理速度提高了高达 44%,而精度仅下降了 0.4%。源代码可在 https://github.com/daashuai/rl4evit 获得。

🔬 方法详解

问题定义:论文旨在解决Vision Transformer (ViT) 中由于token数量过多导致的计算成本过高的问题。现有的token剪枝策略通常是手工设计的,缺乏对不同输入数据的自适应性,并且没有充分利用token剪枝在不同网络层之间的序列依赖关系。这些痛点限制了ViT在资源受限设备上的部署和应用。

核心思路:论文的核心思路是将token剪枝过程建模为一个序列决策问题,并利用强化学习来学习一个数据自适应的剪枝策略。具体来说,将token剪枝过程视为一个马尔可夫博弈,其中每个token对应一个智能体,智能体通过与环境交互学习最优的剪枝策略。这种方法能够考虑到不同token的重要性以及层间的依赖关系,从而实现更高效的剪枝。

技术框架:整体框架包括以下几个主要模块:1) ViT模型作为环境;2) 多个智能体,每个智能体负责一个token的剪枝决策;3) 多智能体近端策略优化 (MAPPO) 算法,用于训练智能体的策略;4) 奖励函数,用于指导智能体学习。流程如下:首先,ViT模型接收输入图像,然后每个token对应的智能体根据当前状态(例如,token的特征)做出剪枝决策。根据决策结果,ViT模型进行前向传播,并计算奖励。最后,利用MAPPO算法更新智能体的策略。

关键创新:最重要的技术创新点在于将token剪枝问题建模为一个马尔可夫博弈,并利用多智能体强化学习来解决。与现有方法相比,该方法能够数据自适应地学习剪枝策略,并考虑到token之间的相互影响以及层间的序列依赖关系。此外,论文还设计了奖励函数,鼓励智能体在效率和准确性之间取得平衡。

关键设计:论文的关键设计包括:1) 使用MAPPO算法来训练智能体,该算法能够有效地处理多智能体环境中的协作和竞争关系;2) 设计了奖励函数,包括准确率奖励、剪枝率惩罚等,用于指导智能体学习;3) 针对不同的ViT架构,调整了MAPPO算法的超参数,以获得最佳性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,在ImageNet-1k数据集上,该方法在精度损失仅为0.4%的情况下,将ViT模型的推理速度提高了高达44%。与现有的手工设计的剪枝策略相比,该方法能够实现更好的效率和准确性平衡。这些结果表明,基于强化学习的token剪枝方法具有很大的潜力。

🎯 应用场景

该研究成果可应用于各种需要高效ViT模型部署的场景,例如移动设备上的图像识别、视频分析、自动驾驶等。通过降低计算成本,该方法可以使ViT模型在资源受限的环境中运行,从而扩展了ViT的应用范围。未来,该方法还可以与其他模型压缩技术相结合,进一步提高模型的效率。

📄 摘要(原文)

Vision Transformers (ViTs) have computational costs scaling quadratically with the number of tokens, calling for effective token pruning policies. Most existing policies are handcrafted, lacking adaptivity to varying inputs. Moreover, they fail to consider the sequential nature of token pruning across multiple layers. In this work, for the first time (as far as we know), we exploit Reinforcement Learning (RL) to data-adaptively learn a pruning policy. Formulating token pruning as a sequential decision-making problem, we model it as a Markov Game and utilize Multi-Agent Proximal Policy Optimization (MAPPO) where each agent makes an individualized pruning decision for a single token. We also develop reward functions that enable simultaneous collaboration and competition of these agents to balance efficiency and accuracy. On the well-known ImageNet-1k dataset, our method improves the inference speed by up to 44% while incurring only a negligible accuracy drop of 0.4%. The source code is available at https://github.com/daashuai/rl4evit.