You can remove GPT2's LayerNorm by fine-tuning
作者: Stefan Heimersheim
分类: cs.CL, cs.LG
发布日期: 2024-09-06 (更新: 2024-11-17)
备注: Presented at the Attributing Model Behavior at Scale (ATTRIB) and Interpretable AI: Past, Present, and Future workshops at NeurIPS 2024
🔗 代码/项目: GITHUB | HUGGINGFACE
💡 一句话要点
通过微调去除GPT2的LayerNorm层,简化模型并保持性能
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: GPT2 LayerNorm 微调 机制可解释性 Transformer模型
📋 核心要点
- GPT风格Transformer模型中的LayerNorm层阻碍了机制可解释性研究,其非线性特性使得残差流的解释和模型分解为电路变得困难。
- 该论文通过微调预训练的GPT2-small模型,成功移除了LayerNorm层,核心思想是在训练过程中使模型适应没有LayerNorm的结构。
- 实验结果表明,移除LayerNorm后的模型在多个数据集上与原始模型性能相当,证明了LayerNorm在推理时可能不是必需的。
📝 摘要(中文)
本文研究表明,可以通过在部分训练数据(500M tokens)上微调预训练的GPT2-small模型来移除LayerNorm(LN)层。实验证明,在OpenWebText和ThePile数据集上,无LN的模型与原始模型取得了相似的性能(交叉熵损失仅下降0.05),在Hellaswag基准测试中也取得了相近的准确率(下降0.5%)。该研究不仅为机制可解释性研究提供了一个简化的模型,也证明了LN层在Transformer模型推理时可能并非至关重要。代码和微调后的模型已开源。
🔬 方法详解
问题定义:现有的GPT系列Transformer模型中,LayerNorm层虽然对稳定训练至关重要,但其非线性特性给模型的机制可解释性带来了挑战。研究者难以理解残差流,也难以将模型分解为更小的电路单元。因此,如何移除LayerNorm层,同时保持模型性能,是一个重要的研究问题。
核心思路:该论文的核心思路是通过微调,使模型适应没有LayerNorm层的结构。作者认为,虽然LayerNorm在预训练阶段对稳定训练至关重要,但在推理阶段,模型可能已经学习到了一种内在的表示,使得它可以不需要LayerNorm也能正常工作。因此,通过微调,模型可以重新调整其内部参数,以适应没有LayerNorm的环境。
技术框架:该研究的技术框架非常简单:首先,使用预训练的GPT2-small模型作为起点。然后,移除模型中的所有LayerNorm层。最后,在一个较小的数据集(500M tokens)上对修改后的模型进行微调。微调的目标是最小化交叉熵损失,使得模型能够尽可能地恢复其原始性能。
关键创新:该论文的关键创新在于证明了LayerNorm层在Transformer模型中并非绝对必要。虽然LayerNorm在训练过程中起到了稳定作用,但在推理时,模型可以通过微调来适应没有LayerNorm的环境。这为简化模型结构,提高可解释性提供了新的思路。
关键设计:该研究的关键设计在于微调数据集的大小和微调的轮数。作者选择了500M tokens的数据集,并进行了多次实验,以找到最佳的微调策略。此外,作者还使用了标准的交叉熵损失函数作为微调的目标函数。
📊 实验亮点
实验结果表明,通过微调移除LayerNorm层后的GPT2-small模型,在OpenWebText和ThePile数据集上,交叉熵损失仅下降0.05,性能与原始模型基本持平。在Hellaswag基准测试中,准确率仅下降0.5%。这些结果表明,即使没有LayerNorm层,模型依然能够保持良好的性能。
🎯 应用场景
该研究成果可应用于对大型语言模型进行机制可解释性分析,简化模型结构,便于研究人员理解模型的内部运作机制。此外,该方法也为设计更高效、更易于理解的Transformer模型提供了新的思路,可能促进轻量级模型的发展,并降低模型部署的计算成本。
📄 摘要(原文)
The LayerNorm (LN) layer in GPT-style transformer models has long been a hindrance to mechanistic interpretability. LN is a crucial component required to stabilize the training of large language models, and LN or the similar RMSNorm have been used in practically all large language models based on the transformer architecture. The non-linear nature of the LN layers is a hindrance for mechanistic interpretability as it hinders interpretation of the residual stream, and makes it difficult to decompose the model into circuits. Some researchers have gone so far as to name "reasons interpretability researchers hate layer norm." In this paper we show that it is possible to remove the LN layers from a pre-trained GPT2-small model by fine-tuning on a fraction (500M tokens) of the training data. We demonstrate that this LN-free model achieves similar performance to the original model on the OpenWebText and ThePile datasets (-0.05 cross-entropy loss), and the Hellaswag benchmark (-0.5% accuracy). We provide our implementation at https://github.com/ApolloResearch/gpt2_noLN, and fine-tuned GPT2-small models at https://huggingface.co/apollo-research/gpt2_noLN. Our work not only provides a simplified model for mechanistic interpretability research, but also provides evidence that the LN layers, at inference time, do not play a crucial role in transformer models.