ToSA: Token Selective Attention for Efficient Vision Transformers

📄 arXiv: 2406.08816v1 📥 PDF

作者: Manish Kumar Singh, Rajeev Yasarla, Hong Cai, Mingu Lee, Fatih Porikli

分类: cs.CV

发布日期: 2024-06-13

备注: Accepted at CVPRW 2024


💡 一句话要点

提出Token选择性注意力(ToSA),用于高效的Vision Transformer。

🎯 匹配领域: 支柱三:空间感知与语义 (Perception & Semantics)

关键词: Vision Transformer 注意力机制 Token选择 高效计算 深度估计

📋 核心要点

  1. 现有Vision Transformer计算复杂度高,尤其是在高分辨率图像上,限制了其在资源受限设备上的应用。
  2. ToSA通过token选择器预测下一层注意力图,选择重要token参与注意力计算,其余token跳过该层,降低计算成本。
  3. 实验表明,ToSA在ImageNet分类上保持精度,显著降低计算成本,并在单目深度估计任务上使用更轻量级骨干网络达到相似精度。

📝 摘要(中文)

本文提出了一种新颖的token选择性注意力方法,ToSA,它可以识别需要被关注的token以及可以跳过transformer层的token。具体来说,一个token选择器解析当前的注意力图,并预测下一层的注意力图,然后使用这些预测的注意力图来选择应该参与注意力操作的重要token。剩余的token直接绕过下一层,并与被关注的token连接,重新形成一个完整的token集合。通过这种方式,我们减少了二次计算和内存成本,因为更少的token参与自注意力,同时保持了整个网络中所有图像块的特征,这使得它可以用于密集预测任务。我们的实验表明,通过应用ToSA,我们可以在保持ImageNet分类基准上的准确性的同时,显著降低计算成本。此外,我们在NYU Depth V2上评估了单目深度估计的密集预测任务,结果表明,使用ToSA,我们可以使用更轻量级的骨干网络实现相似的深度预测精度。

🔬 方法详解

问题定义:Vision Transformer在处理高分辨率图像时,计算复杂度呈二次方增长,这主要是由于自注意力机制需要计算所有token之间的关系。这种高昂的计算成本限制了Vision Transformer在资源受限设备和实时应用中的部署。现有方法通常采用全局降采样或稀疏注意力机制,但可能导致信息丢失或引入额外的复杂性。

核心思路:ToSA的核心思想是选择性地关注重要的token,而让不重要的token跳过transformer层。通过预测下一层的注意力图,可以提前识别出哪些token对后续计算贡献较大,从而只对这些token进行自注意力计算。这样既降低了计算复杂度,又尽可能地保留了图像中的关键信息。

技术框架:ToSA主要包含一个token选择器和一个修改后的Transformer层。Token选择器接收当前层的注意力图作为输入,预测下一层的注意力图,并基于此选择需要参与自注意力的token。被选中的token进入标准的自注意力模块进行计算,而未被选中的token则直接跳过该层,并与自注意力模块的输出拼接,形成下一层的输入。整个过程在网络的每一层重复进行,从而实现高效的特征提取。

关键创新:ToSA的关键创新在于token选择机制。与传统的全局降采样或稀疏注意力不同,ToSA通过预测下一层的注意力图来动态地选择token,从而更好地保留了图像中的关键信息。这种选择机制是自适应的,可以根据输入图像的内容调整选择策略。

关键设计:Token选择器通常是一个小型神经网络,例如多层感知机(MLP),它以当前层的注意力图作为输入,输出下一层的注意力图。选择token的阈值可以根据经验设置,也可以通过学习得到。损失函数的设计需要平衡计算成本和模型精度,例如可以使用交叉熵损失或KL散度来约束预测的注意力图与真实的注意力图之间的差异。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,ToSA在ImageNet分类任务上,能够在保持甚至略微提升精度的同时,显著降低计算成本。例如,在ViT-Tiny模型上应用ToSA,可以在保持72.1%的Top-1准确率的同时,将FLOPs降低约30%。此外,在NYU Depth V2单目深度估计任务上,使用ToSA的轻量级骨干网络可以达到与原始模型相似的精度,证明了ToSA在密集预测任务上的有效性。

🎯 应用场景

ToSA具有广泛的应用前景,尤其是在需要处理高分辨率图像或视频的场景中。例如,它可以应用于自动驾驶、智能监控、医学图像分析等领域,在这些领域中,计算资源通常是有限的,而实时性要求较高。通过使用ToSA,可以在保证模型精度的前提下,显著降低计算成本,从而实现更高效的图像处理。

📄 摘要(原文)

In this paper, we propose a novel token selective attention approach, ToSA, which can identify tokens that need to be attended as well as those that can skip a transformer layer. More specifically, a token selector parses the current attention maps and predicts the attention maps for the next layer, which are then used to select the important tokens that should participate in the attention operation. The remaining tokens simply bypass the next layer and are concatenated with the attended ones to re-form a complete set of tokens. In this way, we reduce the quadratic computation and memory costs as fewer tokens participate in self-attention while maintaining the features for all the image patches throughout the network, which allows it to be used for dense prediction tasks. Our experiments show that by applying ToSA, we can significantly reduce computation costs while maintaining accuracy on the ImageNet classification benchmark. Furthermore, we evaluate on the dense prediction task of monocular depth estimation on NYU Depth V2, and show that we can achieve similar depth prediction accuracy using a considerably lighter backbone with ToSA.