Skip to content

Commit 1ac6e5b

Browse files
committed
Fixups for latest SQLAlchemy
1 parent fa94e1d commit 1ac6e5b

File tree

2 files changed

+61
-33
lines changed

2 files changed

+61
-33
lines changed

databend_sqlalchemy/databend_dialect.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,24 +1491,24 @@ def _get_default_schema_name(self, connection):
14911491
def get_schema_names(self, connection, **kw):
14921492
return [row[0] for row in connection.execute(text("SHOW DATABASES"))]
14931493

1494-
def _get_table_columns(self, connection, table_name, schema):
1495-
if schema is None:
1496-
schema = self.default_schema_name
1497-
quote_table_name = self.identifier_preparer.quote_identifier(table_name)
1498-
quote_schema = self.identifier_preparer.quote_identifier(schema)
1499-
1500-
return connection.execute(
1501-
text(f"DESC {quote_schema}.{quote_table_name}")
1502-
).fetchall()
1503-
15041494
@reflection.cache
15051495
def has_table(self, connection, table_name, schema=None, **kw):
1496+
table_name_query = """
1497+
select case when exists(
1498+
select table_name
1499+
from information_schema.tables
1500+
where table_schema = :schema_name
1501+
and table_name = :table_name
1502+
) then 1 else 0 end
1503+
"""
1504+
query = text(table_name_query).bindparams(
1505+
bindparam("schema_name", type_=sqltypes.Unicode),
1506+
bindparam("table_name", type_=sqltypes.Unicode),
1507+
)
15061508
if schema is None:
15071509
schema = self.default_schema_name
1508-
quote_table_name = self.identifier_preparer.quote_identifier(table_name)
1509-
quote_schema = self.identifier_preparer.quote_identifier(schema)
1510-
query = f"""EXISTS TABLE {quote_schema}.{quote_table_name}"""
1511-
r = connection.scalar(text(query))
1510+
1511+
r = connection.scalar(query, dict(schema_name=schema, table_name=table_name))
15121512
if r == 1:
15131513
return True
15141514
return False
@@ -1550,21 +1550,26 @@ def get_columns(self, connection, table_name, schema=None, **kw):
15501550
def get_view_definition(self, connection, view_name, schema=None, **kw):
15511551
if schema is None:
15521552
schema = self.default_schema_name
1553-
quote_schema = self.identifier_preparer.quote_identifier(schema)
1554-
quote_view_name = self.identifier_preparer.quote_identifier(view_name)
1555-
full_view_name = f"{quote_schema}.{quote_view_name}"
1556-
1557-
# ToDo : perhaps can be removed if we get `SHOW CREATE VIEW`
1558-
if view_name not in self.get_view_names(connection, schema):
1559-
raise NoSuchTableError(full_view_name)
1560-
1561-
query = f"""SHOW CREATE TABLE {full_view_name}"""
1562-
try:
1563-
view_def = connection.execute(text(query)).first()
1564-
return view_def[1]
1565-
except DBAPIError as e:
1566-
if "1025" in e.orig.message: # ToDo: The errors need parsing properly
1567-
raise NoSuchTableError(full_view_name) from e
1553+
query = text(
1554+
"""
1555+
select view_query
1556+
from system.views
1557+
where name = :view_name
1558+
and database = :schema_name
1559+
"""
1560+
).bindparams(
1561+
bindparam("view_name", type_=sqltypes.UnicodeText),
1562+
bindparam("schema_name", type_=sqltypes.Unicode),
1563+
)
1564+
r = connection.scalar(
1565+
query, dict(view_name=view_name, schema_name=schema)
1566+
)
1567+
if not r:
1568+
raise NoSuchTableError(
1569+
f"{self.identifier_preparer.quote_identifier(schema)}."
1570+
f"{self.identifier_preparer.quote_identifier(view_name)}"
1571+
)
1572+
return r
15681573

15691574
def _get_column_type(self, column_type):
15701575
pattern = r"(?:Nullable)*(?:\()*(\w+)(?:\((.*?)\))?(?:\))*"

tests/test_sqlalchemy.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
from packaging import version
3232
import sqlalchemy
3333
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
34-
from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest
34+
if version.parse(sqlalchemy.__version__) < version.parse('2.0.42'):
35+
from sqlalchemy.testing.suite import BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest
3536
from sqlalchemy.testing.suite import EnumTest as _EnumTest
3637
else:
3738
from sqlalchemy.testing.suite import ComponentReflectionTest as _ComponentReflectionTest
@@ -43,14 +44,36 @@ def test_get_indexes(self):
4344
pass
4445

4546
class ComponentReflectionTestExtra(_ComponentReflectionTestExtra):
46-
47+
@testing.skip("databend") #ToDo No length in Databend
4748
@testing.requires.table_reflection
4849
def test_varchar_reflection(self, connection, metadata):
4950
typ = self._type_round_trip(
5051
connection, metadata, sql_types.String(52)
5152
)[0]
5253
assert isinstance(typ, sql_types.String)
53-
# eq_(typ.length, 52) # No length in Databend
54+
eq_(typ.length, 52)
55+
56+
@testing.skip("databend") # ToDo No length in Databend
57+
@testing.requires.table_reflection
58+
@testing.combinations(
59+
sql_types.String,
60+
sql_types.VARCHAR,
61+
sql_types.CHAR,
62+
(sql_types.NVARCHAR, testing.requires.nvarchar_types),
63+
(sql_types.NCHAR, testing.requires.nvarchar_types),
64+
argnames="type_",
65+
)
66+
def test_string_length_reflection(self, connection, metadata, type_):
67+
typ = self._type_round_trip(connection, metadata, type_(52))[0]
68+
if issubclass(type_, sql_types.VARCHAR):
69+
assert isinstance(typ, sql_types.VARCHAR)
70+
elif issubclass(type_, sql_types.CHAR):
71+
assert isinstance(typ, sql_types.CHAR)
72+
else:
73+
assert isinstance(typ, sql_types.String)
74+
75+
eq_(typ.length, 52)
76+
assert isinstance(typ.length, int)
5477

5578

5679
class BooleanTest(_BooleanTest):
@@ -205,7 +228,7 @@ def test_get_indexes(self, name):
205228
class JoinTest(_JoinTest):
206229
__requires__ = ("foreign_keys",)
207230

208-
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
231+
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0') and version.parse(sqlalchemy.__version__) < version.parse('2.0.42'):
209232
class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest):
210233
__requires__ = ("foreign_keys",)
211234

0 commit comments

Comments
 (0)