diff --git a/bindu/server/scheduler/base.py b/bindu/server/scheduler/base.py index fd06e878..d0b14c1e 100644 --- a/bindu/server/scheduler/base.py +++ b/bindu/server/scheduler/base.py @@ -21,22 +21,18 @@ class Scheduler(ABC): @abstractmethod async def run_task(self, params: TaskSendParams) -> None: """Send a task to be executed by the worker.""" - raise NotImplementedError("send_run_task is not implemented yet.") @abstractmethod async def cancel_task(self, params: TaskIdParams) -> None: """Cancel a task.""" - raise NotImplementedError("send_cancel_task is not implemented yet.") @abstractmethod async def pause_task(self, params: TaskIdParams) -> None: """Pause a task.""" - raise NotImplementedError("send_pause_task is not implemented yet.") @abstractmethod async def resume_task(self, params: TaskIdParams) -> None: """Resume a task.""" - raise NotImplementedError("send_resume_task is not implemented yet.") @abstractmethod async def __aenter__(self) -> Self: diff --git a/bindu/server/workers/base.py b/bindu/server/workers/base.py index 1ff58950..c6023e20 100644 --- a/bindu/server/workers/base.py +++ b/bindu/server/workers/base.py @@ -215,25 +215,100 @@ def build_artifacts(self, result: Any) -> list[Artifact]: ... # ------------------------------------------------------------------------- - # Future Operations (Not Yet Implemented) + # Pause/Resume Operations # ------------------------------------------------------------------------- async def _handle_pause(self, params: TaskIdParams) -> None: - """Handle pause operation. + task_id = params["task_id"] + logger.info(f"Pausing task: {task_id}") - TODO: Implement task pause functionality - - Save current execution state - - Update task to 'suspended' state - - Release resources while preserving context - """ - raise NotImplementedError("Pause operation not yet implemented") + try: + existing_task = await self.storage.load_task(task_id) + if not existing_task: + logger.warning(f"Task {task_id} not found for pause operation") + return + + current_state = existing_task.get("status", {}).get("state", "working") + if current_state in ("completed", "canceled", "failed"): + logger.warning(f"Cannot pause task {task_id} in state: {current_state}") + return + + checkpoint_data = await self._create_task_checkpoint(task_id) + + metadata = existing_task.get("metadata") or {} + metadata["_checkpoint"] = checkpoint_data + metadata["_paused_at"] = str(anyio.current_time()) + + await self.storage.update_task( + task_id=task_id, + state="suspended", + metadata=metadata, + ) + logger.info(f"Task {task_id} paused successfully") + + except Exception as e: + logger.error(f"Failed to pause task {task_id}: {e}", exc_info=True) + await self.storage.update_task(task_id, state="failed") async def _handle_resume(self, params: TaskIdParams) -> None: - """Handle resume operation. + task_id = params["task_id"] + logger.info(f"Resuming task: {task_id}") - TODO: Implement task resume functionality - - Restore execution state - - Update task to 'resumed' state - - Continue from last checkpoint - """ - raise NotImplementedError("Resume operation not yet implemented") + try: + existing_task = await self.storage.load_task(task_id) + if not existing_task: + logger.warning(f"Task {task_id} not found for resume operation") + return + + current_state = existing_task.get("status", {}).get("state") + if current_state != "suspended": + logger.warning( + f"Cannot resume task {task_id} in state: {current_state}. Task must be in 'suspended' state." + ) + return + + metadata = existing_task.get("metadata") or {} + checkpoint_data = metadata.get("_checkpoint") + + if checkpoint_data: + await self._restore_task_checkpoint(task_id, checkpoint_data) + + metadata.pop("_checkpoint", None) + metadata.pop("_paused_at", None) + + await self.storage.update_task( + task_id=task_id, + state="resumed", + metadata=metadata, + ) + logger.info(f"Task {task_id} resumed successfully") + + except Exception as e: + logger.error(f"Failed to resume task {task_id}: {e}", exc_info=True) + await self.storage.update_task(task_id, state="failed") + + async def _create_task_checkpoint(self, task_id: Any) -> dict[str, Any]: + task = await self.storage.load_task(task_id) + if not task: + return {} + + checkpoint = { + "state": task.get("status", {}).get("state", "working"), + "metadata": task.get("metadata", {}), + "context_id": task.get("contextId"), + } + + context_id = task.get("contextId") + if context_id: + context = await self.storage.load_context(context_id) + if context: + checkpoint["context"] = context + + return checkpoint + + async def _restore_task_checkpoint( + self, task_id: Any, checkpoint: dict[str, Any] + ) -> None: + logger.debug( + f"Restored checkpoint for task {task_id}: {checkpoint.get('state')}" + )