Pocket Foundation Models: Distilling TFMs into CPU-Ready Gradient-Boosted Trees
作者: Aditya Tanna, Nassim Bouarour, Mohamed Bouadi, Vinay kumar Sankarapu, Pratinav Seth
分类: cs.LG, cs.AI
发布日期: 2026-05-18
💡 一句话要点
提出一种知识蒸馏方法,将表格领域预训练模型压缩为CPU可用的梯度提升树,实现推理加速。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 知识蒸馏 表格数据 预训练模型 梯度提升树 模型压缩
📋 核心要点
- 表格数据预训练模型推理速度慢,难以满足低延迟应用需求,例如欺诈检测。
- 采用知识蒸馏方法,将表格预训练模型(教师模型)的知识迁移到XGBoost或CatBoost(学生模型),使其能在CPU上快速推理。
- 实验表明,该方法在多个数据集上实现了显著的推理加速,同时保持了较高的模型性能,优于直接训练的CatBoost模型。
📝 摘要(中文)
欺诈评分器需要在2毫秒内给出答案,而最好的表格领域预训练模型(TFMs)在GPU上需要151-1275毫秒。本文通过离线将TFM提炼成在CPU上原生运行的XGBoost或CatBoost学生模型来弥合这一差距。核心障碍在于上下文学习(ICL)教师模型:它们在对自己训练集进行评分时会泄露标签,导致软目标坍缩为接近one-hot向量,没有留下可供提炼的类间结构。分层留一法(OOF)教师标签可以防止这种情况。在来自TALENT、OpenML-CC18、TabZilla和TabArena的153个分类数据集上,将TabICLv2提炼到XGBoost中,在CPU上以1.9毫秒的速度实现了0.882的宏平均AUC(教师AUC的96.5%),在师生对之间实现了38倍至860倍的加速,并且在经过调整的CatBoost基线上具有统计学上的显著优势(Wilcoxon p = 0.0008;51%的胜率)。四个进一步的发现:教师排名完全转移到学生排名;收益集中在低维数据上(<21个特征:比CatBoost高+0.011,而>21个特征:+0.001);多教师平均有助于MLP学生(+0.006,p = 0.003),但对树学生增加不到0.001;在高维任务中,如果教师本身落后于CatBoost,蒸馏会使情况变得更糟。完整的pipeline已作为TabTune库的一部分开源。
🔬 方法详解
问题定义:表格领域预训练模型(TFMs)在GPU上的推理速度较慢,无法满足对延迟有严格要求的实际应用,例如欺诈检测需要在毫秒级别完成推理。现有方法难以在保证性能的同时,实现模型在CPU上的快速部署。
核心思路:利用知识蒸馏技术,将复杂的TFMs的知识迁移到更轻量级的梯度提升树模型(如XGBoost或CatBoost)。通过训练学生模型来模仿教师模型的输出,从而在CPU上实现快速且准确的推理。特别地,针对上下文学习(ICL)教师模型,采用分层留一法(OOF)来避免标签泄露问题。
技术框架:该方法包含两个主要阶段:1) 教师模型(TFM)的训练或选择;2) 学生模型(XGBoost或CatBoost)的训练,使用教师模型的预测作为软标签。关键在于使用分层OOF来生成教师模型的预测,以避免标签泄露。最终,训练好的学生模型部署在CPU上进行推理。
关键创新:针对ICL教师模型,提出了分层OOF标签方法,有效防止了教师模型在训练集上进行预测时产生的标签泄露问题,保证了知识蒸馏的有效性。这使得学生模型能够学习到教师模型更丰富的类间结构信息。
关键设计:分层OOF的具体实现是,将训练数据分成多个fold,每个fold轮流作为验证集,其余fold作为训练集。教师模型在训练集上训练,并在验证集上进行预测。最终,每个样本都会得到一个由教师模型产生的预测值,作为学生模型的训练标签。学生模型使用这些软标签进行训练,损失函数通常采用交叉熵损失或均方误差损失。
🖼️ 关键图片
📊 实验亮点
实验结果表明,将TabICLv2蒸馏到XGBoost后,在CPU上以1.9毫秒的速度实现了0.882的宏平均AUC,达到了教师模型AUC的96.5%,同时实现了38倍至860倍的推理加速。与直接训练的CatBoost模型相比,该方法在统计上具有显著优势(Wilcoxon p = 0.0008),胜率为51%。
🎯 应用场景
该研究成果可广泛应用于对推理速度有较高要求的表格数据分类任务,例如金融风控、医疗诊断、推荐系统等。通过将复杂的预训练模型蒸馏成轻量级模型,可以在资源受限的设备上实现高性能的预测,降低部署成本,并加速模型迭代。
📄 摘要(原文)
A fraud scorer needs to answer in under 2 ms. The best tabular foundation models (TFMs) take 151-1,275 ms on GPU. We close this gap by distilling the TFM offline into an XGBoost or CatBoost student that runs natively on CPU. The central obstacle is specific to in-context learning (ICL) teachers: they leak labels when scoring their own training set, so the soft targets collapse to near-one-hot vectors with no inter-class structure left to distill. Stratified out-of-fold (OOF) teacher labeling prevents this. Across 153 classification datasets drawn from TALENT, OpenML-CC18, TabZilla, and TabArena, distilling TabICLv2 into XGBoost gives 0.882 macro-mean AUC (96.5% of teacher AUC) at 1.9 ms on CPU, a 38x to 860x speedup across teacher-student pairs with a statistically significant edge over a tuned CatBoost baseline (Wilcoxon p = 0.0008; 51% win rate). Four further findings: teacher rank transfers exactly to student rank; gains concentrate on low-dimensional data (< 21 features: +0.011 over CatBoost vs. >21 features: +0.001); multi-teacher averaging helps MLP students (+0.006, p = 0.003) but adds less than 0.001 for tree students; and on high-dimensional tasks where the teacher itself trails CatBoost, distillation makes things worse rather than better. The full pipeline is open-sourced as part of the TabTune library.