diff --git a/dapr_agents/agents/configs.py b/dapr_agents/agents/configs.py index 348be554..6becbbaa 100644 --- a/dapr_agents/agents/configs.py +++ b/dapr_agents/agents/configs.py @@ -2,7 +2,17 @@ import re from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Sequence, Type +from typing import ( + Any, + Callable, + Dict, + List, + MutableMapping, + Optional, + Sequence, + Type, + Union, +) from pydantic import BaseModel @@ -265,3 +275,23 @@ class AgentExecutionConfig: # TODO: add stop_at_tokens max_iterations: int = 10 tool_choice: Optional[str] = "auto" + + +@dataclass +class WorkflowRetryPolicy: + """ + Configuration for durable retry policies in workflows. + + Attributes: + max_attempts: Maximum number of retry attempts. + initial_backoff_seconds: Initial backoff interval in seconds. + max_backoff_seconds: Maximum backoff interval in seconds. + backoff_multiplier: Multiplier for exponential backoff. + retry_timeout: Optional total timeout for all retries in seconds. + """ + + max_attempts: Optional[int] = 1 + initial_backoff_seconds: Optional[int] = 5 + max_backoff_seconds: Optional[int] = 30 + backoff_multiplier: Optional[float] = 1.5 + retry_timeout: Optional[Union[int, None]] = None diff --git a/dapr_agents/agents/durable.py b/dapr_agents/agents/durable.py index 617ba1bb..3dfc11d2 100644 --- a/dapr_agents/agents/durable.py +++ b/dapr_agents/agents/durable.py @@ -1,8 +1,10 @@ from __future__ import annotations +from datetime import timedelta import json import logging from typing import Any, Dict, Iterable, List, Optional +from os import getenv import dapr.ext.workflow as wf @@ -14,6 +16,7 @@ AgentRegistryConfig, AgentStateConfig, WorkflowGrpcOptions, + WorkflowRetryPolicy, ) from dapr_agents.agents.prompting import AgentProfileConfig from dapr_agents.agents.schemas import ( @@ -26,7 +29,6 @@ from dapr_agents.types import ( AgentError, LLMChatResponse, - ToolExecutionRecord, ToolMessage, UserMessage, ) @@ -76,6 +78,7 @@ def __init__( agent_metadata: Optional[Dict[str, Any]] = None, workflow_grpc: Optional[WorkflowGrpcOptions] = None, runtime: Optional[wf.WorkflowRuntime] = None, + retry_policy: WorkflowRetryPolicy = WorkflowRetryPolicy(), ) -> None: """ Initialize behavior, infrastructure, and workflow runtime. @@ -104,6 +107,7 @@ def __init__( agent_metadata: Extra metadata to publish to the registry. workflow_grpc: Optional gRPC overrides for the workflow runtime channel. runtime: Optional pre-existing workflow runtime to attach to. + retry_policy: Durable retry policy configuration. """ super().__init__( pubsub=pubsub, @@ -132,6 +136,28 @@ def __init__( self._registered = False self._started = False + try: + retries = int(getenv("DAPR_API_MAX_RETRIES", "")) + except ValueError: + retries = retry_policy.max_attempts + + if retries < 1: + raise ( + ValueError("max_attempts or DAPR_API_MAX_RETRIES must be at least 1.") + ) + + self._retry_policy: wf.RetryPolicy = wf.RetryPolicy( + max_number_of_attempts=retries, + first_retry_interval=timedelta( + seconds=retry_policy.initial_backoff_seconds + ), + max_retry_interval=timedelta(seconds=retry_policy.max_backoff_seconds), + backoff_coefficient=retry_policy.backoff_multiplier, + retry_timeout=timedelta(seconds=retry_policy.retry_timeout) + if retry_policy.retry_timeout + else None, + ) + # ------------------------------------------------------------------ # Runtime accessors # ------------------------------------------------------------------ @@ -203,6 +229,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): "start_time": ctx.current_utc_datetime.isoformat(), "trace_context": otel_span_context, }, + retry_policy=self._retry_policy, ) final_message: Dict[str, Any] = {} @@ -226,6 +253,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): "instance_id": ctx.instance_id, "time": ctx.current_utc_datetime.isoformat(), }, + retry_policy=self._retry_policy, ) tool_calls = assistant_response.get("tool_calls") or [] @@ -246,6 +274,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): "time": ctx.current_utc_datetime.isoformat(), "order": idx, }, + retry_policy=self._retry_policy, ) for idx, tc in enumerate(tool_calls) ] @@ -257,6 +286,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): "tool_results": tool_results, "instance_id": ctx.instance_id, }, + retry_policy=self._retry_policy, ) task = None # prepare for next turn @@ -298,6 +328,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): yield ctx.call_activity( self.broadcast_message_to_agents, input={"message": final_message}, + retry_policy=self._retry_policy, ) # Optionally send a direct response back to the trigger origin. @@ -309,6 +340,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): "target_agent": source, "target_instance_id": trigger_instance_id, }, + retry_policy=self._retry_policy, ) # Finalize the workflow entry in durable state. @@ -320,6 +352,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): "end_time": ctx.current_utc_datetime.isoformat(), "triggering_workflow_instance_id": trigger_instance_id, }, + retry_policy=self._retry_policy, ) if not ctx.is_replaying: diff --git a/tests/agents/durableagent/test_durable_agent.py b/tests/agents/durableagent/test_durable_agent.py index 21832d23..e735d3b5 100644 --- a/tests/agents/durableagent/test_durable_agent.py +++ b/tests/agents/durableagent/test_durable_agent.py @@ -2,7 +2,9 @@ # Right now we have to do a bunch of patching at the class-level instead of patching at the instance-level. # In future, we should do dependency injection instead of patching at the class-level to make it easier to test. # This applies to all areas in this file where we have with patch.object()... +from datetime import timedelta import os +from typing import Optional from unittest.mock import AsyncMock, Mock, patch, MagicMock import pytest @@ -15,6 +17,7 @@ AgentRegistryConfig, AgentMemoryConfig, AgentExecutionConfig, + WorkflowRetryPolicy, ) from dapr_agents.agents.schemas import ( AgentWorkflowMessage, @@ -43,6 +46,23 @@ def patch_dapr_check(monkeypatch): mock_runtime = Mock(spec=wf.WorkflowRuntime) monkeypatch.setattr(wf, "WorkflowRuntime", lambda: mock_runtime) + class MockRetryPolicy: + def __init__( + self, + max_number_of_attempts=1, + first_retry_interval=timedelta(seconds=1), + max_retry_interval=timedelta(seconds=60), + backoff_coefficient=2.0, + retry_timeout: Optional[timedelta] = None, + ): + self.max_number_of_attempts = max_number_of_attempts + self.first_retry_interval = first_retry_interval + self.max_retry_interval = max_retry_interval + self.backoff_coefficient = backoff_coefficient + self.retry_timeout = retry_timeout + + monkeypatch.setattr(wf, "RetryPolicy", MockRetryPolicy) + # Return the mock runtime for tests that need it yield mock_runtime @@ -990,3 +1010,214 @@ def test_durable_agent_state_initialization(self, basic_durable_agent): assert isinstance(validated_state, AgentWorkflowState) assert "instances" in basic_durable_agent.state assert basic_durable_agent.state["instances"] == {} + + def test_durable_agent_retry_policy_initialization(self, mock_llm): + """Test that DurableAgent correctly initializes with retry policy parameters.""" + agent = DurableAgent( + name="RetryTestAgent", + role="Retry Test Assistant", + llm=mock_llm, + pubsub=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="RetryTestAgent", + ), + retry_policy=WorkflowRetryPolicy( + max_attempts=5, + initial_backoff_seconds=10, + max_backoff_seconds=60, + backoff_multiplier=2.0, + retry_timeout=300, + ), + ) + + assert agent._retry_policy is not None + assert agent._retry_policy.max_number_of_attempts == 5 + assert agent._retry_policy.first_retry_interval.total_seconds() == 10 + assert agent._retry_policy.max_retry_interval.total_seconds() == 60 + assert agent._retry_policy.backoff_coefficient == 2.0 + assert agent._retry_policy.retry_timeout.total_seconds() == 300 + + def test_durable_agent_retry_policy_defaults(self, mock_llm): + """Test that DurableAgent uses correct default retry values.""" + agent = DurableAgent( + name="RetryDefaultAgent", + role="Retry Default Assistant", + llm=mock_llm, + pubsub=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="RetryDefaultAgent", + ), + ) + + assert agent._retry_policy is not None + assert agent._retry_policy.max_number_of_attempts == 1 + assert agent._retry_policy.first_retry_interval.total_seconds() == 5 + assert agent._retry_policy.max_retry_interval.total_seconds() == 30 + assert agent._retry_policy.backoff_coefficient == 1.5 + assert agent._retry_policy.retry_timeout is None + + def test_durable_agent_retry_policy_env_override(self, mock_llm, monkeypatch): + """Test that DAPR_API_MAX_RETRIES environment variable overrides max_attempts.""" + monkeypatch.setenv("DAPR_API_MAX_RETRIES", "10") + + agent = DurableAgent( + name="RetryEnvAgent", + role="Retry Env Assistant", + llm=mock_llm, + pubsub=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="RetryEnvAgent", + ), + retry_policy=WorkflowRetryPolicy(max_attempts=3), + ) + + # Should use env var value over max_attempts + assert agent._retry_policy.max_number_of_attempts == 10 + + def test_durable_agent_retry_policy_invalid_env(self, mock_llm, monkeypatch): + """Test that invalid DAPR_API_MAX_RETRIES falls back to max_attempts.""" + monkeypatch.setenv("DAPR_API_MAX_RETRIES", "invalid") + + agent = DurableAgent( + name="RetryInvalidEnvAgent", + role="Retry Invalid Env Assistant", + llm=mock_llm, + pubsub=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="RetryInvalidEnvAgent", + ), + retry_policy=WorkflowRetryPolicy(max_attempts=3), + ) + + # Should fall back to max_attempts since env var is invalid + assert agent._retry_policy.max_number_of_attempts == 3 + + def test_durable_agent_retry_policy_min_attempts_validation(self, mock_llm): + """Test that max_attempts cannot be less than 1.""" + with pytest.raises( + ValueError, match="max_attempts or DAPR_API_MAX_RETRIES must be at least 1." + ): + DurableAgent( + name="RetryZeroAgent", + role="Retry Zero Assistant", + llm=mock_llm, + pubsub=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="RetryZeroAgent", + ), + retry_policy=WorkflowRetryPolicy(max_attempts=0), + ) + + def test_agent_workflow_applies_retry_policy( + self, basic_durable_agent, mock_workflow_context + ): + """Test that agent_workflow applies retry policy to activity calls.""" + message = { + "task": "Test task with retries", + "workflow_instance_id": "parent-instance-123", + } + + call_activity_calls = [] + + def track_call_activity(activity, **kwargs): + call_activity_calls.append( + { + "activity": activity, + "input": kwargs.get("input"), + "retry_policy": kwargs.get("retry_policy"), + } + ) + + if hasattr(activity, "__name__"): + activity_name = activity.__name__ + elif hasattr(activity, "__func__"): + activity_name = activity.__func__.__name__ + else: + activity_name = str(activity) + + if activity_name == "call_llm": + return { + "content": "Test response", + "tool_calls": [ + { + "id": "call_test_123", + "type": "function", + "function": { + "name": "test_tool", + "arguments": '{"arg": "value"}', + }, + } + ], + "role": "assistant", + } + elif activity_name == "run_tool": + return { + "tool_call_id": "call_test_123", + "content": "tool result", + "role": "tool", + "name": "test_tool", + } + elif activity_name in [ + "record_initial_entry", + "finalize_workflow", + "save_tool_results", + ]: + return None + + mock_workflow_context.instance_id = "test-instance-123" + mock_workflow_context.call_activity = Mock(side_effect=track_call_activity) + + # Set up minimal state + entry = AgentWorkflowEntry( + input_value="Test task with retries", + source=None, + triggering_workflow_instance_id="parent-instance-123", + workflow_instance_id="test-instance-123", + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + ) + basic_durable_agent._state_model.instances["test-instance-123"] = entry + + # Run the workflow generator + workflow_gen = basic_durable_agent.agent_workflow( + mock_workflow_context, message + ) + + # Step through the generator, sending results back + result = None + try: + while True: + result = workflow_gen.send(result) + except StopIteration as e: + result = e.value + + # Verify that retry_policy was passed to critical activities + assert ( + len(call_activity_calls) >= 5 + ), f"Expected at least 3 activity calls, got {len(call_activity_calls)}" + + # All activities should have retry_policy parameter + for call in call_activity_calls: + assert "retry_policy" in call, f"Missing retry_policy in call: {call}" + assert ( + call["retry_policy"] == basic_durable_agent._retry_policy + ), f"Expected retry_policy {basic_durable_agent._retry_policy}, got {call['retry_policy']}" + + # Verify the key activities were called + activity_names = [ + getattr(call["activity"], "__name__", str(call["activity"])) + for call in call_activity_calls + ] + assert ( + "record_initial_entry" in activity_names + ), f"Missing record_initial_entry in {activity_names}" + assert "call_llm" in activity_names, f"Missing call_llm in {activity_names}" + assert "run_tool" in activity_names, f"Missing run_tool in {activity_names}" + assert ( + "save_tool_results" in activity_names + ), f"Missing save_tool_results in {activity_names}" + assert ( + "finalize_workflow" in activity_names + ), f"Missing finalize_workflow in {activity_names}"