Grokking as Manifold Selection Under Constraint
A transformer trained on (a + b) mod 113 memorizes first, then suddenly generalizes.
This is grokking. The baseline takes 5,760 epochs to get there. We add one thing: a small
probe that asks "can you predict your own hidden state from your input?" The mismatch
becomes a regularizer. Nothing else changes. Same optimizer, same architecture, same seed.
With the full stencil [-1,0,+1], the model groks at 1,920 epochs. That is 3.00x faster.
The interesting part is not the speedup. It is what happens when you take the constraint apart.
A stencil ablation isolates three temporal components. Functional coherence alone (k=0)
gives 2.55x. Backward admissibility alone (k=-1) gives only 1.52x. Combining backward
and forward without coherence [-1,+1] gives 3.08x, the fastest pair. Adding k=-1 to k=0
gives 2.88x. The full stencil gives 3.00x. The constraint is not just regularization. It
selects a manifold where the MLP's computation is geometrically consistent across time.
Generalization lives on that manifold. The probe finds it faster than weight decay alone.
architecture
class Transformer(nn.Module):
embed = Embed(114, 128)
pos_embed = PosEmbed(3, 128)
attn = Attention(128, 4, 32)
mlp = MLP(128, 512, ReLU)
unembed = Unembed(128, 113)
def forward(self, tokens): # [batch, 3]
x = embed(tokens) + pos_embed(tokens)
x = x + attn(x)
x = x + mlp(x)
return unembed(x)[:, -1, :]
setup
TASK ........... (a + b) mod p PRIME p ........ 113 ARCH ........... 1-layer decoder-only transformer d_model ........ 128 n_heads ... 4 d_mlp .......... 512 activation ReLU OPTIMIZER ...... AdamW lr ........ 1e-3 WEIGHT DECAY ... 1.0 betas ..... (0.9, 0.98) BATCH .......... full train ..... 30% SEED ........... 999 epochs .... 25,000
A live reproduction of Nanda et al. (2023), "Progress Measures for
Grokking via Mechanistic Interpretability" (ICLR Spotlight). Real weights, real gradients, real grokking.
After the phase transition, Fourier analysis reveals the model discovers the
discrete Fourier transform: integers become rotations on a circle, composed via trig identities.
Grokking is detected when both train and test accuracy reach 99% (Nanda et al. report >99.95% post-grok). Training continues to 99% stable basin.
Nanda et al. (2023). Progress Measures for Grokking. ICLR Spotlight.
Power et al. (2022). Grokking. arXiv:2201.02177.
source
Power et al. (2022). Grokking. arXiv:2201.02177.
source
Baseline Reproduction
Standard training without introspection constraints. Grokking emerges from weight decay alone.
stencil ablation
A small probe is added alongside the MLP. Each epoch the probe learns to predict
mlp_hidden from mlp_in, then the mismatch regularizes the model. The probe is
discarded after training. Click run add stencil to test each offset subset.
Expect: earlier test accuracy jump, possible loss spike during transition.
L_total = L_task + 0.1 * ||f(mlp_in) - mlp_hidden||^2
temporal offsets
k = -1 backward admissibility mlp_in(t) -> mlp_hidden(t-1) "does the current input predict the previous hidden state?" k = 0 functional coherence mlp_in(t) -> mlp_hidden(t) "does the current input predict the current hidden state?" k = +1 forward lookahead mlp_in(t-1) -> mlp_hidden(t) "does the previous input predict the current hidden state?"
mechanism
MLP path: mlp_in (128) -> fc1 + ReLU -> mlp_hidden (R^512).
Probe: 128 -> 64 -> ReLU -> 512. Separate optimizer, discarded after training.
two-phase training
Phase 1: model frozen. forward pass, detach activations, update probe only.
Phase 2: probe frozen. fresh forward pass, MSE regularizer added to task loss,
gradients update model not probe.
This is a constraint, not distillation.
buffer and leak safety
A circular buffer stores past activations so probes can reference nearby epochs.
No future information leaks: the buffer only stores what SGD already computed.
wavetable display