diff --git a/docs/examples.rst b/docs/examples.rst index 8d7cbf3c..686285ac 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -5,4 +5,5 @@ Examples :maxdepth: 3 examples/basic_example - examples/authentication \ No newline at end of file + examples/authentication + examples/coordination \ No newline at end of file diff --git a/docs/examples/coordination.rst b/docs/examples/coordination.rst new file mode 100644 index 00000000..aa5d7695 --- /dev/null +++ b/docs/examples/coordination.rst @@ -0,0 +1,99 @@ +Coordination Service +=================== + +.. warning:: + Coordination Service API is experimental and may contain bugs and may change in future releases. + +All examples in this section are parts of `coordination example `_. + + +Create node +----------- + +.. code-block:: python + + driver.coordination_client.create_node("/local/my_node") + +Create node with config +----------------------- + +.. code-block:: python + + from ydb import NodeConfig, ConsistencyMode, RateLimiterCountersMode + + config = NodeConfig( + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.AGGREGATED, + self_check_period_millis=1000, + session_grace_period_millis=10000 + ) + + driver.coordination_client.create_node("/local/my_node", config) + +Describe node +------------- + +.. code-block:: python + + config = driver.coordination_client.describe_node("/local/my_node") + +Alter node +---------- + +.. code-block:: python + + new_config = NodeConfig( + attach_consistency_mode=ConsistencyMode.RELAXED, + read_consistency_mode=ConsistencyMode.RELAXED, + rate_limiter_counters_mode=RateLimiterCountersMode.DETAILED, + self_check_period_millis=2000, + session_grace_period_millis=15000 + ) + driver.coordination_client.alter_node("/local/my_node", new_config) + +Delete node +----------- + +.. code-block:: python + + driver.coordination_client.delete_node("/local/my_node") + +Create session +-------------- + +.. code-block:: python + + with driver.coordination_client.session("/local/my_node") as session: + pass + +Use semaphore manually +---------------------- + +.. code-block:: python + + with driver.coordination_client.session("/local/my_node") as session: + semaphore = session.semaphore("my_semaphore", limit=2) # limit is optional, default is 1 + semaphore.acquire(count=2) # count is optional, default is 1 + try: + pass + finally: + semaphore.release() + +Use semaphore with context manager +---------------------------------- + +.. code-block:: python + + with driver.coordination_client.session("/local/my_node") as session: + with session.semaphore("my_semaphore"): + pass + +Async usage +----------- + +.. code-block:: python + + async with driver.coordination_client.session("/local/my_node") as session: + async with session.semaphore("my_semaphore"): + pass diff --git a/examples/coordination/__main__.py b/examples/coordination/__main__.py new file mode 100644 index 00000000..fa7a4983 --- /dev/null +++ b/examples/coordination/__main__.py @@ -0,0 +1,40 @@ +import argparse +import asyncio +from example import run as run_sync +from example_async import run as run_async +import logging + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""\033[92mYDB coordination example.\x1b[0m\n""", + ) + parser.add_argument("-e", "--endpoint", help="Endpoint url to use", default="grpc://localhost:2136") + parser.add_argument("-d", "--database", help="Name of the database to use", default="/local") + parser.add_argument("-v", "--verbose", default=False, action="store_true") + parser.add_argument("-m", "--mode", default="sync", help="Mode of example: sync or async") + + args = parser.parse_args() + + if args.verbose: + logger = logging.getLogger("ydb") + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler()) + + if args.mode == "sync": + print("Running sync example") + run_sync( + args.endpoint, + args.database, + ) + elif args.mode == "async": + print("Running async example") + asyncio.run( + run_async( + args.endpoint, + args.database, + ) + ) + else: + raise ValueError(f"Unsupported mode: {args.mode}, use one of sync|async") diff --git a/examples/coordination/example.py b/examples/coordination/example.py new file mode 100644 index 00000000..0b817ff3 --- /dev/null +++ b/examples/coordination/example.py @@ -0,0 +1,57 @@ +import time +import threading +import ydb + +NODE_PATH = "/local/node_name1" +SEMAPHORE_NAME = "semaphore" + + +def linear_workload(client, text): + session = client.session(NODE_PATH) + semaphore = session.semaphore(SEMAPHORE_NAME) + for i in range(3): + semaphore.acquire() + for j in range(3): + print(f"{text} iteration {i}-{j}") + time.sleep(0.1) + semaphore.release() + time.sleep(0.05) + session.close() + + +def context_manager_workload(client, text): + with client.session(NODE_PATH) as session: + for i in range(3): + with session.semaphore(SEMAPHORE_NAME): + for j in range(3): + print(f"{text} iteration {i}-{j}") + time.sleep(0.1) + time.sleep(0.05) + + +def run(endpoint, database): + with ydb.Driver( + endpoint=endpoint, + database=database, + credentials=ydb.credentials_from_env_variables(), + root_certificates=ydb.load_ydb_root_certificate(), + ) as driver: + driver.wait(timeout=5, fail_fast=True) + + driver.coordination_client.create_node(NODE_PATH) + + threads = [] + + for i in range(4): + worker_name = f"worker {i+1}" + if i < 2: + thread = threading.Thread(target=linear_workload, args=(driver.coordination_client, worker_name)) + else: + thread = threading.Thread( + target=context_manager_workload, args=(driver.coordination_client, worker_name) + ) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() diff --git a/examples/coordination/example_async.py b/examples/coordination/example_async.py new file mode 100644 index 00000000..1990c6dd --- /dev/null +++ b/examples/coordination/example_async.py @@ -0,0 +1,47 @@ +import asyncio +import ydb + +NODE_PATH = "/local/node_name1" +SEMAPHORE_NAME = "semaphore" + + +async def linear_workload(client, text): + session = client.session(NODE_PATH) + semaphore = session.semaphore(SEMAPHORE_NAME) + for i in range(3): + await semaphore.acquire() + for j in range(3): + print(f"{text} iteration {i}-{j}") + await asyncio.sleep(0.1) + await semaphore.release() + await asyncio.sleep(0.05) + await session.close() + + +async def context_manager_workload(client, text): + async with client.session(NODE_PATH) as session: + for i in range(3): + async with session.semaphore(SEMAPHORE_NAME): + for j in range(3): + print(f"{text} iteration {i}-{j}") + await asyncio.sleep(0.1) + await asyncio.sleep(0.05) + + +async def run(endpoint, database): + async with ydb.aio.Driver( + endpoint=endpoint, + database=database, + credentials=ydb.credentials_from_env_variables(), + root_certificates=ydb.load_ydb_root_certificate(), + ) as driver: + await driver.wait(timeout=5, fail_fast=True) + + await driver.coordination_client.create_node(NODE_PATH) + + await asyncio.gather( + linear_workload(driver.coordination_client, "worker 1"), + linear_workload(driver.coordination_client, "worker 2"), + context_manager_workload(driver.coordination_client, "worker 3"), + context_manager_workload(driver.coordination_client, "worker 4"), + ) diff --git a/tests/coordination/test_coordination_client.py b/tests/coordination/test_coordination_client.py index 460c1139..58014452 100644 --- a/tests/coordination/test_coordination_client.py +++ b/tests/coordination/test_coordination_client.py @@ -1,7 +1,11 @@ +import asyncio +import threading + import pytest import ydb from ydb.aio.coordination import CoordinationClient as AioCoordinationClient +from ydb import StatusCode from ydb.coordination import ( NodeConfig, @@ -11,27 +15,63 @@ ) -class TestCoordination: - def test_coordination_node_lifecycle(self, driver_sync: ydb.Driver): - client = CoordinationClient(driver_sync) - node_path = "/local/test_node_lifecycle" +@pytest.fixture +def sync_coordination_node(driver_sync): + client = CoordinationClient(driver_sync) + node_path = "/local/test_node" - try: - client.delete_node(node_path) - except ydb.SchemeError: - pass + try: + client.delete_node(node_path) + except ydb.SchemeError: + pass - with pytest.raises(ydb.SchemeError): - client.describe_node(node_path) + config = NodeConfig( + session_grace_period_millis=1000, + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, + self_check_period_millis=0, + ) + client.create_node(node_path, config) - initial_config = NodeConfig( - session_grace_period_millis=1000, - attach_consistency_mode=ConsistencyMode.STRICT, - read_consistency_mode=ConsistencyMode.STRICT, - rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, - self_check_period_millis=0, - ) - client.create_node(node_path, initial_config) + yield client, node_path, config + + try: + client.delete_node(node_path) + except ydb.SchemeError: + pass + + +@pytest.fixture +async def async_coordination_node(aio_connection): + client = AioCoordinationClient(aio_connection) + node_path = "/local/test_node" + + try: + await client.delete_node(node_path) + except ydb.SchemeError: + pass + + config = NodeConfig( + session_grace_period_millis=1000, + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, + self_check_period_millis=0, + ) + await client.create_node(node_path, config) + + yield client, node_path, config + + try: + await client.delete_node(node_path) + except ydb.SchemeError: + pass + + +class TestCoordination: + def test_coordination_node_lifecycle(self, sync_coordination_node): + client, node_path, initial_config = sync_coordination_node node_conf = client.describe_node(node_path) assert node_conf == initial_config @@ -53,26 +93,8 @@ def test_coordination_node_lifecycle(self, driver_sync: ydb.Driver): with pytest.raises(ydb.SchemeError): client.describe_node(node_path) - async def test_coordination_node_lifecycle_async(self, aio_connection): - client = AioCoordinationClient(aio_connection) - node_path = "/local/test_node_lifecycle" - - try: - await client.delete_node(node_path) - except ydb.SchemeError: - pass - - with pytest.raises(ydb.SchemeError): - await client.describe_node(node_path) - - initial_config = NodeConfig( - session_grace_period_millis=1000, - attach_consistency_mode=ConsistencyMode.STRICT, - read_consistency_mode=ConsistencyMode.STRICT, - rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, - self_check_period_millis=0, - ) - await client.create_node(node_path, initial_config) + async def test_coordination_node_lifecycle_async(self, async_coordination_node): + client, node_path, initial_config = async_coordination_node node_conf = await client.describe_node(node_path) assert node_conf == initial_config @@ -93,3 +115,121 @@ async def test_coordination_node_lifecycle_async(self, aio_connection): with pytest.raises(ydb.SchemeError): await client.describe_node(node_path) + + async def test_coordination_lock_describe_full_async(self, async_coordination_node): + client, node_path, _ = async_coordination_node + + async with client.session(node_path) as node: + lock = node.semaphore("test_lock") + + desc = await lock.describe() + assert desc.status == StatusCode.NOT_FOUND + + async with lock: + pass + + desc = await lock.describe() + assert desc.data == b"" + + await lock.update(new_data=b"world") + + desc2 = await lock.describe() + assert desc2.data == b"world" + + def test_coordination_lock_describe_full(self, sync_coordination_node): + client, node_path, _ = sync_coordination_node + + with client.session(node_path) as node: + lock = node.semaphore("test_lock") + + desc = lock.describe() + assert desc.status == StatusCode.NOT_FOUND + + with lock: + pass + + desc = lock.describe() + assert desc.data == b"" + + lock.update(new_data=b"world") + + desc2 = lock.describe() + assert desc2.data == b"world" + + async def test_coordination_lock_racing_async(self, async_coordination_node): + client, node_path, _ = async_coordination_node + timeout = 5 + + async with client.session(node_path) as node: + lock2_started = asyncio.Event() + lock2_acquired = asyncio.Event() + lock2_release = asyncio.Event() + + async def second_lock_task(): + lock2_started.set() + async with node.semaphore("test_lock"): + lock2_acquired.set() + await lock2_release.wait() + + async with node.semaphore("test_lock"): + t2 = asyncio.create_task(second_lock_task()) + await asyncio.wait_for(lock2_started.wait(), timeout=timeout) + + await asyncio.wait_for(lock2_acquired.wait(), timeout=timeout) + lock2_release.set() + await asyncio.wait_for(t2, timeout=timeout) + + def test_coordination_lock_racing(self, sync_coordination_node): + client, node_path, _ = sync_coordination_node + timeout = 5 + + with client.session(node_path) as node: + lock2_started = threading.Event() + lock2_acquired = threading.Event() + lock2_release = threading.Event() + + def second_lock_task(): + lock2_started.set() + with node.semaphore("test_lock"): + lock2_acquired.set() + lock2_release.wait(timeout) + + with node.semaphore("test_lock"): + t2 = threading.Thread(target=second_lock_task) + t2.start() + + assert lock2_started.wait(timeout) + + assert lock2_acquired.wait(timeout) + lock2_release.set() + t2.join(timeout) + + async def test_coordination_reconnect_async(self, async_coordination_node): + client, node_path, _ = async_coordination_node + + async with client.session(node_path) as node: + lock = node.semaphore("test_lock") + + async with lock: + pass + + await node._reconnector._stream.close() + + async with lock: + pass + + async def test_same_lock_cannot_be_acquired_twice(self, async_coordination_node): + client, node_path, _ = async_coordination_node + + async with client.session(node_path) as node: + lock1 = node.semaphore("lock1") + lock1_1 = node.semaphore("lock1") + + await lock1.acquire() + + acquire_task = asyncio.create_task(lock1_1.acquire()) + + assert not acquire_task.done() + + await lock1.release() + await asyncio.wait_for(acquire_task, timeout=5) diff --git a/ydb/_apis.py b/ydb/_apis.py index 97f64b90..595550b2 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -143,9 +143,9 @@ class QueryService(object): class CoordinationService(object): Stub = ydb_coordination_v1_pb2_grpc.CoordinationServiceStub - - Session = "Session" CreateNode = "CreateNode" AlterNode = "AlterNode" DropNode = "DropNode" DescribeNode = "DescribeNode" + SessionRequest = "SessionRequest" + Session = "Session" diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 0fb960d6..cf91b9c9 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -220,7 +220,7 @@ async def _start_sync_driver(self, driver: Driver, stub, method): self._stream_call = stream_call self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) - async def receive(self, timeout: Optional[int] = None) -> Any: + async def receive(self, timeout: Optional[int] = None, is_coordination_calls=False) -> Any: # todo handle grpc exceptions and convert it to internal exceptions try: if timeout is None: @@ -235,7 +235,8 @@ async def get_response(): except (grpc.RpcError, grpc.aio.AioRpcError) as e: raise connection._rpc_error_handler(self._connection_state, e) - issues._process_response(grpc_message) + if not is_coordination_calls: + issues._process_response(grpc_message) if self._connection_state != "has_received_messages": self._connection_state = "has_received_messages" diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination.py b/ydb/_grpc/grpcwrapper/ydb_coordination.py index 176e4e02..8794b570 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination.py @@ -16,7 +16,7 @@ class CreateNodeRequest(IToProto): path: str config: typing.Optional[NodeConfig] - def to_proto(self) -> ydb_coordination_pb2.CreateNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.CreateNodeRequest": cfg_proto = self.config.to_proto() if self.config else None return ydb_coordination_pb2.CreateNodeRequest( path=self.path, @@ -29,7 +29,7 @@ class AlterNodeRequest(IToProto): path: str config: NodeConfig - def to_proto(self) -> ydb_coordination_pb2.AlterNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.AlterNodeRequest": cfg_proto = self.config.to_proto() if self.config else None return ydb_coordination_pb2.AlterNodeRequest( path=self.path, @@ -41,7 +41,7 @@ def to_proto(self) -> ydb_coordination_pb2.AlterNodeRequest: class DescribeNodeRequest(IToProto): path: str - def to_proto(self) -> ydb_coordination_pb2.DescribeNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.DescribeNodeRequest": return ydb_coordination_pb2.DescribeNodeRequest( path=self.path, ) @@ -51,7 +51,174 @@ def to_proto(self) -> ydb_coordination_pb2.DescribeNodeRequest: class DropNodeRequest(IToProto): path: str - def to_proto(self) -> ydb_coordination_pb2.DropNodeRequest: + def to_proto(self) -> "ydb_coordination_pb2.DropNodeRequest": return ydb_coordination_pb2.DropNodeRequest( path=self.path, ) + + +@dataclass +class SessionStart(IToProto): + path: str + timeout_millis: int + description: str = "" + session_id: int = 0 + seq_no: int = 0 + protection_key: bytes = b"" + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + session_start=ydb_coordination_pb2.SessionRequest.SessionStart( + path=self.path, + session_id=self.session_id, + timeout_millis=self.timeout_millis, + description=self.description, + seq_no=self.seq_no, + protection_key=self.protection_key, + ) + ) + + +@dataclass +class SessionStop(IToProto): + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest(session_stop=ydb_coordination_pb2.SessionRequest.SessionStop()) + + +@dataclass +class Ping(IToProto): + opaque: int = 0 + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + ping=ydb_coordination_pb2.SessionRequest.PingPong(opaque=self.opaque) + ) + + +@dataclass +class CreateSemaphore(IToProto): + name: str + req_id: int + limit: int + data: bytes = b"" + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + create_semaphore=ydb_coordination_pb2.SessionRequest.CreateSemaphore( + req_id=self.req_id, name=self.name, limit=self.limit, data=self.data + ) + ) + + +@dataclass +class AcquireSemaphore(IToProto): + name: str + req_id: int + count: int = 1 + timeout_millis: int = 0 + data: bytes = b"" + ephemeral: bool = False + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + acquire_semaphore=ydb_coordination_pb2.SessionRequest.AcquireSemaphore( + req_id=self.req_id, + name=self.name, + timeout_millis=self.timeout_millis, + count=self.count, + data=self.data, + ephemeral=self.ephemeral, + ) + ) + + +@dataclass +class ReleaseSemaphore(IToProto): + name: str + req_id: int + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + release_semaphore=ydb_coordination_pb2.SessionRequest.ReleaseSemaphore(req_id=self.req_id, name=self.name) + ) + + +@dataclass +class DescribeSemaphore(IToProto): + include_owners: bool + include_waiters: bool + name: str + req_id: int + watch_data: bool + watch_owners: bool + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + describe_semaphore=ydb_coordination_pb2.SessionRequest.DescribeSemaphore( + include_owners=self.include_owners, + include_waiters=self.include_waiters, + name=self.name, + req_id=self.req_id, + watch_data=self.watch_data, + watch_owners=self.watch_owners, + ) + ) + + +@dataclass +class UpdateSemaphore(IToProto): + name: str + req_id: int + data: bytes + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + update_semaphore=ydb_coordination_pb2.SessionRequest.UpdateSemaphore( + req_id=self.req_id, name=self.name, data=self.data + ) + ) + + +@dataclass +class DeleteSemaphore(IToProto): + name: str + req_id: int + force: bool = False + + def to_proto(self) -> "ydb_coordination_pb2.SessionRequest": + return ydb_coordination_pb2.SessionRequest( + delete_semaphore=ydb_coordination_pb2.SessionRequest.DeleteSemaphore( + req_id=self.req_id, name=self.name, force=self.force + ) + ) + + +@dataclass +class FromServer: + raw: "ydb_coordination_pb2.SessionResponse" + + @staticmethod + def from_proto(resp: "ydb_coordination_pb2.SessionResponse") -> "FromServer": + return FromServer(raw=resp) + + def __getattr__(self, name: str): + return getattr(self.raw, name) + + @property + def session_started(self) -> typing.Optional["ydb_coordination_pb2.SessionResponse.SessionStarted"]: + s = self.raw.session_started + return s if s.session_id else None + + @property + def opaque(self) -> typing.Optional[int]: + if self.raw.HasField("ping"): + return self.raw.ping.opaque + return None + + @property + def acquire_semaphore_result(self): + return self.raw.acquire_semaphore_result if self.raw.HasField("acquire_semaphore_result") else None + + @property + def create_semaphore_result(self): + return self.raw.create_semaphore_result if self.raw.HasField("create_semaphore_result") else None diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py index a3580974..1112cd4b 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py @@ -2,7 +2,6 @@ from enum import IntEnum import typing - if typing.TYPE_CHECKING: from ..v4.protos import ydb_coordination_pb2 else: @@ -55,3 +54,60 @@ def from_proto(msg: ydb_coordination_pb2.DescribeNodeResponse) -> "NodeConfig": result = ydb_coordination_pb2.DescribeNodeResult() msg.operation.result.Unpack(result) return NodeConfig.from_proto(result.config) + + +@dataclass +class AcquireSemaphoreResult: + req_id: int + acquired: bool + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.AcquireSemaphoreResult) -> "AcquireSemaphoreResult": + return AcquireSemaphoreResult( + req_id=msg.req_id, + acquired=msg.acquired, + status=msg.status, + ) + + +@dataclass +class CreateSemaphoreResult: + req_id: int + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.CreateSemaphoreResult) -> "CreateSemaphoreResult": + return CreateSemaphoreResult( + req_id=msg.req_id, + status=msg.status, + ) + + +@dataclass +class DescribeLockResult: + req_id: int + status: int + watch_added: bool + count: int + data: bytes + ephemeral: bool + limit: int + name: str + owners: list + waiters: list + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.DescribeSemaphoreResult) -> "DescribeLockResult": + return DescribeLockResult( + req_id=msg.req_id, + status=msg.status, + watch_added=msg.watch_added, + count=msg.semaphore_description.count, + data=msg.semaphore_description.data, + ephemeral=msg.semaphore_description.ephemeral, + limit=msg.semaphore_description.limit, + name=msg.semaphore_description.name, + owners=msg.semaphore_description.owners, + waiters=msg.semaphore_description.waiters, + ) diff --git a/ydb/aio/__init__.py b/ydb/aio/__init__.py index 4e4192a8..9747666f 100644 --- a/ydb/aio/__init__.py +++ b/ydb/aio/__init__.py @@ -1,5 +1,4 @@ from .driver import Driver # noqa from .table import SessionPool, retry_operation # noqa from .query import QuerySessionPool, QuerySession, QueryTxContext # noqa - -# from .coordination import CoordinationClient # noqa +from .coordination import CoordinationClient # noqa diff --git a/ydb/aio/coordination/client.py b/ydb/aio/coordination/client.py index b36b8950..2efff035 100644 --- a/ydb/aio/coordination/client.py +++ b/ydb/aio/coordination/client.py @@ -8,32 +8,43 @@ ) from ..._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig from ...coordination.base import BaseCoordinationClient +from .session import CoordinationSession class CoordinationClient(BaseCoordinationClient): async def create_node(self, path: str, config: Optional[NodeConfig] = None, settings=None): + self._log_experimental_api() + return await self._call_create( CreateNodeRequest(path=path, config=config).to_proto(), settings=settings, ) async def describe_node(self, path: str, settings=None) -> NodeConfig: + self._log_experimental_api() + return await self._call_describe( DescribeNodeRequest(path=path).to_proto(), settings=settings, ) async def alter_node(self, path: str, new_config: NodeConfig, settings=None): + self._log_experimental_api() + return await self._call_alter( AlterNodeRequest(path=path, config=new_config).to_proto(), settings=settings, ) async def delete_node(self, path: str, settings=None): + self._log_experimental_api() + return await self._call_delete( DropNodeRequest(path=path).to_proto(), settings=settings, ) - async def lock(self): - raise NotImplementedError("Will be implemented in future release") + def session(self, path: str) -> CoordinationSession: + self._log_experimental_api() + + return CoordinationSession(self._driver, path) diff --git a/ydb/aio/coordination/reconnector.py b/ydb/aio/coordination/reconnector.py new file mode 100644 index 00000000..676ed5b1 --- /dev/null +++ b/ydb/aio/coordination/reconnector.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Dict + +from ... import issues +from ..._grpc.grpcwrapper.common_utils import IToProto +from ..._grpc.grpcwrapper.ydb_coordination import FromServer +from .stream import CoordinationStream + +logger = logging.getLogger(__name__) + + +class CoordinationReconnector: + def __init__(self, driver, node_path: str, timeout_millis: int = 30000): + self._driver = driver + self._node_path = node_path + self._timeout_millis = timeout_millis + self._wait_timeout = timeout_millis / 1000 + + self._stream = None + self._session_id = None + + self._pending_futures: Dict[int, asyncio.Future] = {} + self._pending_requests: Dict[int, IToProto] = {} + + self._send_lock = asyncio.Lock() + self._connection_task = None + self._closed = False + + async def stop(self): + self._closed = True + + if self._connection_task: + self._connection_task.cancel() + try: + await self._connection_task + except asyncio.CancelledError: + pass + + if self._stream: + await self._stream.close() + + for fut in self._pending_futures.values(): + if not fut.done(): + fut.set_exception(asyncio.CancelledError()) + + self._pending_futures.clear() + self._pending_requests.clear() + + async def send_and_wait(self, req: IToProto): + if self._closed: + raise issues.Error("Reconnector closed") + + if self._connection_task is None: + self._connection_task = asyncio.create_task(self._connection_loop()) + + while not self._stream or self._stream._closed: + await asyncio.sleep(0) + + req_id = getattr(req, "req_id") + loop = asyncio.get_running_loop() + fut = loop.create_future() + + self._pending_futures[req_id] = fut + self._pending_requests[req_id] = req + + async with self._send_lock: + await self._stream.send(req) + + return await asyncio.wait_for(fut, self._wait_timeout) + + async def _connection_loop(self): + while not self._closed: + try: + stream = CoordinationStream(self._driver) + await stream.start_session( + self._node_path, + self._timeout_millis, + session_id=self._session_id, + ) + + self._stream = stream + self._session_id = stream.session_id + + for req in self._pending_requests.values(): + await stream.send(req) + + await self._dispatch_loop(stream) + + except asyncio.CancelledError: + return + except Exception as exc: + logger.debug("Coordination stream error: %r", exc) + finally: + if self._stream: + await self._stream.close() + self._stream = None + + async def _dispatch_loop(self, stream): + while not self._closed and self._stream is stream: + resp = await stream.receive(self._wait_timeout) + if not resp: + continue + + fs = FromServer.from_proto(resp) + payload = next( + ( + getattr(fs, name) + for name in ( + "acquire_semaphore_result", + "release_semaphore_result", + "describe_semaphore_result", + "create_semaphore_result", + "update_semaphore_result", + "delete_semaphore_result", + ) + if fs.raw.HasField(name) + ), + None, + ) + + if not payload: + continue + + fut = self._pending_futures.pop(payload.req_id, None) + self._pending_requests.pop(payload.req_id, None) + + if fut and not fut.done(): + fut.set_result(payload) diff --git a/ydb/aio/coordination/semaphore.py b/ydb/aio/coordination/semaphore.py new file mode 100644 index 00000000..3723cd50 --- /dev/null +++ b/ydb/aio/coordination/semaphore.py @@ -0,0 +1,98 @@ +from ... import StatusCode, issues + +from ..._grpc.grpcwrapper.ydb_coordination import ( + AcquireSemaphore, + ReleaseSemaphore, + UpdateSemaphore, + DescribeSemaphore, + CreateSemaphore, +) +from ..._grpc.grpcwrapper.ydb_coordination_public_types import ( + DescribeLockResult, +) + + +class CoordinationSemaphore: + def __init__(self, session, name: str, limit: int): + self._session = session + self._name = name + + self._limit = limit + self._timeout_millis = session._timeout_millis + + async def acquire(self, count: int = 1): + await self._create_if_not_exists() + resp = await self._try_acquire(count) + + if resp.status != StatusCode.SUCCESS: + raise issues.Error(f"Failed to acquire lock {self._name}: {resp.status}") + + return self + + async def release(self): + req = ReleaseSemaphore( + req_id=await self._session.next_req_id(), + name=self._name, + ) + try: + await self._session._reconnector.send_and_wait(req) + except Exception: + pass + + async def describe(self) -> DescribeLockResult: + req = DescribeSemaphore( + req_id=await self._session.next_req_id(), + name=self._name, + include_owners=True, + include_waiters=True, + watch_data=False, + watch_owners=False, + ) + resp = await self._session._reconnector.send_and_wait(req) + return DescribeLockResult.from_proto(resp) + + async def update(self, new_data: bytes) -> None: + req = UpdateSemaphore( + req_id=await self._session.next_req_id(), + name=self._name, + data=new_data, + ) + resp = await self._session._reconnector.send_and_wait(req) + + if resp.status != StatusCode.SUCCESS: + raise issues.Error(f"Failed to update lock {self._name}: {resp.status}") + + async def close(self): + await self.release() + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.release() + + async def _try_acquire(self, count: int): + req = AcquireSemaphore( + req_id=await self._session.next_req_id(), + name=self._name, + count=count, + ephemeral=False, + timeout_millis=self._timeout_millis, + ) + return await self._session._reconnector.send_and_wait(req) + + async def _create_if_not_exists(self): + req = CreateSemaphore( + req_id=await self._session.next_req_id(), + name=self._name, + limit=self._limit, + data=b"", + ) + resp = await self._session._reconnector.send_and_wait(req) + + if resp.status not in ( + StatusCode.SUCCESS, + StatusCode.ALREADY_EXISTS, + ): + raise issues.Error(f"Failed to create lock {self._name}: {resp.status}") diff --git a/ydb/aio/coordination/session.py b/ydb/aio/coordination/session.py new file mode 100644 index 00000000..1d6ae4fd --- /dev/null +++ b/ydb/aio/coordination/session.py @@ -0,0 +1,43 @@ +import asyncio + +from .reconnector import CoordinationReconnector +from .semaphore import CoordinationSemaphore + + +class CoordinationSession: + def __init__(self, driver, path: str, timeout_millis: int = 30000): + self._driver = driver + self._path = path + self._timeout_millis = timeout_millis + + self._reconnector = CoordinationReconnector( + driver=driver, + node_path=path, + timeout_millis=timeout_millis, + ) + + self._req_id = 0 + self._req_id_lock = asyncio.Lock() + self._closed = False + + async def next_req_id(self) -> int: + async with self._req_id_lock: + self._req_id += 1 + return self._req_id + + def semaphore(self, name: str, limit: int = 1) -> CoordinationSemaphore: + if self._closed: + raise RuntimeError("CoordinationSession is closed") + return CoordinationSemaphore(self, name, limit) + + async def close(self): + if self._closed: + return + self._closed = True + await self._reconnector.stop() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() diff --git a/ydb/aio/coordination/stream.py b/ydb/aio/coordination/stream.py new file mode 100644 index 00000000..a04280e6 --- /dev/null +++ b/ydb/aio/coordination/stream.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Optional + +from ... import issues, _apis +from ..._grpc.grpcwrapper.common_utils import IToProto, GrpcWrapperAsyncIO +from ..._grpc.grpcwrapper.ydb_coordination import ( + FromServer, + SessionStart, + Ping, +) + +logger = logging.getLogger(__name__) + + +class CoordinationStream: + def __init__(self, driver): + self._driver = driver + self._stream = GrpcWrapperAsyncIO(FromServer.from_proto) + + self._incoming = asyncio.Queue() + self._reader_task: Optional[asyncio.Task] = None + + self._closed = False + self.session_id: Optional[int] = None + + async def start_session( + self, + path: str, + timeout_millis: int, + session_id: Optional[int] = None, + ): + await self._stream.start( + self._driver, + _apis.CoordinationService.Stub, + _apis.CoordinationService.Session, + ) + + self._stream.write( + SessionStart( + path=path, + timeout_millis=timeout_millis, + session_id=int(session_id) if session_id is not None else 0, + ) + ) + + while True: + resp = await self._stream.receive( + timeout=3, + is_coordination_calls=True, + ) + if resp is None: + continue + + fs = FromServer.from_proto(resp) + if fs.session_started: + self.session_id = int(fs.session_started.session_id) + break + + self._reader_task = asyncio.create_task(self._reader_loop()) + + async def _reader_loop(self): + try: + while True: + resp = await self._stream.receive( + timeout=3, + is_coordination_calls=True, + ) + if resp is None: + continue + + fs = FromServer.from_proto(resp) + + if fs.opaque: + try: + self._stream.write(Ping(fs.opaque)) + except Exception: + break + continue + + await self._incoming.put(resp) + + except asyncio.CancelledError: + pass + except Exception as exc: + logger.debug("CoordinationStream reader stopped: %r", exc) + finally: + self._closed = True + await self._incoming.put(None) + + try: + await self._stream.close() + except Exception: + pass + + async def send(self, req: IToProto): + if self._closed: + raise issues.Error("Coordination stream closed") + self._stream.write(req) + + async def receive(self, timeout: Optional[float] = None): + if self._closed: + raise issues.Error("Coordination stream closed") + + if timeout is None: + return await self._incoming.get() + + return await asyncio.wait_for(self._incoming.get(), timeout) + + async def close(self): + if self._closed: + return + + self._closed = True + + if self._reader_task: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + + try: + await self._stream.close() + except Exception: + pass diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index 267997fb..0e95e46a 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -41,6 +41,7 @@ def __init__( **kwargs ): from .. import topic # local import for prevent cycle import error + from . import coordination # local import for prevent cycle import error config = get_config( driver_config, @@ -59,6 +60,7 @@ def __init__( self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, config.table_client_settings) self.topic_client = topic.TopicClientAsyncIO(self, config.topic_client_settings) + self.coordination_client = coordination.CoordinationClient(self) async def stop(self, timeout=10): await self.table_client._stop_pool_if_needed(timeout=timeout) diff --git a/ydb/coordination/__init__.py b/ydb/coordination/__init__.py index fd994c56..b50bfa61 100644 --- a/ydb/coordination/__init__.py +++ b/ydb/coordination/__init__.py @@ -4,13 +4,18 @@ "ConsistencyMode", "RateLimiterCountersMode", "DescribeResult", + "CreateSemaphoreResult", + "DescribeLockResult", ] from .client import CoordinationClient + from .._grpc.grpcwrapper.ydb_coordination_public_types import ( NodeConfig, ConsistencyMode, RateLimiterCountersMode, DescribeResult, + CreateSemaphoreResult, + DescribeLockResult, ) diff --git a/ydb/coordination/base.py b/ydb/coordination/base.py index 0be7cb8f..4dadaf04 100644 --- a/ydb/coordination/base.py +++ b/ydb/coordination/base.py @@ -2,7 +2,6 @@ from .._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig, DescribeResult import logging - logger = logging.getLogger(__name__) @@ -25,8 +24,8 @@ def wrapper_alter_node(rpc_state, response_pb): class BaseCoordinationClient: def __init__(self, driver): - logger.warning("Experimental API: interface may change in future releases.") self._driver = driver + self._user_warned = False def _call_create(self, request, settings=None): return self._driver( @@ -63,3 +62,10 @@ def _call_delete(self, request, settings=None): wrap_result=wrapper_delete_node, settings=settings, ) + + def _log_experimental_api(self): + if not self._user_warned: + logger.warning( + "Coordination Service API is experimental, may contain bugs and may change in future releases." + ) + self._user_warned = True diff --git a/ydb/coordination/client.py b/ydb/coordination/client.py index 549528d9..f22488d6 100644 --- a/ydb/coordination/client.py +++ b/ydb/coordination/client.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from .._grpc.grpcwrapper.ydb_coordination import ( @@ -8,32 +9,45 @@ ) from .._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig from .base import BaseCoordinationClient +from .session import CoordinationSession + +logger = logging.getLogger(__name__) class CoordinationClient(BaseCoordinationClient): - def create_node(self, path: str, config: Optional[NodeConfig], settings=None): + def create_node(self, path: str, config: Optional[NodeConfig] = None, settings=None): + self._log_experimental_api() + return self._call_create( CreateNodeRequest(path=path, config=config).to_proto(), settings=settings, ) def describe_node(self, path: str, settings=None) -> NodeConfig: + self._log_experimental_api() + return self._call_describe( DescribeNodeRequest(path=path).to_proto(), settings=settings, ) def alter_node(self, path: str, new_config: NodeConfig, settings=None): + self._log_experimental_api() + return self._call_alter( AlterNodeRequest(path=path, config=new_config).to_proto(), settings=settings, ) def delete_node(self, path: str, settings=None): + self._log_experimental_api() + return self._call_delete( DropNodeRequest(path=path).to_proto(), settings=settings, ) - def lock(self): - raise NotImplementedError("Will be implemented in future release") + def session(self, path: str) -> CoordinationSession: + self._log_experimental_api() + + return CoordinationSession(self, path) diff --git a/ydb/coordination/semaphore.py b/ydb/coordination/semaphore.py new file mode 100644 index 00000000..10e53e2f --- /dev/null +++ b/ydb/coordination/semaphore.py @@ -0,0 +1,73 @@ +from typing import Optional + +from .. import issues +from .._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync +from ..aio.coordination.semaphore import CoordinationSemaphore as CoordinationSemaphoreAio +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .session import CoordinationSession + + +class CoordinationSemaphore: + def __init__(self, session: "CoordinationSession", name: str, limit: int = 1): + self._session = session + self._name = name + self._limit = limit + self._closed = False + self._caller = CallFromSyncToAsync(_get_shared_event_loop()) + self._async_semaphore: CoordinationSemaphoreAio = self._session._async_session.semaphore(name, limit) + + def _check_closed(self): + if self._closed: + raise issues.Error(f"CoordinationSemaphore {self._name} already closed") + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.release() + except Exception: + pass + + def acquire(self, count: int = 1, timeout: Optional[float] = None): + self._check_closed() + return self._caller.safe_call_with_result( + self._async_semaphore.acquire(count), + timeout, + ) + + def release(self, timeout: Optional[float] = None): + if self._closed: + return + return self._caller.safe_call_with_result( + self._async_semaphore.release(), + timeout, + ) + + def describe(self, timeout: Optional[float] = None): + self._check_closed() + return self._caller.safe_call_with_result( + self._async_semaphore.describe(), + timeout, + ) + + def update(self, new_data: bytes, timeout: Optional[float] = None): + self._check_closed() + return self._caller.safe_call_with_result( + self._async_semaphore.update(new_data), + timeout, + ) + + def close(self, timeout: Optional[float] = None): + if self._closed: + return + try: + self._caller.safe_call_with_result( + self._async_semaphore.release(), + timeout, + ) + finally: + self._closed = True diff --git a/ydb/coordination/session.py b/ydb/coordination/session.py new file mode 100644 index 00000000..111625ca --- /dev/null +++ b/ydb/coordination/session.py @@ -0,0 +1,42 @@ +from .._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync +from ..aio.coordination.session import CoordinationSession as CoordinationSessionAio +from .semaphore import CoordinationSemaphore + + +class CoordinationSession: + def __init__(self, client, path: str, timeout_sec: float = 5): + self._client = client + self._path = path + self._timeout_sec = timeout_sec + + self._caller = CallFromSyncToAsync(_get_shared_event_loop()) + self._closed = False + + async def _make_session() -> CoordinationSessionAio: + return CoordinationSessionAio( + client._driver, + path, + ) + + self._async_session: CoordinationSessionAio = self._caller.safe_call_with_result( + _make_session(), + self._timeout_sec, + ) + + def semaphore(self, name: str, limit: int = 1): + return CoordinationSemaphore(self, name, limit) + + def close(self): + if self._closed: + return + self._caller.safe_call_with_result( + self._async_session.close(), + self._timeout_sec, + ) + self._closed = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() diff --git a/ydb/driver.py b/ydb/driver.py index 5c9822ed..cd17c0a2 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -272,6 +272,7 @@ def __init__( :param credentials: A credentials. If not specifed credentials constructed by default. """ from . import topic # local import for prevent cycle import error + from . import coordination # local import for prevent cycle import error driver_config = get_config( driver_config, @@ -289,6 +290,7 @@ def __init__( self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, driver_config.table_client_settings) self.topic_client = topic.TopicClient(self, driver_config.topic_client_settings) + self.coordination_client = coordination.CoordinationClient(self) def stop(self, timeout=10): self.table_client._stop_pool_if_needed(timeout=timeout)