👉 See Ablation Study: What Drives HRM Performance?
This is an implementation of the Hierachical Reasoning Model (HRM) proposed by Guan Wang et al. I built it for educational purposes, with a few minor simplifications and extensions (see Modifications to the Original Work).
The architecture is inspired by hierarchical, multi-timescale processing in the human brain. It uses two connected recurrent modules running at different frequencies:
- H module (slower): handles abstract planning
- L module (faster): handles low-level computations
Both are based on self-attention. Together, this is an attempt to model reasoning in latent space.
The model is applied to a pathfinding task: given an N×N board with obstacles, find the shortest path from START to END. The animation below shows actual inference steps as the model incrementally discovers the path.
10x10 board | 20x20 board |
---|---|
Legend: . = Floor, # = Wall, S = Start point, E = End point, * = Path
Dependencies:
python3 -m pip install torch pillow
Train a model:
python3 boardpath.py --mode train
Run inference on a random board (also saves an animated GIF of the steps):
python3 boardpath.py --mode inference
To adjust the task, model, or training setup, edit get_config()
and get_train_config()
in boardpath.py
. For example:
- Board size & obstacle density:
board_size
,wall_prob
- Embedding dimensionality (representation of each board cell):
d_model
All parameters are documented in hrm/hrm.py
.
- Fixed "think time" (that is, fixed # of segments, both in train & inference) instead of Q-learning / ACT halting
- PyTorch SDPA instead of FlashAttention (same results, lower performance)
- RoPE and learned positional encodings can be used together
- Standard initialization instead of truncated LeCun normal
- Learnable (or fixed) initial H and L states
nn.RMSNorm
with learnable scale (elementwise_affine=True
)- Slight differences in where
nn.Dropout
is applied
HRM
=InputEmbedding
+ twoReasoningModule
instances (H and L) + linear projection (on last H state)InputEmbedding
=nn.Embedding
for tokens + optional absolute positionalnn.Embedding
(added to token embeddings)ReasoningModule
= stack ofHRMBlock
instancesHRMBlock
has two sublayers:- Attention:
SDPAttention
→Dropout
→ residual →RMSNorm
(post-norm) - MLP:
SwiGLU
→Dropout
→ residual →RMSNorm
- Attention:
SDPAttention
= scaled dot-product attention + linear projection, with optional RoPE
Recent discussion around the Hierarchical Reasoning Model (HRM) asks whether it's the new two-timescale H/L architecture that drives the performance.
To explore this, I ran a simple set of ablations on a board pathfinding task (20×20 boards, wall probability = 0.3). Each variant was trained on 2000 training boards and validated on 500 boards, for 40 epochs. Models were parameter-matched: H/L variants used d_model=256 (~6.29M parameters), single-module variants used d_model=360 (~6.23M parameters).
This is a small study on a relatively simple task, so the results should be taken as illustrative rather than definitive.
- Architecture (variant):
- H-only (detached): single ReasoningModule unrolled for H×L steps, hidden state detached.
- H-only (bptt): same as above, but with full backpropagation-through-time (BPTT).
- H/L: full HRM with separate H and L modules.
- Training segments (train_seg): number of refinement segments during training.
- Inference segments (infer_seg): number of refinement segments during evaluation.
- Cycles: number of inner iterations (H×L).
In the table:
- board acc = accuracy of predicting the entire board correctly (last 5 epochs average).
- acc4x = the same metric, but with 4× more inference segments (extra test-time refinement).
- Gap = acc4x – board acc, showing how much accuracy improves with additional inference steps.
Variant | train_seg | infer_seg | H×L cycles | Board acc | acc4x | Gap |
---|---|---|---|---|---|---|
H/L | 2 | 2 | 2×2 (4) | 0.390 | 0.382 | -0.008 |
H/L | 4 | 2 | 2×2 (4) | 0.320 | 0.425 | +0.104 |
H/L | 2 | 2 | 2×4 (8) | 0.403 | 0.444 | +0.041 |
H/L | 4 | 2 | 2×4 (8) | 0.447 | 0.552 | +0.105 |
H/L | 2 | 2 | 4×2 (8) | 0.458 | 0.481 | +0.023 |
H/L | 4 | 2 | 4×2 (8) | 0.523 | 0.545 | +0.022 |
H-only (bptt) | 2 | 2 | 4 | 0.226 | 0.272 | +0.046 |
H-only (bptt) | 4 | 2 | 4 | 0.574 | 0.625 | +0.052 |
H-only (bptt) | 2 | 2 | 8 | 0.376 | 0.394 | +0.018 |
H-only (detached) | 2 | 2 | 4 | 0.158 | 0.182 | +0.024 |
H-only (detached) | 4 | 2 | 4 | 0.347 | 0.436 | +0.089 |
H-only (detached) | 2 | 2 | 8 | 0.222 | 0.226 | +0.005 |
Board accuracy (last 5 epochs average):
acc by variant | 👉 acc by train_seg | acc by cycles |
---|---|---|
Refinement gap (last 5 epochs average):
gap by variant | 👉 gap by train_seg | gap by cycles |
---|---|---|
-
Segments are the main driver. This applies to both accuracy and refinement ability.
-
Architecture has little influence. H/L and single-module BPTT perform similarly; any differences are minor compared to the impact of segments.
-
Cycles increase accuracy but not refinement. More cycles raise board accuracy a bit, but not the refinement ability.
That said, one notable aspect of H/L is that it achieves high performance and refinement ability with detachment (no BPTT) and segment training, potentially reducing training cost — something not explored in detail here.
These findings are consistent with the ARC Prize team’s analysis (blog, slides), which also concluded that outer-loop refinement is the main driver of performance, not the H/L split.
When trained with more segments, the model reaches higher accuracy and refines its predictions better when given extra inference steps.
With 2 training segments, board accuracy improves but acc and acc4x remain nearly identical. With 4 training segments, accuracy is higher and a clear gap opens up between acc and acc4x, showing the model learns to refine further at inference.
Train Segments = 2 | Train Segments = 4 |
---|---|
The refinement process is visible in how solutions emerge: early steps make broad strokes, while later steps progressively add smaller corrections until the full path is resolved.
20x20 board refinement | 30x30 board refinement |
---|---|