Train Multiple Agent Roles Within a Single LLM via Reinforcement Learning.
MATPO allows planner and worker agents to coexist within a single LLM and be trained via RL, achieving an 18.38% relative improvement over single-agent baselines on GAIA-text, FRAMES, and WebWalker-QA.
- [2025-Oct-31] Enabled LoRA support for MATPO training
- [2025-Oct-31] Quick start guide with PyTorch basic docker image released
- [2025-Oct-08] MATPO-Qwen3-14B checkpoints and rollouts released
- [2025-Oct-08] Code and training scripts released
- [2025-Oct-06] Arxiv Paper released
MATPO (Multi-Agent Tool-Integrated Policy Optimization) is a novel reinforcement learning framework that enables training multiple specialized agent roles (planner and worker agents) within a single large language model.
Current single-agent approaches for multi-turn tool-integrated planning face critical limitations:
- Context Length Bottleneck: Tool responses (e.g., web scraping) consume excessive tokens, making long-range planning prohibitive
- Noisy Tool Responses: Raw tool responses interfere with the model's attention and planning capabilities
MATPO introduces a multi-agent-in-one-model architecture where:
- A planner-agent orchestrates high-level planning and delegates subtasks
- Worker-agents handle specific browsing and search tasks with isolated contexts
- Both roles are trained within a single LLM using role-specific prompts via reinforcement learning
- Multi-Agent-in-One-Model: Train planner and worker agents within a single LLM using role-specific system prompts
- Principled Credit Assignment: Extends GRPO with theoretically grounded reward distribution across planner and worker rollouts
- Easy Integration: Built on top of veRL, compatible with existing RL training frameworks
- Robust Training: More stable learning curves compared to single-agent approaches, especially with noisy tool responses
- Infrastructure Efficient: No need for deployment of separate models or additional rollout engines
MATPO employs a hierarchical multi-agent framework where a single LLM serves multiple roles:
User Query → Planner Agent → Subtask 1 → Worker Agent → Result 1
→ Subtask 2 → Worker Agent → Result 2
→ ...
→ Final Answer
Comparison between the rollout trajectories between the single-agent GRPO (top) and the multi-agent MATPO (bottom).
-
Planner Agent:
- Receives user query with planner-specific system prompt
- Generates high-level plan and decomposes it into subtasks
- Delegates subtasks to worker agents
- Synthesizes worker responses into final answer
-
Worker Agent:
- Receives subtask with worker-specific system prompt
- Performs multi-turn tool-integrated planning (search, scrape, analyze)
- Returns summarized result to planner
- Maintains isolated context to prevent token overflow
-
Credit Assignment:
- Final answer accuracy determines the reward
- Reward is normalized across all planner-worker rollout groups
- Gradient flows to both planner actions and worker actions proportionally
Visualization of MATPO implementation.
This section provide a quick start guide from PyTorch basic docker image.
Prerequisites:
- You can train a Qwen3-4B model with 1 x (8 x 80G-A800) NVIDIA A800 80GB GPUs.
- NVIDIA Driver.
- Docker with the NVIDIA Container Toolkit.
Step 1: Clone the repository.
cd YOUR_WORKING_DIR
git clone https://github.com/mzf666/MATPO.git
cd MATPOStep 2: Download the training and testing datasets to the data directory. The prerpocessed datasets can be downloaded here.
Step 3: Launch the PyTorch docker container. Make sure the mounted directories are writable.
docker pull pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
docker run -it \
--gpus all \
--shm-size 16gb \
--name matpo \
-v YOUR_WORKING_DIR/MATPO:/workspace/MATPO:rw \
-v YOUR_WORKING_DIR/models:/workspace/models \
-w /workspace/MATPO \
pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel \
/bin/bashNow you are in the PyTorch docker container.
Step 4: Start to install the python dependencies inside the docker container.
# Execute the following commands inside the docker container
# Create a new conda environment for MATPO
conda create -n matpo python==3.10 -y
conda init bash
source /opt/conda/etc/profile.d/conda.sh
conda activate matpo
# Install the python dependencies
# You are highy recommended to execute the commands in the install.sh script one by one.
source /workspace/MATPO/examples/sglang_multiturn/install.shStep 5: Setup Node.js for Serper API support inside the docker container.
MCP (Model Context Protocol) requires Node.js to run MCP servers. Node.js version 18+ is recommended for optimal compatibility with MCP tools. Configure the Node.js paths and HTTP / HTTPS proxies (if necessary) in the
examples/sglang_multiturn/run_in_docker/launch.shscript properly.
# Install Node.js
apt-get update && apt-get install -y wget xz-utils git
cd /workspace
wget https://nodejs.org/dist/v24.2.0/node-v24.2.0-linux-x64.tar.xz
cd -
NODEJS_HOME=/workspace/nodejs
mkdir -p $NODEJS_HOME
tar -xf /workspace/node-v24.2.0-linux-x64.tar.xz -C $NODEJS_HOME
# Configure Node.js environment variables
export PATH=$NODEJS_HOME/bin:$PATH
export NODE_SHARED=$NODEJS_HOME/node-shared/node_modules
export PATH=$NODE_SHARED/.bin:$PATH
# Verify Node.js installation
node --version
npm --version
# Install Serper MCP Server
mkdir -p $target_path/node-shared
cd $target_path/node-shared
npm init -y
npm install serper-search-scrape-mcp-server
# Back to MATPO repository
cd /workspace/MATPO Step 6: Test the environment setup and run single-node training. Train a Qwen3-4B model with MATPO on the MuSiQue dataset and evaluate on the GAIA-text datasets.
Remember to adjust the directory paths in examples/sglang_multiturn/launch.sh accordingly, e.g. YOUR_NODEJS_HOME=/workspace/nodejs/node-v24.2.0-linux-x64 and YOUR_NODE_SHARED=/workspace/nodejs/node-shared/node_modules.
# Execute the following commands inside the docker container
# Tested on 1 x (8 x 80G-A800) nodes
#!/bin/bash
source /opt/conda/etc/profile.d/conda.sh
export SERPER_API_KEY="YOUR_SERPER_API_KEY"
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
export WANDB_API_KEY="YOUR_WANDB_API_KEY"
export SINGLENODE=true
export RAY_DEBUG=legacy
export HYDRA_FULL_ERROR=1
conda activate matpo
cd /workspace/MATPO
bash ./examples/sglang_multiturn/launch.sh \
examples/sglang_multiturn/qwen3-4b_musique_single_agent.sh
# bash ./examples/sglang_multiturn/run_in_docker/launch.sh ./examples/sglang_multiturn/run_in_docker/qwen3-4b_musique_single_agent.shIf you counter any issues during the environment setup, you can refer to the examples/sglang_multiturn/pip_list_reference.txt for the expected python dependencies list and check the installation process step by step.
Step 7: Train a Qwen3-14B model with MATPO on the MuSiQue dataset and evaluate on the GAIA-text datasets using computation platforms with multiple GPU nodes. Remember to adjust the directory paths in examples/sglang_multiturn/launch.sh accordingly.
# tested on 16 x (8 x 80G-A800) nodes
export SERPER_API_KEY="YOUR_SERPER_API_KEY" && \
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY" && \
export WANDB_API_KEY="YOUR_WANDB_API_KEY" && \
export SINGLENODE=true && \
export RAY_DEBUG=legacy && \
export HYDRA_FULL_ERROR=1 && \
source YOUR_CONDA_PATH activate matpo && \
cd YOUR_PROJECT_PATH && \
bash examples/sglang_multiturn/launch.sh \
examples/sglang_multiturn/qwen3-14b_musique_MATPO.shEvaluate a trained MATPO / single-agent model checkpoint.
# test on 2 x (8 x 80G-A800) nodes
export SERPER_API_KEY="YOUR_SERPER_API_KEY" && \
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY" && \
export WANDB_API_KEY="YOUR_WANDB_API_KEY" && \
export SINGLENODE=true && \
export RAY_DEBUG=legacy && \
export HYDRA_FULL_ERROR=1 && \
source YOUR_CONDA_PATH activate matpo && \
cd YOUR_PROJECT_PATH && \
bash examples/sglang_multiturn/launch.sh \
examples/sglang_multiturn/eval_MATPO.sh
# # To evaluate a trained single-agent GRPO model checkpoint:
# bash examples/sglang_multiturn/launch.sh \
# examples/sglang_multiturn/eval_single_agent.shWe enabled LoRA support for MATPO training. Please refer to MATPO_LORA_README.md for more illustrations. An example command for LoRA-enabled MATPO training on Qwen3-4B is provided below:
# Execute the following commands inside the docker container
# Tested on 1 x (8 x 80G-A800) nodes
#!/bin/bash
source /opt/conda/etc/profile.d/conda.sh
export SERPER_API_KEY="YOUR_SERPER_API_KEY"
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
export WANDB_API_KEY="YOUR_WANDB_API_KEY"
export SINGLENODE=true
export RAY_DEBUG=legacy
export HYDRA_FULL_ERROR=1
conda activate matpo
cd /workspace/MATPO
bash ./examples/sglang_multiturn/launch.sh \
examples/sglang_multiturn/qwen3-4b_musique_MATPO_lora.shMATPO consistently outperforms single-agent GRPO baselines across all benchmarks:
| Method | GAIA-text | WebWalkerQA | FRAMES | Relative Average Improvement |
|---|---|---|---|---|
| Single-Agent GRPO | 32.16% | 30.14% | 56.22% | - |
| MATPO (Ours) | 42.60% | 33.00% | 63.64% | +18.38% |
- Base Model: Qwen3-14B-base
- Training Dataset: Filtered MuSiQue dataset.
- Training Steps: 180 steps
- Rollouts per Query: 8 (for group normalization)
- Reward Function: 0.9 × accuracy + 0.1 × tool_format_reward
We release the trained Qwen3-14B-base model checkpoints at the 180th training step of both single-agent GRPO and MATPO.
The associated model rollouts across various training steps can be found here.
-
More Stable Training: MATPO exhibits more stable learning curves and avoids catastrophic performance drops observed in single-agent training
-
Robustness to Noise: Multi-agent decomposition effectively isolates noisy tool responses, preventing them from interfering with high-level planning
-
Better Credit Assignment: Principled reward distribution across planner and worker rollouts leads to more effective learning
Based on our experiments, we recommend:
- Final Summary: Final summaries from worker agents are critical for clean planner-worker interfaces
- Query Recap: Recapping original user query in worker prompt significantly improves performance
- URL Blocking: Remember to blocking HuggingFace search results to avoid data leakage
To add new tools or worker-agents, you need to:
- Update the agent-tool configuration file: modify the
examples/sglang_multiturn/config/tool_config/mcp_tool_config_full_agent.yamlagent-tool configuration file following the instructions in the following subsections. - Update the system prompt of your training data: update the system prompt of your training data accordingly, to ensure the planner-agent recieve the instructions to call new tools or new worker-agents correctly. A script to establish training data with updated system prompts based on a new agent-tool configuration file is provided in
examples/data_preprocess/update_system_prompt.py.
In the following subsections, we present some examples on how to add new tools or worker-agents by modifying the agent-tool configuration file.
You can add a new MCP tool to the planner-agent by adding a new entry to the tools list. Before adding the new tool, you need to wrap up your tool server into a standard MCP server. A tutorial on how to building new MCP tools is provided here.
In this example, we add a new MCP tool for calling a Python interpreter.
tools:
- class_name: "verl.tools.mcp_tool.MCPTool"
config:
command: "serper-mcp"
args: []
env: ["SERPER_API_KEY", "HTTPS_PROXY"]
server_name: "search_and_scrape_webpage"
tool_schema:
type: "mcp"
function: {}
# A newly added Python interpreter MCP tool.
- class_name: "verl.tools.mcp_tool.MCPTool"
config:
command: "python3"
args: ["verl/tools/python_server.py"]
env: ["E2B_API_KEY", "HTTPS_PROXY"]
server_name: "python_interpreter"
tool_schema:
type: "mcp"
function: {}
- class_name: "verl.tools.mcp_tool.MCPTool"
config:
command: ""
server_name: "browsing_agent"
tool_schema:
type: "agent"
function: {}
schema:
name: "search_and_browse"
description: "This tool is an agent that performs the subtask of searching and browsing the web for specific missing information and generating the desired answer. The subtask should be clearly defined, include relevant background, and focus on factual gaps. It does not perform vague or speculative subtasks. \nArgs: \n\tsubtask: the subtask to be performed. \nReturns: \n\tthe result of the subtask."
parameters:
properties:
subtask:
title: "Subtask"
type: "string"
required:
- "subtask"
title: "search_and_browseArguments"
type: "object"
server_name: "browsing_agent"
agents:
- agent_type: "main_agent"
tools:
- browsing_agent
- python_interpreter # Now, the planner-agent can directly call a Python interpreter tool, or it can delegate a searching-and-browsing subtask to the browsing-agent.
- agent_type: "browsing_agent"
tools:
- search_and_scrape_webpageYou can define and add a new worker-agent by adding a new entry to the tools and agents list. In this example, we add a new worker-agent for tackling programming subtasks using a Python interpreter.
tools:
- class_name: "verl.tools.mcp_tool.MCPTool"
config:
command: "serper-mcp"
args: []
env: ["SERPER_API_KEY", "HTTPS_PROXY"]
server_name: "search_and_scrape_webpage"
tool_schema:
type: "mcp"
function: {}
# A newly added Python interpreter MCP tool.
- class_name: "verl.tools.mcp_tool.MCPTool"
config:
command: "python3"
args: ["verl/tools/python_server.py"]
env: ["E2B_API_KEY", "HTTPS_PROXY"]
server_name: "python_interpreter"
tool_schema:
type: "mcp"
function: {}
- class_name: "verl.tools.mcp_tool.MCPTool"
config:
command: ""
server_name: "browsing_agent"
tool_schema:
type: "agent"
function: {}
schema:
name: "search_and_browse"
description: "This tool is an agent that performs the subtask of searching and browsing the web for specific missing information and generating the desired answer. The subtask should be clearly defined, include relevant background, and focus on factual gaps. It does not perform vague or speculative subtasks. \nArgs: \n\tsubtask: the subtask to be performed. \nReturns: \n\tthe result of the subtask."
parameters:
properties:
subtask:
title: "Subtask"
type: "string"
required:
- "subtask"
title: "search_and_browseArguments"
type: "object"
server_name: "browsing_agent"
# A newly added Python coding agent.
- class_name: "verl.tools.mcp_tool.MCPTool"
config:
command: ""
server_name: "python_agent"
tool_schema:
type: "agent"
function: {}
schema:
name: "coding_in_python"
description: "This tool is an agent that address subtasks that can be solved via Python programes. The subtask should be clearly defined, include relevant background, and focus on factual gaps. It does not perform vague or speculative subtasks. \nArgs: \n\tsubtask: the subtask to be performed. \nReturns: \n\tthe result of the subtask."
parameters:
properties:
subtask:
title: "Subtask"
type: "string"
required:
- "subtask"
title: "coding_in_pythonArguments"
type: "object"
server_name: "python_agent"
agents:
- agent_type: "main_agent"
tools:
- browsing_agent
- python_agent # The main-agent can delegate a coding subtask to the python-agent.
- agent_type: "browsing_agent"
tools:
- search_and_scrape_webpage
# A newly added Python coding agent.
- agent_type: "python_agent"
tools:
- python_interpreter # The python-agent can directly call a Python interpreter tool to solve the coding subtask from the main-agent.In this repository, the multi-agent rollout orchestration is implemented the _async_rollout_a_request() function in verl/workers/rollout/sglang_rollout/sglang_rollout.py. The workflow is sketched as follows:
- At each intermediate step in planner-agent rollout:
- Once a MCP tool-call is parsed from the LLM response, the
_async_rollout_a_request()function will be called to determine which agent to call, and return the MCP tool-call response. - Once a worker-agent-call is parsed from the LLM response, the
_async_rollout_a_request()function will create a newAsyncRolloutRequestobject for this worker-agent-call, and pass it to_async_rollout_a_request()again to trigger a new agentic rollout.
- After the planner-agent rollout terminates (either no further tool-calls are parsed), all the worker-agent rollouts will be collected, serialized, and packed with the planner-agent rollout.
- The packed rollout will be passed to
verl/trainer/ppo/ray_trainer.pyto proceed MATPO training.
You can customize your own rollout orchestration by modifying the logic and scheduling in the _async_rollout_a_request() function.
If you find MATPO helpful in your research, please consider citing our paper:
@misc{mo2025multiagenttoolintegratedpolicyoptimization,
title={Multi-Agent Tool-Integrated Policy Optimization},
author={Zhanfeng Mo and Xingxuan Li and Yuntao Chen and Lidong Bing},
year={2025},
eprint={2510.04678},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2510.04678},
}We would like to thank:
- VolcEngine for developing and open-sourcing veRL, the RL training framework that powers MATPO
- Alibaba Cloud for the Qwen3 model series
- Google for the Serper API that enables web search capabilities
- The authors of GAIA, WebWalkerQA, FRAMES, and MuSiQue datasets
- The open-source community for valuable feedback and contributions
Q: What's the difference between MATPO and traditional multi-agent systems?
MATPO uses a single LLM to play multiple agent roles via different system prompts, rather than deploying separate models. This offers:
- Lower infrastructure complexity
- Better parameter efficiency
- Easier deployment and maintenance
- Compatible with existing RL frameworks
Q: Can I use MATPO with models other than Qwen3?
Yes! MATPO is model-agnostic. You can use any decoder-only LLM that supports tool calling and multi-turn conversations. We've tested with Qwen3-14B-base, but models like Llama 3, Mistral, or other reasoning-capable LLMs should work.
Q: How many GPUs do I need for training?
For Qwen3-14B-base, we recommend:
- Training: 16 x (8 x A100/A800 GPUs (80GB))
- Inference: 2 x (8 x A100/A800 GPUs (80GB))
Q: How does MATPO handle credit assignment?
MATPO extends GRPO with principled credit assignment:
- The planner's final answer determines the accuracy reward
- This reward is normalized across all rollouts in a group
- Gradients flow proportionally to both planner and worker actions
- Worker agents receive the same advantage value as their parent planner rollout
See our paper for more details.
Q: Can I use MATPO for tasks other than web search?
Absolutely! While our paper focuses on web search, MATPO's framework is general. You can extend it to:
- Code generation with execution feedback
- Scientific reasoning with calculator tools
- Data analysis with pandas/SQL tools
- Any multi-turn task with verifiable rewards
Q: How stable is MATPO training compared to single-agent RL?
MATPO is significantly more stable. Our experiments show:
- Single-agent GRPO often suffers catastrophic drops after step 120
- MATPO maintains steady improvement throughout training
- Multi-agent structure isolates noisy tool responses, preventing interference
See Figure 4 in our paper for training curves.
Q: Do I need to block HuggingFace URLs during training?
For research integrity, yes - especially if your evaluation benchmarks are hosted on HuggingFace. This prevents models from "cheating" by finding ground-truth answers online.
For production systems with no data leakage concerns, this is optional.
Star ⭐ this repository if you find it helpful!