Skip to content

Commit 78c7f5b

Browse files
author
Rares Polenciuc
committed
feat: add per-execution lambda endpoint support
- Add lambda_endpoint field to StartDurableExecutionInput - Cache clients by endpoint to avoid race conditions - Maintain backward compatibility
1 parent 97f3db2 commit 78c7f5b

File tree

3 files changed

+91
-35
lines changed

3 files changed

+91
-35
lines changed

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def start_execution(
106106
trace_fields=input.trace_fields,
107107
tenant_id=input.tenant_id,
108108
input=input.input,
109+
lambda_endpoint=input.lambda_endpoint,
109110
)
110111

111112
execution = Execution.new(input=input)

src/aws_durable_execution_sdk_python_testing/invoker.py

Lines changed: 86 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
from threading import Lock
45
from typing import TYPE_CHECKING, Any, Protocol
56

67
import boto3 # type: ignore
@@ -108,21 +109,68 @@ def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
108109
class LambdaInvoker(Invoker):
109110
def __init__(self, lambda_client: Any) -> None:
110111
self.lambda_client = lambda_client
112+
# Maps execution_arn -> endpoint for that execution
113+
# Maps endpoint -> client to reuse clients across executions
114+
self._execution_endpoints: dict[str, str] = {}
115+
self._endpoint_clients: dict[str, Any] = {}
116+
self._current_endpoint: str = "" # Track current endpoint for new executions
117+
self._lock = Lock()
111118

112119
@staticmethod
113120
def create(endpoint_url: str, region_name: str) -> LambdaInvoker:
114121
"""Create with the boto lambda client."""
115-
return LambdaInvoker(
122+
invoker = LambdaInvoker(
116123
boto3.client(
117124
"lambdainternal", endpoint_url=endpoint_url, region_name=region_name
118125
)
119126
)
127+
invoker._current_endpoint = endpoint_url
128+
invoker._endpoint_clients[endpoint_url] = invoker.lambda_client
129+
return invoker
120130

121131
def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
122132
"""Update the Lambda client endpoint."""
123-
self.lambda_client = boto3.client(
124-
"lambdainternal", endpoint_url=endpoint_url, region_name=region_name
125-
)
133+
# Cache client by endpoint to reuse across executions
134+
with self._lock:
135+
if endpoint_url not in self._endpoint_clients:
136+
self._endpoint_clients[endpoint_url] = boto3.client(
137+
"lambdainternal", endpoint_url=endpoint_url, region_name=region_name
138+
)
139+
self.lambda_client = self._endpoint_clients[endpoint_url]
140+
self._current_endpoint = endpoint_url
141+
142+
def _get_client_for_execution(
143+
self, durable_execution_arn: str, lambda_endpoint: str | None = None
144+
) -> Any:
145+
"""Get the appropriate client for this execution."""
146+
# Use provided endpoint or fall back to cached endpoint for this execution
147+
if lambda_endpoint:
148+
# Client should already exist from update_endpoint() call
149+
if lambda_endpoint not in self._endpoint_clients:
150+
from aws_durable_execution_sdk_python_testing.exceptions import (
151+
ServiceException,
152+
)
153+
154+
raise ServiceException(
155+
f"Lambda endpoint {lambda_endpoint} not configured. update_endpoint() must be called first."
156+
)
157+
return self._endpoint_clients[lambda_endpoint]
158+
159+
# Fallback to cached endpoint
160+
if durable_execution_arn not in self._execution_endpoints:
161+
with self._lock:
162+
if durable_execution_arn not in self._execution_endpoints:
163+
self._execution_endpoints[durable_execution_arn] = (
164+
self._current_endpoint
165+
)
166+
167+
endpoint = self._execution_endpoints[durable_execution_arn]
168+
169+
# If no endpoint configured, fall back to default client
170+
if not endpoint:
171+
return self.lambda_client
172+
173+
return self._endpoint_clients[endpoint]
126174

127175
def create_invocation_input(
128176
self, execution: Execution
@@ -165,9 +213,12 @@ def invoke(
165213
msg = "Function name is required"
166214
raise InvalidParameterValueException(msg)
167215

216+
# Get the client for this execution
217+
client = self._get_client_for_execution(input.durable_execution_arn)
218+
168219
try:
169220
# Invoke AWS Lambda function using standard invoke method
170-
response = self.lambda_client.invoke(
221+
response = client.invoke(
171222
FunctionName=function_name,
172223
InvocationType="RequestResponse", # Synchronous invocation
173224
Payload=json.dumps(input.to_dict(), default=str),
@@ -192,49 +243,49 @@ def invoke(
192243
# Convert to DurableExecutionInvocationOutput
193244
return DurableExecutionInvocationOutput.from_dict(response_dict)
194245

195-
except self.lambda_client.exceptions.ResourceNotFoundException as e:
246+
except client.exceptions.ResourceNotFoundException as e:
196247
msg = f"Function not found: {function_name}"
197248
raise ResourceNotFoundException(msg) from e
198-
except self.lambda_client.exceptions.InvalidParameterValueException as e:
249+
except client.exceptions.InvalidParameterValueException as e:
199250
msg = f"Invalid parameter: {e}"
200251
raise InvalidParameterValueException(msg) from e
201252
except (
202-
self.lambda_client.exceptions.TooManyRequestsException,
203-
self.lambda_client.exceptions.ServiceException,
204-
self.lambda_client.exceptions.ResourceConflictException,
205-
self.lambda_client.exceptions.InvalidRequestContentException,
206-
self.lambda_client.exceptions.RequestTooLargeException,
207-
self.lambda_client.exceptions.UnsupportedMediaTypeException,
208-
self.lambda_client.exceptions.InvalidRuntimeException,
209-
self.lambda_client.exceptions.InvalidZipFileException,
210-
self.lambda_client.exceptions.ResourceNotReadyException,
211-
self.lambda_client.exceptions.SnapStartTimeoutException,
212-
self.lambda_client.exceptions.SnapStartNotReadyException,
213-
self.lambda_client.exceptions.SnapStartException,
214-
self.lambda_client.exceptions.RecursiveInvocationException,
253+
client.exceptions.TooManyRequestsException,
254+
client.exceptions.ServiceException,
255+
client.exceptions.ResourceConflictException,
256+
client.exceptions.InvalidRequestContentException,
257+
client.exceptions.RequestTooLargeException,
258+
client.exceptions.UnsupportedMediaTypeException,
259+
client.exceptions.InvalidRuntimeException,
260+
client.exceptions.InvalidZipFileException,
261+
client.exceptions.ResourceNotReadyException,
262+
client.exceptions.SnapStartTimeoutException,
263+
client.exceptions.SnapStartNotReadyException,
264+
client.exceptions.SnapStartException,
265+
client.exceptions.RecursiveInvocationException,
215266
) as e:
216267
msg = f"Lambda invocation failed: {e}"
217268
raise DurableFunctionsTestError(msg) from e
218269
except (
219-
self.lambda_client.exceptions.InvalidSecurityGroupIDException,
220-
self.lambda_client.exceptions.EC2ThrottledException,
221-
self.lambda_client.exceptions.EFSMountConnectivityException,
222-
self.lambda_client.exceptions.SubnetIPAddressLimitReachedException,
223-
self.lambda_client.exceptions.EC2UnexpectedException,
224-
self.lambda_client.exceptions.InvalidSubnetIDException,
225-
self.lambda_client.exceptions.EC2AccessDeniedException,
226-
self.lambda_client.exceptions.EFSIOException,
227-
self.lambda_client.exceptions.ENILimitReachedException,
228-
self.lambda_client.exceptions.EFSMountTimeoutException,
229-
self.lambda_client.exceptions.EFSMountFailureException,
270+
client.exceptions.InvalidSecurityGroupIDException,
271+
client.exceptions.EC2ThrottledException,
272+
client.exceptions.EFSMountConnectivityException,
273+
client.exceptions.SubnetIPAddressLimitReachedException,
274+
client.exceptions.EC2UnexpectedException,
275+
client.exceptions.InvalidSubnetIDException,
276+
client.exceptions.EC2AccessDeniedException,
277+
client.exceptions.EFSIOException,
278+
client.exceptions.ENILimitReachedException,
279+
client.exceptions.EFSMountTimeoutException,
280+
client.exceptions.EFSMountFailureException,
230281
) as e:
231282
msg = f"Lambda infrastructure error: {e}"
232283
raise DurableFunctionsTestError(msg) from e
233284
except (
234-
self.lambda_client.exceptions.KMSAccessDeniedException,
235-
self.lambda_client.exceptions.KMSDisabledException,
236-
self.lambda_client.exceptions.KMSNotFoundException,
237-
self.lambda_client.exceptions.KMSInvalidStateException,
285+
client.exceptions.KMSAccessDeniedException,
286+
client.exceptions.KMSDisabledException,
287+
client.exceptions.KMSNotFoundException,
288+
client.exceptions.KMSInvalidStateException,
238289
) as e:
239290
msg = f"Lambda KMS error: {e}"
240291
raise DurableFunctionsTestError(msg) from e

src/aws_durable_execution_sdk_python_testing/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class StartDurableExecutionInput:
117117
trace_fields: dict | None = None
118118
tenant_id: str | None = None
119119
input: str | None = None
120+
lambda_endpoint: str | None = None # Endpoint for this specific execution
120121

121122
@classmethod
122123
def from_dict(cls, data: dict) -> StartDurableExecutionInput:
@@ -146,6 +147,7 @@ def from_dict(cls, data: dict) -> StartDurableExecutionInput:
146147
trace_fields=data.get("TraceFields"),
147148
tenant_id=data.get("TenantId"),
148149
input=data.get("Input"),
150+
lambda_endpoint=data.get("LambdaEndpoint", None),
149151
)
150152

151153
def to_dict(self) -> dict[str, Any]:
@@ -165,6 +167,8 @@ def to_dict(self) -> dict[str, Any]:
165167
result["TenantId"] = self.tenant_id
166168
if self.input is not None:
167169
result["Input"] = self.input
170+
if self.lambda_endpoint is not None:
171+
result["LambdaEndpoint"] = self.lambda_endpoint
168172
return result
169173

170174
def get_normalized_input(self):

0 commit comments

Comments
 (0)