Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

All notable changes to this project will be documented in this file.

## Unreleased

### Added

- Client operation correlation logging: `FunctionInvocationId` is now propagated via HTTP headers to the host for client operations, enabling correlation with host logs.

## 1.0.0b6

- [Create timer](https://github.com/Azure/azure-functions-durable-python/issues/35) functionality available
Expand Down
9 changes: 8 additions & 1 deletion azure/durable_functions/decorators/durable_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,14 @@ async def df_client_middleware(*args, **kwargs):
# construct rich object from it,
# and assign parameter to that rich object
starter = kwargs[parameter_name]
client = client_constructor(starter)

# Try to extract the function invocation ID from the context for correlation
function_invocation_id = None
context = kwargs.get('context')
if context is not None and hasattr(context, 'invocation_id'):
function_invocation_id = context.invocation_id

client = client_constructor(starter, function_invocation_id)
kwargs[parameter_name] = client

# Invoke user code with rich DF Client binding
Expand Down
44 changes: 36 additions & 8 deletions azure/durable_functions/models/DurableOrchestrationClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ class DurableOrchestrationClient:
orchestration instances.
"""

def __init__(self, context: str):
def __init__(self, context: str, function_invocation_id: Optional[str] = None):
"""Initialize a DurableOrchestrationClient.

Parameters
----------
context : str
The JSON-encoded client binding context.
function_invocation_id : Optional[str]
The function invocation ID for correlation with host-side logs.
"""
self.task_hub_name: str
self._uniqueWebHookOrigins: List[str]
self._event_name_placeholder: str = "{eventName}"
Expand All @@ -39,6 +48,7 @@ def __init__(self, context: str):
self._show_history_query_key: str = "showHistory"
self._show_history_output_query_key: str = "showHistoryOutput"
self._show_input_query_key: str = "showInput"
self._function_invocation_id: Optional[str] = function_invocation_id
self._orchestration_bindings: DurableOrchestrationBindings = \
DurableOrchestrationBindings.from_json(context)
self._post_async_request = post_async_request
Expand Down Expand Up @@ -84,7 +94,8 @@ async def start_new(self,
request_url,
self._get_json_input(client_input),
trace_parent,
trace_state)
trace_state,
self._function_invocation_id)

status_code: int = response[0]
if status_code <= 202 and response[1]:
Expand All @@ -100,6 +111,7 @@ async def start_new(self,
ex_message: Any = response[1]
raise Exception(ex_message)


def create_check_status_response(
self, request: func.HttpRequest, instance_id: str) -> func.HttpResponse:
"""Create a HttpResponse that contains useful information for \
Expand Down Expand Up @@ -256,7 +268,10 @@ async def raise_event(
request_url = self._get_raise_event_url(
instance_id, event_name, task_hub_name, connection_name)

response = await self._post_async_request(request_url, json.dumps(event_data))
response = await self._post_async_request(
request_url,
json.dumps(event_data),
function_invocation_id=self._function_invocation_id)

switch_statement = {
202: lambda: None,
Expand Down Expand Up @@ -445,7 +460,10 @@ async def terminate(self, instance_id: str, reason: str) -> None:
"""
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
f"terminate?reason={quote(reason)}"
response = await self._post_async_request(request_url, None)
response = await self._post_async_request(
request_url,
None,
function_invocation_id=self._function_invocation_id)
switch_statement = {
202: lambda: None, # instance in progress
410: lambda: None, # instance failed or terminated
Expand Down Expand Up @@ -564,7 +582,8 @@ async def signal_entity(self, entityId: EntityId, operation_name: str,
request_url,
json.dumps(operation_input) if operation_input else None,
trace_parent,
trace_state)
trace_state,
self._function_invocation_id)

switch_statement = {
202: lambda: None # signal accepted
Expand Down Expand Up @@ -714,7 +733,10 @@ async def rewind(self,
raise Exception("The Python SDK only supports RPC endpoints."
+ "Please remove the `localRpcEnabled` setting from host.json")

response = await self._post_async_request(request_url, None)
response = await self._post_async_request(
request_url,
None,
function_invocation_id=self._function_invocation_id)
status: int = response[0]
ex_msg: str = ""
if status == 200 or status == 202:
Expand Down Expand Up @@ -753,7 +775,10 @@ async def suspend(self, instance_id: str, reason: str) -> None:
"""
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
f"suspend?reason={quote(reason)}"
response = await self._post_async_request(request_url, None)
response = await self._post_async_request(
request_url,
None,
function_invocation_id=self._function_invocation_id)
switch_statement = {
202: lambda: None, # instance is suspended
410: lambda: None, # instance completed
Expand Down Expand Up @@ -788,7 +813,10 @@ async def resume(self, instance_id: str, reason: str) -> None:
"""
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
f"resume?reason={quote(reason)}"
response = await self._post_async_request(request_url, None)
response = await self._post_async_request(
request_url,
None,
function_invocation_id=self._function_invocation_id)
switch_statement = {
202: lambda: None, # instance is resumed
410: lambda: None, # instance completed
Expand Down
29 changes: 23 additions & 6 deletions azure/durable_functions/models/utils/http_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, List, Union
from typing import Any, List, Union, Optional

import aiohttp


async def post_async_request(url: str,
data: Any = None,
trace_parent: str = None,
trace_state: str = None) -> List[Union[int, Any]]:
trace_state: str = None,
function_invocation_id: str = None) -> List[Union[int, Any]]:
"""Post request with the data provided to the url provided.

Parameters
Expand All @@ -19,6 +20,8 @@ async def post_async_request(url: str,
traceparent header to send with the request
trace_state: str
tracestate header to send with the request
function_invocation_id: str
function invocation ID header to send for correlation

Returns
-------
Expand All @@ -31,6 +34,8 @@ async def post_async_request(url: str,
headers["traceparent"] = trace_parent
if trace_state:
headers["tracestate"] = trace_state
if function_invocation_id:
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
async with session.post(url, json=data, headers=headers) as response:
# We disable aiohttp's input type validation
# as the server may respond with alternative
Expand All @@ -40,41 +45,53 @@ async def post_async_request(url: str,
return [response.status, data]


async def get_async_request(url: str) -> List[Any]:
async def get_async_request(url: str,
function_invocation_id: str = None) -> List[Any]:
"""Get the data from the url provided.

Parameters
----------
url: str
url to get the data from
function_invocation_id: str
function invocation ID header to send for correlation

Returns
-------
[int, Any]
Tuple with the Response status code and the data returned from the request
"""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
headers = {}
if function_invocation_id:
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
async with session.get(url, headers=headers) as response:
data = await response.json(content_type=None)
if data is None:
data = ""
return [response.status, data]


async def delete_async_request(url: str) -> List[Union[int, Any]]:
async def delete_async_request(url: str,
function_invocation_id: str = None) -> List[Union[int, Any]]:
"""Delete the data from the url provided.

Parameters
----------
url: str
url to delete the data from
function_invocation_id: str
function invocation ID header to send for correlation

Returns
-------
[int, Any]
Tuple with the Response status code and the data returned from the request
"""
async with aiohttp.ClientSession() as session:
async with session.delete(url) as response:
headers = {}
if function_invocation_id:
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
async with session.delete(url, headers=headers) as response:
data = await response.json(content_type=None)
return [response.status, data]
70 changes: 70 additions & 0 deletions tests/models/test_DurableOrchestrationClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,73 @@ async def test_post_500_resume(binding_string):

with pytest.raises(Exception):
await client.resume(TEST_INSTANCE_ID, raw_reason)


# Tests for function_invocation_id parameter
def test_client_stores_function_invocation_id(binding_string):
"""Test that the client stores the function_invocation_id parameter."""
invocation_id = "test-invocation-123"
client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id)
assert client._function_invocation_id == invocation_id


def test_client_stores_none_when_no_invocation_id(binding_string):
"""Test that the client stores None when no invocation ID is provided."""
client = DurableOrchestrationClient(binding_string)
assert client._function_invocation_id is None


class MockRequestWithInvocationId:
"""Mock request class that verifies function_invocation_id is passed."""

def __init__(self, expected_url: str, response: [int, any], expected_invocation_id: str = None):
self._expected_url = expected_url
self._response = response
self._expected_invocation_id = expected_invocation_id
self._received_invocation_id = None

@property
def received_invocation_id(self):
return self._received_invocation_id

async def post(self, url: str, data: Any = None, trace_parent: str = None,
trace_state: str = None, function_invocation_id: str = None):
assert url == self._expected_url
self._received_invocation_id = function_invocation_id
if self._expected_invocation_id is not None:
assert function_invocation_id == self._expected_invocation_id
return self._response


@pytest.mark.asyncio
async def test_start_new_passes_invocation_id(binding_string):
"""Test that start_new passes the function_invocation_id to the HTTP request."""
invocation_id = "test-invocation-456"
function_name = "MyOrchestrator"

mock_request = MockRequestWithInvocationId(
expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}",
response=[202, {"id": TEST_INSTANCE_ID}],
expected_invocation_id=invocation_id)

client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id)
client._post_async_request = mock_request.post

await client.start_new(function_name)
assert mock_request.received_invocation_id == invocation_id


@pytest.mark.asyncio
async def test_start_new_passes_none_when_no_invocation_id(binding_string):
"""Test that start_new passes None when no invocation ID is provided."""
function_name = "MyOrchestrator"

mock_request = MockRequestWithInvocationId(
expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}",
response=[202, {"id": TEST_INSTANCE_ID}])

client = DurableOrchestrationClient(binding_string)
client._post_async_request = mock_request.post

await client.start_new(function_name)
assert mock_request.received_invocation_id is None
Loading