diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml deleted file mode 100644 index d69d58d..0000000 --- a/.github/workflows/pypi.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: pypi -on: - release: - types: - - created -jobs: - publish: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Install and configure Poetry - run: | - pip install -U pip poetry - poetry config virtualenvs.create false - - name: Build dists - run: make build - - name: Pypi Publish - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.pypi_password }} diff --git a/asynch/connection.py b/asynch/connection.py index 27383dc..b301e0a 100644 --- a/asynch/connection.py +++ b/asynch/connection.py @@ -20,11 +20,14 @@ def __init__( cursor_cls=Cursor, echo: bool = False, stack_track: bool = False, + pre_ping: bool = True, **kwargs, ): if dsn: config = parse_dsn(dsn) - self._connection = ProtoConnection(**config, stack_track=stack_track, **kwargs) + self._connection = ProtoConnection( + **config, stack_track=stack_track, pre_ping=pre_ping, **kwargs + ) user = config.get("user", None) or user password = config.get("password", None) or password host = config.get("host", None) or host @@ -178,32 +181,25 @@ async def ping(self) -> None: msg = f"Ping has failed for {self}" raise ConnectionError(msg) - async def _refresh(self) -> None: - """Refresh the connection. + async def is_live(self) -> bool: + """Checks if the connection is live. - Attempting to ping and if failed, - then trying to connect again. - If the reconnection does not work, - an Exception is propagated. + Attempts to ping and returns True if successful. :raises ConnectionError: 1. refreshing created, i.e., not opened connection 2. refreshing already closed connection - :return: None + :return: True if the connection is alive, otherwise False. """ - - if self.status == ConnectionStatus.created: - msg = f"the {self} is not opened to be refreshed" - raise ConnectionError(msg) - if self.status == ConnectionStatus.closed: - msg = f"the {self} is already closed" - raise ConnectionError(msg) + if self.status == ConnectionStatus.created or self.status == ConnectionStatus.closed: + return False try: await self.ping() + return True except ConnectionError: - await self.connect() + return False async def rollback(self): raise NotSupportedError diff --git a/asynch/cursors.py b/asynch/cursors.py index 9b041cd..b28a290 100644 --- a/asynch/cursors.py +++ b/asynch/cursors.py @@ -248,6 +248,9 @@ def _prepare(self, context=None): "query_id": self._query_id, } + if "columnar" in execution_options: + execute_kwargs["columnar"] = execution_options.get("columnar", False) + return execute, execute_kwargs def __aiter__(self): diff --git a/asynch/pool.py b/asynch/pool.py index 7aa1f38..056e6c7 100644 --- a/asynch/pool.py +++ b/asynch/pool.py @@ -2,7 +2,7 @@ import logging from collections import deque from collections.abc import AsyncIterator -from contextlib import asynccontextmanager, suppress +from contextlib import asynccontextmanager from typing import Optional from asynch.connection import Connection @@ -132,21 +132,19 @@ def maxsize(self) -> int: def minsize(self) -> int: return self._minsize - async def _create_connection(self) -> None: + async def _create_connection(self) -> Connection: if self._pool_size == self._maxsize: raise AsynchPoolError(f"{self} is already full") if self._pool_size > self._maxsize: raise AsynchPoolError(f"{self} is overburden") - conn = Connection(**self._connection_kwargs) + conn = Connection(pre_ping=False, **self._connection_kwargs) await conn.connect() + return conn - try: - await conn.ping() - self._free_connections.append(conn) - except ConnectionError as e: - msg = f"failed to create a {conn} for {self}" - raise AsynchPoolError(msg) from e + async def _create_and_release_connection(self) -> None: + conn = await self._create_connection() + self._free_connections.append(conn) def _pop_connection(self) -> Connection: if not self._free_connections: @@ -156,8 +154,8 @@ def _pop_connection(self) -> Connection: async def _get_fresh_connection(self) -> Optional[Connection]: while self._free_connections: conn = self._pop_connection() - with suppress(ConnectionError): - await conn._refresh() + logger.debug(f"Testing connection {conn}") + if await conn.is_live(): return conn return None @@ -166,8 +164,8 @@ async def _acquire_connection(self) -> Connection: self._acquired_connections.append(conn) return conn - await self._create_connection() - conn = self._pop_connection() + logger.debug("No free connection in pool. Creating new connection.") + conn = await self._create_connection() self._acquired_connections.append(conn) return conn @@ -175,13 +173,8 @@ async def _release_connection(self, conn: Connection) -> None: if conn not in self._acquired_connections: raise AsynchPoolError(f"the connection {conn} does not belong to {self}") + logger.debug(f"Releasing connection {conn}") self._acquired_connections.remove(conn) - try: - await conn._refresh() - except ConnectionError as e: - msg = f"the {conn} is invalidated" - raise AsynchPoolError(msg) from e - self._free_connections.append(conn) async def _init_connections(self, n: int, *, strict: bool = False) -> None: @@ -199,7 +192,7 @@ async def _init_connections(self, n: int, *, strict: bool = False) -> None: # it is possible that the `_create_connection` may not create `n` connections tasks: list[asyncio.Task] = [ - asyncio.create_task(self._create_connection()) for _ in range(n) + asyncio.create_task(self._create_and_release_connection()) for _ in range(n) ] # that is why possible exceptions from the `_create_connection` are also gathered if strict and any( @@ -226,10 +219,15 @@ async def connection(self) -> AsyncIterator[Connection]: :return: a free connection from the pool :rtype: Connection """ + logger.debug( + f"Acquiring connection from Pool ({len(self._free_connections)} free connections, {len(self._acquired_connections)} acquired connections)" + ) async with self._sem: async with self._lock: conn = await self._acquire_connection() + logger.debug(f"Acquired connection {conn}") + try: yield conn finally: diff --git a/asynch/proto/columns/arraycolumn.py b/asynch/proto/columns/arraycolumn.py index 42bfeae..94b6c67 100644 --- a/asynch/proto/columns/arraycolumn.py +++ b/asynch/proto/columns/arraycolumn.py @@ -1,5 +1,5 @@ +from collections import deque from itertools import chain -from queue import Queue from struct import Struct from .base import Column @@ -82,15 +82,14 @@ async def _write_sizes( self, value, ): - q = Queue() - q.put((self, value, 0)) + q = deque([(self, value, 0)]) cur_depth = 0 offset = 0 nulls_map = [] - while not q.empty(): - column, value, depth = q.get_nowait() + while q: + column, value, depth = q.popleft() if cur_depth != depth: cur_depth = depth @@ -112,7 +111,7 @@ async def _write_sizes( nested_column = column.nested_column if isinstance(nested_column, ArrayColumn): for x in value: - q.put((nested_column, x, cur_depth + 1)) + q.append((nested_column, x, cur_depth + 1)) nulls_map.append(None if x is None else False) async def _write_data( @@ -176,8 +175,7 @@ async def _read( self, size, ): - q = Queue() - q.put((self, size, 0)) + q = deque([(self, size, 0)]) slices_series = [] @@ -196,8 +194,8 @@ async def _read( nested_column = self.nested_column # Read and store info about slices. - while not q.empty(): - column, size, depth = q.get_nowait() + while q: + column, size, depth = q.popleft() nested_column = column.nested_column @@ -220,7 +218,7 @@ async def _read( for _i in range(size): offset = await self.size_unpack() nested_column_size = offset - q.put((nested_column, offset - prev_offset, cur_depth + 1)) + q.append((nested_column, offset - prev_offset, cur_depth + 1)) slices.append((prev_offset, offset)) prev_offset = offset diff --git a/asynch/proto/connection.py b/asynch/proto/connection.py index e9b467b..8bcdcb1 100644 --- a/asynch/proto/connection.py +++ b/asynch/proto/connection.py @@ -8,6 +8,7 @@ from urllib.parse import urlparse from asynch.errors import ( + OperationalError, PartiallyConsumedQueryError, ServerException, UnexpectedPacketFromServerError, @@ -94,6 +95,7 @@ def __init__( # nosec:B107 alt_hosts: str = None, stack_track=False, settings_is_important=False, + pre_ping: bool = True, **kwargs, ): self.stack_track = stack_track @@ -119,6 +121,7 @@ def __init__( # nosec:B107 self._lock = asyncio.Lock() self.secure_socket = secure self.verify = verify + self.pre_ping = pre_ping ssl_options = {} if ssl_version is not None: @@ -320,10 +323,12 @@ async def ping(self) -> bool: msg = self.unexpected_packet_message("Pong", packet_type) raise UnexpectedPacketFromServerError(msg) return True - except AttributeError: - logger.debug("The connection %s is not open", self) + except OperationalError as e: + logger.info("The connection %s is not open", self, exc_info=e) + except AttributeError as e: + logger.info("The connection %s is not open", self, exc_info=e) except IndexError as e: - logger.debug( + logger.info( "Ping package smaller than expected or empty. " "There may be connection or network problems - " "we believe that the connection is incorrect.", @@ -334,7 +339,7 @@ async def ping(self) -> bool: # because this is a connection loss case if isinstance(e, RuntimeError) and "TCPTransport closed=True" not in str(e): raise e - logger.debug("Socket closed", exc_info=e) + logger.info("Socket closed", exc_info=e) return False async def receive_data(self, raw=False): @@ -442,6 +447,7 @@ async def _receive_packet(self): elif packet_type == ServerPacket.EXCEPTION: packet.exception = await self.receive_exception() + self.is_query_executing = False elif packet.type == ServerPacket.PROGRESS: packet.progress = await self.receive_progress() @@ -577,9 +583,10 @@ async def disconnect(self): async def connect(self): if self.connected: await self.disconnect() - logger.debug("Connecting. Database: %s. User: %s", self.database, self.user) for host, port in self.hosts: - logger.debug("Connecting to %s:%s", host, port) + logger.debug( + "Connecting to %s:%s Database: %s. User: %s", host, port, self.database, self.user + ) return await self._init_connection(host, port) async def execute( @@ -758,9 +765,10 @@ async def force_connect(self): if not self.connected: await self.connect() - elif not await self.ping(): - logger.info("Connection was closed, reconnecting.") - await self.connect() + elif self.pre_ping: + if not await self.ping(): + logger.info("Connection was closed, reconnecting.") + await self.connect() async def process_ordinary_query( self, @@ -820,30 +828,49 @@ async def process_insert_query( await self.send_query(query_without_data, query_id=query_id) await self.send_external_tables(external_tables, types_check=types_check) - sample_block = await self.receive_sample_block() + sample_block = await self._receive_sample_block() if sample_block: rv = await self.send_data( sample_block, data, types_check=types_check, columnar=columnar ) - packet = await self._receive_packet() - if packet.exception: - raise packet.exception + await self._receive_end_of_stream() + return rv - async def receive_sample_block(self): + async def _receive_sample_block(self): while True: packet = await self._receive_packet() if packet.type == ServerPacket.DATA: return packet.block - elif packet.type == ServerPacket.EXCEPTION: raise packet.exception elif packet.type == ServerPacket.LOG: - self.log_block(packet.block) + pass elif packet.type == ServerPacket.TABLE_COLUMNS: pass + else: + message = self.unexpected_packet_message( + "Data, Exception or TableColumns", packet.type + ) + raise UnexpectedPacketFromServerError(message) + async def _receive_end_of_stream(self): + while True: + packet = await self._receive_packet() + + if packet.type == ServerPacket.END_OF_STREAM: + return + elif packet.type == ServerPacket.EXCEPTION: + raise packet.exception + elif packet.type == ServerPacket.LOG: + pass + elif packet.type == ServerPacket.PROFILE_INFO: + pass + elif packet.type == ServerPacket.PROFILE_EVENTS: + pass + elif packet.type == ServerPacket.PROGRESS: + pass else: message = self.unexpected_packet_message( "Data, Exception or TableColumns", packet.type diff --git a/asynch/proto/constants.py b/asynch/proto/constants.py index 24232fb..f147a1a 100644 --- a/asynch/proto/constants.py +++ b/asynch/proto/constants.py @@ -31,6 +31,8 @@ DBMS_MIN_PROTOCOL_VERSION_WITH_INITIAL_QUERY_START_TIME = 54449 DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS = 54451 DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS = 54453 +DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION = 54454 +DBMS_MIN_PROTOCOL_VERSION_WITH_PROFILE_EVENTS_IN_INSERT = 54456 # Timeouts DBMS_DEFAULT_CONNECT_TIMEOUT_SEC = 10 @@ -46,7 +48,7 @@ CLIENT_VERSION_MAJOR = 20 CLIENT_VERSION_MINOR = 10 CLIENT_VERSION_PATCH = 2 -CLIENT_REVISION = 54453 +CLIENT_REVISION = 54456 BUFFER_SIZE = 1048576 diff --git a/asynch/proto/context.py b/asynch/proto/context.py index ff70323..5a2f0ae 100644 --- a/asynch/proto/context.py +++ b/asynch/proto/context.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Optional +from asynch.errors import ServerException from asynch.proto.result import QueryInfo if TYPE_CHECKING: @@ -55,7 +56,9 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type: - if issubclass(exc_type, (Exception, KeyboardInterrupt)): + if not issubclass(exc_type, ServerException) and issubclass( + exc_type, (Exception, KeyboardInterrupt) + ): await self._connection.disconnect() raise exc_val self._connection.track_current_database(self._query) diff --git a/asynch/proto/models/enums.py b/asynch/proto/models/enums.py index 7c9cf8c..957a8a7 100644 --- a/asynch/proto/models/enums.py +++ b/asynch/proto/models/enums.py @@ -12,6 +12,9 @@ class ConnectionStatus(str, Enum): opened = "opened" closed = "closed" + def __str__(self): + return self.value + class CursorStatus(str, Enum): ready = "ready" @@ -19,6 +22,9 @@ class CursorStatus(str, Enum): finished = "finished" closed = "closed" + def __str__(self): + return self.value + class PoolStatus(str, Enum): created = "created" @@ -29,3 +35,6 @@ class PoolStatus(str, Enum): class ClickhouseScheme(str, Enum): clickhouse = "clickhouse" clickhouses = "clickhouses" + + def __str__(self): + return self.value diff --git a/asynch/proto/streams/block.py b/asynch/proto/streams/block.py index af21411..11dddf5 100644 --- a/asynch/proto/streams/block.py +++ b/asynch/proto/streams/block.py @@ -1,3 +1,4 @@ +from asynch.errors import ClickHouseException from asynch.proto import constants from asynch.proto.block import BaseBlock, BlockInfo, ColumnOrientedBlock from asynch.proto.columns import read_column, write_column @@ -30,6 +31,9 @@ async def write(self, block: BaseBlock): await self.writer.write_str( col_type, ) + if revision >= constants.DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION: + # We write always sparse data without custom serialization. + await self.writer.write_uint8(0) if n_columns: try: @@ -78,6 +82,14 @@ async def read(self): names.append(column_name) types.append(column_type) + if revision >= constants.DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION: + # Only with revision 54465 will the server actually send custom serialization. + has_custom_serialization = bool(await self.reader.read_uint8()) + if has_custom_serialization: + raise ClickHouseException( + f"Custom serialization for column {column_name} not supported." + ) + if n_rows: column = await read_column( self.reader, diff --git a/asynch/proto/streams/buffered.py b/asynch/proto/streams/buffered.py index e435773..79d5f3f 100644 --- a/asynch/proto/streams/buffered.py +++ b/asynch/proto/streams/buffered.py @@ -61,6 +61,9 @@ async def write_fixed_strings(self, data, length): async def close(self) -> None: if not self.writer: return + if self.writer.is_closing(): + return + self.writer.close() await self.writer.wait_closed() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 6c1dddc..dd87e63 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,6 @@ import pytest +from asynch import Connection from asynch.errors import ServerException from asynch.pool import Pool @@ -11,3 +12,16 @@ async def test_database_exists(config): async with conn.cursor() as cursor: with pytest.raises(ServerException): await cursor.execute("create database test") + + +@pytest.mark.asyncio +async def test_connection_does_not_close_after_exception(): + async with Connection() as conn: + async with conn.cursor() as cur: + with pytest.raises(ServerException): + await cur.execute("foo") + + assert conn._connection.connected is True + assert conn.opened is True + + await cur.execute("select 1") diff --git a/tests/test_proto/test_proto_connection.py b/tests/test_proto/test_proto_connection.py index c620af8..6201987 100644 --- a/tests/test_proto/test_proto_connection.py +++ b/tests/test_proto/test_proto_connection.py @@ -103,6 +103,17 @@ async def test_execute_with_missing_arg(proto_conn: ProtoConnection): await proto_conn.execute(query, args={"foo": 1}) +@pytest.mark.asyncio +async def test_large_insert(proto_conn: ProtoConnection): + data = [(1,)] * 10_000 + async with create_table(proto_conn, "a Int64"): + await proto_conn.execute( + "INSERT INTO test.test (a) VALUES", data, settings={"insert_block_size": 1000} + ) + rv = await proto_conn.execute("SELECT * FROM test.test") + assert rv == data + + @asynccontextmanager async def create_table(connection, spec): await connection.execute("DROP TABLE IF EXISTS test.test") diff --git a/tests/test_reconnection.py b/tests/test_reconnection.py index be47abf..e8d8d14 100644 --- a/tests/test_reconnection.py +++ b/tests/test_reconnection.py @@ -40,7 +40,7 @@ async def proxy(request): @pytest.fixture() async def proxy_pool(proxy): - async with Pool(minsize=1, maxsize=1, dsn=CONNECTION_DSN.replace("9000", "9001")) as pool: + async with Pool(minsize=1, maxsize=2, dsn=CONNECTION_DSN.replace("9000", "9001")) as pool: yield pool @@ -66,6 +66,30 @@ async def test_close_disconnected_connection(proxy_pool): await asyncio.sleep(TIMEOUT * 2) +@pytest.mark.asyncio +async def test_connection_reuse(proxy_pool): + async def execute_sleep(): + async with proxy_pool.connection() as c: + async with c.cursor() as cursor: + await cursor.execute("SELECT sleep(0.1)") + + await asyncio.gather(execute_sleep(), execute_sleep()) + + # There are two live connections in the pool. + assert proxy_pool.free_connections == 2 + + logger.info(f"Killing {proxy_pool._free_connections[0]}") + await proxy_pool._free_connections[0]._connection.writer.close() + + async with proxy_pool.connection() as c: + async with c.cursor() as cursor: + await cursor.execute("SELECT 1") + + # The first connection was not live anymore and was closed. The second connection was reused. + # There is now only one connection in the pool. + assert proxy_pool.free_connections == 1 + + async def reader_to_writer(name: str, graceful: bool, reader: StreamReader, writer: StreamWriter): while True: try: