B'MOJO: Hybrid State Space Realizations of Foundation Models with Eidetic and Fading Memory
作者: Luca Zancato, Arjun Seshadri, Yonatan Dukler, Aditya Golatkar, Yantao Shen, Benjamin Bowman, Matthew Trager, Alessandro Achille, Stefano Soatto
分类: cs.LG, cs.CL, cs.NE
发布日期: 2024-07-08
💡 一句话要点
B'MOJO:融合显式和隐式记忆的混合状态空间模型,提升长序列建模能力
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 状态空间模型 长序列建模 显式记忆 隐式记忆 转导推理 语言建模 随机实现理论
📋 核心要点
- 现有架构在处理长序列时,要么依赖有限上下文的显式记忆,要么依赖无限范围的衰减记忆,难以兼顾效率和长期依赖。
- B'MOJO利用随机实现理论,设计了一种可组合模块,能够无缝融合显式和隐式记忆,灵活调节两者的比例。
- 实验表明,B'MOJO在转导推理任务上优于现有SSM和混合模型,并在长序列语言建模上表现出卓越的性能。
📝 摘要(中文)
本文提出了一种新的架构族B'MOJO,旨在支持转导推理,允许记忆增长到有限但事先未知的边界,同时有效利用有限的推理资源。现有架构要么在有限范围内显式地表示数据(如Transformer中的“上下文”),要么在无限范围内隐式地衰减(如状态空间模型SSM)。最近的混合架构结合了显式和隐式记忆,但存在局限性,无法让设计者或学习过程无缝地调节两者,也无法扩展显式记忆的范围。我们利用随机实现理论的思想,开发了一种名为B'MOJO的模型类,以在基本的可组合模块中无缝地结合显式和隐式记忆。该架构可以实现访问“上下文”中的短期显式记忆、“权重”中的永久结构记忆、“状态”中的衰减记忆以及“存储”中的长期显式记忆的模型,通过原生结合从异步更新的记忆中检索信息。我们证明了Transformer、Mamba等现有SSM以及Jamba等混合架构是B'MOJO的特例,并描述了一个基本的开源实现,可以在硬件中高效地堆叠和扩展。我们在联想回忆等转导推理任务上测试了B'MOJO,其性能优于现有的SSM和混合模型。作为基线,我们在普通语言建模上测试了B'MOJO,其困惑度与参数量相似的Transformer和SSM相当,最高可达14亿参数,同时训练速度提高了10%。最后,我们表明B'MOJO调节显式和隐式记忆的能力可以在更长的序列上实现更好的推理,测试长度高达32K tokens,是训练期间最长序列长度的四倍。
🔬 方法详解
问题定义:现有Transformer和SSM模型在处理长序列时存在局限性。Transformer的上下文窗口有限,无法捕捉长期依赖关系。SSM虽然具有无限上下文,但其记忆是衰减的,难以精确地记住历史信息。混合架构试图结合两者的优点,但难以灵活地调节显式和隐式记忆之间的平衡,也无法扩展显式记忆的范围。
核心思路:B'MOJO的核心思路是利用随机实现理论,将显式(eidetic)和隐式(fading)记忆无缝地集成到一个统一的框架中。通过这种方式,模型可以根据任务的需求,灵活地调节两种记忆的比例,从而更好地捕捉序列中的长期依赖关系和关键信息。这种设计允许模型在“上下文”中访问短期显式记忆,在“权重”中访问永久结构记忆,在“状态”中访问衰减记忆,并在“存储”中访问长期显式记忆。
技术框架:B'MOJO的整体架构基于可组合的模块,每个模块都包含显式和隐式记忆组件。这些模块可以堆叠起来,形成一个深层网络。模型可以通过异步更新的外部存储器进行检索,从而实现长期记忆。该框架允许模型访问四种类型的记忆:短期显式记忆(in-context)、永久结构记忆(in-weights)、衰减记忆(in-state)和长期显式记忆(in-storage)。
关键创新:B'MOJO的关键创新在于其能够无缝地结合显式和隐式记忆,并允许设计者或学习过程灵活地调节两者。与现有方法相比,B'MOJO能够更好地捕捉序列中的长期依赖关系和关键信息。此外,B'MOJO还能够通过异步更新的外部存储器进行检索,从而实现长期记忆。
关键设计:B'MOJO的具体实现细节(如参数设置、损失函数、网络结构等)在论文中没有详细描述,标记为待开源。但从整体架构来看,关键设计在于如何有效地融合显式和隐式记忆,以及如何实现异步更新的外部存储器检索。这些设计细节将直接影响模型的性能和效率。
🖼️ 关键图片
📊 实验亮点
B'MOJO在联想回忆等转导推理任务上优于现有的SSM和混合模型。在普通语言建模任务中,B'MOJO的困惑度与参数量相似的Transformer和SSM相当,最高可达14亿参数,同时训练速度提高了10%。此外,B'MOJO在长度高达32K tokens的长序列上实现了更好的推理,是训练期间最长序列长度的四倍。
🎯 应用场景
B'MOJO具有广泛的应用前景,包括自然语言处理、时间序列分析、视频理解等领域。它可以用于构建更强大的语言模型、预测模型和智能系统,例如,在对话系统中,B'MOJO可以更好地记住用户的历史对话,从而提供更个性化的服务。在金融领域,它可以用于更准确地预测股票价格走势。在医疗领域,它可以用于分析患者的病历,从而辅助医生进行诊断。
📄 摘要(原文)
We describe a family of architectures to support transductive inference by allowing memory to grow to a finite but a-priori unknown bound while making efficient use of finite resources for inference. Current architectures use such resources to represent data either eidetically over a finite span ("context" in Transformers), or fading over an infinite span (in State Space Models, or SSMs). Recent hybrid architectures have combined eidetic and fading memory, but with limitations that do not allow the designer or the learning process to seamlessly modulate the two, nor to extend the eidetic memory span. We leverage ideas from Stochastic Realization Theory to develop a class of models called B'MOJO to seamlessly combine eidetic and fading memory within an elementary composable module. The overall architecture can be used to implement models that can access short-term eidetic memory "in-context," permanent structural memory "in-weights," fading memory "in-state," and long-term eidetic memory "in-storage" by natively incorporating retrieval from an asynchronously updated memory. We show that Transformers, existing SSMs such as Mamba, and hybrid architectures such as Jamba are special cases of B'MOJO and describe a basic implementation, to be open sourced, that can be stacked and scaled efficiently in hardware. We test B'MOJO on transductive inference tasks, such as associative recall, where it outperforms existing SSMs and Hybrid models; as a baseline, we test ordinary language modeling where B'MOJO achieves perplexity comparable to similarly-sized Transformers and SSMs up to 1.4B parameters, while being up to 10% faster to train. Finally, we show that B'MOJO's ability to modulate eidetic and fading memory results in better inference on longer sequences tested up to 32K tokens, four-fold the length of the longest sequences seen during training.