Self-Distillation Improves DNA Sequence Inference
作者: Tong Yu, Lei Cheng, Ruslan Khalitov, Erland Brandser Olsson, Zhirong Yang
分类: cs.LG
发布日期: 2024-05-14
🔗 代码/项目: GITHUB
💡 一句话要点
提出基于自蒸馏的DNA序列推断模型,提升下游任务预测精度
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: DNA序列推断 自蒸馏学习 对比学习 基因组学 深度学习
📋 核心要点
- 现有基因组学自监督预训练方法侧重于单个序列的掩码语言建模,忽略了多个序列的统计信息。
- 论文提出一种包含学生和教师子网络的深度神经网络模型,通过自蒸馏学习序列的上下文信息和分布数据。
- 实验结果表明,该方法在20个下游推断任务中显著提高了预测性能,验证了其有效性。
📝 摘要(中文)
自监督预训练(SSP)已被认为是提高各种下游任务预测精度的一种方法。然而,它在DNA序列上的效果仍然受到一定的限制。这种限制主要源于基因组学中现有的大多数SSP方法侧重于单个序列的掩码语言建模,而忽略了编码多个序列的统计信息的关键方面。为了克服这一挑战,我们引入了一种创新的深度神经网络模型,该模型结合了“学生”和“教师”子网络之间的协同学习。在该模型中,学生子网络采用核苷酸的掩码学习,并通过指数移动平均方法逐步调整其参数以适应教师子网络。同时,两个子网络都参与对比学习,从输入序列的两个增强表示中获得见解。这种自蒸馏过程使我们的模型能够有效地吸收来自单个序列的上下文信息和序列群体的分布数据。我们使用人类参考基因组进行了初步预训练,然后将其应用于20个下游推断任务,验证了我们的方法。这些实验的经验结果表明,我们的新方法显著提高了大多数这些任务的推断性能。我们的代码可在https://github.com/wiedersehne/FinDNA获取。
🔬 方法详解
问题定义:现有基于自监督预训练的DNA序列推断方法,主要集中于对单个序列进行掩码语言建模,忽略了序列之间的统计信息,导致模型无法充分学习DNA序列的特征表示。这限制了模型在下游任务中的表现。
核心思路:论文的核心思路是利用自蒸馏学习框架,通过学生网络和教师网络之间的协同学习,同时学习单个序列的上下文信息和多个序列的分布信息。学生网络通过掩码学习进行训练,并逐步向教师网络学习,教师网络则通过指数移动平均更新参数,从而实现知识的传递和模型的提升。同时,对比学习被用于增强模型对序列不同表示的理解。
技术框架:整体框架包含两个主要的子网络:学生网络和教师网络。首先,输入DNA序列经过数据增强,生成两个不同的表示。然后,学生网络对其中一个表示进行掩码学习,并利用对比学习学习两个表示之间的关系。教师网络则利用指数移动平均更新参数,并为学生网络提供学习目标。最终,两个网络共同学习,提升模型性能。
关键创新:最重要的技术创新点在于将自蒸馏学习框架引入到DNA序列推断任务中,并结合掩码学习和对比学习,从而能够同时学习序列的上下文信息和分布信息。这种方法不同于以往只关注单个序列建模的方法,能够更全面地学习DNA序列的特征表示。
关键设计:关键设计包括:1) 学生网络和教师网络的网络结构选择,例如可以使用Transformer等结构;2) 掩码学习的比例和策略;3) 对比学习的损失函数选择,例如可以使用InfoNCE损失;4) 指数移动平均的衰减系数设置;5) 数据增强方法选择,例如可以使用随机裁剪、突变等方法。
📊 实验亮点
该方法在20个下游推断任务中进行了验证,实验结果表明,该方法显著提高了预测性能。具体来说,该方法在大多数任务上都取得了优于现有方法的性能,证明了其有效性。论文开源了代码,方便其他研究者复现和应用。
🎯 应用场景
该研究成果可应用于基因组学、生物信息学等领域,例如基因功能预测、疾病风险评估、药物靶点发现等。通过提升DNA序列推断的准确性,可以为相关研究提供更可靠的基础,加速生物医学研究的进展,并最终改善人类健康。
📄 摘要(原文)
Self-supervised pretraining (SSP) has been recognized as a method to enhance prediction accuracy in various downstream tasks. However, its efficacy for DNA sequences remains somewhat constrained. This limitation stems primarily from the fact that most existing SSP approaches in genomics focus on masked language modeling of individual sequences, neglecting the crucial aspect of encoding statistics across multiple sequences. To overcome this challenge, we introduce an innovative deep neural network model, which incorporates collaborative learning between a
student' and ateacher' subnetwork. In this model, the student subnetwork employs masked learning on nucleotides and progressively adapts its parameters to the teacher subnetwork through an exponential moving average approach. Concurrently, both subnetworks engage in contrastive learning, deriving insights from two augmented representations of the input sequences. This self-distillation process enables our model to effectively assimilate both contextual information from individual sequences and distributional data across the sequence population. We validated our approach with preliminary pretraining using the human reference genome, followed by applying it to 20 downstream inference tasks. The empirical results from these experiments demonstrate that our novel method significantly boosts inference performance across the majority of these tasks. Our code is available at https://github.com/wiedersehne/FinDNA.