Recent advances in large language models (LLMs) demonstrate their impressive reasoning capabilities. However, the reasoning confined to internal parametric space limits LLMs' access to real-time information and understanding of the physical world. To overcome this constraint, we introduce SituatedThinker, a novel framework that enables LLMs to ground their reasoning in real-world contexts through situated thinking, which adaptively combines both internal knowledge and external information with predefined interfaces. By utilizing reinforcement learning, SituatedThinker incentivizes deliberate reasoning with the real world to acquire information and feedback, allowing LLMs to surpass their knowledge boundaries and enhance reasoning. Experimental results demonstrate significant performance improvements on multi-hop question-answering and mathematical reasoning benchmarks. Furthermore, SituatedThinker demonstrates strong performance on unseen tasks, such as KBQA, TableQA, and text-based games, showcasing the generalizable real-world grounded reasoning capability.
- 🐍 Python 3.10
Run the following commands to set up the environment and install dependencies:
conda create -n situated-thinker python=3.12
conda activate situated-thinker
cd src
pip install -e .
pip install flash-attn==2.7.4.post1 --no-build-isolationWe utilize the Wikipedia 2018 dump as the corpus for the retrieval interface.
First, deploy the embedding model to generate embeddings for both text and queries:
scripts/deploy/deploy_llm.sh [your_embedding_model_name] [num_gpu]Update the following entries in src/runtime.env:
RETRIEVAL_EMBED_MODEL=your_embedding_model_name
RETRIEVAL_EMBED_KEY=EMPTY
RETRIEVAL_EMBED_DIM=your_embedding_model_dim
RETRIEVAL_EMBED_URL=your_embedding_model_urlThen, generate the FAISS index for the corpus:
scripts/build_faiss_index.shThe index and corresponding corpus will be saved in cache/faiss_index/wikipedia18.
Deploy the retrieval interface on a machine with at least one A100-80G GPU or equivalent:
scripts/deploy/deploy_retrieval_interface.shObtain the URL of the deployed retrieval interface.
Refer to SandboxFusion to deploy the sandbox service. Then, set your sandbox service URL in src/runtime.env:
SANDBOX_FUSION_ENDPOINT=your_sandbox_urlRun the following script to generate training data for GRPO:
scripts/build_grpo_data.shThe data will be saved in cache/data/grpo.
On the master machine, start the Ray cluster:
ray start --head --port=8266On other machines, connect to the master node:
ray start --address=[master_machine_ip]:8266On the master machine, run the following script to initiate training:
export WANDB_KEY=your_wandb_key
export SWANLAB_API_KEY=your_swanlab_key
export RETRIEVAL_URL=your_retrieval_url
scripts/run_grpo_multinode.sh [llm_name] [num_node] [num_gpu_per_node] [tp_for_rollout] [gpu_memory_utilization_for_rollout]export RETRIEVAL_URL=your_retrieval_url
scripts/evaluate.sh [dataset_name] [checkpoint_path] [output_path] [num_gpu] [tp_size] [num_generation] [temperature]Here, data_name can be one of hotpotqa, 2wiki, musique, bamboogle, aime24, aime25, math500, medqa, gpqa, webqsp, wtq, and textworld. checkpoint_path is the path to the checkpoint saved by veRL, output_path is the save path of evalation details, num_generation is the number of generations for each question (if num_generation > 1, the evaluation metric will be computed by averaging.), and temperature is the temperature for sampling.
If you find our work helpful, please consider citing our paper.