PDFA Distillation via String Probability Queries

📄 arXiv: 2406.18328v2 📥 PDF

作者: Robert Baumgartner, Sicco Verwer

分类: cs.FL, cs.LG

发布日期: 2024-06-26 (更新: 2024-06-28)

备注: LearnAUT 2024


💡 一句话要点

提出基于字符串概率查询的PDFA蒸馏算法,用于从神经网络中提取可解释模型。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 概率确定有限自动机 PDFA 知识蒸馏 可解释机器学习 字符串概率查询

📋 核心要点

  1. 神经网络语言模型缺乏可解释性,难以理解其决策过程,限制了其在安全敏感领域的应用。
  2. 提出一种基于字符串概率查询的PDFA蒸馏算法,从神经网络中提取紧凑且可解释的PDFA模型。
  3. 实验结果表明,该算法能够有效地从神经网络中蒸馏PDFA,并在公开数据集上取得了良好的性能。

📝 摘要(中文)

概率确定有限自动机(PDFA)是一种离散事件系统,用于建模语言上的条件概率:给定一个已观察到的token序列,它们返回下一个可能出现的token的概率。这类模型在可解释机器学习领域越来越受到关注,被用作训练为语言模型的神经网络的替代模型。本文提出了一种从神经网络中蒸馏PDFA的算法。该算法是L#算法的变体,能够从一种新型查询中学习PDFA,该查询通过查询字符串出现的概率来推断条件概率。我们在一个最新的公开数据集上,通过从一组训练好的神经网络中蒸馏PDFA,展示了该算法的有效性。

🔬 方法详解

问题定义:论文旨在解决从复杂的神经网络语言模型中提取可解释的概率确定有限自动机(PDFA)的问题。现有方法,如直接训练PDFA,可能难以达到与大型神经网络相当的性能。而直接使用神经网络进行解释又缺乏透明度,难以理解其内部逻辑。因此,需要一种方法能够将神经网络的知识转移到可解释的PDFA模型中,同时保持较高的预测精度。

核心思路:论文的核心思路是通过查询神经网络,获取字符串的概率信息,然后利用这些概率信息来学习PDFA。这种方法类似于知识蒸馏,将神经网络视为“教师模型”,PDFA视为“学生模型”。通过设计合适的查询策略,可以有效地从神经网络中提取有用的信息,并用于构建PDFA。

技术框架:该算法是L#算法的变体,主要包含以下几个阶段: 1. 初始化:初始化一个包含起始状态的PDFA。 2. 查询:使用字符串概率查询,向神经网络询问特定字符串的概率。 3. 学习:根据查询结果,更新PDFA的状态和转移概率。 4. 迭代:重复查询和学习过程,直到PDFA的性能达到预定的标准。

关键创新:该论文的关键创新在于提出了一种新型的查询方式,即字符串概率查询。传统的L#算法通常使用成员查询和等价查询,而该论文使用字符串概率查询,直接获取字符串出现的概率。这种查询方式更适合从神经网络中提取信息,因为神经网络可以直接输出字符串的概率。

关键设计:算法的关键设计包括: 1. 查询策略:如何选择要查询的字符串,以最大程度地提高学习效率。 2. 概率计算:如何根据查询结果,计算PDFA的状态转移概率。 3. 停止准则:如何判断PDFA的性能是否已经足够好,可以停止学习。

📊 实验亮点

该论文通过实验验证了所提出算法的有效性。实验结果表明,该算法能够从训练好的神经网络中成功地蒸馏出PDFA模型,并在公开数据集上取得了良好的性能。与直接训练的PDFA模型相比,蒸馏得到的PDFA模型具有更高的预测精度和更好的可解释性。

🎯 应用场景

该研究成果可应用于多个领域,例如自然语言处理、语音识别和生物信息学。通过将复杂的神经网络模型转化为可解释的PDFA模型,可以提高模型的可信度和透明度,从而更容易被用户理解和接受。此外,PDFA模型还可以用于生成文本、预测序列和进行异常检测。

📄 摘要(原文)

Probabilistic deterministic finite automata (PDFA) are discrete event systems modeling conditional probabilities over languages: Given an already seen sequence of tokens they return the probability of tokens of interest to appear next. These types of models have gained interest in the domain of explainable machine learning, where they are used as surrogate models for neural networks trained as language models. In this work we present an algorithm to distill PDFA from neural networks. Our algorithm is a derivative of the L# algorithm and capable of learning PDFA from a new type of query, in which the algorithm infers conditional probabilities from the probability of the queried string to occur. We show its effectiveness on a recent public dataset by distilling PDFA from a set of trained neural networks.