Decision Mamba: Reinforcement Learning via Hybrid Selective Sequence Modeling

📄 arXiv: 2406.00079v1 📥 PDF

作者: Sili Huang, Jifeng Hu, Zhejian Yang, Liwei Yang, Tao Luo, Hechang Chen, Lichao Sun, Bo Yang

分类: cs.LG

发布日期: 2024-05-31

备注: arXiv admin note: text overlap with arXiv:2405.20692. arXiv admin note: text overlap with arXiv:2405.20692; text overlap with arXiv:2305.16554, arXiv:2210.14215 by other authors


💡 一句话要点

提出Decision Mamba-Hybrid,结合Transformer和Mamba优势,提升强化学习长时序决策效率。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 强化学习 长时序建模 Mamba模型 Transformer模型 序列建模 决策Transformer 混合模型

📋 核心要点

  1. Transformer在强化学习中表现出色,但其注意力机制导致计算复杂度高,限制了其在长时序任务中的应用。
  2. Decision Mamba-Hybrid (DM-H)结合Mamba和Transformer的优势,利用Mamba生成子目标,引导Transformer进行高质量预测。
  3. 实验表明,DM-H在长短期任务中均达到SOTA,且长时任务的在线测试速度比Transformer快28倍。

📝 摘要(中文)

本文提出了一种基于混合选择性序列建模的强化学习方法,称为Decision Mamba (DM)。该方法旨在解决Transformer模型在强化学习中因注意力机制的二次计算复杂度而导致的计算成本过高的问题,尤其是在处理长时序任务时。首先,通过将Decision Transformer (DT)的骨干网络替换为Mamba模型来实现DM。然后,提出Decision Mamba-Hybrid (DM-H),它结合了Transformer的高质量预测能力和Mamba的长时记忆能力。具体来说,DM-H首先通过Mamba模型从长期记忆中生成高价值的子目标,然后使用这些子目标来提示Transformer,从而实现高质量的预测。实验结果表明,DM-H在D4RL、Grid World和Tmaze等长短期任务基准测试中均达到了最先进的水平。在效率方面,DM-H在长时任务中的在线测试速度比基于Transformer的基线快28倍。

🔬 方法详解

问题定义:现有基于Transformer的强化学习方法在处理长时序决策问题时,由于Transformer的注意力机制的二次方复杂度,计算成本非常高昂,限制了其应用范围。尤其是在需要长期记忆的任务中,Transformer的效率会显著下降。

核心思路:本文的核心思路是结合Mamba模型和Transformer模型的优势。Mamba模型以其高效处理长序列依赖关系的能力而闻名,而Transformer模型则擅长进行高质量的预测。通过将两者结合,可以既保证长时记忆能力,又实现准确的决策。

技术框架:DM-H的整体框架包含两个主要阶段。首先,利用Mamba模型从长期记忆中提取并生成高价值的子目标。这些子目标代表了对未来状态的期望或规划。然后,将这些子目标作为提示信息输入到Transformer模型中,引导Transformer进行策略预测。Transformer基于这些子目标,生成更准确、更有效的动作序列。

关键创新:DM-H的关键创新在于混合使用Mamba和Transformer。Mamba负责处理长时依赖关系和生成子目标,Transformer负责基于子目标进行高质量的策略预测。这种混合架构克服了Transformer在长序列处理上的局限性,同时保留了其强大的预测能力。

关键设计:DM-H的关键设计包括:1) 使用Mamba的Selective State Space Models (SSM) 结构来高效处理长序列;2) 设计合适的子目标表示方法,使其能够有效地引导Transformer的策略预测;3) 优化Mamba和Transformer之间的交互方式,确保信息能够顺畅地传递和利用。具体的损失函数和网络结构细节在论文中进行了详细描述(未知)。

📊 实验亮点

实验结果表明,DM-H在D4RL、Grid World和Tmaze等基准测试中均取得了state-of-the-art的性能。尤其是在长时任务中,DM-H的在线测试速度比基于Transformer的基线快28倍。这表明DM-H在保证性能的同时,显著提高了计算效率,使其更适用于实际应用。

🎯 应用场景

该研究成果可应用于需要长期规划和决策的机器人控制、游戏AI、自动驾驶等领域。通过提高强化学习算法在长时序任务中的效率和性能,可以使智能体更好地适应复杂环境,完成更具挑战性的任务。例如,在机器人导航中,机器人可以利用DM-H进行长期路径规划,避开障碍物并最终到达目标地点。

📄 摘要(原文)

Recent works have shown the remarkable superiority of transformer models in reinforcement learning (RL), where the decision-making problem is formulated as sequential generation. Transformer-based agents could emerge with self-improvement in online environments by providing task contexts, such as multiple trajectories, called in-context RL. However, due to the quadratic computation complexity of attention in transformers, current in-context RL methods suffer from huge computational costs as the task horizon increases. In contrast, the Mamba model is renowned for its efficient ability to process long-term dependencies, which provides an opportunity for in-context RL to solve tasks that require long-term memory. To this end, we first implement Decision Mamba (DM) by replacing the backbone of Decision Transformer (DT). Then, we propose a Decision Mamba-Hybrid (DM-H) with the merits of transformers and Mamba in high-quality prediction and long-term memory. Specifically, DM-H first generates high-value sub-goals from long-term memory through the Mamba model. Then, we use sub-goals to prompt the transformer, establishing high-quality predictions. Experimental results demonstrate that DM-H achieves state-of-the-art in long and short-term tasks, such as D4RL, Grid World, and Tmaze benchmarks. Regarding efficiency, the online testing of DM-H in the long-term task is 28$\times$ times faster than the transformer-based baselines.