Growing Neural Cellular Automata
原文:https://distill.pub/2020/growing-ca/
原文开头有一个交互式的界面展示成果,强烈建议去试一试。
引言
这篇文章开头就用非常引人入胜的一篇引言(至少把我吸引住了)去讲述了Cellular Automata的重要性。作者是从生物学的角度入手的,一个生物都是由最开始的egg细胞发育而来,一个细胞里的基因完全编码了,这个细胞应该怎样分化分裂,变成一个完整的个体。这个过程和我们设计的cellular automata是一样的,可以说,这篇论文的目标就是设计出一个我们自己的基因,可以从一个初始状态繁衍成任何的最终状态。
Cellular Automata有三个典型特点:
- 局部性:每个格点只看小邻域(如 3x3);
- 共享规则:所有格点使用同一条更新规则;
- 涌现行为:全局形态由局部交互产生。
但经典 CA 的更新规则通常是“手写的离散规则”,不易针对某个目标形态进行设计。GNCA 的核心就是把“规则”变成可学习的连续函数:
- 用一个小神经网络 $f_\theta$ 作为每个细胞共享的更新函数;
- 让系统运行 $T$ 步后输出一张图(RGBA);
- 用损失函数表达“我想长成什么样”;
- 对 $\theta$ 做梯度下降。
方法
状态表示:每个细胞用一个 16 维的向量表示
在一个二维网格上,每个格点(细胞)有一个 16 维连续状态向量:
$$ \mathbf{S}_t \in \mathbb{R}^{H\times W\times 16}. $$
其中通道语义:
- 前 3 个通道:可视颜色 $\text{RGB}$;
- 第 4 个通道:$\alpha$(透明度/前景掩码),目标图案前景像素 $\alpha=1$,背景 $\alpha=0$;
- 剩余通道:hidden channels(没有预定义含义,由规则自发学会携带“化学浓度/电位/信号”等内部信息)。
“生命”与空细胞:alpha 作为 alive masking
文中把 $\alpha$ 赋予了“是否活着”的语义,这个应该是借助了图灵那个原始的生命游戏去设计的:
- 若某细胞的自身或邻居存在“成熟细胞”($\alpha>0.1$),则该区域被认为是“活”的;
- 否则被认为是“空/死”,并且其 所有通道在每步都会被显式清零,避免空细胞携带隐藏状态参与计算。
用公式表达(对应原文伪代码):
设 $A_t(x,y) = \mathbf{S}_t(x,y,3)$ 为 alpha 通道(0-based 下标 3),定义
$$ \text{alive}_t(x,y)=\mathbf{1}\Big(\max_{(i,j)\in\mathcal{N}_{3\times3}(x,y)} A_t(i,j) > 0.1\Big), $$
然后做掩码:
$$ \mathbf{S}_t(x,y,:) \leftarrow \text{alive}_t(x,y)\cdot \mathbf{S}_t(x,y,:). $$
直觉:这相当于把“生命边界”做成一个可控的计算区域,让生长从种子向外扩展,但又不会让远处空白格点无意义地积累隐状态。
更新规则
GNCA 的单步更新不是直接“看 3x3 状态然后丢进网络”,而是拆成几段,带有很强的生物类比。
Perception:固定 Sobel 梯度感知(从 16 维变 48 维)
每个细胞先做“感知”:对每个通道做 Sobel 卷积估计 $x,y$ 方向偏导。
Sobel 核(原文):
$$ K_x= \begin{bmatrix} -1&0&+1\\ -2&0&+2\\ -1&0&+1 \end{bmatrix}, \qquad K_y = K_x^\top. $$
对整张状态图做卷积(按通道独立):
$$ \mathbf{G}^x_t = K_x * \mathbf{S}_t,\qquad \mathbf{G}^y_t = K_y * \mathbf{S}_t, $$
然后把“自身状态 + 两个梯度”拼起来形成感知向量:
$$ \mathbf{P}_t = \text{concat}(\mathbf{S}_t,\mathbf{G}^x_t,\mathbf{G}^y_t)\in\mathbb{R}^{H\times W\times 48}. $$
原文特别强调这里用固定的 Sobel(而不是学习一个 3x3 kernel),动机是:真实细胞常依赖梯度作为导航信号。
梯度“体现在哪里”?
梯度是被作为额外输入特征拼进了每个细胞的感知向量里。
把每个通道拆开写会更直观。令第 $c$ 个通道($c\in\{1,\dots,16\}$)的标量场为
$$ S_t^{(c)}(x,y)=\mathbf{S}_t(x,y,c). $$
那么 Sobel 卷积得到的“梯度特征”就是
$$ G_t^{x,(c)}(x,y)=\sum_{i=-1}^{1}\sum_{j=-1}^{1}K_x(i,j)\,S_t^{(c)}(x+i,y+j), $$
$$ G_t^{y,(c)}(x,y)=\sum_{i=-1}^{1}\sum_{j=-1}^{1}K_y(i,j)\,S_t^{(c)}(x+i,y+j). $$
最终给到更新网络 $f_\theta$ 的、单个细胞 $(x,y)$ 的输入向量就是把三类量按通道拼起来(每个通道 3 个数:值、x梯度、y梯度):
$$ \mathbf{p}_t(x,y) = \Big[ \underbrace{S_t^{(1)}(x,y),\dots,S_t^{(16)}(x,y)}_{\text{自身状态 16 维}}, \underbrace{G_t^{x,(1)}(x,y),\dots,G_t^{x,(16)}(x,y)}_{x\text{方向梯度 16 维}}, \underbrace{G_t^{y,(1)}(x,y),\dots,G_t^{y,(16)}(x,y)}_{y\text{方向梯度 16 维}} \Big] \in\mathbb{R}^{48}. $$
所以“梯度体现在哪里?”——就在 $\mathbf{p}_t(x,y)$ 的后 32 维里。
也可以把它理解为:每个细胞不只是“看到邻居是什么”,而是“感到周围信号沿 $x/y$ 的变化趋势”。这使得 $f_\theta$ 很容易学到诸如:
- 目标边界附近梯度大 → 该生长/填充/修复;
- 内部区域梯度小 → 该稳定/抑制生长;
- 梯度方向指示“往哪里扩张/收缩”。
Update rule:一个很小的“逐细胞 MLP/1x1 Conv”(48→128→16)
对每个细胞的 48 维感知向量,应用同一个小网络 $f_\theta$ 输出增量更新(残差风格):
$$ \Delta \mathbf{s}_t(x,y) = f_\theta(\mathbf{p}_t(x,y)) \in \mathbb{R}^{16}, $$
结构(原文伪代码):
$$ 48 \xrightarrow{\text{dense / 1x1 conv}} 128 \xrightarrow{\text{ReLU}} 128 \xrightarrow{\text{dense / 1x1 conv}} 16. $$
并且有两个非常关键的实现选择:
- 输出是增量:下一步是 $\mathbf{s}+\Delta\mathbf{s}$,这让规则更像“连续动力系统”的离散化;
- 最后一层权重零初始化:使得初始规则接近“什么都不做”,训练更稳定(避免一开始就把状态炸飞)。
Stochastic/asynchronous update:异步更新(训练时概率 0.5)
传统 CA 同步更新所有细胞,相当于假设有一个“全局时钟”。文中为了更像自组织系统,引入随机异步更新:
- 每个细胞独立地以某概率执行更新;
- 其余时候“等待”,等价于对更新向量做 per-cell dropout。
训练时使用的概率是 0.5(原文明确写出并给出伪代码):
$$ \mathbf{M}_t(x,y) \sim \text{Bernoulli}(0.5), \qquad \mathbf{S}'_t = \mathbf{S}_t + \mathbf{M}_t\odot\Delta\mathbf{S}_t. $$
注意这里 $\mathbf{M}_t$ 是每个格点一个标量(广播到 16 个通道),$\odot$ 表示逐元素乘。
Alive masking:把“死区”强制清零
最后应用第 2.2 节定义的 alive mask:
$$ \mathbf{S}_{t+1} = \text{aliveMask}(\mathbf{S}'_t). $$
总结
把上述合并,可写成:
$$ \mathbf{P}_t = \text{perceive}(\mathbf{S}_t),\quad \Delta\mathbf{S}_t=f_\theta(\mathbf{P}_t),\quad \mathbf{S}'_t=\mathbf{S}_t+\mathbf{M}_t\odot\Delta\mathbf{S}_t,\quad \mathbf{S}_{t+1}=\text{aliveMask}(\mathbf{S}'_t). $$
训练设置
初始化
原文给出一致设定:初始网格全 0,仅中心一个种子细胞激活:
- RGB 设为 0(因为背景是白的,避免种子一开始就显眼);
- alpha + hidden(通道 3 及之后)设为 1.0。
即:
$$ \mathbf{S}_0 = 0,\qquad \mathbf{S}_0(H/2,W/2, 3:)=1. $$
训练展开步数:随机取 $[64,96]$
每次训练迭代随机采样 CA 运行步数 $T\sim U\{64,\dots,96\}$(原文给出区间),意图是:
- 不只在某个固定步数“恰好长对”,而是要在一个区间内都稳定;
- 这为后续的“持久性/吸引子”埋下约束。
损失:对 RGBA 做像素级 L2
训练目标图案给出 RGBA(RGB + alpha),在最后一步对比模型当前 RGBA 与目标图像:
$$ \mathcal{L} = \sum_{x,y}\left\|\text{RGBA}(\mathbf{S}_T(x,y)) - \text{RGBA}^*(x,y)\right\|_2^2. $$
训练稳定性:梯度归一化
原文提到在训练后期观察到不稳定(损失突然跳变),并用对每个变量的梯度做 L2 归一化来缓解(效果类似 weight normalization 的某些稳定作用)。
Experiment 1:Learning to Grow
目标
从单个 seed 出发,在 $T\in[64,96]$ 步后生成目标 emoji 图案。
现象:长期行为差异很大
即便训练损失都能降下来,长时间滚动(超过训练步数)时会出现明显分叉:
- 有的会死亡(图案消失);
- 有的会失控生长(不停扩张);
- 偶尔才出现“差不多稳定”的模型。
这引出核心问题:我们并没有强制目标图案成为系统吸引子;只是在“从 seed 到目标”的轨迹上拟合得很好,但目标附近的动力学可能是不稳定的。
Experiment 2:What persists, exists(把目标变成吸引子:sample pool)
动力系统视角
把整张网格看成一个高维动力系统:
- 状态是 $\mathbf{S}$(维度巨大);
- 更新规则由 $f_\theta$ 决定;
- 训练是在“调动力学”,希望目标形态是吸引子(从很多邻近状态都会回到它)。
一种直接办法是“拉长时间、周期性施加损失”,但这会带来巨大的 BPTT 内存开销。
Sample pool 策略(原文伪代码给了完整超参)
作者提出更省内存的办法:维护一个状态池,让训练不断从“历史末态”继续出发,相当于用很多不同起点逼迫系统在目标附近形成吸引子。
关键超参/流程(原文给出数字):
- pool 大小:1024
- 每次取 batch:32
- 为避免“灾难性遗忘”,每个 batch 都会把 最高损失样本替换成原始 seed
训练循环的本质逻辑可以概括为:
- 从 pool 采样一批当前状态作为起点;
- 运行 CA 若干步、算损失、对 $\theta$ 做更新;
- 把这批运行后的状态写回 pool(作为新的起点分布)。
直觉:训练早期,pool 里充满各种“半成品/错误态”;模型被迫学会从这些坏状态“救回来”。训练后期,pool 的分布逐渐集中到“接近目标”的区域,训练就变成“把目标附近雕刻成稳定吸引子”。
更具体地说:pool 里的“状态”是什么?batch 的“起点”怎么来?
pool 存的不是图片,也不是单个细胞,而是一整个网格的完整状态。
- 一个 pool entry 记为 $\mathbf{S}\in\mathbb{R}^{H\times W\times 16}$;
- 它可以是“刚开始的 seed”,也可以是“已经长到一半的半成品”,也可以是“接近完成的形态”;
- 当我们说“用它做起点”,意思就是:把它当作本次 rollout 的初始状态 $\mathbf{S}_0$。
为了把流程写得像“可执行算法”,我们先把 CA 的单步更新(包含 perceive、update、stochastic、alive mask)整体记成一个算子:
$$ \mathbf{S}_{t+1}=\Phi_\theta(\mathbf{S}_t;\,\mathbf{M}_t), $$
其中 $\mathbf{M}_t$ 是随机异步更新的掩码(训练时 fire rate=0.5)。
再定义 rollout(展开 $T$ 步):
$$ \text{Rollout}_\theta(\mathbf{S}_0,T)=\mathbf{S}_T. $$
那么一次训练迭代(有 pool 的版本)其实就是下面这套“采样—展开—算loss—回写”的循环。
Pool training:一步一步的训练循环(比原文更“操作化”)
下面这段是“概念上等价”的伪代码,但我把你关心的 batch=32 到底怎么“合成一个” 写得更贴近真实实现(重点是张量形状与 stack/gather/scatter 的关系):
# 常量(原文给了典型值)
N = 1024 # pool size
B = 32 # batch size
T ~ Uniform{64..96}
seed = zeros(H, W, 16)
seed[H//2, W//2, 3:] = 1.0
pool = [seed for _ in range(N)]
for it in range(num_iterations):
# 1) 从 pool 里抽一批“起点状态”的索引(长度为 B=32)
idx = random_choice(range(N), size=B) # idx.shape == (B,)
# 2) 把 32 个 (H,W,16) 的 entry “堆叠”成一个 batch 张量
# 这就是你问的“合成一个”:多出来的第一维就是 batch 维度
# batch0.shape == (B, H, W, 16)
batch0 = stack([pool[i] for i in idx], axis=0)
# 2) 为了避免遗忘:强制让 batch 里至少有一个从 seed 开始的样本
# 原文做法:把“最差的那个”替换成 seed(更能清理 pool 里的坏状态)
# 这里的“最差”通常指:用当前模型做一次短 rollout 后的 loss 最大
# 实现上一般是:Rollout 接受 (B,H,W,16) → 输出 (B,H,W,16),loss 输出 (B,)
batchT0 = Rollout_theta(batch0, T) # (B,H,W,16)
losses0 = loss(batchT0, target) # (B,)
worst = argmax(losses0) # 一个标量下标 0..B-1
batch0[worst] = seed # 用 seed 覆盖这一个样本(形状匹配 (H,W,16))
# 3) 正式训练:对每个起点展开 T 步,算 loss,对 θ 反向传播(BPTT)并更新
batchT = Rollout_theta(batch0, T) # (B,H,W,16)
losses = loss(batchT, target) # (B,)
L = mean(losses) # 标量
theta = theta - lr * grad_theta(L)
# 4) 把“这次跑到的末态”写回 pool(关键!)
# 注意:idx[k] 指向 pool 里的哪个槽位;batchT[k] 是 batch 中第 k 个样本的末态
for k in range(B):
pool[idx[k]] = batchT[k] # 每次写回一个 (H,W,16)
一句话总结 batch=32 怎么“合成一个”:就是把 32 个形状为 $(H,W,16)$ 的网格状态,沿着新加的 batch 维度 stack 成一个 $(32,H,W,16)$ 的四维张量,然后 rollout/loss 都在这第一维上并行计算。
这段里最关键的是第 4 步:把末态写回 pool。
如果你不写回 pool,而是每次都从 seed 开始,那你训练到的只是一条“从 seed 到目标的轨迹”(Experiment 1 的味道更浓),目标未必会成为吸引子。
写回 pool 的效果是:
- 训练初期:pool 里出现大量“乱七八糟/半成品”的状态;
- 训练被迫处理“从这些状态出发如何回到目标”,等价于在更大范围里塑造吸引盆;
- 训练后期:pool 会被末态不断替换成“更接近目标”的状态,模型进一步学会“保持 + 微调”,从而更稳定。
为什么要“把最高 loss 的样本替换成 seed”?
直觉上有两个作用:
- 防遗忘:永远保留“从头生长”的能力(否则模型可能只会修修补补,不会从单细胞重新长出来)。
- 清理 pool:替换最差样本能把 pool 里的“灾难态”逐渐冲刷掉,让训练更稳定(原文也提到这点)。
Experiment 3:Learning to regenerate(通过“训练时损伤”扩大吸引域)
不训练也可能出现“自发再生”
作者展示:Experiment 2 的一些模型在被切掉半边或挖洞后,会出现一定的再生趋势(尤其某些图案更明显),但并不稳定一致。
原因:如果目标只是一个小吸引盆,那么严重损伤可能把系统推出吸引域;它就可能失控、生长爆炸、过度稳定不响应、甚至自毁。
训练中引入损伤(pool-sampled damage)
策略:在每次训练迭代中,对从 pool 采样的一部分状态进行随机损伤(用随机圆形区域擦除为 0),然后要求系统仍能回到目标。
原文给出一个具体做法(用于说明训练批次如何构造):
- 从 pool 采样 8 个状态;
- 用 seed 替换其中最高损失的样本;
- 把其中损失最低的 3 个样本做随机圆形擦除(置零);
- 训练并把输出写回 pool。
结果:被“见过损伤”的模型展现出更强、且能泛化到未见过的损伤类型(比如矩形切割)的再生能力。
从动力系统语言说:损伤训练在扩大目标的basin of attraction(吸引域/吸引盆)。
更具体:damage 是怎么加的?为什么是“伤最好的几个”?
Experiment 3 的“加损伤训练”可以直接理解为:我们故意把一部分起点 $\mathbf{S}_0$ 变成 $\tilde{\mathbf{S}}_0$(被擦除了一块),然后依旧要求 rollout 后回到目标。
一个概念等价的流程是(和原文描述一致):
# 从 pool 抽 8 个起点(原文演示用 8)
idx = random_choice(range(N), size=8)
batch0 = [pool[i] for i in idx]
# 用当前模型估计每个起点“离目标多远”
losses0 = [loss(Rollout_theta(s0, T), target) for s0 in batch0]
# 最高 loss 的那个换成 seed(从头生长能力 + 清理坏状态)
worst = argmax(losses0)
batch0[worst] = seed
# 选 3 个最低 loss(已经长得最好/最接近目标的)来做损伤
best3 = argsort(losses0)[:3]
for k in best3:
batch0[k] = erase_random_circle(batch0[k]) # 把一个随机圆内的所有通道置 0
# 然后像 Experiment 2 一样 rollout、训练、写回 pool
batchT = [Rollout_theta(s0, T) for s0 in batch0]
L = mean([loss(sT, target) for sT in batchT])
theta = theta - lr * grad_theta(L)
for k in range(8):
pool[idx[k]] = batchT[k]
为什么更倾向“伤最好的几个(最低 loss)”?
- 训练初期本来就很差,再伤只会让信号更乱,容易学不到东西;
- “最好的几个”已经处在目标吸引盆附近,你对它做中等程度的扰动,相当于在吸引盆周围采样更多点,能更有效地学到“回到目标”的动力学;
- 这正对应“扩大吸引域”:让更多被扰动的状态也能回到同一个目标形态。
Experiment 4:Rotating the perceptive field(不重训就旋转生成图案)
这部分很有意思:因为感知阶段用的是梯度方向(Sobel),如果我们把“传感器方向”旋转,相当于在坐标系中改变了细胞感知的方向。
作者用一个二维旋转矩阵对 Sobel 核做线性组合:
$$ \begin{bmatrix} K_x\\ K_y \end{bmatrix} = \begin{bmatrix} \cos\theta & -\sin\theta\\ \sin\theta & \cos\theta \end{bmatrix} * \begin{bmatrix} \text{Sobel}_x\\ \text{Sobel}_y \end{bmatrix}. $$
用旋转后的 $K_x,K_y$ 重新计算梯度感知,模型就能在不重新训练的情况下,生成(相当程度上)旋转后的目标图案。
补一句更“机制化”的解释:更新网络 $f_\theta$ 学到的是“给定(自身值 + 梯度)该怎么更新”。当你把梯度测量的坐标轴整体旋转,相当于把它看到的“方向信息”也旋转了,于是它会沿旋转后的方向去执行同样的生长/修复策略,从而出现旋转后的形态。
这体现出 GNCA 规则对“底层像素格点细节”有一定鲁棒性:即便像素旋转在离散网格上不是完美的连续旋转,系统仍能自组织出合理结果。
超参数
- 状态维度:16
- 可见通道:RGB(0~2)+ alpha(3)
- alive 判定:对 alpha 做 3x3 max-pool,阈值 0.1;若邻域无成熟细胞则整细胞状态清零
- 感知:固定 Sobel $K_x,K_y$;拼接 state + grad_x + grad_y → 48 维
- 更新网络:48→128(ReLU)→16;最后一层权重零初始化;输出为增量 $\Delta s$
- 异步更新(训练时):每格点以 0.5 概率执行更新(等价 per-cell dropout)
- 训练展开步数:每次采样 $T\in[64,96]$
- 损失:像素级 L2,比较 RGBA 与目标
- pool 训练:pool=1024;batch=32;将 batch 中最高损失样本替换为 seed;输出写回 pool
- 再生训练(示例批构造):从 pool 采样 8;最高损失替换为 seed;对损失最低的 3 个做随机圆形擦除置零;训练后写回 pool
局限与可延伸方向
- 表达能力 vs 可解释性:规则是神经网络,隐藏通道的语义难以直接解释;但这也正是“可自组织”的代价。
- 目标仍然是像素级:损失直接对 RGBA,适合图案生成;若要更抽象的形态目标(拓扑、骨架、器官功能)需要更高级的损失或约束。
- 尺度与泛化:文中图案尺寸/步数是固定范围,尽管出现了一定的鲁棒泛化(比如旋转、未见损伤类型),但严格的跨尺度泛化仍不是自动保证的。