Home Learn Blog Game
Learn Papers Monkey Learning Tess: A Scalable Temporally And Spatially Loca Learning Rule For Spiking Neural Networks

Course Structure

2502.01837v1.pdf 2502.01837V1 Main

TESS: A Scalable Temporally and Spatially Local Learning Rule for Spiking Neural Networks

时间: 2025-02

链接:https://arxiv.org/abs/2502.01837


Abstract

SNN推理十分高效,但是训练成本很大,依赖于 BPTT 算法。这一方法复杂度较高,同时时间空间也并不local,训练过程不太有生物合理性。

本文提出的 TESS 方法,是一种Three-factor Learning Rules,特点为:
- 时间空间都local
- 内存和计算开销与时间步长 $T$ 无关,仅随神经元数量线性增长。复杂度降低。
- 主要涉及的东西是资格迹、STDP等等。


Introduction

背景与挑战

  • 边缘智能: 需要在本地设备上进行实时模型适应(On-device learning),保护隐私并减少延迟。
  • SNN 的困境: SNN 利用稀疏的脉冲事件进行计算,推理很节能。但训练很难。
    • BPTT (Backpropagation Through Time): 目前最强的训练算法。它需要将误差通过时间轴和层级反向传播。
    • 复杂度: 时间复杂度 $O(TLn^2)$,内存复杂度 $O(TLn)$。这意味着输入序列越长,训练越慢、越占内存。

现有解决方案的局限

  1. 空间局部方法 (Spatial Locality): 如反馈对齐 (FA, DFA),试图避免层间反向传播,但深层网络收敛差。
  2. 时间局部方法 (Temporal Locality): 如 e-prop, OSTL。使用“资格迹”解决时间依赖,但通常仍需要全局误差信号(即空间上不局部),内存复杂度仍较高 ($O(Ln^2)$)。
  3. S-TLLR: 作者之前的这类工作虽然降低了内存 ($O(Ln)$),但仍依赖层间反向传播。

TESS 的突破

TESS 实现了全局部(Fully Local):
* 时间局部: 使用资格迹处理历史信息。
* 空间局部: 使用局部生成的学习信号(基于固定随机投影),无需跨层传误差。
* Insight: TESS 就像是让每个神经元自己根据“刚才发生了什么”(资格迹)和“现在的局部目标是什么”(局部信号)来更新权重,而不需要等待整个网络运行完并传回反馈。


Background

A. LIF model

老生常谈了,简单写一点吧

$$u_i^{(l)}[t] = \gamma (u_i^{(l)}[t-1] - v_{th}o_i^{(l)}[t-1]) + \sum_j W_{ij}^{(l)} o_j^{(l-1)}[t]$$
$$o_i^{(l)}[t] = \Theta(u_i^{(l)}[t] - v_{th})$$
* $u$: 膜电位 (Membrane Potential)。
* $o$: 输出脉冲 (0 或 1)。
* $\gamma$: 泄漏因子 (Leak factor)。
* $\Theta$: 阶跃函数 (Heaviside step function)。

SNN 训练的核心难点在于 $\Theta$ 函数不可导(导数为0或无穷大)。因此在计算梯度时,通常使用 Surrogate Gradient (代理梯度),即在反向传播时用一个平滑函数(如 Sigmoid 或分段线性函数)的导数来近似 $\Theta$ 的导数。

B. Gradient-based optimization for SNNs

这一部分回顾了一下SNN是如何计算梯度的,使用的是经典的BPTT方法。

$$\frac{dL}{dW^{(l)}} = \sum_{t=1}^T \frac{\partial L}{\partial u^{(l)}[t]} \frac{\partial u^{(l)}[t]}{\partial W^{(l)}}$$
然后指出 $\frac{\partial L}{\partial u^{(l)}[t]}$ 这一项包含了对未来的依赖(时间反向传播)和对后层的依赖(空间反向传播),认为它不是时间空间local的。

C. Three-factor Learning Rules

这一部分还是值得深入研究的,先简单写一下。

也就是所谓三因子学习规则,是受生物启发的学习规则,认为权重的更新与一下三个东西有关:
- pre-synaptic activity
- post-synaptic activity
- modulatory signal(我个人认为这个对应大脑的弥散调制系统)

有几个核心概念:

资格迹 $\boldsymbol{e}_{ij}^{(l)}[t]$,大概的动力学规则是
$$\boldsymbol{e}_{ij}^{(l)}[t] = \beta \boldsymbol{e}_{ij}^{(l)}[t-1] + f(\boldsymbol{o}_i^{(l)}[t])g(\boldsymbol{o}_j^{(l-1)}[t])$$

然后认为权重的更新和全局的signal与资格迹共同决定:
$$\Delta \boldsymbol{W}_{ij} = \sum_{t} \boldsymbol{m}_{i}[t] \boldsymbol{e}_{ij}^{(l)}[t]$$

Proposed Method - A Scalable Fully Local Learning Rule

本文所提出的方法

a) Temporal Credit Assignment with Eligibility Traces

作者说$\boldsymbol{e}_{ij}^{(l)}[t] = \beta \boldsymbol{e}_{ij}^{(l)}[t-1] + f(\boldsymbol{o}_i^{(l)}[t])g(\boldsymbol{o}_j^{(l-1)}[t])$这个资格迹公式的复杂度太高了,有 $O(n^2)$。作者认为把 $\beta$ 设置为0,可以降低复杂度到 $O(n)$。

TESS 定义了两个资格迹分量:
1. 因果迹 (Causal Trace): 表示前后两个神经元有因果关系。
2. 非因果迹 (Non-causal Trace): 表示前后两个神经元没有因果关系。

1. 因果资格迹 (Pre-synaptic based)

  • 资格迹公式里的 $f(\cdot)$ 定义为膜电位 $u^{(l)}[t]$ 的辅助激活函数 $\Psi(\cdot)$。
  • 资格迹公式里的 $g(\cdot)$ 定义为输入脉冲的低通滤波器:$\sum_{t'=0}^t \lambda_{pre}^{t-t'} o^{(l-1)}[t']$。

  • 递归变量 $q^{(l)}[t]$: 为了前向计算,引入递归变量 $q$ 来追踪突触前活动:
    $$q^{(l)}[t] = \lambda_{pre}q^{(l)}[t - 1] + o^{(l-1)}[t] \quad (5)$$

  • 因果资格迹公式:
    $$\boldsymbol{e}_{pre}^{(l)}[t] = \alpha_{pre}\Psi(\boldsymbol{u}^{(l)}[t]) \otimes q^{(l)}[t] \quad (6)$$
    其中 $\alpha_{pre}$ 控制幅度(实验中设为 1)。

2. 非因果资格迹 (Post-synaptic based)

  • 递归变量 $h^{(l)}[t]$: 引入变量 $h$ 来追踪突触后活动历史:
    $$\boldsymbol{h}^{(l)}[t] = \lambda_{post}\boldsymbol{h}^{(l)}[t - 1] + \Psi(\boldsymbol{u}^{(l)}[t - 1]) \quad (7)$$
  • 非因果资格迹公式:
    $$\boldsymbol{e}_{post}^{(l)}[t] = \alpha_{post}\boldsymbol{h}^{(l)}[t] \otimes \boldsymbol{o}^{(l-1)}[t] \quad (8)$$
    其中 $\alpha_{post}$ 决定非因果项的包含方式(+1: 正向包含, -1: 负向包含, 0: 排除)。

b) Spatial Credit Assignment with Locally Generated Learning Signals

简单来说就是:把每一层的脉冲输出投影到任务子空间,然后根据真实标签来计算误差信号。也就是所谓的学习信号$m^{(l)}[t]$。

我觉得这样子干很奇怪,但是你就说是不是没有误差传播吧。(还不如random alignment呢)

$$m^{(l)}[t] = B^{(l)\top} (f(B^{(l)}o^{(l)}[t]) - y^*) \quad (9)$$

$f(\cdot)$: 任务相关的函数(分类任务用 Softmax,回归任务用 Identity)。

矩阵 $B^{(l)}$ 的设计: 这是一个固定的二进制矩阵,其列对应于方波函数 (Square Wave Functions)。

Insight: 这种设计有助于通过分配不同的空间频率来同步层内神经元的活动。方波简单且硬件友好。

优势: 这种局部生成方式消除了层间反向传播,将计算复杂度从 $O(n^2)$ 降低到 $O(Cn)$。

c) Weight Updates

权重更新由学习信号 $m^{(l)}[t]$ 调制资格迹来计算。

  • 因果项更新:
    $$\Delta W_{pre}^{(l)}[t] = (m^{(l)}[t] \odot \alpha_{pre}\Psi(u^{(l)}[t])) \otimes q^{(l)}[t] \quad (10)$$
  • 非因果项更新:
    $$\Delta W_{post}^{(l)}[t] = (m^{(l)}[t] \odot \alpha_{post}h^{(l)}[t]) \otimes o^{(l-1)}[t] \quad (11)$$
  • 总权重更新:
    $$\Delta W^{(l)}[t] = \Delta W_{pre}^{(l)}[t] + \Delta W_{post}^{(l)}[t] \quad (12)$$

A. Algorithm Implementation

TESS 是迭代运行的。对于每一层 $l$ 的伪代码 (Algorithm 1) 如下:

  1. 初始化 $u, h, q$ 为 0。
  2. 对于每个时间步 $t = 1, \dots, T$:
    • 根据 Eq (7) 更新 $h^{(l)}[t]$ (历史后突触迹)。
    • 根据 Eq (1) 和 (2) 更新 $u^{(l)}[t]$ 和 $o^{(l)}[t]$ (LIF 神经元前向传播)。
    • 根据 Eq (5) 更新 $q^{(l)}[t]$ (历史前突触迹)。
    • 如果 $t \ge t_l$ (开始生成信号的时间):
      • 根据 Eq (9) 计算局部学习信号 $m^{(l)}[t]$。
      • 根据 Eq (10)-(12) 计算权重更新 $\Delta W^{(l)}[t]$。
  3. 最后更新权重: $W^{(l)} = W^{(l)} + \eta \sum \Delta W^{(l)}[t]$。

B. Computational and Memory Cost

这是 TESS 的核心优势所在。

1) Memory Requirements

不考虑神经元状态变量(如膜电位),只考虑梯度计算相关的内存。
* BPTT: 需要存储所有时间步的激活值以进行反向传播。
$$Mem_{BPTT} = T \sum_{l=0}^L n^{(l)} \quad (13)$$
* TESS: 只需要存储当前的递归变量 $q$ 和 $h$,不需要存储历史。
$$Mem_{TESS} = 2 \sum_{l=0}^L n^{(l)} \quad (14)$$
结论: TESS 的内存复杂度是 $O(Ln)$,与时间步 $T$ 无关,且随神经元数量线性增长(而非 $n^2$)。

2) Computational Requirements

比较乘加运算 (MAC) 的数量。
* BPTT: 需要反向传播误差通过权重矩阵 $W$ ($n \times n$)。
$$MAC_{BPTT} = T \sum_{l=1}^L n^{(l)} \times n^{(l-1)} \quad (15)$$
* TESS: 只需要通过投影矩阵 $B$ ($C \times n$) 计算局部信号。
$$MAC_{TESS} = (T - t_l) \sum_{l=1}^L 2 \times n^{(l)} \times C \quad (17)$$
结论: 由于类别数 $C \ll n$,TESS 的计算量大幅降低,复杂度从 $O(n^2)$ 降为 $O(Cn)$。

C. Comparison with other local learning rules

方法 内存复杂度 时间复杂度 时间局部性 空间局部性
BPTT $O(TLn)$ $O(TLn^2)$ ✗ ✗
e-prop / OSTL $O(Ln^2)$ $O(Ln^2)$ ✓ ✗
S-TLLR $O(Ln)$ $O(Ln^2)$ ✓ ✗
TESS (Ours) $O(Ln)$ $O(LCn)$ ✓ ✓

TESS 是唯一同时实现了线性内存复杂度 ($O(Ln)$) 和 线性时间复杂度 ($O(LCn)$) 且保持全时空局部性的方法。


Experimental Evaluation

实验设置

  • 数据集: CIFAR10, CIFAR100, IBM DVS Gesture, CIFAR10-DVS。
  • 模型: VGG-9 架构。
  • 训练: 200 epochs, Adam 优化器。

结果分析

  1. 消融实验 (Ablation Study):

    • 探究 $\alpha_{post}$ (非因果项) 的作用。结果显示包含正向的非因果项 ($\alpha_{post}=+1$) 通常能提高精度 (0.2% - 1.8%)。
  2. 性能对比 (Performance):

    • IBM DVS Gesture: TESS (98.56%) 甚至超过了 BPTT baseline (97.95%)。
    • CIFAR10: TESS (92.55%) 与 BPTT (92.55%) 持平。
    • CIFAR100: TESS (70.00%) 略优于 BPTT (69.28%)。
    • CIFAR10-DVS: TESS (75.00%) 略低于 BPTT (76.40%),差距约 1.4%。
  3. 资源消耗:

    • MACs (计算量): 相比 BPTT 减少了 205倍 到 661倍。
    • Memory (内存): 相比 BPTT 减少了 3倍 到 10倍。

Conclusions

TESS 提出了一种极其高效的 SNN 训练方法。
1. 它打破了 BPTT 对时间和空间的双重依赖。
2. 它利用局部生成的误差信号和资格迹,实现了 $O(Ln)$ 的线性内存复杂度和极低的计算复杂度。
3. 在保证极低开销的同时,精度几乎没有损失(在某些数据集上甚至更优)。

这种方法非常适合部署在神经形态芯片(Neuromorphic Hardware)或边缘嵌入式设备上,因为它不需要缓存大量的历史数据,也不需要复杂的全局同步。

Previous

© 2025 Ze Rui Liu.