Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 65 additions & 35 deletions databend_sqlalchemy/databend_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import sqlalchemy.engine.reflection
import sqlalchemy.types as sqltypes
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, List
from sqlalchemy import util as sa_util
from sqlalchemy.engine import reflection
from sqlalchemy.sql import (
Expand Down Expand Up @@ -703,7 +703,6 @@ def __init__(self, key_type, value_type):
super(MAP, self).__init__()



class DatabendDate(sqltypes.DATE):
__visit_name__ = "DATE"

Expand Down Expand Up @@ -857,7 +856,6 @@ class DatabendGeography(GEOGRAPHY):
}



# Column spec
colspecs = {
sqltypes.Interval: DatabendInterval,
Expand All @@ -872,6 +870,12 @@ class DatabendGeography(GEOGRAPHY):
class DatabendIdentifierPreparer(PGIdentifierPreparer):
reserved_words = {r.lower() for r in RESERVED_WORDS}

# overridden to exclude schema from sequence
def format_sequence(
self, sequence, use_schema: bool = True
) -> str:
return super().format_sequence(sequence, use_schema=False)


class DatabendCompiler(PGCompiler):
iscopyintotable: bool = False
Expand Down Expand Up @@ -1230,6 +1234,15 @@ def copy_into_table_results(self) -> list[dict]:
def copy_into_location_results(self) -> dict:
return self._copy_into_location_results

def fire_sequence(self, seq, type_):
return self._execute_scalar(
(
"select nextval(%s)"
% self.identifier_preparer.format_sequence(seq)
),
type_,
)


class DatabendTypeCompiler(compiler.GenericTypeCompiler):
def visit_ARRAY(self, type_, **kw):
Expand Down Expand Up @@ -1280,7 +1293,6 @@ def visit_GEOGRAPHY(self, type_, **kw):
return "GEOGRAPHY"



class DatabendDDLCompiler(compiler.DDLCompiler):
def visit_primary_key_constraint(self, constraint, **kw):
return ""
Expand Down Expand Up @@ -1394,6 +1406,7 @@ class DatabendDialect(default.DefaultDialect):
supports_empty_insert = False
supports_is_distinct_from = True
supports_multivalues_insert = True
supports_sequences = True

supports_statement_cache = False
supports_server_side_cursors = True
Expand Down Expand Up @@ -1478,24 +1491,24 @@ def _get_default_schema_name(self, connection):
def get_schema_names(self, connection, **kw):
return [row[0] for row in connection.execute(text("SHOW DATABASES"))]

def _get_table_columns(self, connection, table_name, schema):
if schema is None:
schema = self.default_schema_name
quote_table_name = self.identifier_preparer.quote_identifier(table_name)
quote_schema = self.identifier_preparer.quote_identifier(schema)

return connection.execute(
text(f"DESC {quote_schema}.{quote_table_name}")
).fetchall()

@reflection.cache
def has_table(self, connection, table_name, schema=None, **kw):
table_name_query = """
select case when exists(
select table_name
from information_schema.tables
where table_schema = :schema_name
and table_name = :table_name
) then 1 else 0 end
"""
query = text(table_name_query).bindparams(
bindparam("schema_name", type_=sqltypes.Unicode),
bindparam("table_name", type_=sqltypes.Unicode),
)
if schema is None:
schema = self.default_schema_name
quote_table_name = self.identifier_preparer.quote_identifier(table_name)
quote_schema = self.identifier_preparer.quote_identifier(schema)
query = f"""EXISTS TABLE {quote_schema}.{quote_table_name}"""
r = connection.scalar(text(query))

r = connection.scalar(query, dict(schema_name=schema, table_name=table_name))
if r == 1:
return True
return False
Expand Down Expand Up @@ -1537,21 +1550,26 @@ def get_columns(self, connection, table_name, schema=None, **kw):
def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is None:
schema = self.default_schema_name
quote_schema = self.identifier_preparer.quote_identifier(schema)
quote_view_name = self.identifier_preparer.quote_identifier(view_name)
full_view_name = f"{quote_schema}.{quote_view_name}"

# ToDo : perhaps can be removed if we get `SHOW CREATE VIEW`
if view_name not in self.get_view_names(connection, schema):
raise NoSuchTableError(full_view_name)

query = f"""SHOW CREATE TABLE {full_view_name}"""
try:
view_def = connection.execute(text(query)).first()
return view_def[1]
except DBAPIError as e:
if "1025" in e.orig.message: # ToDo: The errors need parsing properly
raise NoSuchTableError(full_view_name) from e
query = text(
"""
select view_query
from system.views
where name = :view_name
and database = :schema_name
"""
).bindparams(
bindparam("view_name", type_=sqltypes.UnicodeText),
bindparam("schema_name", type_=sqltypes.Unicode),
)
r = connection.scalar(
query, dict(view_name=view_name, schema_name=schema)
)
if not r:
raise NoSuchTableError(
f"{self.identifier_preparer.quote_identifier(schema)}."
f"{self.identifier_preparer.quote_identifier(view_name)}"
)
return r

def _get_column_type(self, column_type):
pattern = r"(?:Nullable)*(?:\()*(\w+)(?:\((.*?)\))?(?:\))*"
Expand Down Expand Up @@ -1621,7 +1639,6 @@ def get_temp_table_names(self, connection, schema=None, **kw):
result = connection.execute(query, dict(schema_name=schema))
return [row[0] for row in result]


@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
view_name_query = """
Expand Down Expand Up @@ -1762,7 +1779,6 @@ def get_multi_table_comment(
schema='system',
).alias("a_tab_comments")


has_filter_names, params = self._prepare_filter_names(filter_names)
owner = schema or self.default_schema_name

Expand Down Expand Up @@ -1804,6 +1820,20 @@ def _check_unicode_description(self, connection):
# We decode everything as UTF-8
return True

@reflection.cache
def get_sequence_names(self, connection, schema: Optional[str] = None, **kw: Any) -> List[str]:
# N.B. sequences are not defined per schema/database
sequence_query = """
show sequences
"""
query = text(sequence_query)
result = connection.execute(query)
return [row[0] for row in result]

def has_sequence(self, connection, sequence_name: str, schema: Optional[str] = None, **kw: Any) -> bool:
# N.B. sequences are not defined per schema/database
return sequence_name in self.get_sequence_names(connection, schema, **kw)


dialect = DatabendDialect

Expand Down
109 changes: 100 additions & 9 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest
from sqlalchemy.testing.suite import QuotedNameArgumentTest as _QuotedNameArgumentTest
from sqlalchemy.testing.suite import JoinTest as _JoinTest
from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest

from sqlalchemy.testing.suite import ServerSideCursorsTest as _ServerSideCursorsTest

Expand All @@ -21,7 +22,7 @@
from sqlalchemy.testing.suite import IntegerTest as _IntegerTest

from sqlalchemy import types as sql_types
from sqlalchemy.testing import config
from sqlalchemy.testing import config, skip_test
from sqlalchemy import testing, Table, Column, Integer
from sqlalchemy.testing import eq_, fixtures, assertions

Expand All @@ -30,7 +31,8 @@
from packaging import version
import sqlalchemy
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest
if version.parse(sqlalchemy.__version__) < version.parse('2.0.42'):
from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest
from sqlalchemy.testing.suite import EnumTest as _EnumTest
else:
from sqlalchemy.testing.suite import ComponentReflectionTest as _ComponentReflectionTest
Expand All @@ -42,14 +44,36 @@ def test_get_indexes(self):
pass

class ComponentReflectionTestExtra(_ComponentReflectionTestExtra):

@testing.skip("databend") #ToDo No length in Databend
@testing.requires.table_reflection
def test_varchar_reflection(self, connection, metadata):
typ = self._type_round_trip(
connection, metadata, sql_types.String(52)
)[0]
assert isinstance(typ, sql_types.String)
# eq_(typ.length, 52) # No length in Databend
eq_(typ.length, 52)

@testing.skip("databend") # ToDo No length in Databend
@testing.requires.table_reflection
@testing.combinations(
sql_types.String,
sql_types.VARCHAR,
sql_types.CHAR,
(sql_types.NVARCHAR, testing.requires.nvarchar_types),
(sql_types.NCHAR, testing.requires.nvarchar_types),
argnames="type_",
)
def test_string_length_reflection(self, connection, metadata, type_):
typ = self._type_round_trip(connection, metadata, type_(52))[0]
if issubclass(type_, sql_types.VARCHAR):
assert isinstance(typ, sql_types.VARCHAR)
elif issubclass(type_, sql_types.CHAR):
assert isinstance(typ, sql_types.CHAR)
else:
assert isinstance(typ, sql_types.String)

eq_(typ.length, 52)
assert isinstance(typ.length, int)


class BooleanTest(_BooleanTest):
Expand Down Expand Up @@ -204,7 +228,7 @@ def test_get_indexes(self, name):
class JoinTest(_JoinTest):
__requires__ = ("foreign_keys",)

if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0') and version.parse(sqlalchemy.__version__) < version.parse('2.0.42'):
class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest):
__requires__ = ("foreign_keys",)

Expand Down Expand Up @@ -586,9 +610,6 @@ def test_geometry_write_and_read(self, connection):
eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}'))





class GeographyTest(fixtures.TablesTest):

@classmethod
Expand Down Expand Up @@ -664,4 +685,74 @@ def test_geography_write_and_read(self, connection):
result = connection.execute(
select(geography_table.c.geography_data).where(geography_table.c.id == 7)
).scalar()
eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}'))
eq_(result, ('{"type": "GeometryCollection", "geometries": [{"type": "Point", "coordinates": [10,20]},{"type": "LineString", "coordinates": [[10,20],[30,40]]},{"type": "Polygon", "coordinates": [[[10,20],[30,40],[50,60],[10,20]]]}]}'))


class HasSequenceTest(_HasSequenceTest):

# ToDo - overridden other_seq definition due to lack of sequence ddl support for nominvalue nomaxvalue
@classmethod
def define_tables(cls, metadata):
normalize_sequence(config, Sequence("user_id_seq", metadata=metadata))
normalize_sequence(
config,
Sequence(
"other_seq",
metadata=metadata,
# nomaxvalue=True,
# nominvalue=True,
),
)
if testing.requires.schemas.enabled:
#ToDo - omitted because Databend does not allow schema on sequence
# normalize_sequence(
# config,
# Sequence(
# "user_id_seq", schema=config.test_schema, metadata=metadata
# ),
# )
normalize_sequence(
config,
Sequence(
"schema_seq", schema=config.test_schema, metadata=metadata
),
)
Table(
"user_id_table",
metadata,
Column("id", Integer, primary_key=True),
)

@testing.skip("databend") # ToDo - requires definition of sequences with schema
def test_has_sequence_remote_not_in_default(self, connection):
eq_(inspect(connection).has_sequence("schema_seq"), False)

@testing.skip("databend") # ToDo - requires definition of sequences with schema
def test_get_sequence_names(self, connection):
exp = {"other_seq", "user_id_seq"}

res = set(inspect(connection).get_sequence_names())
is_true(res.intersection(exp) == exp)
is_true("schema_seq" not in res)

@testing.skip("databend") # ToDo - requires definition of sequences with schema
@testing.requires.schemas
def test_get_sequence_names_no_sequence_schema(self, connection):
eq_(
inspect(connection).get_sequence_names(
schema=config.test_schema_2
),
[],
)

@testing.skip("databend") # ToDo - requires definition of sequences with schema
@testing.requires.schemas
def test_get_sequence_names_sequences_schema(self, connection):
eq_(
sorted(
inspect(connection).get_sequence_names(
schema=config.test_schema
)
),
["schema_seq", "user_id_seq"],
)