Skip to content

Workflow

The Workflow class provides the core interface and functionality that all workflows inherit from.

rllm.workflows.workflow

Workflow

Bases: ABC

Source code in rllm/workflows/workflow.py
class Workflow(ABC):
    def __init__(self, rollout_engine: RolloutEngine, executor: ThreadPoolExecutor, timeout=1e6, gamma=0.0, reward_bonus_coeff=0.0, **kwargs):
        """Initialize the Workflow.

        Args:
            rollout_engine: The rollout engine to use.
            executor: The executor to use.
            timeout: The timeout for the workflow.
            gamma: The discount factor for the workflow.
            reward_bonus_coeff: The reward bonus coefficient for the workflow.
            **kwargs: Additional keyword arguments.
        """
        self.rollout_engine = rollout_engine
        self.executor = executor
        self.timeout = int(timeout)
        self.gamma = gamma
        self.reward_bonus_coeff = reward_bonus_coeff

        self._completed_trajectories: list[Trajectory] = []

    @abstractmethod
    async def run(self, task: dict, uid: str, **kwargs) -> Episode | None:
        """Execute the workflow on a single task

        Args:
            task: The task to execute.
            uid: The unique identifier for the task.
            **kwargs: Additional keyword arguments.

        Returns:
            Episode: The episode generated by the workflow.
        """
        pass

    async def run_with_termination_handling(self, task: dict, uid: str, **kwargs) -> Episode:
        """Wrapper method around workflow.run that handles termination events, errors, timeouts, and post-processing.

        Args:
            task: The task to execute.
            uid: The unique identifier for the task.
            **kwargs: Additional keyword arguments.
        """
        try:
            coro = self.run(task, uid, **kwargs)
            output = await asyncio.wait_for(coro, timeout=self.timeout)
            if output is not None and isinstance(output, Episode):
                return output  # we assume it's already postprocessed
            return self.postprocess_episode(self.collect_trajectories(), TerminationReason.UNKNOWN)
        except asyncio.TimeoutError:
            return self.postprocess_episode(self.collect_trajectories(), TerminationReason.TIMEOUT)
        except TerminationEvent as e:
            return self.postprocess_episode(self.collect_trajectories(), e.reason)
        except Exception as e:
            import traceback

            error_details = {"error_message": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()}
            return self.postprocess_episode(self.collect_trajectories(), TerminationReason.ERROR, error=error_details)

    def commit(self, name: str | None = None, agent: BaseAgent | None = None, trajectory: Trajectory | None = None, reset: bool = False) -> None:
        """Commit a trajectory for training.

        Args:
            name: The name of the trajectory.
            agent: The agent that generated the trajectory.
            trajectory: The trajectory to commit.
            reset: Whether to reset the agent.
        """
        assert agent is not None or trajectory is not None, "Either agent or trajectory must be provided to workflow.commit"
        assert agent is None or trajectory is None, "Only one of agent or trajectory can be provided to workflow.commit"

        traj = agent.trajectory if agent is not None else trajectory
        if name:
            traj.name = name
        if traj.steps:
            self._completed_trajectories.append(deepcopy(traj))

        if agent is not None and reset:
            agent.reset()

    def collect_trajectories(self) -> Episode:
        """Collect the trajectories from the workflow

        Returns:
            Episode: The episode generated by the workflow.
        """

        episode = Episode()

        # Start with completed trajectories
        episode.trajectories.extend(self._completed_trajectories)

        # Track completed trajectory uids
        completed_trajectory_uids = {trajectory.uid for trajectory in self._completed_trajectories}

        # Add trajectories from agents that aren't already in completed trajectories
        for attr_name in dir(self):
            if attr_name.startswith("_"):
                continue
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseAgent) and hasattr(attr_value, "trajectory") and getattr(attr_value.trajectory, "uid", None) not in completed_trajectory_uids and len(attr_value.trajectory.steps) > 0:
                episode.trajectories.append(deepcopy(attr_value.trajectory))

        return episode

    def compute_trajectory_reward(self, trajectory: Trajectory) -> None:
        """
        Compute the trajectory-level reward.
        Default: sum the step rewards

        Args:
            trajectory: The trajectory to compute the reward for.
        """
        trajectory.reward = np.sum([d.reward for d in trajectory.steps])

    def adjust_step_rewards(self, trajectory: Trajectory) -> None:
        """
        Adjust the step-level rewards. Supports reward shaping and discounting
        self.reward_bonus_coeff and self.gamma are 0.0, so no adjustments are made by default.

        Args:
            trajectory: The trajectory to adjust the rewards for.
        """
        # reward shaping
        # s[i].reward = s[i].reward + bonus * (s[i].reward - s[i-1].reward) for i > 0
        if self.reward_bonus_coeff > 0.0:
            raw_rewards = [step.reward for step in trajectory.steps]
            for i in range(1, len(trajectory.steps)):
                trajectory.steps[i].reward += self.reward_bonus_coeff * (raw_rewards[i] - raw_rewards[i - 1])

        # Compute Monte Carlo returns (backward iteration)
        # G_t = R_{t+1} + γ * R_{t+2} + γ² * R_{t+3} + ... + γ^{T-t-1} * R_T
        if self.gamma > 0.0:
            G = 0.0
            for step in reversed(trajectory.steps):
                G = step.reward + self.gamma * G
                step.reward = G  # Replace the reward with MC return

    def assign_episode_correctness(self, episode: Episode) -> None:
        """
        Assign an episode-level correctness flag.
        Default: True if the sum of the trajectory rewards is strictly positive.

        Args:
            episode: The episode to assign the correctness flag to.
        """
        total_reward = 0
        for trajectory in episode.trajectories:
            total_reward += trajectory.reward
        episode.is_correct = total_reward > 0

    def collect_metrics(self, episode: Episode) -> None:
        """
        Collect metrics from the episode.

        Args:
            episode: The episode to collect metrics from.
        """
        metrics = defaultdict(list)
        for traj in episode.trajectories:
            name = traj.name
            metrics[name].append(traj.reward)
        episode.metrics = {f"{k}_acc": float(np.mean(v)) for k, v in metrics.items()}

    def postprocess_episode(self, episode: Episode, termination_reason: TerminationReason = None, error: dict = None) -> Episode:
        """Collect and process the trajectories

        Args:
            episode: The episode to postprocess.
            termination_reason: The termination reason for the episode.
            error: The error details for the episode.
        """

        # 1. assign a task id and task
        episode.id = self.uid
        episode.task = self.task

        for trajectory in episode.trajectories:
            # depending on the terminaiton reason, there may be a trajectry with an additional step with empty chat_completions
            # i.e., if it's thrown between agent.update_from_env() and agent.update_from_model()
            if trajectory.steps and not trajectory.steps[-1].chat_completions:
                trajectory.steps.pop()

            # 2. compute trajectory-level rewards
            self.compute_trajectory_reward(trajectory)

            # 3. adjust the step level rewards (e.g., reward shaping or discounting)
            if len(trajectory.steps) > 1:
                self.adjust_step_rewards(trajectory)

        # 4. assign an episode-level correctness flag
        self.assign_episode_correctness(episode)

        # 5. collect additional metrics workflow
        # by default, we report the acc of each agent using the traj reward
        self.collect_metrics(episode)

        # 6. store error details if provided
        if error is not None:
            episode.info["error"] = error

        # 7. assign a termination reason
        episode.termination_reason = termination_reason or TerminationReason.UNKNOWN

        return episode

    def reset(self, task: dict | None = None, uid: str | None = None) -> None:
        """Reset the workflow

        Args:
            task: The task to reset the workflow to.
            uid: The unique identifier for the task.
        """
        # set the uid and task
        self.uid = uid
        self.task = task
        self._completed_trajectories = []

        # reset agents (look for class attributes that are BaseAgent subclasses)
        for attr_name in dir(self):
            if attr_name.startswith("_"):
                continue
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseAgent) and hasattr(attr_value, "reset"):
                attr_value.reset()
                attr_value.trajectory.task = task

        # reset environments (look for class attributes that are BaseEnv subclasses)
        for attr_name in dir(self):
            if attr_name.startswith("_"):
                continue
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseEnv) and hasattr(attr_value, "reset"):
                attr_value.reset(task=task)

    def is_multithread_safe(self) -> bool:
        """Check if the workflow is multithread safe

        Returns:
            bool: True if the workflow is multithread safe, False otherwise.
        """
        for attr_name in dir(self):
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseEnv) and not attr_value.is_multithread_safe():
                return False
        return True

    async def run_in_executor(self, fn, *args, **kwargs):
        """Run a function in seperate thread pool executor.

        Args:
            fn: The function to run.
            *args: The arguments to pass to the function.
            **kwargs: The keyword arguments to pass to the function.
        """
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, partial(fn, *args, **kwargs))

__init__

__init__(rollout_engine: RolloutEngine, executor: ThreadPoolExecutor, timeout=1000000.0, gamma=0.0, reward_bonus_coeff=0.0, **kwargs)

Initialize the Workflow.

Parameters:

Name Type Description Default
rollout_engine RolloutEngine

The rollout engine to use.

required
executor ThreadPoolExecutor

The executor to use.

required
timeout

The timeout for the workflow.

1000000.0
gamma

The discount factor for the workflow.

0.0
reward_bonus_coeff

The reward bonus coefficient for the workflow.

0.0
**kwargs

Additional keyword arguments.

{}
Source code in rllm/workflows/workflow.py
def __init__(self, rollout_engine: RolloutEngine, executor: ThreadPoolExecutor, timeout=1e6, gamma=0.0, reward_bonus_coeff=0.0, **kwargs):
    """Initialize the Workflow.

    Args:
        rollout_engine: The rollout engine to use.
        executor: The executor to use.
        timeout: The timeout for the workflow.
        gamma: The discount factor for the workflow.
        reward_bonus_coeff: The reward bonus coefficient for the workflow.
        **kwargs: Additional keyword arguments.
    """
    self.rollout_engine = rollout_engine
    self.executor = executor
    self.timeout = int(timeout)
    self.gamma = gamma
    self.reward_bonus_coeff = reward_bonus_coeff

    self._completed_trajectories: list[Trajectory] = []

run abstractmethod async

run(task: dict, uid: str, **kwargs) -> Episode | None

Execute the workflow on a single task

Parameters:

Name Type Description Default
task dict

The task to execute.

required
uid str

The unique identifier for the task.

required
**kwargs

Additional keyword arguments.

{}

Returns:

Name Type Description
Episode Episode | None

The episode generated by the workflow.

Source code in rllm/workflows/workflow.py
@abstractmethod
async def run(self, task: dict, uid: str, **kwargs) -> Episode | None:
    """Execute the workflow on a single task

    Args:
        task: The task to execute.
        uid: The unique identifier for the task.
        **kwargs: Additional keyword arguments.

    Returns:
        Episode: The episode generated by the workflow.
    """
    pass

run_with_termination_handling async

run_with_termination_handling(task: dict, uid: str, **kwargs) -> Episode

Wrapper method around workflow.run that handles termination events, errors, timeouts, and post-processing.

Parameters:

Name Type Description Default
task dict

The task to execute.

required
uid str

The unique identifier for the task.

required
**kwargs

Additional keyword arguments.

{}
Source code in rllm/workflows/workflow.py
async def run_with_termination_handling(self, task: dict, uid: str, **kwargs) -> Episode:
    """Wrapper method around workflow.run that handles termination events, errors, timeouts, and post-processing.

    Args:
        task: The task to execute.
        uid: The unique identifier for the task.
        **kwargs: Additional keyword arguments.
    """
    try:
        coro = self.run(task, uid, **kwargs)
        output = await asyncio.wait_for(coro, timeout=self.timeout)
        if output is not None and isinstance(output, Episode):
            return output  # we assume it's already postprocessed
        return self.postprocess_episode(self.collect_trajectories(), TerminationReason.UNKNOWN)
    except asyncio.TimeoutError:
        return self.postprocess_episode(self.collect_trajectories(), TerminationReason.TIMEOUT)
    except TerminationEvent as e:
        return self.postprocess_episode(self.collect_trajectories(), e.reason)
    except Exception as e:
        import traceback

        error_details = {"error_message": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()}
        return self.postprocess_episode(self.collect_trajectories(), TerminationReason.ERROR, error=error_details)

commit

commit(name: str | None = None, agent: BaseAgent | None = None, trajectory: Trajectory | None = None, reset: bool = False) -> None

Commit a trajectory for training.

Parameters:

Name Type Description Default
name str | None

The name of the trajectory.

None
agent BaseAgent | None

The agent that generated the trajectory.

None
trajectory Trajectory | None

The trajectory to commit.

None
reset bool

Whether to reset the agent.

False
Source code in rllm/workflows/workflow.py
def commit(self, name: str | None = None, agent: BaseAgent | None = None, trajectory: Trajectory | None = None, reset: bool = False) -> None:
    """Commit a trajectory for training.

    Args:
        name: The name of the trajectory.
        agent: The agent that generated the trajectory.
        trajectory: The trajectory to commit.
        reset: Whether to reset the agent.
    """
    assert agent is not None or trajectory is not None, "Either agent or trajectory must be provided to workflow.commit"
    assert agent is None or trajectory is None, "Only one of agent or trajectory can be provided to workflow.commit"

    traj = agent.trajectory if agent is not None else trajectory
    if name:
        traj.name = name
    if traj.steps:
        self._completed_trajectories.append(deepcopy(traj))

    if agent is not None and reset:
        agent.reset()

collect_trajectories

collect_trajectories() -> Episode

Collect the trajectories from the workflow

Returns:

Name Type Description
Episode Episode

The episode generated by the workflow.

Source code in rllm/workflows/workflow.py
def collect_trajectories(self) -> Episode:
    """Collect the trajectories from the workflow

    Returns:
        Episode: The episode generated by the workflow.
    """

    episode = Episode()

    # Start with completed trajectories
    episode.trajectories.extend(self._completed_trajectories)

    # Track completed trajectory uids
    completed_trajectory_uids = {trajectory.uid for trajectory in self._completed_trajectories}

    # Add trajectories from agents that aren't already in completed trajectories
    for attr_name in dir(self):
        if attr_name.startswith("_"):
            continue
        attr_value = getattr(self, attr_name)
        if isinstance(attr_value, BaseAgent) and hasattr(attr_value, "trajectory") and getattr(attr_value.trajectory, "uid", None) not in completed_trajectory_uids and len(attr_value.trajectory.steps) > 0:
            episode.trajectories.append(deepcopy(attr_value.trajectory))

    return episode

compute_trajectory_reward

compute_trajectory_reward(trajectory: Trajectory) -> None

Compute the trajectory-level reward. Default: sum the step rewards

Parameters:

Name Type Description Default
trajectory Trajectory

The trajectory to compute the reward for.

required
Source code in rllm/workflows/workflow.py
def compute_trajectory_reward(self, trajectory: Trajectory) -> None:
    """
    Compute the trajectory-level reward.
    Default: sum the step rewards

    Args:
        trajectory: The trajectory to compute the reward for.
    """
    trajectory.reward = np.sum([d.reward for d in trajectory.steps])

adjust_step_rewards

adjust_step_rewards(trajectory: Trajectory) -> None

Adjust the step-level rewards. Supports reward shaping and discounting self.reward_bonus_coeff and self.gamma are 0.0, so no adjustments are made by default.

Parameters:

Name Type Description Default
trajectory Trajectory

The trajectory to adjust the rewards for.

required
Source code in rllm/workflows/workflow.py
def adjust_step_rewards(self, trajectory: Trajectory) -> None:
    """
    Adjust the step-level rewards. Supports reward shaping and discounting
    self.reward_bonus_coeff and self.gamma are 0.0, so no adjustments are made by default.

    Args:
        trajectory: The trajectory to adjust the rewards for.
    """
    # reward shaping
    # s[i].reward = s[i].reward + bonus * (s[i].reward - s[i-1].reward) for i > 0
    if self.reward_bonus_coeff > 0.0:
        raw_rewards = [step.reward for step in trajectory.steps]
        for i in range(1, len(trajectory.steps)):
            trajectory.steps[i].reward += self.reward_bonus_coeff * (raw_rewards[i] - raw_rewards[i - 1])

    # Compute Monte Carlo returns (backward iteration)
    # G_t = R_{t+1} + γ * R_{t+2} + γ² * R_{t+3} + ... + γ^{T-t-1} * R_T
    if self.gamma > 0.0:
        G = 0.0
        for step in reversed(trajectory.steps):
            G = step.reward + self.gamma * G
            step.reward = G  # Replace the reward with MC return

assign_episode_correctness

assign_episode_correctness(episode: Episode) -> None

Assign an episode-level correctness flag. Default: True if the sum of the trajectory rewards is strictly positive.

Parameters:

Name Type Description Default
episode Episode

The episode to assign the correctness flag to.

required
Source code in rllm/workflows/workflow.py
def assign_episode_correctness(self, episode: Episode) -> None:
    """
    Assign an episode-level correctness flag.
    Default: True if the sum of the trajectory rewards is strictly positive.

    Args:
        episode: The episode to assign the correctness flag to.
    """
    total_reward = 0
    for trajectory in episode.trajectories:
        total_reward += trajectory.reward
    episode.is_correct = total_reward > 0

collect_metrics

collect_metrics(episode: Episode) -> None

Collect metrics from the episode.

Parameters:

Name Type Description Default
episode Episode

The episode to collect metrics from.

required
Source code in rllm/workflows/workflow.py
def collect_metrics(self, episode: Episode) -> None:
    """
    Collect metrics from the episode.

    Args:
        episode: The episode to collect metrics from.
    """
    metrics = defaultdict(list)
    for traj in episode.trajectories:
        name = traj.name
        metrics[name].append(traj.reward)
    episode.metrics = {f"{k}_acc": float(np.mean(v)) for k, v in metrics.items()}

postprocess_episode

postprocess_episode(episode: Episode, termination_reason: TerminationReason = None, error: dict = None) -> Episode

Collect and process the trajectories

Parameters:

Name Type Description Default
episode Episode

The episode to postprocess.

required
termination_reason TerminationReason

The termination reason for the episode.

None
error dict

The error details for the episode.

None
Source code in rllm/workflows/workflow.py
def postprocess_episode(self, episode: Episode, termination_reason: TerminationReason = None, error: dict = None) -> Episode:
    """Collect and process the trajectories

    Args:
        episode: The episode to postprocess.
        termination_reason: The termination reason for the episode.
        error: The error details for the episode.
    """

    # 1. assign a task id and task
    episode.id = self.uid
    episode.task = self.task

    for trajectory in episode.trajectories:
        # depending on the terminaiton reason, there may be a trajectry with an additional step with empty chat_completions
        # i.e., if it's thrown between agent.update_from_env() and agent.update_from_model()
        if trajectory.steps and not trajectory.steps[-1].chat_completions:
            trajectory.steps.pop()

        # 2. compute trajectory-level rewards
        self.compute_trajectory_reward(trajectory)

        # 3. adjust the step level rewards (e.g., reward shaping or discounting)
        if len(trajectory.steps) > 1:
            self.adjust_step_rewards(trajectory)

    # 4. assign an episode-level correctness flag
    self.assign_episode_correctness(episode)

    # 5. collect additional metrics workflow
    # by default, we report the acc of each agent using the traj reward
    self.collect_metrics(episode)

    # 6. store error details if provided
    if error is not None:
        episode.info["error"] = error

    # 7. assign a termination reason
    episode.termination_reason = termination_reason or TerminationReason.UNKNOWN

    return episode

reset

reset(task: dict | None = None, uid: str | None = None) -> None

Reset the workflow

Parameters:

Name Type Description Default
task dict | None

The task to reset the workflow to.

None
uid str | None

The unique identifier for the task.

None
Source code in rllm/workflows/workflow.py
def reset(self, task: dict | None = None, uid: str | None = None) -> None:
    """Reset the workflow

    Args:
        task: The task to reset the workflow to.
        uid: The unique identifier for the task.
    """
    # set the uid and task
    self.uid = uid
    self.task = task
    self._completed_trajectories = []

    # reset agents (look for class attributes that are BaseAgent subclasses)
    for attr_name in dir(self):
        if attr_name.startswith("_"):
            continue
        attr_value = getattr(self, attr_name)
        if isinstance(attr_value, BaseAgent) and hasattr(attr_value, "reset"):
            attr_value.reset()
            attr_value.trajectory.task = task

    # reset environments (look for class attributes that are BaseEnv subclasses)
    for attr_name in dir(self):
        if attr_name.startswith("_"):
            continue
        attr_value = getattr(self, attr_name)
        if isinstance(attr_value, BaseEnv) and hasattr(attr_value, "reset"):
            attr_value.reset(task=task)

is_multithread_safe

is_multithread_safe() -> bool

Check if the workflow is multithread safe

Returns:

Name Type Description
bool bool

True if the workflow is multithread safe, False otherwise.

Source code in rllm/workflows/workflow.py
def is_multithread_safe(self) -> bool:
    """Check if the workflow is multithread safe

    Returns:
        bool: True if the workflow is multithread safe, False otherwise.
    """
    for attr_name in dir(self):
        attr_value = getattr(self, attr_name)
        if isinstance(attr_value, BaseEnv) and not attr_value.is_multithread_safe():
            return False
    return True

run_in_executor async

run_in_executor(fn, *args, **kwargs)

Run a function in seperate thread pool executor.

Parameters:

Name Type Description Default
fn

The function to run.

required
*args

The arguments to pass to the function.

()
**kwargs

The keyword arguments to pass to the function.

{}
Source code in rllm/workflows/workflow.py
async def run_in_executor(self, fn, *args, **kwargs):
    """Run a function in seperate thread pool executor.

    Args:
        fn: The function to run.
        *args: The arguments to pass to the function.
        **kwargs: The keyword arguments to pass to the function.
    """
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(self.executor, partial(fn, *args, **kwargs))