class AgentExecutionEngine:
def __init__(
self,
engine_name="openai",
tokenizer=None,
rollout_engine=None,
chat_parser=None,
n_parallel_agents=1,
trajectory_timeout=None,
gamma=0.2,
api_retries=3,
retry_limit=3,
max_steps=5,
max_response_length=8192,
max_prompt_length=1024,
config=None,
agent_class=None,
env_class=None,
agent_args=None,
rollout_engine_args=None,
env_args=None,
max_workers=64,
enforce_max_prompt_length=False, # If enabled, applies max_prompt check per step
overlong_filter=False, # Filter for overlong trajectories (i.e. TRUNCATION, MAX_STEPS, TIMEOUT)
**kwargs,
):
if agent_args is None:
agent_args = {}
if rollout_engine_args is None:
rollout_engine_args = {}
if env_args is None:
env_args = {}
self.config = config
self.tokenizer = tokenizer
self.engine_name = engine_name
self.n_parallel_agents = n_parallel_agents
self.overlong_filter = overlong_filter
# For interaction
self.gamma = gamma
self.retry_limit = retry_limit
self.max_steps = max_steps
self.max_response_length = max_response_length
self.max_prompt_length = max_prompt_length
self.enforce_max_prompt_length = enforce_max_prompt_length
self.agent_class = agent_class
self.agent_args = agent_args
self.env_class = env_class
self.env_args = env_args
self.agents = [None for _ in range(n_parallel_agents)]
self.envs = [None for _ in range(n_parallel_agents)]
self.trajectory_timeout = trajectory_timeout
if not trajectory_timeout:
self.trajectory_timeout = int(1e9)
if env_class is not None:
assert env_class.is_multithread_safe(), "Environment must be multithread safe for async engine"
if chat_parser is None:
self.chat_parser = ChatTemplateParser.get_parser(self.tokenizer, disable_thinking=kwargs.get("disable_thinking", False))
else:
self.chat_parser = chat_parser
self.rollout_engine_args = rollout_engine_args
self.sampling_params = kwargs.get("sampling_params", {}) # for openai api requests
assert self.engine_name in ["openai", "verl"], "Currently only openai and verl are supported as rollout engine"
if self.engine_name == "openai":
from rllm.engine.rollout.openai_engine import OpenAIEngine
self.rollout_engine = OpenAIEngine(
**rollout_engine_args,
api_retries=api_retries,
tokenizer=self.tokenizer,
max_prompt_length=self.max_prompt_length,
max_response_length=self.max_response_length,
disable_thinking=kwargs.get("disable_thinking", False),
)
elif self.engine_name == "verl":
from rllm.engine.rollout.verl_engine import VerlEngine
self.rollout_engine = VerlEngine(
config=self.config,
rollout_manager=rollout_engine,
tokenizer=self.tokenizer,
disable_thinking=self.config.rllm.disable_thinking,
)
# Create a thread pool executor for environment interactions (i.e. step, reset, close)
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
async def get_model_response(self, prompt, application_id, **kwargs) -> str:
"""
Compute model response asynchronously based on the engine type.
This function is multithread safe and routes the request to the appropriate
engine-specific handler.
Args:
prompt: The input prompt to send to the model
application_id: Unique identifier for the application
**kwargs: Additional arguments to pass to the model
Returns:
The model's response text
Raises:
NotImplementedError: If the engine type is not supported
"""
sampling_params = self.sampling_params.copy()
sampling_params.update(kwargs)
if self.engine_name == "openai":
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, enforce_max_prompt_length=False, **sampling_params)
return output.text
elif self.engine_name == "verl":
meta_data = sampling_params.pop("meta_info", {})
validate = meta_data.get("validate", False)
output = await self.rollout_engine.get_model_response(prompt, application_id=application_id, validate=validate, enforce_max_prompt_length=False, **sampling_params)
return output.text
else:
raise NotImplementedError(f"Engine type '{self.engine_name}' not supported")
def update_envs_and_agents(self, envs, agents):
"""
Update the environments and agents.
Args:
envs: List of environments to use
agents: List of agents to use
"""
assert len(agents) == len(envs), f"Number of agents must equal to number of environments but received, {len(agents)} and {len(envs)}"
self.envs = envs
# For keeping track of the environment index in the batch.
for idx, env in enumerate(envs):
env.idx = idx
self.agents = agents
self.n_parallel_agents = len(envs)
async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Text", **kwargs):
"""Run a single agent's trajectory asynchronously"""
agent = self.agents[idx]
env = self.envs[idx]
# env_id = env.env_id
termination_reason = None
prompt_token_len = 0
prompt_tokens = []
response_token_len = 0
response_tokens = []
response_masks = []
total_time = 0.0
reward_time = None
llm_time = 0.0
env_time = 0.0
reward = 0.0
# for step return
episode_steps = []
# Reset environment with the task using the executor
loop = asyncio.get_event_loop()
observation, info = await loop.run_in_executor(self.executor, env.reset)
info["max_steps"] = self.max_steps
# Reset agent
agent.reset()
# Update agent internal state from environment.
agent.update_from_env(
observation=observation, # Raw observation from environment
reward=0.0,
done=False,
info=info,
)
messages = agent.chat_completions
prompt_tokens, _ = convert_messages_to_tokens_and_masks(messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=True, contains_generation_msg=True)
prompt_token_len = len(prompt_tokens)
# Note, this should never happen!
if prompt_token_len > self.max_prompt_length:
agent.reset()
raise Exception(f"Trajectory {idx}: initial prompt length {prompt_token_len} already exceeded max_prompt_length {self.max_prompt_length}, retrying")
for step_idx in range(self.max_steps):
# Get action from agent
prompt_messages = agent.chat_completions.copy()
# Max remaining tokens left for the response
# For enforced max prompt at each step, no need to deduct here
if not self.enforce_max_prompt_length:
max_tokens = self.max_response_length - response_token_len
else:
max_tokens = self.max_response_length
# since max prompt is enforced, we filter out too long prompts.
prompt_str = self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True)
prompt_len = len(self.tokenizer.encode(prompt_str, add_special_tokens=False))
if prompt_len > self.max_prompt_length:
termination_reason = "PROMPT_TRUNCATION"
break
kwargs["max_tokens"] = max_tokens
start_time = time.time()
response = await self.get_model_response(prompt_messages, application_id, **kwargs)
delta_time = time.time() - start_time
llm_time += delta_time
total_time += delta_time
# Update steps
prompt_response_pair = {
"prompt": self.chat_parser.parse(prompt_messages, add_generation_prompt=True, is_first_msg=True),
"response": response,
}
episode_steps.append(prompt_response_pair)
# Update agent with model response
action: Action = agent.update_from_model(response)
action = action.action
# Take step in environment using the executor
start_time = time.time()
try:
next_observation, reward, done, info = await asyncio.wait_for(loop.run_in_executor(self.executor, env.step, action), timeout=(self.trajectory_timeout - total_time))
except asyncio.TimeoutError:
termination_reason = "ENV_TIMEOUT"
if step_idx == 0:
colorful_print(f"Warning: Trajectory {idx} completed due to: {termination_reason} before able to perform 1 complete action. This might cause unexpected behavior. Consider increasing trajectory timeout limit.\n", "red")
reward = 0
cur_step = agent.get_current_state()
done = True
cur_step.done = done
break
delta_time = time.time() - start_time
env_time += delta_time
total_time += delta_time
info["max_steps"] = self.max_steps
info["cur_tokens"] = response_token_len
# Update agent internal state.
agent.update_from_env(
observation=next_observation,
reward=reward,
done=done,
info=info,
)
cur_step = agent.get_current_state()
cur_step.reward = reward
cur_step.done = done
cur_step.info.update(info)
chat_completions_messages = agent.chat_completions
assistant_message, env_messages = get_recent_assistant_user_messages(chat_completions_messages)
# Check and convert to tokens if necessary
assert assistant_message is not None or mode != "Token", "Assistant messages is none when accumulating token trajectories which should be conversations. This should not happen."
assert env_messages is not None or mode != "Token", "Environment messages is none when accumulating token trajectories which should be conversations. This should not happen."
assistant_msg_tokens, assistant_msg_masks = [], []
env_msg_tokens, env_msg_masks = [], []
if assistant_message:
assistant_msg_tokens, assistant_msg_masks = convert_messages_to_tokens_and_masks([assistant_message], tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=False, contains_generation_msg=False)
if env_messages:
env_msg_tokens, env_msg_masks = convert_messages_to_tokens_and_masks(env_messages, tokenizer=self.tokenizer, parser=self.chat_parser, contains_first_msg=False, contains_generation_msg=True)
# Update repsonse token length
response_token_len += len(assistant_msg_tokens) + len(env_msg_tokens)
# Reached maximum number of tokens for the trajectory
if not self.enforce_max_prompt_length and response_token_len >= self.max_response_length:
# Truncation length
truncation_length = self.max_response_length - response_token_len
# Truncate the response and masks
if truncation_length < 0:
truncated_response_tokens = (assistant_msg_tokens + env_msg_tokens)[:truncation_length]
truncated_response_masks = (assistant_msg_masks + env_msg_masks)[:truncation_length]
else:
# Edge case where the response is exactly the max response length.
truncated_response_tokens = assistant_msg_tokens + env_msg_tokens
truncated_response_masks = assistant_msg_masks + env_msg_masks
# Update token collections
response_tokens.extend(truncated_response_tokens)
response_masks.extend(truncated_response_masks)
cur_step = agent.get_current_state()
if response_token_len - len(env_msg_tokens) > self.max_response_length:
cur_step.reward = 0.0
cur_step.done = True
termination_reason = "TRUNCATION"
# handle returning
break
# Update the token version of trajectory
response_tokens.extend(assistant_msg_tokens)
response_masks.extend(assistant_msg_masks)
observation = next_observation
if total_time >= self.trajectory_timeout:
termination_reason = "TIMEOUT"
cur_step = agent.get_current_state()
done = True
cur_step.done = done
break
# Check if episode is done
if done:
termination_reason = "ENV_DONE"
break
response_tokens.extend(env_msg_tokens)
response_masks.extend(env_msg_masks)
if step_idx == self.max_steps - 1:
termination_reason = "MAX_STEPS"
masked_out = False
if self.overlong_filter:
if termination_reason == "TRUNCATION" or termination_reason == "MAX_STEPS" or termination_reason == "TIMEOUT":
# Mask out the entire response for overlong trajectories if the reward is 0.
response_masks = [0] * len(response_masks)
masked_out = True
if hasattr(env, "compute_final_reward") and not masked_out:
cur_step = agent.get_current_state()
start_time = time.time()
reward = await loop.run_in_executor(self.executor, env.compute_final_reward)
reward_time = time.time() - start_time
cur_step.reward = reward
# Closing environment using the executor.
await loop.run_in_executor(self.executor, env.close)
if termination_reason:
if reward > 0:
color = "green"
else:
color = "yellow"
colorful_print(
f"Trajectory {idx} completed due to: {termination_reason}. Reward is {reward}. \n",
color,
)
if masked_out:
colorful_print(f"Trajectory {idx} is masked out due to overlong filter.", "red")
trajectory: Trajectory = agent.trajectory
# Aggregate final trajectory statistics
compute_trajectory_reward(trajectory)
compute_mc_return(trajectory, gamma=self.gamma)
if mode == "Text":
return trajectory
elif mode == "Token":
token_result = {
"prompt_tokens": torch.tensor(prompt_tokens, dtype=torch.long),
"response_tokens": torch.tensor(response_tokens, dtype=torch.long),
"response_masks": torch.tensor(response_masks, dtype=torch.long),
"trajectory_reward": trajectory.reward,
"idx": env.idx,
"chat_completions": agent.chat_completions,
"metrics": {
# Total number of steps taken in the trajectory
"steps": len(trajectory.steps),
# Time to calculate reward
"reward_time": reward_time,
# Total time spent in environment execution (env.step)
"env_time": env_time,
# Time to calculate response tokens
"llm_time": llm_time,
# Total time spent in the trajectory
"total_time": total_time,
},
}
return token_result
elif mode == "Conversation":
return agent.chat_completions
elif mode == "Step":
steps_result = {
"steps": episode_steps,
"trajectory_reward": trajectory.reward,
"idx": env.idx,
"mc_returns": [step.mc_return for step in trajectory.steps][: len(episode_steps)],
}
return steps_result
async def run_agent_trajectory_with_retry(self, idx, application_id, seed=0, mode="Text", **kwargs):
for _ in range(self.retry_limit):
try:
return await asyncio.wait_for(self.run_agent_trajectory_async(idx, application_id=application_id, seed=seed, mode=mode, **kwargs), timeout=7200)
except Exception:
traceback.print_exc()
continue
traceback.print_exc()
raise Exception(f"Trajectory {idx} cannot complete. Please check the log message")
async def trajectory_generator(self, reset_seed=0, timing_raw=None, mode="Text", **kwargs):
if timing_raw is None:
timing_raw = {}
assert all(env is not None and isinstance(env, BaseEnv) for env in self.envs), "All environments must be inheriting from BaseEnv"
assert all(env.is_multithread_safe() for env in self.envs), "All environments must be multithread safe for async engine" # type: ignore
max_concurrency = self.n_parallel_agents
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
if self.engine_name == "verl":
self.rollout_engine.wake_up()
async def launch_one_trajectory_task(env_idx: int):
try:
application_id = str(uuid.uuid4())
result = await self.run_agent_trajectory_with_retry(
idx=env_idx,
application_id=application_id,
seed=reset_seed,
mode=mode,
**kwargs,
)
except Exception as e:
import traceback
traceback.print_exc()
raise e
return result
# Create all N conceptual tasks. Their execution will be throttled by the semaphore
# and the availability of agent/env indices.
tasks_to_run = [launch_one_trajectory_task(i) for i in range(len(self.envs))]
tasks_completed = 0
for coro in asyncio.as_completed(tasks_to_run):
try:
result = await coro
tasks_completed += 1
colorful_print(f"Number of Trajectories {tasks_completed}/{len(self.envs)} completed", "cyan")
yield result
except Exception as e:
raise e
if self.engine_name == "verl":
self.rollout_engine.sleep()
self.executor.shutdown(wait=False, cancel_futures=True)
async def execute_tasks(self, tasks: list[dict]):
"""
Run asynchronous interactions between the agent and environment where each agent
has its own environment instance and can proceed independently.
Args:
tasks: List of tasks to process
max_concurrent: Maximum number of concurrent tasks to process (defaults to self.n_parallel_agents)
Returns:
A list of trajectories, one for each task.
"""
max_concurrent = self.n_parallel_agents
# Initialize results list to store trajectories for all tasks
all_trajectories = {}
# Create a queue of tasks to process
task_queue = list(enumerate(tasks))
semaphore = asyncio.Semaphore(max_concurrent)
index_queue: asyncio.Queue[int] = asyncio.Queue(maxsize=max_concurrent)
for i in range(max_concurrent):
index_queue.put_nowait(i)
# Track completed trajectories
completed = 0
total = len(tasks)
async def sem_wrapper(task_id, task):
nonlocal completed
async with semaphore:
# Get an available index
index = await index_queue.get()
try:
self.envs[index] = self.env_class.from_dict({**task, **self.env_args})
self.agents[index] = self.agent_class(**self.agent_args)
assert self.agents[index] is not None and isinstance(self.agents[index], BaseAgent), "Agent is not initalized or not inheriting from BaseAgent"
self.agents[index].trajectory.task = task # type: ignore
res = await self.run_agent_trajectory_async(index, application_id=task_id)
res.task = task
completed += 1
colorful_print(f"Progress: {completed}/{total} trajectories completed", "cyan")
return task_id, res
finally:
# Put the index back in the queue when done
await index_queue.put(index)
# Run all tasks concurrently
results = await asyncio.gather(*[sem_wrapper(task_id, task) for task_id, task in task_queue])
all_trajectories = {task_id: trajectory for task_id, trajectory in results}
ordered_trajectories = [all_trajectories[i] for i in range(len(all_trajectories))]
return ordered_trajectories
def shutdown(self):
if hasattr(self, "executor") and self.executor is not None:
self.executor.shutdown()
self.executor = None