Compact Bayesian Neural Networks via pruned MCMC sampling

📄 arXiv: 2501.06962v1 📥 PDF

作者: Ratneel Deo, Scott Sisson, Jody M. Webster, Rohitash Chandra

分类: cs.LG, cs.AI

发布日期: 2025-01-12

备注: 22 pages, 11 figures


💡 一句话要点

提出基于剪枝MCMC采样的紧凑贝叶斯神经网络,提升模型泛化能力与可移植性。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 贝叶斯神经网络 MCMC采样 网络剪枝 模型压缩 不确定性量化

📋 核心要点

  1. 贝叶斯神经网络训练计算成本高昂,且参数量随网络深度和数据复杂度指数增长,冗余参数影响模型泛化能力。
  2. 论文提出一种基于剪枝的MCMC采样方法,通过对模型参数后验分布采样并剪枝低重要性权重,得到紧凑模型。
  3. 实验结果表明,该方法在保证模型泛化性能的同时,可有效减少网络规模,最高可达75%以上。

📝 摘要(中文)

贝叶斯神经网络(BNNs)在模型预测中提供了稳健的不确定性量化,但训练它们面临着巨大的计算挑战。这主要是由于使用马尔可夫链蒙特卡洛(MCMC)采样和变分推理算法对多模态后验分布进行采样的问题。此外,模型参数的数量随着隐藏层、神经元和数据集中特征的增加而呈指数级增长。通常,这些密集连接的参数中有很大一部分是冗余的,剪枝神经网络不仅提高了可移植性,而且还具有更好的泛化能力。在本研究中,我们通过利用MCMC采样与网络剪枝来获得紧凑的概率模型,从而解决了一些挑战,该模型已经删除了冗余参数。我们对模型参数(权重和偏差)的后验分布进行采样,并剪枝重要性低的权重,从而得到一个紧凑的模型。我们通过调整剪枝后的重采样,确保紧凑的BNN保留其通过后验分布估计不确定性的能力,同时保持模型训练和泛化性能的准确性。我们通过经验结果分析,在选定的回归和分类问题的基准数据集上评估了我们的MCMC剪枝策略的有效性。我们还考虑了两个珊瑚礁钻芯岩性分类数据集,以测试剪枝模型在复杂真实世界数据集中的鲁棒性。我们进一步研究了改进紧凑BNN是否可以保留任何性能损失。我们的结果表明,使用MCMC训练和剪枝BNN是可行的,同时保持了泛化性能,网络规模减少了75%以上。这为开发紧凑的BNN模型铺平了道路,这些模型为实际应用提供不确定性估计。

🔬 方法详解

问题定义:贝叶斯神经网络(BNNs)能够提供不确定性量化,但在实际应用中面临着计算量大和模型参数冗余的问题。现有的MCMC采样方法难以处理高维参数空间,且训练得到的BNN模型体积庞大,不利于部署和应用。因此,需要一种方法能够在降低模型复杂度的同时,保持BNN的不确定性估计能力和泛化性能。

核心思路:论文的核心思路是通过剪枝来降低BNN的复杂度,并结合MCMC采样来保证模型的不确定性估计能力。具体来说,首先使用MCMC采样得到模型参数的后验分布,然后根据权重的重要性对网络进行剪枝,移除冗余的连接。为了弥补剪枝可能带来的性能损失,论文还采用了后剪枝重采样策略,进一步优化模型参数。

技术框架:整体流程包括以下几个主要阶段: 1. MCMC采样:使用MCMC算法对BNN的权重和偏差进行采样,得到模型参数的后验分布。 2. 权重重要性评估:根据某种指标(例如权重的大小或梯度)评估每个权重的重要性。 3. 网络剪枝:根据权重的重要性,移除一部分不重要的连接,得到一个更紧凑的网络。 4. 后剪枝重采样:对剪枝后的网络进行重采样,进一步优化模型参数,弥补剪枝可能带来的性能损失。

关键创新:该方法的主要创新在于将MCMC采样和网络剪枝相结合,能够在降低模型复杂度的同时,保持BNN的不确定性估计能力。传统的剪枝方法通常是确定性的,无法提供不确定性估计。而该方法通过MCMC采样,能够对剪枝后的模型参数进行不确定性量化。此外,后剪枝重采样策略也是一个重要的创新点,能够有效弥补剪枝可能带来的性能损失。

关键设计: 1. MCMC采样算法:论文中使用的MCMC采样算法的具体选择未知,但需要选择一种能够有效处理高维参数空间的算法。 2. 权重重要性评估指标:论文中使用的权重重要性评估指标未知,但需要选择一种能够准确反映权重对模型性能影响的指标。 3. 剪枝比例:剪枝比例是一个重要的超参数,需要根据具体数据集进行调整,以在模型复杂度和性能之间取得平衡。 4. 后剪枝重采样策略:后剪枝重采样策略的具体实现方式未知,但需要设计一种能够有效优化模型参数的策略。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在多个基准数据集上取得了良好的效果。在保证模型泛化性能的前提下,网络规模最多可减少75%以上。此外,该方法在珊瑚礁钻芯岩性分类等复杂真实世界数据集上表现出良好的鲁棒性,验证了其在实际应用中的潜力。通过对紧凑BNN进行改进,可以保留任何性能损失。

🎯 应用场景

该研究成果可应用于对模型大小和计算资源有严格限制的场景,例如移动设备、嵌入式系统和边缘计算。通过减小模型体积,可以降低存储和计算成本,提高推理速度。此外,该方法提供的模型不确定性估计对于风险敏感型应用(如医疗诊断、自动驾驶)至关重要,可以帮助决策者更好地理解模型的预测结果并做出更明智的决策。

📄 摘要(原文)

Bayesian Neural Networks (BNNs) offer robust uncertainty quantification in model predictions, but training them presents a significant computational challenge. This is mainly due to the problem of sampling multimodal posterior distributions using Markov Chain Monte Carlo (MCMC) sampling and variational inference algorithms. Moreover, the number of model parameters scales exponentially with additional hidden layers, neurons, and features in the dataset. Typically, a significant portion of these densely connected parameters are redundant and pruning a neural network not only improves portability but also has the potential for better generalisation capabilities. In this study, we address some of the challenges by leveraging MCMC sampling with network pruning to obtain compact probabilistic models having removed redundant parameters. We sample the posterior distribution of model parameters (weights and biases) and prune weights with low importance, resulting in a compact model. We ensure that the compact BNN retains its ability to estimate uncertainty via the posterior distribution while retaining the model training and generalisation performance accuracy by adapting post-pruning resampling. We evaluate the effectiveness of our MCMC pruning strategy on selected benchmark datasets for regression and classification problems through empirical result analysis. We also consider two coral reef drill-core lithology classification datasets to test the robustness of the pruning model in complex real-world datasets. We further investigate if refining compact BNN can retain any loss of performance. Our results demonstrate the feasibility of training and pruning BNNs using MCMC whilst retaining generalisation performance with over 75% reduction in network size. This paves the way for developing compact BNN models that provide uncertainty estimates for real-world applications.