diff --git a/azure/durable_functions/models/utils/http_utils.py b/azure/durable_functions/models/utils/http_utils.py index eaa3a07d..f93bc8df 100644 --- a/azure/durable_functions/models/utils/http_utils.py +++ b/azure/durable_functions/models/utils/http_utils.py @@ -1,8 +1,80 @@ -from typing import Any, List, Union +from typing import Any, List, Union, Optional +import asyncio import aiohttp +# Global session and lock for thread-safe initialization +_client_session: Optional[aiohttp.ClientSession] = None +_session_lock: asyncio.Lock = asyncio.Lock() + + +async def _get_session() -> aiohttp.ClientSession: + """Get or create the shared ClientSession. + + Returns + ------- + aiohttp.ClientSession + The shared client session with configured timeout and connection pooling. + """ + global _client_session + + # Double-check locking pattern for async + if _client_session is None or _client_session.closed: + async with _session_lock: + # Check again after acquiring lock + if _client_session is None or _client_session.closed: + # Configure timeout optimized for localhost IPC + timeout = aiohttp.ClientTimeout( + total=240, # 4-minute total timeout for slow operations + sock_connect=10, # Fast connection over localhost + sock_read=None # Covered by total timeout + ) + + # Configure TCP connector optimized for localhost IPC + connector = aiohttp.TCPConnector( + limit=30, # Maximum connections for single host + limit_per_host=30, # Maximum connections per host + enable_cleanup_closed=True # Enable cleanup of closed connections + ) + + _client_session = aiohttp.ClientSession( + timeout=timeout, + connector=connector + ) + + return _client_session + + +async def _handle_request_error(): + """Handle connection errors by closing and resetting the session. + + This handles cases where the remote host process recycles. + """ + global _client_session + async with _session_lock: + if _client_session is not None and not _client_session.closed: + try: + await _client_session.close() + finally: + _client_session = None + + +async def _close_session() -> None: + """Close the shared ClientSession if it exists. + + This function should be called during worker shutdown. + """ + global _client_session + + async with _session_lock: + if _client_session is not None and not _client_session.closed: + try: + await _client_session.close() + finally: + _client_session = None + + async def post_async_request(url: str, data: Any = None, trace_parent: str = None, @@ -25,12 +97,14 @@ async def post_async_request(url: str, [int, Any] Tuple with the Response status code and the data returned from the request """ - async with aiohttp.ClientSession() as session: - headers = {} - if trace_parent: - headers["traceparent"] = trace_parent - if trace_state: - headers["tracestate"] = trace_state + session = await _get_session() + headers = {} + if trace_parent: + headers["traceparent"] = trace_parent + if trace_state: + headers["tracestate"] = trace_state + + try: async with session.post(url, json=data, headers=headers) as response: # We disable aiohttp's input type validation # as the server may respond with alternative @@ -38,6 +112,10 @@ async def post_async_request(url: str, # More here: https://docs.aiohttp.org/en/stable/client_advanced.html data = await response.json(content_type=None) return [response.status, data] + except (aiohttp.ClientError, asyncio.TimeoutError): + # On connection errors, close and recreate session for next request + await _handle_request_error() + raise async def get_async_request(url: str) -> List[Any]: @@ -53,12 +131,18 @@ async def get_async_request(url: str) -> List[Any]: [int, Any] Tuple with the Response status code and the data returned from the request """ - async with aiohttp.ClientSession() as session: + session = await _get_session() + + try: async with session.get(url) as response: data = await response.json(content_type=None) if data is None: data = "" return [response.status, data] + except (aiohttp.ClientError, asyncio.TimeoutError): + # On connection errors, close and recreate session for next request + await _handle_request_error() + raise async def delete_async_request(url: str) -> List[Union[int, Any]]: @@ -74,7 +158,13 @@ async def delete_async_request(url: str) -> List[Union[int, Any]]: [int, Any] Tuple with the Response status code and the data returned from the request """ - async with aiohttp.ClientSession() as session: + session = await _get_session() + + try: async with session.delete(url) as response: data = await response.json(content_type=None) return [response.status, data] + except (aiohttp.ClientError, asyncio.TimeoutError): + # On connection errors, close and recreate session for next request + await _handle_request_error() + raise diff --git a/tests/utils/test_http_utils.py b/tests/utils/test_http_utils.py new file mode 100644 index 00000000..3b3b165a --- /dev/null +++ b/tests/utils/test_http_utils.py @@ -0,0 +1,287 @@ +"""Tests for http_utils module to verify ClientSession reuse.""" +import pytest +from unittest.mock import AsyncMock, patch, Mock +from azure.durable_functions.models.utils import http_utils + + +@pytest.mark.asyncio +async def test_session_is_reused_across_requests(): + """Test that the same session is reused for multiple requests.""" + # Reset the session to start fresh + http_utils._client_session = None + + # Make first request to create session + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session = Mock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "success"}) + + # Create a proper async context manager + mock_post_context = AsyncMock() + mock_post_context.__aenter__.return_value = mock_response + mock_post_context.__aexit__.return_value = None + mock_session.post.return_value = mock_post_context + mock_session.closed = False + mock_session_class.return_value = mock_session + + # First request + await http_utils.post_async_request("http://test.com", + {"data": "test1"}) + + # Verify session was created once + assert mock_session_class.call_count == 1 + first_session = http_utils._client_session + + # Second request - should reuse same session + await http_utils.post_async_request("http://test.com", + {"data": "test2"}) + + # Verify session was NOT created again + assert mock_session_class.call_count == 1 + assert http_utils._client_session is first_session + + +@pytest.mark.asyncio +async def test_session_recreated_after_close(): + """Test that a new session is created if the previous one was closed.""" + # Reset the session + http_utils._client_session = None + + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session1 = Mock() + mock_session1.closed = False + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "success"}) + + mock_post_context = AsyncMock() + mock_post_context.__aenter__.return_value = mock_response + mock_post_context.__aexit__.return_value = None + mock_session1.post.return_value = mock_post_context + + mock_session2 = Mock() + mock_session2.closed = False + mock_session2.post.return_value = mock_post_context + + mock_session_class.side_effect = [mock_session1, mock_session2] + + # First request creates session + await http_utils.post_async_request("http://test.com", + {"data": "test1"}) + assert mock_session_class.call_count == 1 + + # Simulate session being closed + mock_session1.closed = True + + # Second request should create new session + await http_utils.post_async_request("http://test.com", + {"data": "test2"}) + assert mock_session_class.call_count == 2 + + +@pytest.mark.asyncio +async def test_session_closed_on_connection_error(): + """Test that session is closed and reset on connection errors.""" + # Reset the session + http_utils._client_session = None + + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session = Mock() + mock_session.closed = False + mock_session.close = AsyncMock() + + # First request succeeds + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "success"}) + + mock_post_context_success = AsyncMock() + mock_post_context_success.__aenter__.return_value = mock_response + mock_post_context_success.__aexit__.return_value = None + + mock_session.post.return_value = mock_post_context_success + mock_session_class.return_value = mock_session + + await http_utils.post_async_request("http://test.com", + {"data": "test1"}) + assert http_utils._client_session is not None + + # Second request raises connection error + from aiohttp import ClientError + mock_post_context_error = AsyncMock() + mock_post_context_error.__aenter__.side_effect = \ + ClientError("Connection failed") + mock_session.post.return_value = mock_post_context_error + + with pytest.raises(ClientError): + await http_utils.post_async_request("http://test.com", + {"data": "test2"}) + + # Verify close was called + mock_session.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_request_uses_shared_session(): + """Test that GET requests use the shared session.""" + # Reset the session + http_utils._client_session = None + + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session = Mock() + mock_session.closed = False + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "data"}) + + mock_get_context = AsyncMock() + mock_get_context.__aenter__.return_value = mock_response + mock_get_context.__aexit__.return_value = None + mock_session.get.return_value = mock_get_context + mock_session_class.return_value = mock_session + + # Make GET request + await http_utils.get_async_request("http://test.com") + + # Make another GET request + await http_utils.get_async_request("http://test.com") + + # Verify session was created only once + assert mock_session_class.call_count == 1 + + +@pytest.mark.asyncio +async def test_delete_request_uses_shared_session(): + """Test that DELETE requests use the shared session.""" + # Reset the session + http_utils._client_session = None + + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session = Mock() + mock_session.closed = False + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "deleted"}) + + mock_delete_context = AsyncMock() + mock_delete_context.__aenter__.return_value = mock_response + mock_delete_context.__aexit__.return_value = None + mock_session.delete.return_value = mock_delete_context + mock_session_class.return_value = mock_session + + # Make DELETE request + await http_utils.delete_async_request("http://test.com") + + # Make another DELETE request + await http_utils.delete_async_request("http://test.com") + + # Verify session was created only once + assert mock_session_class.call_count == 1 + + +@pytest.mark.asyncio +async def test_session_configured_with_timeouts(): + """Test that session is configured with appropriate timeouts.""" + # Reset the session + http_utils._client_session = None + + with patch('aiohttp.ClientSession') as mock_session_class, \ + patch('aiohttp.ClientTimeout') as mock_timeout_class, \ + patch('aiohttp.TCPConnector') as mock_connector_class: + + mock_session = Mock() + mock_session.closed = False + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "success"}) + + mock_post_context = AsyncMock() + mock_post_context.__aenter__.return_value = mock_response + mock_post_context.__aexit__.return_value = None + mock_session.post.return_value = mock_post_context + mock_session_class.return_value = mock_session + + await http_utils.post_async_request("http://test.com", + {"data": "test"}) + + # Verify timeout was configured for localhost IPC + mock_timeout_class.assert_called_once() + timeout_call = mock_timeout_class.call_args + assert timeout_call.kwargs['total'] == 240 + assert timeout_call.kwargs['sock_connect'] == 10 + assert timeout_call.kwargs['sock_read'] is None + + # Verify connector was configured for localhost IPC + mock_connector_class.assert_called_once() + connector_call = mock_connector_class.call_args + assert connector_call.kwargs['limit'] == 30 + assert connector_call.kwargs['limit_per_host'] == 30 + + +@pytest.mark.asyncio +async def test_close_session(): + """Test the _close_session function.""" + # Reset and create a session + http_utils._client_session = None + + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session = Mock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "success"}) + + mock_post_context = AsyncMock() + mock_post_context.__aenter__.return_value = mock_response + mock_post_context.__aexit__.return_value = None + mock_session.post.return_value = mock_post_context + mock_session_class.return_value = mock_session + + # Create session + await http_utils.post_async_request("http://test.com", + {"data": "test"}) + assert http_utils._client_session is not None + + # Close session + await http_utils._close_session() + + # Verify close was called and session is None + mock_session.close.assert_called_once() + assert http_utils._client_session is None + + +@pytest.mark.asyncio +async def test_trace_headers_are_passed(): + """Test that trace headers are properly passed in requests.""" + # Reset the session + http_utils._client_session = None + + with patch('aiohttp.ClientSession') as mock_session_class: + mock_session = Mock() + mock_session.closed = False + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"result": "success"}) + + mock_post_context = AsyncMock() + mock_post_context.__aenter__.return_value = mock_response + mock_post_context.__aexit__.return_value = None + mock_session.post.return_value = mock_post_context + mock_session_class.return_value = mock_session + + trace_parent = "00-trace-id-parent" + trace_state = "state=value" + + await http_utils.post_async_request( + "http://test.com", + {"data": "test"}, + trace_parent=trace_parent, + trace_state=trace_state + ) + + # Verify headers were passed + call_args = mock_session.post.call_args + assert call_args.kwargs['headers']['traceparent'] == trace_parent + assert call_args.kwargs['headers']['tracestate'] == trace_state