Continual Learning Using a Kernel-Based Method Over Foundation Models
作者: Saleh Momeni, Sahisnu Mazumder, Bing Liu
分类: cs.LG, cs.AI, cs.CL, cs.CV
发布日期: 2024-12-20
🔗 代码/项目: GITHUB
💡 一句话要点
提出基于核方法的KLDA持续学习算法,有效应对灾难性遗忘和类间分离问题。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 持续学习 类别增量学习 灾难性遗忘 核方法 线性判别分析
📋 核心要点
- 类别增量学习中的灾难性遗忘和类间分离问题是持续学习的关键挑战,现有方法难以有效解决。
- KLDA算法利用预训练模型特征,通过RBF核和随机傅里叶特征增强表示,避免灾难性遗忘和类间分离。
- 实验表明,KLDA在文本和图像分类任务上显著优于现有基线,性能接近联合训练的上限。
📝 摘要(中文)
本文研究了具有挑战性的类别增量学习(CIL)持续学习(CL)场景。CIL面临两个关键挑战:灾难性遗忘(CF)和类间类别分离(ICS)。尽管提出了许多方法,但这些问题仍然存在。本文提出了一种新的CIL方法,称为核线性判别分析(KLDA),可以有效避免CF和ICS问题。它仅利用基础模型(FM)中学习到的强大特征。然而,直接使用这些特征并非最优。为了解决这个问题,KLDA结合了径向基函数(RBF)核及其随机傅里叶特征(RFF),以增强来自FM的特征表示,从而提高性能。当新任务到达时,KLDA仅计算任务中每个类别的均值,并基于核化特征更新所有已学习类别的共享协方差矩阵。分类使用线性判别分析执行。使用文本和图像分类数据集的经验评估表明,KLDA明显优于基线。值得注意的是,在不依赖重放数据的情况下,KLDA实现了与所有类别的联合训练相当的准确率,这被认为是CIL性能的上限。KLDA代码可在https://github.com/salehmomeni/klda 获得。
🔬 方法详解
问题定义:论文旨在解决类别增量学习(CIL)中的灾难性遗忘(CF)和类间类别分离(ICS)问题。现有方法在处理CIL时,往往难以在学习新任务的同时保持对旧任务的性能,或者无法有效区分不同任务的类别,导致性能下降。
核心思路:论文的核心思路是利用预训练的Foundation Model (FM) 提取的强大特征,并通过核方法(Kernel Method)增强这些特征的表达能力,从而更好地进行类别区分,并减少灾难性遗忘。通过使用核函数,可以将原始特征映射到更高维度的空间,从而更容易找到线性可分的决策边界。
技术框架:KLDA算法的整体流程如下:1. 利用预训练的FM提取特征。2. 使用RBF核和随机傅里叶特征(RFF)对提取的特征进行核化处理,增强特征表示。3. 对于每个新任务,计算每个类别的核化特征的均值。4. 更新所有已学习类别的共享协方差矩阵。5. 使用线性判别分析(LDA)进行分类。
关键创新:该方法最重要的创新点在于将核方法(特别是RBF核和RFF)引入到基于预训练模型的持续学习中。与直接使用预训练模型的特征相比,核方法能够更好地捕捉数据中的非线性关系,从而提高分类性能,并减少灾难性遗忘。此外,该方法仅需计算类别均值和更新共享协方差矩阵,计算效率较高。
关键设计:RBF核的选择是关键设计之一,其参数(例如gamma值)需要根据具体数据集进行调整。随机傅里叶特征(RFF)的使用是为了近似RBF核,从而降低计算复杂度。共享协方差矩阵的设计是为了利用所有已学习类别的信息,从而提高分类的鲁棒性。线性判别分析(LDA)作为分类器,其目标是找到最佳的线性判别方向,使得类内方差最小化,类间方差最大化。
🖼️ 关键图片
📊 实验亮点
实验结果表明,KLDA算法在文本和图像分类数据集上显著优于现有基线方法。在不使用任何重放数据的情况下,KLDA的性能可以与联合训练所有类别的方法相媲美,接近CIL性能的上限。这表明KLDA算法能够有效地避免灾难性遗忘,并保持较高的分类准确率。
🎯 应用场景
该研究成果可应用于需要持续学习新类别或新任务的场景,例如智能监控系统、自动驾驶、医疗诊断等。在这些场景中,模型需要不断适应新的数据分布和类别,而KLDA算法能够有效地解决灾难性遗忘问题,保证模型的性能。
📄 摘要(原文)
Continual learning (CL) learns a sequence of tasks incrementally. This paper studies the challenging CL setting of class-incremental learning (CIL). CIL has two key challenges: catastrophic forgetting (CF) and inter-task class separation (ICS). Despite numerous proposed methods, these issues remain persistent obstacles. This paper proposes a novel CIL method, called Kernel Linear Discriminant Analysis (KLDA), that can effectively avoid CF and ICS problems. It leverages only the powerful features learned in a foundation model (FM). However, directly using these features proves suboptimal. To address this, KLDA incorporates the Radial Basis Function (RBF) kernel and its Random Fourier Features (RFF) to enhance the feature representations from the FM, leading to improved performance. When a new task arrives, KLDA computes only the mean for each class in the task and updates a shared covariance matrix for all learned classes based on the kernelized features. Classification is performed using Linear Discriminant Analysis. Our empirical evaluation using text and image classification datasets demonstrates that KLDA significantly outperforms baselines. Remarkably, without relying on replay data, KLDA achieves accuracy comparable to joint training of all classes, which is considered the upper bound for CIL performance. The KLDA code is available at https://github.com/salehmomeni/klda.