Towards Generalizable Trajectory Prediction Using Dual-Level Representation Learning And Adaptive Prompting

📄 arXiv: 2501.04815v1 📥 PDF

作者: Kaouther Messaoud, Matthieu Cord, Alexandre Alahi

分类: cs.CV

发布日期: 2025-01-08


💡 一句话要点

提出PerReg+,利用双层表征学习和自适应Prompt提升轨迹预测的泛化性。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 轨迹预测 泛化性 双层表征学习 自适应Prompt Tuning 多模态预测 自动驾驶 Perceiver

📋 核心要点

  1. 现有轨迹预测模型泛化性差,难以处理复杂交互,且架构复杂,针对特定数据集定制。
  2. PerReg+通过双层表征学习捕获全局上下文和细粒度细节,并利用自适应Prompt Tuning提升泛化性。
  3. PerReg+在nuScenes、Argoverse 2和WOMD上取得了SOTA性能,预训练模型在小数据集上误差降低6.8%。

📝 摘要(中文)

现有的车辆轨迹预测模型在泛化性、预测不确定性和处理复杂交互方面存在困难。这通常是由于为特定数据集定制的复杂架构以及低效的多模态处理等限制造成的。我们提出了PerReg+,一种新颖的轨迹预测框架,它引入了:(1)通过自蒸馏(SD)和掩码重建(MR)的双层表征学习,捕获全局上下文和细粒度细节。此外,我们通过查询丢弃从掩码输入重建分段轨迹和车道线段的方法,能够有效地利用上下文信息并提高泛化性;(2)使用基于寄存器的查询和预训练来增强多模态性,无需聚类和抑制;以及(3)微调期间的自适应Prompt Tuning,冻结主架构并优化少量Prompt以实现高效适应。PerReg+在nuScenes、Argoverse 2和Waymo Open Motion Dataset (WOMD)上取得了新的state-of-the-art性能。值得注意的是,我们的预训练模型在较小的数据集上将误差降低了6.8%,多数据集训练增强了泛化性。在跨域测试中,与非预训练变体相比,PerReg+将B-FDE降低了11.8%。

🔬 方法详解

问题定义:车辆轨迹预测旨在预测车辆未来一段时间内的运动轨迹。现有方法通常依赖于针对特定数据集设计的复杂架构,泛化能力较差,难以适应新的场景和数据集。此外,现有方法在处理多模态预测和复杂交互时效率较低,需要复杂的聚类和抑制策略。

核心思路:PerReg+的核心思路是利用双层表征学习来同时捕获全局上下文和细粒度细节,并通过自适应Prompt Tuning来实现高效的领域适应。通过预训练模型,可以学习到通用的运动模式,从而提高在不同数据集上的泛化能力。使用基于寄存器的查询来处理多模态预测,避免了复杂的聚类和抑制过程。

技术框架:PerReg+的整体架构基于Perceiver模型,并引入了Register queries。该框架包含以下主要模块:1) 双层表征学习模块,通过自蒸馏和掩码重建来学习全局上下文和细粒度细节;2) 基于寄存器的查询模块,用于处理多模态预测;3) 自适应Prompt Tuning模块,用于在微调阶段优化少量Prompt,实现高效的领域适应。

关键创新:PerReg+的关键创新点在于:1) 提出了双层表征学习方法,能够同时捕获全局上下文和细粒度细节,提高了模型的表达能力;2) 使用基于寄存器的查询来处理多模态预测,避免了复杂的聚类和抑制过程;3) 引入了自适应Prompt Tuning,能够在微调阶段高效地适应新的领域。

关键设计:在双层表征学习中,使用了自蒸馏和掩码重建两种方法。自蒸馏通过将教师模型的知识传递给学生模型来提高模型的性能。掩码重建通过从掩码输入中重建轨迹和车道线段来增强模型对上下文信息的利用。在自适应Prompt Tuning中,冻结了主架构,只优化少量Prompt参数,从而实现了高效的领域适应。具体的损失函数包括轨迹预测损失、自蒸馏损失和掩码重建损失。

📊 实验亮点

PerReg+在nuScenes、Argoverse 2和WOMD数据集上取得了SOTA性能。在较小的数据集上,预训练模型将误差降低了6.8%。在跨域测试中,与非预训练变体相比,PerReg+将B-FDE降低了11.8%,显著提升了模型的泛化能力。

🎯 应用场景

PerReg+可应用于自动驾驶、高级驾驶辅助系统(ADAS)、智能交通管理等领域。通过提高轨迹预测的准确性和泛化性,可以提升自动驾驶系统的安全性、可靠性和适应性,减少交通事故的发生,并优化交通流量。

📄 摘要(原文)

Existing vehicle trajectory prediction models struggle with generalizability, prediction uncertainties, and handling complex interactions. It is often due to limitations like complex architectures customized for a specific dataset and inefficient multimodal handling. We propose Perceiver with Register queries (PerReg+), a novel trajectory prediction framework that introduces: (1) Dual-Level Representation Learning via Self-Distillation (SD) and Masked Reconstruction (MR), capturing global context and fine-grained details. Additionally, our approach of reconstructing segmentlevel trajectories and lane segments from masked inputs with query drop, enables effective use of contextual information and improves generalization; (2) Enhanced Multimodality using register-based queries and pretraining, eliminating the need for clustering and suppression; and (3) Adaptive Prompt Tuning during fine-tuning, freezing the main architecture and optimizing a small number of prompts for efficient adaptation. PerReg+ sets a new state-of-the-art performance on nuScenes [1], Argoverse 2 [2], and Waymo Open Motion Dataset (WOMD) [3]. Remarkable, our pretrained model reduces the error by 6.8% on smaller datasets, and multi-dataset training enhances generalization. In cross-domain tests, PerReg+ reduces B-FDE by 11.8% compared to its non-pretrained variant.