Skip to content

Solver-Judge Workflow Example

This example demonstrates a solver-judge workflow using rLLM's AgentWorkflowEngine. The workflow generates multiple candidate solutions to countdown problems and uses a judge to select the best one, showcasing multi-agent coordination.

Overview

The solver-judge workflow demonstrates:

  • How to implement custom workflows extending the Workflow base class
  • Multi-agent coordination between solver and judge agents
  • Parallel solution generation and quality assessment
  • Integration with rLLM's workflow engine

Quick Start

Setup Countdown Data

First, prepare the countdown dataset:

cd examples/solver_judge
python prepare_countdown_data.py

Run Solver-Judge Workflow

Execute the workflow on countdown problems:

python run_solver_judge_flow.py

Train with Solver-Judge Workflow

Train an agent using the solver-judge workflow:

bash train_solver_judge_flow.sh

Code Reference

Solver-Judge Workflow Implementation

The core workflow that coordinates solver and judge agents:

examples/solver_judge/solver_judge_flow.py
import asyncio
import re

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.engine import ModelOutput, RolloutEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.workflows.workflow import Workflow


class Solver:
    def __init__(self, rollout_engine: RolloutEngine, **kwargs):
        self.rollout_engine = rollout_engine

    async def generate_solution(self, problem: str) -> Trajectory:
        messages = [{"role": "user", "content": f"{problem}. Output the final answer within <answer>...</answer>"}]
        output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        return Trajectory(
            name="solver",
            steps=[
                Step(
                    chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
                    thought=output.reasoning,
                    action=self._parse_solver_response(output.content),
                    model_output=output,
                )
            ],
        )

    async def generate_solutions(self, problem: str, n_solutions: int = 2) -> list[Trajectory]:
        tasks = [asyncio.create_task(self.generate_solution(problem)) for _ in range(n_solutions)]
        return await asyncio.gather(*tasks)

    def _parse_solver_response(self, response: str) -> str:
        answer_match = re.search(r"<answer>(.*?)</answer>", response, re.IGNORECASE | re.DOTALL)
        if answer_match:
            return f"<answer>{answer_match.group(1).strip()}</answer>"
        else:
            return "No solution found"


class Judge:
    def __init__(self, rollout_engine: RolloutEngine, **kwargs):
        self.rollout_engine = rollout_engine

    async def judge_solutions(self, problem: str, solutions: list[str]) -> Trajectory:
        messages = [{"role": "user", "content": self._create_judge_prompt(problem, solutions)}]
        output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        return Trajectory(
            name="judge",
            steps=[
                Step(
                    chat_completions=messages + [{"role": "assistant", "content": output.content, "reasoning": output.reasoning}],
                    thought=output.reasoning,
                    action=self._parse_judge_response(output.content, solutions),
                    model_output=output,
                )
            ],
        )

    def _parse_judge_response(self, response: str, solutions: list[str]) -> str:
        answer_match = re.search(r"<answer>(.*?)</answer>", response, re.IGNORECASE | re.DOTALL)
        if answer_match:
            answer_text = answer_match.group(1).strip()
            try:
                solution_index = int(answer_text)
                return solutions[solution_index - 1]
            except (ValueError, IndexError):
                return ""
        return ""

    def _create_judge_prompt(self, problem: str, solutions: list[str]) -> str:
        """Create a prompt for the judge to evaluate solutions."""
        prompt = f"""You are an expert verifier. Given a countdown problem and multiple solution attempts, select a correct solution.
Problem:
{problem}
Solutions to evaluate:
"""
        for i, solution in enumerate(solutions, 1):
            prompt += f"\nSolution {i}:\n{solution}\n"

        prompt += """
A correct solution must satisfy the following criteria:
1. The solution uses only the given numbers.
2. Each number is used exactly once.
3. Only basic arithmetic operations (+, -, *, /) are used.
4. The calculation results in the target number.
5. The final answer is clearly marked within <answer>...</answer> tags.
Output the index of your selected solution within <answer>...</answer> tags, e.g., <answer>1</answer> for the first solution, <answer>2</answer> for the second solution, etc. If multiple solutions are correct, output the index of the first correct solution."""
        return prompt


class SolverJudgeWorkflow(Workflow):
    def __init__(self, rollout_engine: RolloutEngine, n_solutions: int = 2, reward_function: RewardFunction = None, **kwargs):
        super().__init__(rollout_engine, **kwargs)
        self.n_solutions = n_solutions
        self.reward_function = reward_function
        self.solver = Solver(rollout_engine)
        self.judge = Judge(rollout_engine)

    async def run(self, task: dict, uid: str, **kwargs) -> Episode:
        self.reset(task, uid)
        problem = task["question"]

        # Step 1: Solver generates multiple solutions in parallel
        solver_trajectories = await self.solver.generate_solutions(problem, self.n_solutions)

        # Assign rewards to solver trajectories
        solutions = []
        for traj in solver_trajectories:
            solution = traj.steps[0].action
            solutions.append(solution)
            reward = self.reward_function(task, solution).reward
            traj.steps[0].reward = reward

        # Step 2: Judge selects the best solution
        judge_trajectory = await self.judge.judge_solutions(problem, solutions)
        selected_solution = judge_trajectory.steps[0].action

        # Evaluate the selected solution
        reward_result = self.reward_function(task, selected_solution)
        judge_trajectory.steps[0].reward = reward_result.reward
        is_correct = reward_result.is_correct

        # Compute metrics
        solver_acc = sum(traj.steps[0].reward for traj in solver_trajectories) / len(solver_trajectories)
        judge_acc = int(is_correct)

        # Step 3: Return episode with multiple trajectories
        return Episode(
            id=uid,
            task=task,
            trajectories=[*solver_trajectories, judge_trajectory],
            is_correct=is_correct,
            metrics={"solver_acc": solver_acc, "judge_acc": judge_acc},
        )

Workflow Runner

Main script for running the solver-judge workflow:

examples/solver_judge/run_solver_judge_flow.py
import asyncio
import json
import os

# Import countdown-specific modules
import sys
from copy import deepcopy

from solver_judge_flow import SolverJudgeWorkflow
from transformers import AutoTokenizer

from rllm.data.dataset import DatasetRegistry
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.countdown_reward import countdown_reward_fn

sys.path.append(os.path.join(os.path.dirname(__file__), "..", "countdown"))


def load_data(n=1):
    """Load countdown data using the Dataset interface."""
    dataset = DatasetRegistry.load_dataset("countdown", "test")
    if dataset is None:
        print("Dataset not found, preparing dataset...")
        from prepare_countdown_data import prepare_countdown_data

        _, dataset, _, _ = prepare_countdown_data()

    data = []
    for idx, example in enumerate(dataset):
        processed = process_countdown_fn(example, idx)
        for i in range(n):
            data.append(deepcopy(processed))
    return data


def process_countdown_fn(example, idx):
    """Process countdown example into the expected format."""
    question = example["question"]
    target = example["target"]
    nums = example["nums"]

    # Create ground truth in the format expected by countdown_reward_fn
    ground_truth = {"target": target, "numbers": nums}

    task = {"question": question, "ground_truth": ground_truth, "idx": idx, "data_source": "countdown", "target": target, "nums": nums}
    return task


def evaluate_results(results):
    """Evaluate the results and compute pass@k metrics."""
    from collections import defaultdict

    # Create a map to store correct answers per problem
    problem_correct_map = defaultdict(int)
    problem_total_map = defaultdict(int)

    # Count correct answers for each problem
    for episode in results:
        problem = episode.task["question"]

        # Use the episode-level is_correct flag set by the workflow
        is_correct = episode.is_correct

        problem_correct_map[problem] += int(is_correct)
        problem_total_map[problem] += 1

    # Calculate pass@1 and pass@k
    k = max(problem_total_map.values()) if problem_total_map else 1
    total_problems = len(problem_correct_map)

    if total_problems > 0:
        pass_at_1 = sum(problem_correct_map.values()) / sum(problem_total_map.values())
        pass_at_k = sum(1 for problem, correct in problem_correct_map.items() if correct > 0) / total_problems
    else:
        pass_at_1 = 0.0
        pass_at_k = 0.0

    print("Total unique problems:", total_problems)
    print("Average Pass@1 Accuracy:", pass_at_1)
    print(f"Average Pass@{k} Accuracy:", pass_at_k)


if __name__ == "__main__":
    import os

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    # Configuration
    n_parallel_tasks = 128
    n_solutions = 2  # Number of solutions to generate per problem

    model_name = "Qwen/Qwen3-0.6B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    rollout_engine = OpenAIEngine(
        model=model_name,
        tokenizer=tokenizer,
        max_prompt_length=2048,
        max_response_length=1024,
        base_url="http://localhost:30000/v1",
        api_key="None",
        sampling_params={"temperature": 0.6, "top_p": 0.95},
    )

    engine = AgentWorkflowEngine(
        workflow_cls=SolverJudgeWorkflow,
        workflow_args={
            "n_solutions": n_solutions,
            "reward_function": countdown_reward_fn,
        },
        rollout_engine=rollout_engine,
        config=None,
        n_parallel_tasks=n_parallel_tasks,
        retry_limit=1,
    )

    # Load countdown tasks
    tasks = load_data(n=1)
    print(f"Loaded {len(tasks)} countdown tasks")

    results = asyncio.run(engine.execute_tasks(tasks))

    # Evaluate results (rewards are already assigned in the workflow)
    print("Evaluating results...")
    evaluate_results(results)

    # Save results
    os.makedirs("logs", exist_ok=True)
    with open("logs/solver_judge_countdown.json", "w") as f:
        json.dump([episode.to_dict() for episode in results], f, indent=4)

    print("\nResults saved to logs/solver_judge_countdown.json")

Training Script

Training configuration using the solver-judge workflow:

examples/solver_judge/train_solver_judge_flow.py
import hydra

from examples.solver_judge.solver_judge_flow import SolverJudgeWorkflow
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.countdown_reward import countdown_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    train_dataset = DatasetRegistry.load_dataset("countdown", "train")
    test_dataset = DatasetRegistry.load_dataset("countdown", "test")

    trainer = AgentTrainer(
        workflow_class=SolverJudgeWorkflow,
        workflow_args={
            "n_solutions": 2,
            "reward_function": countdown_reward_fn,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=test_dataset,
    )
    trainer.train()


if __name__ == "__main__":
    main()

How It Works

  1. Solver Phase: Generate multiple candidate solutions in parallel
  2. Judge Phase: Evaluate solutions and select the best one
  3. Episode Completion: Determine overall correctness based on both solver and judge results

The workflow uses the countdown dataset where agents must use given numbers and basic arithmetic operations to reach a target number.