Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.13
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[project]
name = "m-schema"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"llama-index>=0.12.42",
"numpy>=2.3.0",
"psycopg2>=2.9.10",
"pymysql>=1.1.1",
"sqlalchemy>=2.0.41",
]
31 changes: 26 additions & 5 deletions schema_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@ class SchemaEngine(SQLDatabase):
def __init__(self, engine: Engine, schema: Optional[str] = None, metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3, indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300,
custom_table_info: Optional[dict] = None,
view_support: bool = False, max_string_length: int = 300,
mschema: Optional[MSchema] = None, db_name: Optional[str] = ''):
super().__init__(engine, schema, metadata, ignore_tables, include_tables, sample_rows_in_table_info,
indexes_in_table_info, custom_table_info, view_support, max_string_length)

self._db_name = db_name
# Dictionary to store table names and their corresponding schema
self._tables_schemas: Dict[str, str] = {}
self._tables_schemas: Dict[str, str] = {} # For MySQL and similar databases, if no schema is specified but db_name is provided,
# use db_name as the schema to avoid getting tables from all databases
if schema is None and db_name:
if self._engine.dialect.name == 'mysql':
schema = db_name
elif self._engine.dialect.name == 'postgresql':
# For PostgreSQL, use 'public' as default schema
schema = 'public'

# If a schema is specified, filter by that schema and store that value for every table.
if schema:
Expand Down Expand Up @@ -85,10 +93,23 @@ def fectch_distinct_values(self, table_name: str, column_name: str, max_num: int
return values

def init_mschema(self):
print(f"Debug: Database dialect = {self._engine.dialect.name}")
print(f"Debug: DB name = {self._db_name}")
print(f"Debug: Available schemas = {self.get_schema_names()}")
print(f"Debug: Usable tables = {self._usable_tables}")
print(f"Debug: Tables schemas mapping = {self._tables_schemas}")

for table_name in self._usable_tables:
table_comment = self.get_table_comment(table_name)
table_comment = '' if table_comment is None else table_comment.strip()
table_with_schema = self._tables_schemas[table_name] + '.' + table_name
table_comment = '' if table_comment is None else table_comment.strip() # For MySQL, avoid duplicate schema name in table identifier
# For PostgreSQL, include schema name if it's not 'public'
schema_name = self._tables_schemas[table_name]
if self._engine.dialect.name == 'mysql' and schema_name == self._db_name:
table_with_schema = table_name
elif self._engine.dialect.name == 'postgresql' and schema_name == 'public':
table_with_schema = table_name
else:
table_with_schema = schema_name + '.' + table_name
self._mschema.add_table(table_with_schema, fields={}, comment=table_comment)
pks = self.get_pk_constraint(table_name)

Expand Down
Loading