From 3dd0d70910d0f427d435c84ec1277037114c5950 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Fri, 15 Aug 2025 10:57:11 +0800 Subject: [PATCH] Support resource lists and runner batching --- agentlightning/litagent.py | 92 +++++++++++++++ agentlightning/runner.py | 228 +++++++++++++++++++++++-------------- agentlightning/server.py | 27 +++++ agentlightning/trainer.py | 3 + tests/test_client.py | 20 ++++ tests/test_litagent.py | 43 +++++++ 6 files changed, 325 insertions(+), 88 deletions(-) create mode 100644 tests/test_litagent.py diff --git a/agentlightning/litagent.py b/agentlightning/litagent.py index 4ff733e37..6f4e3a748 100644 --- a/agentlightning/litagent.py +++ b/agentlightning/litagent.py @@ -139,6 +139,50 @@ def validation_rollout(self, task: TaskInput, rollout_id: str, resources: NamedR """ return self.training_rollout(task, rollout_id, resources) + def training_rollout_batch( + self, + tasks: List[TaskInput], + rollout_ids: List[str], + resources_list: List[NamedResources], + ) -> List[RolloutRawResult]: + """Executes training rollouts for a batch of tasks. + + By default, this method iterates over ``tasks`` and delegates to + :meth:`training_rollout` for each corresponding ``rollout_id`` with its + matching resources. + + Args: + tasks: A list of task inputs. + rollout_ids: A list of rollout identifiers matching ``tasks``. + resources_list: A list of named resources for each rollout. + + Returns: + A list containing the result of each individual rollout. + """ + return [self.training_rollout(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list)] + + def validation_rollout_batch( + self, + tasks: List[TaskInput], + rollout_ids: List[str], + resources_list: List[NamedResources], + ) -> List[RolloutRawResult]: + """Executes validation rollouts for a batch of tasks. + + By default, this method delegates to :meth:`validation_rollout` for each + task, which itself falls back to :meth:`training_rollout` unless + overridden. + + Args: + tasks: A list of task inputs. + rollout_ids: A list of rollout identifiers matching ``tasks``. + resources_list: A list of named resources for each rollout. + + Returns: + A list containing the result of each individual validation rollout. + """ + return [self.validation_rollout(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list)] + async def training_rollout_async( self, task: TaskInput, rollout_id: str, resources: NamedResources ) -> RolloutRawResult: @@ -176,3 +220,51 @@ async def validation_rollout_async( The result of the asynchronous validation rollout. """ return await self.training_rollout_async(task, rollout_id, resources) + + async def training_rollout_batch_async( + self, + tasks: List[TaskInput], + rollout_ids: List[str], + resources_list: List[NamedResources], + ) -> List[RolloutRawResult]: + """Asynchronous version of :meth:`training_rollout_batch`. + + By default, this method awaits :meth:`training_rollout_async` for each + item in ``tasks`` sequentially with its matching resources. + + Args: + tasks: A list of task inputs. + rollout_ids: A list of rollout identifiers matching ``tasks``. + resources_list: A list of named resources for each rollout. + + Returns: + A list containing the result of each individual rollout. + """ + return [ + await self.training_rollout_async(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list) + ] + + async def validation_rollout_batch_async( + self, + tasks: List[TaskInput], + rollout_ids: List[str], + resources_list: List[NamedResources], + ) -> List[RolloutRawResult]: + """Asynchronous version of :meth:`validation_rollout_batch`. + + By default, this method awaits :meth:`validation_rollout_async` for each + task. Since :meth:`validation_rollout_async` redirects to + :meth:`training_rollout_async` unless overridden, this batch method will + also use :meth:`training_rollout_async` by default. + + Args: + tasks: A list of task inputs. + rollout_ids: A list of rollout identifiers matching ``tasks``. + resources_list: A list of named resources for each rollout. + + Returns: + A list containing the result of each individual validation rollout. + """ + return [ + await self.validation_rollout_async(t, rid, res) for t, rid, res in zip(tasks, rollout_ids, resources_list) + ] diff --git a/agentlightning/runner.py b/agentlightning/runner.py index 7ec141411..4d103aba9 100644 --- a/agentlightning/runner.py +++ b/agentlightning/runner.py @@ -5,13 +5,14 @@ import time from contextlib import nullcontext from typing import List, Optional, Union, Dict, Any +from collections import defaultdict import agentops from opentelemetry.sdk.trace import ReadableSpan from .client import AgentLightningClient from .litagent import LitAgent -from .types import Rollout, Task, Triplet, RolloutRawResult +from .types import Rollout, Task, Triplet, RolloutRawResult, TaskInput, NamedResources from .types import ParallelWorkerBase from .tracer.base import BaseTracer from .tracer import TripletExporter @@ -43,6 +44,7 @@ def __init__( triplet_exporter: TripletExporter, worker_id: Optional[int] = None, max_tasks: Optional[int] = None, + batch_size: int = 1, ): super().__init__() self.agent = agent @@ -53,6 +55,7 @@ def __init__( # Worker-specific attributes self.worker_id = worker_id self.max_tasks = max_tasks + self.batch_size = batch_size def _log_prefix(self, rollout_id: Optional[str] = None) -> str: """Generates a standardized log prefix for the current worker.""" @@ -134,50 +137,71 @@ def _to_rollout_object( return result.model_copy(update=result_dict) return Rollout(**result_dict) - def run(self) -> bool: - """Poll the task and rollout once synchronously.""" - self.agent.set_runner(self) # Ensure the agent has a reference to this runner + def run(self) -> int: + """Poll tasks and execute rollouts synchronously. - task = self.client.poll_next_task() - if task is None: + Returns the number of tasks processed.""" + self.agent.set_runner(self) + + tasks: List[Task] = [] + for _ in range(self.batch_size): + task = self.client.poll_next_task() + if task is None: + break + tasks.append(task) + + if not tasks: logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.") - return False - rollout_id = task.rollout_id - - resources_id = task.resources_id - resources_update = None - if resources_id: - resources_update = self.client.get_resources_by_id(resources_id) - else: - logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.") - resources_update = self.client.get_latest_resources() - if not resources_update: - logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.") - return False - - rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout - - try: - with self.tracer.trace_context(name=f"rollout_{rollout_id}"): - start_time = time.time() - rollout_method = self.agent.training_rollout if task.mode == "train" else self.agent.validation_rollout - # Pass the task input, not the whole task object - result = rollout_method(task.input, task.rollout_id, resources_update.resources) - rollout_obj = self._to_rollout_object(result, task.rollout_id) - end_time = time.time() - logger.info( - f"{self._log_prefix(rollout_id)} Completed in " - f"{end_time - start_time:.2f}s. Triplet length: " - f"{len(rollout_obj.triplets) if rollout_obj.triplets is not None else 'N/A'}. " - f"Reward: {rollout_obj.final_reward}" - ) - - except Exception: - logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.") - finally: - self.client.post_rollout(rollout_obj) - - return True + return 0 + + inputs: List[TaskInput] = [] + rollout_ids: List[str] = [] + resources_list: List[NamedResources] = [] + modes: List[Optional[str]] = [] + + for task in tasks: + resources_id = task.resources_id + resources_update = None + if resources_id: + resources_update = self.client.get_resources_by_id(resources_id) + else: + logger.debug(f"{self._log_prefix(task.rollout_id)} No 'resources_id'. Fetching latest resources.") + resources_update = self.client.get_latest_resources() + if not resources_update: + logger.error(f"{self._log_prefix(task.rollout_id)} Failed to fetch resources. Skipping.") + continue + inputs.append(task.input) + rollout_ids.append(task.rollout_id) + resources_list.append(resources_update.resources) + modes.append(task.mode) + + if not inputs: + return 0 + + results: List[RolloutRawResult] = [None] * len(inputs) + mode_groups: Dict[Optional[str], List[int]] = defaultdict(list) + for idx, mode in enumerate(modes): + mode_groups[mode].append(idx) + + for mode, indices in mode_groups.items(): + sub_tasks = [inputs[i] for i in indices] + sub_ids = [rollout_ids[i] for i in indices] + sub_res = [resources_list[i] for i in indices] + if mode == "train": + sub_results = self.agent.training_rollout_batch(sub_tasks, sub_ids, sub_res) + else: + sub_results = self.agent.validation_rollout_batch(sub_tasks, sub_ids, sub_res) + for idx, res in zip(indices, sub_results): + results[idx] = res + + for rid, res in zip(rollout_ids, results): + rollout_obj = self._to_rollout_object(res, rid) + try: + self.client.post_rollout(rollout_obj) + except Exception: + logger.exception(f"{self._log_prefix(rid)} Exception during rollout.") + + return len(rollout_ids) def iter(self) -> int: """Executes the synchronous polling and rollout loop.""" @@ -185,8 +209,11 @@ def iter(self) -> int: logger.info(f"{self._log_prefix()} Started sync rollouts (max: {self.max_tasks or 'unlimited'}).") while self.max_tasks is None or num_tasks_processed < self.max_tasks: - if self.run(): - num_tasks_processed += 1 + processed = self.run() + if processed: + num_tasks_processed += processed + else: + break if num_tasks_processed % 10 == 0 or num_tasks_processed == 1: logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}") @@ -194,49 +221,71 @@ def iter(self) -> int: logger.info(f"{self._log_prefix()} Finished sync rollouts. Processed {num_tasks_processed} tasks.") return num_tasks_processed - async def run_async(self) -> bool: - """Poll the task and rollout once.""" - self.agent.set_runner(self) # Ensure the agent has a reference to this runner + async def run_async(self) -> int: + """Poll tasks and execute rollouts asynchronously. - task = await self.client.poll_next_task_async() - if task is None: + Returns the number of tasks processed.""" + self.agent.set_runner(self) + + tasks: List[Task] = [] + for _ in range(self.batch_size): + task = await self.client.poll_next_task_async() + if task is None: + break + tasks.append(task) + + if not tasks: logger.info(f"{self._log_prefix()} Poll returned no task. Exiting.") - return False - rollout_id = task.rollout_id - - resources_id = task.resources_id - resources_update = None - if resources_id: - resources_update = await self.client.get_resources_by_id_async(resources_id) - else: - logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.") - resources_update = await self.client.get_latest_resources_async() - if not resources_update: - logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.") - return False - - rollout_obj = Rollout(rollout_id=task.rollout_id) # Default empty rollout - - try: - with self.tracer.trace_context(name=f"rollout_{rollout_id}"): - start_time = time.time() - rollout_method = ( - self.agent.training_rollout_async if task.mode == "train" else self.agent.validation_rollout_async - ) - # Pass the task input, not the whole task object - result = await rollout_method(task.input, task.rollout_id, resources_update.resources) - rollout_obj = self._to_rollout_object(result, task.rollout_id) - end_time = time.time() - logger.info( - f"{self._log_prefix(rollout_id)} Completed in " - f"{end_time - start_time:.2f}s. Reward: {rollout_obj.final_reward}" - ) - except Exception: - logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.") - finally: - await self.client.post_rollout_async(rollout_obj) - - return True + return 0 + + inputs: List[TaskInput] = [] + rollout_ids: List[str] = [] + resources_list: List[NamedResources] = [] + modes: List[Optional[str]] = [] + + for task in tasks: + resources_id = task.resources_id + resources_update = None + if resources_id: + resources_update = await self.client.get_resources_by_id_async(resources_id) + else: + logger.debug(f"{self._log_prefix(task.rollout_id)} No 'resources_id'. Fetching latest resources.") + resources_update = await self.client.get_latest_resources_async() + if not resources_update: + logger.error(f"{self._log_prefix(task.rollout_id)} Failed to fetch resources. Skipping.") + continue + inputs.append(task.input) + rollout_ids.append(task.rollout_id) + resources_list.append(resources_update.resources) + modes.append(task.mode) + + if not inputs: + return 0 + + results: List[RolloutRawResult] = [None] * len(inputs) + mode_groups: Dict[Optional[str], List[int]] = defaultdict(list) + for idx, mode in enumerate(modes): + mode_groups[mode].append(idx) + + for mode, indices in mode_groups.items(): + sub_tasks = [inputs[i] for i in indices] + sub_ids = [rollout_ids[i] for i in indices] + sub_res = [resources_list[i] for i in indices] + if mode == "train": + sub_results = await self.agent.training_rollout_batch_async(sub_tasks, sub_ids, sub_res) + else: + sub_results = await self.agent.validation_rollout_batch_async(sub_tasks, sub_ids, sub_res) + for idx, res in zip(indices, sub_results): + results[idx] = res + + for rid, res in zip(rollout_ids, results): + rollout_obj = self._to_rollout_object(res, rid) + try: + await self.client.post_rollout_async(rollout_obj) + except Exception: + logger.exception(f"{self._log_prefix(rid)} Exception during rollout.") + + return len(rollout_ids) async def iter_async(self) -> int: """Executes the asynchronous polling and rollout loop.""" @@ -244,8 +293,11 @@ async def iter_async(self) -> int: logger.info(f"{self._log_prefix()} Started async rollouts (max: {self.max_tasks or 'unlimited'}).") while self.max_tasks is None or num_tasks_processed < self.max_tasks: - if await self.run_async(): - num_tasks_processed += 1 + processed = await self.run_async() + if processed: + num_tasks_processed += processed + else: + break if num_tasks_processed % 10 == 0 or num_tasks_processed == 1: logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self.max_tasks or 'unlimited'}") diff --git a/agentlightning/server.py b/agentlightning/server.py index 82f6ccba8..71eb6523d 100644 --- a/agentlightning/server.py +++ b/agentlightning/server.py @@ -65,6 +65,21 @@ async def add_task( logger.info(f"Task queued: {rollout_id} (mode: {mode}, resources_id: {resources_id})") return rollout_id + async def add_tasks( + self, + samples: List[Any], + mode: Literal["train", "val", "test"] | None = None, + resources_id: str | None = None, + metadata_list: Optional[List[Dict[str, Any]]] = None, + ) -> List[str]: + """Adds multiple tasks to the queue and returns their rollout IDs.""" + rollout_ids = [] + for idx, sample in enumerate(samples): + metadata = metadata_list[idx] if metadata_list and idx < len(metadata_list) else None + rollout_id = await self.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata) + rollout_ids.append(rollout_id) + return rollout_ids + async def get_next_task(self) -> Optional[Task]: """ Retrieves the next task from the queue without blocking. @@ -312,6 +327,18 @@ async def queue_task( raise RuntimeError("Store not initialized. The server may not be running.") return await self._store.add_task(sample, mode=mode, resources_id=resources_id, metadata=metadata) + async def queue_tasks( + self, + samples: List[Any], + mode: Literal["train", "val", "test"] | None = None, + resources_id: str | None = None, + metadata_list: Optional[List[Dict[str, Any]]] = None, + ) -> List[str]: + """Adds multiple tasks to the queue for a client to process.""" + if not self._store: + raise RuntimeError("Store not initialized. The server may not be running.") + return await self._store.add_tasks(samples, mode=mode, resources_id=resources_id, metadata_list=metadata_list) + async def update_resources(self, resources: NamedResources) -> str: """ Updates the resources, creating a new version and setting it as the latest. diff --git a/agentlightning/trainer.py b/agentlightning/trainer.py index bc9338824..c8555906c 100644 --- a/agentlightning/trainer.py +++ b/agentlightning/trainer.py @@ -50,6 +50,7 @@ def __init__( n_workers: int = 1, max_tasks: Optional[int] = None, daemon: bool = True, + batch_size: int = 1, tracer: Union[BaseTracer, str, dict, None] = None, triplet_exporter: Union[TripletExporter, dict, None] = None, ): @@ -58,6 +59,7 @@ def __init__( self.max_tasks = max_tasks self.daemon = daemon self.dev = dev + self.batch_size = batch_size self._client: AgentLightningClient | None = None # Will be initialized in fit method self.tracer = self._make_tracer(tracer) @@ -181,6 +183,7 @@ def _worker_main_loop(self, agent: LitAgent, worker_id: int, is_async: bool): triplet_exporter=self.triplet_exporter, max_tasks=self.max_tasks, worker_id=worker_id, + batch_size=self.batch_size, ) loop.init_worker(worker_id) if is_async: diff --git a/tests/test_client.py b/tests/test_client.py index e401a6bba..64ace4843 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -372,3 +372,23 @@ async def test_local_client_async_methods(sample_resources: NamedResources): assert result["rollout_id"] == "async_rollout" assert len(client.rollouts) == 1 assert client.rollouts[0].final_reward == 0.8 + + +@pytest.mark.asyncio +async def test_queue_tasks(server_setup: Dict[str, Any]): + """Ensure multiple tasks can be queued and retrieved.""" + server: AgentLightningServer = server_setup["server"] + http_client: AsyncClient = server_setup["http_client"] + endpoint: str = server_setup["endpoint"] + client = AgentLightningClient(endpoint=endpoint, poll_interval=0.1) + + samples = [{"prompt": "1"}, {"prompt": "2"}] + rollout_ids = await server.queue_tasks(samples, mode="train") + assert len(rollout_ids) == 2 + + task_a = await client.poll_next_task_async() + task_b = await client.poll_next_task_async() + assert {task_a.rollout_id, task_b.rollout_id} == set(rollout_ids) + + resp = await http_client.get("/task") + assert resp.json()["is_available"] is False diff --git a/tests/test_litagent.py b/tests/test_litagent.py new file mode 100644 index 000000000..6a2ca1cb4 --- /dev/null +++ b/tests/test_litagent.py @@ -0,0 +1,43 @@ +import pytest + +from agentlightning.litagent import LitAgent + + +def test_rollout_batch_defaults_to_single_rollout(): + class DummyAgent(LitAgent): + def __init__(self): + super().__init__() + self.calls = [] + + def training_rollout(self, task, rollout_id, resources): + self.calls.append((task, rollout_id)) + return 1.0 + + agent = DummyAgent() + results = agent.training_rollout_batch([1, 2], ["r1", "r2"], [{}, {}]) + assert results == [1.0, 1.0] + assert agent.calls == [(1, "r1"), (2, "r2")] + + val_results = agent.validation_rollout_batch([3], ["r3"], [{}]) + assert val_results == [1.0] + assert agent.calls == [(1, "r1"), (2, "r2"), (3, "r3")] + + +@pytest.mark.asyncio +async def test_rollout_batch_async_defaults_to_single_rollout(): + class DummyAsyncAgent(LitAgent): + def __init__(self): + super().__init__() + self.calls = [] + + async def training_rollout_async(self, task, rollout_id, resources): + self.calls.append((task, rollout_id)) + return 1.0 + + agent = DummyAsyncAgent() + results = await agent.training_rollout_batch_async([1, 2], ["r1", "r2"], [{}, {}]) + assert results == [1.0, 1.0] + + val_results = await agent.validation_rollout_batch_async([3], ["r3"], [{}]) + assert val_results == [1.0] + assert agent.calls == [(1, "r1"), (2, "r2"), (3, "r3")]