diff --git a/CHANGELOG.md b/CHANGELOG.md index d69236e6..2546f10b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog ## [Unreleased] +- Added support for Time and Time64 columns (available in ClickHouse server 25.6+). Closes [#390](https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/390) ## [0.3.2] - 2024-06-12 ### Added diff --git a/clickhouse_sqlalchemy/alembic/comparators.py b/clickhouse_sqlalchemy/alembic/comparators.py index 6743bea9..492165d4 100644 --- a/clickhouse_sqlalchemy/alembic/comparators.py +++ b/clickhouse_sqlalchemy/alembic/comparators.py @@ -4,7 +4,6 @@ from alembic.autogenerate import comparators from alembic.autogenerate.compare import _compare_columns from alembic.operations.ops import ModifyTableOps -from alembic.util.sqla_compat import _reflect_table as _alembic_reflect_table from sqlalchemy import schema as sa_schema from sqlalchemy import text @@ -35,10 +34,9 @@ def _extract_to_table_name(create_table_query): def _reflect_table(inspector, table): - if alembic_version >= (1, 11, 0): - return _alembic_reflect_table(inspector, table) - else: - return _alembic_reflect_table(inspector, table, None) + # Use SQLAlchemy's standard reflection mechanism + table.clear() + inspector.reflect_table(table, None) @comparators.dispatch_for('schema', 'clickhouse') diff --git a/clickhouse_sqlalchemy/drivers/base.py b/clickhouse_sqlalchemy/drivers/base.py index 7e102df1..a3c4fe63 100644 --- a/clickhouse_sqlalchemy/drivers/base.py +++ b/clickhouse_sqlalchemy/drivers/base.py @@ -38,6 +38,8 @@ 'Date32': types.Date32, 'DateTime': types.DateTime, 'DateTime64': types.DateTime64, + 'Time': types.Time, + 'Time64': types.Time64, 'Float64': types.Float64, 'Float32': types.Float32, 'Decimal': types.Decimal, @@ -313,6 +315,9 @@ def _get_column_type(self, name, spec): elif spec.startswith('DateTime'): coltype = self.ischema_names['DateTime'] return coltype(*self._parse_detetime_params(spec)) + elif spec.startswith('Time64'): + coltype = self.ischema_names['Time64'] + return coltype(*self._parse_time64_params(spec)) else: try: return self.ischema_names[spec] @@ -345,6 +350,13 @@ def _parse_detetime_params(spec): return [] return [inner_spec] + @staticmethod + def _parse_time64_params(spec): + inner_spec = get_inner_spec(spec) + if not inner_spec: + return [] + return [int(inner_spec)] + @staticmethod def _parse_options(option_string): options = dict() diff --git a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py index 0c5ab472..c708266d 100644 --- a/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py +++ b/clickhouse_sqlalchemy/drivers/compilers/typecompiler.py @@ -179,3 +179,14 @@ def visit_simpleaggregatefunction(self, type_, **kw): return "SimpleAggregateFunction(%s, %s)" % ( agg_str, ", ".join(type_strings) ) + + def visit_time(self, type_, **kw): + return 'Time' + + def visit_time64(self, type_, **kw): + if type_.precision not in [3, 6, 9]: + raise ValueError( + "Invalid precision value. Expected one of [3, 6, 9]." + ) + + return f'Time64({type_.precision})' diff --git a/clickhouse_sqlalchemy/drivers/http/escaper.py b/clickhouse_sqlalchemy/drivers/http/escaper.py index 9c572942..8cbbdad0 100644 --- a/clickhouse_sqlalchemy/drivers/http/escaper.py +++ b/clickhouse_sqlalchemy/drivers/http/escaper.py @@ -1,4 +1,4 @@ -from datetime import date, datetime +from datetime import date, datetime, time from decimal import Decimal import enum import uuid @@ -47,6 +47,15 @@ def escape_datetime64(self, item): # XXX: shouldn't this be `toDateTime64(...)`? return self.escape_string(item.strftime('%Y-%m-%d %H:%M:%S.%f')) + def escape_time(self, item): + if item.microsecond: + value = item.strftime('%H:%M:%S.%f').rstrip('0') + if value[-1] == '.': + value = value[:-1] + else: + value = item.strftime('%H:%M:%S') + return self.escape_string(value) + def escape_decimal(self, item): return float(item) @@ -60,6 +69,8 @@ def escape_item(self, item): return self.escape_number(item) elif isinstance(item, datetime): return self.escape_datetime(item) + elif isinstance(item, time): + return self.escape_time(item) elif isinstance(item, date): return self.escape_date(item) elif isinstance(item, Decimal): diff --git a/clickhouse_sqlalchemy/drivers/http/transport.py b/clickhouse_sqlalchemy/drivers/http/transport.py index ad01ea56..64b5e1d6 100644 --- a/clickhouse_sqlalchemy/drivers/http/transport.py +++ b/clickhouse_sqlalchemy/drivers/http/transport.py @@ -1,6 +1,6 @@ import re -from datetime import datetime +from datetime import datetime, timedelta from decimal import Decimal from functools import partial @@ -35,6 +35,32 @@ def datetime_converter(x): return datetime.strptime(x, '%Y-%m-%d %H:%M:%S') +def time_converter(x): + if x is None: + return None + + time_part, _, fractional = x.partition('.') + dt = datetime.strptime(time_part, '%H:%M:%S') + + if fractional: + fractional = fractional.rstrip('0') + if fractional: + digits = len(fractional) + if digits <= 6: + microsecond = int(fractional.ljust(6, '0')) + else: + scale = 10 ** (digits - 6) + # datetime supports up to microsec precision-round here. + microsecond = (int(fractional) + scale // 2) // scale + if microsecond == 1000000: + dt += timedelta(seconds=1) + microsecond = 0 + if dt.day != 1: + dt -= timedelta(days=1) + dt = dt.replace(microsecond=microsecond) + return dt.time() + + def nullable_converter(subtype_str, x): if x is None: return None @@ -66,6 +92,8 @@ def nothing_converter(x): 'Date': date_converter, 'DateTime': datetime_converter, 'DateTime64': datetime_converter, + 'Time': time_converter, + 'Time64': time_converter, 'IPv4': IPv4Address, 'IPv6': IPv6Address, 'Nullable': nullable_converter, @@ -80,6 +108,8 @@ def _get_type(type_str): # sometimes type_str is DateTime64(x) if type_str.startswith('DateTime64'): return converters['DateTime64'] + if type_str.startswith('Time64'): + return converters['Time64'] if type_str.startswith('Decimal'): return converters['Decimal'] if type_str.startswith('Nullable('): diff --git a/clickhouse_sqlalchemy/types/__init__.py b/clickhouse_sqlalchemy/types/__init__.py index 67747f88..4f177d2e 100644 --- a/clickhouse_sqlalchemy/types/__init__.py +++ b/clickhouse_sqlalchemy/types/__init__.py @@ -37,6 +37,8 @@ 'Map', 'AggregateFunction', 'SimpleAggregateFunction', + 'Time', + 'Time64', ] from .common import String @@ -74,6 +76,8 @@ from .common import Map from .common import AggregateFunction from .common import SimpleAggregateFunction +from .common import Time +from .common import Time64 from .ip import IPv4 from .ip import IPv6 from .nested import Nested diff --git a/clickhouse_sqlalchemy/types/common.py b/clickhouse_sqlalchemy/types/common.py index 458a320c..2bbac43a 100644 --- a/clickhouse_sqlalchemy/types/common.py +++ b/clickhouse_sqlalchemy/types/common.py @@ -254,3 +254,15 @@ def __repr__(self) -> str: agg_str = f'sa.func.{self.agg_func}' return f"SimpleAggregateFunction({agg_str}, {', '.join(type_strs)})" + + +class Time(ClickHouseTypeEngine): + __visit_name__ = "time" + + +class Time64(ClickHouseTypeEngine): + __visit_name__ = "time64" + + def __init__(self, precision=3): + self.precision = precision + super().__init__() diff --git a/tests/types/test_json.py b/tests/types/test_json.py index b2d03858..f7fe6f58 100644 --- a/tests/types/test_json.py +++ b/tests/types/test_json.py @@ -37,6 +37,10 @@ class JSONTestCase(BaseTestCase): ) def test_select_insert(self): + # Native driver doesn't support JSON type yet + if self.session.bind.driver == "native": + self.skipTest("Native driver doesn't support JSON type yet") + data = {'k1': 1, 'k2': '2', 'k3': True} self.table.drop(bind=self.session.bind, if_exists=True) diff --git a/tests/types/test_time.py b/tests/types/test_time.py new file mode 100644 index 00000000..35370d52 --- /dev/null +++ b/tests/types/test_time.py @@ -0,0 +1,59 @@ +import datetime + +from sqlalchemy import Column, text +from sqlalchemy.sql.ddl import CreateTable + +from clickhouse_sqlalchemy import Table, engines, types +from tests.testcase import BaseTestCase, CompilationTestCase +from tests.util import with_native_and_http_sessions + + +class TimeCompilationTestCase(CompilationTestCase): + def test_create_table(self): + if self.server_version < (25, 6, 0): + self.skipTest("Time types require ClickHouse 25.6+") + table = Table( + "test", + CompilationTestCase.metadata(), + Column("x", types.Time, primary_key=True), + engines.Memory(), + ) + assert ( + self.compile(CreateTable(table)) + == "CREATE TABLE test (x Time) ENGINE = Memory" + ) + + +@with_native_and_http_sessions +class TimeRuntimeTestCase(BaseTestCase): + def test_select_insert(self): + if self.server_version < (25, 6, 0): + self.skipTest("Time types require ClickHouse 25.6+") + + # Native driver doesn't support Time type yet + if self.session.bind.driver == "native": + self.skipTest("Native driver doesn't support Time type yet") + + time_val = datetime.time(15, 20, 30) + table_name = "test_time_runtime" + + with self.session.bind.connect() as conn: + try: + conn.execute(text(f"DROP TABLE IF EXISTS {table_name}")) + conn.execute( + text( + f""" + CREATE TABLE {table_name} (x Time) ENGINE = Memory + SETTINGS enable_time_time64_type = 1 + """ + ) + ) + conn.execute( + text(f"INSERT INTO {table_name} (x) VALUES ('{time_val}')") + ) + result = conn.execute( + text(f"SELECT x FROM {table_name}") + ).scalar() + assert result == time_val + finally: + conn.execute(text(f"DROP TABLE IF EXISTS {table_name}")) diff --git a/tests/types/test_time64.py b/tests/types/test_time64.py new file mode 100644 index 00000000..272d4714 --- /dev/null +++ b/tests/types/test_time64.py @@ -0,0 +1,107 @@ +import datetime + +import pytest +from sqlalchemy import Column, text +from sqlalchemy.sql.ddl import CreateTable + +from clickhouse_sqlalchemy import Table, engines, types +from tests.testcase import BaseTestCase, CompilationTestCase +from tests.util import with_native_and_http_sessions + + +class Time64CompilationTestCase(CompilationTestCase): + def test_create_table(self): + if self.server_version < (25, 6, 0): + self.skipTest("Time types require ClickHouse 25.6+") + table = Table( + "test", + CompilationTestCase.metadata(), + Column("x", types.Time64, primary_key=True), + engines.Memory(), + ) + + self.assertEqual( + self.compile(CreateTable(table)), + "CREATE TABLE test (x Time64(3)) ENGINE = Memory", + ) + + +class Time64CompilationTestCasePrecision(CompilationTestCase): + def test_create_table_with_precision(self): + if self.server_version < (25, 6, 0): + self.skipTest("Time types require ClickHouse 25.6+") + table = Table( + "test", + CompilationTestCase.metadata(), + Column("x", types.Time64(6), primary_key=True), + engines.Memory(), + ) + + self.assertEqual( + self.compile(CreateTable(table)), + "CREATE TABLE test (x Time64(6)) ENGINE = Memory", + ) + + def test_create_table_with_bad_precision(self): + if self.server_version < (25, 6, 0): + self.skipTest("Time types require ClickHouse 25.6+") + table = Table( + "test", + CompilationTestCase.metadata(), + Column("x", types.Time64(7), primary_key=True), + engines.Memory(), + ) + + with pytest.raises(ValueError, match="Invalid precision value"): + self.compile(CreateTable(table)) + + def test_create_table_with_empty_precision_defaults_to_3(self): + if self.server_version < (25, 6, 0): + self.skipTest("Time types require ClickHouse 25.6+") + table = Table( + "test", + CompilationTestCase.metadata(), + Column("x", types.Time64, primary_key=True), + engines.Memory(), + ) + + self.assertEqual( + self.compile(CreateTable(table)), + "CREATE TABLE test (x Time64(3)) ENGINE = Memory", + ) + + +@with_native_and_http_sessions +class Time64TestCase(BaseTestCase): + def test_select_insert(self): + if self.server_version < (25, 6, 0): + self.skipTest("Time types require ClickHouse 25.6+") + + # Native driver doesn't support Time64 type yet + if self.session.bind.driver == "native": + self.skipTest("Native driver doesn't support Time64 type yet") + + time_val = datetime.time(15, 20, 30, 123000) + table_name = "test_time64_runtime" + + with self.session.bind.connect() as conn: + try: + conn.execute(text(f"DROP TABLE IF EXISTS {table_name}")) + conn.execute( + text( + f""" + CREATE TABLE {table_name} (x Time64(3)) ENGINE = Memory + SETTINGS enable_time_time64_type = 1 + """ + ) + ) + conn.execute( + text(f"INSERT INTO {table_name} (x) VALUES ('{time_val}')") + ) + result = conn.execute( + text(f"SELECT x FROM {table_name}") + ).scalar() + self.assertEqual(result, time_val) + + finally: + conn.execute(text(f"DROP TABLE IF EXISTS {table_name}"))