Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion conformance/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import traceback
from collections.abc import AsyncGenerator
from types import TracebackType
from typing import Any

from connect.connect import StreamRequest, UnaryRequest
Expand All @@ -25,6 +26,93 @@
logger = logging.getLogger("conformance.runner")


class URLLockManager:
"""A manager for handling locks on URLs to ensure sequential execution of tasks with the same URL.

This class maintains a set of URLs that are currently in use and provides locks to ensure
that tasks operating on the same URL are executed sequentially. This prevents race conditions
when multiple tasks are trying to access the same URL simultaneously.
"""

def __init__(self) -> None:
"""Initialize the URL lock manager with an empty set of URLs and a lock."""
self._urls_in_use: set[str] = set()
self._lock = asyncio.Lock()
self._url_locks: dict[str, asyncio.Lock] = {}

async def acquire(self, url: str) -> None:
"""Acquire a lock for the specified URL.

If the URL is already in use, this method will block until the URL is available.

Args:
url (str): The URL to acquire a lock for.

"""
async with self._lock:
if url not in self._url_locks:
self._url_locks[url] = asyncio.Lock()

await self._url_locks[url].acquire()

async with self._lock:
self._urls_in_use.add(url)

def release(self, url: str) -> None:
"""Release the lock for the specified URL.

Args:
url (str): The URL to release the lock for.

"""
if url in self._urls_in_use:
self._urls_in_use.remove(url)
self._url_locks[url].release()

async def __call__(self, url: str) -> "URLLockContext":
"""Create a context manager for the specified URL.

Args:
url (str): The URL to acquire a lock for.

Returns:
A context manager that acquires and releases the lock for the URL.

"""
return URLLockContext(self, url)


class URLLockContext:
"""A context manager for acquiring and releasing URL locks."""

def __init__(self, manager: URLLockManager, url: str):
"""Initialize the context manager with a URL lock manager and URL.

Args:
manager (URLLockManager): The URL lock manager to use.
url (str): The URL to acquire a lock for.

"""
self._manager = manager
self._url = url

async def __aenter__(self) -> None:
"""Acquire the lock for the URL."""
await self._manager.acquire(self._url)

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Release the lock for the URL."""
self._manager.release(self._url)


url_lock_manager = URLLockManager()


def read_request() -> client_compat_pb2.ClientCompatRequest | None:
"""Read a serialized `ClientCompatRequest` message from standard input.

Expand Down Expand Up @@ -186,6 +274,7 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c
for making requests.
- Compression (e.g., gzip) is applied if specified in the request.
- Headers and trailers are converted to protobuf-compatible formats.
- Uses a URL lock manager to ensure that tasks with the same URL are executed sequentially.

"""
reqs = unpack_requests(msg.request_messages)
Expand Down Expand Up @@ -214,7 +303,10 @@ async def handle_message(msg: client_compat_pb2.ClientCompatRequest) -> client_c

url = f"{proto}://{msg.host}:{msg.port}"

async with AsyncClientSession(http1=http1, http2=http2, ssl_context=ssl_context) as session:
async with (
await url_lock_manager(url),
AsyncClientSession(http1=http1, http2=http2, ssl_context=ssl_context) as session,
):
payloads = []
try:
options = ClientOptions()
Expand Down
Loading