In-Context Learning with Representations: Contextual Generalization of Trained Transformers
作者: Tong Yang, Yu Huang, Yingbin Liang, Yuejie Chi
分类: cs.LG, cs.CL, cs.IT, math.OC, stat.ML
发布日期: 2024-08-19 (更新: 2024-09-25)
备注: Accepted by NeurIPS 2024
💡 一句话要点
研究Transformer在上下文学习中泛化能力,证明其可学习模板信息以推广到未见示例和任务。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 上下文学习 Transformer 泛化能力 非线性回归 模板学习
📋 核心要点
- 大型语言模型的上下文学习能力缺乏理论支撑,特别是Transformer能否泛化到prompt中未见过的示例。
- 论文通过非线性回归任务,分析Transformer的训练动态,证明其可学习任务模板函数以实现上下文泛化。
- 研究表明,单层多头Transformer的训练损失线性收敛到全局最小值,并有效地学习执行岭回归。
📝 摘要(中文)
上下文学习(ICL)是指预训练大型语言模型的一种显著能力,即在推理过程中给定少量示例即可学习新任务。然而,对ICL的理论理解在很大程度上尚未被探索,特别是Transformer是否可以通过训练来推广到prompt中未见过的示例,这将要求模型获得prompt的上下文知识以进行泛化。本文通过非线性回归任务的角度研究了Transformer通过梯度下降的训练动态。这里的上下文泛化可以通过学习每个任务的模板函数来实现,其中所有模板函数都位于具有$m$个基函数的线性空间中。我们分析了单层多头Transformer的训练动态,以在上下文环境中预测给定部分标记prompt的未标记输入,其中标签包含高斯噪声,并且每个prompt中的示例数量不足以确定模板。在温和的假设下,我们表明单层多头Transformer的训练损失线性收敛到全局最小值。此外,Transformer有效地学习对基函数执行岭回归。据我们所知,这项研究是第一个可证明的演示,即当prompt仅包含少量查询-答案对时,Transformer可以学习上下文(即模板)信息以推广到未见过的示例和任务。
🔬 方法详解
问题定义:论文旨在解决Transformer在上下文学习中,如何通过学习prompt中的少量示例,泛化到未见过的示例和任务的问题。现有方法缺乏对Transformer上下文学习能力的理论理解,特别是模型如何获取prompt的上下文知识以进行泛化。
核心思路:论文的核心思路是将上下文学习问题建模为非线性回归任务,并假设每个任务都存在一个模板函数,Transformer通过学习这些模板函数来实现上下文泛化。具体来说,所有模板函数都位于一个线性空间中,Transformer需要学习这个线性空间的基函数。
技术框架:论文分析了单层多头Transformer的训练动态。给定部分标记的prompt,Transformer需要预测未标记的输入。标签中包含高斯噪声,并且每个prompt中的示例数量不足以确定模板。论文研究了Transformer如何通过梯度下降学习这些模板函数。
关键创新:论文最重要的技术创新点在于,从理论上证明了Transformer可以通过学习prompt中的少量示例,学习到任务的模板信息,从而泛化到未见过的示例和任务。这是第一个可证明的关于Transformer上下文学习能力的理论结果。
关键设计:论文的关键设计包括:1) 将上下文学习问题建模为非线性回归任务;2) 假设存在一个线性空间的模板函数;3) 分析单层多头Transformer的训练动态;4) 证明Transformer可以学习执行岭回归。论文还对标签中的高斯噪声和每个prompt中的示例数量进行了假设,以保证理论结果的成立。
🖼️ 关键图片
📊 实验亮点
论文证明了单层多头Transformer的训练损失线性收敛到全局最小值,并且Transformer有效地学习对基函数执行岭回归。这是首次从理论上证明Transformer可以学习上下文信息,并推广到未见过的示例和任务,为理解Transformer的上下文学习能力提供了重要依据。
🎯 应用场景
该研究成果可应用于提升大型语言模型在小样本学习场景下的性能,例如在资源受限的环境中快速适应新任务。此外,该理论分析有助于理解和改进Transformer的架构设计,使其更有效地进行上下文学习,并为开发更强大的通用人工智能系统奠定基础。
📄 摘要(原文)
In-context learning (ICL) refers to a remarkable capability of pretrained large language models, which can learn a new task given a few examples during inference. However, theoretical understanding of ICL is largely under-explored, particularly whether transformers can be trained to generalize to unseen examples in a prompt, which will require the model to acquire contextual knowledge of the prompt for generalization. This paper investigates the training dynamics of transformers by gradient descent through the lens of non-linear regression tasks. The contextual generalization here can be attained via learning the template function for each task in-context, where all template functions lie in a linear space with $m$ basis functions. We analyze the training dynamics of one-layer multi-head transformers to in-contextly predict unlabeled inputs given partially labeled prompts, where the labels contain Gaussian noise and the number of examples in each prompt are not sufficient to determine the template. Under mild assumptions, we show that the training loss for a one-layer multi-head transformer converges linearly to a global minimum. Moreover, the transformer effectively learns to perform ridge regression over the basis functions. To our knowledge, this study is the first provable demonstration that transformers can learn contextual (i.e., template) information to generalize to both unseen examples and tasks when prompts contain only a small number of query-answer pairs.