Skip to content

RL Training with Tinker

This example shows how to train a solver‑judge RL workflow with the Tinker backend in rLLM, using Tinker's hosted GPU service.

Overview

With this example you will:

  1. Train a solver‑judge workflow for the Countdown task using the Tinker backend
  2. Reuse the unified Tinker RL config (tinker_rl_trainer.yaml) for workflow training

Under the hood, rLLM integrates with Tinker as:

  • Rollout backend: sampling and logprob computation happen on Tinker's GPU service
  • Policy trainer: LoRA adapters are optimized remotely via Tinker training clients
  • Checkpoint manager: checkpoints are stored and resumed via Tinker model IDs

Setup

Install dependencies

uv pip install -e .[tinker] --torch-backend=cpu

Configure Tinker authentication

Set your Tinker API key:

export TINKER_API_KEY=your_api_key_here

You can obtain an API key from the Tinker console.

Shared Tinker RL config

This example uses the unified RL config in rLLM:

rllm/trainer/config/tinker_rl_trainer.yaml
# Tinker Backend Configuration for rLLM
# This config is used when training agents with Tinker backend
# Unified configuration supporting both workflow and agent training modes
# Default settings match tinker_cookbook.recipes.math_rl for MATH dataset

# Tinker-specific settings
tinker_base_url: null  # Tinker service URL (null for local)

# Model Configuration
model:
  name: "Qwen/Qwen3-8B"  # Default model for MATH dataset
  lora_rank: 32
  train_unembed: true  # Train LoRA on output embedding layer (set to false for Fireworks compatibility)
  train_attn: true     # Train LoRA on attention layers
  train_mlp: true      # Train LoRA on MLP layers

# Training Configuration
training:
  group_size: 16  # Number of rollouts per prompt (for GRPO)
  learning_rate: 2e-5  # 2e-5 for MATH dataset
  beta1: 0.9
  beta2: 0.95
  eps: 1e-8
  max_length: 32768
  num_minibatches: 1

# Sampling Configuration
sampling:
  temperature: 0.6
  top_p: 0.95

# Algorithm Configuration (compatible with verl)
algorithm:
  adv_estimator: grpo  # REINFORCE, GRPO
  gamma: 1.0
  lam: 0.95
  norm_adv_by_std_in_grpo: false  # math_rl doesn't normalize by std
  grouping_level: 'trajectory'  # Options: 'trajectory' or 'step'

# Workflow Configuration (for workflow training mode)
workflow:
  n_parallel_tasks: 256
  retry_limit: 3 

# Agent Configuration (for agent training mode)
agent:
  max_steps: 1  # Single-turn vs multi-turn
  agent_args: {}

# Environment Configuration (for agent training mode)
env:
  env_args: {}

# Data Configuration
data:
  train_files: null
  val_files: null
  max_prompt_length: 2048
  max_response_length: 2048
  train_batch_size: 64
  val_batch_size: 32

# Trainer Configuration
trainer:
  total_epochs: 10
  logger: ['console']  # Options: 'console', 'wandb', 'tensorboard'
  project_name: 'rllm-tinker'
  experiment_name: 'default'
  test_freq: 5
  save_freq: 20
  val_before_train: true
  default_local_dir: '/tmp/rllm-tinker-checkpoints'
  resume_from_tinker_id: null  # Tinker model ID path to resume from (e.g., tinker://uuid/weights/000060)

# Hydra configuration
hydra:
  run:
    dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}

Key options you may want to tune:

  • model.name: base model to fine‑tune (e.g. Qwen/Qwen3-8B, Qwen/Qwen3-30B-A3B)
  • model.lora_rank: LoRA rank
  • training.group_size: number of trajectories per prompt (GRPO group size)
  • data.max_prompt_length / data.max_response_length: context and generation lengths
  • trainer.total_epochs, trainer.logger, trainer.project_name, trainer.experiment_name

You can override any of these from the command line using Hydra syntax (see below).


Solver‑Judge RL Training with Tinker

This example trains a multi‑agent solver‑judge workflow on the Countdown task using the same Tinker RL backend.

2.1 Prepare Countdown dataset

First download and register the Countdown dataset:

cd examples/countdown
python prepare_countdown_data.py

This will:

  • Load Jiayi-Pan/Countdown-Tasks-3to4 from HuggingFace
  • Convert each example into a math‑style word problem
  • Register multiple splits (train, test, stage2, stage3) under the countdown key

Dataset preparation:

examples/countdown/prepare_countdown_data.py
import random

from datasets import load_dataset

from rllm.data.dataset import DatasetRegistry


def prepare_countdown_data():
    """
    Prepare the countdown task dataset from HuggingFace.
    Take 1024 examples as test set, remaining as training set.
    Also create stage 2 and stage 3 training sets with 50k examples each.
    """
    # Load the countdown dataset
    dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")

    # Split dataset: 1024 examples for test, rest for training
    test_size = 1024
    total_size = len(dataset)

    # Create train/test split
    test_dataset = dataset.select(range(test_size))
    train_dataset = dataset.select(range(test_size, total_size))

    def preprocess_fn(example, idx):
        """
        Convert countdown task format to math problem format.
        Example: target=98, nums=[44, 19, 35] becomes a math word problem.
        """
        target = example["target"]
        nums = example["nums"]

        # Format as a math problem
        nums_str = ", ".join(map(str, nums))
        question = f"Using the numbers {nums_str}, find a way to reach the target number {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your step-by-step calculation and output the final answer within <answer>...</answer>, for example <answer> (1 + 2) / 3 </answer>."

        return {
            "question": question,
            "ground_truth": str(target),
            "data_source": "countdown",
            "target": target,
            "nums": nums,
        }

    # Apply preprocessing
    train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
    test_dataset = test_dataset.map(preprocess_fn, with_indices=True)

    # Create stage 2 and stage 3 training datasets
    train_size = len(train_dataset)
    stage_size = 50000

    # Ensure we have enough data for both stages
    if train_size < 2 * stage_size:
        print(f"Warning: Training set has only {train_size} examples, but need {2 * stage_size} for both stages")
        stage_size = min(stage_size, train_size // 2)

    # Shuffle and select indices for stage 2 and stage 3
    all_indices = list(range(train_size))
    random.shuffle(all_indices)

    stage2_indices = all_indices[:stage_size]
    stage3_indices = all_indices[stage_size : 2 * stage_size]

    # Create stage datasets
    stage2_dataset = train_dataset.select(stage2_indices)
    stage3_dataset = train_dataset.select(stage3_indices)

    # Register datasets
    train_dataset = DatasetRegistry.register_dataset("countdown", train_dataset, "train")
    test_dataset = DatasetRegistry.register_dataset("countdown", test_dataset, "test")
    stage2_dataset = DatasetRegistry.register_dataset("countdown", stage2_dataset, "stage2_train")
    stage3_dataset = DatasetRegistry.register_dataset("countdown", stage3_dataset, "stage3_train")

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    print(f"Stage 2 train dataset size: {len(stage2_dataset)}")
    print(f"Stage 3 train dataset size: {len(stage3_dataset)}")

    return train_dataset, test_dataset, stage2_dataset, stage3_dataset


if __name__ == "__main__":
    train_dataset, test_dataset, stage2_dataset, stage3_dataset = prepare_countdown_data()
    print("Train dataset path:", train_dataset.get_data_path())
    print("Test dataset path:", test_dataset.get_data_path())
    print("Stage 2 train dataset path:", stage2_dataset.get_data_path())
    print("Stage 3 train dataset path:", stage3_dataset.get_data_path())

    # Print a sample
    print("\nSample train example:")
    print(train_dataset[0])
    print("\nSample stage 2 train example:")
    print(stage2_dataset[0])
    print("\nSample stage 3 train example:")
    print(stage3_dataset[0])

2.2 Solver‑Judge workflow with Tinker backend

The Tinker RL training entrypoint for the solver‑judge workflow is:

examples/solver_judge_tinker/train_solver_judge_flow_tinker.py
import hydra

from examples.solver_judge.solver_judge_flow import SolverJudgeWorkflow
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.countdown_reward import countdown_reward_fn
from rllm.trainer import AgentTrainer


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="tinker_rl_trainer", version_base=None)
def main(config):
    train_dataset = DatasetRegistry.load_dataset("countdown", "train")
    test_dataset = DatasetRegistry.load_dataset("countdown", "test")

    trainer = AgentTrainer(
        workflow_class=SolverJudgeWorkflow,
        workflow_args={
            "n_solutions": 2,
            "reward_function": countdown_reward_fn,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=test_dataset,
        backend="tinker",
    )
    trainer.train()


if __name__ == "__main__":
    main()

It uses:

  • SolverJudgeWorkflow from examples.solver_judge.solver_judge_flow
  • countdown_reward_fn as the reward function
  • AgentTrainer with backend="tinker" and workflow_class=SolverJudgeWorkflow

2.3 Train solver‑judge workflow with Tinker

Run the provided shell script:

cd examples/solver_judge_tinker
bash train_solver_judge_flow_tinker.sh

This will:

  • Fine‑tune Qwen/Qwen3-4B-Instruct-2507 with LoRA (rank 32)
  • Train with GRPO using trajectory‑level grouping (algorithm.grouping_level=trajectory)
  • Use normalized advantages for stability (algorithm.norm_adv_by_std_in_grpo=true)
  • Log training metrics to Weights & Biases

Shell configuration:

examples/solver_judge_tinker/train_solver_judge_flow_tinker.sh
set -x

MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507

python -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
    model.name=$MODEL_PATH \
    model.lora_rank=32 \
    training.group_size=4 \
    training.learning_rate=4e-5 \
    sampling.temperature=1.0 \
    sampling.top_p=1.0 \
    algorithm.adv_estimator=grpo \
    algorithm.norm_adv_by_std_in_grpo=true \
    algorithm.grouping_level=trajectory \
    data.max_prompt_length=2048 \
    data.max_response_length=1024 \
    data.train_batch_size=64 \
    data.val_batch_size=512 \
    trainer.total_epochs=100 \
    trainer.logger=['wandb'] \
    trainer.project_name='solver-judge-workflow' \
    trainer.experiment_name='countdown-solver-judge-tinker-norm-by-std' \
    trainer.val_before_train=False \
    trainer.test_freq=10 \
    trainer.save_freq=20 \
    trainer.default_local_dir='/tmp/countdown-solver-judge-tinker-norm-by-std'

You can customize training via CLI overrides, e.g.:

cd examples/solver_judge_tinker
python -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
    model.name=Qwen/Qwen3-8B \
    model.lora_rank=16 \
    training.group_size=8 \
    data.train_batch_size=32 \
    trainer.total_epochs=20 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='solver-judge-tinker' \
    trainer.experiment_name='countdown-grpo-qwen3-8b'

2.4 Run the workflow with Tinker rollout engine (optional)

For interactive evaluation (no training step), you can run the Countdown solver‑judge workflow directly using Tinker for sampling:

examples/solver_judge_tinker/run_solver_judge_flow_tinker.py
import asyncio
import json
import os

# Import countdown-specific modules
import sys
from copy import deepcopy

import tinker
from solver_judge_flow import SolverJudgeWorkflow
from transformers import AutoTokenizer

from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.engine.rollout.tinker_engine import TinkerEngine
from rllm.rewards.countdown_reward import countdown_reward_fn

sys.path.append(os.path.join(os.path.dirname(__file__), "..", "countdown"))


def load_data(n=1):
    """Load countdown data using the Dataset interface."""
    dataset = DatasetRegistry.load_dataset("countdown", "test")
    if dataset is None:
        print("Dataset not found, preparing dataset...")
        from prepare_countdown_data import prepare_countdown_data

        _, dataset, _, _ = prepare_countdown_data()

    data = []
    for idx, example in enumerate(dataset):
        processed = process_countdown_fn(example, idx)
        for i in range(n):
            data.append(deepcopy(processed))
    return data


def process_countdown_fn(example, idx):
    """Process countdown example into the expected format."""
    question = example["question"]
    target = example["target"]
    nums = example["nums"]

    # Create ground truth in the format expected by countdown_reward_fn
    ground_truth = {"target": target, "numbers": nums}

    task = {"question": question, "ground_truth": ground_truth, "idx": idx, "data_source": "countdown", "target": target, "nums": nums}
    return task


def evaluate_results(results):
    """Evaluate the results and compute pass@k metrics."""
    from collections import defaultdict

    # Create a map to store correct answers per problem
    problem_correct_map = defaultdict(int)
    problem_total_map = defaultdict(int)

    # Count correct answers for each problem
    for episode in results:
        problem = episode.task["question"]

        # Use the episode-level is_correct flag set by the workflow
        is_correct = episode.is_correct

        problem_correct_map[problem] += int(is_correct)
        problem_total_map[problem] += 1

    # Calculate pass@1 and pass@k
    k = max(problem_total_map.values()) if problem_total_map else 1
    total_problems = len(problem_correct_map)

    if total_problems > 0:
        pass_at_1 = sum(problem_correct_map.values()) / sum(problem_total_map.values())
        pass_at_k = sum(1 for problem, correct in problem_correct_map.items() if correct > 0) / total_problems
    else:
        pass_at_1 = 0.0
        pass_at_k = 0.0

    print("Total unique problems:", total_problems)
    print("Average Pass@1 Accuracy:", pass_at_1)
    print(f"Average Pass@{k} Accuracy:", pass_at_k)


if __name__ == "__main__":
    import os

    os.environ["TOKENIZERS_PARALLELISM"] = "true"

    # Configuration
    n_parallel_tasks = 4
    n_solutions = 2  # Number of solutions to generate per problem

    model_name = "Qwen/Qwen3-8B"
    service_client = tinker.ServiceClient(base_url=None)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    rollout_engine = TinkerEngine(
        base_url=None,
        model_name=model_name,
        tokenizer=tokenizer,
        service_client=service_client,
        max_prompt_length=2048,
        max_response_length=1024,
        sampling_params={"temperature": 0.6, "top_p": 0.95},
    )
    training_client = service_client.create_lora_training_client(
        base_model=model_name,
        rank=4,
    )
    sampler_future = training_client.save_weights_for_sampler(name="000000")
    sampler_result = sampler_future.result()
    sampling_client = training_client.create_sampling_client(sampler_result.path)

    rollout_engine.set_sampling_client(sampling_client)

    engine = AgentWorkflowEngine(
        workflow_cls=SolverJudgeWorkflow,
        workflow_args={
            "n_solutions": n_solutions,
            "reward_function": countdown_reward_fn,
        },
        rollout_engine=rollout_engine,
        config=None,
        n_parallel_tasks=n_parallel_tasks,
        retry_limit=1,
    )

    # Load countdown tasks
    tasks = load_data(n=1)
    print(f"Loaded {len(tasks)} countdown tasks")
    tasks = tasks[:4]

    results = asyncio.run(engine.execute_tasks(tasks))
    import pdb

    pdb.set_trace()

    print(results[1])

    # Evaluate results (rewards are already assigned in the workflow)
    print("Evaluating results...")
    evaluate_results(results)

    # Save results
    os.makedirs("logs", exist_ok=True)
    with open("logs/solver_judge_countdown.json", "w") as f:
        json.dump([episode.to_dict() for episode in results], f, indent=4)

    print("\nResults saved to logs/solver_judge_countdown.json")

This script:

  • Builds a TinkerEngine for rollouts
  • Wraps it with AgentWorkflowEngine using SolverJudgeWorkflow
  • Executes Countdown tasks and computes pass@1 / pass@k metrics

Monitoring and Checkpoints

For the solver‑judge example:

  • Logging:
  • Set trainer.logger=['console','wandb'] to enable Weights & Biases
  • Use trainer.project_name / trainer.experiment_name to organize runs
  • Checkpoints:
  • Local paths are controlled by trainer.default_local_dir
  • You can resume from a Tinker checkpoint via trainer.resume_from_tinker_id='tinker://<uuid>/weights/<checkpoint_name>'

This gives you an end‑to‑end RL training pipeline where rollouts, gradients, and checkpoints all run on Tinker's managed GPU service, while rLLM handles datasets, workflows, and trainer orchestration.