diff --git a/databend_sqlalchemy/databend_dialect.py b/databend_sqlalchemy/databend_dialect.py index ca3bbac..05ea602 100644 --- a/databend_sqlalchemy/databend_dialect.py +++ b/databend_sqlalchemy/databend_dialect.py @@ -31,7 +31,7 @@ import sqlalchemy.engine.reflection import sqlalchemy.types as sqltypes -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, List from sqlalchemy import util as sa_util from sqlalchemy.engine import reflection from sqlalchemy.sql import ( @@ -703,7 +703,6 @@ def __init__(self, key_type, value_type): super(MAP, self).__init__() - class DatabendDate(sqltypes.DATE): __visit_name__ = "DATE" @@ -857,7 +856,6 @@ class DatabendGeography(GEOGRAPHY): } - # Column spec colspecs = { sqltypes.Interval: DatabendInterval, @@ -872,6 +870,12 @@ class DatabendGeography(GEOGRAPHY): class DatabendIdentifierPreparer(PGIdentifierPreparer): reserved_words = {r.lower() for r in RESERVED_WORDS} + # overridden to exclude schema from sequence + def format_sequence( + self, sequence, use_schema: bool = True + ) -> str: + return super().format_sequence(sequence, use_schema=False) + class DatabendCompiler(PGCompiler): iscopyintotable: bool = False @@ -1230,6 +1234,15 @@ def copy_into_table_results(self) -> list[dict]: def copy_into_location_results(self) -> dict: return self._copy_into_location_results + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "select nextval(%s)" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + class DatabendTypeCompiler(compiler.GenericTypeCompiler): def visit_ARRAY(self, type_, **kw): @@ -1280,7 +1293,6 @@ def visit_GEOGRAPHY(self, type_, **kw): return "GEOGRAPHY" - class DatabendDDLCompiler(compiler.DDLCompiler): def visit_primary_key_constraint(self, constraint, **kw): return "" @@ -1394,6 +1406,7 @@ class DatabendDialect(default.DefaultDialect): supports_empty_insert = False supports_is_distinct_from = True supports_multivalues_insert = True + supports_sequences = True supports_statement_cache = False supports_server_side_cursors = True @@ -1478,24 +1491,24 @@ def _get_default_schema_name(self, connection): def get_schema_names(self, connection, **kw): return [row[0] for row in connection.execute(text("SHOW DATABASES"))] - def _get_table_columns(self, connection, table_name, schema): - if schema is None: - schema = self.default_schema_name - quote_table_name = self.identifier_preparer.quote_identifier(table_name) - quote_schema = self.identifier_preparer.quote_identifier(schema) - - return connection.execute( - text(f"DESC {quote_schema}.{quote_table_name}") - ).fetchall() - @reflection.cache def has_table(self, connection, table_name, schema=None, **kw): + table_name_query = """ + select case when exists( + select table_name + from information_schema.tables + where table_schema = :schema_name + and table_name = :table_name + ) then 1 else 0 end + """ + query = text(table_name_query).bindparams( + bindparam("schema_name", type_=sqltypes.Unicode), + bindparam("table_name", type_=sqltypes.Unicode), + ) if schema is None: schema = self.default_schema_name - quote_table_name = self.identifier_preparer.quote_identifier(table_name) - quote_schema = self.identifier_preparer.quote_identifier(schema) - query = f"""EXISTS TABLE {quote_schema}.{quote_table_name}""" - r = connection.scalar(text(query)) + + r = connection.scalar(query, dict(schema_name=schema, table_name=table_name)) if r == 1: return True return False @@ -1537,21 +1550,26 @@ def get_columns(self, connection, table_name, schema=None, **kw): def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is None: schema = self.default_schema_name - quote_schema = self.identifier_preparer.quote_identifier(schema) - quote_view_name = self.identifier_preparer.quote_identifier(view_name) - full_view_name = f"{quote_schema}.{quote_view_name}" - - # ToDo : perhaps can be removed if we get `SHOW CREATE VIEW` - if view_name not in self.get_view_names(connection, schema): - raise NoSuchTableError(full_view_name) - - query = f"""SHOW CREATE TABLE {full_view_name}""" - try: - view_def = connection.execute(text(query)).first() - return view_def[1] - except DBAPIError as e: - if "1025" in e.orig.message: # ToDo: The errors need parsing properly - raise NoSuchTableError(full_view_name) from e + query = text( + """ + select view_query + from system.views + where name = :view_name + and database = :schema_name + """ + ).bindparams( + bindparam("view_name", type_=sqltypes.UnicodeText), + bindparam("schema_name", type_=sqltypes.Unicode), + ) + r = connection.scalar( + query, dict(view_name=view_name, schema_name=schema) + ) + if not r: + raise NoSuchTableError( + f"{self.identifier_preparer.quote_identifier(schema)}." + f"{self.identifier_preparer.quote_identifier(view_name)}" + ) + return r def _get_column_type(self, column_type): pattern = r"(?:Nullable)*(?:\()*(\w+)(?:\((.*?)\))?(?:\))*" @@ -1621,7 +1639,6 @@ def get_temp_table_names(self, connection, schema=None, **kw): result = connection.execute(query, dict(schema_name=schema)) return [row[0] for row in result] - @reflection.cache def get_view_names(self, connection, schema=None, **kw): view_name_query = """ @@ -1762,7 +1779,6 @@ def get_multi_table_comment( schema='system', ).alias("a_tab_comments") - has_filter_names, params = self._prepare_filter_names(filter_names) owner = schema or self.default_schema_name @@ -1804,6 +1820,20 @@ def _check_unicode_description(self, connection): # We decode everything as UTF-8 return True + @reflection.cache + def get_sequence_names(self, connection, schema: Optional[str] = None, **kw: Any) -> List[str]: + # N.B. sequences are not defined per schema/database + sequence_query = """ + show sequences + """ + query = text(sequence_query) + result = connection.execute(query) + return [row[0] for row in result] + + def has_sequence(self, connection, sequence_name: str, schema: Optional[str] = None, **kw: Any) -> bool: + # N.B. sequences are not defined per schema/database + return sequence_name in self.get_sequence_names(connection, schema, **kw) + dialect = DatabendDialect diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 64a04c2..080965a 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -13,6 +13,7 @@ from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest from sqlalchemy.testing.suite import QuotedNameArgumentTest as _QuotedNameArgumentTest from sqlalchemy.testing.suite import JoinTest as _JoinTest +from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest from sqlalchemy.testing.suite import ServerSideCursorsTest as _ServerSideCursorsTest @@ -21,7 +22,7 @@ from sqlalchemy.testing.suite import IntegerTest as _IntegerTest from sqlalchemy import types as sql_types -from sqlalchemy.testing import config +from sqlalchemy.testing import config, skip_test from sqlalchemy import testing, Table, Column, Integer from sqlalchemy.testing import eq_, fixtures, assertions @@ -30,7 +31,8 @@ from packaging import version import sqlalchemy if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'): - from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest + if version.parse(sqlalchemy.__version__) < version.parse('2.0.42'): + from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest from sqlalchemy.testing.suite import EnumTest as _EnumTest else: from sqlalchemy.testing.suite import ComponentReflectionTest as _ComponentReflectionTest @@ -42,14 +44,36 @@ def test_get_indexes(self): pass class ComponentReflectionTestExtra(_ComponentReflectionTestExtra): - + @testing.skip("databend") #ToDo No length in Databend @testing.requires.table_reflection def test_varchar_reflection(self, connection, metadata): typ = self._type_round_trip( connection, metadata, sql_types.String(52) )[0] assert isinstance(typ, sql_types.String) - # eq_(typ.length, 52) # No length in Databend + eq_(typ.length, 52) + + @testing.skip("databend") # ToDo No length in Databend + @testing.requires.table_reflection + @testing.combinations( + sql_types.String, + sql_types.VARCHAR, + sql_types.CHAR, + (sql_types.NVARCHAR, testing.requires.nvarchar_types), + (sql_types.NCHAR, testing.requires.nvarchar_types), + argnames="type_", + ) + def test_string_length_reflection(self, connection, metadata, type_): + typ = self._type_round_trip(connection, metadata, type_(52))[0] + if issubclass(type_, sql_types.VARCHAR): + assert isinstance(typ, sql_types.VARCHAR) + elif issubclass(type_, sql_types.CHAR): + assert isinstance(typ, sql_types.CHAR) + else: + assert isinstance(typ, sql_types.String) + + eq_(typ.length, 52) + assert isinstance(typ.length, int) class BooleanTest(_BooleanTest): @@ -204,7 +228,7 @@ def test_get_indexes(self, name): class JoinTest(_JoinTest): __requires__ = ("foreign_keys",) -if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'): +if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0') and version.parse(sqlalchemy.__version__) < version.parse('2.0.42'): class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): __requires__ = ("foreign_keys",) @@ -586,9 +610,6 @@ def test_geometry_write_and_read(self, connection): eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}')) - - - class GeographyTest(fixtures.TablesTest): @classmethod @@ -664,4 +685,74 @@ def test_geography_write_and_read(self, connection): result = connection.execute( select(geography_table.c.geography_data).where(geography_table.c.id == 7) ).scalar() - eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}')) \ No newline at end of file + eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}')) + + +class HasSequenceTest(_HasSequenceTest): + + # ToDo - overridden other_seq definition due to lack of sequence ddl support for nominvalue nomaxvalue + @classmethod + def define_tables(cls, metadata): + normalize_sequence(config, Sequence("user_id_seq", metadata=metadata)) + normalize_sequence( + config, + Sequence( + "other_seq", + metadata=metadata, + # nomaxvalue=True, + # nominvalue=True, + ), + ) + if testing.requires.schemas.enabled: + #ToDo - omitted because Databend does not allow schema on sequence + # normalize_sequence( + # config, + # Sequence( + # "user_id_seq", schema=config.test_schema, metadata=metadata + # ), + # ) + normalize_sequence( + config, + Sequence( + "schema_seq", schema=config.test_schema, metadata=metadata + ), + ) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + @testing.skip("databend") # ToDo - requires definition of sequences with schema + def test_has_sequence_remote_not_in_default(self, connection): + eq_(inspect(connection).has_sequence("schema_seq"), False) + + @testing.skip("databend") # ToDo - requires definition of sequences with schema + def test_get_sequence_names(self, connection): + exp = {"other_seq", "user_id_seq"} + + res = set(inspect(connection).get_sequence_names()) + is_true(res.intersection(exp) == exp) + is_true("schema_seq" not in res) + + @testing.skip("databend") # ToDo - requires definition of sequences with schema + @testing.requires.schemas + def test_get_sequence_names_no_sequence_schema(self, connection): + eq_( + inspect(connection).get_sequence_names( + schema=config.test_schema_2 + ), + [], + ) + + @testing.skip("databend") # ToDo - requires definition of sequences with schema + @testing.requires.schemas + def test_get_sequence_names_sequences_schema(self, connection): + eq_( + sorted( + inspect(connection).get_sequence_names( + schema=config.test_schema + ) + ), + ["schema_seq", "user_id_seq"], + )