Clin-JEPA: A Multi-Phase Co-Training Framework for Joint-Embedding Predictive Pretraining on EHR Patient Trajectories
作者: Yixuan Yang, Mehak Arora, Ryan Zhang, Baraa Abed, Junseob Kim, Tilendra Choudhary, Md Hassanuzzaman, Kevin Zhu, Ayman Ali, Chengkun Yang, Alasdair Edward Gent, Victor Moas, Rishikesan Kamaleswaran
分类: cs.LG, cs.AI, q-bio.QM
发布日期: 2026-05-11
备注: 17 pages, 4 figures, 8 tables. Code: https://github.com/YeungYathin/Clin-JEPA
💡 一句话要点
提出Clin-JEPA多阶段协同训练框架,实现电子健康记录(EHR)患者轨迹的联合嵌入预测预训练。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 电子健康记录 联合嵌入预测 课程学习 时序预测 表示学习 临床决策支持
📋 核心要点
- 现有JEPA框架在EHR应用中面临预测器与编码器脱节的问题,导致无法有效利用滚动信号,且直接协同训练极易引发表示坍缩与漂移。
- 提出五阶段课程学习策略,通过预测器预热、EMA目标对齐及硬同步等机制,实现了Qwen3-8B编码器与轨迹预测器的稳定协同训练。
- 实验证明Clin-JEPA在48小时轨迹预测中漂移收敛(-15.7%),并在多项临床风险预测任务中显著超越强基线模型,AUROC提升达0.041。
📝 摘要(中文)
本文提出了Clin-JEPA,一种用于电子健康记录(EHR)患者轨迹联合嵌入预测(JEPA)的五阶段协同训练框架。尽管JEPA架构在机器人和视觉领域表现出色,但将其扩展至EHR数据以实现单一骨干网络同时进行轨迹预测和下游风险评估仍具挑战。现有方法在预训练后丢弃预测器或将其固定,导致编码器无法感知推理时的滚动信号。Clin-JEPA通过五阶段课程学习(预测器预热、联合精调、EMA目标对齐、硬同步及预测器终结)解决了表示坍缩和在线/目标漂移问题。在MIMIC-IV ICU数据集上的实验表明,该框架在48小时轨迹预测中实现了稳定的漂移收敛,并显著提升了多任务风险预测性能,优于现有的表格和序列基线模型。
🔬 方法详解
问题定义:论文旨在解决EHR数据中长时序患者轨迹预测与下游风险分类任务的统一建模问题。现有JEPA范式在预训练后往往忽略预测器,或在固定编码器上训练预测器,导致编码器无法针对推理时的滚动预测信号进行优化,而直接协同训练又会导致模型不稳定及表示坍缩。
核心思路:引入一种多阶段课程学习框架,通过分阶段约束编码器与预测器的协同演化,确保潜在空间表示的临床判别性,同时通过EMA(指数移动平均)和硬同步机制缓解在线目标漂移,使模型能够进行稳定的自回归轨迹预测。
技术框架:整体架构基于Qwen3-8B作为编码器,配合92M参数的潜在轨迹预测器。训练流程分为五个阶段:预测器预热、联合精调、EMA目标对齐、硬同步以及预测器终结,逐步增强模型对时序演化的建模能力。
关键创新:提出了针对JEPA的五阶段预训练课程,通过分阶段解耦与重组训练目标,成功克服了联合训练中的不稳定性,使单一骨干网络能够同时胜任生成式轨迹预测与判别式风险评估任务。
关键设计:采用EMA目标对齐以稳定目标表示,通过硬同步机制强制编码器与预测器在后期对齐,并利用潜在空间中的L1距离作为滚动漂移的度量指标,确保模型在长达48小时的预测窗口内保持收敛。
🖼️ 关键图片
📊 实验亮点
Clin-JEPA在MIMIC-IV数据集上表现优异:在48小时轨迹预测中,其潜在空间漂移收敛至-15.7%,而基线模型则出现显著发散(最高达+4951%)。在临床判别性方面,恶化患者在潜在空间中的位移是稳定患者的4.83倍。在多任务风险预测中,该模型在ICareFM EEP上达到0.851 AUROC,在8项二分类任务中平均AUROC提升0.041。
🎯 应用场景
该研究主要应用于重症监护(ICU)环境下的患者病情监测与风险预警。通过对患者电子健康记录的深度建模,Clin-JEPA可作为通用的临床决策支持系统骨干,在无需针对特定任务进行微调的情况下,实现对患者病情恶化、死亡风险等多种临床指标的实时预测,具有极高的临床应用价值。
📄 摘要(原文)
We present Clin-JEPA, a multi-phase co-training framework for joint-embedding predictive (JEPA) pretraining on EHR patient trajectories. JEPA architectures have enabled latent-space planning in robotics and high-quality representation learning in vision, but extending the paradigm to EHR data -- to obtain a single backbone that simultaneously forecasts patient trajectories and serves diverse downstream risk-prediction tasks without per-task fine-tuning -- remains an open challenge. Existing JEPA frameworks either discard the predictor after pretraining (I-JEPA, V-JEPA) or train it on a frozen pretrained encoder (V-JEPA 2-AC), leaving the encoder unaware of the rollout signal that the retained predictor must use at inference; co-training the encoder and predictor under a shared JEPA prediction objective would supply this grounding, but naïve co-training is unstable, with representation collapse and online/target drift causing autoregressive rollout to diverge. Clin-JEPA's five-phase pretraining curriculum -- predictor warmup, joint refinement, EMA target alignment, hard sync, and predictor finalization -- addresses each failure mode by phase, stably co-training a Qwen3-8B-based encoder and a 92M-parameter latent trajectory predictor. On MIMIC-IV ICU data, three independent evaluations support the framework: (1) latent $\ell_1$ rollout drift uniquely converges ($-$15.7%) over 48-hour horizons while baselines and ablations diverge (+3% to +4951%); (2) the encoder learns a clinically discriminative latent geometry (deteriorating-patient cohorts displace 4.83$\times$ further than stable patients in latent space, vs $\leq$2.62$\times$ for baseline encoders); (3) a single backbone outperforms strong tabular and sequence baselines on multi-task downstream evaluation. Clin-JEPA achieves mean AUROC 0.851 on ICareFM EEP and 0.883 on 8 binary risk tasks (+0.038 and +0.041 vs baseline average).