$f$-Trajectory Balance: A Loss Family for Tuning GFlowNets, Generative Models, and LLMs with Off- and On-Policy Data

📄 arXiv: 2605.15417v1 📥 PDF

作者: Jake Fawkes, Jason Hartford

分类: cs.LG, cs.AI

发布日期: 2026-05-14

备注: Published at ICML 2026


💡 一句话要点

提出f-Trajectory Balance损失族,用于优化GFlowNets、生成模型和LLM

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 生成模型 GFlowNets f-散度 损失函数 Off-policy学习

📋 核心要点

  1. 现有生成模型训练方法在off-policy数据上表现不佳,难以兼顾模式覆盖和训练稳定性。
  2. 论文核心思想是构建一个与f-散度相关的损失函数族,在on-policy时等价于f-散度,off-policy时保持有效性。
  3. 实验表明,该损失函数族在多种生成模型任务中,均能有效提升模型性能,并保持on-policy和off-policy的一致性。

📝 摘要(中文)

本文提出了一种新的损失函数族,称为$f$-Trajectory Balance,用于训练生成模型,包括GFlowNets、变分推断模型和大型语言模型(LLM)。该损失族基于目标和模型对数概率之间的均方误差,并将其扩展到整个$f$-散度家族。在on-policy评估时,该损失的梯度对应于相应的$f$-散度,而在off-policy评估时,它仍然是一个有效的损失,具有相同的全局最小值。这种等价性允许设计新的替代损失函数,用于调整各种生成模型,这些模型继承了相应$f$-散度的属性,例如更好的模式覆盖,同时适用于off-policy数据。该损失在一系列任务中进行了测试,包括合成示例、用于分子发现的SynFlowNets和异步LLM调整,结果表明模型在on-policy和off-policy数据上都保持了其预测的属性。

🔬 方法详解

问题定义:现有生成模型,如GFlowNets和LLM,在训练时面临着on-policy数据利用率低和off-policy数据训练不稳定的问题。传统的KL散度损失在off-policy情况下可能失效,导致模型训练崩溃或模式坍塌。因此,需要一种既能利用off-policy数据,又能保证训练稳定性和模式覆盖的损失函数。

核心思路:论文的核心思路是将目标和模型对数概率之间的均方误差推广到$f$-散度家族。通过构建一个与特定$f$-散度相关的损失函数,使得在on-policy情况下,损失函数的梯度等价于该$f$-散度的梯度。而在off-policy情况下,该损失函数仍然有效,并具有相同的全局最小值。这样,就可以利用off-policy数据进行训练,同时继承$f$-散度的良好性质,如模式覆盖。

技术框架:该方法的核心是构建一个$f$-Trajectory Balance损失函数族。对于每个$f$-散度,都存在一个对应的损失函数,其形式为目标和模型对数概率的函数。在训练过程中,可以使用on-policy或off-policy数据来计算损失函数的梯度,并更新模型参数。整体流程包括:1) 选择一个合适的$f$-散度;2) 构建对应的$f$-Trajectory Balance损失函数;3) 使用on-policy或off-policy数据计算损失函数的梯度;4) 更新模型参数。

关键创新:该方法最重要的创新点在于将目标和模型对数概率之间的均方误差推广到整个$f$-散度家族,从而构建了一个新的损失函数族。与传统的KL散度损失相比,该损失函数族具有更好的模式覆盖能力,并且适用于off-policy数据。此外,该方法还揭示了translation invariant loss functions on the target and model log probabilities, and $f$-divergences之间的一一对应关系。

关键设计:关键设计在于如何构建与特定$f$-散度对应的$f$-Trajectory Balance损失函数。论文中给出了一个通用的构建方法,即通过选择一个合适的函数$f$,并将其代入一个特定的公式中,即可得到对应的损失函数。具体的参数设置取决于所选择的$f$-散度。例如,当选择KL散度时,对应的损失函数就是传统的KL散度损失。网络结构的选择取决于具体的生成模型。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在合成数据集、SynFlowNets分子发现和LLM异步调优等任务上均取得了显著的性能提升。例如,在分子发现任务中,使用该方法训练的SynFlowNets能够生成更多样化的分子结构,并具有更高的目标属性。在LLM异步调优任务中,该方法能够有效地利用离线数据,加速模型的收敛速度,并提升模型的生成质量。

🎯 应用场景

该研究成果可广泛应用于生成模型的训练和优化,包括GFlowNets、变分推断模型和大型语言模型。尤其在需要利用大量离线数据或进行异步训练的场景下,该方法具有重要价值。例如,可以用于分子发现、文本生成、图像生成等领域,提升生成模型的性能和泛化能力。

📄 摘要(原文)

In GFlowNets and variational inference, it has been shown that the mean square error between target and model log probabilities is an effective, low variance, surrogate loss for training generative models. This loss has the property that when evaluated \emph{on-policy} its gradients correspond to those of the KL divergence, while \emph{off-policy} it remains a valid loss with the same global minimizer. In this work, we demonstrate that this construction can be extended to the whole family of $f$-divergences, leading to a family of losses whose on-policy gradients are that of the corresponding $f$-divergence, but retain the same global minimizer off-policy. Specifically, we show that the on-policy gradients lead to a one to one correspondence between translation invariant loss functions on the target and model log probabilities, and $f$-divergences. This equivalence allows us to design new surrogate loss functions for tuning a wide class of generative models that inherit the properties of the corresponding $f$-divergence, such as being more mode covering, whilst being applicable to off-policy data. We apply our losses on a range of tasks, including classic synthetic examples, SynFlowNets for molecule discovery, and asynchronous large language model (LLM) tuning, demonstrating that our models retain their predicted properties on- and off-policy in a wide class of generative models.