GrAInS: Gradient-based Attribution for Inference-Time Steering of LLMs and VLMs

📄 arXiv: 2507.18043v1 📥 PDF

作者: Duy Nguyen, Archiki Prasad, Elias Stengel-Eskin, Mohit Bansal

分类: cs.CL, cs.AI, cs.CV

发布日期: 2025-07-24

备注: 21 pages. Code: https://github.com/duykhuongnguyen/GrAInS


💡 一句话要点

GrAInS:利用梯度归因实现LLM和VLM的推理时引导

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 推理时引导 梯度归因 大型语言模型 视觉语言模型 可解释性 模型控制 多模态学习

📋 核心要点

  1. 现有推理时引导方法依赖全局干预向量,忽略token因果影响,且未充分利用梯度信息,尤其在多模态场景下。
  2. GrAInS通过对比梯度归因识别关键token,构建定向引导向量,在推理时调整激活,实现细粒度模型行为控制。
  3. 实验表明,GrAInS在多个基准测试中优于微调和现有引导方法,提升了准确率、降低了幻觉率并提高了对齐胜率。

📝 摘要(中文)

推理时引导方法为大型语言模型(LLM)和视觉语言模型(VLM)提供了一种轻量级的替代方案,它通过在测试时修改内部激活而无需更新模型权重。然而,现有方法大多依赖于固定的全局干预向量,忽略了单个输入token的因果影响,并且未能利用来自模型logits的信息梯度,尤其是在视觉和文本输入贡献不均的多模态环境中。为了解决这些局限性,我们提出了一种推理时引导方法GrAInS,它可以在纯语言模型和视觉语言模型以及任务中运行。GrAInS使用基于积分梯度的对比梯度归因来识别top-k个最具影响力的token,这些token基于它们对首选输出与非首选输出的贡献进行正向和负向归因。然后,这些token用于构建定向引导向量,以捕获从不良行为到理想行为的语义转变。在推理过程中,GrAInS在transformer层调整隐藏激活,并由token级别的归因信号引导,并对激活进行归一化以保持表示尺度。这实现了对模型行为的细粒度、可解释和模块化控制,而无需重新训练或辅助监督。实验表明,GrAInS始终优于微调和现有的引导基线:使用Llama-3.1-8B在TruthfulQA上实现了13.22%的准确率提升,使用LLaVA-1.6-7B将MMHal-Bench上的幻觉率从0.624降低到0.514,并在SPA-VL上将对齐胜率提高了8.11%,同时保持了模型的流畅性和通用能力。

🔬 方法详解

问题定义:现有推理时引导方法在控制LLM和VLM行为时存在局限性,主要体现在:1) 使用固定的全局干预向量,缺乏对输入token的细粒度控制;2) 忽略了不同token对模型输出的因果影响;3) 未能充分利用模型logits的梯度信息,尤其是在多模态场景下,视觉和文本输入的贡献不均衡。这些痛点导致模型行为控制不够精确,可解释性差,且难以适应复杂的多模态任务。

核心思路:GrAInS的核心思路是利用对比梯度归因来识别对模型输出影响最大的token,并基于这些token构建定向引导向量,从而在推理时对模型的内部激活进行调整。通过关注token级别的因果关系,并利用梯度信息,GrAInS能够实现更精细、更可控的模型行为引导。这种方法无需重新训练模型,即可在推理阶段灵活地调整模型行为。

技术框架:GrAInS的整体框架包括以下几个主要阶段:1) 梯度归因:使用积分梯度方法计算每个输入token对模型输出的贡献度,区分正向和负向影响;2) 关键Token选择:根据归因分数选择top-k个最具影响力的token;3) 引导向量构建:基于选定的token构建定向引导向量,该向量代表从不期望行为到期望行为的语义转变;4) 激活调整:在transformer层的隐藏状态中,根据引导向量调整激活值,并进行归一化,以保持表示尺度。

关键创新:GrAInS的关键创新在于:1) 对比梯度归因:使用对比的方式,区分token对期望输出和非期望输出的贡献,从而更准确地识别关键token;2) token级别引导:将引导操作细化到token级别,实现更精细的模型行为控制;3) 激活归一化:在调整激活值后进行归一化,避免激活值过大或过小,保持模型的稳定性和泛化能力。与现有方法相比,GrAInS能够更有效地利用梯度信息,实现更精确、更可解释的模型行为引导。

关键设计:GrAInS的关键设计包括:1) 积分梯度参数:积分梯度的路径积分步数需要根据具体任务进行调整,以获得最佳的归因效果;2) Top-k选择:top-k值的选择需要平衡计算复杂度和引导效果,通常选择对模型输出影响最大的少量token;3) 激活归一化方法:可以使用L2归一化或其他归一化方法,以保持激活值的尺度不变;4) 引导向量的权重:可以根据token的归因分数对引导向量进行加权,以增强关键token的影响。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

GrAInS在多个任务上取得了显著的性能提升。在TruthfulQA上,使用Llama-3.1-8B模型,GrAInS实现了13.22%的准确率提升。在MMHal-Bench上,使用LLaVA-1.6-7B模型,GrAInS将幻觉率从0.624降低到0.514。在SPA-VL上,GrAInS将对齐胜率提高了8.11%。这些结果表明,GrAInS能够有效地提高模型的准确性、可靠性和安全性。

🎯 应用场景

GrAInS可应用于多种场景,例如:提高LLM的事实性和可靠性,减少幻觉;控制VLM生成图像描述的风格和内容;增强模型在对话系统中的安全性,避免生成有害或不当内容;在医疗领域,引导模型生成更准确的诊断报告。该研究具有重要的实际价值,有助于提升LLM和VLM在各个领域的应用效果和安全性。

📄 摘要(原文)

Inference-time steering methods offer a lightweight alternative to fine-tuning large language models (LLMs) and vision-language models (VLMs) by modifying internal activations at test time without updating model weights. However, most existing approaches rely on fixed, global intervention vectors, overlook the causal influence of individual input tokens, and fail to leverage informative gradients from the model's logits, particularly in multimodal settings where visual and textual inputs contribute unevenly. To address these limitations, we introduce GrAInS, an inference-time steering approach that operates across both language-only and vision-language models and tasks. GrAInS uses contrastive, gradient-based attribution via Integrated Gradients to identify the top-k most influential tokens, both positively and negatively attributed based on their contribution to preferred versus dispreferred outputs. These tokens are then used to construct directional steering vectors that capture semantic shifts from undesirable to desirable behavior. During inference, GrAInS adjusts hidden activations at transformer layers guided by token-level attribution signals, and normalizes activations to preserve representational scale. This enables fine-grained, interpretable, and modular control over model behavior, without retraining or auxiliary supervision. Empirically, GrAInS consistently outperforms both fine-tuning and existing steering baselines: it achieves a 13.22% accuracy gain on TruthfulQA using Llama-3.1-8B, reduces hallucination rates on MMHal-Bench from 0.624 to 0.514 with LLaVA-1.6-7B, and improves alignment win rates on SPA-VL by 8.11%, all while preserving the model's fluency and general capabilities.