FrozenLake Agent Example
This example shows how we could train a RL agent to play the FrozenLake game.
FrozenLake is a classic RL environment where:
- Agent navigates a frozen lake grid
- Goal is to reach the frisbee without falling into holes
- Slippery surface adds stochasticity to actions
- Discrete action space: UP, DOWN, LEFT, RIGHT
Quick Start
Prepare Environment Data
Run FrozenLake Agent
Train Agent
Code Reference
Agent Runner
Main script for running the FrozenLake agent:
examples/frozenlake/run_frozenlake_agent.py
import asyncio
from transformers import AutoTokenizer
from rllm.agents.frozenlake_agent import FrozenLakeAgent
from rllm.data.dataset import DatasetRegistry
from rllm.engine.agent_execution_engine import AgentExecutionEngine
from rllm.environments.frozenlake.frozenlake import FrozenLakeEnv
from rllm.utils import compute_pass_at_k
def load_frozenlake_data():
if DatasetRegistry.dataset_exists("frozenlake", "test"):
test_dataset = DatasetRegistry.load_dataset("frozenlake", "test")
return test_dataset.get_data()
print("FrozenLake datasets not found. Preparing datasets...")
from prepare_frozenlake_data import prepare_frozenlake_data
train_dataset, test_dataset = prepare_frozenlake_data()
return test_dataset.get_data()
if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
n_parallel_agents = 256
model_name = "Qwen/Qwen3-4B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
sampling_params = {"temperature": 0.6, "top_p": 0.95, "model": model_name}
agent_args = {
"max_steps": 10,
"use_accumulate_history": True,
}
env_args = {
"max_steps": 8,
"is_slippery": False,
}
engine = AgentExecutionEngine(
agent_class=FrozenLakeAgent,
env_class=FrozenLakeEnv,
agent_args=agent_args,
env_args=env_args,
engine_name="openai",
tokenizer=tokenizer,
sampling_params=sampling_params,
rollout_engine_args={
"base_url": "http://localhost:30000/v1",
"api_key": "None",
},
max_response_length=16384,
max_prompt_length=4096,
n_parallel_agents=n_parallel_agents,
)
tasks = load_frozenlake_data()
results = asyncio.run(engine.execute_tasks(tasks))
compute_pass_at_k(results)
Training Script
Agent training implementation:
examples/frozenlake/train_frozenlake_agent.py
import hydra
from rllm.agents.frozenlake_agent import FrozenLakeAgent
from rllm.data import DatasetRegistry
from rllm.environments.frozenlake.frozenlake import FrozenLakeEnv
from rllm.trainer.agent_trainer import AgentTrainer
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
def main(config):
train_dataset = DatasetRegistry.load_dataset("frozenlake", "train")
val_dataset = DatasetRegistry.load_dataset("frozenlake", "test")
trainer = AgentTrainer(
agent_class=FrozenLakeAgent,
env_class=FrozenLakeEnv,
config=config,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
trainer.train()
if __name__ == "__main__":
main()
For more details, see the FrozenLake README.