Search Agent Example
This example follows the set up from Search-R1 to train a agent that can perform interleaved reasoning and search to answer multi-hop QA questions.
Overview
The search examples demonstrate: - How to use rLLM's ToolAgent and ToolEnvironment - How to write custom tools in rLLM
Quick Start
Setup Search Data
First, prepare your search data:
Run Search Agent
Execute the search agent:
Train Search Agent
Train your own search agent:
Code Reference
Search Agent Runner
Main script for running search operations:
examples/search/run_search_agent.py
import asyncio
import os
from dotenv import load_dotenv
from local_retrieval_tool import LocalRetrievalTool
from transformers import AutoTokenizer
from rllm.agents.system_prompts import SEARCH_SYSTEM_PROMPT
from rllm.agents.tool_agent import ToolAgent
from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_execution_engine import AgentExecutionEngine
from rllm.environments.tools.tool_env import ToolEnvironment
from rllm.rewards.reward_fn import search_reward_fn
from rllm.utils import save_trajectories
def load_search_data(train_size=3000, test_size=100):
"""
Load search data, preparing it if not already available.
Returns the test dataset data for evaluation.
"""
test_dataset = DatasetRegistry.load_dataset("hotpotqa", "test")
if test_dataset is None:
print("Dataset not found, preparing search dataset...")
from prepare_hotpotqa_data import prepare_hotpotqa_data
_, test_dataset = prepare_hotpotqa_data(train_size=train_size, test_size=test_size)
return test_dataset.get_data()
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "true"
if "RETRIEVAL_SERVER_URL" not in os.environ:
os.environ["RETRIEVAL_SERVER_URL"] = "http://127.0.0.1:8000"
load_dotenv()
n_parallel_agents = 64
model_name = "Qwen/Qwen3-4B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
sampling_params = {"temperature": 0.6, "top_p": 0.95, "model": model_name}
tool_map = {"local_search": LocalRetrievalTool}
engine = AgentExecutionEngine(
agent_class=ToolAgent,
agent_args={"tool_map": tool_map, "system_prompt": SEARCH_SYSTEM_PROMPT, "parser_name": "qwen"},
env_class=ToolEnvironment,
env_args={"tool_map": tool_map, "reward_fn": search_reward_fn},
rollout_engine=None,
engine_name="openai",
tokenizer=tokenizer,
sampling_params=sampling_params,
rollout_engine_args={
"base_url": "http://localhost:30000/v1",
"api_key": "None",
},
max_response_length=16384,
max_prompt_length=4096,
config=None,
n_parallel_agents=n_parallel_agents,
)
tasks = load_search_data()
results = asyncio.run(engine.execute_tasks(tasks))
save_trajectories(results, filename="search_trajectories.pt")
Training Script
Search agent training configuration:
examples/search/train_search_agent.py
import hydra
from rllm.agents.system_prompts import SEARCH_SYSTEM_PROMPT
from rllm.agents.tool_agent import ToolAgent
from rllm.data import DatasetRegistry
from rllm.environments.tools.tool_env import ToolEnvironment
from rllm.rewards.reward_fn import search_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer
from .local_retrieval_tool import LocalRetrievalTool
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
train_dataset = DatasetRegistry.load_dataset("hotpotqa", "train")
val_dataset = DatasetRegistry.load_dataset("hotpotqa", "test")
tool_map = {"local_search": LocalRetrievalTool}
env_args = {
"max_steps": 20,
"tool_map": tool_map,
"reward_fn": search_reward_fn,
}
agent_args = {"system_prompt": SEARCH_SYSTEM_PROMPT, "tool_map": tool_map, "parser_name": "qwen"}
# Use the registry-based approach (comment out the other approach)
trainer = AgentTrainer(
agent_class=ToolAgent,
env_class=ToolEnvironment,
config=config,
train_dataset=train_dataset,
val_dataset=val_dataset,
agent_args=agent_args,
env_args=env_args,
)
trainer.train()
if __name__ == "__main__":
main()
For detailed setup instructions, see the README in the search example directory.