RL Training with Tinker
This example shows how to train a solver‑judge RL workflow with the Tinker backend in rLLM, using Tinker's hosted GPU service.
Overview
With this example you will:
- Train a solver‑judge workflow for the Countdown task using the Tinker backend
- Reuse the unified Tinker RL config (
tinker_rl_trainer.yaml) for workflow training
Under the hood, rLLM integrates with Tinker as:
- Rollout backend: sampling and logprob computation happen on Tinker's GPU service
- Policy trainer: LoRA adapters are optimized remotely via Tinker training clients
- Checkpoint manager: checkpoints are stored and resumed via Tinker model IDs
Setup
Install dependencies
Configure Tinker authentication
Set your Tinker API key:
You can obtain an API key from the Tinker console.
Shared Tinker RL config
This example uses the unified RL config in rLLM:
# Tinker Backend Configuration for rLLM
# This config is used when training agents with Tinker backend
# Unified configuration supporting both workflow and agent training modes
# Default settings match tinker_cookbook.recipes.math_rl for MATH dataset
# Tinker-specific settings
tinker_base_url: null # Tinker service URL (null for local)
# Model Configuration
model:
name: "Qwen/Qwen3-8B" # Default model for MATH dataset
lora_rank: 32
train_unembed: true # Train LoRA on output embedding layer (set to false for Fireworks compatibility)
train_attn: true # Train LoRA on attention layers
train_mlp: true # Train LoRA on MLP layers
# Training Configuration
training:
group_size: 16 # Number of rollouts per prompt (for GRPO)
learning_rate: 2e-5 # 2e-5 for MATH dataset
beta1: 0.9
beta2: 0.95
eps: 1e-8
max_length: 32768
num_minibatches: 1
# Sampling Configuration
sampling:
temperature: 0.6
top_p: 0.95
# Algorithm Configuration (compatible with verl)
algorithm:
adv_estimator: grpo # REINFORCE, GRPO
gamma: 1.0
lam: 0.95
norm_adv_by_std_in_grpo: false # math_rl doesn't normalize by std
grouping_level: 'trajectory' # Options: 'trajectory' or 'step'
# Workflow Configuration (for workflow training mode)
workflow:
n_parallel_tasks: 256
retry_limit: 3
# Agent Configuration (for agent training mode)
agent:
max_steps: 1 # Single-turn vs multi-turn
agent_args: {}
# Environment Configuration (for agent training mode)
env:
env_args: {}
# Data Configuration
data:
train_files: null
val_files: null
max_prompt_length: 2048
max_response_length: 2048
train_batch_size: 64
val_batch_size: 32
# Trainer Configuration
trainer:
total_epochs: 10
logger: ['console'] # Options: 'console', 'wandb', 'tensorboard'
project_name: 'rllm-tinker'
experiment_name: 'default'
test_freq: 5
save_freq: 20
val_before_train: true
default_local_dir: '/tmp/rllm-tinker-checkpoints'
resume_from_tinker_id: null # Tinker model ID path to resume from (e.g., tinker://uuid/weights/000060)
# Hydra configuration
hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
Key options you may want to tune:
model.name: base model to fine‑tune (e.g.Qwen/Qwen3-8B,Qwen/Qwen3-30B-A3B)model.lora_rank: LoRA ranktraining.group_size: number of trajectories per prompt (GRPO group size)data.max_prompt_length/data.max_response_length: context and generation lengthstrainer.total_epochs,trainer.logger,trainer.project_name,trainer.experiment_name
You can override any of these from the command line using Hydra syntax (see below).
Solver‑Judge RL Training with Tinker
This example trains a multi‑agent solver‑judge workflow on the Countdown task using the same Tinker RL backend.
2.1 Prepare Countdown dataset
First download and register the Countdown dataset:
This will:
- Load
Jiayi-Pan/Countdown-Tasks-3to4from HuggingFace - Convert each example into a math‑style word problem
- Register multiple splits (train, test, stage2, stage3) under the
countdownkey
Dataset preparation:
import random
from datasets import load_dataset
from rllm.data.dataset import DatasetRegistry
def prepare_countdown_data():
"""
Prepare the countdown task dataset from HuggingFace.
Take 1024 examples as test set, remaining as training set.
Also create stage 2 and stage 3 training sets with 50k examples each.
"""
# Load the countdown dataset
dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
# Split dataset: 1024 examples for test, rest for training
test_size = 1024
total_size = len(dataset)
# Create train/test split
test_dataset = dataset.select(range(test_size))
train_dataset = dataset.select(range(test_size, total_size))
def preprocess_fn(example, idx):
"""
Convert countdown task format to math problem format.
Example: target=98, nums=[44, 19, 35] becomes a math word problem.
"""
target = example["target"]
nums = example["nums"]
# Format as a math problem
nums_str = ", ".join(map(str, nums))
question = f"Using the numbers {nums_str}, find a way to reach the target number {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your step-by-step calculation and output the final answer within <answer>...</answer>, for example <answer> (1 + 2) / 3 </answer>."
return {
"question": question,
"ground_truth": str(target),
"data_source": "countdown",
"target": target,
"nums": nums,
}
# Apply preprocessing
train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
test_dataset = test_dataset.map(preprocess_fn, with_indices=True)
# Create stage 2 and stage 3 training datasets
train_size = len(train_dataset)
stage_size = 50000
# Ensure we have enough data for both stages
if train_size < 2 * stage_size:
print(f"Warning: Training set has only {train_size} examples, but need {2 * stage_size} for both stages")
stage_size = min(stage_size, train_size // 2)
# Shuffle and select indices for stage 2 and stage 3
all_indices = list(range(train_size))
random.shuffle(all_indices)
stage2_indices = all_indices[:stage_size]
stage3_indices = all_indices[stage_size : 2 * stage_size]
# Create stage datasets
stage2_dataset = train_dataset.select(stage2_indices)
stage3_dataset = train_dataset.select(stage3_indices)
# Register datasets
train_dataset = DatasetRegistry.register_dataset("countdown", train_dataset, "train")
test_dataset = DatasetRegistry.register_dataset("countdown", test_dataset, "test")
stage2_dataset = DatasetRegistry.register_dataset("countdown", stage2_dataset, "stage2_train")
stage3_dataset = DatasetRegistry.register_dataset("countdown", stage3_dataset, "stage3_train")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
print(f"Stage 2 train dataset size: {len(stage2_dataset)}")
print(f"Stage 3 train dataset size: {len(stage3_dataset)}")
return train_dataset, test_dataset, stage2_dataset, stage3_dataset
if __name__ == "__main__":
train_dataset, test_dataset, stage2_dataset, stage3_dataset = prepare_countdown_data()
print("Train dataset path:", train_dataset.get_data_path())
print("Test dataset path:", test_dataset.get_data_path())
print("Stage 2 train dataset path:", stage2_dataset.get_data_path())
print("Stage 3 train dataset path:", stage3_dataset.get_data_path())
# Print a sample
print("\nSample train example:")
print(train_dataset[0])
print("\nSample stage 2 train example:")
print(stage2_dataset[0])
print("\nSample stage 3 train example:")
print(stage3_dataset[0])
2.2 Solver‑Judge workflow with Tinker backend
The Tinker RL training entrypoint for the solver‑judge workflow is:
import hydra
from examples.solver_judge.solver_judge_flow import SolverJudgeWorkflow
from rllm.data.dataset import DatasetRegistry
from rllm.rewards.countdown_reward import countdown_reward_fn
from rllm.trainer import AgentTrainer
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="tinker_rl_trainer", version_base=None)
def main(config):
train_dataset = DatasetRegistry.load_dataset("countdown", "train")
test_dataset = DatasetRegistry.load_dataset("countdown", "test")
trainer = AgentTrainer(
workflow_class=SolverJudgeWorkflow,
workflow_args={
"n_solutions": 2,
"reward_function": countdown_reward_fn,
},
config=config,
train_dataset=train_dataset,
val_dataset=test_dataset,
backend="tinker",
)
trainer.train()
if __name__ == "__main__":
main()
It uses:
SolverJudgeWorkflowfromexamples.solver_judge.solver_judge_flowcountdown_reward_fnas the reward functionAgentTrainerwithbackend="tinker"andworkflow_class=SolverJudgeWorkflow
2.3 Train solver‑judge workflow with Tinker
Run the provided shell script:
This will:
- Fine‑tune
Qwen/Qwen3-4B-Instruct-2507with LoRA (rank 32) - Train with GRPO using trajectory‑level grouping (
algorithm.grouping_level=trajectory) - Use normalized advantages for stability (
algorithm.norm_adv_by_std_in_grpo=true) - Log training metrics to Weights & Biases
Shell configuration:
set -x
MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
python -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
model.name=$MODEL_PATH \
model.lora_rank=32 \
training.group_size=4 \
training.learning_rate=4e-5 \
sampling.temperature=1.0 \
sampling.top_p=1.0 \
algorithm.adv_estimator=grpo \
algorithm.norm_adv_by_std_in_grpo=true \
algorithm.grouping_level=trajectory \
data.max_prompt_length=2048 \
data.max_response_length=1024 \
data.train_batch_size=64 \
data.val_batch_size=512 \
trainer.total_epochs=100 \
trainer.logger=['wandb'] \
trainer.project_name='solver-judge-workflow' \
trainer.experiment_name='countdown-solver-judge-tinker-norm-by-std' \
trainer.val_before_train=False \
trainer.test_freq=10 \
trainer.save_freq=20 \
trainer.default_local_dir='/tmp/countdown-solver-judge-tinker-norm-by-std'
You can customize training via CLI overrides, e.g.:
cd examples/solver_judge_tinker
python -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
model.name=Qwen/Qwen3-8B \
model.lora_rank=16 \
training.group_size=8 \
data.train_batch_size=32 \
trainer.total_epochs=20 \
trainer.logger=['console','wandb'] \
trainer.project_name='solver-judge-tinker' \
trainer.experiment_name='countdown-grpo-qwen3-8b'
2.4 Run the workflow with Tinker rollout engine (optional)
For interactive evaluation (no training step), you can run the Countdown solver‑judge workflow directly using Tinker for sampling:
import asyncio
import json
import os
# Import countdown-specific modules
import sys
from copy import deepcopy
import tinker
from solver_judge_flow import SolverJudgeWorkflow
from transformers import AutoTokenizer
from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.engine.rollout.tinker_engine import TinkerEngine
from rllm.rewards.countdown_reward import countdown_reward_fn
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "countdown"))
def load_data(n=1):
"""Load countdown data using the Dataset interface."""
dataset = DatasetRegistry.load_dataset("countdown", "test")
if dataset is None:
print("Dataset not found, preparing dataset...")
from prepare_countdown_data import prepare_countdown_data
_, dataset, _, _ = prepare_countdown_data()
data = []
for idx, example in enumerate(dataset):
processed = process_countdown_fn(example, idx)
for i in range(n):
data.append(deepcopy(processed))
return data
def process_countdown_fn(example, idx):
"""Process countdown example into the expected format."""
question = example["question"]
target = example["target"]
nums = example["nums"]
# Create ground truth in the format expected by countdown_reward_fn
ground_truth = {"target": target, "numbers": nums}
task = {"question": question, "ground_truth": ground_truth, "idx": idx, "data_source": "countdown", "target": target, "nums": nums}
return task
def evaluate_results(results):
"""Evaluate the results and compute pass@k metrics."""
from collections import defaultdict
# Create a map to store correct answers per problem
problem_correct_map = defaultdict(int)
problem_total_map = defaultdict(int)
# Count correct answers for each problem
for episode in results:
problem = episode.task["question"]
# Use the episode-level is_correct flag set by the workflow
is_correct = episode.is_correct
problem_correct_map[problem] += int(is_correct)
problem_total_map[problem] += 1
# Calculate pass@1 and pass@k
k = max(problem_total_map.values()) if problem_total_map else 1
total_problems = len(problem_correct_map)
if total_problems > 0:
pass_at_1 = sum(problem_correct_map.values()) / sum(problem_total_map.values())
pass_at_k = sum(1 for problem, correct in problem_correct_map.items() if correct > 0) / total_problems
else:
pass_at_1 = 0.0
pass_at_k = 0.0
print("Total unique problems:", total_problems)
print("Average Pass@1 Accuracy:", pass_at_1)
print(f"Average Pass@{k} Accuracy:", pass_at_k)
if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Configuration
n_parallel_tasks = 4
n_solutions = 2 # Number of solutions to generate per problem
model_name = "Qwen/Qwen3-8B"
service_client = tinker.ServiceClient(base_url=None)
tokenizer = AutoTokenizer.from_pretrained(model_name)
rollout_engine = TinkerEngine(
base_url=None,
model_name=model_name,
tokenizer=tokenizer,
service_client=service_client,
max_prompt_length=2048,
max_response_length=1024,
sampling_params={"temperature": 0.6, "top_p": 0.95},
)
training_client = service_client.create_lora_training_client(
base_model=model_name,
rank=4,
)
sampler_future = training_client.save_weights_for_sampler(name="000000")
sampler_result = sampler_future.result()
sampling_client = training_client.create_sampling_client(sampler_result.path)
rollout_engine.set_sampling_client(sampling_client)
engine = AgentWorkflowEngine(
workflow_cls=SolverJudgeWorkflow,
workflow_args={
"n_solutions": n_solutions,
"reward_function": countdown_reward_fn,
},
rollout_engine=rollout_engine,
config=None,
n_parallel_tasks=n_parallel_tasks,
retry_limit=1,
)
# Load countdown tasks
tasks = load_data(n=1)
print(f"Loaded {len(tasks)} countdown tasks")
tasks = tasks[:4]
results = asyncio.run(engine.execute_tasks(tasks))
import pdb
pdb.set_trace()
print(results[1])
# Evaluate results (rewards are already assigned in the workflow)
print("Evaluating results...")
evaluate_results(results)
# Save results
os.makedirs("logs", exist_ok=True)
with open("logs/solver_judge_countdown.json", "w") as f:
json.dump([episode.to_dict() for episode in results], f, indent=4)
print("\nResults saved to logs/solver_judge_countdown.json")
This script:
- Builds a
TinkerEnginefor rollouts - Wraps it with
AgentWorkflowEngineusingSolverJudgeWorkflow - Executes Countdown tasks and computes pass@1 / pass@k metrics
Monitoring and Checkpoints
For the solver‑judge example:
- Logging:
- Set
trainer.logger=['console','wandb']to enable Weights & Biases - Use
trainer.project_name/trainer.experiment_nameto organize runs - Checkpoints:
- Local paths are controlled by
trainer.default_local_dir - You can resume from a Tinker checkpoint via
trainer.resume_from_tinker_id='tinker://<uuid>/weights/<checkpoint_name>'
This gives you an end‑to‑end RL training pipeline where rollouts, gradients, and checkpoints all run on Tinker's managed GPU service, while rLLM handles datasets, workflows, and trainer orchestration.