SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models
作者: Shuaijie Shen, Chao Wang, Renzhuo Huang, Yan Zhong, Qinghai Guo, Zhichao Lu, Jianguo Zhang, Luziwei Leng
分类: cs.CL, cs.LG, cs.NE
发布日期: 2024-08-27 (更新: 2024-12-24)
💡 一句话要点
提出SpikingSSM,利用脉冲神经网络和状态空间模型进行高效长序列学习
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 脉冲神经网络 状态空间模型 长序列学习 稀疏计算 低功耗 并行计算 语言建模
📋 核心要点
- 脉冲神经网络在视觉任务上表现出色,但在长序列建模方面应用较少,未能充分利用其内在的时间动态特性。
- SpikingSSM结合了状态空间模型和脉冲神经元,通过分层集成和稀疏计算,提升了长序列建模能力。
- 通过轻量级替代动态网络解决并行计算问题,加速训练,并在长序列任务和语言建模上取得了显著成果。
📝 摘要(中文)
本文提出了一种用于长序列学习的脉冲状态空间模型(SpikingSSM),旨在结合脉冲神经网络(SNNs)的低功耗特性和状态空间模型(SSMs)的序列建模能力。受树突神经元结构的启发,该模型分层地将神经元动态与原始SSM块集成,同时实现稀疏突触计算。为了解决事件驱动的神经元动态与并行计算的冲突,提出了一种轻量级的替代动态网络,该网络能够准确预测重置后的膜电位,并与可学习的阈值兼容,从而显著提高训练速度。在长程竞技场基准测试中,SpikingSSM在实现平均90%网络稀疏性的同时,达到了与最先进的SSM相当的性能。在语言建模方面,该网络在WikiText-103数据集上显著超越了现有的脉冲大型语言模型(spikingLLMs),且模型大小仅为其三分之一,展示了其作为低计算成本LLM骨干架构的潜力。
🔬 方法详解
问题定义:现有脉冲神经网络(SNNs)在长序列建模任务中表现不佳,无法有效利用其时间动态特性。传统的SNN训练方法难以进行并行计算,限制了其在大规模序列数据上的应用。此外,如何在SNN中实现高效的稀疏计算,降低计算成本,也是一个挑战。
核心思路:本文的核心思路是将状态空间模型(SSMs)的序列建模能力与脉冲神经网络的低功耗特性相结合,构建SpikingSSM。通过模仿树突神经元的结构,分层地将神经元动态集成到SSM块中,实现稀疏突触计算。同时,设计轻量级的替代动态网络,解决SNN训练中的并行计算问题。
技术框架:SpikingSSM的整体架构包括:1) 原始SSM块,用于序列建模;2) 分层集成的脉冲神经元动态,模拟树突结构,实现稀疏计算;3) 轻量级替代动态网络,用于预测重置后的膜电位,支持并行训练。训练过程包括前向传播、反向传播和参数更新。
关键创新:最重要的技术创新点在于:1) 将脉冲神经元动态与SSM块分层集成,实现稀疏计算,降低功耗;2) 提出轻量级替代动态网络,解决了事件驱动的神经元动态与并行计算的冲突,显著加速了训练过程;3) 实现了可学习的阈值,提升了模型的灵活性和表达能力。
关键设计:关键设计包括:1) 树突神经元结构的模拟,通过分层集成实现稀疏连接;2) 轻量级替代动态网络的结构和训练方法,确保其能够准确预测重置后的膜电位;3) 可学习阈值的初始化和更新策略,保证模型的稳定性和收敛性。损失函数采用交叉熵损失或均方误差损失,根据具体任务进行选择。
🖼️ 关键图片
📊 实验亮点
SpikingSSM在长程竞技场基准测试中,实现了与最先进的SSM相当的性能,同时实现了平均90%的网络稀疏性。在WikiText-103数据集上,SpikingSSM显著超越了现有的脉冲大型语言模型(spikingLLMs),且模型大小仅为其三分之一,验证了其高效性和潜力。
🎯 应用场景
SpikingSSM具有广泛的应用前景,包括低功耗边缘计算设备上的长序列建模、自然语言处理、语音识别、时间序列预测等。其低计算成本的特性使其特别适用于资源受限的场景,例如移动设备和物联网设备。未来,SpikingSSM有望成为构建低功耗大型语言模型(LLMs)的关键技术。
📄 摘要(原文)
Known as low energy consumption networks, spiking neural networks (SNNs) have gained a lot of attention within the past decades. While SNNs are increasing competitive with artificial neural networks (ANNs) for vision tasks, they are rarely used for long sequence tasks, despite their intrinsic temporal dynamics. In this work, we develop spiking state space models (SpikingSSMs) for long sequence learning by leveraging on the sequence learning abilities of state space models (SSMs). Inspired by dendritic neuron structure, we hierarchically integrate neuronal dynamics with the original SSM block, meanwhile realizing sparse synaptic computation. Furthermore, to solve the conflict of event-driven neuronal dynamics with parallel computing, we propose a light-weight surrogate dynamic network which accurately predicts the after-reset membrane potential and compatible to learnable thresholds, enabling orders of acceleration in training speed compared with conventional iterative methods. On the long range arena benchmark task, SpikingSSM achieves competitive performance to state-of-the-art SSMs meanwhile realizing on average 90\% of network sparsity. On language modeling, our network significantly surpasses existing spiking large language models (spikingLLMs) on the WikiText-103 dataset with only a third of the model size, demonstrating its potential as backbone architecture for low computation cost LLMs.