Expansion Span: Combining Fading Memory and Retrieval in Hybrid State Space Models

📄 arXiv: 2412.13328v2 📥 PDF

作者: Elvis Nunez, Luca Zancato, Benjamin Bowman, Aditya Golatkar, Wei Xia, Stefano Soatto

分类: cs.CL, cs.LG

发布日期: 2024-12-17 (更新: 2025-05-25)


💡 一句话要点

提出Span-Expanded Attention,扩展混合状态空间模型的记忆范围,提升长序列建模能力。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 状态空间模型 长序列建模 Attention机制 记忆扩展 混合模型

📋 核心要点

  1. 现有混合SSM模型在处理长序列时,无法有效回忆远距离信息,且Attention机制的记忆范围受限。
  2. 提出Span-Expanded Attention (SE-Attn),通过动态检索并聚合远距离相关token,扩展模型的有效记忆范围。
  3. 结合HyLoRA微调方法,SE-Attn在长序列任务上表现出优越的性能,且计算成本低于LongLoRA等替代方案。

📝 摘要(中文)

状态空间模型(SSM)的“状态”代表其记忆,该记忆随时间以指数形式衰减。相比之下,基于Attention的模型在有限的跨度(上下文大小)上具有“过目不忘”的记忆。混合架构结合了状态空间层和Attention层,但仍然无法回忆遥远的过去,并且只能“过目不忘”地访问最近的token。与当前结合SSM和Attention层的方法不同,我们允许基于相关性而不是时间顺序来分配状态。这样,对于每一组新的查询token,我们的模型都可以“过目不忘”地访问来自当前混合SSM的Attention跨度之外的token,而无需额外的硬件资源。我们引入了一种方法,通过“保留”一部分Attention上下文用于从过去检索的token,来扩展混合状态的记忆跨度,从而扩展整体状态的“过目不忘”的记忆跨度。我们将这部分保留的token称为“扩展跨度”,并将检索和聚合它的机制称为“Span-Expanded Attention”(SE-Attn)。为了使混合模型适应使用SE-Attn,我们提出了一种新颖的微调方法,该方法将LoRA扩展到混合模型(HyLoRA),并允许在长token跨度上进行高效的适应。我们表明,SE-Attn使我们能够有效地在比用于预训练的序列长达8倍的token序列上调整预训练的混合模型。我们表明,当应用于具有长程依赖性的自然语言基准(如PG-19、RULER)和其他常见的自然语言下游任务时,带有SE-Attn的HyLoRA比LongLoRA等替代方案更便宜且性能更好。

🔬 方法详解

问题定义:现有混合状态空间模型在处理长序列时面临记忆衰减和Attention范围限制的问题。传统的SSM虽然具有无限的理论记忆范围,但实际应用中信息会随时间指数衰减。而Attention机制虽然能精确记忆,但受限于上下文窗口大小,无法有效捕捉长距离依赖关系。因此,如何有效地结合两者的优势,扩展模型的记忆范围,是本论文要解决的核心问题。

核心思路:论文的核心思路是引入“扩展跨度”(Expansion Span)的概念,允许模型在Attention机制中动态地检索并聚合来自远距离历史的信息。通过这种方式,模型可以“选择性”地记住重要的历史信息,而不是简单地依赖于最近的上下文。这种基于相关性的记忆方式,能够更有效地捕捉长序列中的依赖关系。

技术框架:整体框架是在混合SSM模型的基础上,引入Span-Expanded Attention (SE-Attn)模块。SE-Attn模块负责从整个历史序列中检索与当前查询相关的token,并将这些token添加到Attention的上下文中。为了高效地训练SE-Attn,论文还提出了HyLoRA微调方法,该方法将LoRA技术扩展到混合模型,允许在长序列上进行高效的参数调整。

关键创新:最重要的创新点在于Span-Expanded Attention机制,它打破了传统Attention机制的上下文窗口限制,允许模型动态地访问和利用远距离的历史信息。与现有方法相比,SE-Attn不需要额外的硬件资源,并且能够更有效地捕捉长序列中的依赖关系。

关键设计:SE-Attn的关键设计包括:1) 如何选择需要检索的token(检索策略,例如基于相似度);2) 如何将检索到的token与当前的上下文进行融合(聚合策略,例如加权平均);3) 如何高效地训练SE-Attn(HyLoRA微调方法)。HyLoRA通过冻结大部分预训练参数,只微调少量参数,从而降低了训练成本,并避免了灾难性遗忘。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,SE-Attn结合HyLoRA微调方法,在PG-19和RULER等长序列基准测试中,性能优于LongLoRA等替代方案,且计算成本更低。具体而言,该方法能够在比预训练序列长8倍的序列上进行有效微调,并显著提升模型在长距离依赖关系任务上的性能。

🎯 应用场景

该研究成果可广泛应用于需要处理长序列数据的自然语言处理任务,例如长文本摘要、文档理解、机器翻译等。通过扩展模型的记忆范围,可以提高模型在这些任务上的性能,并更好地理解长距离依赖关系。此外,该方法还可以应用于其他领域,例如时间序列分析、视频理解等。

📄 摘要(原文)

The "state" of State Space Models (SSMs) represents their memory, which fades exponentially over an unbounded span. By contrast, Attention-based models have "eidetic" (i.e., verbatim, or photographic) memory over a finite span (context size). Hybrid architectures combine State Space layers with Attention, but still cannot recall the distant past and can access only the most recent tokens eidetically. Unlike current methods of combining SSM and Attention layers, we allow the state to be allocated based on relevancy rather than recency. In this way, for every new set of query tokens, our models can "eidetically" access tokens from beyond the Attention span of current Hybrid SSMs without requiring extra hardware resources. We introduce a method to expand the memory span of the hybrid state by "reserving" a fraction of the Attention context for tokens retrieved from arbitrarily distant in the past, thus expanding the eidetic memory span of the overall state. We call this reserved fraction of tokens the "expansion span," and the mechanism to retrieve and aggregate it "Span-Expanded Attention" (SE-Attn). To adapt Hybrid models to using SE-Attn, we propose a novel fine-tuning method that extends LoRA to Hybrid models (HyLoRA) and allows efficient adaptation on long spans of tokens. We show that SE-Attn enables us to efficiently adapt pre-trained Hybrid models on sequences of tokens up to 8 times longer than the ones used for pre-training. We show that HyLoRA with SE-Attn is cheaper and more performant than alternatives like LongLoRA when applied to Hybrid models on natural language benchmarks with long-range dependencies, such as PG-19, RULER, and other common natural language downstream tasks.