Revisiting Scalable Hessian Diagonal Approximations for Applications in Reinforcement Learning
作者: Mohamed Elsayed, Homayoon Farrahi, Felix Dangel, A. Rupam Mahmood
分类: cs.LG, cs.AI
发布日期: 2024-06-05 (更新: 2024-07-03)
备注: Published in the Proceedings of the 41st International Conference on Machine Learning (ICML 2024). Code is available at https://github.com/mohmdelsayed/HesScale. arXiv admin note: substantial text overlap with arXiv:2210.11639
💡 一句话要点
HesScale:一种高效可扩展的Hessian对角近似方法,提升强化学习性能
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: Hessian近似 二阶优化 强化学习 步长缩放 计算效率
📋 核心要点
- 现有Hessian对角近似方法计算成本高昂,限制了其在实际问题中的应用。
- 论文提出HesScale,一种基于BL89的改进方法,计算成本与梯度相当,且性能更优。
- 实验表明,HesScale在强化学习任务中,优化速度更快,稳定性更高,具有良好的应用前景。
📝 摘要(中文)
二阶信息在许多应用中都很有价值,但计算起来极具挑战性。一些工作致力于计算或近似Hessian对角线,但即使这种简化也比计算梯度引入了显著的额外成本。在缺乏高效的Hessian对角线精确计算方案的情况下,我们重新审视了Becker和LeCun(1989, BL89)提出的早期近似方案,该方案的成本与梯度相似,但似乎被社区忽略了。我们引入了HesScale,它是BL89的改进版本,增加了可忽略不计的额外计算量。在小型网络上,我们发现这种改进比所有替代方案都具有更高的质量,即使是那些具有理论保证(如无偏性)的方案,而且计算成本要低得多。我们在使用小型网络的强化学习问题中使用了这一发现,并在二阶优化和步长参数缩放中展示了HesScale。在我们的实验中,HesScale比现有方法优化得更快,并通过步长缩放提高了稳定性。这些发现对于未来在更大的模型中扩展二阶方法是有希望的。
🔬 方法详解
问题定义:论文旨在解决计算Hessian矩阵对角线元素的高昂计算成本问题。现有方法,即使是近似方法,也往往引入显著的额外计算负担,限制了二阶优化方法在实际问题,特别是强化学习中的应用。BL89方法虽然计算高效,但精度有待提高。
核心思路:论文的核心思路是改进BL89方法,使其在计算成本几乎不变的情况下,显著提高Hessian对角线近似的精度。通过引入简单的缩放因子,HesScale能够更准确地估计Hessian对角线,从而改善二阶优化算法的性能。
技术框架:HesScale方法建立在BL89方法的基础上。BL89方法通过计算梯度与输入向量的乘积来近似Hessian对角线。HesScale在BL89的基础上,增加了一个缩放步骤,该步骤的计算量可以忽略不计。整体流程包括:1)计算梯度;2)应用BL89方法进行初步近似;3)应用HesScale缩放因子进行校正。
关键创新:HesScale的关键创新在于引入了高效的缩放因子,该因子能够显著提高Hessian对角线近似的精度,而几乎不增加计算成本。与现有方法相比,HesScale在精度和效率之间取得了更好的平衡。
关键设计:HesScale的关键设计在于缩放因子的选择。论文中具体缩放因子的计算公式未知,但强调了其计算复杂度极低,可以忽略不计。此外,论文在强化学习任务中,将HesScale应用于二阶优化算法和步长参数缩放,以提高训练速度和稳定性。
🖼️ 关键图片
📊 实验亮点
实验结果表明,HesScale在小型网络上优于其他Hessian对角近似方法,即使是那些具有理论保证的方法。在强化学习任务中,HesScale优化速度快于现有方法,并通过步长缩放提高了稳定性。具体性能提升数据未知,但论文强调了HesScale在实际应用中的优势。
🎯 应用场景
该研究成果可广泛应用于需要二阶优化方法的机器学习领域,尤其是在计算资源受限或模型规模较大的情况下。例如,可以应用于深度强化学习、自然语言处理和计算机视觉等领域,加速模型训练,提高模型性能。HesScale的低计算成本使其有望成为大规模模型训练的实用工具。
📄 摘要(原文)
Second-order information is valuable for many applications but challenging to compute. Several works focus on computing or approximating Hessian diagonals, but even this simplification introduces significant additional costs compared to computing a gradient. In the absence of efficient exact computation schemes for Hessian diagonals, we revisit an early approximation scheme proposed by Becker and LeCun (1989, BL89), which has a cost similar to gradients and appears to have been overlooked by the community. We introduce HesScale, an improvement over BL89, which adds negligible extra computation. On small networks, we find that this improvement is of higher quality than all alternatives, even those with theoretical guarantees, such as unbiasedness, while being much cheaper to compute. We use this insight in reinforcement learning problems where small networks are used and demonstrate HesScale in second-order optimization and scaling the step-size parameter. In our experiments, HesScale optimizes faster than existing methods and improves stability through step-size scaling. These findings are promising for scaling second-order methods in larger models in the future.