Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions agentlightning/litagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,50 @@ def validation_rollout(self, task: TaskInput, rollout_id: str, resources: NamedR
"""
return self.training_rollout(task, rollout_id, resources)

def training_rollout_batch(
self,
tasks: List[TaskInput],
rollout_ids: List[str],
resources_list: List[NamedResources],
) -> List[RolloutRawResult]:
"""Executes training rollouts for a batch of tasks.

By default, this method iterates over ``tasks`` and delegates to
:meth:`training_rollout` for each corresponding ``rollout_id`` with its
matching resources.

Args:
tasks: A list of task inputs.
rollout_ids: A list of rollout identifiers matching ``tasks``.
resources_list: A list of named resources for each rollout.

Returns:
A list containing the result of each individual rollout.
"""
return [self.training_rollout(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list)]

def validation_rollout_batch(
self,
tasks: List[TaskInput],
rollout_ids: List[str],
resources_list: List[NamedResources],
) -> List[RolloutRawResult]:
"""Executes validation rollouts for a batch of tasks.

By default, this method delegates to :meth:`validation_rollout` for each
task, which itself falls back to :meth:`training_rollout` unless
overridden.

Args:
tasks: A list of task inputs.
rollout_ids: A list of rollout identifiers matching ``tasks``.
resources_list: A list of named resources for each rollout.

Returns:
A list containing the result of each individual validation rollout.
"""
return [self.validation_rollout(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list)]

async def training_rollout_async(
self, task: TaskInput, rollout_id: str, resources: NamedResources
) -> RolloutRawResult:
Expand Down Expand Up @@ -176,3 +220,51 @@ async def validation_rollout_async(
The result of the asynchronous validation rollout.
"""
return await self.training_rollout_async(task, rollout_id, resources)

async def training_rollout_batch_async(
self,
tasks: List[TaskInput],
rollout_ids: List[str],
resources_list: List[NamedResources],
) -> List[RolloutRawResult]:
"""Asynchronous version of :meth:`training_rollout_batch`.

By default, this method awaits :meth:`training_rollout_async` for each
item in ``tasks`` sequentially with its matching resources.

Args:
tasks: A list of task inputs.
rollout_ids: A list of rollout identifiers matching ``tasks``.
resources_list: A list of named resources for each rollout.

Returns:
A list containing the result of each individual rollout.
"""
return [
await self.training_rollout_async(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list)
]

async def validation_rollout_batch_async(
self,
tasks: List[TaskInput],
rollout_ids: List[str],
resources_list: List[NamedResources],
) -> List[RolloutRawResult]:
"""Asynchronous version of :meth:`validation_rollout_batch`.

By default, this method awaits :meth:`validation_rollout_async` for each
task. Since :meth:`validation_rollout_async` redirects to
:meth:`training_rollout_async` unless overridden, this batch method will
also use :meth:`training_rollout_async` by default.

Args:
tasks: A list of task inputs.
rollout_ids: A list of rollout identifiers matching ``tasks``.
resources_list: A list of named resources for each rollout.

Returns:
A list containing the result of each individual validation rollout.
"""
return [
await self.validation_rollout_async(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list)
]
228 changes: 140 additions & 88 deletions agentlightning/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import time
from contextlib import nullcontext
from typing import List, Optional, Union, Dict, Any
from collections import defaultdict

import agentops

from opentelemetry.sdk.trace import ReadableSpan
from .client import AgentLightningClient
from .litagent import LitAgent
from .types import Rollout, Task, Triplet, RolloutRawResult
from .types import Rollout, Task, Triplet, RolloutRawResult, TaskInput, NamedResources
from .types import ParallelWorkerBase
from .tracer.base import BaseTracer
from .tracer import TripletExporter
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
triplet_exporter: TripletExporter,
worker_id: Optional[int] = None,
max_tasks: Optional[int] = None,
batch_size: int = 1,
):
super().__init__()
self.agent = agent
Expand All @@ -53,6 +55,7 @@ def __init__(
# Worker-specific attributes
self.worker_id = worker_id
self.max_tasks = max_tasks
self.batch_size = batch_size

def _log_prefix(self, rollout_id: Optional[str] = None) -> str:
"""Generates a standardized log prefix for the current worker."""
Expand Down Expand Up @@ -134,118 +137,167 @@ def _to_rollout_object(
return result.model_copy(update=result_dict)
return Rollout(**result_dict)

def run(self) -> bool:
"""Poll the task and rollout once synchronously."""
self.agent.set_runner(self) # Ensure the agent has a reference to this runner
def run(self) -> int:
"""Poll tasks and execute rollouts synchronously.

task = self.client.poll_next_task()
if task is None:
Returns the number of tasks processed."""
self.agent.set_runner(self)

tasks: List[Task] = []
for _ in range(self.batch_size):
task = self.client.poll_next_task()
if task is None:
break
tasks.append(task)

if not tasks:
logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
return False
rollout_id = task.rollout_id

resources_id = task.resources_id
resources_update = None
if resources_id:
resources_update = self.client.get_resources_by_id(resources_id)
else:
logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
resources_update = self.client.get_latest_resources()
if not resources_update:
logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
return False

rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout

try:
with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
start_time = time.time()
rollout_method = self.agent.training_rollout if task.mode == "train" else self.agent.validation_rollout
# Pass the task input, not the whole task object
result = rollout_method(task.input, task.rollout_id, resources_update.resources)
rollout_obj = self._to_rollout_object(result, task.rollout_id)
end_time = time.time()
logger.info(
f"{self._log_prefix(rollout_id)} Completed in "
f"{end_time - start_time:.2f}s. Triplet length: "
f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. "
f"Reward: {rollout_obj.final_reward}"
)

except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
finally:
self.client.post_rollout(rollout_obj)

return True
return 0

inputs: List[TaskInput] = []
rollout_ids: List[str] = []
resources_list: List[NamedResources] = []
modes: List[Optional[str]] = []

for task in tasks:
resources_id = task.resources_id
resources_update = None
if resources_id:
resources_update = self.client.get_resources_by_id(resources_id)
else:
logger.debug(f"{self._log_prefix(task.rollout_id)} No 'resources_id'. Fetching latest resources.")
resources_update = self.client.get_latest_resources()
if not resources_update:
logger.error(f"{self._log_prefix(task.rollout_id)} Failed to fetch resources. Skipping.")
continue
inputs.append(task.input)
rollout_ids.append(task.rollout_id)
resources_list.append(resources_update.resources)
modes.append(task.mode)

if not inputs:
return 0

results: List[RolloutRawResult] = [None] * len(inputs)
mode_groups: Dict[Optional[str], List[int]] = defaultdict(list)
for idx, mode in enumerate(modes):
mode_groups[mode].append(idx)

for mode, indices in mode_groups.items():
sub_tasks = [inputs[i] for i in indices]
sub_ids = [rollout_ids[i] for i in indices]
sub_res = [resources_list[i] for i in indices]
if mode == "train":
sub_results = self.agent.training_rollout_batch(sub_tasks, sub_ids, sub_res)
else:
sub_results = self.agent.validation_rollout_batch(sub_tasks, sub_ids, sub_res)
for idx, res in zip(indices, sub_results):
results[idx] = res

for rid, res in zip(rollout_ids, results):
rollout_obj = self._to_rollout_object(res, rid)
try:
self.client.post_rollout(rollout_obj)
except Exception:
logger.exception(f"{self._log_prefix(rid)} Exception during rollout.")

return len(rollout_ids)

def iter(self) -> int:
"""Executes the synchronous polling and rollout loop."""
num_tasks_processed = 0
logger.info(f"{self._log_prefix()} Started sync rollouts (max: {self.max_tasks or 'unlimited'}).")

while self.max_tasks is None or num_tasks_processed < self.max_tasks:
if self.run():
num_tasks_processed += 1
processed = self.run()
if processed:
num_tasks_processed += processed
else:
break

if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")

logger.info(f"{self._log_prefix()} Finished sync rollouts. Processed {num_tasks_processed} tasks.")
return num_tasks_processed

async def run_async(self) -> bool:
"""Poll the task and rollout once."""
self.agent.set_runner(self) # Ensure the agent has a reference to this runner
async def run_async(self) -> int:
"""Poll tasks and execute rollouts asynchronously.

task = await self.client.poll_next_task_async()
if task is None:
Returns the number of tasks processed."""
self.agent.set_runner(self)

tasks: List[Task] = []
for _ in range(self.batch_size):
task = await self.client.poll_next_task_async()
if task is None:
break
tasks.append(task)

if not tasks:
logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.")
return False
rollout_id = task.rollout_id

resources_id = task.resources_id
resources_update = None
if resources_id:
resources_update = await self.client.get_resources_by_id_async(resources_id)
else:
logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
resources_update = await self.client.get_latest_resources_async()
if not resources_update:
logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
return False

rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout

try:
with self.tracer.trace_context(name=f"rollout_{rollout_id}"):
start_time = time.time()
rollout_method = (
self.agent.training_rollout_async if task.mode == "train" else self.agent.validation_rollout_async
)
# Pass the task input, not the whole task object
result = await rollout_method(task.input, task.rollout_id, resources_update.resources)
rollout_obj = self._to_rollout_object(result, task.rollout_id)
end_time = time.time()
logger.info(
f"{self._log_prefix(rollout_id)} Completed in "
f"{end_time - start_time:.2f}s. Reward: {rollout_obj.final_reward}"
)
except Exception:
logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
finally:
await self.client.post_rollout_async(rollout_obj)

return True
return 0

inputs: List[TaskInput] = []
rollout_ids: List[str] = []
resources_list: List[NamedResources] = []
modes: List[Optional[str]] = []

for task in tasks:
resources_id = task.resources_id
resources_update = None
if resources_id:
resources_update = await self.client.get_resources_by_id_async(resources_id)
else:
logger.debug(f"{self._log_prefix(task.rollout_id)} No 'resources_id'. Fetching latest resources.")
resources_update = await self.client.get_latest_resources_async()
if not resources_update:
logger.error(f"{self._log_prefix(task.rollout_id)} Failed to fetch resources. Skipping.")
continue
inputs.append(task.input)
rollout_ids.append(task.rollout_id)
resources_list.append(resources_update.resources)
modes.append(task.mode)

if not inputs:
return 0

results: List[RolloutRawResult] = [None] * len(inputs)
mode_groups: Dict[Optional[str], List[int]] = defaultdict(list)
for idx, mode in enumerate(modes):
mode_groups[mode].append(idx)

for mode, indices in mode_groups.items():
sub_tasks = [inputs[i] for i in indices]
sub_ids = [rollout_ids[i] for i in indices]
sub_res = [resources_list[i] for i in indices]
if mode == "train":
sub_results = await self.agent.training_rollout_batch_async(sub_tasks, sub_ids, sub_res)
else:
sub_results = await self.agent.validation_rollout_batch_async(sub_tasks, sub_ids, sub_res)
for idx, res in zip(indices, sub_results):
results[idx] = res

for rid, res in zip(rollout_ids, results):
rollout_obj = self._to_rollout_object(res, rid)
try:
await self.client.post_rollout_async(rollout_obj)
except Exception:
logger.exception(f"{self._log_prefix(rid)} Exception during rollout.")

return len(rollout_ids)

async def iter_async(self) -> int:
"""Executes the asynchronous polling and rollout loop."""
num_tasks_processed = 0
logger.info(f"{self._log_prefix()} Started async rollouts (max: {self.max_tasks or 'unlimited'}).")

while self.max_tasks is None or num_tasks_processed < self.max_tasks:
if await self.run_async():
num_tasks_processed += 1
processed = await self.run_async()
if processed:
num_tasks_processed += processed
else:
break

if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}")
Expand Down
Loading