diff --git a/tests/coordination/test_coordination_client.py b/tests/coordination/test_coordination_client.py index 98fb6768..cf91a3a0 100644 --- a/tests/coordination/test_coordination_client.py +++ b/tests/coordination/test_coordination_client.py @@ -1,13 +1,19 @@ +import asyncio +import threading +import time + import pytest import ydb -from ydb import aio +from ydb import aio, StatusCode, logger from ydb.coordination import ( NodeConfig, ConsistencyMode, RateLimiterCountersMode, CoordinationClient, + CreateSemaphoreResult, + DescribeLockResult, ) @@ -93,3 +99,170 @@ async def test_coordination_node_lifecycle_async(self, aio_connection): with pytest.raises(ydb.SchemeError): await client.describe_node(node_path) + + async def test_coordination_lock_full_lifecycle(self, aio_connection): + client = aio.CoordinationClient(aio_connection) + + node_path = "/local/test_lock_full_lifecycle" + + try: + await client.delete_node(node_path) + except ydb.SchemeError: + pass + + await client.create_node( + node_path, + NodeConfig( + session_grace_period_millis=1000, + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, + self_check_period_millis=0, + ), + ) + + lock = client.lock("test_lock", node_path) + + create_resp: CreateSemaphoreResult = await lock.create(init_limit=1, init_data=b"init-data") + assert create_resp.status == StatusCode.SUCCESS + + describe_resp: DescribeLockResult = await lock.describe() + assert describe_resp.status == StatusCode.SUCCESS + assert describe_resp.name == "test_lock" + assert describe_resp.data == b"init-data" + assert describe_resp.count == 0 + assert describe_resp.ephemeral is False + assert list(describe_resp.owners) == [] + assert list(describe_resp.waiters) == [] + + update_resp = await lock.update(new_data=b"updated-data") + assert update_resp.status == StatusCode.SUCCESS + + describe_resp2: DescribeLockResult = await lock.describe() + assert describe_resp2.status == StatusCode.SUCCESS + assert describe_resp2.name == "test_lock" + assert describe_resp2.data == b"updated-data" + assert describe_resp2.count == 0 + assert describe_resp2.ephemeral is False + assert list(describe_resp2.owners) == [] + assert list(describe_resp2.waiters) == [] + + lock2_started = asyncio.Event() + lock2_acquired = asyncio.Event() + + async def second_lock_task(): + lock2_started.set() + async with client.lock("test_lock", node_path): + lock2_acquired.set() + await asyncio.sleep(0.5) + + async with client.lock("test_lock", node_path) as lock1: + + resp: DescribeLockResult = await lock1.describe() + assert resp.status == StatusCode.SUCCESS + assert resp.name == "test_lock" + assert resp.data == b"updated-data" + assert resp.count == 1 + assert resp.ephemeral is False + assert len(list(resp.owners)) == 1 + assert list(resp.waiters) == [] + + t2 = asyncio.create_task(second_lock_task()) + await lock2_started.wait() + + await asyncio.sleep(0.5) + + await asyncio.wait_for(lock2_acquired.wait(), timeout=5) + await asyncio.wait_for(t2, timeout=5) + + async with client.lock("test_lock", node_path) as lock3: + + resp3: DescribeLockResult = await lock3.describe() + assert resp3.status == StatusCode.SUCCESS + assert resp3.count == 1 + + delete_resp = await lock.delete() + assert delete_resp.status == StatusCode.SUCCESS + + describe_after_delete: DescribeLockResult = await lock.describe() + assert describe_after_delete.status == StatusCode.NOT_FOUND + + def test_coordination_lock_full_lifecycle_sync(self, driver_sync): + client = CoordinationClient(driver_sync) + node_path = "/local/test_lock_full_lifecycle" + + try: + client.delete_node(node_path) + except ydb.SchemeError: + pass + + client.create_node( + node_path, + NodeConfig( + session_grace_period_millis=1000, + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, + self_check_period_millis=0, + ), + ) + + lock = client.lock("test_lock", node_path) + + create_resp: CreateSemaphoreResult = lock.create(init_limit=1, init_data=b"init-data") + assert create_resp.status == StatusCode.SUCCESS + + describe_resp: DescribeLockResult = lock.describe() + assert describe_resp.status == StatusCode.SUCCESS + assert describe_resp.data == b"init-data" + + update_resp = lock.update(new_data=b"updated-data") + assert update_resp.status == StatusCode.SUCCESS + assert lock.describe().data == b"updated-data" + + lock2_ready = threading.Event() + lock2_acquired = threading.Event() + thread_exc = {"err": None} + + def second_lock_task(): + try: + lock2_ready.set() + with client.lock("test_lock", node_path): + lock2_acquired.set() + logger.info("Second thread acquired lock") + except Exception as e: + logger.exception("second_lock_task failed") + thread_exc["err"] = e + + t2 = threading.Thread(target=second_lock_task) + + with client.lock("test_lock", node_path) as lock1: + resp = lock1.describe() + assert resp.status == StatusCode.SUCCESS + assert resp.count == 1 + + t2.start() + started = lock2_ready.wait(timeout=2.0) + assert started, "Second thread did not signal readiness to acquire lock" + + acquired = lock2_acquired.wait(timeout=10.0) + t2.join(timeout=5.0) + + if not acquired: + if thread_exc["err"]: + raise AssertionError(f"Second thread raised exception: {thread_exc['err']!r}") from thread_exc["err"] + else: + raise AssertionError("Second thread did not acquire the lock in time. Check logs for details.") + + assert not t2.is_alive(), "Second thread did not finish after acquiring lock" + + with client.lock("test_lock", node_path) as lock3: + resp3: DescribeLockResult = lock3.describe() + assert resp3.status == StatusCode.SUCCESS + assert resp3.count == 1 + + delete_resp = lock.delete() + assert delete_resp.status == StatusCode.SUCCESS + time.sleep(0.1) + describe_after_delete: DescribeLockResult = lock.describe() + assert describe_after_delete.status == StatusCode.NOT_FOUND diff --git a/ydb/_apis.py b/ydb/_apis.py index 97f64b90..595550b2 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -143,9 +143,9 @@ class QueryService(object): class CoordinationService(object): Stub = ydb_coordination_v1_pb2_grpc.CoordinationServiceStub - - Session = "Session" CreateNode = "CreateNode" AlterNode = "AlterNode" DropNode = "DropNode" DescribeNode = "DescribeNode" + SessionRequest = "SessionRequest" + Session = "Session" diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 0fb960d6..fd149bd2 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -235,7 +235,8 @@ async def get_response(): except (grpc.RpcError, grpc.aio.AioRpcError) as e: raise connection._rpc_error_handler(self._connection_state, e) - issues._process_response(grpc_message) + # coordination grpc calls dont have status field + # issues._process_response(grpc_message) if self._connection_state != "has_received_messages": self._connection_state = "has_received_messages" diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination.py b/ydb/_grpc/grpcwrapper/ydb_coordination.py index 176e4e02..8794b570 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination.py @@ -16,7 +16,7 @@ class CreateNodeRequest(IToProto): path: str config: typing.Optional[NodeConfig] - def to_proto(self) -> ydb_coordination_pb2.CreateNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.CreateNodeRequest": cfg_proto = self.config.to_proto() if self.config else None return ydb_coordination_pb2.CreateNodeRequest( path=self.path, @@ -29,7 +29,7 @@ class AlterNodeRequest(IToProto): path: str config: NodeConfig - def to_proto(self) -> ydb_coordination_pb2.AlterNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.AlterNodeRequest": cfg_proto = self.config.to_proto() if self.config else None return ydb_coordination_pb2.AlterNodeRequest( path=self.path, @@ -41,7 +41,7 @@ def to_proto(self) -> ydb_coordination_pb2.AlterNodeRequest: class DescribeNodeRequest(IToProto): path: str - def to_proto(self) -> ydb_coordination_pb2.DescribeNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.DescribeNodeRequest": return ydb_coordination_pb2.DescribeNodeRequest( path=self.path, ) @@ -51,7 +51,174 @@ def to_proto(self) -> ydb_coordination_pb2.DescribeNodeRequest: class DropNodeRequest(IToProto): path: str - def to_proto(self) -> ydb_coordination_pb2.DropNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.DropNodeRequest": return ydb_coordination_pb2.DropNodeRequest( path=self.path, ) + + +@dataclass +class SessionStart(IToProto): + path: str + timeout_millis: int + description: str = "" + session_id: int = 0 + seq_no: int = 0 + protection_key: bytes = b"" + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + session_start=ydb_coordination_pb2.SessionRequest.SessionStart( + path=self.path, + session_id=self.session_id, + timeout_millis=self.timeout_millis, + description=self.description, + seq_no=self.seq_no, + protection_key=self.protection_key, + ) + ) + + +@dataclass +class SessionStop(IToProto): + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest(session_stop=ydb_coordination_pb2.SessionRequest.SessionStop()) + + +@dataclass +class Ping(IToProto): + opaque: int = 0 + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + ping=ydb_coordination_pb2.SessionRequest.PingPong(opaque=self.opaque) + ) + + +@dataclass +class CreateSemaphore(IToProto): + name: str + req_id: int + limit: int + data: bytes = b"" + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + create_semaphore=ydb_coordination_pb2.SessionRequest.CreateSemaphore( + req_id=self.req_id, name=self.name, limit=self.limit, data=self.data + ) + ) + + +@dataclass +class AcquireSemaphore(IToProto): + name: str + req_id: int + count: int = 1 + timeout_millis: int = 0 + data: bytes = b"" + ephemeral: bool = False + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + acquire_semaphore=ydb_coordination_pb2.SessionRequest.AcquireSemaphore( + req_id=self.req_id, + name=self.name, + timeout_millis=self.timeout_millis, + count=self.count, + data=self.data, + ephemeral=self.ephemeral, + ) + ) + + +@dataclass +class ReleaseSemaphore(IToProto): + name: str + req_id: int + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + release_semaphore=ydb_coordination_pb2.SessionRequest.ReleaseSemaphore(req_id=self.req_id, name=self.name) + ) + + +@dataclass +class DescribeSemaphore(IToProto): + include_owners: bool + include_waiters: bool + name: str + req_id: int + watch_data: bool + watch_owners: bool + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + describe_semaphore=ydb_coordination_pb2.SessionRequest.DescribeSemaphore( + include_owners=self.include_owners, + include_waiters=self.include_waiters, + name=self.name, + req_id=self.req_id, + watch_data=self.watch_data, + watch_owners=self.watch_owners, + ) + ) + + +@dataclass +class UpdateSemaphore(IToProto): + name: str + req_id: int + data: bytes + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + update_semaphore=ydb_coordination_pb2.SessionRequest.UpdateSemaphore( + req_id=self.req_id, name=self.name, data=self.data + ) + ) + + +@dataclass +class DeleteSemaphore(IToProto): + name: str + req_id: int + force: bool = False + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + delete_semaphore=ydb_coordination_pb2.SessionRequest.DeleteSemaphore( + req_id=self.req_id, name=self.name, force=self.force + ) + ) + + +@dataclass +class FromServer: + raw: "ydb_coordination_pb2.SessionResponse" + + @staticmethod + def from_proto(resp: "ydb_coordination_pb2.SessionResponse") -> "FromServer": + return FromServer(raw=resp) + + def __getattr__(self, name: str): + return getattr(self.raw, name) + + @property + def session_started(self) -> typing.Optional["ydb_coordination_pb2.SessionResponse.SessionStarted"]: + s = self.raw.session_started + return s if s.session_id else None + + @property + def opaque(self) -> typing.Optional[int]: + if self.raw.HasField("ping"): + return self.raw.ping.opaque + return None + + @property + def acquire_semaphore_result(self): + return self.raw.acquire_semaphore_result if self.raw.HasField("acquire_semaphore_result") else None + + @property + def create_semaphore_result(self): + return self.raw.create_semaphore_result if self.raw.HasField("create_semaphore_result") else None diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py index a3580974..1112cd4b 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py @@ -2,7 +2,6 @@ from enum import IntEnum import typing - if typing.TYPE_CHECKING: from ..v4.protos import ydb_coordination_pb2 else: @@ -55,3 +54,60 @@ def from_proto(msg: ydb_coordination_pb2.DescribeNodeResponse) -> "NodeConfig": result = ydb_coordination_pb2.DescribeNodeResult() msg.operation.result.Unpack(result) return NodeConfig.from_proto(result.config) + + +@dataclass +class AcquireSemaphoreResult: + req_id: int + acquired: bool + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.AcquireSemaphoreResult) -> "AcquireSemaphoreResult": + return AcquireSemaphoreResult( + req_id=msg.req_id, + acquired=msg.acquired, + status=msg.status, + ) + + +@dataclass +class CreateSemaphoreResult: + req_id: int + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.CreateSemaphoreResult) -> "CreateSemaphoreResult": + return CreateSemaphoreResult( + req_id=msg.req_id, + status=msg.status, + ) + + +@dataclass +class DescribeLockResult: + req_id: int + status: int + watch_added: bool + count: int + data: bytes + ephemeral: bool + limit: int + name: str + owners: list + waiters: list + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.DescribeSemaphoreResult) -> "DescribeLockResult": + return DescribeLockResult( + req_id=msg.req_id, + status=msg.status, + watch_added=msg.watch_added, + count=msg.semaphore_description.count, + data=msg.semaphore_description.data, + ephemeral=msg.semaphore_description.ephemeral, + limit=msg.semaphore_description.limit, + name=msg.semaphore_description.name, + owners=msg.semaphore_description.owners, + waiters=msg.semaphore_description.waiters, + ) diff --git a/ydb/aio/__init__.py b/ydb/aio/__init__.py index d38d9e73..9747666f 100644 --- a/ydb/aio/__init__.py +++ b/ydb/aio/__init__.py @@ -1,4 +1,4 @@ from .driver import Driver # noqa from .table import SessionPool, retry_operation # noqa from .query import QuerySessionPool, QuerySession, QueryTxContext # noqa -from .coordination_client import CoordinationClient # noqa +from .coordination import CoordinationClient # noqa diff --git a/ydb/aio/coordination/__init__.py b/ydb/aio/coordination/__init__.py new file mode 100644 index 00000000..f6d48237 --- /dev/null +++ b/ydb/aio/coordination/__init__.py @@ -0,0 +1,5 @@ +__all__ = [ + "CoordinationClient", +] + +from .client import CoordinationClient diff --git a/ydb/aio/coordination_client.py b/ydb/aio/coordination/client.py similarity index 82% rename from ydb/aio/coordination_client.py rename to ydb/aio/coordination/client.py index 9aab6785..90d86df4 100644 --- a/ydb/aio/coordination_client.py +++ b/ydb/aio/coordination/client.py @@ -9,8 +9,14 @@ from ydb._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig from ydb.coordination.base_coordination_client import BaseCoordinationClient +from ydb.aio.coordination.lock import CoordinationLock + class CoordinationClient(BaseCoordinationClient): + def __init__(self, driver): + super().__init__(driver) + self._driver = driver + async def create_node(self, path: str, config: Optional[NodeConfig] = None, settings=None): return await self._call_create( CreateNodeRequest(path=path, config=config).to_proto(), @@ -35,5 +41,5 @@ async def delete_node(self, path: str, settings=None): settings=settings, ) - async def lock(self): - raise NotImplementedError("Will be implemented in future release") + def lock(self, lock_name: str, node_path: str): + return CoordinationLock(self, lock_name, node_path=node_path) diff --git a/ydb/aio/coordination/lock.py b/ydb/aio/coordination/lock.py new file mode 100644 index 00000000..bd0a5584 --- /dev/null +++ b/ydb/aio/coordination/lock.py @@ -0,0 +1,207 @@ +import asyncio +from typing import Optional + +from ydb import issues, StatusCode +from ydb._grpc.grpcwrapper.ydb_coordination import ( + AcquireSemaphore, + ReleaseSemaphore, + UpdateSemaphore, + DescribeSemaphore, + CreateSemaphore, + DeleteSemaphore, + FromServer, +) +from ydb._grpc.grpcwrapper.ydb_coordination_public_types import CreateSemaphoreResult, DescribeLockResult +from ydb.aio.coordination.stream import CoordinationStream +from ydb.aio.coordination.reconnector import CoordinationReconnector + + +class CoordinationLock: + def __init__( + self, + client, + name: str, + node_path: Optional[str] = None, + ): + self._client = client + self._driver = client._driver + self._name = name + self._node_path = node_path + + self._req_id: Optional[int] = None + self._count: int = 1 + self._timeout_millis: int = 30000 + self._next_req_id: int = 1 + + self._request_queue: asyncio.Queue = asyncio.Queue() + self._stream: Optional[CoordinationStream] = None + + self._reconnector = CoordinationReconnector( + driver=self._driver, + request_queue=self._request_queue, + node_path=self._node_path, + timeout_millis=self._timeout_millis, + ) + + self._wait_timeout: float = self._timeout_millis / 1000.0 + + def next_req_id(self) -> int: + r = self._next_req_id + self._next_req_id += 1 + return r + + async def send(self, req): + if self._stream is None: + raise issues.Error("Stream is not started yet") + await self._stream.send(req) + + async def _ensure_session(self): + if self._stream is not None and self._stream.session_id is not None: + return + + if not self._node_path: + raise issues.Error("node_path is not set for CoordinationLock") + + self._reconnector.start() + await self._reconnector.wait_ready() + + self._stream = self._reconnector.get_stream() + + async def _wait_for_response(self, req_id: int, *, kind: str): + try: + while True: + resp = await self._stream.receive(self._wait_timeout) + + fs = FromServer.from_proto(resp) + + if kind == "acquire": + r = fs.acquire_semaphore_result + elif kind == "describe": + r = fs.describe_semaphore_result + elif kind == "create": + r = fs.create_semaphore_result + elif kind == "update": + r = fs.update_semaphore_result + elif kind == "delete": + r = fs.delete_semaphore_result + else: + r = None + + if r and r.req_id == req_id: + return r + + except asyncio.TimeoutError: + action = { + "acquire": "acquisition", + "describe": "describe", + "update": "update", + "delete": "delete", + "create": "create", + }.get(kind, "operation") + + raise issues.Error(f"Timeout waiting for lock {self._name} {action}") + + async def __aenter__(self): + await self._ensure_session() + + req_id = self.next_req_id() + self._req_id = req_id + + req = AcquireSemaphore( + req_id=req_id, + name=self._name, + count=self._count, + ephemeral=False, + timeout_millis=self._timeout_millis, + ) + + await self.send(req) + await self._wait_for_response(req_id, kind="acquire") + + return self + + + async def __aexit__(self, exc_type, exc, tb): + if self._req_id is not None: + try: + req = ReleaseSemaphore( + req_id=self._req_id, + name=self._name, + ) + await self.send(req) + except issues.Error: + pass + + self._req_id = None + + async def acquire(self): + return await self.__aenter__() + + async def release(self): + await self.__aexit__(None, None, None) + + async def create(self, init_limit, init_data): + await self._ensure_session() + + req_id = self.next_req_id() + + req = CreateSemaphore(req_id=req_id, name=self._name, limit=init_limit, data=init_data) + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="create") + return CreateSemaphoreResult.from_proto(resp) + + async def delete(self): + await self._ensure_session() + req_id = self.next_req_id() + req = DeleteSemaphore(req_id=req_id, name=self._name) + await self.send(req) + resp = await self._wait_for_response(req_id, kind="delete") + return resp + + async def describe(self): + await self._ensure_session() + + req_id = self.next_req_id() + + req = DescribeSemaphore( + req_id=req_id, + name=self._name, + include_owners=True, + include_waiters=True, + watch_data=False, + watch_owners=False, + ) + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="describe") + return DescribeLockResult.from_proto(resp) + + async def update(self, new_data): + await self._ensure_session() + + req_id = self.next_req_id() + req = UpdateSemaphore(req_id=req_id, name=self._name, data=new_data) + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="update") + return resp + + async def close(self, flush: bool = True): + try: + if self._req_id is not None: + req = ReleaseSemaphore(req_id=self._req_id, name=self._name) + if self._stream is not None: + await self.send(req) + except issues.Error: + pass + + if self._reconnector is not None: + await self._reconnector.stop(flush) + + self._stream = None + self._req_id = None + self._node_path = None diff --git a/ydb/aio/coordination/reconnector.py b/ydb/aio/coordination/reconnector.py new file mode 100644 index 00000000..ca720522 --- /dev/null +++ b/ydb/aio/coordination/reconnector.py @@ -0,0 +1,101 @@ +import asyncio +import contextlib +from typing import Optional + +from ydb.aio.coordination.stream import CoordinationStream + + +class CoordinationReconnector: + def __init__( + self, + driver, + request_queue: asyncio.Queue, + node_path: str, + timeout_millis: int, + ): + self._driver = driver + self._request_queue = request_queue + self._node_path = node_path + self._timeout_millis = timeout_millis + + self._task: Optional[asyncio.Task] = None + self._stream: Optional[CoordinationStream] = None + + self._ready: Optional[asyncio.Event] = None + self._stopped = False + self._first_error: Optional[Exception] = None + + def start(self): + if self._stopped: + return + + if self._ready is None: + self._ready = asyncio.Event() + + self._first_error = None + + if self._task is None or self._task.done(): + self._task = asyncio.create_task(self._connection_loop()) + + async def stop(self, flush: bool): + self._stopped = True + + if self._task: + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + self._task = None + + if self._stream: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None + + if self._ready: + self._ready.clear() + + async def wait_ready(self): + if self._first_error: + raise self._first_error + if not self._ready: + raise RuntimeError("Reconnector not started") + await self._ready.wait() + if self._first_error: + raise self._first_error + + def get_stream(self) -> CoordinationStream: + if self._stream is None or self._stream.session_id is None: + raise RuntimeError("Coordination stream is not ready") + return self._stream + + async def _connection_loop(self): + if self._stopped: + return + + try: + self._stream = CoordinationStream(self._driver) + await self._stream.start_session(self._node_path, self._timeout_millis) + if self._ready: + self._ready.set() + + if self._stream._background_tasks: + done, pending = await asyncio.wait( + self._stream._background_tasks, + return_when=asyncio.FIRST_EXCEPTION, + ) + + for d in done: + if d.cancelled(): + continue + exc = d.exception() + if exc: + raise exc + + except Exception as exc: + self._first_error = exc + if self._ready: + self._ready.clear() + if self._stream: + with contextlib.suppress(Exception): + await self._stream.close() + self._stream = None diff --git a/ydb/aio/coordination/stream.py b/ydb/aio/coordination/stream.py new file mode 100644 index 00000000..dfe55aac --- /dev/null +++ b/ydb/aio/coordination/stream.py @@ -0,0 +1,98 @@ +import asyncio +import contextlib +from typing import Optional, Set + +import ydb +from ydb import issues, _apis +from ydb._grpc.grpcwrapper.common_utils import IToProto, GrpcWrapperAsyncIO +from ydb._grpc.grpcwrapper.ydb_coordination import FromServer, Ping, SessionStart + + +class CoordinationStream: + def __init__(self, driver: "ydb.aio.Driver"): + self._driver = driver + self._stream: GrpcWrapperAsyncIO = GrpcWrapperAsyncIO(FromServer.from_proto) + self._background_tasks: Set[asyncio.Task] = set() + self._incoming_queue: asyncio.Queue = asyncio.Queue() + self._closed = False + self._started = False + self.session_id: Optional[int] = None + + async def start_session(self, path: str, timeout_millis: int): + if self._started: + raise issues.Error("CoordinationStream already started") + + self._started = True + await self._stream.start(self._driver, _apis.CoordinationService.Stub, _apis.CoordinationService.Session) + + self._stream.write(SessionStart(path=path, timeout_millis=timeout_millis)) + + while True: + try: + resp = await self._stream.receive(timeout=3) + if resp.session_started: + self.session_id = resp.session_started + break + else: + continue + except asyncio.TimeoutError: + raise issues.Error("Timeout waiting for SessionStart response") + except Exception as e: + raise issues.Error(f"Failed to start session: {e}") + + task = asyncio.get_running_loop().create_task(self._reader_loop()) + self._background_tasks.add(task) + + async def _reader_loop(self): + while True: + try: + resp = await self._stream.receive(timeout=3) + if self._closed: + break + + fs = FromServer.from_proto(resp) + if fs.opaque: + try: + self._stream.write(Ping(fs.opaque)) + except Exception: + raise issues.Error("Failed to write Ping") + else: + await self._incoming_queue.put(resp) + except asyncio.CancelledError: + break + + async def send(self, req: IToProto): + if self._closed: + raise issues.Error("Stream closed") + self._stream.write(req) + + async def receive(self, timeout: Optional[float] = None): + if self._closed: + raise issues.Error("Stream closed") + + try: + if timeout is not None: + return await asyncio.wait_for(self._incoming_queue.get(), timeout) + else: + return await self._incoming_queue.get() + except asyncio.TimeoutError: + return None + + async def close(self): + if self._closed: + return + self._closed = True + + for task in list(self._background_tasks): + task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*self._background_tasks, return_exceptions=True) + + self._background_tasks.clear() + + if self._stream: + self._stream.close() + self._stream = None + + self.session_id = None diff --git a/ydb/coordination/__init__.py b/ydb/coordination/__init__.py index 55834e89..1e280ee7 100644 --- a/ydb/coordination/__init__.py +++ b/ydb/coordination/__init__.py @@ -5,6 +5,16 @@ ConsistencyMode, RateLimiterCountersMode, DescribeResult, + CreateSemaphoreResult, + DescribeLockResult, ) -__all__ = ["CoordinationClient", "NodeConfig", "ConsistencyMode", "RateLimiterCountersMode", "DescribeResult"] +__all__ = [ + "CoordinationClient", + "NodeConfig", + "ConsistencyMode", + "RateLimiterCountersMode", + "DescribeResult", + "CreateSemaphoreResult", + "DescribeLockResult", +] diff --git a/ydb/coordination/coordination_client.py b/ydb/coordination/coordination_client.py index 24dd999d..3a684cef 100644 --- a/ydb/coordination/coordination_client.py +++ b/ydb/coordination/coordination_client.py @@ -8,6 +8,7 @@ ) from ydb._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig from ydb.coordination.base_coordination_client import BaseCoordinationClient +from ydb.coordination.lock_sync import CoordinationLockSync class CoordinationClient(BaseCoordinationClient): @@ -35,5 +36,5 @@ def delete_node(self, path: str, settings=None): settings=settings, ) - def lock(self): - raise NotImplementedError("Will be implemented in future release") + def lock(self, lock_name: str, node_path: str): + return CoordinationLockSync(self, lock_name, node_path=node_path) diff --git a/ydb/coordination/lock_sync.py b/ydb/coordination/lock_sync.py new file mode 100644 index 00000000..2726a94e --- /dev/null +++ b/ydb/coordination/lock_sync.py @@ -0,0 +1,81 @@ +from typing import Optional + +from ydb import issues +from ydb._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync +from ydb.aio.coordination.lock import CoordinationLock + + +class CoordinationLockSync: + def __init__( + self, + client, + name: str, + node_path: Optional[str] = None, + ): + self._closed = False + self._name = name + self._caller = CallFromSyncToAsync(_get_shared_event_loop()) + self._timeout_sec = 5 + + async def _make_lock(): + return CoordinationLock( + client=client, + name=self._name, + node_path=node_path, + ) + + self._async_lock: CoordinationLock = self._caller.safe_call_with_result(_make_lock(), self._timeout_sec) + + def _check_closed(self): + if self._closed: + raise issues.Error(f"CoordinationLockSync {self._name} already closed") + + def __enter__(self): + self.acquire(timeout=self._timeout_sec) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.release(timeout=self._timeout_sec) + except Exception: + pass + + def acquire(self, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result(self._async_lock.acquire(), t) + + def release(self, timeout: Optional[float] = None): + if self._closed: + return + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result(self._async_lock.release(), t) + + def create(self, init_limit: int, init_data: bytes, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result(self._async_lock.create(init_limit, init_data), t) + + def delete(self, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result(self._async_lock.delete(), t) + + def describe(self, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result(self._async_lock.describe(), t) + + def update(self, new_data: bytes, timeout: Optional[float] = None): + self._check_closed() + t = timeout or self._timeout_sec + return self._caller.safe_call_with_result(self._async_lock.update(new_data), t) + + def close(self, timeout: Optional[float] = None): + if self._closed: + return + t = timeout or self._timeout_sec + + self._caller.safe_call_with_result(self._async_lock.release(), t) + self._caller.safe_call_with_result(self._async_lock.close(True), t) + self._closed = True