Grokking as Manifold Selection Under Constraint
A transformer trained on (a + b) mod 113 memorizes first, then suddenly generalizes. This is grokking. The baseline takes 5,760 epochs to get there. We add one thing: a small probe that asks "can you predict your own hidden state from your input?" The mismatch becomes a regularizer. Nothing else changes. Same optimizer, same architecture, same seed.
With the full stencil [-1,0,+1], the model groks at 1,920 epochs. That is 3.00x faster. The interesting part is not the speedup. It is what happens when you take the constraint apart.
A stencil ablation isolates three temporal components. Functional coherence alone (k=0) gives 2.55x. Backward admissibility alone (k=-1) gives only 1.52x. Combining backward and forward without coherence [-1,+1] gives 3.08x, the fastest pair. Adding k=-1 to k=0 gives 2.88x. The full stencil gives 3.00x. The constraint is not just regularization. It selects a manifold where the MLP's computation is geometrically consistent across time. Generalization lives on that manifold. The probe finds it faster than weight decay alone.
architecture
class Transformer(nn.Module):
    embed     = Embed(114, 128)
    pos_embed = PosEmbed(3, 128)
    attn      = Attention(128, 4, 32)
    mlp       = MLP(128, 512, ReLU)
    unembed   = Unembed(128, 113)

    def forward(self, tokens):  # [batch, 3]
        x = embed(tokens) + pos_embed(tokens)
        x = x + attn(x)
        x = x + mlp(x)
        return unembed(x)[:, -1, :]
setup
TASK ........... (a + b) mod p
PRIME p ........ 113
ARCH ........... 1-layer decoder-only transformer
d_model ........ 128      n_heads ... 4
d_mlp .......... 512      activation  ReLU

OPTIMIZER ...... AdamW    lr ........ 1e-3
WEIGHT DECAY ... 1.0      betas ..... (0.9, 0.98)
BATCH .......... full     train ..... 30%
SEED ........... 999      epochs .... 25,000
A live reproduction of Nanda et al. (2023), "Progress Measures for Grokking via Mechanistic Interpretability" (ICLR Spotlight). Real weights, real gradients, real grokking. After the phase transition, Fourier analysis reveals the model discovers the discrete Fourier transform: integers become rotations on a circle, composed via trig identities. Grokking is detected when both train and test accuracy reach 99% (Nanda et al. report >99.95% post-grok). Training continues to 99% stable basin.
Nanda et al. (2023). Progress Measures for Grokking. ICLR Spotlight.
Power et al. (2022). Grokking. arXiv:2201.02177.
source
Baseline Reproduction
Standard training without introspection constraints. Grokking emerges from weight decay alone.
stencil ablation
A small probe is added alongside the MLP. Each epoch the probe learns to predict mlp_hidden from mlp_in, then the mismatch regularizes the model. The probe is discarded after training. Click run add stencil to test each offset subset. Expect: earlier test accuracy jump, possible loss spike during transition.
L_total = L_task + 0.1 * ||f(mlp_in) - mlp_hidden||^2
temporal offsets
k = -1  backward admissibility   mlp_in(t)   -> mlp_hidden(t-1)   "does the current input predict the previous hidden state?"
k = 0   functional coherence     mlp_in(t)   -> mlp_hidden(t)     "does the current input predict the current hidden state?"
k = +1  forward lookahead        mlp_in(t-1) -> mlp_hidden(t)     "does the previous input predict the current hidden state?"
mechanism
MLP path: mlp_in (128) -> fc1 + ReLU -> mlp_hidden (R^512). Probe: 128 -> 64 -> ReLU -> 512. Separate optimizer, discarded after training.
two-phase training
Phase 1: model frozen. forward pass, detach activations, update probe only. Phase 2: probe frozen. fresh forward pass, MSE regularizer added to task loss, gradients update model not probe. This is a constraint, not distillation.
buffer and leak safety
A circular buffer stores past activations so probes can reference nearby epochs. No future information leaks: the buffer only stores what SGD already computed.
wavetable display