11from __future__ import annotations
22
33import json
4+ from threading import Lock
45from typing import TYPE_CHECKING , Any , Protocol
56
67import boto3 # type: ignore
@@ -108,21 +109,68 @@ def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
108109class LambdaInvoker (Invoker ):
109110 def __init__ (self , lambda_client : Any ) -> None :
110111 self .lambda_client = lambda_client
112+ # Maps execution_arn -> endpoint for that execution
113+ # Maps endpoint -> client to reuse clients across executions
114+ self ._execution_endpoints : dict [str , str ] = {}
115+ self ._endpoint_clients : dict [str , Any ] = {}
116+ self ._current_endpoint : str = "" # Track current endpoint for new executions
117+ self ._lock = Lock ()
111118
112119 @staticmethod
113120 def create (endpoint_url : str , region_name : str ) -> LambdaInvoker :
114121 """Create with the boto lambda client."""
115- return LambdaInvoker (
122+ invoker = LambdaInvoker (
116123 boto3 .client (
117124 "lambdainternal" , endpoint_url = endpoint_url , region_name = region_name
118125 )
119126 )
127+ invoker ._current_endpoint = endpoint_url
128+ invoker ._endpoint_clients [endpoint_url ] = invoker .lambda_client
129+ return invoker
120130
121131 def update_endpoint (self , endpoint_url : str , region_name : str ) -> None :
122132 """Update the Lambda client endpoint."""
123- self .lambda_client = boto3 .client (
124- "lambdainternal" , endpoint_url = endpoint_url , region_name = region_name
125- )
133+ # Cache client by endpoint to reuse across executions
134+ with self ._lock :
135+ if endpoint_url not in self ._endpoint_clients :
136+ self ._endpoint_clients [endpoint_url ] = boto3 .client (
137+ "lambdainternal" , endpoint_url = endpoint_url , region_name = region_name
138+ )
139+ self .lambda_client = self ._endpoint_clients [endpoint_url ]
140+ self ._current_endpoint = endpoint_url
141+
142+ def _get_client_for_execution (
143+ self , durable_execution_arn : str , lambda_endpoint : str | None = None
144+ ) -> Any :
145+ """Get the appropriate client for this execution."""
146+ # Use provided endpoint or fall back to cached endpoint for this execution
147+ if lambda_endpoint :
148+ # Client should already exist from update_endpoint() call
149+ if lambda_endpoint not in self ._endpoint_clients :
150+ from aws_durable_execution_sdk_python_testing .exceptions import (
151+ ServiceException ,
152+ )
153+
154+ raise ServiceException (
155+ f"Lambda endpoint { lambda_endpoint } not configured. update_endpoint() must be called first."
156+ )
157+ return self ._endpoint_clients [lambda_endpoint ]
158+
159+ # Fallback to cached endpoint
160+ if durable_execution_arn not in self ._execution_endpoints :
161+ with self ._lock :
162+ if durable_execution_arn not in self ._execution_endpoints :
163+ self ._execution_endpoints [durable_execution_arn ] = (
164+ self ._current_endpoint
165+ )
166+
167+ endpoint = self ._execution_endpoints [durable_execution_arn ]
168+
169+ # If no endpoint configured, fall back to default client
170+ if not endpoint :
171+ return self .lambda_client
172+
173+ return self ._endpoint_clients [endpoint ]
126174
127175 def create_invocation_input (
128176 self , execution : Execution
@@ -165,9 +213,12 @@ def invoke(
165213 msg = "Function name is required"
166214 raise InvalidParameterValueException (msg )
167215
216+ # Get the client for this execution
217+ client = self ._get_client_for_execution (input .durable_execution_arn )
218+
168219 try :
169220 # Invoke AWS Lambda function using standard invoke method
170- response = self . lambda_client .invoke (
221+ response = client .invoke (
171222 FunctionName = function_name ,
172223 InvocationType = "RequestResponse" , # Synchronous invocation
173224 Payload = json .dumps (input .to_dict (), default = str ),
@@ -192,49 +243,49 @@ def invoke(
192243 # Convert to DurableExecutionInvocationOutput
193244 return DurableExecutionInvocationOutput .from_dict (response_dict )
194245
195- except self . lambda_client .exceptions .ResourceNotFoundException as e :
246+ except client .exceptions .ResourceNotFoundException as e :
196247 msg = f"Function not found: { function_name } "
197248 raise ResourceNotFoundException (msg ) from e
198- except self . lambda_client .exceptions .InvalidParameterValueException as e :
249+ except client .exceptions .InvalidParameterValueException as e :
199250 msg = f"Invalid parameter: { e } "
200251 raise InvalidParameterValueException (msg ) from e
201252 except (
202- self . lambda_client .exceptions .TooManyRequestsException ,
203- self . lambda_client .exceptions .ServiceException ,
204- self . lambda_client .exceptions .ResourceConflictException ,
205- self . lambda_client .exceptions .InvalidRequestContentException ,
206- self . lambda_client .exceptions .RequestTooLargeException ,
207- self . lambda_client .exceptions .UnsupportedMediaTypeException ,
208- self . lambda_client .exceptions .InvalidRuntimeException ,
209- self . lambda_client .exceptions .InvalidZipFileException ,
210- self . lambda_client .exceptions .ResourceNotReadyException ,
211- self . lambda_client .exceptions .SnapStartTimeoutException ,
212- self . lambda_client .exceptions .SnapStartNotReadyException ,
213- self . lambda_client .exceptions .SnapStartException ,
214- self . lambda_client .exceptions .RecursiveInvocationException ,
253+ client .exceptions .TooManyRequestsException ,
254+ client .exceptions .ServiceException ,
255+ client .exceptions .ResourceConflictException ,
256+ client .exceptions .InvalidRequestContentException ,
257+ client .exceptions .RequestTooLargeException ,
258+ client .exceptions .UnsupportedMediaTypeException ,
259+ client .exceptions .InvalidRuntimeException ,
260+ client .exceptions .InvalidZipFileException ,
261+ client .exceptions .ResourceNotReadyException ,
262+ client .exceptions .SnapStartTimeoutException ,
263+ client .exceptions .SnapStartNotReadyException ,
264+ client .exceptions .SnapStartException ,
265+ client .exceptions .RecursiveInvocationException ,
215266 ) as e :
216267 msg = f"Lambda invocation failed: { e } "
217268 raise DurableFunctionsTestError (msg ) from e
218269 except (
219- self . lambda_client .exceptions .InvalidSecurityGroupIDException ,
220- self . lambda_client .exceptions .EC2ThrottledException ,
221- self . lambda_client .exceptions .EFSMountConnectivityException ,
222- self . lambda_client .exceptions .SubnetIPAddressLimitReachedException ,
223- self . lambda_client .exceptions .EC2UnexpectedException ,
224- self . lambda_client .exceptions .InvalidSubnetIDException ,
225- self . lambda_client .exceptions .EC2AccessDeniedException ,
226- self . lambda_client .exceptions .EFSIOException ,
227- self . lambda_client .exceptions .ENILimitReachedException ,
228- self . lambda_client .exceptions .EFSMountTimeoutException ,
229- self . lambda_client .exceptions .EFSMountFailureException ,
270+ client .exceptions .InvalidSecurityGroupIDException ,
271+ client .exceptions .EC2ThrottledException ,
272+ client .exceptions .EFSMountConnectivityException ,
273+ client .exceptions .SubnetIPAddressLimitReachedException ,
274+ client .exceptions .EC2UnexpectedException ,
275+ client .exceptions .InvalidSubnetIDException ,
276+ client .exceptions .EC2AccessDeniedException ,
277+ client .exceptions .EFSIOException ,
278+ client .exceptions .ENILimitReachedException ,
279+ client .exceptions .EFSMountTimeoutException ,
280+ client .exceptions .EFSMountFailureException ,
230281 ) as e :
231282 msg = f"Lambda infrastructure error: { e } "
232283 raise DurableFunctionsTestError (msg ) from e
233284 except (
234- self . lambda_client .exceptions .KMSAccessDeniedException ,
235- self . lambda_client .exceptions .KMSDisabledException ,
236- self . lambda_client .exceptions .KMSNotFoundException ,
237- self . lambda_client .exceptions .KMSInvalidStateException ,
285+ client .exceptions .KMSAccessDeniedException ,
286+ client .exceptions .KMSDisabledException ,
287+ client .exceptions .KMSNotFoundException ,
288+ client .exceptions .KMSInvalidStateException ,
238289 ) as e :
239290 msg = f"Lambda KMS error: { e } "
240291 raise DurableFunctionsTestError (msg ) from e
0 commit comments