Meta-Learning Bidirectional Update Rules (BLUR)
论文: 2104.04657v2
来源: ICML 2021
作者: Mark Sandler, Max Vladymyrov, et al. (Google Research)
1. 摘要与简介
本文提出了一种新型的广义神经网络框架,其中神经元和突触维护多个状态。作者展示了经典的基于梯度的反向传播(Backpropagation)可以看作是双状态网络的一个特例:一个状态用于激活值(activations),另一个状态用于梯度(gradients),其更新规则源自链式法则。
在此基础上,作者提出了 BLUR (Bidirectional Learned Update Rules) 框架。在这个框架中:
- 网络既没有显式的梯度概念,也从未接收过梯度。
- 突触和神经元使用由共享的低维“基因组(genome)”参数化的双向 Hebb 风格更新规则进行更新。
- 这些基因组可以通过元学习(Meta-learning)从头开始学习,使用常规优化技术或进化策略(如 CMA-ES)。
结果表明,这种学习到的更新规则可以泛化到未见过的任务,并且在一些标准计算机视觉和合成任务上,训练速度比基于梯度下降的优化器更快。
2. 方法:学习新型神经网络
2.1 基于神经元状态的梯度下降广义化
为了定义可能的更新规则空间,作者首先将经典的人工神经网络(ANN)和反向传播(BP)重新表述为一个多状态系统。
经典 ANN 前向传播:
$$h_j = \sigma\left(\sum_{i \in I(j)} w_{ij} h_i\right)$$
其中 $h_j$ 是神经元 $j$ 的激活值。
经典 BP 反向传播:
$$\frac{\partial L}{\partial h_i} = \sum_{j \in J(i)} w_{ij} \frac{\partial L}{\partial h_j} h'_j$$
其中 $h'_j$ 是激活函数的导数。
SGD 权重更新:
$$w_{ij} \leftarrow w_{ij} - \tilde{\eta} \frac{\partial L}{\partial h_j} h'_j h_i$$
这具有 Hebbian 学习规则的形式(突触前激活 $h_i$ 与突触后信号 $\frac{\partial L}{\partial h_j}h'_j$ 的乘积)。
双状态网络重构:
作者引入每个神经元的两个状态 $\mathbf{a}_i = (a_i^{(1)}, a_i^{(2)})$。
- $a_i^{(1)}$: 前向信号(激活值)。
- $a_i^{(2)}$: 反向信号(梯度/误差)。
通过定义常数矩阵 $\nu, \mu, \tilde{\nu}, \tilde{\mu}$ 和广义激活函数 $\phi$,BP 可以重写为:
- 前向传播:
$$a_j^c \leftarrow \phi^c \left( \sum_{i \in I(j),d} w_{ij} a_i^d \nu^{cd} \right)$$ - 反向传播:
$$a_i^{(2)} \leftarrow a_i^{(2)} \sum_{j \in J(i),d} w_{ij} a_j^d \mu^d$$ - 权重更新:
$$w_{ij} \leftarrow w_{ij} - \tilde{\eta} \sum_{c,d} a_j^c \tilde{\mu}^c a_i^d \tilde{\nu}^d$$
其中 $c, d \in \{1, 2\}$ 代表状态索引。这表明传统的梯度反向传播只是通用双状态网络的一个特例,其规则由一组固定的低维矩阵控制。
2.2 多状态双向更新规则 (BLUR)
受上述双状态解释的启发,作者提出了 BLUR 框架。BLUR 具有以下特点:
- 使用多通道非对称突触。
- 前向和后向路径使用相同的更新机制。
- 允许每个神经元的不同通道之间进行信息混合。
BLUR 的核心方程:
-
前向传播 (Forward pass):
$$a_{j}^{c} \leftarrow \sigma \left(fa_{j}^{c} + \eta \sum_{i \in I(j),d} w_{ij}^{c} \nu^{cd} a_{i}^{d}\right) \quad \text{(Eq. 5)}$$- $f, \eta$: 神经元的遗忘门和更新门(标量)。
- $\nu^{cd}$: 前向神经元变换矩阵,控制状态间的混合。
-
反向传播 (Backward pass):
$$a_{i}^{c} \leftarrow \sigma \left(fa_{i}^{c} + \eta \sum_{j \in J(i),d} w_{ji}^{c} \mu^{cd} a_{j}^{d}\right) \quad \text{(Eq. 6)}$$- 注意这里使用的是 $w_{ji}$,即反馈连接的权重,可以是独立于前向权重的(非对称)。
- $\mu^{cd}$: 后向神经元变换矩阵。
-
权重更新 (Weights update):
$$w_{ij}^c \leftarrow \tilde{f} w_{ij}^c + \tilde{\eta} \sum_{e,d} a_i^e \tilde{\nu}^{ec} \cdot \tilde{\mu}^{cd} a_j^d \quad \text{(Eq. 7)}$$- $\tilde{f}, \tilde{\eta}$: 突触的遗忘门和更新门。
- $\tilde{\nu}, \tilde{\mu}$: 突触变换矩阵,混合突触前 ($a_i$) 和突触后 ($a_j$) 的活性。
- 这是一个广义的 Hebbian 更新规则。
基因组 (Genome):
由矩阵 $\{f, \tilde{f}, \nu, \tilde{\nu}, \mu, \tilde{\mu}, \eta, \tilde{\eta}\}$ 组成的集合称为基因组。这些是元参数(Meta-parameters),在元训练(Meta-training)期间进行优化,而在具体任务的训练(Inner-loop)中保持固定。
主要区别于 BP:
- 加性更新: 前向和后向传播都使用加性更新(BP 在反向传播中使用乘性更新)。
- 非对称突触: $w_{ij} \neq w_{ji}$,更符合生物学特性。
- 无显式损失函数: 网络不需要显式的损失函数概念。Ground truth 可以直接作为信号输入到最后一层(例如更新第二个状态),网络会自动反向传播这个信号。
2.3 稳定性机制
为了防止权重无限制增长(Hebb 规则的常见问题),作者引入了两种机制:
-
激活归一化 (Activation Normalization):
类似于 Batch Normalization,对前向和后向激活进行归一化。这有助于训练更深的网络并提高泛化能力。 -
Oja 规则 (Oja's Rule):
为了实现权重饱和,作者推导了一种修正的 Oja 规则作为抑制项:
$$(\Delta w_{ij}^c)^{\text{Oja}} = -(\tilde{f} - 1)w_{ij}^c \sum_r (w_{rj}^c)^2 - \tilde{\eta} w_{ij}^c \sum_{r,e,d} w_{rj}^c a_r^e \tilde{\nu}^{ec} \cdot \tilde{\mu}^{cd} a_j^d$$
在实验中,通常只使用第一项(权重衰减项)并配以可学习的乘数即可。
3. 这种更新规则是梯度下降吗?
作者通过数值实验验证了 BLUR 学习到的更新规则 通常不等同于 任何已知损失函数的梯度下降(Gradient Descent)。
如果等同于梯度下降,则权重更新 $\Delta w$ 必须满足对称性条件:
$$\frac{\partial \Delta w_{ij}^c}{\partial w_{mn}^d} = \frac{\partial \Delta w_{mn}^d}{\partial w_{ij}^c}$$
数值计算表明,BLUR 学习到的规则并不满足这一条件(见原论文 Figure 2)。此外,作者还探讨了黎曼度量下的广义梯度下降,也未能找到等价性。这表明 BLUR 探索了一个比传统梯度下降更广泛的优化算法空间。
4. 实验结果
4.1 简单函数元学习
- 任务: 学习 XOR, two-moon 等 2D 函数。
- 结果: 学习到的基因组不仅能解决训练任务,还能泛化到未见过的任务(如 MNIST)。
4.2 泛化能力
- 设置: 在 MNIST 的裁剪/缩放版本上进行元训练。
- 结果:
- 基因组可以泛化到全尺寸 MNIST、Fashion MNIST 和 E-MNIST。
- 甚至可以泛化到布尔任务(out-of-domain)。
- 课程学习(Curriculum training)有助于学习更长的反向传播步数(unroll steps)。
4.3 与 SGD 对比
- 在 MNIST 上,BLUR 在训练初期通常比 SGD 收敛得更快(见原论文 Figure 5, 6)。
- 尽管 SGD 最终可能会达到略高的精度,但 BLUR 展示了极强的小样本/快速适应能力。
4.4 进化策略 (CMA-ES)
- 使用 CMA-ES 替代 SGD 进行元训练,以优化不可微的目标(如准确率)。
- 结果: CMA-ES 能够找到有效的基因组,且能处理更长的 unroll 步数(因为不需要通过时间反向传播梯度)。
4.5 跨架构泛化
- 发现: 在较深网络(如 4 层)上训练的基因组可以很好地泛化到较浅的网络(如 1-3 层)。
- 反之不然: 在浅层网络上训练的基因组难以泛化到深层网络。
5. 结论
- BLUR 提供了一个通用的协议,用于定义神经网络中的节点更新。
- 梯度下降只是该协议中可能的一种“基因组”。
- 可以通过元学习找到全新的、不依赖梯度的更新规则。
- 这些规则在某些任务上比传统神经网络训练更快,且具有良好的泛化能力。
- 该框架为生物学上更合理的学习机制(如非对称突触、无显式梯度)提供了计算基础。