Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
946 changes: 443 additions & 503 deletions mipdb/commands.py

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions mipdb/databases/__init__.py → mipdb/credentials.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import os

from typing import Dict, Any
import toml

CONFIG_PATH = "/home/config.toml"


def credentials_from_config():
def credentials_from_config() -> Dict[str, Any]:
"""Return a dict of credentials from *config.toml* or sensible fall‑backs."""

try:
return toml.load(os.getenv("CONFIG_PATH", CONFIG_PATH))
except FileNotFoundError:
return {
"DB_IP": "",
"DB_PORT": "",
"MONETDB_ENABLED": False,
"MONETDB_ADMIN_USERNAME": "",
"MONETDB_LOCAL_USERNAME": "",
"MONETDB_LOCAL_PASSWORD": "",
Expand Down
2 changes: 1 addition & 1 deletion mipdb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import wraps
from contextlib import contextmanager

sys.tracebacklimit = 0
# sys.tracebacklimit = 0


class DataBaseError(Exception):
Expand Down
7 changes: 7 additions & 0 deletions mipdb/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import logging
import sys

logging.basicConfig(
stream=sys.stderr, level=logging.INFO, format="%(levelname)s: %(message)s"
)
LOGGER = logging.getLogger(__name__)
Empty file added mipdb/monetdb/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import sqlalchemy as sql
from sqlalchemy import MetaData

from mipdb.credentials import credentials_from_config
from mipdb.data_frame import DATASET_COLUMN_NAME
from mipdb.databases import credentials_from_config
from mipdb.dataelements import CommonDataElement
from mipdb.exceptions import UserInputError
from mipdb.schema import Schema
from mipdb.monetdb.schema import Schema

RECORDS_PER_COPY = 100000

Expand Down Expand Up @@ -158,8 +158,8 @@ def __init__(self, dataframe_sql_type_per_column, db):
self._table.create(bind=db.get_executor())

def create(self, db):
db.execute(f'DROP TABLE IF EXISTS "{self.table.name}"')
self.table.create(bind=db.get_executor())
db.execute(f'DROP TABLE IF EXISTS "{self.table.name}"')
self.table.create(bind=db.get_executor())

def validate_csv(self, csv_path, cdes_with_min_max, cdes_with_enumerations, db):
validated_datasets = []
Expand Down
File renamed without changes.
Empty file added mipdb/sqlite/__init__.py
Empty file.
115 changes: 70 additions & 45 deletions mipdb/databases/sqlite.py → mipdb/sqlite/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def execute_fetchall(self, query: str, *args, **kwargs) -> List[Any]:
result = conn.execute(text(query), *args, **kwargs)
return result.fetchall()

def insert_values_to_table(self, table: sql.Table, values: List[Dict[str, Any]]) -> None:
def insert_values_to_table(
self, table: sql.Table, values: List[Dict[str, Any]]
) -> None:
session = self.Session()
try:
session.execute(table.insert(), values)
Expand All @@ -96,8 +98,9 @@ def get_data_model_status(self, data_model_id: int) -> Any:
session = self.Session()
try:
result = session.execute(
select(Base.metadata.tables["data_models"].c.status)
.where(Base.metadata.tables["data_models"].c.data_model_id == data_model_id)
select(Base.metadata.tables["data_models"].c.status).where(
Base.metadata.tables["data_models"].c.data_model_id == data_model_id
)
).scalar_one_or_none()
finally:
session.close()
Expand All @@ -108,7 +111,9 @@ def update_data_model_status(self, status: str, data_model_id: int) -> None:
try:
session.execute(
sql.update(Base.metadata.tables["data_models"])
.where(Base.metadata.tables["data_models"].c.data_model_id == data_model_id)
.where(
Base.metadata.tables["data_models"].c.data_model_id == data_model_id
)
.values(status=status)
)
session.commit()
Expand All @@ -119,8 +124,9 @@ def get_dataset_status(self, dataset_id: int) -> Any:
session = self.Session()
try:
result = session.execute(
select(Base.metadata.tables["datasets"].c.status)
.where(Base.metadata.tables["datasets"].c.dataset_id == dataset_id)
select(Base.metadata.tables["datasets"].c.status).where(
Base.metadata.tables["datasets"].c.dataset_id == dataset_id
)
).scalar_one_or_none()
finally:
session.close()
Expand Down Expand Up @@ -158,7 +164,9 @@ def get_dataset(self, dataset_id: int, columns: List[str]) -> Any:
try:
cols = [getattr(Base.metadata.tables["datasets"].c, col) for col in columns]
result = session.execute(
select(*cols).where(Base.metadata.tables["datasets"].c.dataset_id == dataset_id)
select(*cols).where(
Base.metadata.tables["datasets"].c.dataset_id == dataset_id
)
).one_or_none()
finally:
session.close()
Expand All @@ -167,9 +175,13 @@ def get_dataset(self, dataset_id: int, columns: List[str]) -> Any:
def get_data_model(self, data_model_id: int, columns: List[str]) -> Any:
session = self.Session()
try:
cols = [getattr(Base.metadata.tables["data_models"].c, col) for col in columns]
cols = [
getattr(Base.metadata.tables["data_models"].c, col) for col in columns
]
result = session.execute(
select(*cols).where(Base.metadata.tables["data_models"].c.data_model_id == data_model_id)
select(*cols).where(
Base.metadata.tables["data_models"].c.data_model_id == data_model_id
)
).one_or_none()
finally:
session.close()
Expand All @@ -180,10 +192,10 @@ def get_data_model(self, data_model_id: int, columns: List[str]) -> Any:
# ...

def get_values(
self,
table: sql.Table,
columns: List[str] | None = None,
where_conditions: Dict[str, Any] | None = None,
self,
table: sql.Table,
columns: List[str] | None = None,
where_conditions: Dict[str, Any] | None = None,
) -> List[sql.Row]:
"""Return rows (SQLAlchemy Row objects) respecting an optional WHERE."""
stmt = select(
Expand All @@ -199,7 +211,9 @@ def get_values(
def get_data_models(self, columns: List[str]) -> List[Dict[str, Any]]:
session = self.Session()
try:
cols = [getattr(Base.metadata.tables["data_models"].c, col) for col in columns]
cols = [
getattr(Base.metadata.tables["data_models"].c, col) for col in columns
]
rows = session.execute(select(*cols)).all()
finally:
session.close()
Expand All @@ -208,13 +222,12 @@ def get_data_models(self, columns: List[str]) -> List[Dict[str, Any]]:
def get_dataset_count_by_data_model_id(self) -> List[Dict[str, Any]]:
session = self.Session()
try:
stmt = (
select(
Base.metadata.tables["datasets"].c.data_model_id,
func.count(Base.metadata.tables["datasets"].c.data_model_id).label("count"),
)
.group_by(Base.metadata.tables["datasets"].c.data_model_id)
)
stmt = select(
Base.metadata.tables["datasets"].c.data_model_id,
func.count(Base.metadata.tables["datasets"].c.data_model_id).label(
"count"
),
).group_by(Base.metadata.tables["datasets"].c.data_model_id)
rows = session.execute(stmt).all()
finally:
session.close()
Expand All @@ -223,7 +236,9 @@ def get_dataset_count_by_data_model_id(self) -> List[Dict[str, Any]]:
def get_row_count(self, table_name: str) -> int:
session = self.Session()
try:
count = session.execute(select(func.count()).select_from(text(table_name))).scalar_one()
count = session.execute(
select(func.count()).select_from(text(table_name))
).scalar_one()
finally:
session.close()
return count
Expand Down Expand Up @@ -253,8 +268,9 @@ def get_dataset_properties(self, dataset_id: int) -> Any:
session = self.Session()
try:
result = session.execute(
select(Base.metadata.tables["datasets"].c.properties)
.where(Base.metadata.tables["datasets"].c.dataset_id == dataset_id)
select(Base.metadata.tables["datasets"].c.properties).where(
Base.metadata.tables["datasets"].c.dataset_id == dataset_id
)
).scalar_one_or_none()
finally:
session.close()
Expand All @@ -264,26 +280,33 @@ def get_data_model_properties(self, data_model_id: int) -> Any:
session = self.Session()
try:
result = session.execute(
select(Base.metadata.tables["data_models"].c.properties)
.where(Base.metadata.tables["data_models"].c.data_model_id == data_model_id)
select(Base.metadata.tables["data_models"].c.properties).where(
Base.metadata.tables["data_models"].c.data_model_id == data_model_id
)
).scalar_one_or_none()
finally:
session.close()
return result or {}

def set_data_model_properties(self, properties: Dict[str, Any], data_model_id: int) -> None:
def set_data_model_properties(
self, properties: Dict[str, Any], data_model_id: int
) -> None:
session = self.Session()
try:
session.execute(
sql.update(Base.metadata.tables["data_models"])
.where(Base.metadata.tables["data_models"].c.data_model_id == data_model_id)
.where(
Base.metadata.tables["data_models"].c.data_model_id == data_model_id
)
.values(properties=properties)
)
session.commit()
finally:
session.close()

def set_dataset_properties(self, properties: Dict[str, Any], dataset_id: int) -> None:
def set_dataset_properties(
self, properties: Dict[str, Any], dataset_id: int
) -> None:
session = self.Session()
try:
session.execute(
Expand All @@ -298,20 +321,21 @@ def set_dataset_properties(self, properties: Dict[str, Any], dataset_id: int) ->
def get_data_model_id(self, code: str, version: str) -> int:
session = self.Session()
try:
stmt = (
select(Base.metadata.tables["data_models"].c.data_model_id)
.where(
Base.metadata.tables["data_models"].c.code == code,
Base.metadata.tables["data_models"].c.version == version,
)
stmt = select(Base.metadata.tables["data_models"].c.data_model_id).where(
Base.metadata.tables["data_models"].c.code == code,
Base.metadata.tables["data_models"].c.version == version,
)
data_model_id = session.execute(stmt).scalar_one_or_none()
except MultipleResultsFound:
raise DataBaseError(f"Got more than one data_model ids for code={code} and version={version}.")
raise DataBaseError(
f"Got more than one data_model ids for code={code} and version={version}."
)
finally:
session.close()
if not data_model_id:
raise DataBaseError(f"Data_models table doesn't have a record with code={code}, version={version}")
raise DataBaseError(
f"Data_models table doesn't have a record with code={code}, version={version}"
)
return data_model_id

def get_max_data_model_id(self) -> int:
Expand All @@ -337,20 +361,21 @@ def get_max_dataset_id(self) -> int:
def get_dataset_id(self, code: str, data_model_id: int) -> int:
session = self.Session()
try:
stmt = (
select(Base.metadata.tables["datasets"].c.dataset_id)
.where(
Base.metadata.tables["datasets"].c.code == code,
Base.metadata.tables["datasets"].c.data_model_id == data_model_id,
)
stmt = select(Base.metadata.tables["datasets"].c.dataset_id).where(
Base.metadata.tables["datasets"].c.code == code,
Base.metadata.tables["datasets"].c.data_model_id == data_model_id,
)
dataset_id = session.execute(stmt).scalar_one_or_none()
except MultipleResultsFound:
raise DataBaseError(f"Got more than one dataset ids for code={code} and data_model_id={data_model_id}.")
raise DataBaseError(
f"Got more than one dataset ids for code={code} and data_model_id={data_model_id}."
)
finally:
session.close()
if not dataset_id:
raise DataBaseError(f"Datasets table doesn't have a record with code={code}, data_model_id={data_model_id}")
raise DataBaseError(
f"Datasets table doesn't have a record with code={code}, data_model_id={data_model_id}"
)
return dataset_id

def table_exists(self, table: sql.Table) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import sqlalchemy as sql
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Integer, String, JSON, select
from sqlalchemy import Integer, String, JSON, select

from typing import List

from mipdb.dataelements import CommonDataElement
from mipdb.exceptions import DataBaseError
from mipdb.databases.sqlite import DataModel, Dataset
from mipdb.sqlite.sqlite import DataModel, Dataset

METADATA_TABLE = "variables_metadata"
PRIMARYDATA_TABLE = "primary_data"
Expand Down Expand Up @@ -77,6 +77,7 @@ def get_column_distinct(self, column, db):
def drop(self, db):
db.drop_table(self._table)


class DataModelTable(Table):
def __init__(self):
self._table = DataModel.__table__
Expand Down Expand Up @@ -111,7 +112,7 @@ def delete_data_model(self, code, version, db):
db.delete_from(self._table, where_conditions={"code": code, "version": version})

def get_next_data_model_id(self, db) -> int:
res = db.execute_fetchall('SELECT MAX(data_model_id) FROM data_models;')
res = db.execute_fetchall("SELECT MAX(data_model_id) FROM data_models;")
max_id = res[0][0] if res and res[0][0] is not None else 0
return max_id + 1

Expand All @@ -123,14 +124,14 @@ def __init__(self):
def get_datasets(self, db, columns: list = None):
return db.get_values(table=self._table, columns=columns, where_conditions={})

def get_dataset_codes(self, db, columns: List[str] = None, data_model_id: int = None) -> List[str]:
cols = columns or ['code']
def get_dataset_codes(
self, db, columns: List[str] = None, data_model_id: int = None
) -> List[str]:
cols = columns or ["code"]
stmt = select(*[self.table.c[col] for col in cols])
if data_model_id is not None:
stmt = stmt.where(self.table.c.data_model_id == data_model_id)
compiled_sql = str(
stmt.compile(compile_kwargs={"literal_binds": True})
)
compiled_sql = str(stmt.compile(compile_kwargs={"literal_binds": True}))
rows = db.execute_fetchall(compiled_sql)
codes: List[str] = []
for row in rows:
Expand Down
Loading