diff --git a/CHANGELOG.md b/CHANGELOG.md index 2144672a..3216b2c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ All notable changes to this project will be documented in this file. +## Unreleased + +### Added + +- Client operation correlation logging: `FunctionInvocationId` is now propagated via HTTP headers to the host for client operations, enabling correlation with host logs. + ## 1.0.0b6 - [Create timer](https://github.com/Azure/azure-functions-durable-python/issues/35) functionality available diff --git a/azure/durable_functions/decorators/durable_app.py b/azure/durable_functions/decorators/durable_app.py index 219b7dc8..2b74a045 100644 --- a/azure/durable_functions/decorators/durable_app.py +++ b/azure/durable_functions/decorators/durable_app.py @@ -195,7 +195,14 @@ async def df_client_middleware(*args, **kwargs): # construct rich object from it, # and assign parameter to that rich object starter = kwargs[parameter_name] - client = client_constructor(starter) + + # Try to extract the function invocation ID from the context for correlation + function_invocation_id = None + context = kwargs.get('context') + if context is not None and hasattr(context, 'invocation_id'): + function_invocation_id = context.invocation_id + + client = client_constructor(starter, function_invocation_id) kwargs[parameter_name] = client # Invoke user code with rich DF Client binding diff --git a/azure/durable_functions/models/DurableOrchestrationClient.py b/azure/durable_functions/models/DurableOrchestrationClient.py index cc746cd2..fa0f0978 100644 --- a/azure/durable_functions/models/DurableOrchestrationClient.py +++ b/azure/durable_functions/models/DurableOrchestrationClient.py @@ -26,7 +26,16 @@ class DurableOrchestrationClient: orchestration instances. """ - def __init__(self, context: str): + def __init__(self, context: str, function_invocation_id: Optional[str] = None): + """Initialize a DurableOrchestrationClient. + + Parameters + ---------- + context : str + The JSON-encoded client binding context. + function_invocation_id : Optional[str] + The function invocation ID for correlation with host-side logs. + """ self.task_hub_name: str self._uniqueWebHookOrigins: List[str] self._event_name_placeholder: str = "{eventName}" @@ -39,6 +48,7 @@ def __init__(self, context: str): self._show_history_query_key: str = "showHistory" self._show_history_output_query_key: str = "showHistoryOutput" self._show_input_query_key: str = "showInput" + self._function_invocation_id: Optional[str] = function_invocation_id self._orchestration_bindings: DurableOrchestrationBindings = \ DurableOrchestrationBindings.from_json(context) self._post_async_request = post_async_request @@ -84,7 +94,8 @@ async def start_new(self, request_url, self._get_json_input(client_input), trace_parent, - trace_state) + trace_state, + self._function_invocation_id) status_code: int = response[0] if status_code <= 202 and response[1]: @@ -100,6 +111,7 @@ async def start_new(self, ex_message: Any = response[1] raise Exception(ex_message) + def create_check_status_response( self, request: func.HttpRequest, instance_id: str) -> func.HttpResponse: """Create a HttpResponse that contains useful information for \ @@ -256,7 +268,10 @@ async def raise_event( request_url = self._get_raise_event_url( instance_id, event_name, task_hub_name, connection_name) - response = await self._post_async_request(request_url, json.dumps(event_data)) + response = await self._post_async_request( + request_url, + json.dumps(event_data), + function_invocation_id=self._function_invocation_id) switch_statement = { 202: lambda: None, @@ -445,7 +460,10 @@ async def terminate(self, instance_id: str, reason: str) -> None: """ request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \ f"terminate?reason={quote(reason)}" - response = await self._post_async_request(request_url, None) + response = await self._post_async_request( + request_url, + None, + function_invocation_id=self._function_invocation_id) switch_statement = { 202: lambda: None, # instance in progress 410: lambda: None, # instance failed or terminated @@ -564,7 +582,8 @@ async def signal_entity(self, entityId: EntityId, operation_name: str, request_url, json.dumps(operation_input) if operation_input else None, trace_parent, - trace_state) + trace_state, + self._function_invocation_id) switch_statement = { 202: lambda: None # signal accepted @@ -714,7 +733,10 @@ async def rewind(self, raise Exception("The Python SDK only supports RPC endpoints." + "Please remove the `localRpcEnabled` setting from host.json") - response = await self._post_async_request(request_url, None) + response = await self._post_async_request( + request_url, + None, + function_invocation_id=self._function_invocation_id) status: int = response[0] ex_msg: str = "" if status == 200 or status == 202: @@ -753,7 +775,10 @@ async def suspend(self, instance_id: str, reason: str) -> None: """ request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \ f"suspend?reason={quote(reason)}" - response = await self._post_async_request(request_url, None) + response = await self._post_async_request( + request_url, + None, + function_invocation_id=self._function_invocation_id) switch_statement = { 202: lambda: None, # instance is suspended 410: lambda: None, # instance completed @@ -788,7 +813,10 @@ async def resume(self, instance_id: str, reason: str) -> None: """ request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \ f"resume?reason={quote(reason)}" - response = await self._post_async_request(request_url, None) + response = await self._post_async_request( + request_url, + None, + function_invocation_id=self._function_invocation_id) switch_statement = { 202: lambda: None, # instance is resumed 410: lambda: None, # instance completed diff --git a/azure/durable_functions/models/utils/http_utils.py b/azure/durable_functions/models/utils/http_utils.py index eaa3a07d..cfc62401 100644 --- a/azure/durable_functions/models/utils/http_utils.py +++ b/azure/durable_functions/models/utils/http_utils.py @@ -1,4 +1,4 @@ -from typing import Any, List, Union +from typing import Any, List, Union, Optional import aiohttp @@ -6,7 +6,8 @@ async def post_async_request(url: str, data: Any = None, trace_parent: str = None, - trace_state: str = None) -> List[Union[int, Any]]: + trace_state: str = None, + function_invocation_id: str = None) -> List[Union[int, Any]]: """Post request with the data provided to the url provided. Parameters @@ -19,6 +20,8 @@ async def post_async_request(url: str, traceparent header to send with the request trace_state: str tracestate header to send with the request + function_invocation_id: str + function invocation ID header to send for correlation Returns ------- @@ -31,6 +34,8 @@ async def post_async_request(url: str, headers["traceparent"] = trace_parent if trace_state: headers["tracestate"] = trace_state + if function_invocation_id: + headers["X-Azure-Functions-InvocationId"] = function_invocation_id async with session.post(url, json=data, headers=headers) as response: # We disable aiohttp's input type validation # as the server may respond with alternative @@ -40,13 +45,16 @@ async def post_async_request(url: str, return [response.status, data] -async def get_async_request(url: str) -> List[Any]: +async def get_async_request(url: str, + function_invocation_id: str = None) -> List[Any]: """Get the data from the url provided. Parameters ---------- url: str url to get the data from + function_invocation_id: str + function invocation ID header to send for correlation Returns ------- @@ -54,20 +62,26 @@ async def get_async_request(url: str) -> List[Any]: Tuple with the Response status code and the data returned from the request """ async with aiohttp.ClientSession() as session: - async with session.get(url) as response: + headers = {} + if function_invocation_id: + headers["X-Azure-Functions-InvocationId"] = function_invocation_id + async with session.get(url, headers=headers) as response: data = await response.json(content_type=None) if data is None: data = "" return [response.status, data] -async def delete_async_request(url: str) -> List[Union[int, Any]]: +async def delete_async_request(url: str, + function_invocation_id: str = None) -> List[Union[int, Any]]: """Delete the data from the url provided. Parameters ---------- url: str url to delete the data from + function_invocation_id: str + function invocation ID header to send for correlation Returns ------- @@ -75,6 +89,9 @@ async def delete_async_request(url: str) -> List[Union[int, Any]]: Tuple with the Response status code and the data returned from the request """ async with aiohttp.ClientSession() as session: - async with session.delete(url) as response: + headers = {} + if function_invocation_id: + headers["X-Azure-Functions-InvocationId"] = function_invocation_id + async with session.delete(url, headers=headers) as response: data = await response.json(content_type=None) return [response.status, data] diff --git a/tests/models/test_DurableOrchestrationClient.py b/tests/models/test_DurableOrchestrationClient.py index 7707a63a..c12f577b 100644 --- a/tests/models/test_DurableOrchestrationClient.py +++ b/tests/models/test_DurableOrchestrationClient.py @@ -739,3 +739,73 @@ async def test_post_500_resume(binding_string): with pytest.raises(Exception): await client.resume(TEST_INSTANCE_ID, raw_reason) + + +# Tests for function_invocation_id parameter +def test_client_stores_function_invocation_id(binding_string): + """Test that the client stores the function_invocation_id parameter.""" + invocation_id = "test-invocation-123" + client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id) + assert client._function_invocation_id == invocation_id + + +def test_client_stores_none_when_no_invocation_id(binding_string): + """Test that the client stores None when no invocation ID is provided.""" + client = DurableOrchestrationClient(binding_string) + assert client._function_invocation_id is None + + +class MockRequestWithInvocationId: + """Mock request class that verifies function_invocation_id is passed.""" + + def __init__(self, expected_url: str, response: [int, any], expected_invocation_id: str = None): + self._expected_url = expected_url + self._response = response + self._expected_invocation_id = expected_invocation_id + self._received_invocation_id = None + + @property + def received_invocation_id(self): + return self._received_invocation_id + + async def post(self, url: str, data: Any = None, trace_parent: str = None, + trace_state: str = None, function_invocation_id: str = None): + assert url == self._expected_url + self._received_invocation_id = function_invocation_id + if self._expected_invocation_id is not None: + assert function_invocation_id == self._expected_invocation_id + return self._response + + +@pytest.mark.asyncio +async def test_start_new_passes_invocation_id(binding_string): + """Test that start_new passes the function_invocation_id to the HTTP request.""" + invocation_id = "test-invocation-456" + function_name = "MyOrchestrator" + + mock_request = MockRequestWithInvocationId( + expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}", + response=[202, {"id": TEST_INSTANCE_ID}], + expected_invocation_id=invocation_id) + + client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id) + client._post_async_request = mock_request.post + + await client.start_new(function_name) + assert mock_request.received_invocation_id == invocation_id + + +@pytest.mark.asyncio +async def test_start_new_passes_none_when_no_invocation_id(binding_string): + """Test that start_new passes None when no invocation ID is provided.""" + function_name = "MyOrchestrator" + + mock_request = MockRequestWithInvocationId( + expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}", + response=[202, {"id": TEST_INSTANCE_ID}]) + + client = DurableOrchestrationClient(binding_string) + client._post_async_request = mock_request.post + + await client.start_new(function_name) + assert mock_request.received_invocation_id is None