Is Mamba Compatible with Trajectory Optimization in Offline Reinforcement Learning?

📄 arXiv: 2405.12094v2 📥 PDF

作者: Yang Dai, Oubo Ma, Longfei Zhang, Xingxing Liang, Shengchao Hu, Mengzhu Wang, Shouling Ji, Jincai Huang, Li Shen

分类: cs.LG

发布日期: 2024-05-20 (更新: 2024-10-27)

备注: 23 pages, 11 figures


💡 一句话要点

提出Decision Mamba (DeMa),在离线强化学习轨迹优化中实现更优性能和参数效率。

🎯 匹配领域: 支柱一:机器人控制 (Robot Control) 支柱二:RL算法与架构 (RL & Architecture)

关键词: 离线强化学习 轨迹优化 Mamba模型 序列建模 决策Transformer

📋 核心要点

  1. Transformer在离线强化学习轨迹优化中表现优异,但参数量大,限制了其在资源受限设备上的应用。
  2. 论文提出Decision Mamba (DeMa),利用Mamba的线性时间序列建模能力,减少参数量,提升计算效率。
  3. 实验表明,DeMa在Atari和MuJoCo上超越Decision Transformer,同时显著减少了参数使用。

📝 摘要(中文)

基于Transformer的轨迹优化方法在离线强化学习(offline RL)中表现出色,但参数量大、可扩展性有限,这在计算资源受限的序列决策场景(如机器人和无人机)中尤为关键。Mamba作为一种新型线性时间序列模型,在长序列上提供了与Transformer相当的性能,同时显著减少了参数。本文旨在通过全面的实验,从数据结构和关键组件的角度,探索Mamba在离线RL中轨迹优化的潜力,并提出了Decision Mamba (DeMa)。研究发现,长序列会带来显著的计算负担,但对性能提升没有贡献,因为DeMa对序列的关注度呈近似指数递减。因此,我们引入了一种类似于Transformer的DeMa,而非RNN式的DeMa。对于DeMa的组件,我们发现隐藏注意力机制是其成功的关键因素,它可以与其他残差结构很好地协同工作,并且不需要位置嵌入。大量的评估表明,我们专门设计的DeMa与轨迹优化兼容,并且超越了以前的方法,在Atari上以更少的30%参数超越了Decision Transformer (DT),在MuJoCo上仅用DT四分之一的参数就超越了DT。

🔬 方法详解

问题定义:现有基于Transformer的离线强化学习轨迹优化方法,如Decision Transformer (DT),虽然性能优越,但模型参数量巨大,计算复杂度高,难以部署在计算资源受限的设备上,例如机器人和无人机。因此,如何在保证性能的同时,降低模型参数量和计算复杂度,是亟待解决的问题。

核心思路:论文的核心思路是利用Mamba模型替代Transformer模型,进行轨迹优化。Mamba模型具有线性时间复杂度,能够显著减少长序列建模的计算负担,同时保持与Transformer相当的性能。通过对Mamba模型的结构进行调整和优化,使其更适合轨迹优化任务,从而实现性能和效率的平衡。

技术框架:论文提出的Decision Mamba (DeMa) 框架主要包含以下几个部分:1) 状态、动作和奖励序列的嵌入层;2) 多个Mamba块组成的序列建模层;3) 策略预测头和价值预测头。DeMa接收离线数据集中的轨迹序列作为输入,通过Mamba块进行序列建模,然后通过策略预测头预测动作,通过价值预测头预测状态价值。整个框架采用端到端的方式进行训练。

关键创新:论文最重要的技术创新点在于将Mamba模型引入到离线强化学习的轨迹优化任务中,并针对该任务对Mamba模型进行了改进。具体来说,论文发现长序列对性能提升的贡献有限,因此采用了一种类似于Transformer的结构,而非RNN式的结构。此外,论文还发现隐藏注意力机制是Mamba成功的关键,并且可以与其他残差结构很好地协同工作,不需要位置嵌入。

关键设计:DeMa的关键设计包括:1) 采用Transformer-like的结构,避免RNN-like结构带来的计算负担;2) 强调隐藏注意力机制的重要性,并将其与其他残差结构结合;3) 去除位置嵌入,进一步减少参数量;4) 针对轨迹优化任务,对Mamba块的参数进行了精细调整,例如状态空间模型的维度、选择机制的参数等。损失函数采用策略梯度损失和价值函数损失的加权和。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,DeMa在Atari游戏上,使用比Decision Transformer (DT) 少30%的参数,取得了更高的性能。在MuJoCo连续控制任务上,DeMa仅使用DT四分之一的参数,就超越了DT的性能。这些结果充分证明了DeMa在离线强化学习轨迹优化中的有效性和优越性。

🎯 应用场景

该研究成果可应用于资源受限的序列决策场景,例如机器人控制、无人机导航、自动驾驶等。通过降低模型参数量和计算复杂度,DeMa能够使这些设备在有限的计算资源下实现更高效的决策和控制,从而提高其智能化水平和应用范围。此外,该研究也为其他序列建模任务提供了新的思路和方法。

📄 摘要(原文)

Transformer-based trajectory optimization methods have demonstrated exceptional performance in offline Reinforcement Learning (offline RL). Yet, it poses challenges due to substantial parameter size and limited scalability, which is particularly critical in sequential decision-making scenarios where resources are constrained such as in robots and drones with limited computational power. Mamba, a promising new linear-time sequence model, offers performance on par with transformers while delivering substantially fewer parameters on long sequences. As it remains unclear whether Mamba is compatible with trajectory optimization, this work aims to conduct comprehensive experiments to explore the potential of Decision Mamba (dubbed DeMa) in offline RL from the aspect of data structures and essential components with the following insights: (1) Long sequences impose a significant computational burden without contributing to performance improvements since DeMa's focus on sequences diminishes approximately exponentially. Consequently, we introduce a Transformer-like DeMa as opposed to an RNN-like DeMa. (2) For the components of DeMa, we identify the hidden attention mechanism as a critical factor in its success, which can also work well with other residual structures and does not require position embedding. Extensive evaluations demonstrate that our specially designed DeMa is compatible with trajectory optimization and surpasses previous methods, outperforming Decision Transformer (DT) with higher performance while using 30\% fewer parameters in Atari, and exceeding DT with only a quarter of the parameters in MuJoCo.