Thanks to visit codestin.com
Credit goes to github.com

Skip to content

jnanliu/SituatedThinker

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SituatedThinker: Grounding LLM Reasoning with Real-World through Situated Thinking

SituatedThinker Overview

📋 Introduction

Recent advances in large language models (LLMs) demonstrate their impressive reasoning capabilities. However, the reasoning confined to internal parametric space limits LLMs' access to real-time information and understanding of the physical world. To overcome this constraint, we introduce SituatedThinker, a novel framework that enables LLMs to ground their reasoning in real-world contexts through situated thinking, which adaptively combines both internal knowledge and external information with predefined interfaces. By utilizing reinforcement learning, SituatedThinker incentivizes deliberate reasoning with the real world to acquire information and feedback, allowing LLMs to surpass their knowledge boundaries and enhance reasoning. Experimental results demonstrate significant performance improvements on multi-hop question-answering and mathematical reasoning benchmarks. Furthermore, SituatedThinker demonstrates strong performance on unseen tasks, such as KBQA, TableQA, and text-based games, showcasing the generalizable real-world grounded reasoning capability.

📦 Dependencies

  • 🐍 Python 3.10

⚙️ Installation

Run the following commands to set up the environment and install dependencies:

conda create -n situated-thinker python=3.12
conda activate situated-thinker
cd src
pip install -e .
pip install flash-attn==2.7.4.post1 --no-build-isolation

🚀 Quick Start

🧠 GRPO Training

🔍 Prepare Retrieval Interface

🗂️ Index Corpus

We utilize the Wikipedia 2018 dump as the corpus for the retrieval interface.

First, deploy the embedding model to generate embeddings for both text and queries:

scripts/deploy/deploy_llm.sh [your_embedding_model_name] [num_gpu]

Update the following entries in src/runtime.env:

RETRIEVAL_EMBED_MODEL=your_embedding_model_name
RETRIEVAL_EMBED_KEY=EMPTY
RETRIEVAL_EMBED_DIM=your_embedding_model_dim
RETRIEVAL_EMBED_URL=your_embedding_model_url

Then, generate the FAISS index for the corpus:

scripts/build_faiss_index.sh

The index and corresponding corpus will be saved in cache/faiss_index/wikipedia18.

🚀 Deploy Retrieval Interface

Deploy the retrieval interface on a machine with at least one A100-80G GPU or equivalent:

scripts/deploy/deploy_retrieval_interface.sh

Obtain the URL of the deployed retrieval interface.

🧪 Prepare Code Execution Interface

🛡️ Deploy Sandbox

Refer to SandboxFusion to deploy the sandbox service. Then, set your sandbox service URL in src/runtime.env:

SANDBOX_FUSION_ENDPOINT=your_sandbox_url

📄 Prepare Training Data

Run the following script to generate training data for GRPO:

scripts/build_grpo_data.sh

The data will be saved in cache/data/grpo.

🏋️‍♂️ Training

⚡ Start Ray Cluster

On the master machine, start the Ray cluster:

ray start --head --port=8266

On other machines, connect to the master node:

ray start --address=[master_machine_ip]:8266
🎯 Start Training

On the master machine, run the following script to initiate training:

export WANDB_KEY=your_wandb_key
export SWANLAB_API_KEY=your_swanlab_key
export RETRIEVAL_URL=your_retrieval_url

scripts/run_grpo_multinode.sh [llm_name] [num_node] [num_gpu_per_node] [tp_for_rollout] [gpu_memory_utilization_for_rollout]

🧪 Evaluation

export RETRIEVAL_URL=your_retrieval_url 

scripts/evaluate.sh [dataset_name] [checkpoint_path] [output_path] [num_gpu] [tp_size] [num_generation] [temperature]

Here, data_name can be one of hotpotqa, 2wiki, musique, bamboogle, aime24, aime25, math500, medqa, gpqa, webqsp, wtq, and textworld. checkpoint_path is the path to the checkpoint saved by veRL, output_path is the save path of evalation details, num_generation is the number of generations for each question (if num_generation > 1, the evaluation metric will be computed by averaging.), and temperature is the temperature for sampling.

📖 Citation

If you find our work helpful, please consider citing our paper.

About

[Preprint 2025] SituatedThinker: Grounding LLM Reasoning with Real-World through Situated Thinking

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published