|
1 | | -import re |
2 | | -from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple |
| 1 | +from typing import Any, List, Optional, Dict, Union |
3 | 2 |
|
4 | 3 | import databricks.sqlalchemy._ddl as dialect_ddl_impl |
5 | 4 | import databricks.sqlalchemy._types as dialect_type_impl |
|
11 | 10 | build_pk_dict, |
12 | 11 | get_fk_strings_from_dte_output, |
13 | 12 | get_pk_strings_from_dte_output, |
| 13 | + get_comment_from_dte_output, |
14 | 14 | parse_column_info_from_tgetcolumnsresponse, |
15 | 15 | ) |
16 | 16 |
|
17 | 17 | import sqlalchemy |
18 | 18 | from sqlalchemy import DDL, event |
19 | 19 | from sqlalchemy.engine import Connection, Engine, default, reflection |
20 | | -from sqlalchemy.engine.reflection import ObjectKind |
21 | 20 | from sqlalchemy.engine.interfaces import ( |
22 | 21 | ReflectedForeignKeyConstraint, |
23 | 22 | ReflectedPrimaryKeyConstraint, |
24 | 23 | ReflectedColumn, |
25 | | - TableKey, |
| 24 | + ReflectedTableComment, |
26 | 25 | ) |
| 26 | +from sqlalchemy.engine.reflection import ReflectionDefaults |
27 | 27 | from sqlalchemy.exc import DatabaseError, SQLAlchemyError |
28 | 28 |
|
29 | 29 | try: |
@@ -285,7 +285,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs): |
285 | 285 | views_result = self.get_view_names(connection=connection, schema=schema) |
286 | 286 |
|
287 | 287 | # In Databricks, SHOW TABLES FROM <schema> returns both tables and views. |
288 | | - # Potential optimisation: rewrite this to instead query informtation_schema |
| 288 | + # Potential optimisation: rewrite this to instead query information_schema |
289 | 289 | tables_minus_views = [ |
290 | 290 | row.tableName for row in tables_result if row.tableName not in views_result |
291 | 291 | ] |
@@ -328,7 +328,7 @@ def get_materialized_view_names( |
328 | 328 | def get_temp_view_names( |
329 | 329 | self, connection: Connection, schema: Optional[str] = None, **kw: Any |
330 | 330 | ) -> List[str]: |
331 | | - """A wrapper around get_view_names taht fetches only the names of temporary views""" |
| 331 | + """A wrapper around get_view_names that fetches only the names of temporary views""" |
332 | 332 | return self.get_view_names(connection, schema, only_temp=True) |
333 | 333 |
|
334 | 334 | def do_rollback(self, dbapi_connection): |
@@ -375,6 +375,30 @@ def get_schema_names(self, connection, **kw): |
375 | 375 | schema_list = [row[0] for row in result] |
376 | 376 | return schema_list |
377 | 377 |
|
| 378 | + @reflection.cache |
| 379 | + def get_table_comment( |
| 380 | + self, |
| 381 | + connection: Connection, |
| 382 | + table_name: str, |
| 383 | + schema: Optional[str] = None, |
| 384 | + **kw: Any, |
| 385 | + ) -> ReflectedTableComment: |
| 386 | + result = self._describe_table_extended( |
| 387 | + connection=connection, |
| 388 | + table_name=table_name, |
| 389 | + schema_name=schema, |
| 390 | + ) |
| 391 | + |
| 392 | + if result is None: |
| 393 | + return ReflectionDefaults.table_comment() |
| 394 | + |
| 395 | + comment = get_comment_from_dte_output(result) |
| 396 | + |
| 397 | + if comment: |
| 398 | + return dict(text=comment) |
| 399 | + else: |
| 400 | + return ReflectionDefaults.table_comment() |
| 401 | + |
378 | 402 |
|
379 | 403 | @event.listens_for(Engine, "do_connect") |
380 | 404 | def receive_do_connect(dialect, conn_rec, cargs, cparams): |
|
0 commit comments