66import time
77from concurrent .futures import ThreadPoolExecutor , as_completed
88from datetime import datetime
9- from typing import Any , Dict , Iterable , List , Optional , Tuple
9+ from functools import wraps
10+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple
1011from uuid import UUID
1112
1213import requests
@@ -780,14 +781,19 @@ def introspect_primary_key(
780781 """
781782 primary_index_dict = inspector .get_pk_constraint (relation_name , schema_name )
782783
783- # MySQL at least can have unnamed primary keys. The returned dict will have 'name' -> None.
784- # Sigh.
785- pkey_name = primary_index_dict .get ('name' ) or '(unnamed primary key)'
784+ # Athena dialect returns ... an empty _list_ instead of a dict, contrary to what
785+ # https://docs.sqlalchemy.org/en/14/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_pk_constraint
786+ # specifies for the return result from inspector.get_pk_constraint().
787+ if isinstance (primary_index_dict , dict ):
788+ # MySQL at least can have unnamed primary keys. The returned dict will have 'name' -> None.
789+ # Sigh.
790+ pkey_name = primary_index_dict .get ('name' ) or '(unnamed primary key)'
786791
787- if primary_index_dict ['constrained_columns' ]:
788- return pkey_name , primary_index_dict ['constrained_columns' ]
789- else :
790- return None , []
792+ if primary_index_dict ['constrained_columns' ]:
793+ return pkey_name , primary_index_dict ['constrained_columns' ]
794+
795+ # No primary key to be returned.
796+ return None , []
791797
792798 def introspect_columns (
793799 self , inspector : SchemaStrippingInspector , schema_name : str , relation_name : str
@@ -1197,6 +1203,30 @@ def run_meta_command(
11971203 instance .do_run (invoker , args )
11981204
11991205
1206+ def handle_not_implemented (default : Any = None , default_factory : Callable [[], Any ] = None ):
1207+ """Decorator to catch NotImplementedError, return either default constant or
1208+ whatever default_factory() returns."""
1209+ assert default or default_factory , 'must provide one of default or default_factory'
1210+ assert not (
1211+ default and default_factory
1212+ ), 'only provide one of either default or default_factory'
1213+
1214+ def wrapper (func ):
1215+ @wraps (func )
1216+ def wrapped (* args , ** kwargs ):
1217+ try :
1218+ return func (* args , ** kwargs )
1219+ except NotImplementedError :
1220+ if default_factory :
1221+ return default_factory ()
1222+ else :
1223+ return default
1224+
1225+ return wrapped
1226+
1227+ return wrapper
1228+
1229+
12001230class SchemaStrippingInspector :
12011231 """Proxy implementation that removes 'schema.' prefixing from results of underlying
12021232 get_table_names() and get_view_names(). BigQuery dialect inspector seems to include
@@ -1218,6 +1248,7 @@ def get_schema_names(self) -> List[str]:
12181248 def get_columns (self , relation_name : str , schema : Optional [str ] = None ) -> List [dict ]:
12191249 return self .underlying_inspector .get_columns (relation_name , schema = schema )
12201250
1251+ @handle_not_implemented ('(unobtainable)' )
12211252 def get_view_definition (self , view_name : str , schema : Optional [str ] = None ) -> str :
12221253 return self .underlying_inspector .get_view_definition (view_name , schema = schema )
12231254
@@ -1227,20 +1258,16 @@ def get_pk_constraint(self, table_name: str, schema: Optional[str] = None) -> di
12271258 def get_foreign_keys (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
12281259 return self .underlying_inspector .get_foreign_keys (table_name , schema = schema )
12291260
1261+ @handle_not_implemented (default_factory = list )
12301262 def get_check_constraints (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
1231- try :
1232- return self .underlying_inspector .get_check_constraints (table_name , schema = schema )
1233- except NotImplementedError :
1234- return []
1263+ return self .underlying_inspector .get_check_constraints (table_name , schema = schema )
12351264
12361265 def get_indexes (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
12371266 return self .underlying_inspector .get_indexes (table_name , schema = schema )
12381267
1268+ @handle_not_implemented (default_factory = list )
12391269 def get_unique_constraints (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
1240- try :
1241- return self .underlying_inspector .get_unique_constraints (table_name , schema = schema )
1242- except NotImplementedError :
1243- return []
1270+ return self .underlying_inspector .get_unique_constraints (table_name , schema = schema )
12441271
12451272 # Now the value-adding filtering methods.
12461273 def get_table_names (self , schema : Optional [str ] = None ) -> List [str ]:
0 commit comments