Optimized Local Updates in Federated Learning via Reinforcement Learning
作者: Ali Murad, Bo Hui, Wei-Shinn Ku
分类: cs.LG
发布日期: 2025-05-31
备注: This paper is accepted at IEEE IJCNN 2025
🔗 代码/项目: GITHUB
💡 一句话要点
提出基于强化学习的联邦学习局部更新优化方法,提升非独立同分布数据下的模型性能。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 联邦学习 强化学习 非独立同分布数据 局部更新优化 分布式机器学习
📋 核心要点
- 联邦学习在非独立同分布数据下,模型聚合导致性能下降,过度本地训练并不能提升整体性能。
- 利用深度强化学习智能体,根据训练损失变化优化客户端本地训练数据量,避免过度共享信息。
- 实验证明,该方法在多个基准数据集和FL框架上,能够提升联邦学习客户端的性能。
📝 摘要(中文)
联邦学习(FL)是一种分布式框架,用于在大规模分布式数据上进行协作模型训练,在保护客户端数据隐私的同时实现更高的性能。然而,集中式服务器上的模型聚合在不同客户端存在非独立同分布(non-IID)数据时可能导致性能下降。我们注意到,在客户端本地训练过多数据并不利于所有客户端的整体性能。本文设计了一种新颖的框架,利用深度强化学习(DRL)智能体选择训练客户端模型所需的优化数据量,而无需过度共享信息给服务器。DRL智能体在不了解客户端性能的情况下,利用训练损失的变化作为奖励信号,学习优化改进客户端性能所需的训练数据量。具体而言,在每个聚合轮次之后,DRL算法将本地性能视为当前状态,并输出每个类别的优化权重,用于下一轮本地训练。通过这样做,智能体学习一种策略,在FL轮次期间创建本地训练数据集的优化分区。在FL之后,客户端利用整个本地训练数据集进一步增强其自身数据分布的性能,从而减轻聚合的非独立同分布影响。通过大量实验,我们证明了通过我们的算法训练FL客户端可以在多个基准数据集和FL框架上获得卓越的性能。我们的代码可在https://github.com/amuraddd/optimized_client_training.git 获取。
🔬 方法详解
问题定义:联邦学习在客户端数据非独立同分布的情况下,模型聚合容易导致全局模型性能下降。现有方法通常采用固定的本地训练策略,没有考虑到不同客户端数据分布的差异性,可能导致某些客户端过度训练,反而影响整体性能。因此,需要一种自适应的本地训练策略,能够根据客户端的数据特点和模型状态,动态调整训练数据量。
核心思路:论文的核心思路是利用深度强化学习(DRL)来优化每个客户端的本地训练数据量。通过将客户端的本地性能作为状态,训练损失的变化作为奖励信号,DRL智能体可以学习一种策略,选择最适合该客户端的训练数据子集。这样可以避免过度训练,并提高模型在非独立同分布数据下的泛化能力。
技术框架:整体框架包括联邦学习服务器和多个客户端。每个客户端都包含一个本地模型和一个DRL智能体。 1. 联邦学习服务器:负责模型聚合和分发。 2. 客户端: - 本地模型:参与联邦学习的模型。 - DRL智能体:根据本地模型性能,选择优化的训练数据子集。 - 训练过程:在每一轮联邦学习中,客户端首先使用DRL智能体选择训练数据子集,然后使用该子集训练本地模型,并将模型更新上传到服务器。服务器聚合所有客户端的模型更新,并将更新后的全局模型分发给客户端。
关键创新:该论文的关键创新在于将强化学习引入联邦学习的本地训练过程,实现了一种自适应的本地训练策略。与传统的固定训练策略相比,该方法能够根据客户端的数据特点和模型状态,动态调整训练数据量,从而提高模型在非独立同分布数据下的泛化能力。
关键设计: - 状态:客户端的本地模型性能,例如训练损失。 - 动作:每个类别的训练数据权重,用于选择训练数据子集。 - 奖励:训练损失的变化,反映了模型性能的提升。 - DRL算法:可以使用任何合适的DRL算法,例如DQN或Policy Gradient。 - 本地训练后处理:在联邦学习完成后,客户端使用整个本地数据集进行进一步训练,以缓解非独立同分布数据的影响。
🖼️ 关键图片
📊 实验亮点
实验结果表明,该方法在多个基准数据集(如MNIST、CIFAR-10)和FL框架上,均取得了优于传统联邦学习方法的性能。具体而言,该方法能够显著提高模型在非独立同分布数据下的准确率,并且在某些情况下,可以达到与集中式训练相近的性能。
🎯 应用场景
该研究成果可应用于各种联邦学习场景,尤其是在客户端数据非独立同分布的情况下,例如医疗健康、金融风控、智能推荐等领域。通过优化本地训练数据量,可以提高模型在异构数据上的泛化能力,提升联邦学习的实际应用价值,并促进更安全、高效的分布式机器学习。
📄 摘要(原文)
Federated Learning (FL) is a distributed framework for collaborative model training over large-scale distributed data, enabling higher performance while maintaining client data privacy. However, the nature of model aggregation at the centralized server can result in a performance drop in the presence of non-IID data across different clients. We remark that training a client locally on more data than necessary does not benefit the overall performance of all clients. In this paper, we devise a novel framework that leverages a Deep Reinforcement Learning (DRL) agent to select an optimized amount of data necessary to train a client model without oversharing information with the server. Starting without awareness of the client's performance, the DRL agent utilizes the change in training loss as a reward signal and learns to optimize the amount of training data necessary for improving the client's performance. Specifically, after each aggregation round, the DRL algorithm considers the local performance as the current state and outputs the optimized weights for each class, in the training data, to be used during the next round of local training. In doing so, the agent learns a policy that creates an optimized partition of the local training dataset during the FL rounds. After FL, the client utilizes the entire local training dataset to further enhance its performance on its own data distribution, mitigating the non-IID effects of aggregation. Through extensive experiments, we demonstrate that training FL clients through our algorithm results in superior performance on multiple benchmark datasets and FL frameworks. Our code is available at https://github.com/amuraddd/optimized_client_training.git.