Skip to content

Commit f649402

Browse files
authored
Merge pull request #314 from CasperGN/fix/221-retrypolicy
Fix/221 retrypolicy
2 parents e32cb47 + e54cbe9 commit f649402

File tree

3 files changed

+296
-2
lines changed

3 files changed

+296
-2
lines changed

dapr_agents/agents/configs.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22

33
import re
44
from dataclasses import dataclass, field
5-
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Sequence, Type
5+
from typing import (
6+
Any,
7+
Callable,
8+
Dict,
9+
List,
10+
MutableMapping,
11+
Optional,
12+
Sequence,
13+
Type,
14+
Union,
15+
)
616

717
from pydantic import BaseModel
818

@@ -265,3 +275,23 @@ class AgentExecutionConfig:
265275
# TODO: add stop_at_tokens
266276
max_iterations: int = 10
267277
tool_choice: Optional[str] = "auto"
278+
279+
280+
@dataclass
281+
class WorkflowRetryPolicy:
282+
"""
283+
Configuration for durable retry policies in workflows.
284+
285+
Attributes:
286+
max_attempts: Maximum number of retry attempts.
287+
initial_backoff_seconds: Initial backoff interval in seconds.
288+
max_backoff_seconds: Maximum backoff interval in seconds.
289+
backoff_multiplier: Multiplier for exponential backoff.
290+
retry_timeout: Optional total timeout for all retries in seconds.
291+
"""
292+
293+
max_attempts: Optional[int] = 1
294+
initial_backoff_seconds: Optional[int] = 5
295+
max_backoff_seconds: Optional[int] = 30
296+
backoff_multiplier: Optional[float] = 1.5
297+
retry_timeout: Optional[Union[int, None]] = None

dapr_agents/agents/durable.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
from datetime import timedelta
34
import json
45
import logging
56
from typing import Any, Dict, Iterable, List, Optional
7+
from os import getenv
68

79
import dapr.ext.workflow as wf
810

@@ -14,6 +16,7 @@
1416
AgentRegistryConfig,
1517
AgentStateConfig,
1618
WorkflowGrpcOptions,
19+
WorkflowRetryPolicy,
1720
)
1821
from dapr_agents.agents.prompting import AgentProfileConfig
1922
from dapr_agents.agents.schemas import (
@@ -26,7 +29,6 @@
2629
from dapr_agents.types import (
2730
AgentError,
2831
LLMChatResponse,
29-
ToolExecutionRecord,
3032
ToolMessage,
3133
UserMessage,
3234
)
@@ -76,6 +78,7 @@ def __init__(
7678
agent_metadata: Optional[Dict[str, Any]] = None,
7779
workflow_grpc: Optional[WorkflowGrpcOptions] = None,
7880
runtime: Optional[wf.WorkflowRuntime] = None,
81+
retry_policy: WorkflowRetryPolicy = WorkflowRetryPolicy(),
7982
) -> None:
8083
"""
8184
Initialize behavior, infrastructure, and workflow runtime.
@@ -104,6 +107,7 @@ def __init__(
104107
agent_metadata: Extra metadata to publish to the registry.
105108
workflow_grpc: Optional gRPC overrides for the workflow runtime channel.
106109
runtime: Optional pre-existing workflow runtime to attach to.
110+
retry_policy: Durable retry policy configuration.
107111
"""
108112
super().__init__(
109113
pubsub=pubsub,
@@ -132,6 +136,28 @@ def __init__(
132136
self._registered = False
133137
self._started = False
134138

139+
try:
140+
retries = int(getenv("DAPR_API_MAX_RETRIES", ""))
141+
except ValueError:
142+
retries = retry_policy.max_attempts
143+
144+
if retries < 1:
145+
raise (
146+
ValueError("max_attempts or DAPR_API_MAX_RETRIES must be at least 1.")
147+
)
148+
149+
self._retry_policy: wf.RetryPolicy = wf.RetryPolicy(
150+
max_number_of_attempts=retries,
151+
first_retry_interval=timedelta(
152+
seconds=retry_policy.initial_backoff_seconds
153+
),
154+
max_retry_interval=timedelta(seconds=retry_policy.max_backoff_seconds),
155+
backoff_coefficient=retry_policy.backoff_multiplier,
156+
retry_timeout=timedelta(seconds=retry_policy.retry_timeout)
157+
if retry_policy.retry_timeout
158+
else None,
159+
)
160+
135161
# ------------------------------------------------------------------
136162
# Runtime accessors
137163
# ------------------------------------------------------------------
@@ -203,6 +229,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
203229
"start_time": ctx.current_utc_datetime.isoformat(),
204230
"trace_context": otel_span_context,
205231
},
232+
retry_policy=self._retry_policy,
206233
)
207234

208235
final_message: Dict[str, Any] = {}
@@ -226,6 +253,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
226253
"instance_id": ctx.instance_id,
227254
"time": ctx.current_utc_datetime.isoformat(),
228255
},
256+
retry_policy=self._retry_policy,
229257
)
230258

231259
tool_calls = assistant_response.get("tool_calls") or []
@@ -246,6 +274,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
246274
"time": ctx.current_utc_datetime.isoformat(),
247275
"order": idx,
248276
},
277+
retry_policy=self._retry_policy,
249278
)
250279
for idx, tc in enumerate(tool_calls)
251280
]
@@ -257,6 +286,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
257286
"tool_results": tool_results,
258287
"instance_id": ctx.instance_id,
259288
},
289+
retry_policy=self._retry_policy,
260290
)
261291

262292
task = None # prepare for next turn
@@ -298,6 +328,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
298328
yield ctx.call_activity(
299329
self.broadcast_message_to_agents,
300330
input={"message": final_message},
331+
retry_policy=self._retry_policy,
301332
)
302333

303334
# Optionally send a direct response back to the trigger origin.
@@ -309,6 +340,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
309340
"target_agent": source,
310341
"target_instance_id": trigger_instance_id,
311342
},
343+
retry_policy=self._retry_policy,
312344
)
313345

314346
# Finalize the workflow entry in durable state.
@@ -320,6 +352,7 @@ def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict):
320352
"end_time": ctx.current_utc_datetime.isoformat(),
321353
"triggering_workflow_instance_id": trigger_instance_id,
322354
},
355+
retry_policy=self._retry_policy,
323356
)
324357

325358
if not ctx.is_replaying:

0 commit comments

Comments
 (0)