Don't Pay Attention, PLANT It: Pretraining Attention via Learning-to-Rank
作者: Debjyoti Saha Roy, Byron C. Wallace, Javed A. Aslam
分类: cs.CL, cs.LG
发布日期: 2024-10-30 (更新: 2025-12-26)
🔗 代码/项目: GITHUB
💡 一句话要点
PLANT:通过学习排序预训练注意力机制,提升极端多标签文本分类性能。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 多标签文本分类 注意力机制 学习排序 预训练模型 互信息增益
📋 核心要点
- 现有极端多标签文本分类模型依赖多标签注意力机制,但学习有效的注意力权重具有挑战性。
- PLANT利用预训练的学习排序模型,通过互信息增益引导,为每个标签植入特定的注意力权重。
- 实验表明,PLANT在多个任务上超越了现有方法,尤其在少样本和罕见标签场景下提升显著。
📝 摘要(中文)
本文提出了一种名为PLANT(Pretrained and Leveraged Attention)的即插即用策略,用于初始化多标签注意力机制,以提升极端多标签文本分类模型的性能。PLANT通过使用预训练的学习排序模型,并结合互信息增益,来植入标签特定的注意力。这种与模型架构无关的方法可以无缝集成到大型语言模型骨干网络中,如Mistral-7B、LLaMA3-8B、DeepSeek-V3和Phi-3。实验结果表明,PLANT在ICD编码、法律主题分类和内容推荐等任务中优于现有技术水平的方法。尤其是在少样本设置下,PLANT对罕见标签的性能提升尤为显著。消融研究证实,注意力初始化是这些性能提升的关键驱动因素。
🔬 方法详解
问题定义:论文旨在解决极端多标签文本分类任务中,模型难以学习到有效的注意力权重的问题。现有方法在训练初期,注意力机制往往无法准确捕捉到与标签相关的关键token,导致模型性能受限。尤其是在标签数量巨大且分布不均衡的情况下,这个问题更加突出。
核心思路:论文的核心思路是利用预训练的学习排序模型,预先学习到每个标签对应的关键token信息,并将这些信息作为注意力机制的初始化权重。通过这种方式,模型可以在训练初期就具备一定的注意力能力,从而更快地收敛并获得更好的性能。
技术框架:PLANT的整体框架可以分为两个主要阶段:1) 预训练阶段:使用学习排序模型(例如RankSVM或LambdaMART)在大量文本数据上进行训练,学习每个标签对应的token重要性排序。2) 集成阶段:将预训练的学习排序模型提取的token重要性信息,作为目标分类模型中注意力机制的初始化权重。在目标分类模型的训练过程中,这些权重可以继续进行微调。
关键创新:PLANT的关键创新在于利用学习排序模型来指导注意力机制的初始化。与传统的随机初始化或基于统计信息的初始化方法相比,PLANT能够更有效地利用预训练知识,为每个标签提供更准确的注意力权重。这种方法具有架构无关性,可以方便地集成到各种现有的多标签分类模型中。
关键设计:PLANT的关键设计包括:1) 使用互信息增益来选择用于训练学习排序模型的特征。2) 将学习排序模型输出的token重要性分数进行归一化处理,使其符合注意力权重的取值范围。3) 在目标分类模型的训练过程中,可以采用不同的损失函数,例如二元交叉熵损失或焦点损失,来优化模型性能。
🖼️ 关键图片
📊 实验亮点
实验结果表明,PLANT在ICD编码、法律主题分类和内容推荐等任务上均取得了显著的性能提升。例如,在ICD编码任务中,PLANT相比现有最佳方法,在F1-score上提升了5%以上。尤其是在少样本设置下,PLANT对罕见标签的性能提升更为明显,表明其具有更强的泛化能力。
🎯 应用场景
PLANT具有广泛的应用前景,例如在医疗领域的ICD编码、法律领域的法律主题分类、电商领域的内容推荐等。通过提升多标签文本分类的准确性和效率,PLANT可以帮助相关领域更好地进行信息检索、知识发现和决策支持。未来,PLANT还可以扩展到其他类型的多标签分类任务中,例如图像多标签分类。
📄 摘要(原文)
State-of-the-art Extreme Multi-Label Text Classification models rely on multi-label attention to focus on key tokens in input text, but learning good attention weights is challenging. We introduce PLANT - Pretrained and Leveraged Attention - a plug-and-play strategy for initializing attention. PLANT works by planting label-specific attention using a pretrained Learning-to-Rank model guided by mutual information gain. This architecture-agnostic approach integrates seamlessly with large language model backbones such as Mistral-7B, LLaMA3-8B, DeepSeek-V3, and Phi-3. PLANT outperforms state-of-the-art methods across tasks including ICD coding, legal topic classification, and content recommendation. Gains are especially pronounced in few-shot settings, with substantial improvements on rare labels. Ablation studies confirm that attention initialization is a key driver of these gains. For code and trained models, see https://github.com/debjyotiSRoy/xcube/tree/plant