From 6d20cf02a0ceae3e3bb4f648e1658a80b615f58e Mon Sep 17 00:00:00 2001 From: Kkt04 Date: Fri, 6 Mar 2026 07:18:32 +0530 Subject: [PATCH 1/3] feat: Add PostgreSQL checkpoint storage with pause/resume support --- .../20260306_0001_add_task_checkpoints.py | 215 ++++++++++++ bindu/server/storage/base.py | 53 +++ bindu/server/storage/memory_storage.py | 309 ++++++----------- bindu/server/storage/postgres_storage.py | 189 ++++++++++ bindu/server/storage/schema.py | 42 +++ bindu/server/workers/base.py | 132 ++++++- bindu/server/workers/manifest_worker.py | 9 + bindu/settings.py | 10 + docs/CHECKPOINTS.md | 170 +++++++++ docs/STATE_TRANSITIONS.md | 231 +++++++++++++ .../test_checkpoint_persistence.py | 217 ++++++++++++ tests/integration/test_pause_resume.py | 272 +++++++++++++++ tests/integration/test_state_transitions.py | 323 ++++++++++++++++++ 13 files changed, 1945 insertions(+), 227 deletions(-) create mode 100644 alembic/versions/20260306_0001_add_task_checkpoints.py create mode 100644 docs/CHECKPOINTS.md create mode 100644 docs/STATE_TRANSITIONS.md create mode 100644 tests/integration/test_checkpoint_persistence.py create mode 100644 tests/integration/test_pause_resume.py create mode 100644 tests/integration/test_state_transitions.py diff --git a/alembic/versions/20260306_0001_add_task_checkpoints.py b/alembic/versions/20260306_0001_add_task_checkpoints.py new file mode 100644 index 00000000..2d06fdbf --- /dev/null +++ b/alembic/versions/20260306_0001_add_task_checkpoints.py @@ -0,0 +1,215 @@ +"""Add task_checkpoints table for pause/resume support. + +Revision ID: 20260306_0001 +Revises: 20260119_0001 +Create Date: 2026-03-06 00:00:00.000000 + +This migration adds the task_checkpoints table to store execution state +for pause/resume functionality. Checkpoints allow tasks to be suspended +and resumed later from the same point. + +The checkpoint stores: +- checkpoint_data: JSONB containing execution state +- step_number: Current step in execution +- step_label: Optional label for the current step +""" + +from typing import Sequence, Union + +from alembic import op + +revision: str = "20260306_0001" +down_revision: Union[str, None] = "20260119_0001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add task_checkpoints table.""" + op.create_table( + "task_checkpoints", + op.Column( + "id", op.Integer(), primary_key=True, autoincrement=True, nullable=False + ), + op.Column("task_id", op.UUID(as_uuid=True), nullable=False), + op.Column("checkpoint_data", op.JSONB(), nullable=False), + op.Column("step_number", op.Integer(), nullable=False, server_default="0"), + op.Column("step_label", op.String(255), nullable=True), + op.Column( + "created_at", + op.TIMESTAMP(timezone=True), + nullable=False, + server_default=op.text("NOW()"), + ), + op.Column( + "updated_at", + op.TIMESTAMP(timezone=True), + nullable=False, + server_default=op.text("NOW()"), + ), + ) + + op.create_index("idx_task_checkpoints_task_id", "task_checkpoints", ["task_id"]) + op.create_index( + "idx_task_checkpoints_created_at", "task_checkpoints", ["created_at"] + ) + + op.create_foreign_key( + "fk_task_checkpoints_task", + "task_checkpoints", + "tasks", + ["task_id"], + ["id"], + ondelete="CASCADE", + ) + + op.alter_column("task_checkpoints", "task_id", nullable=False) + + # Update the helper function to include task_checkpoints table + op.execute(""" + CREATE OR REPLACE FUNCTION create_bindu_tables_in_schema(schema_name TEXT) + RETURNS VOID AS $$ + BEGIN + -- Create tasks table + EXECUTE format(' + CREATE TABLE IF NOT EXISTS %I.tasks ( + id UUID PRIMARY KEY NOT NULL, + context_id UUID NOT NULL, + kind VARCHAR(50) NOT NULL DEFAULT ''task'', + state VARCHAR(50) NOT NULL, + state_timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + history JSONB NOT NULL DEFAULT ''[]''::jsonb, + artifacts JSONB DEFAULT ''[]''::jsonb, + metadata JSONB DEFAULT ''{}''::jsonb, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_tasks_context FOREIGN KEY (context_id) + REFERENCES %I.contexts(id) ON DELETE CASCADE + )', schema_name, schema_name); + + -- Create contexts table + EXECUTE format(' + CREATE TABLE IF NOT EXISTS %I.contexts ( + id UUID PRIMARY KEY NOT NULL, + context_data JSONB NOT NULL DEFAULT ''{}''::jsonb, + message_history JSONB DEFAULT ''[]''::jsonb, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + )', schema_name); + + -- Create task_feedback table + EXECUTE format(' + CREATE TABLE IF NOT EXISTS %I.task_feedback ( + id SERIAL PRIMARY KEY NOT NULL, + task_id UUID NOT NULL, + feedback_data JSONB NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_task_feedback_task FOREIGN KEY (task_id) + REFERENCES %I.tasks(id) ON DELETE CASCADE + )', schema_name, schema_name); + + -- Create webhook_configs table + EXECUTE format(' + CREATE TABLE IF NOT EXISTS %I.webhook_configs ( + task_id UUID PRIMARY KEY NOT NULL, + config JSONB NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_webhook_configs_task FOREIGN KEY (task_id) + REFERENCES %I.tasks(id) ON DELETE CASCADE + )', schema_name, schema_name); + + -- Create task_checkpoints table + EXECUTE format(' + CREATE TABLE IF NOT EXISTS %I.task_checkpoints ( + id SERIAL PRIMARY KEY NOT NULL, + task_id UUID NOT NULL, + checkpoint_data JSONB NOT NULL, + step_number INTEGER NOT NULL DEFAULT 0, + step_label VARCHAR(255), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + CONSTRAINT fk_task_checkpoints_task FOREIGN KEY (task_id) + REFERENCES %I.tasks(id) ON DELETE CASCADE + )', schema_name, schema_name); + + -- Create indexes for tasks + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON %I.tasks(context_id)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_tasks_state ON %I.tasks(state)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON %I.tasks(created_at DESC)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_tasks_updated_at ON %I.tasks(updated_at DESC)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_tasks_history_gin ON %I.tasks USING gin(history)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_tasks_metadata_gin ON %I.tasks USING gin(metadata)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_tasks_artifacts_gin ON %I.tasks USING gin(artifacts)', schema_name); + + -- Create indexes for contexts + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_contexts_created_at ON %I.contexts(created_at DESC)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_contexts_updated_at ON %I.contexts(updated_at DESC)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_contexts_data_gin ON %I.contexts USING gin(context_data)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_contexts_history_gin ON %I.contexts USING gin(message_history)', schema_name); + + -- Create indexes for task_feedback + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_task_feedback_task_id ON %I.task_feedback(task_id)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_task_feedback_created_at ON %I.task_feedback(created_at DESC)', schema_name); + + -- Create indexes for webhook_configs + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_webhook_configs_created_at ON %I.webhook_configs(created_at DESC)', schema_name); + + -- Create indexes for task_checkpoints + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_task_checkpoints_task_id ON %I.task_checkpoints(task_id)', schema_name); + EXECUTE format('CREATE INDEX IF NOT EXISTS idx_task_checkpoints_created_at ON %I.task_checkpoints(created_at DESC)', schema_name); + + -- Create triggers for updated_at + EXECUTE format(' + CREATE TRIGGER update_tasks_updated_at + BEFORE UPDATE ON %I.tasks + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column() + ', schema_name); + + EXECUTE format(' + CREATE TRIGGER update_contexts_updated_at + BEFORE UPDATE ON %I.contexts + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column() + ', schema_name); + + EXECUTE format(' + CREATE TRIGGER update_webhook_configs_updated_at + BEFORE UPDATE ON %I.webhook_configs + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column() + ', schema_name); + + EXECUTE format(' + CREATE TRIGGER update_task_checkpoints_updated_at + BEFORE UPDATE ON %I.task_checkpoints + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column() + ', schema_name); + + RAISE NOTICE 'Created all Bindu tables in schema: %', schema_name; + END; + $$ LANGUAGE plpgsql; + """) + + # Update drop function to include task_checkpoints + op.execute(""" + CREATE OR REPLACE FUNCTION drop_bindu_tables_in_schema(schema_name TEXT) + RETURNS VOID AS $$ + BEGIN + EXECUTE format('DROP TABLE IF EXISTS %I.task_checkpoints CASCADE', schema_name); + EXECUTE format('DROP TABLE IF EXISTS %I.task_feedback CASCADE', schema_name); + EXECUTE format('DROP TABLE IF EXISTS %I.webhook_configs CASCADE', schema_name); + EXECUTE format('DROP TABLE IF EXISTS %I.tasks CASCADE', schema_name); + EXECUTE format('DROP TABLE IF EXISTS %I.contexts CASCADE', schema_name); + + RAISE NOTICE 'Dropped all Bindu tables in schema: %', schema_name; + END; + $$ LANGUAGE plpgsql; + """) + + +def downgrade() -> None: + """Remove task_checkpoints table.""" + op.drop_table("task_checkpoints") diff --git a/bindu/server/storage/base.py b/bindu/server/storage/base.py index f18e5a86..02ef9e41 100644 --- a/bindu/server/storage/base.py +++ b/bindu/server/storage/base.py @@ -277,3 +277,56 @@ async def load_all_webhook_configs(self) -> dict[UUID, PushNotificationConfig]: Returns: Dictionary mapping task IDs to their webhook configurations """ + + # ------------------------------------------------------------------------- + # Checkpoint Operations (for pause/resume support) + # ------------------------------------------------------------------------- + + @abstractmethod + async def save_checkpoint( + self, + task_id: UUID, + checkpoint_data: dict[str, Any], + step_number: int = 0, + step_label: str | None = None, + ) -> None: + """Save a checkpoint for task pause/resume. + + Stores execution state that can be restored when resuming a paused task. + + Args: + task_id: Task to save checkpoint for + checkpoint_data: Execution state to persist + step_number: Current step in execution + step_label: Optional label for current step + """ + + @abstractmethod + async def get_checkpoint(self, task_id: UUID) -> dict[str, Any] | None: + """Load the latest checkpoint for a task. + + Args: + task_id: Task to load checkpoint for + + Returns: + Checkpoint data if found, None otherwise + """ + + @abstractmethod + async def delete_checkpoint(self, task_id: UUID) -> None: + """Delete checkpoint(s) for a task. + + Args: + task_id: Task to delete checkpoint(s) for + """ + + async def cleanup_old_checkpoints(self, days_old: int = 7) -> int: + """Delete checkpoints older than specified days. + + Args: + days_old: Delete checkpoints older than this many days + + Returns: + Number of checkpoints deleted + """ + return 0 diff --git a/bindu/server/storage/memory_storage.py b/bindu/server/storage/memory_storage.py index 06de7d72..781b541b 100644 --- a/bindu/server/storage/memory_storage.py +++ b/bindu/server/storage/memory_storage.py @@ -49,17 +49,16 @@ class InMemoryStorage(Storage[ContextT]): - tasks: Dict[UUID, Task] - All tasks indexed by task_id - contexts: Dict[UUID, list[UUID]] - Task IDs grouped by context_id - task_feedback: Dict[UUID, List[dict]] - Optional feedback storage + - _checkpoints: Dict[UUID, list] - Checkpoints for pause/resume """ def __init__(self): - """Initialize in-memory storage. - - Note: This is an __init__ method. - """ + """Initialize in-memory storage.""" self.tasks: dict[UUID, Task] = {} self.contexts: dict[UUID, list[UUID]] = {} self.task_feedback: dict[UUID, list[dict[str, Any]]] = {} self._webhook_configs: dict[UUID, PushNotificationConfig] = {} + self._checkpoints: dict[UUID, list[dict[str, Any]]] = {} @retry_storage_operation(max_attempts=3, min_wait=0.1, max_wait=1) async def load_task( @@ -81,10 +80,8 @@ async def load_task( if task is None: return None - # Always return a deep copy to prevent mutations affecting stored task task_copy = cast(Task, copy.deepcopy(task)) - # Limit history if requested if history_length is not None and history_length > 0 and "history" in task: task_copy["history"] = task["history"][-history_length:] @@ -92,28 +89,10 @@ async def load_task( @retry_storage_operation(max_attempts=3, min_wait=0.1, max_wait=1) async def submit_task(self, context_id: UUID, message: Message) -> Task: - """Create a new task or continue an existing non-terminal task. - - Task-First Pattern (Bindu): - - If task exists and is in non-terminal state: Append message and reset to 'submitted' - - If task exists and is in terminal state: Raise error (immutable) - - If task doesn't exist: Create new task - - Args: - context_id: Context to associate the task with - message: Initial message containing task request - - Returns: - Task in 'submitted' state (new or continued) - - Raises: - TypeError: If IDs are invalid types - ValueError: If attempting to continue a terminal task - """ + """Create a new task or continue an existing non-terminal task.""" if not isinstance(context_id, UUID): raise TypeError(f"context_id must be UUID, got {type(context_id).__name__}") - # Parse task ID from message (handle both snake_case and camelCase) task_id_raw = message.get("task_id") task_id: UUID @@ -126,7 +105,6 @@ async def submit_task(self, context_id: UUID, message: Message) -> Task: f"task_id must be UUID or str, got {type(task_id_raw).__name__}" ) - # Ensure all UUID fields are proper UUID objects (normalize to snake_case) message["task_id"] = task_id message["context_id"] = context_id @@ -138,7 +116,6 @@ async def submit_task(self, context_id: UUID, message: Message) -> Task: f"message_id must be UUID or str, got {type(message_id_raw).__name__}" ) - # Validate and normalize reference_task_ids if present (handle both formats) ref_ids_key = "reference_task_ids" if ref_ids_key in message: ref_ids = message[ref_ids_key] @@ -155,21 +132,17 @@ async def submit_task(self, context_id: UUID, message: Message) -> Task: ) message["reference_task_ids"] = normalized_refs - # Check if task already exists existing_task = self.tasks.get(task_id) if existing_task: - # Task exists - check if it's mutable current_state = existing_task["status"]["state"] - # Check if task is in terminal state (immutable) if current_state in app_settings.agent.terminal_states: raise ValueError( f"Cannot continue task {task_id}: Task is in terminal state '{current_state}' and is immutable. " f"Create a new task with referenceTaskIds to continue the conversation." ) - # Non-terminal states (mutable) - append message and continue logger.info( f"Continuing existing task {task_id} from state '{current_state}'" ) @@ -178,14 +151,12 @@ async def submit_task(self, context_id: UUID, message: Message) -> Task: existing_task["history"] = [] existing_task["history"].append(message) - # Reset to submitted state for re-execution existing_task["status"] = TaskStatus( state="submitted", timestamp=datetime.now(timezone.utc).isoformat() ) return existing_task - # Task doesn't exist - create new task task_status = TaskStatus( state="submitted", timestamp=datetime.now(timezone.utc).isoformat() ) @@ -198,7 +169,6 @@ async def submit_task(self, context_id: UUID, message: Message) -> Task: ) self.tasks[task_id] = task - # Add task to context if context_id not in self.contexts: self.contexts[context_id] = [] self.contexts[context_id].append(task_id) @@ -214,26 +184,7 @@ async def update_task( new_messages: list[Message] | None = None, metadata: dict[str, Any] | None = None, ) -> Task: - """Update task state and append new content. - - Hybrid Pattern Support: - - Message only: update_task(task_id, "input-required", new_messages=[...], metadata={...}) - - Completion: update_task(task_id, "completed", new_artifacts=[...], new_messages=[...]) - - Args: - task_id: Task to update - state: New task state (working, completed, failed, etc.) - new_artifacts: Optional artifacts to append (for completion) - new_messages: Optional messages to append to history - metadata: Optional metadata to update/merge with task metadata - - Returns: - Updated task object - - Raises: - TypeError: If task_id is not UUID - KeyError: If task not found - """ + """Update task state and append new content.""" if not isinstance(task_id, UUID): raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") @@ -258,7 +209,6 @@ async def update_task( if new_messages: if "history" not in task: task["history"] = [] - # Add IDs to messages for consistency for message in new_messages: if not isinstance(message, dict): raise TypeError( @@ -271,36 +221,12 @@ async def update_task( return task async def update_context(self, context_id: UUID, context: ContextT) -> None: - """Store or update context metadata. - - Note: This stores additional context metadata. Task associations are - managed automatically via submit_task(). - - Args: - context_id: Context identifier - context: Context data (format determined by agent implementation) - - Raises: - TypeError: If context_id is not UUID - """ + """Store or update context metadata.""" if not isinstance(context_id, UUID): raise TypeError(f"context_id must be UUID, got {type(context_id).__name__}") - # Note: This method is kept for backward compatibility but contexts - # are now primarily managed as task lists - async def load_context(self, context_id: UUID) -> list[UUID] | None: - """Load context task list from storage. - - Args: - context_id: Unique identifier of the context - - Returns: - List of task UUIDs if context exists, None otherwise - - Raises: - TypeError: If context_id is not UUID - """ + """Load context task list from storage.""" if not isinstance(context_id, UUID): raise TypeError(f"context_id must be UUID, got {type(context_id).__name__}") @@ -309,51 +235,23 @@ async def load_context(self, context_id: UUID) -> list[UUID] | None: async def append_to_contexts( self, context_id: UUID, messages: list[Message] ) -> None: - """Append messages to context history. - - Note: This method is deprecated as contexts now store task lists. - Messages are stored in task history instead. - - Args: - context_id: Context to update - messages: Messages to append to history - - Raises: - TypeError: If context_id is not UUID or messages is not a list - """ + """Append messages to context history.""" if not isinstance(context_id, UUID): raise TypeError(f"context_id must be UUID, got {type(context_id).__name__}") if not isinstance(messages, list): raise TypeError(f"messages must be list, got {type(messages).__name__}") - self.contexts[context_id] = [] - async def list_tasks(self, length: int | None = None) -> list[Task]: - """List all tasks in storage. - - Args: - length: Optional limit on number of tasks to return (most recent) - - Returns: - List of tasks - """ + """List all tasks in storage.""" if length is None: return list(self.tasks.values()) - # Optimize: Only convert to list what we need all_tasks = list(self.tasks.values()) return all_tasks[-length:] if length < len(all_tasks) else all_tasks async def count_tasks(self, status: str | None = None) -> int: - """Count number of tasks, optionally filtered by status. - - Args: - status: Optional status to filter by - - Returns: - Count of matching tasks - """ + """Count number of tasks, optionally filtered by status.""" if status is None: return len(self.tasks) @@ -362,24 +260,10 @@ async def count_tasks(self, status: str | None = None) -> int: async def list_tasks_by_context( self, context_id: UUID, length: int | None = None ) -> list[Task]: - """List tasks belonging to a specific context. - - Used for building conversation history and supporting task refinements. - - Args: - context_id: Context to filter tasks by - length: Optional limit on number of tasks to return (most recent) - - Returns: - List of tasks in the context - - Raises: - TypeError: If context_id is not UUID - """ + """List tasks belonging to a specific context.""" if not isinstance(context_id, UUID): raise TypeError(f"context_id must be UUID, got {type(context_id).__name__}") - # Get task IDs from context task_ids = self.contexts.get(context_id, []) tasks: list[Task] = [ self.tasks[task_id] for task_id in task_ids if task_id in self.tasks @@ -390,14 +274,7 @@ async def list_tasks_by_context( return tasks async def list_contexts(self, length: int | None = None) -> list[dict[str, Any]]: - """List all contexts in storage. - - Args: - length: Optional limit on number of contexts to return (most recent) - - Returns: - List of context objects with task counts - """ + """List all contexts in storage.""" contexts = [ {"context_id": ctx_id, "task_count": len(task_ids), "task_ids": task_ids} for ctx_id, task_ids in self.contexts.items() @@ -408,62 +285,37 @@ async def list_contexts(self, length: int | None = None) -> list[dict[str, Any]] return contexts async def clear_context(self, context_id: UUID) -> None: - """Clear all tasks associated with a specific context. - - Args: - context_id: The context ID to clear - - Raises: - TypeError: If context_id is not UUID - ValueError: If context does not exist - - Warning: This is a destructive operation. - """ + """Clear all tasks associated with a specific context.""" if not isinstance(context_id, UUID): raise TypeError(f"context_id must be UUID, got {type(context_id).__name__}") - # Check if context exists if context_id not in self.contexts: raise ValueError(f"Context {context_id} not found") - # Get task IDs from the context task_ids = self.contexts.get(context_id, []) - # Remove all tasks associated with this context for task_id in task_ids: if task_id in self.tasks: del self.tasks[task_id] - # Also clear feedback for these tasks if task_id in self.task_feedback: del self.task_feedback[task_id] - # Remove the context itself del self.contexts[context_id] logger.info(f"Cleared context {context_id}: removed {len(task_ids)} tasks") async def clear_all(self) -> None: - """Clear all tasks and contexts from storage. - - Warning: This is a destructive operation. - """ + """Clear all tasks and contexts from storage.""" self.tasks.clear() self.contexts.clear() self.task_feedback.clear() self._webhook_configs.clear() + self._checkpoints.clear() async def store_task_feedback( self, task_id: UUID, feedback_data: dict[str, Any] ) -> None: - """Store user feedback for a task. - - Args: - task_id: Task to associate feedback with - feedback_data: Feedback content (rating, comments, etc.) - - Raises: - TypeError: If task_id is not UUID or feedback_data is not dict - """ + """Store user feedback for a task.""" if not isinstance(task_id, UUID): raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") @@ -477,38 +329,16 @@ async def store_task_feedback( self.task_feedback[task_id].append(feedback_data) async def get_task_feedback(self, task_id: UUID) -> list[dict[str, Any]] | None: - """Retrieve feedback for a task. - - Args: - task_id: Task to get feedback for - - Returns: - List of feedback entries or None if no feedback exists - - Raises: - TypeError: If task_id is not UUID - """ + """Retrieve feedback for a task.""" if not isinstance(task_id, UUID): raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") return self.task_feedback.get(task_id) - # ------------------------------------------------------------------------- - # Webhook Persistence Operations (for long-running tasks) - # ------------------------------------------------------------------------- - async def save_webhook_config( self, task_id: UUID, config: PushNotificationConfig ) -> None: - """Save a webhook configuration for a task. - - Args: - task_id: Task to associate the webhook config with - config: Push notification configuration to persist - - Raises: - TypeError: If task_id is not UUID - """ + """Save a webhook configuration for a task.""" if not isinstance(task_id, UUID): raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") @@ -516,33 +346,14 @@ async def save_webhook_config( logger.debug(f"Saved webhook config for task {task_id}") async def load_webhook_config(self, task_id: UUID) -> PushNotificationConfig | None: - """Load a webhook configuration for a task. - - Args: - task_id: Task to load the webhook config for - - Returns: - The webhook configuration if found, None otherwise - - Raises: - TypeError: If task_id is not UUID - """ + """Load a webhook configuration for a task.""" if not isinstance(task_id, UUID): raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") return self._webhook_configs.get(task_id) async def delete_webhook_config(self, task_id: UUID) -> None: - """Delete a webhook configuration for a task. - - Args: - task_id: Task to delete the webhook config for - - Raises: - TypeError: If task_id is not UUID - - Note: Does not raise if the config doesn't exist. - """ + """Delete a webhook configuration for a task.""" if not isinstance(task_id, UUID): raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") @@ -551,9 +362,79 @@ async def delete_webhook_config(self, task_id: UUID) -> None: logger.debug(f"Deleted webhook config for task {task_id}") async def load_all_webhook_configs(self) -> dict[UUID, PushNotificationConfig]: - """Load all stored webhook configurations. - - Returns: - Dictionary mapping task IDs to their webhook configurations - """ + """Load all stored webhook configurations.""" return dict(self._webhook_configs) + + async def save_checkpoint( + self, + task_id: UUID, + checkpoint_data: dict[str, Any], + step_number: int = 0, + step_label: str | None = None, + ) -> None: + """Save a checkpoint for task pause/resume.""" + if not isinstance(task_id, UUID): + raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") + + if not isinstance(checkpoint_data, dict): + raise TypeError( + f"checkpoint_data must be dict, got {type(checkpoint_data).__name__}" + ) + + if task_id not in self._checkpoints: + self._checkpoints[task_id] = [] + + self._checkpoints[task_id].append( + { + "checkpoint_data": checkpoint_data, + "step_number": step_number, + "step_label": step_label, + "created_at": datetime.now(timezone.utc).isoformat(), + } + ) + logger.debug(f"Saved checkpoint for task {task_id} at step {step_number}") + + async def get_checkpoint(self, task_id: UUID) -> dict[str, Any] | None: + """Load the latest checkpoint for a task.""" + if not isinstance(task_id, UUID): + raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") + + checkpoints = self._checkpoints.get(task_id, []) + if not checkpoints: + return None + + return checkpoints[-1] + + async def delete_checkpoint(self, task_id: UUID) -> None: + """Delete checkpoint(s) for a task.""" + if not isinstance(task_id, UUID): + raise TypeError(f"task_id must be UUID, got {type(task_id).__name__}") + + if task_id in self._checkpoints: + del self._checkpoints[task_id] + logger.debug(f"Deleted checkpoint(s) for task {task_id}") + + async def list_checkpoints( + self, task_id: UUID | None = None, limit: int = 100 + ) -> list[dict[str, Any]]: + """List checkpoints, optionally filtered by task_id.""" + results = [] + + for cp_task_id, checkpoints in self._checkpoints.items(): + if task_id is not None and cp_task_id != task_id: + continue + for cp in checkpoints: + results.append( + { + "task_id": cp_task_id, + "checkpoint_data": cp.get("checkpoint_data", {}), + "step_number": cp.get("step_number", 0), + "step_label": cp.get("step_label"), + "created_at": cp.get("created_at"), + } + ) + + # Sort by created_at descending (most recent first) + results.sort(key=lambda x: x.get("created_at", ""), reverse=True) + + return results[:limit] diff --git a/bindu/server/storage/postgres_storage.py b/bindu/server/storage/postgres_storage.py index 9a7efc63..8486317f 100644 --- a/bindu/server/storage/postgres_storage.py +++ b/bindu/server/storage/postgres_storage.py @@ -57,6 +57,7 @@ from .helpers.db_operations import get_current_utc_timestamp from .schema import ( contexts_table, + task_checkpoints_table, task_feedback_table, tasks_table, webhook_configs_table, @@ -1019,3 +1020,191 @@ async def _load_all(): return {row.task_id: row.config for row in rows} return await self._retry_on_connection_error(_load_all) + + # ------------------------------------------------------------------------- + # Checkpoint Operations (for pause/resume support) + # ------------------------------------------------------------------------- + + async def save_checkpoint( + self, + task_id: UUID, + checkpoint_data: dict[str, Any], + step_number: int = 0, + step_label: str | None = None, + ) -> None: + """Save a checkpoint for task pause/resume using SQLAlchemy. + + Stores execution state that can be restored when resuming a paused task. + + Args: + task_id: Task to save checkpoint for + checkpoint_data: Execution state to persist + step_number: Current step in execution + step_label: Optional label for current step + + Raises: + TypeError: If task_id is not UUID or checkpoint_data is not dict + """ + task_id = validate_uuid_type(task_id, "task_id") + + if not isinstance(checkpoint_data, dict): + raise TypeError( + f"checkpoint_data must be dict, got {type(checkpoint_data).__name__}" + ) + + self._ensure_connected() + + async def _save(): + async with self._get_session_with_schema() as session: + async with session.begin(): + serialized_data = serialize_for_jsonb(checkpoint_data) + stmt = insert(task_checkpoints_table).values( + task_id=task_id, + checkpoint_data=serialized_data, + step_number=step_number, + step_label=step_label, + ) + await session.execute(stmt) + logger.debug( + f"Saved checkpoint for task {task_id} at step {step_number}" + ) + + await self._retry_on_connection_error(_save) + + async def get_checkpoint(self, task_id: UUID) -> dict[str, Any] | None: + """Load the latest checkpoint for a task using SQLAlchemy. + + Args: + task_id: Task to load checkpoint for + + Returns: + Checkpoint data if found, None otherwise + + Raises: + TypeError: If task_id is not UUID + """ + task_id = validate_uuid_type(task_id, "task_id") + + self._ensure_connected() + + async def _get(): + async with self._get_session_with_schema() as session: + stmt = ( + select(task_checkpoints_table) + .where(task_checkpoints_table.c.task_id == task_id) + .order_by(task_checkpoints_table.c.created_at.desc()) + .limit(1) + ) + result = await session.execute(stmt) + row = result.first() + + if row is None: + return None + + return { + "checkpoint_data": row.checkpoint_data, + "step_number": row.step_number, + "step_label": row.step_label, + "created_at": row.created_at.isoformat() + if row.created_at + else None, + } + + return await self._retry_on_connection_error(_get) + + async def delete_checkpoint(self, task_id: UUID) -> None: + """Delete checkpoint(s) for a task using SQLAlchemy. + + Args: + task_id: Task to delete checkpoint(s) for + + Raises: + TypeError: If task_id is not UUID + """ + task_id = validate_uuid_type(task_id, "task_id") + + self._ensure_connected() + + async def _delete(): + async with self._get_session_with_schema() as session: + async with session.begin(): + stmt = delete(task_checkpoints_table).where( + task_checkpoints_table.c.task_id == task_id + ) + result = await session.execute(stmt) + if result.rowcount > 0: + logger.debug(f"Deleted checkpoint(s) for task {task_id}") + + await self._retry_on_connection_error(_delete) + + async def list_checkpoints( + self, task_id: UUID | None = None, limit: int = 100 + ) -> list[dict[str, Any]]: + """List checkpoints, optionally filtered by task_id. + + Args: + task_id: Optional task ID to filter by + limit: Maximum number of checkpoints to return + + Returns: + List of checkpoint records + """ + self._ensure_connected() + + async def _list(): + async with self._get_session_with_schema() as session: + stmt = select(task_checkpoints_table).order_by( + task_checkpoints_table.c.created_at.desc() + ) + + if task_id is not None: + task_id = validate_uuid_type(task_id, "task_id") + stmt = stmt.where(task_checkpoints_table.c.task_id == task_id) + + stmt = stmt.limit(limit) + result = await session.execute(stmt) + rows = result.fetchall() + + return [ + { + "id": row.id, + "task_id": row.task_id, + "checkpoint_data": row.checkpoint_data, + "step_number": row.step_number, + "step_label": row.step_label, + "created_at": row.created_at.isoformat() + if row.created_at + else None, + } + for row in rows + ] + + return await self._retry_on_connection_error(_list) + + async def cleanup_old_checkpoints(self, days_old: int = 7) -> int: + """Delete checkpoints older than specified days. + + Args: + days_old: Delete checkpoints older than this many days + + Returns: + Number of checkpoints deleted + """ + self._ensure_connected() + + async def _cleanup(): + async with self._get_session_with_schema() as session: + async with session.begin(): + stmt = delete(task_checkpoints_table).where( + task_checkpoints_table.c.created_at + < func.now() - func.make_interval(days=days_old) + ) + result = await session.execute(stmt) + deleted_count = result.rowcount + if deleted_count > 0: + logger.info( + f"Cleaned up {deleted_count} checkpoints older than {days_old} days" + ) + return deleted_count + + return await self._retry_on_connection_error(_cleanup) diff --git a/bindu/server/storage/schema.py b/bindu/server/storage/schema.py index 2e678fa9..e7f3d496 100644 --- a/bindu/server/storage/schema.py +++ b/bindu/server/storage/schema.py @@ -187,6 +187,48 @@ comment="Webhook configurations for long-running task notifications", ) +# ----------------------------------------------------------------------------- +# Task Checkpoints Table (for pause/resume support) +# ----------------------------------------------------------------------------- + +task_checkpoints_table = Table( + "task_checkpoints", + metadata, + # Primary key + Column("id", Integer, primary_key=True, autoincrement=True, nullable=False), + # Foreign key to task + Column( + "task_id", + PG_UUID(as_uuid=True), + ForeignKey("tasks.id", ondelete="CASCADE"), + nullable=False, + ), + # Checkpoint data - stores execution state for resume + Column("checkpoint_data", JSONB, nullable=False), + # Step info for tracking progress + Column("step_number", Integer, nullable=False, default=0), + Column("step_label", String(255), nullable=True), + # Timestamps + Column( + "created_at", + TIMESTAMP(timezone=True), + nullable=False, + server_default=func.now(), + ), + Column( + "updated_at", + TIMESTAMP(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ), + # Indexes + Index("idx_task_checkpoints_task_id", "task_id"), + Index("idx_task_checkpoints_created_at", "created_at"), + # Table comment + comment="Task checkpoints for pause/resume support", +) + # ----------------------------------------------------------------------------- # Helper Functions # ----------------------------------------------------------------------------- diff --git a/bindu/server/workers/base.py b/bindu/server/workers/base.py index 1ff58950..c0c8d86c 100644 --- a/bindu/server/workers/base.py +++ b/bindu/server/workers/base.py @@ -31,6 +31,7 @@ from bindu.common.protocol.types import Artifact, Message, TaskIdParams, TaskSendParams from bindu.server.scheduler.base import Scheduler from bindu.server.storage.base import Storage +from bindu.settings import app_settings from bindu.utils.logging import get_logger tracer = get_tracer(__name__) @@ -215,25 +216,130 @@ def build_artifacts(self, result: Any) -> list[Artifact]: ... # ------------------------------------------------------------------------- - # Future Operations (Not Yet Implemented) + # Pause/Resume Operations # ------------------------------------------------------------------------- async def _handle_pause(self, params: TaskIdParams) -> None: - """Handle pause operation. + """Handle pause operation - suspend task execution with checkpoint. - TODO: Implement task pause functionality - - Save current execution state - - Update task to 'suspended' state - - Release resources while preserving context + Saves current execution state and updates task to 'suspended' state. + The task can be resumed later from the checkpoint. + + Args: + params: Task identification parameters containing task_id """ - raise NotImplementedError("Pause operation not yet implemented") + from opentelemetry.trace import get_current_span + + task_id = params["task_id"] + task = await self.storage.load_task(task_id) + + if task is None: + logger.warning(f"Cannot pause task {task_id}: task not found") + return + + current_state = task["status"]["state"] + + # Check if task can be paused + if current_state not in app_settings.agent.pausable_states: + logger.warning( + f"Cannot pause task {task_id}: task is in '{current_state}' state, " + f"which is not pausable. Pausable states: {app_settings.agent.pausable_states}" + ) + return + + # Add span event for pause + current_span = get_current_span() + if current_span.is_recording(): + current_span.add_event( + "task.state_changed", + attributes={ + "from_state": current_state, + "to_state": "suspended", + "operation": "pause", + }, + ) + + # Save checkpoint with current task state + checkpoint_data = { + "task_state": current_state, + "history": task.get("history", []), + "artifacts": task.get("artifacts", []), + "metadata": task.get("metadata", {}), + } + await self.storage.save_checkpoint( + task_id=task_id, + checkpoint_data=checkpoint_data, + step_number=0, + step_label="paused", + ) + + # Update task to suspended state + await self.storage.update_task(task_id, state="suspended") + await self._notify_lifecycle(task_id, task["context_id"], "suspended", False) + + logger.info(f"Task {task_id} paused at checkpoint") async def _handle_resume(self, params: TaskIdParams) -> None: - """Handle resume operation. + """Handle resume operation - restore task from checkpoint. - TODO: Implement task resume functionality - - Restore execution state - - Update task to 'resumed' state - - Continue from last checkpoint + Loads checkpoint data and updates task to 'working' state. + The task can then be picked up by the scheduler for execution. + + Args: + params: Task identification parameters containing task_id """ - raise NotImplementedError("Resume operation not yet implemented") + from opentelemetry.trace import get_current_span + + task_id = params["task_id"] + task = await self.storage.load_task(task_id) + + if task is None: + logger.warning(f"Cannot resume task {task_id}: task not found") + return + + current_state = task["status"]["state"] + + # Check if task is in suspended state + if current_state != "suspended": + logger.warning( + f"Cannot resume task {task_id}: task is in '{current_state}' state, " + f"only suspended tasks can be resumed" + ) + return + + # Load checkpoint + checkpoint = await self.storage.get_checkpoint(task_id) + if checkpoint is None: + logger.warning(f"Cannot resume task {task_id}: no checkpoint found") + return + + # Add span event for resume + current_span = get_current_span() + if current_span.is_recording(): + current_span.add_event( + "task.state_changed", + attributes={ + "from_state": "suspended", + "to_state": "working", + "operation": "resume", + "step_number": checkpoint.get("step_number", 0), + }, + ) + + # Update task metadata with checkpoint info + checkpoint_info = { + "resumed_from_checkpoint": True, + "checkpoint_step": checkpoint.get("step_number", 0), + "checkpoint_label": checkpoint.get("step_label"), + } + await self.storage.update_task( + task_id, + state="working", + metadata=checkpoint_info, + ) + + await self._notify_lifecycle(task_id, task["context_id"], "working", False) + + logger.info( + f"Task {task_id} resumed from checkpoint at step {checkpoint.get('step_number', 0)}" + ) diff --git a/bindu/server/workers/manifest_worker.py b/bindu/server/workers/manifest_worker.py index ffb13dd7..62d58c06 100644 --- a/bindu/server/workers/manifest_worker.py +++ b/bindu/server/workers/manifest_worker.py @@ -468,6 +468,9 @@ async def _handle_terminal_state( await self._notify_lifecycle(task["id"], task["context_id"], state, True) + # Clean up checkpoint after successful completion + await self.storage.delete_checkpoint(task["id"]) + elif state in ("failed", "rejected"): # Failure/Rejection: Message only (explanation), NO artifacts error_message = MessageConverter.to_protocol_messages( @@ -481,11 +484,17 @@ async def _handle_terminal_state( ) await self._notify_lifecycle(task["id"], task["context_id"], state, True) + # Clean up checkpoint after failure + await self.storage.delete_checkpoint(task["id"]) + elif state == "canceled": # Canceled: State change only, NO new content await self.storage.update_task(task["id"], state=state) await self._notify_lifecycle(task["id"], task["context_id"], state, True) + # Clean up checkpoint after cancellation + await self.storage.delete_checkpoint(task["id"]) + async def _handle_task_failure(self, task: dict[str, Any], error: str) -> None: """Handle task execution failure. diff --git a/bindu/settings.py b/bindu/settings.py index 4d94d4bd..a180bbbe 100644 --- a/bindu/settings.py +++ b/bindu/settings.py @@ -378,6 +378,7 @@ class AgentSettings(BaseSettings): "working", # Agent actively processing "input-required", # Waiting for user input "auth-required", # Waiting for authentication + "suspended", # Task is paused and can be resumed } ) @@ -391,6 +392,15 @@ class AgentSettings(BaseSettings): } ) + # Pausable states: Tasks can be paused from these states + pausable_states: frozenset[str] = frozenset( + { + "submitted", # Task submitted, awaiting execution + "working", # Agent actively processing + "input-required", # Waiting for user input + } + ) + # message/stream polling behavior stream_poll_interval_seconds: float = 0.1 stream_missing_task_retries: int = 2 diff --git a/docs/CHECKPOINTS.md b/docs/CHECKPOINTS.md new file mode 100644 index 00000000..99115bad --- /dev/null +++ b/docs/CHECKPOINTS.md @@ -0,0 +1,170 @@ +# Checkpoint System + +The checkpoint system enables pause/resume functionality for long-running tasks, allowing them to be suspended and resumed later from the same point without losing progress. + +## Overview + +Checkpoints store the execution state of a task at a specific point in time. When a task is paused, its current state is saved to persistent storage. When the task is resumed, the checkpoint is retrieved and the task continues from where it left off. + +## Checkpoint Data Structure + +```python +{ + "task_id": "uuid", + "checkpoint_data": { + "task_state": "working", + "history": [...], + "artifacts": [...], + "metadata": {...} + }, + "step_number": 0, + "step_label": "processing", + "created_at": "2024-01-01T00:00:00Z" +} +``` + +### Fields + +| Field | Type | Description | +|-------|------|-------------| +| `task_id` | UUID | Unique identifier of the task | +| `checkpoint_data` | JSON | Execution state including task state, history, artifacts, and metadata | +| `step_number` | int | Current step in execution (0-based) | +| `step_label` | str | Optional label for the current step | +| `created_at` | datetime | When the checkpoint was created | + +## Checkpoint Lifecycle + +### 1. Creation (Pause) + +When a task is paused: + +1. **Validation**: Check if task is in a pausable state (`submitted`, `working`, `input-required`) +2. **Save**: Create checkpoint with current task state +3. **Update State**: Change task state to `suspended` +4. **Notify**: Trigger lifecycle notification + +``` +working → pause → suspended +``` + +### 2. Retrieval (Resume) + +When a task is resumed: + +1. **Validation**: Check if task is in `suspended` state +2. **Load**: Retrieve latest checkpoint +3. **Update State**: Change task state to `working` +4. **Restore**: Apply checkpoint metadata to task + +``` +suspended → resume → working +``` + +### 3. Cleanup + +Checkpoints are automatically deleted when: + +- Task reaches terminal state: `completed`, `failed`, `canceled`, `rejected` +- Task is explicitly deleted + +## Storage + +### PostgreSQL + +Checkpoints are stored in the `task_checkpoints` table: + +```sql +CREATE TABLE task_checkpoints ( + id SERIAL PRIMARY KEY, + task_id UUID NOT NULL REFERENCES tasks(id) ON DELETE CASCADE, + checkpoint_data JSONB NOT NULL, + step_number INTEGER NOT NULL DEFAULT 0, + step_label VARCHAR(255), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_task_checkpoints_task_id ON task_checkpoints(task_id); +CREATE INDEX idx_task_checkpoints_created_at ON task_checkpoints(created_at DESC); +``` + +### Memory Storage + +For development/testing, in-memory storage maintains checkpoints in a dictionary. + +## API + +### Storage Interface + +```python +# Save a checkpoint +await storage.save_checkpoint( + task_id: UUID, + checkpoint_data: dict, + step_number: int = 0, + step_label: str | None = None +) + +# Get latest checkpoint +checkpoint = await storage.get_checkpoint(task_id: UUID) +# Returns: {checkpoint_data, step_number, step_label, created_at} or None + +# Delete checkpoint(s) +await storage.delete_checkpoint(task_id: UUID) + +# List checkpoints (optionally filtered by task_id) +checkpoints = await storage.list_checkpoints(task_id: UUID | None = None, limit: int = 100) + +# Cleanup old checkpoints +deleted_count = await storage.cleanup_old_checkpoints(days_old: int = 7) +``` + +### Worker Interface + +```python +# Pause a task +await worker._handle_pause(TaskIdParams(task_id=...)) + +# Resume a task +await worker._handle_resume(TaskIdParams(task_id=...)) +``` + +## Configuration + +### Pausable States + +Tasks can only be paused from these states: + +```python +pausable_states = frozenset({ + "submitted", # Task submitted, awaiting execution + "working", # Agent actively processing + "input-required" # Waiting for user input +}) +``` + +## Best Practices + +1. **Frequent Checkpoints**: For long-running tasks, save checkpoints periodically during execution +2. **Minimal Data**: Store only essential state in checkpoints to minimize storage +3. **Cleanup**: Ensure checkpoints are cleaned up after task completion +4. **Idempotency**: Pause/resume operations should be idempotent + +## Example Usage + +```python +# Pause a task +await worker._handle_pause({"task_id": task_id}) +# Task is now suspended, checkpoint saved + +# Resume a task +await worker._handle_resume({"task_id": task_id}) +# Task is now working, continues from checkpoint +``` + +## Error Handling + +- **Pause non-pausable task**: Silently returns without action (logs warning) +- **Resume without checkpoint**: Silently returns without action (logs warning) +- **Resume non-suspended task**: Silently returns without action (logs warning) diff --git a/docs/STATE_TRANSITIONS.md b/docs/STATE_TRANSITIONS.md new file mode 100644 index 00000000..c853ab8f --- /dev/null +++ b/docs/STATE_TRANSITIONS.md @@ -0,0 +1,231 @@ +# Task State Machine + +This document describes the task state machine in Bindu, following the A2A Protocol. + +## States + +### Non-Terminal States + +Tasks in these states are mutable and can receive new messages or be modified. + +| State | Description | Can Pause | Can Cancel | +|-------|-------------|-----------|------------| +| `submitted` | Task submitted, awaiting execution | Yes | Yes | +| `working` | Agent actively processing | Yes | Yes | +| `input-required` | Waiting for user input | Yes | Yes | +| `auth-required` | Waiting for authentication | No | Yes | +| `suspended` | Task paused, can be resumed | No | Yes | + +### Terminal States + +Tasks in these states are immutable - no further changes allowed. + +| State | Description | +|-------|-------------| +| `completed` | Successfully completed with artifacts | +| `failed` | Failed due to error | +| `canceled` | Canceled by user | +| `rejected` | Rejected by agent | + +## State Transitions + +### Valid Transitions + +``` +submitted ──→ working ──→ completed + │ │ + │ ├──→ input-required ──→ working + │ │ + │ ├──→ suspended ──→ working (resume) + │ │ + │ ├──→ failed + │ │ + │ └──→ canceled + │ + ├──→ input-required ──→ working + │ + ├──→ suspended ──→ working + │ + ├──→ failed + │ + ├──→ canceled + │ + └──→ rejected +``` + +### Transition Rules + +| From State | To State | Valid | Notes | +|------------|----------|-------|-------| +| `submitted` | `working` | ✓ | Task starts processing | +| `submitted` | `input-required` | ✓ | Agent needs clarification | +| `submitted` | `suspended` | ✓ | Task paused before starting | +| `submitted` | `failed` | ✓ | Task rejected before starting | +| `submitted` | `canceled` | ✓ | User canceled before starting | +| `submitted` | `rejected` | ✓ | Agent rejected | +| `working` | `completed` | ✓ | Task finished successfully | +| `working` | `input-required` | ✓ | Agent needs clarification | +| `working` | `suspended` | ✓ | Task paused during execution | +| `working` | `failed` | ✓ | Task failed with error | +| `working` | `canceled` | ✓ | User canceled | +| `input-required` | `working` | ✓ | User provided input | +| `input-required` | `failed` | ✓ | User failed to provide input | +| `input-required` | `canceled` | ✓ | User canceled | +| `suspended` | `working` | ✓ | Task resumed | +| `suspended` | `canceled` | ✓ | User canceled paused task | +| `suspended` | `failed` | ✓ | Task failed while paused | + +## Pausable States + +Tasks can only be paused from these states: + +```python +pausable_states = frozenset({ + "submitted", # Task submitted, awaiting execution + "working", # Agent actively processing + "input-required" # Waiting for user input +}) +``` + +## State Machine Configuration + +The state machine is configured in `settings.py`: + +```python +# Non-terminal states: Task is mutable +non_terminal_states = frozenset({ + "submitted", + "working", + "input-required", + "auth-required", + "suspended", +}) + +# Terminal states: Task is immutable +terminal_states = frozenset({ + "completed", + "failed", + "canceled", + "rejected", +}) + +# States from which tasks can be paused +pausable_states = frozenset({ + "submitted", + "working", + "input-required", +}) +``` + +## Pause/Resume Flow + +### Pause + +``` +1. Validate task is in pausable state +2. Save checkpoint with current task state +3. Update task state to 'suspended' +4. Trigger lifecycle notification +``` + +### Resume + +``` +1. Validate task is in 'suspended' state +2. Load checkpoint data +3. Update task state to 'working' +4. Apply checkpoint metadata +5. Trigger lifecycle notification +``` + +## Implementation + +### Storage Layer + +The storage layer maintains task state but doesn't enforce transition rules - that's handled at the application/worker layer. + +```python +# Update task state +await storage.update_task(task_id, state="working") + +# Load task +task = await storage.load_task(task_id) +current_state = task["status"]["state"] +``` + +### Worker Layer + +The worker layer enforces state transition rules and handles pause/resume: + +```python +# Check if task can be paused +if current_state not in app_settings.agent.pausable_states: + return # Cannot pause + +# Pause - save checkpoint and update state +await storage.save_checkpoint(task_id, checkpoint_data) +await storage.update_task(task_id, state="suspended") +``` + +## Examples + +### Complete Workflow + +```python +# Submit task +await storage.submit_task(context_id, message) +# State: submitted + +# Start working +await storage.update_task(task_id, state="working") +# State: working + +# Agent needs clarification +await storage.update_task(task_id, state="input-required") +# State: input-required + +# User provides input +await storage.update_task(task_id, state="working") +# State: working + +# Task completes +await storage.update_task(task_id, state="completed") +# State: completed (terminal) +``` + +### Pause/Resume Workflow + +```python +# Submit and start +await storage.submit_task(context_id, message) +await storage.update_task(task_id, state="working") +# State: working + +# Pause +await storage.save_checkpoint(task_id, {...}) +await storage.update_task(task_id, state="suspended") +# State: suspended + +# Resume +await storage.update_task(task_id, state="working") +# State: working + +# Complete +await storage.update_task(task_id, state="completed") +# State: completed +``` + +## Testing + +Run state transition tests: + +```bash +pytest tests/integration/test_state_transitions.py -v +``` + +Test categories: +- State classification (suspended is non-terminal) +- Valid transitions +- Terminal state immutability +- Non-terminal state mutability +- Complete workflows diff --git a/tests/integration/test_checkpoint_persistence.py b/tests/integration/test_checkpoint_persistence.py new file mode 100644 index 00000000..5997e2a1 --- /dev/null +++ b/tests/integration/test_checkpoint_persistence.py @@ -0,0 +1,217 @@ +"""Integration tests for checkpoint persistence.""" + +import pytest +import pytest_asyncio +from uuid import uuid4 +from datetime import datetime, timezone, timedelta + +from bindu.server.storage.memory_storage import InMemoryStorage + + +@pytest_asyncio.fixture +async def storage() -> InMemoryStorage: + """Create a fresh in-memory storage for each test.""" + return InMemoryStorage() + + +@pytest.mark.asyncio +async def test_checkpoint_save_and_retrieve(storage: InMemoryStorage): + """Test basic checkpoint save and retrieve.""" + task_id = uuid4() + checkpoint_data = {"step": 1, "progress": 50, "data": {"key": "value"}} + + # Save checkpoint + await storage.save_checkpoint( + task_id=task_id, + checkpoint_data=checkpoint_data, + step_number=1, + step_label="step-1", + ) + + # Retrieve checkpoint + checkpoint = await storage.get_checkpoint(task_id) + + assert checkpoint is not None + assert checkpoint["checkpoint_data"] == checkpoint_data + assert checkpoint["step_number"] == 1 + assert checkpoint["step_label"] == "step-1" + + +@pytest.mark.asyncio +async def test_checkpoint_with_step_info(storage: InMemoryStorage): + """Test checkpoint stores step information correctly.""" + task_id = uuid4() + + # Save checkpoint with step info + await storage.save_checkpoint( + task_id=task_id, + checkpoint_data={"message": "test"}, + step_number=5, + step_label="processing", + ) + + checkpoint = await storage.get_checkpoint(task_id) + + assert checkpoint["step_number"] == 5 + assert checkpoint["step_label"] == "processing" + + +@pytest.mark.asyncio +async def test_multiple_checkpoints_returns_latest(storage: InMemoryStorage): + """Test that get_checkpoint returns the most recent checkpoint.""" + task_id = uuid4() + + # Save multiple checkpoints + await storage.save_checkpoint( + task_id=task_id, checkpoint_data={"step": 1}, step_number=1 + ) + await storage.save_checkpoint( + task_id=task_id, checkpoint_data={"step": 2}, step_number=2 + ) + await storage.save_checkpoint( + task_id=task_id, checkpoint_data={"step": 3}, step_number=3 + ) + + # Get should return the latest + checkpoint = await storage.get_checkpoint(task_id) + + assert checkpoint["checkpoint_data"]["step"] == 3 + assert checkpoint["step_number"] == 3 + + +@pytest.mark.asyncio +async def test_delete_checkpoint(storage: InMemoryStorage): + """Test checkpoint deletion.""" + task_id = uuid4() + + # Save checkpoint + await storage.save_checkpoint(task_id=task_id, checkpoint_data={"test": True}) + + # Verify exists + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is not None + + # Delete + await storage.delete_checkpoint(task_id) + + # Verify deleted + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is None + + +@pytest.mark.asyncio +async def test_delete_nonexistent_checkpoint(storage: InMemoryStorage): + """Test deleting nonexistent checkpoint doesn't raise.""" + task_id = uuid4() + + # Should not raise + await storage.delete_checkpoint(task_id) + + +@pytest.mark.asyncio +async def test_get_nonexistent_checkpoint(storage: InMemoryStorage): + """Test getting nonexistent checkpoint returns None.""" + task_id = uuid4() + + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is None + + +@pytest.mark.asyncio +async def test_checkpoint_stores_task_id(storage: InMemoryStorage): + """Test that checkpoint correctly stores task_id for reference.""" + task_id = uuid4() + + await storage.save_checkpoint(task_id=task_id, checkpoint_data={"test": True}) + + checkpoints = await storage.list_checkpoints(task_id=task_id) + + assert len(checkpoints) == 1 + assert checkpoints[0]["task_id"] == task_id + + +@pytest.mark.asyncio +async def test_list_checkpoints_with_limit(storage: InMemoryStorage): + """Test listing checkpoints with limit.""" + task_id = uuid4() + + # Create multiple checkpoints + for i in range(10): + await storage.save_checkpoint( + task_id=task_id, checkpoint_data={"step": i}, step_number=i + ) + + # List with limit + checkpoints = await storage.list_checkpoints(limit=5) + + # Should return only 5 (most recent) + assert len(checkpoints) == 5 + + +@pytest.mark.asyncio +async def test_list_checkpoints_by_task_id(storage: InMemoryStorage): + """Test listing checkpoints filtered by task_id.""" + task_id_1 = uuid4() + task_id_2 = uuid4() + + # Create checkpoints for two tasks + await storage.save_checkpoint(task_id=task_id_1, checkpoint_data={"task": 1}) + await storage.save_checkpoint(task_id=task_id_2, checkpoint_data={"task": 2}) + await storage.save_checkpoint( + task_id=task_id_1, checkpoint_data={"task": 1, "step": 2} + ) + + # List checkpoints for task_id_1 only + checkpoints = await storage.list_checkpoints(task_id=task_id_1) + + assert len(checkpoints) == 2 + for cp in checkpoints: + assert cp["task_id"] == task_id_1 + + +@pytest.mark.asyncio +async def test_checkpoint_data_types(storage: InMemoryStorage): + """Test checkpoint handles various data types.""" + task_id = uuid4() + + complex_data = { + "string": "test", + "number": 42, + "float": 3.14, + "bool": True, + "null": None, + "list": [1, 2, 3], + "nested": {"a": {"b": {"c": 1}}}, + } + + await storage.save_checkpoint(task_id=task_id, checkpoint_data=complex_data) + + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint["checkpoint_data"] == complex_data + + +@pytest.mark.asyncio +async def test_checkpoint_empty_data(storage: InMemoryStorage): + """Test checkpoint with empty data.""" + task_id = uuid4() + + await storage.save_checkpoint(task_id=task_id, checkpoint_data={}) + + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint["checkpoint_data"] == {} + + +@pytest.mark.asyncio +async def test_checkpoint_with_uuid_objects(storage: InMemoryStorage): + """Test checkpoint correctly stores UUID objects.""" + task_id = uuid4() + nested_uuid = uuid4() + + await storage.save_checkpoint( + task_id=task_id, + checkpoint_data={"nested_uuid": str(nested_uuid), "uuid": str(task_id)}, + ) + + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint["checkpoint_data"]["uuid"] == str(task_id) + assert checkpoint["checkpoint_data"]["nested_uuid"] == str(nested_uuid) diff --git a/tests/integration/test_pause_resume.py b/tests/integration/test_pause_resume.py new file mode 100644 index 00000000..5082dd61 --- /dev/null +++ b/tests/integration/test_pause_resume.py @@ -0,0 +1,272 @@ +"""Integration tests for pause/resume flow with checkpoint support.""" + +import pytest +import pytest_asyncio +from uuid import uuid4 +from datetime import datetime, timezone + +from bindu.common.protocol.types import Task, TaskStatus +from bindu.server.storage.memory_storage import InMemoryStorage +from bindu.server.workers.base import Worker +from bindu.settings import app_settings + + +class DummyWorker(Worker): + """Dummy worker for testing pause/resume without actual agent execution.""" + + async def run_task(self, params): + pass + + async def cancel_task(self, params): + pass + + def build_message_history(self, history): + return [] + + def build_artifacts(self, result): + return [] + + async def _notify_lifecycle(self, task_id, context_id, state, final): + """Dummy lifecycle notification for testing.""" + pass + + +@pytest_asyncio.fixture +async def storage() -> InMemoryStorage: + """Create a fresh in-memory storage for each test.""" + return InMemoryStorage() + + +@pytest_asyncio.fixture +async def worker(storage: InMemoryStorage) -> DummyWorker: + """Create a dummy worker with test storage.""" + from bindu.server.scheduler.memory_scheduler import InMemoryScheduler + + scheduler = InMemoryScheduler() + return DummyWorker(scheduler=scheduler, storage=storage) + + +def create_test_message(task_id: uuid4, context_id: uuid4): + """Helper to create a test message.""" + return { + "message_id": uuid4(), + "task_id": task_id, + "context_id": context_id, + "role": "user", + "parts": [{"kind": "text", "text": "test"}], + } + + +@pytest.mark.asyncio +async def test_pause_saves_checkpoint_and_updates_state( + storage: InMemoryStorage, worker: DummyWorker +): + """Test that pausing a task saves checkpoint and updates state to suspended.""" + task_id = uuid4() + context_id = uuid4() + + # Create a task in 'working' state by first submitting it + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="working", metadata={"step": 1}) + + # Pause the task + await worker._handle_pause({"task_id": task_id}) + + # Verify task is now suspended + task = await storage.load_task(task_id) + assert task is not None + assert task["status"]["state"] == "suspended" + + # Verify checkpoint was saved + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is not None + assert checkpoint["checkpoint_data"]["task_state"] == "working" + assert checkpoint["checkpoint_data"]["metadata"]["step"] == 1 + + +@pytest.mark.asyncio +async def test_resume_restores_from_checkpoint( + storage: InMemoryStorage, worker: DummyWorker +): + """Test that resuming a task restores from checkpoint and updates state to working.""" + task_id = uuid4() + context_id = uuid4() + + # Create task and pause it + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="working") + + # Pause the task + await worker._handle_pause({"task_id": task_id}) + + # Resume the task + await worker._handle_resume({"task_id": task_id}) + + # Verify task is now working + task = await storage.load_task(task_id) + assert task is not None + assert task["status"]["state"] == "working" + + # Verify checkpoint metadata was added + assert task.get("metadata", {}).get("resumed_from_checkpoint") is True + assert task.get("metadata", {}).get("checkpoint_step") == 0 + + +@pytest.mark.asyncio +async def test_cannot_pause_terminal_tasks( + storage: InMemoryStorage, worker: DummyWorker +): + """Test that terminal tasks cannot be paused.""" + task_id = uuid4() + context_id = uuid4() + + # Create a task and complete it + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="completed") + + # Try to pause - should not raise but should not do anything + await worker._handle_pause({"task_id": task_id}) + + # Verify task is still completed (not suspended) + task = await storage.load_task(task_id) + assert task is not None + assert task["status"]["state"] == "completed" + + # Verify no checkpoint was saved + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is None + + +@pytest.mark.asyncio +async def test_cannot_resume_non_suspended_tasks( + storage: InMemoryStorage, worker: DummyWorker +): + """Test that only suspended tasks can be resumed.""" + task_id = uuid4() + context_id = uuid4() + + # Create a task in 'working' state + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="working") + + # Try to resume - should not raise but should not do anything + await worker._handle_resume({"task_id": task_id}) + + # Verify task is still working + task = await storage.load_task(task_id) + assert task is not None + assert task["status"]["state"] == "working" + + +@pytest.mark.asyncio +async def test_checkpoint_cleanup_on_completion( + storage: InMemoryStorage, worker: DummyWorker +): + """Test that checkpoint is deleted when task completes.""" + task_id = uuid4() + context_id = uuid4() + + # Create and pause a task + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="working") + await worker._handle_pause({"task_id": task_id}) + + # Verify checkpoint exists + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is not None + + # Complete the task and delete checkpoint + await storage.update_task(task_id=task_id, state="completed") + await storage.delete_checkpoint(task_id) + + # Verify checkpoint is deleted + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is None + + +@pytest.mark.asyncio +async def test_checkpoint_cleanup_on_failure( + storage: InMemoryStorage, worker: DummyWorker +): + """Test that checkpoint is deleted when task fails.""" + task_id = uuid4() + context_id = uuid4() + + # Create and pause a task + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="working") + await worker._handle_pause({"task_id": task_id}) + + # Verify checkpoint exists + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is not None + + # Fail the task and delete checkpoint + await storage.update_task(task_id=task_id, state="failed") + await storage.delete_checkpoint(task_id) + + # Verify checkpoint is deleted + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is None + + +@pytest.mark.asyncio +async def test_checkpoint_cleanup_on_cancel( + storage: InMemoryStorage, worker: DummyWorker +): + """Test that checkpoint is deleted when task is canceled.""" + task_id = uuid4() + context_id = uuid4() + + # Create and pause a task + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="working") + await worker._handle_pause({"task_id": task_id}) + + # Verify checkpoint exists + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is not None + + # Cancel the task and delete checkpoint + await storage.update_task(task_id=task_id, state="canceled") + await storage.delete_checkpoint(task_id) + + # Verify checkpoint is deleted + checkpoint = await storage.get_checkpoint(task_id) + assert checkpoint is None + + +@pytest.mark.asyncio +async def test_pause_nonexistent_task(storage: InMemoryStorage, worker: DummyWorker): + """Test that pausing nonexistent task doesn't raise.""" + task_id = uuid4() + + # Should not raise + await worker._handle_pause({"task_id": task_id}) + + +@pytest.mark.asyncio +async def test_resume_nonexistent_task(storage: InMemoryStorage, worker: DummyWorker): + """Test that resuming nonexistent task doesn't raise.""" + task_id = uuid4() + + # Should not raise + await worker._handle_resume({"task_id": task_id}) + + +@pytest.mark.asyncio +async def test_resume_without_checkpoint(storage: InMemoryStorage, worker: DummyWorker): + """Test that resuming without checkpoint doesn't raise but doesn't change state properly.""" + task_id = uuid4() + context_id = uuid4() + + # Create a suspended task without checkpoint + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id=task_id, state="suspended") + + # Try to resume - should not raise + await worker._handle_resume({"task_id": task_id}) + + # Task should remain suspended (since no checkpoint) + task = await storage.load_task(task_id) + assert task["status"]["state"] == "suspended" diff --git a/tests/integration/test_state_transitions.py b/tests/integration/test_state_transitions.py new file mode 100644 index 00000000..41744de7 --- /dev/null +++ b/tests/integration/test_state_transitions.py @@ -0,0 +1,323 @@ +"""Integration tests for task state transitions and validation.""" + +import pytest +import pytest_asyncio +from uuid import uuid4 + +from bindu.server.storage.memory_storage import InMemoryStorage +from bindu.settings import app_settings + + +@pytest_asyncio.fixture +async def storage() -> InMemoryStorage: + """Create a fresh in-memory storage for each test.""" + return InMemoryStorage() + + +def create_test_message(task_id: uuid4, context_id: uuid4): + """Helper to create a test message.""" + return { + "message_id": uuid4(), + "task_id": task_id, + "context_id": context_id, + "role": "user", + "parts": [{"kind": "text", "text": "test"}], + } + + +class TestSuspendedStateClassification: + """Tests for suspended state classification in state machine.""" + + def test_suspended_is_non_terminal(self): + """Test that 'suspended' is classified as non-terminal state.""" + assert "suspended" in app_settings.agent.non_terminal_states + assert "suspended" not in app_settings.agent.terminal_states + + def test_pausable_states_defined(self): + """Test that pausable_states is properly defined.""" + assert hasattr(app_settings.agent, "pausable_states") + assert "submitted" in app_settings.agent.pausable_states + assert "working" in app_settings.agent.pausable_states + assert "input-required" in app_settings.agent.pausable_states + + def test_terminal_states_excludes_suspended(self): + """Test that terminal states do NOT include suspended.""" + for state in app_settings.agent.terminal_states: + assert state != "suspended" + + def test_pausable_states_are_non_terminal(self): + """Test that all pausable states are also non-terminal.""" + for state in app_settings.agent.pausable_states: + assert state in app_settings.agent.non_terminal_states + + +class TestStateTransitions: + """Tests for valid state transitions.""" + + @pytest.mark.asyncio + async def test_submitted_to_working_transition(self, storage: InMemoryStorage): + """Test valid transition from submitted to working.""" + task_id = uuid4() + context_id = uuid4() + + # Create task in submitted state + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + + # Update to working + await storage.update_task(task_id, state="working") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "working" + + @pytest.mark.asyncio + async def test_working_to_suspended_transition(self, storage: InMemoryStorage): + """Test valid transition from working to suspended.""" + task_id = uuid4() + context_id = uuid4() + + # Create task in working state + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working") + + # Update to suspended + await storage.update_task(task_id, state="suspended") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "suspended" + + @pytest.mark.asyncio + async def test_suspended_to_working_transition(self, storage: InMemoryStorage): + """Test valid transition from suspended to working (resume).""" + task_id = uuid4() + context_id = uuid4() + + # Create task in suspended state + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="suspended") + + # Update to working + await storage.update_task(task_id, state="working") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "working" + + @pytest.mark.asyncio + async def test_working_to_completed_transition(self, storage: InMemoryStorage): + """Test valid transition from working to completed.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working") + await storage.update_task(task_id, state="completed") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "completed" + + @pytest.mark.asyncio + async def test_working_to_failed_transition(self, storage: InMemoryStorage): + """Test valid transition from working to failed.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working") + await storage.update_task(task_id, state="failed") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "failed" + + @pytest.mark.asyncio + async def test_working_to_canceled_transition(self, storage: InMemoryStorage): + """Test valid transition from working to canceled.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working") + await storage.update_task(task_id, state="canceled") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "canceled" + + @pytest.mark.asyncio + async def test_working_to_input_required_transition(self, storage: InMemoryStorage): + """Test valid transition from working to input-required.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working") + await storage.update_task(task_id, state="input-required") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "input-required" + + @pytest.mark.asyncio + async def test_input_required_to_working_transition(self, storage: InMemoryStorage): + """Test valid transition from input-required to working.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="input-required") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "input-required" + + # Continue working after input + await storage.update_task(task_id, state="working") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "working" + + +class TestTerminalStateValidation: + """Tests for terminal state immutability.""" + + @pytest.mark.asyncio + async def test_cannot_modify_completed_task(self, storage: InMemoryStorage): + """Test that completed tasks cannot be modified.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="completed") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "completed" + + @pytest.mark.asyncio + async def test_cannot_modify_failed_task(self, storage: InMemoryStorage): + """Test that failed tasks cannot be modified.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="failed") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "failed" + + @pytest.mark.asyncio + async def test_cannot_modify_canceled_task(self, storage: InMemoryStorage): + """Test that canceled tasks cannot be modified.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="canceled") + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "canceled" + + +class TestNonTerminalStateMutability: + """Tests for non-terminal state mutability.""" + + @pytest.mark.asyncio + async def test_can_modify_submitted_task(self, storage: InMemoryStorage): + """Test that submitted tasks can be modified.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + task = await storage.load_task(task_id) + assert task["status"]["state"] == "submitted" + + await storage.update_task(task_id, state="submitted", metadata={"key": "value"}) + task = await storage.load_task(task_id) + assert task["status"]["state"] == "submitted" + assert task["metadata"]["key"] == "value" + + @pytest.mark.asyncio + async def test_can_modify_working_task(self, storage: InMemoryStorage): + """Test that working tasks can be modified.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working", metadata={"progress": 50}) + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "working" + assert task["metadata"]["progress"] == 50 + + @pytest.mark.asyncio + async def test_can_modify_suspended_task(self, storage: InMemoryStorage): + """Test that suspended tasks can be modified.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="suspended", metadata={"paused": True}) + + task = await storage.load_task(task_id) + assert task["status"]["state"] == "suspended" + assert task["metadata"]["paused"] is True + + +class TestStateWorkflowIntegration: + """Integration tests for complete state workflows.""" + + @pytest.mark.asyncio + async def test_workflow_submitted_working_completed(self, storage: InMemoryStorage): + """Test complete workflow: submitted -> working -> completed.""" + task_id = uuid4() + context_id = uuid4() + + # Start (submitted) + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + task = await storage.load_task(task_id) + assert task["status"]["state"] == "submitted" + + # Work + await storage.update_task(task_id, state="working") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "working" + + # Complete + await storage.update_task(task_id, state="completed") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "completed" + + @pytest.mark.asyncio + async def test_workflow_with_pause_resume(self, storage: InMemoryStorage): + """Test workflow with pause/resume: submitted -> working -> suspended -> working -> completed.""" + task_id = uuid4() + context_id = uuid4() + + # Start working + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working") + + # Pause + await storage.update_task(task_id, state="suspended") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "suspended" + + # Resume + await storage.update_task(task_id, state="working") + + # Complete + await storage.update_task(task_id, state="completed") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "completed" + + @pytest.mark.asyncio + async def test_workflow_with_input_required(self, storage: InMemoryStorage): + """Test workflow with input-required: working -> input-required -> working -> completed.""" + task_id = uuid4() + context_id = uuid4() + + await storage.submit_task(context_id, create_test_message(task_id, context_id)) + await storage.update_task(task_id, state="working") + + # Request input + await storage.update_task(task_id, state="input-required") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "input-required" + + # Continue working after input + await storage.update_task(task_id, state="working") + + # Complete + await storage.update_task(task_id, state="completed") + task = await storage.load_task(task_id) + assert task["status"]["state"] == "completed" From fec6e4fb1f555e0099434fbbc0f24819584ec1c3 Mon Sep 17 00:00:00 2001 From: Kkt04 Date: Sat, 7 Mar 2026 14:18:14 +0530 Subject: [PATCH 2/3] Add confirmation dialog to AgentStatePanel before clearing --- .../src/lib/components/AgentStatePanel.svelte | 43 ++++++++++++++++++- frontend/src/lib/services/agent-api.ts | 4 ++ frontend/src/lib/stores/chat.ts | 35 ++++++++++----- frontend/src/routes/+page.svelte | 6 ++- 4 files changed, 74 insertions(+), 14 deletions(-) diff --git a/frontend/src/lib/components/AgentStatePanel.svelte b/frontend/src/lib/components/AgentStatePanel.svelte index 94fe4928..06489a0d 100644 --- a/frontend/src/lib/components/AgentStatePanel.svelte +++ b/frontend/src/lib/components/AgentStatePanel.svelte @@ -1,5 +1,6 @@ + +
Agent inspector @@ -127,11 +142,7 @@ aria-label="Clear context" onclick={() => (showConfirm = "context")} > - + {/if}
@@ -153,11 +164,7 @@ aria-label="Clear tasks" onclick={() => (showConfirm = "tasks")} > - + {/if}