diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index 0fe3026..32abf7d 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -770,7 +770,9 @@ async def invoke() -> None: self._store.save(execution) response: DurableExecutionInvocationOutput = self._invoker.invoke( - execution.start_input.function_name, invocation_input + execution.start_input.function_name, + invocation_input, + execution.start_input.lambda_endpoint, ) # Reload execution after invocation in case it was completed via checkpoint diff --git a/src/aws_durable_execution_sdk_python_testing/invoker.py b/src/aws_durable_execution_sdk_python_testing/invoker.py index 9ee3eed..a363340 100644 --- a/src/aws_durable_execution_sdk_python_testing/invoker.py +++ b/src/aws_durable_execution_sdk_python_testing/invoker.py @@ -15,6 +15,7 @@ from aws_durable_execution_sdk_python_testing.exceptions import ( DurableFunctionsTestError, + ServiceException, ) from aws_durable_execution_sdk_python_testing.model import LambdaContext @@ -63,6 +64,7 @@ def invoke( self, function_name: str, input: DurableExecutionInvocationInput, + endpoint_url: str | None = None, ) -> DurableExecutionInvocationOutput: ... # pragma: no cover def update_endpoint( @@ -93,6 +95,7 @@ def invoke( self, function_name: str, # noqa: ARG002 input: DurableExecutionInvocationInput, + endpoint_url: str | None = None, # noqa: ARG002 ) -> DurableExecutionInvocationOutput: # TODO: reasses if function_name will be used in future input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input( @@ -140,19 +143,19 @@ def update_endpoint(self, endpoint_url: str, region_name: str) -> None: self._current_endpoint = endpoint_url def _get_client_for_execution( - self, durable_execution_arn: str, lambda_endpoint: str | None = None + self, + durable_execution_arn: str, + lambda_endpoint: str | None = None, + region_name: str | None = None, ) -> Any: """Get the appropriate client for this execution.""" # Use provided endpoint or fall back to cached endpoint for this execution if lambda_endpoint: - # Client should already exist from update_endpoint() call if lambda_endpoint not in self._endpoint_clients: - from aws_durable_execution_sdk_python_testing.exceptions import ( - ServiceException, - ) - - raise ServiceException( - f"Lambda endpoint {lambda_endpoint} not configured. update_endpoint() must be called first." + self._endpoint_clients[lambda_endpoint] = boto3.client( + "lambdainternal", + endpoint_url=lambda_endpoint, + region_name=region_name or "us-east-1", ) return self._endpoint_clients[lambda_endpoint] @@ -188,12 +191,14 @@ def invoke( self, function_name: str, input: DurableExecutionInvocationInput, + endpoint_url: str | None = None, ) -> DurableExecutionInvocationOutput: """Invoke AWS Lambda function and return durable execution result. Args: function_name: Name of the Lambda function to invoke input: Durable execution invocation input + endpoint_url: Lambda endpoint url Returns: DurableExecutionInvocationOutput: Result of the function execution @@ -214,7 +219,9 @@ def invoke( raise InvalidParameterValueException(msg) # Get the client for this execution - client = self._get_client_for_execution(input.durable_execution_arn) + client = self._get_client_for_execution( + input.durable_execution_arn, endpoint_url + ) try: # Invoke AWS Lambda function using standard invoke method diff --git a/tests/executor_test.py b/tests/executor_test.py index 0e37976..dad04f0 100644 --- a/tests/executor_test.py +++ b/tests/executor_test.py @@ -645,7 +645,9 @@ def test_invoke_handler_success( mock_invoker.create_invocation_input.assert_called_once_with( execution=mock_execution ) - mock_invoker.invoke.assert_called_once_with("test-function", mock_invocation_input) + mock_invoker.invoke.assert_called_once_with( + "test-function", mock_invocation_input, None + ) def test_invoke_handler_execution_already_complete(