- Overview
- Codebase structure
- Data Setup
- Artificial optogenetics framework -- if you're looking to use this code for your work on different data setup/tasks, skip to this section.
- Reproducing the transience paper
- Reproducing the induction heads paper
- Reproducing the coopetition paper
- Installation
- Contributors
This is the codebase for a sequence of work investigating the dynamics of in-context learning (ICL) in transformers, through empirical and mechanistic lenses. This project started with the finding that ICL is often transient, and then dove into mechanistic interpretations of dynamics. The codebase relies extensively on JAX and Equinox.
Our first paper, The Transient Nature of Emergent In-Context Learning in Transformers, demonstrates that emergent ICL may disappear when overtraining. We demonstrated this on setups previously shown to heavily incentivize ICL over alternative, in-weights learning (IWL) strategies. See data setup for more detail. This establishes emergent in-context learning as a dynamical phenomenon, as opposed to an asymptotic one, which motivated the subsequent works on mechanistic understandings of the dynamics of formation. This first work was primarily done in 12L (~893K params) transformers. See the corresponding section for more details.
Our second paper, What needs to go right for an induction head?, studies emergence dynamics of induction heads on synthetic data that requires ICL to solve the task. We pursued this direction (rather than directly targeting the dynamics of transience in setups where ICL or IWL can solve the task) since we felt that there was more to be understood on the nature of induction heads and the dynamics of their formation before we were ready to tackle transience. This work also helped us build out our mechanistic toolkit, which is discussed further in the artificial optogenetics framework section. See the corresponding section.
Our third paper, Strategy coopetition explains the emergence and transience of in-context learning, brings everything together with a dynamical understanding of why ICL emerges, if only to fade away later (due to competition, as suggested in the first paper). We find a surprising hybrid of ICL and IWL as the asymptotic mechanism, with cooperative interactions allowing ICL to emerge earlier in training. A key aspect of these analyses was using the tools from the second paper. We felt that this paper tied off the work quite nicely, really emphasizing key intuitions and dynamical phenomena that may help understanding larger networks. See the corresponding section for more details.
As it covers the above works, the codebase has many parts. We highlight the key pieces and provide more detail in the subsections for reproducing each paper. Most files are also supplemented with comments to aid understanding.
If you're interested in extensions (or have questions about the code), feel free to reach out -- [email protected].
General codebase sturcture:
- main_utils.py: contains basically all argparse functionality. Each parameter has a dedicated help string which can be used to understand it. Many of the options are unused/set to their default values for the papers, but we found useful to play around with in earlier stages of the project to build intuition.
- main.py: Used to run training experiments. Contains training code (e.g., loss computation and
train_step-- this is where the jitting happens). Uses the functions frommain_utilsto create dataset, model, etc. and run training + evals throughout training. - samplers.py: JAX-based data sampling for our synthetic setup.
- models.py: Implements causal transformer models using our artificial optogenetics framework, allowing for easy recording and manipulation of intermediate activations
- opto.py: Contains the options and implementations of various optogenetic manipulations we used for some of the work. Argparse arguments for optogenetic variations are added here. The general idea was to have this be similar intuitively to a visitor pattern, where to add optogenetic manipulations, all one has to do is modify opto.py. See Artificial optogenetics framework for more detail.
- coopetition_model_solver.py: A lightweight library for the experimenting with toy models that we found to be surprisingly representative of larger network learning dynamics.
- visualize_runs.py: A heavyweight plotting library for motivated readers that want to dive deeper. Largely useful for visualizing near-arbitrary progress measures with near-aribtrary optogenetic manipulations at the dataset and individual data point levels.
Generally, the codebase makes use of a lot of functional programming, as is common with JAX codebases.
Part of why we used JAX is to ensure good random seed reproducibility (similar to other JAX-based transformer frameworks). To this end, we have a few different seeds (listed in main_utils.py, but repeated here):
init_seed: Used to initialize model (when training from scratch)train_seed: Used to generate training data (via samplers.py). Also used for things like dropout etc. if those are used (none of our experiments used these features)eval_seed: Used to generate eval data. We also have an option to directly read in pre-constructed eval data (load_eval_data). Also used for things like dropout etc. if those are used (none of our experiments used these features)
See lines ~347-48 in main.py to see how the latter seeds get split into data and model.
When checkpoints are saved, we save the three relevant seeds (see line ~510 in main.py).
Note, the way our training process works is that a seed is used at every step to generate a batch. This means if the batch size changes, the exact sequence of examples seen by the model will change. A batch size of 32 was used for all experiments. Tests with varying train_seed did not show much variance.
Setup (data: samplers.py, overall: main.py)
Our setups builds off that introduced by Chan et al. (2022), in Data Distributional Properties Drive Emergent In-Context Learning in Transformers. We are grateful for the authors of that work for open-sourcing their code. We used their code for early experiments, but ended up creating our own repository tailored to our analyses. Our work also uses JAX, but relies on Equinox instead of Haiku. We found the PyTree formalism of Equinox easier to work with, especially for artificial optogenetics.
Our data generator assumes a set of classes. Each class can be composed of one or more exemplars. Sequences are composed of a context of exemplar-label pairs, followed by a query exemplar, for which the model needs to output a label. For sampling (samplers.py), we disentangle sampling class sequences for the context (get_constant_burst_seq_idxs) from exemplars within each class (get_exemplar_inds). Though all our experiments tended to just use a single form of class sampling, we offer a way to mix samplers (get_mixed_seq_idxs), which could be used to reproduce the experiments with varying p(bursty) in Chan et al. (2022) or for other experiments. To support "ICL-only" sequences, we offer fewshot_relabel which changes the class labels to be random across contexts (but consistent within a sequence). These sequences force the network to use ICL, which we found useful for our work studying What needs to go right for an induction head?. Finally, our data samplers work by sampling class and exemplar indices, and only indexing to the data matrix (of dimension # classes x # exemplars x input_dim) at the last step. Our process is also end-to-end JIT-able. See samplers.py for more detail.
main.py uses the argparse options from main_utils.py to construct the training and eval data iterators, the model, etc. It also contains the train and eval steps and conducts training. We support saving and loading from checkpoints at custom schedules (we found this useful to e.g., upsample checkpoints during a phase change). This is also where the JIT-ing happens (via eqx.filter_jit).
Artificial Optogenetics framework (models.py)
A key contribution of our work is the artificial optogenetics framework. This is mostly manifest in models.py, which implements a Transformer that contains all elements of the framework. We wrap it with SequenceClassifier for our specific exemplar-label sequences. All manipulations on top of the framework (for the experiments in our papers) are implemented in opto.py. For full documentation on this portion of the code, see artificial_optogenetics_guide.md. As always, feel free to reach out with questions or collaborations -- [email protected].
Most of the runs in this paper were conducted with the original codebase from Chan et. al. (2022). Namely, the runs using a jointly trained Resnet encoder (which is most of the results). This codebase was used for the remaining runs -- those with fixed LLaMa embedding vectors as exemplars (Section 4.3), and those with fixed Omniglot embeddings (Appendix C).
LLaMa embedding vectors (extracted from LLaMa 1 open-source weights) were clustered using FAISS using the procedure in the paper and then turned into h5 files (with dimensions # classes x # exemplars x input_dim, where input_dim varies based on LLaMa source model size). An example sweep file operating on these h5's is llama_sweep_example.py.
Omniglot embeddings were extracted using omni_features_extract.py and then experiments for Appendix C were run using a sweep file like fixed_omni_emb_sweep_example.py. This file may be of use to see how the evaluators were structured.
If citing these evaluators/experiments, please use:
@misc{singh2023transient,
title={The Transient Nature of Emergent In-Context Learning in Transformers},
author={Aaditya K. Singh and Stephanie C. Y. Chan and Ted Moskovitz and Erin Grant and Andrew M. Saxe and Felix Hill},
year={2023},
eprint={2311.08360},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Reproducing What needs to go right for an induction head?
This codebase was largely made to support this paper. The dataset uses pre-processed omniglot features, similar to the transience paper. Namely, omni_features_extract.py was used to extract features. For simplicity, only 5 exemplars per class were processed (as the paper only uses 1 exemplar per class for training, and the remaining 4 for one of the test sets). These features were then re-ordered randomly to form the data file omniglot_resnet18_randomized_order_s0.h5, provided in the codebase. We directly provide this file to enable researchers without access to GPUs to quickly get started with the codebase -- all experiments for this paper can be run on a laptop!
To reproduce the figures of the paper, one should first run ih_paper_runs.sh, which contains all the relevant runs for the paper (for a given initialization seed). Then, one can use the ih_paper_plots.ipynb to reproduce all figures from the paper.
Those files minimally reproduce the paper. To additionally obtain all appendix results, run ih_paper_appendix_runs.sh and ih_paper_appendix_plots.ipynb.
All of the above rely on ih_paper_plot_utils.py for some utils (e.g., a simplified forward function wrapper).
For the toy model of phase changes, we have a separate file simple_model_solver.py. This file is called in the scripts above to generate the corresponding figures. It is completely independent of the rest of codebase, and may also be useful to those looking to further study toy models with clamping and/or progress measures.
We include notebook copies of ih_paper_plots.ipynb that plot results from runs we did on different initialization seeds in the folder ih_paper_additional_seeds. To actually run these notebooks, one would have to run ih_paper_runs.sh with other seeds, then move the notebook to the top-level folder and run it. Our intent with these notebooks is just to share additional results showing qualitative reprodubility of the observed phenomenon.
@misc{singh2024needs,
title={What needs to go right for an induction head? A mechanistic study of in-context learning circuits and their formation},
author={Aaditya K. Singh and Ted Moskovitz and Felix Hill and Stephanie C. Y. Chan and Andrew M. Saxe},
year={2024},
eprint={2404.07129},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
This paper required significantly more compute than the previous paper, so we recommend using GPUs. The total compute is still significantly lower than the first transience paper, given the smaller models used.
To reproduce this paper, one should first generate the exemplar embeddings using omni_features_extract.py. Then, one can run coopetition_paper_sweep.py, which uses submitit to parallelize jobs on a slurm cluster. The exact configuration may need to be changed depending on where you're running the jobs. Worst case, it should be relatively simple to run them sequentially. Note that the final two runs are on 12L models, and will take about 2 days on a 80GB H100 GPU -- those runs are only used for Figure 6b, and so could be removed if this is too expensive. Once the sweep is done, one can use coopetition_paper_plots.ipynb to reproduce all figures from the main paper.
Those files minimally reproduce the main text. To additionally obtain all appendix results, run coopetition_paper_appendix_sweep.py and coopetition_paper_appendix_plots.ipynb.
Building off simple_model_solver.py used for the induction heads paper, we made a new file coopetition_model_solver.py that can serve as a lightweight library for experimenting with minimal vector-learning setups. This file is meant to be modular and imported into a notebook. It provides a lot of functionality, such as automatic tracking of various losses, which we found more easy to use than the original simple_model_solver.py.
We created a heavier-weight plotting library to further explore our models: visualize_runs.py. This may not be the easiest to use for newcomers, but we found it very helpful in looking at a wide variety of evolution plots through training (i.e., progress measures), with arbitrary optogenetic manipulations. It also offers functionality to decompose to the individual datapoint level, which we explored briefly (in the ICL-only setting), noting some interesting things such as loss on many individual points actually increasing before decreasing. We're excited to release this tooling in case its helpful to other ICL researchers!
@misc{singh2025strategycoopetitionexplainsemergence,
title={Strategy Coopetition Explains the Emergence and Transience of In-Context Learning},
author={Aaditya K. Singh and Ted Moskovitz and Sara Dragutinovic and Felix Hill and Stephanie C. Y. Chan and Andrew M. Saxe},
year={2025},
eprint={2503.05631},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2503.05631},
}
See setup.md for instructions for various CUDA driver versions.
The primary creator and maintainer of this code was Aaditya Singh. Ted Moskovitz and Sara Dragutinovic also contributed to various parts (model, data samplers, training). The code was based off an earlier transformer implementation by Erin Grant. A special thanks to Stephanie Chan, who all this work was done in close collaboration with. The overall project was supervised by Andrew Saxe and Felix Hill, with Andrew Saxe also contributing some code for the first iteration of the tensor product toy model.