22import pytest
33from unittest import skipIf
44from sqlalchemy import create_engine , select , insert , Column , MetaData , Table
5- from sqlalchemy .orm import declarative_base , Session
5+ from sqlalchemy .orm import Session
66from sqlalchemy .types import SMALLINT , Integer , BOOLEAN , String , DECIMAL , Date
7+ from sqlalchemy .engine import Engine
8+
9+ from typing import Tuple
10+
11+ try :
12+ from sqlalchemy .orm import declarative_base
13+ except ImportError :
14+ from sqlalchemy .ext .declarative import declarative_base
715
816
917USER_AGENT_TOKEN = "PySQL e2e Tests"
1018
1119
12- @pytest .fixture
13- def db_engine ():
20+ def sqlalchemy_1_3 ():
21+ import sqlalchemy
22+
23+ return sqlalchemy .__version__ .startswith ("1.3" )
24+
25+
26+ def version_agnostic_select (object_to_select , * args , ** kwargs ):
27+ """
28+ SQLAlchemy==1.3.x requires arguments to select() to be a Python list
29+
30+ https://docs.sqlalchemy.org/en/20/changelog/migration_14.html#orm-query-is-internally-unified-with-select-update-delete-2-0-style-execution-available
31+ """
32+
33+ if sqlalchemy_1_3 ():
34+ return select ([object_to_select ], * args , ** kwargs )
35+ else :
36+ return select (object_to_select , * args , ** kwargs )
37+
38+
39+ def version_agnostic_connect_arguments (catalog = None , schema = None ) -> Tuple [str , dict ]:
1440
1541 HOST = os .environ .get ("host" )
1642 HTTP_PATH = os .environ .get ("http_path" )
1743 ACCESS_TOKEN = os .environ .get ("access_token" )
18- CATALOG = os .environ .get ("catalog" )
19- SCHEMA = os .environ .get ("schema" )
44+ CATALOG = catalog or os .environ .get ("catalog" )
45+ SCHEMA = schema or os .environ .get ("schema" )
46+
47+ ua_connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
48+
49+ if sqlalchemy_1_3 ():
50+ conn_string = f"databricks://token:{ ACCESS_TOKEN } @{ HOST } "
51+ connect_args = {
52+ ** ua_connect_args ,
53+ "http_path" : HTTP_PATH ,
54+ "server_hostname" : HOST ,
55+ "catalog" : CATALOG ,
56+ "schema" : SCHEMA ,
57+ }
58+
59+ return conn_string , connect_args
60+ else :
61+ return (
62+ f"databricks://token:{ ACCESS_TOKEN } @{ HOST } ?http_path={ HTTP_PATH } &catalog={ CATALOG } &schema={ SCHEMA } " ,
63+ ua_connect_args ,
64+ )
65+
66+
67+ @pytest .fixture
68+ def db_engine () -> Engine :
69+ conn_string , connect_args = version_agnostic_connect_arguments ()
70+ return create_engine (conn_string , connect_args = connect_args )
2071
21- connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
2272
23- engine = create_engine (
24- f"databricks://token:{ ACCESS_TOKEN } @{ HOST } ?http_path={ HTTP_PATH } &catalog={ CATALOG } &schema={ SCHEMA } " ,
25- connect_args = connect_args ,
73+ @pytest .fixture
74+ def samples_engine () -> Engine :
75+
76+ conn_string , connect_args = version_agnostic_connect_arguments (
77+ catalog = "samples" , schema = "nyctaxi"
2678 )
27- return engine
79+ return create_engine ( conn_string , connect_args = connect_args )
2880
2981
3082@pytest .fixture ()
@@ -62,6 +114,7 @@ def test_connect_args(db_engine):
62114 assert expected in user_agent
63115
64116
117+ @pytest .mark .skipif (sqlalchemy_1_3 (), reason = "Pandas requires SQLAlchemy >= 1.4" )
65118def test_pandas_upload (db_engine , metadata_obj ):
66119
67120 import pandas as pd
@@ -86,7 +139,7 @@ def test_pandas_upload(db_engine, metadata_obj):
86139 db_engine .execute ("DROP TABLE mock_data" )
87140
88141
89- def test_create_table_not_null (db_engine , metadata_obj ):
142+ def test_create_table_not_null (db_engine , metadata_obj : MetaData ):
90143
91144 table_name = "PySQLTest_{}" .format (datetime .datetime .utcnow ().strftime ("%s" ))
92145
@@ -95,7 +148,7 @@ def test_create_table_not_null(db_engine, metadata_obj):
95148 metadata_obj ,
96149 Column ("name" , String (255 )),
97150 Column ("episodes" , Integer ),
98- Column ("some_bool" , BOOLEAN , nullable = False ),
151+ Column ("some_bool" , BOOLEAN ( create_constraint = False ) , nullable = False ),
99152 )
100153
101154 metadata_obj .create_all ()
@@ -135,7 +188,7 @@ def test_bulk_insert_with_core(db_engine, metadata_obj, session):
135188 metadata_obj .create_all ()
136189 db_engine .execute (insert (SampleTable ).values (rows ))
137190
138- rows = db_engine .execute (select (SampleTable )).fetchall ()
191+ rows = db_engine .execute (version_agnostic_select (SampleTable )).fetchall ()
139192
140193 assert len (rows ) == num_to_insert
141194
@@ -148,7 +201,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
148201 metadata_obj ,
149202 Column ("name" , String (255 )),
150203 Column ("episodes" , Integer ),
151- Column ("some_bool" , BOOLEAN ),
204+ Column ("some_bool" , BOOLEAN ( create_constraint = False ) ),
152205 Column ("dollars" , DECIMAL (10 , 2 )),
153206 )
154207
@@ -161,7 +214,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
161214 with db_engine .connect () as conn :
162215 conn .execute (insert_stmt )
163216
164- select_stmt = select (SampleTable )
217+ select_stmt = version_agnostic_select (SampleTable )
165218 resp = db_engine .execute (select_stmt )
166219
167220 result = resp .fetchall ()
@@ -187,7 +240,7 @@ class SampleObject(base):
187240
188241 name = Column (String (255 ), primary_key = True )
189242 episodes = Column (Integer )
190- some_bool = Column (BOOLEAN )
243+ some_bool = Column (BOOLEAN ( create_constraint = False ) )
191244
192245 base .metadata .create_all ()
193246
@@ -197,11 +250,15 @@ class SampleObject(base):
197250 session .add (sample_object_2 )
198251 session .commit ()
199252
200- stmt = select (SampleObject ).where (
253+ stmt = version_agnostic_select (SampleObject ).where (
201254 SampleObject .name .in_ (["Bim Adewunmi" , "Miki Meek" ])
202255 )
203256
204- output = [i for i in session .scalars (stmt )]
257+ if sqlalchemy_1_3 ():
258+ output = [i for i in session .execute (stmt )]
259+ else :
260+ output = [i for i in session .scalars (stmt )]
261+
205262 assert len (output ) == 2
206263
207264 base .metadata .drop_all ()
@@ -215,7 +272,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
215272 metadata_obj ,
216273 Column ("string_example" , String (255 )),
217274 Column ("integer_example" , Integer ),
218- Column ("boolean_example" , BOOLEAN ),
275+ Column ("boolean_example" , BOOLEAN ( create_constraint = False ) ),
219276 Column ("decimal_example" , DECIMAL (10 , 2 )),
220277 Column ("date_example" , Date ),
221278 )
@@ -239,7 +296,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
239296 with db_engine .connect () as conn :
240297 conn .execute (insert_stmt )
241298
242- select_stmt = select (SampleTable )
299+ select_stmt = version_agnostic_select (SampleTable )
243300 resp = db_engine .execute (select_stmt )
244301
245302 result = resp .fetchall ()
@@ -252,3 +309,34 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
252309 assert this_row ["date_example" ] == date_example
253310
254311 metadata_obj .drop_all ()
312+
313+
314+ def test_inspector_smoke_test (samples_engine : Engine ):
315+ """It does not appear that 3L namespace is supported here"""
316+
317+ from sqlalchemy .engine .reflection import Inspector
318+
319+ schema , table = "nyctaxi" , "trips"
320+
321+ try :
322+ inspector = Inspector .from_engine (samples_engine )
323+ except Exception as e :
324+ assert False , f"Could not build inspector: { e } "
325+
326+ # Expect six columns
327+ columns = inspector .get_columns (table , schema = schema )
328+
329+ # Expect zero views, but the method should return
330+ views = inspector .get_view_names (schema = schema )
331+
332+ assert (
333+ len (columns ) == 6
334+ ), "Dialect did not find the expected number of columns in samples.nyctaxi.trips"
335+ assert len (views ) == 0 , "Views could not be fetched"
336+
337+
338+ def test_get_table_names_smoke_test (samples_engine : Engine ):
339+
340+ with samples_engine .connect () as conn :
341+ _names = samples_engine .table_names (schema = "nyctaxi" , connection = conn )
342+ _names is not None , "get_table_names did not succeed"
0 commit comments