Skip to content

Search Agent Example

This example follows the set up from Search-R1 to train a agent that can perform interleaved reasoning and search to answer multi-hop QA questions.

Overview

The search examples demonstrate: - How to use rLLM's ToolAgent and ToolEnvironment - How to write custom tools in rLLM

Quick Start

Setup Search Data

First, prepare your search data:

cd examples/search
python prepare_search_data.py

Run Search Agent

Execute the search agent:

python run_search_agent.py

Train Search Agent

Train your own search agent:

bash train_search_agent.sh

Code Reference

Search Agent Runner

Main script for running search operations:

examples/search/run_search_agent.py
import asyncio
import os

from dotenv import load_dotenv
from local_retrieval_tool import LocalRetrievalTool
from transformers import AutoTokenizer

from rllm.agents.system_prompts import SEARCH_SYSTEM_PROMPT
from rllm.agents.tool_agent import ToolAgent
from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_execution_engine import AgentExecutionEngine
from rllm.environments.tools.tool_env import ToolEnvironment
from rllm.rewards.reward_fn import search_reward_fn
from rllm.utils import save_trajectories


def load_search_data(train_size=3000, test_size=100):
    """
    Load search data, preparing it if not already available.
    Returns the test dataset data for evaluation.
    """
    test_dataset = DatasetRegistry.load_dataset("hotpotqa", "test")
    if test_dataset is None:
        print("Dataset not found, preparing search dataset...")
        from prepare_hotpotqa_data import prepare_hotpotqa_data

        _, test_dataset = prepare_hotpotqa_data(train_size=train_size, test_size=test_size)

    return test_dataset.get_data()


if __name__ == "__main__":
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    if "RETRIEVAL_SERVER_URL" not in os.environ:
        os.environ["RETRIEVAL_SERVER_URL"] = "http://127.0.0.1:8000"

    load_dotenv()

    n_parallel_agents = 64

    model_name = "Qwen/Qwen3-4B"

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    sampling_params = {"temperature": 0.6, "top_p": 0.95, "model": model_name}

    tool_map = {"local_search": LocalRetrievalTool}

    engine = AgentExecutionEngine(
        agent_class=ToolAgent,
        agent_args={"tool_map": tool_map, "system_prompt": SEARCH_SYSTEM_PROMPT, "parser_name": "qwen"},
        env_class=ToolEnvironment,
        env_args={"tool_map": tool_map, "reward_fn": search_reward_fn},
        rollout_engine=None,
        engine_name="openai",
        tokenizer=tokenizer,
        sampling_params=sampling_params,
        rollout_engine_args={
            "base_url": "http://localhost:30000/v1",
            "api_key": "None",
        },
        max_response_length=16384,
        max_prompt_length=4096,
        config=None,
        n_parallel_agents=n_parallel_agents,
    )

    tasks = load_search_data()

    results = asyncio.run(engine.execute_tasks(tasks))

    save_trajectories(results, filename="search_trajectories.pt")

Training Script

Search agent training configuration:

examples/search/train_search_agent.py
import hydra

from rllm.agents.system_prompts import SEARCH_SYSTEM_PROMPT
from rllm.agents.tool_agent import ToolAgent
from rllm.data import DatasetRegistry
from rllm.environments.tools.tool_env import ToolEnvironment
from rllm.rewards.reward_fn import search_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer

from .local_retrieval_tool import LocalRetrievalTool


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    train_dataset = DatasetRegistry.load_dataset("hotpotqa", "train")
    val_dataset = DatasetRegistry.load_dataset("hotpotqa", "test")

    tool_map = {"local_search": LocalRetrievalTool}

    env_args = {
        "max_steps": 20,
        "tool_map": tool_map,
        "reward_fn": search_reward_fn,
    }

    agent_args = {"system_prompt": SEARCH_SYSTEM_PROMPT, "tool_map": tool_map, "parser_name": "qwen"}

    # Use the registry-based approach (comment out the other approach)
    trainer = AgentTrainer(
        agent_class=ToolAgent,
        env_class=ToolEnvironment,
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        agent_args=agent_args,
        env_args=env_args,
    )

    trainer.train()


if __name__ == "__main__":
    main()

For detailed setup instructions, see the README in the search example directory.