This repository contains the code and resources for our paper, which has been accepted to ICLR 2025.
Cite as:
@inproceedings{
li2025chunkdistilled,
title={Chunk-Distilled Language Modeling},
author={Yanhong Li and Karen Livescu and Jiawei Zhou},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=nrvoWOWcyg}
}
Before running any scripts, please set up the necessary environment.
-
Activate your Conda environment. We assume you have Conda installed and have created an environment.
-
Install dependencies. The required Python packages are listed in
requirements.txt. Install them using pip:pip install -r requirements.txt
Note on PyTorch & CUDA: The
requirements.txtfile will install a default version of PyTorch. If you need a specific version compatible with your CUDA toolkit for GPU acceleration, we strongly recommend installing it separately by following the official instructions on the PyTorch website.
Script: preprocess_dataset.py
Convert raw data into .json files (train, validation, test) with a standardized format.
2. Extract Token Probabilities & Hidden States
Script: extract_token_probs_and_memmap.py
Reads your preprocessed .json data, runs a language model to get token probabilities and hidden states, and (optionally) constructs a memory-mapped file (.dat / .h5) for efficient retrieval.
Script: save_token_trie_info.py
- Takes the token probabilities (e.g.
token_probs_*.json) from step 2. - Identifies tokens above a probability threshold.
- Splits them into phrases/chunks.
- Builds and saves the initial trie metadata (
.jsonfiles), which will be used in the next step.
4. Build Datastore (Token Trie) from Hidden States
Script: make_token_trie.py
- Reads the output from step 3 (
tok_pos_{id}.jsonfiles, etc.). - Loads the large
.h5of hidden states from step 2. - Gathers relevant hidden states for each token’s position.
- Stores them in HDF5-based sub-files (
{token_id}.h5) in a specified directory, forming your final retrieval datastore.
Script: eval.py
Use the specialized inference script (eval.py) to tap into the datastore you built. This script
- performs nearest-neighbor retrieval on the token trie, pulling in relevant chunks of hidden states when their similarity passes a chosen threshold.
- If no suitable match is found, it automatically reverts to standard language model generation.
Run the following command to preprocess the dataset:
python preprocess_dataset.py \
--dataset wikitext-103-v1 \
--cache_dir path/to/huggingface/cache \
--output_dir path/to/preprocessed_data
After this command completes, you should have the following directory structure:
path/to/preprocessed_data/
├── train.json
├── validation.json
└── test.json
wikitext-103-v1med_instructionpile-of-law-federal_register
If you want to add your own dataset, follow these steps:
-
Implement a custom function in
preprocess_dataset.pythat:- Loads and splits your data into train, validation (optional), and test sets (optional).
- Saves them in the standard JSON format (shown below).
-
Save the data in the following JSON structure:
[ {"text": "Example passage 1..."}, {"text": "Example passage 2..."} ] -
Update the CLI argument check in
preprocess_dataset.pyto call your custom function when--datasetmatches the name of your new dataset.
You may choose not to use preprocess_dataset.py at all, as long as your data is saved in the format below:
path/to/preprocessed_data/
├── train.json
├── validation.json
└── test.json
Each JSON file should look like this:
[
{"text": "Example passage 1..."},
{"text": "Example passage 2..."}
]2. Extracting Token Probabilities, Hidden States, and (Optionally) Building the Memmap
python extract_token_probs_and_memmap.py \
--model_path gpt2-xl-conversational \
--data_dir path/to/preprocessed_data \
--output_dir path/to/output \
--save_hidden_states \
--partition train \
--construct_memmap True \
--cache_dir path/to/huggingface/cacheParameters:
--model_path: A Hugging Face model path/identifier (e.g.,"gpt2","EleutherAI/gpt-neo-2.7B", or local directory).--data_dir: Directory withtrain.json,validation.json, and/ortest.jsonfrom the previous step.--output_dir: Where to store output. Inside, you’ll see subdirectories for each data split (train,validation,test) with JSON files of token probabilities and (optionally).npyhidden states.--save_hidden_states: If set, saves the final hidden states for each chunk as.npy.--partition: Which data split to process (train,validation,test). If you want to process all three, simply run the script multiple times or adjust accordingly.--construct_memmap: IfTrue, after extracting hidden states the script automatically assembles them into a large memory-mapped.datfile (and also an.h5dataset).--cache_dir: (Optional) Path for caching/downloading model weights.--chunk_size,--stride,--min_length: Control how passages are split into overlapping chunks.
Once complete, you will find for each split (e.g. train):
path/to/output/train/
├── token_probs_0.json
├── token_probs_1.json
├── ...
├── hidden_states_0.npy
├── hidden_states_1.npy
└── ...
When --construct_memmap is used in conjunction with --save_hidden_states, the script automatically:
- Converts each chunk’s
.npyhidden states into a.datfile. - Merges all
.datfiles into a single large memory-mapped file (train.dat), plus an.h5copy (train.h5). - Saves index maps (
train_index_to_seqlen.json,train_index_to_acc_index.json) for quick lookups in retrieval tasks.
[IMPORTANT] If you only need token probabilities (and no hidden states), omit --save_hidden_states and --construct_memmap. If you want hidden states but not the memmap, omit --construct_memmap.
Script: save_token_trie_info.py
In this step, you create the token trie metadata from the extracted token probabilities. Specifically:
-
Inputs:
- A directory of
token_probs_*.jsonfiles, each containing probability data for some chunk of text (output from Step 2). - A tokenizer name/path to ensure consistent tokenization (if needed for final ID assignment).
- Various CLI arguments that control:
dataset_name: A label used in file naming.partition: The data split you’re preparing (e.g.train,validation,test).tok_prob_threshold: Probability threshold for “significant” tokens.
- A directory of
-
Core Operations:
- Reads all
token_probs_*.jsonin your--token_prob_dir. - Optionally skips the first 64 tokens if you aren’t processing a very short dataset (like PII).
- Identifies tokens above the
--tok_prob_threshold. - Groups adjacent high-probability tokens into “phrases” or chunks.
- Subdivides longer phrases into 2-token sub-phrases and merges them in a Trie data structure.
- Reads all
-
Outputs:
- A set of JSON files describing the phrases, which the Trie uses for further processing. These files get saved under
--output_dirwith suffixes like_token_prob_list.json,_all_phrases_{threshold}.json,_all_filtered_tok_probs_{threshold}.json. - A Trie structure is built and serialized into additional files (e.g.
tok_pos_{token_id}.json) for each token that meets the threshold criteria. These are used in the next step.
- A set of JSON files describing the phrases, which the Trie uses for further processing. These files get saved under
python save_token_trie_info.py \
--token_prob_dir /path/to/output/train \
--output_dir /path/to/trie_info \
--model_name gpt2 \
--dataset_name wikitext-103 \
--model_ds gpt2 \
--partition train \
--tok_prob_threshold 0.3After running, you should see files in /path/to/trie_info like:
/path/to/trie_info/
├── gpt2_wikitext-103_train_token_prob_list.json
├── gpt2_wikitext-103_train_all_phrases_0.3.json
├── gpt2_wikitext-103_train_all_filtered_tok_probs_0.3.json
└── token_trie_wikitext-103_DSgpt2_Hgpt2/train_0.3/
├── 123.json
├── tok_pos_123.json
├── 456.json
├── tok_pos_456.json
...
4. Build Datastore (Token Trie) from Hidden States
Script: make_token_trie.py
In this step, you populate the trie datastore with the actual hidden states from Step 2. Each token ID (determined in Step 3) is mapped to its corresponding hidden states for retrieval.
-
Inputs:
- The directory containing
tok_pos_{token_id}.jsonfiles (produced in Step 3). - The
.h5file (train.h5,validation.h5, ortest.h5) of merged hidden states from Step 2. - Two JSON maps:
index_to_acc_index.jsonandindex_to_seqlen.json(also created during the memmap step) that let you quickly look up each chunk’s offset in the.h5.
- The directory containing
-
Core Operations:
- Iterates over each
tok_pos_{token_id}.jsonfile. - For each token ID and each depth in the trie, collects the (chunk_id, tok_pos) pairs.
- Uses
index_to_acc_indexto find where in the.h5those hidden states reside. - Saves the relevant hidden states into an
.h5file dedicated to that token ID (e.g.token_id.h5).
- Iterates over each
-
Outputs:
- For each token ID, a file
{token_id}.h5in the same directory, containing the hidden states relevant to that token. - Together, these
.h5files form your final datastore (the token trie + hidden states).
- For each token ID, a file
python make_token_trie.py \
--save_dir /path/to/trie_info/token_trie_wikitext-103_DSgpt2_Hgpt2/train_0.3 \
--dataset_name wikitext-103 \
--model_ds gpt2 \
--model_gen gpt2 \
--partition train \
--tok_prob_threshold 0.3 \
--idx_to_seqlen_path /path/to/output/train_index_to_seqlen.json \
--index_to_acc_index_path /path/to/output/train_index_to_acc_index.json \
--all_hidden_states_path /path/to/output/train.h5 \
--max_workers 16Parallelism
- By default,
make_token_trie.pyruns with--max_workers 32. This can speed up handling of large numbers of tokens. - If you split your data into multiple runs, you can use the
--kand--k_intervaloptions to process only a subset of token IDs.
Upon completion, each token (token_id) in the trie now has a dedicated HDF5 file containing its relevant embeddings/hidden states. This enables efficient nearest-neighbor lookups during the inference step.
Script: eval.py
This is our specialized inference script. It uses the constructed datastore to perform nearest-neighbor retrieval for chunk-level next-token suggestions based on hidden-state similarity. If the similarity is below a threshold, it falls back to standard LM generation.
python eval.py \
--model_path gpt2 \
--cache_dir /path/to/huggingface/cache \
--token_trie_dir /path/to/trie_info/token_trie_wikitext-103_DSgpt2_Hgpt2/train_0.3 \
--dataset wikitext-103 \
--partition test \
--tok_prob_threshold 0.3 \
--model_ds gpt2 \
--model_gen gpt2 \
--prompt_file /path/to/prompts.json \
--save_dir /path/to/save/results \
--similarity_threshold 0.5 \
--max_new_tokens 50Key Arguments:
--model_path: Hugging Face model path/identifier for the LM used in inference.--cache_dir: Directory where models and pickled tries are cached.--token_trie_dir: Root directory for token trie files from Step 3 & 4.--datasetand--partition: Must match how the trie was built.--tok_prob_threshold: Probability threshold used for building the trie.--model_ds,--model_gen: The model used to build the datastore (model_ds) and the model used for generation (model_gen). Often the same, but can differ if you built a specialized datastore.--prompt_file: JSON file with one or more prompts to feed into the model.--save_dir: Output directory. For each prompt, the script saves a.jsonwith the generated text and timing info.--similarity_threshold: Cosine similarity threshold for deciding if we should retrieve a chunk from the datastore.--max_new_tokens: Number of tokens to generate beyond the prompt length.
By default, the script:
- Loads the language model and tokenizer.
- (Optionally) flattens any
.h5trie files into.pklfor efficient retrieval (unless--no_reprocessingis set). - For each prompt, attempts chunk-level retrieval. If the chunk retrieval is successful (similarity ≥ threshold), it appends an entire chunk at once. Otherwise, it defaults to single-token LM generation.
- Saves the resulting generation in
prompt_{i}.json(one file per prompt).
If your prompts.json looks like:
[
"Example prompt 1...",
"Example prompt 2..."
]Then eval.py generates completions for each prompt and writes them to:
path/to/save/results/
├── prompt_0.json
└── prompt_1.json
Each .json file contains both the generation details and timing information.
If you want to compare with pure language model generation (no retrieval), add --gen_baseline_only. Internally, it sets similarity_threshold=1, effectively disabling retrieval.
Please open an issue on this GitHub repo for any troubleshooting or feature requests.