Feel free to star the repo or cite the paper if you find it interesting.
@article{fu2025tah,
title={Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models},
author={Tianyu Fu and Yichen You and Zekai Chen and Guohao Dai and Huazhong Yang and Yu Wang},
journal={arXiv preprint arXiv:2510.08577},
year={2025},
}-
[2025/11] We released the TaH-plus-1.7B checkpoint. The model is finetuned from Qwen3-1.7B-Base using 100K samples from the OpenR1 dataset, capable of QA, math, and coding.
-
[2025/11] Our paper was featured as the #2 Paper of the Day on Huggingface Daily Papers
Create a new environment:
conda create -n tah python=3.10
conda activate tahInstall the package:
pip install -e .For training and evaluation, install additional dependencies:
pip install -e ".[training,evaluation]"For code generation evaluation, install evalplus
python script/playground/inference_example.pyThis script demonstrates TaH's selective latent iteration mechanism, with color-coded output showing the iteration count for each token.
python script/evaluation/eval.py \
--eval_config ./script/recipes/qwen3_1.7/eval_tah.yaml \
--model_path nics-efc/TaH-plus-1.7B \
--dataset_name gsm8k \
--backend tah \
--job_nums 8 \
--tp_size_per_job 1Key parameters:
--eval_config: Path to evaluation config file--model_path: Path to the model--dataset_name: Dataset name (supports gsm8k, math500, aime24, etc. Detailed configs can be found intah/evaluate/eval_configs/dataset_configs.json)--backend: Inference backend (tahfor TaH)--job_nums: Number of parallel jobs--tp_size_per_job: Tensor parallel size per job
python script/evaluation/eval.py \
--eval_config ./script/recipes/qwen3_1.7/eval_base.yaml \
--model_path nics-efc/Standard-1.7B \
--dataset_name gsm8k \
--backend hf \
--job_nums 8 \
--tp_size_per_job 1Similar to TaH evaluation, but using:
--backend hfor--backend sglang
Training a TaH model consists of three stages:
1. Prepare training data
Use a reference model to generate hard token labels for the training and validation data:
### step 0
python script/preparation/label.py \
--num_gpu 8 \
--dataset_path ./data/initial_data/openr1-math/train.jsonl \
--test_model_list Qwen/Qwen3-1.7B \
--output_path ./data/processed_data/openr1-math/1_7/train \
--max_input_length 10000
python script/preparation/label.py \
--num_gpu 8 \
--dataset_path ./data/initial_data/openr1-math/eval.jsonl \
--test_model_list Qwen/Qwen3-1.7B \
--output_path ./data/processed_data/openr1-math/1_7/eval \
--max_input_length 10000 \2. (Optional) Prepare pruned model
For the TaH version, prune one layer from the base model to match the parameter count of the standard baseline (skip this step for TaH+ version):
### step 0
python script/preparation/prune.py \
--model Qwen/Qwen3-1.7B-Base \
--dataset ./data/processed_data/openr1-math/1_7/eval \
--output ./model/qwen3_1.7_base_pruned \
--num_prune 1The first stage uses fixed iteration labels for training:
### step 1
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
./script/train/SFT_TaH.py \
--config ./script/recipes/qwen3_1.7/sft_tah_step1.yamlKey configurations in Step1 (sft_tah_step1.yaml):
max_iter: 2: Maximum number of iterationsiter_decider: "FixedLabelIterDecider": Use fixed labels to decide iterationsiter_label_generator: "FixedIterLabelGenerator": Generate labels from mismatch field in datainput_updater: "AdditiveUpdater": Use additive updater for input updatesadapter: "lora": Use LoRA adapter for deeper iterationtrain_loss: "NextTokenPredLoss": Next token prediction loss
The second stage trains the iteration decider:
### step 2
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
./script/train/SFT_TaH.py \
--config ./script/recipes/qwen3_1.7/sft_tah_step2.yamlKey configurations in Step2 (sft_tah_step2.yaml):
tah_model_path: Load the model trained in Step1iter_decider: "MLPIterDecider": Use MLP decider to automatically determine iterationstrain_loss: "IterDeciderLoss": Iteration decider loss functionfreeze_component: [model.simple_base_model]: Freeze model backbone
After two-stage training, the model can automatically decide when to perform latent reasoning iterations.
TaH/
├── tah/ # Core package
│ ├── model/ # Core model components
│ ├── train/ # Training components
│ ├── evaluate/ # Evaluation utilities
│ └── utils/ # General utilities
├── bash/ # Bash scripts for training and evaluation
├── script/ # Execution scripts
│ ├── analysis/ # Analysis scripts
│ ├── evaluation/ # Evaluation scripts
│ ├── preparation/ # Preparation for training
│ │ ├── label.py # Data labeling (generate mismatch labels)
│ │ └── prune.py # Model pruning
│ ├── playground/ # Some examples
│ └── recipes/ # Configuration files
│ ├── qwen3_0.6/ # Qwen3-0.6B-Base configs
│ ├── qwen3_1.7/ # Qwen3-1.7B-Base configs
│ └── accelerate_configs/ # Distributed training configs
└── pyproject.toml # Project configuration
- Support more inference backends (e.g., SGLang)
- Optimize iteration decision strategies
- Integrate TaH with online distillation or RL
- Support training for larger models