Kronecker Embeddings: Byte-Level Structured Token Representations for Parameter-Efficient Language Models

📄 arXiv: 2605.29459v1 📥 PDF

作者: Rohan Shravan

分类: cs.CL, cs.LG

发布日期: 2026-05-28

备注: 28 pages, 16 tables. Reference implementation: https://github.com/theschoolofai/kronecker-embeddings


💡 一句话要点

提出Kronecker嵌入,通过字节级结构化表示显著降低语言模型参数量。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 参数高效 语言模型 嵌入表示 字节级别 Kronecker嵌入

📋 核心要点

  1. 现有语言模型的嵌入层参数量巨大,成为模型规模扩展的瓶颈。
  2. Kronecker嵌入利用字节级信息,通过固定编码器和学习投影替代传统嵌入表。
  3. 实验表明,Kronecker嵌入在参数效率、拼写鲁棒性和训练稳定性方面优于BPE。

📝 摘要(中文)

大型语言模型将每个输入通过一个学习到的嵌入表(形状为|V| x d_model),这在最前沿的模型中消耗了数亿到数十亿的可训练参数。我们引入了Kronecker嵌入,这是一种确定性的字节级字符位置分解方法,它用一个固定的编码器和一个单一的学习投影来代替这个表,与标准的BPE分词器兼容,从而在最前沿的模型中消除了91-94%的输入端可训练参数。我们提供了五个贡献。首先,对六个语言模型(1.35亿-6710亿参数)的跨模型探针显示,训练后的输入嵌入对探针词的印刷变体的聚类程度远高于形态相关的词;Kronecker在嵌入层避免了这种聚类。其次,在nanoGPT GPT-2 1.24亿参数模型上,对FineWeb-Edu的25亿个token进行了三次种子对比实验,结果表明,Kronecker达到了比BPE绑定的基线低2.5 +- 0.2%的验证损失(差距为0.083 +- 0.007 nats,约9%的困惑度降低),并且达到BPE收敛损失所需的步骤减少了约1.43倍。第三,对110个干净/拼写错误对的拼写鲁棒性探针显示,Kronecker在55.5%的对上保留了top-1预测,而BPE为47.3%(+8.2个百分点),并且KL散度降低了7.6%,在11个类别中的10个类别中获胜或打平;生成探针显示,Kronecker通过生成回显字节新颖的字符串和拼写错误,而BPE则忘记了它们。第四,BPE嵌入范数在训练期间漂移,而Kronecker投影范数保持在1.0附近,这与稳定的表示目标一致。第五,一种即时运行时变体从一个4.5 MB的字节缓冲区而不是一个131,072词汇量的2.15 GB表中重建嵌入,步长时间开销为0.01-0.24%。字节级局部性有一个权衡:字节相似但语义上距离较远的对(compute/commute, nation/notion)会聚集在一起,从而将消歧转移到早期的注意力层。

🔬 方法详解

问题定义:现有大型语言模型的输入嵌入层通常采用one-hot编码后接一个大型嵌入表,该表的大小为词汇表大小乘以嵌入维度。随着词汇表的增大,嵌入表的参数量也急剧增加,成为模型参数量的重要组成部分,限制了模型的可扩展性。此外,传统的嵌入方式对词汇的形态和拼写变体缺乏鲁棒性,相似的词汇在嵌入空间中可能距离较远。

核心思路:Kronecker嵌入的核心思想是利用字节级别的结构化信息来表示token,从而避免直接学习一个庞大的嵌入表。它将token分解为字节序列,并利用一个固定的编码器和一个学习到的投影层来生成嵌入向量。这种方法可以显著减少参数量,并且能够更好地捕捉词汇的形态和拼写变体之间的关系。

技术框架:Kronecker嵌入的整体框架包括以下几个主要模块:1) 字节编码器:将输入的token分解为字节序列,并使用一个固定的编码器(例如,one-hot编码)将每个字节转换为向量表示。2) 位置编码:为每个字节添加位置编码,以区分token中不同位置的字节。3) 投影层:将字节编码和位置编码的组合投影到一个低维的嵌入空间中。这个投影层是唯一需要学习的参数。4) 嵌入向量:将投影后的向量作为token的嵌入向量。

关键创新:Kronecker嵌入的关键创新在于使用字节级别的结构化信息来表示token,从而避免了直接学习一个庞大的嵌入表。与传统的嵌入方法相比,Kronecker嵌入具有以下优点:1) 参数效率高:显著减少了嵌入层的参数量。2) 拼写鲁棒性强:能够更好地捕捉词汇的形态和拼写变体之间的关系。3) 训练稳定性好:嵌入向量的范数在训练过程中更加稳定。

关键设计:Kronecker嵌入的关键设计包括:1) 字节编码器的选择:可以使用不同的字节编码器,例如one-hot编码或学习到的字节嵌入。2) 位置编码的选择:可以使用不同的位置编码方法,例如正弦位置编码或学习到的位置嵌入。3) 投影层的结构:可以使用不同的投影层结构,例如线性层或多层感知机。4) 损失函数:可以使用不同的损失函数来训练投影层,例如交叉熵损失或对比损失。

📊 实验亮点

实验结果表明,Kronecker嵌入在nanoGPT GPT-2 124M模型上,验证损失比BPE基线降低了2.5%,达到BPE收敛损失所需的步骤减少了1.43倍。在拼写鲁棒性方面,Kronecker嵌入在55.5%的拼写错误对上保留了top-1预测,而BPE为47.3%,KL散度降低了7.6%。

🎯 应用场景

Kronecker嵌入可应用于各种自然语言处理任务,尤其适用于资源受限的场景,如移动设备或边缘计算。它可以降低模型大小,提高推理速度,并增强模型对拼写错误的鲁棒性。未来,Kronecker嵌入有望促进更高效、更可靠的语言模型部署。

📄 摘要(原文)

Large language models route every input through a learned embedding table of shape |V| x d_model, consuming hundreds of millions to billions of trainable parameters at frontier scale. We introduce Kronecker Embeddings, a deterministic byte-level character-position factorization that replaces this table with a fixed encoder and a single learned projection, compatible with standard BPE tokenizers, eliminating 91--94% of input-side trainable parameters at frontier scale. We provide five contributions. First, a cross-model probe across six LMs (135M-671B parameters) shows trained input embeddings cluster typographic variants of the probe word far more than morphological relatives; Kronecker escapes this clustering at the embedding layer. Second, a controlled three-seed comparison on nanoGPT GPT-2 124M over 2.5B tokens of FineWeb-Edu shows Kronecker reaching 2.5 +- 0.2% lower validation loss than the BPE-tied baseline (gap 0.083 +- 0.007 nats, ~9% lower perplexity), needing ~1.43x fewer steps to reach BPE's converged loss. Third, a spelling-robustness probe over 110 clean/typo pairs shows Kronecker preserves the top-1 prediction on 55.5% of pairs vs. 47.3% for BPE (+8.2 pp) and lowers KL by 7.6%, winning or tying in 10 of 11 categories; a generation probe shows Kronecker echoes byte-novel strings and typos through generation where BPE forgets them. Fourth, BPE embedding norm drifts during training while Kronecker projection norm stays near 1.0, consistent with a stable representational target. Fifth, an on-the-fly runtime variant reconstructs embeddings from a 4.5 MB byte buffer rather than a 2.15 GB table at vocabulary 131,072, with 0.01--0.24% step-time overhead. Byte-level locality has a tradeoff: byte-similar but semantically distant pairs (compute/commute, nation/notion) cluster together, shifting disambiguation to early attention layers.