From 78c7f5b31136217458369c6c5379a5d3dce27c20 Mon Sep 17 00:00:00 2001 From: Rares Polenciuc Date: Mon, 24 Nov 2025 13:28:01 +0000 Subject: [PATCH] feat: add per-execution lambda endpoint support - Add lambda_endpoint field to StartDurableExecutionInput - Cache clients by endpoint to avoid race conditions - Maintain backward compatibility --- .../executor.py | 1 + .../invoker.py | 121 +++++++++++++----- .../model.py | 4 + 3 files changed, 91 insertions(+), 35 deletions(-) diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index e707366..d06888f 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -106,6 +106,7 @@ def start_execution( trace_fields=input.trace_fields, tenant_id=input.tenant_id, input=input.input, + lambda_endpoint=input.lambda_endpoint, ) execution = Execution.new(input=input) diff --git a/src/aws_durable_execution_sdk_python_testing/invoker.py b/src/aws_durable_execution_sdk_python_testing/invoker.py index 7ae00ee..9ee3eed 100644 --- a/src/aws_durable_execution_sdk_python_testing/invoker.py +++ b/src/aws_durable_execution_sdk_python_testing/invoker.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from threading import Lock from typing import TYPE_CHECKING, Any, Protocol import boto3 # type: ignore @@ -108,21 +109,68 @@ def update_endpoint(self, endpoint_url: str, region_name: str) -> None: class LambdaInvoker(Invoker): def __init__(self, lambda_client: Any) -> None: self.lambda_client = lambda_client + # Maps execution_arn -> endpoint for that execution + # Maps endpoint -> client to reuse clients across executions + self._execution_endpoints: dict[str, str] = {} + self._endpoint_clients: dict[str, Any] = {} + self._current_endpoint: str = "" # Track current endpoint for new executions + self._lock = Lock() @staticmethod def create(endpoint_url: str, region_name: str) -> LambdaInvoker: """Create with the boto lambda client.""" - return LambdaInvoker( + invoker = LambdaInvoker( boto3.client( "lambdainternal", endpoint_url=endpoint_url, region_name=region_name ) ) + invoker._current_endpoint = endpoint_url + invoker._endpoint_clients[endpoint_url] = invoker.lambda_client + return invoker def update_endpoint(self, endpoint_url: str, region_name: str) -> None: """Update the Lambda client endpoint.""" - self.lambda_client = boto3.client( - "lambdainternal", endpoint_url=endpoint_url, region_name=region_name - ) + # Cache client by endpoint to reuse across executions + with self._lock: + if endpoint_url not in self._endpoint_clients: + self._endpoint_clients[endpoint_url] = boto3.client( + "lambdainternal", endpoint_url=endpoint_url, region_name=region_name + ) + self.lambda_client = self._endpoint_clients[endpoint_url] + self._current_endpoint = endpoint_url + + def _get_client_for_execution( + self, durable_execution_arn: str, lambda_endpoint: str | None = None + ) -> Any: + """Get the appropriate client for this execution.""" + # Use provided endpoint or fall back to cached endpoint for this execution + if lambda_endpoint: + # Client should already exist from update_endpoint() call + if lambda_endpoint not in self._endpoint_clients: + from aws_durable_execution_sdk_python_testing.exceptions import ( + ServiceException, + ) + + raise ServiceException( + f"Lambda endpoint {lambda_endpoint} not configured. update_endpoint() must be called first." + ) + return self._endpoint_clients[lambda_endpoint] + + # Fallback to cached endpoint + if durable_execution_arn not in self._execution_endpoints: + with self._lock: + if durable_execution_arn not in self._execution_endpoints: + self._execution_endpoints[durable_execution_arn] = ( + self._current_endpoint + ) + + endpoint = self._execution_endpoints[durable_execution_arn] + + # If no endpoint configured, fall back to default client + if not endpoint: + return self.lambda_client + + return self._endpoint_clients[endpoint] def create_invocation_input( self, execution: Execution @@ -165,9 +213,12 @@ def invoke( msg = "Function name is required" raise InvalidParameterValueException(msg) + # Get the client for this execution + client = self._get_client_for_execution(input.durable_execution_arn) + try: # Invoke AWS Lambda function using standard invoke method - response = self.lambda_client.invoke( + response = client.invoke( FunctionName=function_name, InvocationType="RequestResponse", # Synchronous invocation Payload=json.dumps(input.to_dict(), default=str), @@ -192,49 +243,49 @@ def invoke( # Convert to DurableExecutionInvocationOutput return DurableExecutionInvocationOutput.from_dict(response_dict) - except self.lambda_client.exceptions.ResourceNotFoundException as e: + except client.exceptions.ResourceNotFoundException as e: msg = f"Function not found: {function_name}" raise ResourceNotFoundException(msg) from e - except self.lambda_client.exceptions.InvalidParameterValueException as e: + except client.exceptions.InvalidParameterValueException as e: msg = f"Invalid parameter: {e}" raise InvalidParameterValueException(msg) from e except ( - self.lambda_client.exceptions.TooManyRequestsException, - self.lambda_client.exceptions.ServiceException, - self.lambda_client.exceptions.ResourceConflictException, - self.lambda_client.exceptions.InvalidRequestContentException, - self.lambda_client.exceptions.RequestTooLargeException, - self.lambda_client.exceptions.UnsupportedMediaTypeException, - self.lambda_client.exceptions.InvalidRuntimeException, - self.lambda_client.exceptions.InvalidZipFileException, - self.lambda_client.exceptions.ResourceNotReadyException, - self.lambda_client.exceptions.SnapStartTimeoutException, - self.lambda_client.exceptions.SnapStartNotReadyException, - self.lambda_client.exceptions.SnapStartException, - self.lambda_client.exceptions.RecursiveInvocationException, + client.exceptions.TooManyRequestsException, + client.exceptions.ServiceException, + client.exceptions.ResourceConflictException, + client.exceptions.InvalidRequestContentException, + client.exceptions.RequestTooLargeException, + client.exceptions.UnsupportedMediaTypeException, + client.exceptions.InvalidRuntimeException, + client.exceptions.InvalidZipFileException, + client.exceptions.ResourceNotReadyException, + client.exceptions.SnapStartTimeoutException, + client.exceptions.SnapStartNotReadyException, + client.exceptions.SnapStartException, + client.exceptions.RecursiveInvocationException, ) as e: msg = f"Lambda invocation failed: {e}" raise DurableFunctionsTestError(msg) from e except ( - self.lambda_client.exceptions.InvalidSecurityGroupIDException, - self.lambda_client.exceptions.EC2ThrottledException, - self.lambda_client.exceptions.EFSMountConnectivityException, - self.lambda_client.exceptions.SubnetIPAddressLimitReachedException, - self.lambda_client.exceptions.EC2UnexpectedException, - self.lambda_client.exceptions.InvalidSubnetIDException, - self.lambda_client.exceptions.EC2AccessDeniedException, - self.lambda_client.exceptions.EFSIOException, - self.lambda_client.exceptions.ENILimitReachedException, - self.lambda_client.exceptions.EFSMountTimeoutException, - self.lambda_client.exceptions.EFSMountFailureException, + client.exceptions.InvalidSecurityGroupIDException, + client.exceptions.EC2ThrottledException, + client.exceptions.EFSMountConnectivityException, + client.exceptions.SubnetIPAddressLimitReachedException, + client.exceptions.EC2UnexpectedException, + client.exceptions.InvalidSubnetIDException, + client.exceptions.EC2AccessDeniedException, + client.exceptions.EFSIOException, + client.exceptions.ENILimitReachedException, + client.exceptions.EFSMountTimeoutException, + client.exceptions.EFSMountFailureException, ) as e: msg = f"Lambda infrastructure error: {e}" raise DurableFunctionsTestError(msg) from e except ( - self.lambda_client.exceptions.KMSAccessDeniedException, - self.lambda_client.exceptions.KMSDisabledException, - self.lambda_client.exceptions.KMSNotFoundException, - self.lambda_client.exceptions.KMSInvalidStateException, + client.exceptions.KMSAccessDeniedException, + client.exceptions.KMSDisabledException, + client.exceptions.KMSNotFoundException, + client.exceptions.KMSInvalidStateException, ) as e: msg = f"Lambda KMS error: {e}" raise DurableFunctionsTestError(msg) from e diff --git a/src/aws_durable_execution_sdk_python_testing/model.py b/src/aws_durable_execution_sdk_python_testing/model.py index a440ef9..c678c8b 100644 --- a/src/aws_durable_execution_sdk_python_testing/model.py +++ b/src/aws_durable_execution_sdk_python_testing/model.py @@ -117,6 +117,7 @@ class StartDurableExecutionInput: trace_fields: dict | None = None tenant_id: str | None = None input: str | None = None + lambda_endpoint: str | None = None # Endpoint for this specific execution @classmethod def from_dict(cls, data: dict) -> StartDurableExecutionInput: @@ -146,6 +147,7 @@ def from_dict(cls, data: dict) -> StartDurableExecutionInput: trace_fields=data.get("TraceFields"), tenant_id=data.get("TenantId"), input=data.get("Input"), + lambda_endpoint=data.get("LambdaEndpoint", None), ) def to_dict(self) -> dict[str, Any]: @@ -165,6 +167,8 @@ def to_dict(self) -> dict[str, Any]: result["TenantId"] = self.tenant_id if self.input is not None: result["Input"] = self.input + if self.lambda_endpoint is not None: + result["LambdaEndpoint"] = self.lambda_endpoint return result def get_normalized_input(self):