In this post I analyze the puzzle from puzzles.baulab.info. The task is to reverse-engineer a 9,600 parameter attention-only transformer trained to count unique tokens. See this Colab notebook for the corresponding code.

Task

Given a sequence of 10 tokens drawn from a vocabulary of 10 symbols, predict the number of distinct symbols. The released model hits 100% accuracy on every count from 1 through 10 on a held-out test set. Reverse-engineer the algorithm it learned.

   
Input symbols 10 (rendered as aj)
Sequence length 10 input tokens
Layers 2
Heads per layer 4
d_model 32
Parameters 9,600
Vocab 22 tokens (10 input symbols + BOS + ANS + 10 count tokens #1#10)
Positional embeddings none
Causal mask yes

Attention-only: no MLPs, no LayerNorm, no positional embeddings.

token_embedding
    → attention layer 1 (4 heads, causal mask) + residual
    → attention layer 2 (4 heads, causal mask) + residual
    → linear unembed → logits

Submission

I will claim in this post that the model partitions the 10-symbol vocabulary into four main groups, each handled by a dedicated mechanism.

Token group Handled by Mechanism Summary
b L1 head 0 Self-attends in L0 → distinctive residual → detected directly
f L1 head 1 Mostly self-attends in L0 → invisible to anchor circuit → detected directly
a, h, d, g L1 head 2 L0 heads write anchor OVs into subsequent positions → fingerprint read at last position
c, e, i, j L1 head 3 L0 heads write ceij tokens to ANS → destructive interference with query → BOS attenuation

Below is some evidence that the algorithm is split into these four groups, and an explanation of each group’s mechanism in detail.

The Four-Way Partition

The first piece of evidence is that each layer 1 head is causally necessary for its exact token group. I test this by taking subsets of the input space containing tokens belonging to different groups. Mean ablation of the head drops the count accuracy to 0% when its corresponding tokens are present, but leaves accuracy at 100% when they are absent. Interestingly, the prediction drops by approximately the number of unique tokens in that group, highlighting the additive nature of the counting procedure in layer 1.

Mechanism 1: Counting b

For this mechanism I claim that b has strongly negative key scores in all four layer 0 heads, causing it to always self-attend rather than accumulate information from other tokens. This gives b a unique residual after layer 0. Layer 1 head 0 then attends directly to b if present (else falls back to BOS), and the difference in value outputs between these two cases produces a clean +1 rightward shift in the count logits.

Key results:

  • b attends only to itself or BOS in Layer 0 (1.000), other tokens never attend to b (0.000)
  • L1 head 0 attends to b if present, else BOS (1.000 both ways)
  • OV(b) − OV(BOS) is consistently rightward, encoding +1 to the count

Mechanism 2: Counting f

For this mechanism I claim that f mostly self-attends in Layer 0 (62% self-attention) but unlike b it also reads some anchor content. Crucially however, f has strongly negative key scores in layer 1 head 2, so is still invisible to the anchor counting circuit (0.000 attention from head 2). Instead, layer 1 head 1 detects f directly: it attends to f if present, else falls back to {g,d,h}, and the difference in value outputs again encodes +1 to the count.

Key results:

  • f mostly self-attends in Layer 0 (62%)
  • L1 head 2 never attends to f (0.000). f is invisible to the anchor circuit
  • L1 head 1 attends to f if present, else g/d/h.
  • OV(f) − OV(g/d/h) projects rightward, encoding exactly +1

Mechanism 3: Counting a, h, d, g

Each layer 0 head is dedicated to one anchor token (head 0→a, head 1→h, head 2→d, head 3→g). When an anchor is present, every subsequent non-b/f/ANS token attends to it with probability 1.000, causing the anchor’s OV vector to be written into all subsequent residuals. This creates an accumulated fingerprint encoding which anchors appeared. Layer 1 head 2 is biased toward later positions (which have richer fingerprints via higher key scores), and its value circuit produces a rightward logit shift proportional to how many unique anchors were present.

Key results:

  • Every non-b/f token after an anchor attends to it: 1.000 for all four heads
  • Forcing L1 head 2 to attend to last non-b/f position gives 1.000 accuracy; first position gives 0.002. The last position always contains the full fingerprint
  • Head 2 logit contribution scales with anchor count (0→4 anchors produces monotonically increasing rightward shift)

Mechanism 4: Counting c, e, i, j

Since mechanism 3 is using each layer 0 head for each anchor token, I claim that this mechanism learns to write to the ANS token instead to change the query in layer 1. Each layer 0 head writes one designated secondary token to the ANS position (head 0→i, head 1→j, head 2→c, head 3→e). These OV vectors accumulate in the ANS residual, which becomes the query for all layer 1 heads. Each ceij token’s OV contribution points approximately opposite to the baseline ANS query (~0.98 cosine similarity with −q_base), causing destructive interference that progressively cancels the query magnitude. This reduces ANS’s attention to BOS in Layer 1 head 3. Since BOS provides a strong low-count baseline signal, attenuating it raises the predicted count by approximately +1 per unique ceij token present.

Key results:

  • Each L0 head writes one ceij token to ANS: head 0→i, head 1→j, head 2→c, head 3→e. Forcing this assignment gives 1.000 accuracy
  • Each ceij token’s query shift is ~0.98 anti-parallel to the baseline query (destructive interference)
  • q·k_BOS drops monotonically with ceij count: 19.2 → 11.6 → 8.9 → 6.3 → 3.6
  • BOS logit contribution shrinks with each additional ceij token, reducing the low-count bias

Final Summary

This model solves unique token counting through a clean four-way partition of the vocabulary, with each subset handled by a dedicated layer 1 head using information prepared by Layer 0.

The key insight is that Layer 0 serves a different role for each group:

  • b, f: isolated from the rest of the circuit by self-attention, leaving distinctive residuals for direct detection
  • a, h, d, g: each head propagates its anchor’s presence forward through the sequence via OV writes, accumulating a fingerprint at later positions
  • c, e, i, j: each head writes its designated secondary token into the ANS residual, which then alters the query for layer 1

Layer 1 then reads off these signals independently:

  • Head 0, 1: binary detectors for b and f, contributing exactly +1 when present
  • Head 2: reads the anchor fingerprint from the last non-b/f position, contributing +1 per unique anchor
  • Head 3: its baseline query points toward BOS (low-count bias); each ceij token destructively interferes with this query, attenuating BOS attention and contributing +1 per unique ceij token

The four contributions combine additively to produce the correct total count, achieving 100% accuracy across all counts 1–10. The solution exhibits perfect coverage with zero redundancy.