Skip to content

GSM8K LoRA Training

This example shows how to fine‑tune a math reasoning agent on the GSM8K dataset using LoRA in rLLM.
You will use the standard MathAgent with a single‑turn environment and enable LoRA via a few configuration flags.

Overview

With this example you will:

  1. Prepare the GSM8K dataset and register it with DatasetRegistry
  2. Train MathAgent on GSM8K using REINFORCE/GRPO‑style RL with LoRA adapters
  3. Configure LoRA hyperparameters (rank, target modules, alpha) via the training script

The training loop uses the standard VERL‑style backend (agent_ppo_trainer config) and simply adds LoRA settings to the model.


1. Dataset Preparation

First, preprocess GSM8K and register it in rLLM:

cd examples/gsm8k_lora
python prepare_gsm8k_data.py

This will:

  • Download the openai/gsm8k dataset (train + test)
  • Extract the final numeric answer from the solution using a #### <answer> pattern
  • Create a compact schema:
  • question: the original problem text
  • ground_truth: the extracted numeric answer
  • data_source: "gsm8k"
  • Register datasets as:
  • DatasetRegistry.load_dataset("gsm8k", "train")
  • DatasetRegistry.load_dataset("gsm8k", "test")

Dataset preparation logic:

examples/gsm8k_lora/prepare_gsm8k_data.py
import re

from datasets import load_dataset

from rllm.data.dataset import DatasetRegistry


# Adapted from verl/examples/data_preprocess/gsm8k.py
def extract_solution(solution_str):
    solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
    assert solution is not None
    final_solution = solution.group(0)
    final_solution = final_solution.split("#### ")[1].replace(",", "")
    return final_solution


def prepare_gsm8k_data():
    gsm8k_dataset = load_dataset("openai/gsm8k", "main")
    train_dataset = gsm8k_dataset["train"]
    test_dataset = gsm8k_dataset["test"]

    def preprocess_fn(example, idx):
        return {
            "question": example["question"],
            "ground_truth": extract_solution(example["answer"]),
            "data_source": "gsm8k",
        }

    train_dataset = train_dataset.map(preprocess_fn, with_indices=True)
    test_dataset = test_dataset.map(preprocess_fn, with_indices=True)

    train_dataset = DatasetRegistry.register_dataset("gsm8k", train_dataset, "train")
    test_dataset = DatasetRegistry.register_dataset("gsm8k", test_dataset, "test")
    return train_dataset, test_dataset


if __name__ == "__main__":
    train_dataset, test_dataset = prepare_gsm8k_data()
    print(train_dataset)
    print(test_dataset)

2. Training Script (LoRA + RL)

The main training entrypoint wraps MathAgent in a single‑turn environment with the built‑in math reward:

examples/gsm8k_lora/train_gsm8k_with_lora.py
import hydra

from rllm.agents.math_agent import MathAgent
from rllm.data.dataset import DatasetRegistry
from rllm.environments.base.single_turn_env import SingleTurnEnvironment
from rllm.rewards.reward_fn import math_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
    train_dataset = DatasetRegistry.load_dataset("gsm8k", "train")
    test_dataset = DatasetRegistry.load_dataset("gsm8k", "test")

    env_args = {"reward_fn": math_reward_fn}

    trainer = AgentTrainer(
        agent_class=MathAgent,
        agent_args={},
        env_args=env_args,
        env_class=SingleTurnEnvironment,
        config=config,
        train_dataset=train_dataset,
        val_dataset=test_dataset,
    )
    trainer.train()


if __name__ == "__main__":
    main()

Key pieces:

  • Agent: MathAgent from rllm.agents.math_agent
  • Environment: SingleTurnEnvironment (one question → one answer)
  • Reward: math_reward_fn, which parses the model’s output and checks correctness
  • Datasets:
  • train_dataset = DatasetRegistry.load_dataset("gsm8k", "train")
  • val_dataset = DatasetRegistry.load_dataset("gsm8k", "test")

LoRA is configured via the Hydra overrides in the shell script rather than inside the Python file.

2.1 Launch training with LoRA

Use the helper shell script to start training with LoRA enabled:

cd examples/gsm8k_lora
bash train_gsm8k_lora.sh

Training configuration:

examples/gsm8k_lora/train_gsm8k_lora.sh
set -x

export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
export VLLM_USE_V1=1
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000

MODEL_PATH=Qwen/Qwen2.5-3B-Instruct

python3 -m examples.gsm8k_lora.train_gsm8k_lora \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=8 \
    data.val_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    actor_rollout_ref.model.path=$MODEL_PATH \
    actor_rollout_ref.model.lora_rank=32 \
    actor_rollout_ref.model.lora_alpha=32 \
    actor_rollout_ref.model.target_modules=all-linear \
    actor_rollout_ref.actor.optim.lr=5e-6 \
    actor_rollout_ref.actor.strategy=fsdp2 \
    actor_rollout_ref.actor.loss_agg_mode=token-mean \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=8 \
    actor_rollout_ref.actor.use_dynamic_bsz=False \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=20000 \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.actor.clip_ratio_high=0.2 \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.mode="async" \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
    actor_rollout_ref.rollout.enforce_eager=True \
    actor_rollout_ref.rollout.n=8 \
    actor_rollout_ref.rollout.temperature=0.7 \
    actor_rollout_ref.rollout.top_p=0.95 \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.rollout.val_kwargs.temperature=0.7 \
    actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
    actor_rollout_ref.ref.fsdp_config.param_offload=False \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.actor.entropy_coeff=0 \
    algorithm.kl_ctrl.kl_coef=0.001 \
    rllm.mask_truncated_samples=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='rllm-experiment' \
    trainer.experiment_name='gsm8k-lora' \
    trainer.val_before_train=True \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.save_freq=1000 \
    trainer.test_freq=10 \
    trainer.default_hdfs_dir=null \
    rllm.agent.max_steps=1 \
    rllm.stepwise_advantage.enable=False \
    trainer.total_epochs=100

Important LoRA‑related options in this script:

  • actor_rollout_ref.model.path=$MODEL_PATH – base model (e.g. Qwen/Qwen2.5-3B-Instruct)
  • actor_rollout_ref.model.lora_rank=32 – LoRA rank
  • actor_rollout_ref.model.lora_alpha=32 – LoRA scaling
  • actor_rollout_ref.model.target_modules=all-linear – apply LoRA to all linear layers

Other notable settings:

  • algorithm.adv_estimator=grpo – GRPO‑style advantage estimation
  • data.max_prompt_length=512, data.max_response_length=1024
  • trainer.logger=['console','wandb'], trainer.project_name='rllm-experiment', trainer.experiment_name='gsm8k-lora'

You can modify any of these via additional CLI overrides when calling the script.


3. Customizing LoRA and Training

To experiment with different LoRA and training parameters, you can directly override values in the script call, for example:

cd examples/gsm8k_lora
python3 -m examples.gsm8k_lora.train_gsm8k_lora \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
    actor_rollout_ref.model.lora_rank=16 \
    actor_rollout_ref.model.lora_alpha=16 \
    data.train_batch_size=4 \
    actor_rollout_ref.actor.optim.lr=1e-5 \
    trainer.project_name='gsm8k-lora-ablation' \
    trainer.experiment_name='small-batch-lr-1e-5'

This GSM8K LoRA example demonstrates how LoRA fine‑tuning is just a config change on top of the standard rLLM RL training stack—no changes to the agent or environment code are required.