grisun0 Independent Research December 2025
A PyTorch implementation for studying grokking phenomena in neural networks trained on matrix multiplication tasks, featuring adaptive regularization and zero-shot transfer capabilities.
MatrixGrokker is a research framework designed to investigate how neural networks learn matrix multiplication operations, with particular focus on grokking behavior - the phenomenon where networks suddenly transition from memorization to generalization. The implementation includes sophisticated monitoring of learning dynamics through local complexity and superposition metrics, coupled with an adaptive thermal regularization engine.
Grokking Analysis: Comprehensive tracking of the transition from memorization to generalization Thermal Engine: Adaptive weight decay regulation based on local complexity and superposition metrics Zero-Shot Transfer: Ability to transfer learned representations to larger matrix sizes without additional training Comprehensive Metrics: Real-time monitoring of loss, accuracy, local complexity, and superposition measures Checkpoint System: Robust saving and resuming capabilities for long training runs Modular Architecture: Extensible design for experimenting with different network configurations
git clone https://github.com/grisuno/matrixgrokker
cd matrixgrokker
pip install torch numpyRun the complete experiment with default configuration:
from app import run_full_experiment
grokker, metrics, transfer_results = run_full_experiment()This will:
Train a neural network on 2×2 matrix multiplication Monitor grokking behavior through local complexity and superposition metrics Apply adaptive thermal regularization Test zero-shot transfer to 4×4 and 8×8 matrices Configuration The system is configured through the Config class:
from app import Config
config = Config()
config.MATRIX_SIZE = 2 # Base matrix size for training
config.HIDDEN_DIM = 256 # Hidden layer dimensions
config.NUM_LAYERS = 3 # Network depth
config.TRAIN_EPOCHS = 1000 # Training duration
config.BATCH_SIZE = 128 # Batch size
config.LEARNING_RATE = 0.001 # Learning rate
config.WEIGHT_DECAY = 0.01 # Base weight decay- MLPModel: Multi-layer perceptron with configurable depth and activation functions
- Supports weight expansion for transfer learning
- Implements forward hooks for activation analysis
- Provides weight matrix extraction for superposition analysis
- MatrixMultiplicationDataset: Synthetic data generation
- Generates random matrix pairs within specified ranges
- Computes exact multiplication results
- Flattens matrices for network consumption
- LocalComplexity: Measures neural representation diversity
- Computes pairwise activation similarities
- Quantifies the complexity of learned representations
- Values range from 0 (simple) to 1 (complex) Superposition: Analyzes weight matrix structure
- Performs singular value decomposition on weight matrices
- Measures the degree of weight superposition
- Values range from 0 (no superposition) to 1 (high superposition)
- ThermalEngine: Adaptive regularization system
- Adjusts weight decay based on local complexity and superposition
- Targets optimal learning dynamics
- Provides thermal progress indicators
- Training Pipeline
- The training process implements a sophisticated monitoring system:
- Forward Pass: Compute predictions and loss
- Backward Pass: Update network parameters
- Metrics Computation: Calculate local complexity and superposition every N steps
- Thermal Adjustment: Modify weight decay based on current metrics
- Checkpointing: Save model state and metrics periodically
- Validation: Evaluate on held-out validation set
- Transfer Learning
- The zero-shot transfer mechanism allows the network to generalize to larger matrices:
- Weight Expansion: Increase network capacity for larger input/output dimensions
- Zero-Shot Evaluation: Test on larger matrices without additional training
- Performance Analysis: Measure accuracy and learning dynamics on transfer tasks
- Metrics and Monitoring
- The system tracks comprehensive metrics throughout training:
- Performance Metrics:
- Training and validation loss (MSE)
- Training and validation accuracy (within 0.1 threshold)
- Iterations per second for performance monitoring
- Learning Dynamics:
- Local Complexity (LC): Measures representation diversity
- Superposition (SP): Analyzes weight matrix structure
- Thermal Progress: Combined measure of learning optimization
- Regularization:
- Adaptive weight decay values
- Learning rate scheduling
- Thermal engine status
Repository: https://github.com/grisuno/strass_strassen
DOI: https://doi.org/10.5281/zenodo.18263654
Reproduction:
git clone https://github.com/grisuno/strass_strassen
cd strass_strassen
pip install -r requirements.txt
python app.pyRelated repositories:
- Ancestor: https://github.com/grisuno/SWAN-Phoenix-Rising
- Core Framework: https://github.com/grisuno/agi
- Parity Cassette: https://github.com/grisuno/algebra-de-grok
- Wave Cassette: https://github.com/grisuno/1d_wave_equation_grokker
- Kepler Cassette: https://github.com/grisuno/kepler_orbit_grokker
- Pendulum Cassette: https://github.com/grisuno/chaotic_pendulum_grokked
- Ciclotron Cassette: https://github.com/grisuno/supertopo3
- MatMul 2x2 Cassette: https://github.com/grisuno/matrixgrokker
- HPU Hamiltonian Cassette: https://github.com/grisuno/HPU-Core