diff --git a/open_mastr/mastr.py b/open_mastr/mastr.py index be617eb8..5221e75b 100644 --- a/open_mastr/mastr.py +++ b/open_mastr/mastr.py @@ -1,14 +1,20 @@ import os -from sqlalchemy import inspect, create_engine +from pathlib import Path +from sqlalchemy import inspect, create_engine, Engine, Table +from sqlalchemy.orm import DeclarativeBase +from typing import Literal, Optional, Type, TypeVar, Union +from collections.abc import Mapping # import xml dependencies from open_mastr.xml_download.utils_download_bulk import ( + download_documentation, download_xml_Mastr, delete_xml_files_not_from_given_date, ) from open_mastr.xml_download.utils_write_to_database import ( write_mastr_xml_to_database, ) +from open_mastr.utils.xsd_tables import MastrTableDescription, read_mastr_table_descriptions_from_xsd from open_mastr.utils.helpers import ( validate_parameter_format_for_download_method, @@ -34,6 +40,7 @@ setup_logger, ) import open_mastr.utils.orm as orm +from open_mastr.utils.sqlalchemy_tables import make_sqlalchemy_model_from_mastr_table_description # constants from open_mastr.utils.constants import TECHNOLOGIES, ADDITIONAL_TABLES @@ -41,6 +48,10 @@ # setup logger log = setup_logger() +# TODO: Repeating Type[DeclarativeBase_T] in function signatures is strange. There must be a better option. +DeclarativeBase_T = TypeVar("DeclarativeBase_T", bound=DeclarativeBase) +FALLBACK_DOCS_PATH = Path(__file__).parent / "resources" / "Dokumentation-MaStR-Gesamtdatenexport-20251227-Fallback.zip" + class Mastr: """ @@ -71,21 +82,23 @@ class Mastr: """ - def __init__(self, engine="sqlite", connect_to_translated_db=False) -> None: + def __init__( + self, + engine: Union[Engine, Literal["sqlite"]] = "sqlite", + mastr_table_to_db_table_name: Optional[dict[str, str]] = None, + output_dir: Optional[Union[str, Path]] = None, + home_dir: Optional[Union[str, Path]] = None, + ) -> None: validate_parameter_format_for_mastr_init(engine) - self.output_dir = get_output_dir() - self.home_directory = get_project_home_dir() + self.output_dir = output_dir or get_output_dir() + self.home_directory = home_dir or get_project_home_dir() + self._sqlite_folder_path = os.path.join(self.output_dir, "data", "sqlite") + os.makedirs(self._sqlite_folder_path, exist_ok=True) - self.is_translated = connect_to_translated_db - if connect_to_translated_db: - self.engine = create_translated_database_engine( - engine, self._sqlite_folder_path - ) - else: - self.engine = create_database_engine(engine, self._sqlite_folder_path) + self.engine = create_database_engine(engine, self._sqlite_folder_path) log.info( "\n==================================================\n" @@ -97,7 +110,39 @@ def __init__(self, engine="sqlite", connect_to_translated_db=False) -> None: "'pip install --upgrade open-mastr'\n" ) - orm.Base.metadata.create_all(self.engine) + def generate_data_model( + self, + data: Optional[list[str]] = None, + catalog_value_as_str: bool = True, + base: Optional[Type[DeclarativeBase_T]] = None, + ) -> dict[str, Type[DeclarativeBase_T]]: + data = transform_data_parameter(data) + + docs_folder_path = os.path.join(self.output_dir, "data", "docs_download") + os.makedirs(docs_folder_path, exist_ok=True) + zipped_docs_file_path = os.path.join( + docs_folder_path, + "Dokumentation MaStR Gesamtdatenexport.zip" + ) + try: + download_documentation(zipped_docs_file_path) + return _download_docs_and_generate_data_model( + zipped_docs_file_path=zipped_docs_file_path, + data=data, + catalog_value_as_str=catalog_value_as_str, + base=base, + ) + except Exception as e: + log.exception( + f"Encountered {e} when downloading or processing MaStR documentation." + f" Falling back to stored docs at {FALLBACK_DOCS_PATH}" + ) + return _download_docs_and_generate_data_model( + zipped_docs_file_path=FALLBACK_DOCS_PATH, + data=data, + catalog_value_as_str=catalog_value_as_str, + base=base + ) def download( self, @@ -106,6 +151,8 @@ def download( date=None, bulk_cleansing=True, keep_old_downloads: bool = False, + mastr_table_to_db_table: Optional[Mapping[str, Table]] = None, + alter_database_tables: bool = True, **kwargs, ) -> None: """ @@ -165,13 +212,6 @@ def download( keep_old_downloads: bool If set to True, prior downloaded MaStR zip files will be kept. """ - - if self.is_translated: - raise TypeError( - "You are currently connected to a translated database.\n" - "A translated database cannot be further processed." - ) - if method == "API": log.warning( "Downloading the whole registry via the MaStR SOAP-API is deprecated. " @@ -181,6 +221,20 @@ def download( log.warning("Attention: method='API' changed to method='bulk'.") method = "bulk" + if not mastr_table_to_db_table: + mastr_table_to_db_model = self.generate_data_model( + data=data, + catalog_value_as_str=bulk_cleansing, + ) + mastr_table_to_db_table = { + mastr_table: db_model.__table__ + for mastr_table, db_model in mastr_table_to_db_model.items() + } + log.info("Ensuring database tables for MaStR are present: Dropping old tables if existing and creating new ones.") + for db_table in mastr_table_to_db_table.values(): + db_table.drop(self.engine, checkfirst=True) + db_table.create(self.engine) + validate_parameter_format_for_download_method( method=method, data=data, @@ -192,21 +246,20 @@ def download( date = transform_date_parameter(self, date, **kwargs) - # Find the name of the zipped xml folder bulk_download_date = parse_date_string(date) xml_folder_path = os.path.join(self.output_dir, "data", "xml_download") os.makedirs(xml_folder_path, exist_ok=True) zipped_xml_file_path = os.path.join( xml_folder_path, - f"Gesamtdatenexport_{bulk_download_date}.zip", + f"Gesamtdatenexport_{bulk_download_date.strftime('%Y%m%d')}.zip", ) delete_zip_file_if_corrupted(zipped_xml_file_path) if not keep_old_downloads: - delete_xml_files_not_from_given_date(zipped_xml_file_path, xml_folder_path) + delete_xml_files_not_from_given_date(zipped_xml_file_path, xml_folder_path) - download_xml_Mastr(zipped_xml_file_path, date, data, xml_folder_path) + download_xml_Mastr(zipped_xml_file_path, bulk_download_date, data, xml_folder_path) log.info( "\nWould you like to speed up the creation of your MaStR database?\n" @@ -217,7 +270,6 @@ def download( delete_zip_file_if_corrupted(zipped_xml_file_path) delete_xml_files_not_from_given_date(zipped_xml_file_path, xml_folder_path) - print( "\nWould you like to speed up the creation of your MaStR database?\n" "Try our new parallelized processing by setting os.environ['USE_RECOMMENDED_NUMBER_OF_PROCESSES'] = True " @@ -230,146 +282,40 @@ def download( data=data, bulk_cleansing=bulk_cleansing, bulk_download_date=bulk_download_date, + mastr_table_to_db_table=mastr_table_to_db_table, + alter_database_tables=alter_database_tables, ) def to_csv( self, tables: list = None, chunksize: int = 500000, limit: int = None ) -> None: - """ - Save the database as csv files along with the metadata file. - If 'tables=None' all possible tables will be exported. - - Parameters - ------------ - tables: None or list - For exporting selected tables choose from: - ["wind", "solar", "biomass", "hydro", "gsgk", "combustion", "nuclear", "storage", - "balancing_area", "electricity_consumer", "gas_consumer", "gas_producer", - "gas_storage", "gas_storage_extended", - "grid_connections", "grids", "market_actors", "market_roles", - "locations_extended", "permit", "deleted_units", "storage_units"] - chunksize: int - Defines the chunksize of the tables export. - Default value is 500.000 rows to include in each chunk. - limit: None or int - Limits the number of exported data rows. - """ - - if self.is_translated: - raise TypeError( - "You are currently connected to a translated database.\n" - "A translated database cannot be used for the csv export." - ) - - log.info("Starting csv-export") - - data_path = get_data_version_dir() - - create_data_dir() - - # Validate and parse tables parameter - validate_parameter_data(method="csv_export", data=tables) - data = transform_data_parameter( - method="bulk", data=tables, api_data_types=None, api_location_types=None + pass + # TODO: Think about this. + + +def _download_docs_and_generate_data_model( + zipped_docs_file_path: Path, + data: list[str], + catalog_value_as_str: bool = True, + base: Optional[Type[DeclarativeBase_T]] = None, +): + if base is None: + + class MastrBase(DeclarativeBase): + pass + + base = MastrBase + + mastr_table_descriptions = read_mastr_table_descriptions_from_xsd( + zipped_docs_file_path=zipped_docs_file_path, data=data + ) + mastr_table_to_db_model: dict[str, DeclarativeBase_T] = {} + for mastr_table_description in mastr_table_descriptions: + sqlalchemy_model = make_sqlalchemy_model_from_mastr_table_description( + table_description=mastr_table_description, + catalog_value_as_str=catalog_value_as_str, + base=base ) + mastr_table_to_db_model[mastr_table_description.table_name] = sqlalchemy_model - # Determine tables to export - technologies_to_export = [] - additional_tables_to_export = [] - for table in data: - if table in TECHNOLOGIES: - technologies_to_export.append(table) - elif table in ADDITIONAL_TABLES: - additional_tables_to_export.append(table) - else: - additional_tables_to_export.extend( - data_to_include_tables([table], mapping="export_db_tables") - ) - - if technologies_to_export: - log.info(f"Technology tables: {technologies_to_export}") - if additional_tables_to_export: - log.info(f"Additional tables: {additional_tables_to_export}") - - log.info(f"Tables are saved to: {data_path}") - - reverse_fill_basic_units(technology=technologies_to_export, engine=self.engine) - - # Export technologies to csv - for tech in technologies_to_export: - db_query_to_csv( - db_query=create_db_query(tech=tech, limit=limit, engine=self.engine), - data_table=tech, - chunksize=chunksize, - ) - # Export additional tables to csv - for addit_table in additional_tables_to_export: - db_query_to_csv( - db_query=create_db_query( - additional_table=addit_table, limit=limit, engine=self.engine - ), - data_table=addit_table, - chunksize=chunksize, - ) - - # FIXME: Currently metadata is only created for technology data, Fix in #386 - # Configure and save data package metadata file along with data - # save_metadata(data=technologies_to_export, engine=self.engine) - - def translate(self) -> None: - """ - A database can be translated only once. - - Deletes translated versions of the currently connected database. - - Translates currently connected database,renames it with '-translated' - suffix and updates self.engine's path accordingly. - - !!! example - ```python - - from open_mastr import Mastr - import pandas as pd - - db = Mastr() - db.download(data='biomass') - db.translate() - - df = pd.read_sql(sql='biomass_extended', con=db.engine) - print(df.head(10)) - ``` - - """ - - if "sqlite" not in self.engine.dialect.name: - raise ValueError("engine has to be of type 'sqlite'") - if self.is_translated: - raise TypeError("The currently connected database is already translated.") - - inspector = inspect(self.engine) - old_path = r"{}".format(self.engine.url.database) - new_path = old_path[:-3] + "-translated.db" - - if os.path.exists(new_path): - try: - os.remove(new_path) - except Exception as e: - log.error( - f"An error occurred while removing old translated database: {e}" - ) - - log.info("Replacing previous version of the translated database...") - - for table in inspector.get_table_names(): - rename_table(table, inspector.get_columns(table), self.engine) - - self.engine.dispose() - - try: - os.rename(old_path, new_path) - log.info(f"Database '{old_path}' changed to '{new_path}'") - except Exception as e: - log.error(f"An error occurred while renaming database: {e}") - - self.engine = create_engine(f"sqlite:///{new_path}") - self.is_translated = True + return mastr_table_to_db_model diff --git a/open_mastr/resources/Dokumentation-MaStR-Gesamtdatenexport-20251227-Fallback.zip b/open_mastr/resources/Dokumentation-MaStR-Gesamtdatenexport-20251227-Fallback.zip new file mode 100644 index 00000000..242b39ca Binary files /dev/null and b/open_mastr/resources/Dokumentation-MaStR-Gesamtdatenexport-20251227-Fallback.zip differ diff --git a/open_mastr/utils/helpers.py b/open_mastr/utils/helpers.py index 1e8b1365..544dc879 100644 --- a/open_mastr/utils/helpers.py +++ b/open_mastr/utils/helpers.py @@ -1,9 +1,11 @@ import os import json from contextlib import contextmanager -from datetime import date +import datetime from warnings import warn +from typing import Literal, Union from zipfile import BadZipfile, ZipFile +from zoneinfo import ZoneInfo import dateutil import sqlalchemy @@ -33,6 +35,8 @@ TRANSLATIONS, ) +MASTR_TIMEZONE = ZoneInfo("Europe/Berlin") + def chunks(lst, n): """Yield successive n-sized chunks from lst. @@ -58,11 +62,14 @@ def create_database_engine(engine, sqlite_db_path) -> sqlalchemy.engine.Engine: return engine -def parse_date_string(bulk_date_string: str) -> str: +def parse_date_string(bulk_date_string: str) -> datetime.date: if bulk_date_string == "today": - return date.today().strftime("%Y%m%d") + dt = datetime.datetime.now(tz=MASTR_TIMEZONE) else: - return parse(bulk_date_string).strftime("%Y%m%d") + dt = parse(bulk_date_string) + if dt.tzinfo: + dt = dt.astimezone(MASTR_TIMEZONE) + return dt.date() def validate_parameter_format_for_mastr_init(engine) -> None: @@ -158,7 +165,7 @@ def transform_data_parameter(data, **kwargs): return data -def transform_date_parameter(self, date, **kwargs): +def transform_date_parameter(self, date: Union[datetime.date, Literal["today"]], **kwargs) -> Union[datetime.date, Literal["today"]]: date = kwargs.get("bulk_date", date) date = "today" if date is None else date if date == "existing": diff --git a/open_mastr/utils/sqlalchemy_tables.py b/open_mastr/utils/sqlalchemy_tables.py new file mode 100644 index 00000000..4f671deb --- /dev/null +++ b/open_mastr/utils/sqlalchemy_tables.py @@ -0,0 +1,164 @@ +import datetime +from dataclasses import dataclass +from typing import Any, Union, Type, TypeVar +from sqlalchemy import Column, Integer, String, Float, Boolean, Date, DateTime +from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped + +import xmlschema +from xmlschema.validators.simple_types import XsdAtomicBuiltin, XsdAtomicRestriction +from open_mastr.utils.xsd_tables import MastrColumnType, MastrTableDescription + + +# Potential hierarchy +# Id -> MastrNummer -> EinheitMastrNummer +# -> EegMastrNummer -> KwkMastrNummer -> GenMastrNummer +# -> MarktakteurMastrNummer -> NetzanschlusspunktMastrNummer +MASTR_TABLE_NAME_TO_PRIMARY_KEY_COLUMNS = { + "AnlagenEegBiomasse": {"EegMastrNummer"}, + "AnlagenEegGeothermieGrubengasDruckentspannung": {"EegMastrNummer"}, + "AnlagenEegSolar": {"EegMastrNummer"}, + "AnlagenEegSpeicher": {"EegMastrNummer"}, + "AnlagenEegWasser": {"EegMastrNummer"}, + "AnlagenEegWind": {"EegMastrNummer"}, + "AnlagenGasSpeicher": {"MastrNummer"}, + "AnlagenKwk": {"KwkMastrNummer"}, + "AnlagenStromSpeicher": {"MastrNummer"}, + "Bilanzierungsgebiete": {"Id"}, + "EinheitenAenderungNetzbetreiberzuordnungen": {"EinheitMastrNummer"}, # TODO: May not be a primary key on its own. Check this. + "EinheitenBiomasse": {"EinheitMastrNummer"}, + "EinheitenGasErzeuger": {"EinheitMastrNummer"}, + "EinheitenGasSpeicher": {"EinheitMastrNummer"}, + "EinheitenGasverbraucher": {"EinheitMastrNummer"}, + "EinheitenGenehmigung": {"GenMastrNummer"}, + "EinheitenGeothermieGrubengasDruckentspannung": {"EinheitMastrNummer"}, + "EinheitenKernkraft": {"EinheitMastrNummer"}, + "EinheitenSolar": {"EinheitMastrNummer"}, + "EinheitenStromSpeicher": {"EinheitMastrNummer"}, + "EinheitenStromVerbraucher": {"EinheitMastrNummer"}, + "Einheitentypen": {"Id"}, + "EinheitenVerbrennung": {"EinheitMastrNummer"}, + "EinheitenWasser": {"EinheitMastrNummer"}, + "EinheitenWind": {"EinheitMastrNummer"}, + "Ertuechtigungen": {"Id"}, + "GeloeschteUndDeaktivierteEinheiten": {"EinheitMastrNummer"}, + "GeloeschteUndDeaktivierteMarktakteure": {"MarktakteurMastrNummer"}, + "Katalogkategorien": {"Id"}, + "Katalogwerte": {"Id"}, + "Lokationen": {"MastrNummer"}, + "Lokationstypen": {"Id"}, + "MarktakteureUndRollen": {"MarktakteurMastrNummer"}, + "Marktakteure": {"MastrNummer"}, + "Marktfunktionen": {"Id"}, + "Marktrollen": {"Id"}, + "Netzanschlusspunkte": {"NetzanschlusspunktMastrNummer"}, + "Netze": {"MastrNummer"}, +} + + +class ParentAllTables(object): + DatenQuelle: Mapped[str] = mapped_column(String) + DatumDownload: Mapped[datetime.date] = mapped_column(Date) + + +DeclarativeBase_T = TypeVar("DeclarativeBase_T", bound=DeclarativeBase) + + +def make_sqlalchemy_model_from_mastr_table_description( + table_description: MastrTableDescription, + catalog_value_as_str: bool, + base: Type[DeclarativeBase_T], + mixins: tuple[type, ...] = (ParentAllTables,), +) -> Type[DeclarativeBase_T]: + return _make_sqlalchemy_model( + class_name=table_description.instance_name, + table_name=table_description.table_name, + column_name_to_column_type={ + column.name: _get_sqlalchemy_type_for_mastr_column_type( + mastr_column_type=column.type, + catalog_value_as_str=catalog_value_as_str, + ) + for column in table_description.columns + }, + primary_key_columns=MASTR_TABLE_NAME_TO_PRIMARY_KEY_COLUMNS[table_description.table_name], + base=base, + mixins=(ParentAllTables,) + ) + + +def _make_sqlalchemy_model( + class_name: str, + table_name: str, + column_name_to_column_type: dict[str, Any], + primary_key_columns: set[str], + base: Type[DeclarativeBase_T], + mixins: tuple[type, ...] = tuple(), +) -> Type[DeclarativeBase_T]: # TODO: Is there a way to say that the returned model is a sub-type of DeclarativeBase_T? + namespace = { + "__tablename__": table_name, + "__annotations__": {}, + } + + for column_name, column_type in column_name_to_column_type.items(): + kwargs = {"primary_key": True} if column_name in primary_key_columns else {"nullable": True} + namespace[column_name] = mapped_column(column_type, **kwargs) + + bases = (base,) + mixins + return type(class_name, bases, namespace) + + +_MASTR_COLUMN_TYPE_TO_SQLALCHEMY_TYPE = { + MastrColumnType.STRING: String, + MastrColumnType.INTEGER: Integer, + MastrColumnType.FLOAT: Float, + MastrColumnType.DATE: Date, + MastrColumnType.DATETIME: DateTime(timezone=True), + MastrColumnType.BOOLEAN: Boolean, +} + + +# We're creating special column types for the catalog columns here so that +# we can identify the catalog columns later when processing the XML files. +class CatalogInteger(Integer): + pass + + +class CatalogString(String): + pass + + +def _get_sqlalchemy_type_for_mastr_column_type( + mastr_column_type: MastrColumnType, catalog_value_as_str: bool, +) -> Union[Type[String], Type[Integer], Type[Float], Type[Date], Type[DateTime], Type[Boolean]]: + if mastr_column_type is MastrColumnType.CATALOG_VALUE: + return CatalogString if catalog_value_as_str else CatalogInteger + return _MASTR_COLUMN_TYPE_TO_SQLALCHEMY_TYPE[mastr_column_type] + + + +# TODO: Remove this or make it useful for outsiders. +if __name__ == "__main__": + import os + import sys + from sqlalchemy import create_engine + import traceback + import xmlschema + + print("Parsing XSD files") + xsd_path = sys.argv[1] + for xsd_path in sys.argv[1:]: + schema = xmlschema.XMLSchema(xsd_path) + try: + table_description = MastrTableDescription.from_xml_schema(schema) + except ValueError: + traceback.print_exc() + print("Failed for ", xsd_path) + sys.exit(1) + + model = make_sqlalchemy_model_from_mastr_table_description( + table_description=table_description, + ) + + db_path = os.path.join(os.getcwd(), "test.db") + print(f"Creating SQLite database at {db_path}") + engine = create_engine(f"sqlite:///{db_path}") + Base.metadata.create_all(engine) diff --git a/open_mastr/utils/xsd_tables.py b/open_mastr/utils/xsd_tables.py new file mode 100644 index 00000000..57bdf0b0 --- /dev/null +++ b/open_mastr/utils/xsd_tables.py @@ -0,0 +1,144 @@ +import os +import re +from enum import auto, Enum +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Union +from zipfile import ZipFile, ZipInfo +import xmlschema +from xmlschema.validators.simple_types import XsdAtomicBuiltin, XsdAtomicRestriction + +from open_mastr.utils.helpers import data_to_include_tables + +_XML_SCHEMA_PREFIX = "{http://www.w3.org/2001/XMLSchema}" + + +# TODO: Should we really mess with the original column names? +# The BNetzA "choice" to sometimes write MaStR and sometimes Mastr is certainly confusing, +# but are we the ones who should change that? +# Also TODO: Should we also apply the more opinionated normalization/renaming that is currently stored in orm.py? +# E.g. "VerknuepfteEinheitenMaStRNummern" -> "VerknuepfteEinheiten", "NetzanschlusspunkteMaStRNummern" -> "Netzanschlusspunkte", etc. +def normalize_column_name(original_mastr_column_name: str) -> str: + # BNethA sometimes has MaStR, other times MaStR. We normalize that. + # Also, in case the column names in the XSD contain äöüß, we replace them. This is probably a BNetzA oversight, but has happened at least once. + return original_mastr_column_name.replace("MaStR", "Mastr").replace("ä", "ae").replace("ö", "oe").replace("ü", "ue").replace("ß", "ss").strip() + + +class MastrColumnType(Enum): + STRING = auto() + INTEGER = auto() + FLOAT = auto() + DATE = auto() + DATETIME = auto() + BOOLEAN = auto() + CATALOG_VALUE = auto() + + @classmethod + def from_xsd_type(cls, xsd_type: Union[XsdAtomicBuiltin, XsdAtomicRestriction]) -> "MastrColumnDescription": + xsd_type_to_mastr_column_type = { + f"{_XML_SCHEMA_PREFIX}string": cls.STRING, + f"{_XML_SCHEMA_PREFIX}decimal": cls.INTEGER, + f"{_XML_SCHEMA_PREFIX}int": cls.INTEGER, + f"{_XML_SCHEMA_PREFIX}short": cls.INTEGER, + f"{_XML_SCHEMA_PREFIX}byte": cls.INTEGER, + f"{_XML_SCHEMA_PREFIX}float": cls.FLOAT, + f"{_XML_SCHEMA_PREFIX}date": cls.DATE, + f"{_XML_SCHEMA_PREFIX}dateTime": cls.DATETIME, + } + if xsd_type.is_restriction(): + if enumeration := xsd_type.enumeration: + if set(xsd_type.enumeration) == {0, 1}: + return cls.BOOLEAN + else: + return cls.CATALOG_VALUE + # Ertuechtigungen.xsd has some normal types defined as restrictions for some reason. + # We cope with that by extracting the primitive type it's restricted to. + inner_xsd_type = xsd_type.primitive_type + if mastr_column_type := xsd_type_to_mastr_column_type.get(inner_xsd_type.name): + return mastr_column_type + + if mastr_column_type := xsd_type_to_mastr_column_type.get(xsd_type.name): + return mastr_column_type + + raise ValueError(f"Could not determine MastrColumnType from XSD type {xsd_type!r}") + + +@dataclass(frozen=True) +class MastrColumnDescription: + name: str + type: MastrColumnType + + @classmethod + def from_xsd_element(cls, xsd_element: xmlschema.XsdElement) -> "MastrColumnDescription": + name = normalize_column_name(xsd_element.name) + return cls( + name=name, + type=MastrColumnType.from_xsd_type(xsd_element.type) + ) + + +@dataclass(frozen=True) +class MastrTableDescription: + table_name: str + instance_name: str + columns: tuple[MastrColumnDescription] + + @classmethod + def from_xml_schema(cls, schema: xmlschema.XMLSchema) -> "MastrTableDescription": + if len(schema.root_elements) != 1: + raise ValueError( + "XML schema must have exactly one root element," + f" but has {len(schema.root_elements)} ({schema.root_elements!r})" + ) + root = schema.root_elements[0] + + try: + main_element = root.content.content[0] + column_elements = main_element.content.content + except (AttributeError, IndexError, TypeError) as e: + raise ValueError(f"Could not find columns in XML schema {schema!r}") from e + + columns = tuple( + MastrColumnDescription.from_xsd_element(element) + for element in column_elements + ) + + return cls( + table_name=root.name, + instance_name=main_element.name, + columns=columns, + ) + + +def read_mastr_table_descriptions_from_xsd( + zipped_docs_file_path: Union[Path, str], data: list[str] +) -> set[MastrTableDescription]: + print(data) + include_tables = set(data_to_include_tables(data, mapping="write_xml")) + + mastr_table_descriptions = set() + with ZipFile(zipped_docs_file_path, "r") as docs_z: + xsd_zip_entry = _find_xsd_zip_entry(docs_z) + with ZipFile(docs_z.open(xsd_zip_entry)) as xsd_z: + for entry in xsd_z.filelist: + if entry.is_dir() or not entry.filename.endswith(".xsd"): + continue + + normalized_name = os.path.basename(entry.filename).removesuffix(".xsd").lower() + if normalized_name in include_tables: + with xsd_z.open(entry) as xsd_file: + mastr_table_description = MastrTableDescription.from_xml_schema(xmlschema.XMLSchema(xsd_file)) + mastr_table_descriptions.add(mastr_table_description) + + return mastr_table_descriptions + + +def _find_xsd_zip_entry(docs_zip_file: ZipFile) -> ZipInfo: + desired_filename = "xsd.zip" + for entry in docs_zip_file.filelist: + if os.path.basename(entry.filename) == desired_filename: + return entry + raise RuntimeError( + f"Did not find XSD files in the form of {desired_filename!r} in the documentation" + f" ZIP file {docs_zip_file.filename!r}" + ) diff --git a/open_mastr/xml_download/colums_to_replace.py b/open_mastr/xml_download/colums_to_replace.py index 421ac44c..fde8d5e7 100644 --- a/open_mastr/xml_download/colums_to_replace.py +++ b/open_mastr/xml_download/colums_to_replace.py @@ -1,6 +1,5 @@ -# system catalog is the mapping for the entries within the two columns -# Marktfunktionen und Lokationstyp (entry 1 is mapped to Stromnetzbetreiber -# in the column Marktfunktionen) +# system catalog is the mapping for the entries within the columns +# Marktfunktion, Lokationtyp and Einheittyp # The values for the system catalog can be found in the pdf of the bulk download # documentation: https://www.marktstammdatenregister.de/MaStR/Datendownload @@ -38,98 +37,3 @@ 12: "Gasspeichereinheit", }, } - -# columns to replace lists all columns where the entries have -# to be replaced according to the tables katalogwerte and katalogeinträge -# from the bulk download of the MaStR - -columns_replace_list = [ - # anlageneegsolar - "AnlageBetriebsstatus", - # anlageneegspeicher - # anlagenstromspeicher - # einheitensolar - "Land", - "Bundesland", - "EinheitSystemstatus", - "EinheitBetriebsstatus", - "Energietraeger", - "Einspeisungsart", - "GemeinsamerWechselrichterMitSpeicher", - "Lage", - "Leistungsbegrenzung", - "Hauptausrichtung", - "HauptausrichtungNeigungswinkel", - "Nutzungsbereich", - "Nebenausrichtung", - "NebenausrichtungNeigungswinkel", - "ArtDerFlaecheIds", - # einheitenstromspeicher - "AcDcKoppelung", - "Batterietechnologie", - "Technologie", - "Pumpspeichertechnologie", - "Einsatzort", - # geloeschteunddeaktivierteEinheiten - # geloeschteunddeaktivierteMarktAkteure - "MarktakteurStatus", - # lokationen - # marktakteure - "Personenart", - "Rechtsform", - "HauptwirtdschaftszweigAbteilung", - "HauptwirtdschaftszweigGruppe", - "HauptwirtdschaftszweigAbschnitt", - "Registergericht", - "LandAnZustelladresse", - # netzanschlusspunkte - "Gasqualitaet", - "Spannungsebene", - # anlageneegbiomasse - # anlageneeggeosolarthermiegrubenklaerschlammdruckentspannung - # anlageneegwasser - # anlageneegwind - # anlagengasspeicher - # anlagenkwk - # bilanzierungsgebiete - # einheitenaenderungnetzbetreiberzuordnungen - "ArtDerAenderung", - # einheitenbiomasse - "Hauptbrennstoff", - "Biomasseart", - # einheitengaserzeuger - # einheitengasspeicher - "Speicherart", - # einheitengasverbraucher - # einheitengenehmigung - "Art", - # einheitengeosolarthermiegrubenklaerschlammdruckentspannung - # einheitenkernkraft - # einheitenstromverbraucher - "ArtAbschaltbareLast", - # einheitentypen - # einheitenverbrennung - "WeitererHauptbrennstoff", - "WeitereBrennstoffe", - "ArtDerStilllegung", - # einheitenwasser - "ArtDesZuflusses", - "ArtDerWasserkraftanlage", - # marktrollen - # netze - "Sparte", - # einheitenwind - "Lage", - "Hersteller", - "Seelage", - "ClusterNordsee", - "ClusterOstsee", - # various tables - "NetzbetreiberpruefungStatus", - "WindAnLandOderAufSee", - "TechnologieFlugwindenergieanlage", - "Flughoehe", - "Flugradius", - "ArtDerSolaranlage", - "SpeicherAmGleichenOrt", -] diff --git a/open_mastr/xml_download/parse.py b/open_mastr/xml_download/parse.py new file mode 100644 index 00000000..6573d1ee --- /dev/null +++ b/open_mastr/xml_download/parse.py @@ -0,0 +1,144 @@ +import xmlschema +from xmlschema.validators import XsdComplexType, XsdSimpleType, XsdElement +from typing import Dict, List, Optional + +# ---------------------------------------------- +# 1. Mapping XSD builtin types → SQLAlchemy types +# ---------------------------------------------- +XSD_TO_SQLA = { + "string": "String", + "integer": "Integer", + "int": "Integer", + "short": "Integer", + "long": "BigInteger", + "decimal": "Float", + "float": "Float", + "double": "Float", + "boolean": "Boolean", + "date": "Date", + "dateTime": "DateTime", + "time": "Time", +} + + +def map_xsd_type(xsd_type: XsdSimpleType) -> str: + """Map XSD builtin type to SQLAlchemy column type.""" + if xsd_type.is_simple() and xsd_type.primitive_type: + name = xsd_type.primitive_type.local_name + return XSD_TO_SQLA.get(name, "String") # default fallback + return "String" + + +# ---------------------------------------------- +# 2. Main model generation +# ---------------------------------------------- +def generate_sqlalchemy_models(xsd_file: str) -> str: + schema = xmlschema.XMLSchema(xsd_file) + output = [] + + output.append("from sqlalchemy import Column, Integer, String, Float, Boolean, Date, DateTime, BigInteger, ForeignKey") + output.append("from sqlalchemy.orm import declarative_base, relationship") + output.append("\nBase = declarative_base()\n") + + processed_types = {} + + # Iterate over all global elements (entry points) + for element_name, element in schema.elements.items(): + output.append(generate_class_from_element(element, processed_types)) + + return "\n".join(output) + + +# ---------------------------------------------- +# 3. Generate a class for an element +# ---------------------------------------------- +def generate_class_from_element( + element: XsdElement, + processed_types: Dict[str, str] +) -> str: + """Generate a SQLAlchemy class for the top-level element.""" + cls_name = to_class_name(element.name) + + # If it is a complexType element → + if isinstance(element.type, XsdComplexType): + return generate_class_from_complex_type(cls_name, element.type, processed_types) + + return f"# Skipped simple element {element.name}\n" + + +# ---------------------------------------------- +# 4. Generate class for a complex type +# ---------------------------------------------- +def generate_class_from_complex_type( + cls_name: str, + complex_type: XsdComplexType, + processed_types: Dict[str, str] +) -> str: + + if cls_name in processed_types: + return "" # already generated + + processed_types[cls_name] = cls_name + + lines = [] + lines.append(f"class {cls_name}(Base):") + lines.append(f" __tablename__ = '{camel_to_snake(cls_name)}'") + lines.append(" id = Column(Integer, primary_key=True)\n") + + # Iterate through child elements (sequence, choice, etc.) + for child in complex_type.content.iter_elements(): + + child_name = child.name + col_name = camel_to_snake(child_name) + + if isinstance(child.type, XsdComplexType): + # Nested complex type → child table with relationship + child_class_name = to_class_name(child_name) + lines.append( + f" {col_name}_id = Column(Integer, ForeignKey('{camel_to_snake(child_class_name)}.id'))" + ) + lines.append( + f" {col_name} = relationship('{child_class_name}')" + ) + # Generate nested class too + nested = generate_class_from_complex_type(child_class_name, child.type, processed_types) + lines.append("\n" + nested) + + else: + # Simple child element + sqlalchemy_type = map_xsd_type(child.type) + + nullable = "True" if child.min_occurs == 0 else "False" + lines.append( + f" {col_name} = Column({sqlalchemy_type}, nullable={nullable})" + ) + + lines.append("") + return "\n".join(lines) + + +# ---------------------------------------------- +# 5. Helpers +# ---------------------------------------------- +def to_class_name(name: str) -> str: + return "".join(part.capitalize() for part in name.split("_")) + + +def camel_to_snake(name: str) -> str: + out = "" + for i, ch in enumerate(name): + if ch.isupper() and i > 0: + out += "_" + out += ch.lower() + return out + + +# ---------------------------------------------- +# 6. Run example +# ---------------------------------------------- +if __name__ == "__main__": + import sys + xsd_path = sys.argv[1] + models = generate_sqlalchemy_models(xsd_path) + print(models) + diff --git a/open_mastr/xml_download/schema.py b/open_mastr/xml_download/schema.py new file mode 100644 index 00000000..2b223d7d --- /dev/null +++ b/open_mastr/xml_download/schema.py @@ -0,0 +1,49 @@ +from pathlib import Path +import glob +from xml.etree import ElementTree + +import xmlschema + + +def check_if_files_valid_under_schema(xsd_file, xml_files): + schema = xmlschema.XMLSchema(xsd_file) + for xml_file in xml_files: + xml_resource = xmlschema.XMLResource(xml_file, lazy=True) + errors = schema.iter_errors(xml_resource) + error_count = 0 + for error in errors: + error_count += 1 + breakpoint() + print(" -", error) + if error_count == 0: + print(f"{xml_file}\tValid.") + + +def check_if_files_valid_under_schema_et(xsd_file, xml_files): + schema = xmlschema.XMLSchema(xsd_file) + for xml_file in xml_files: + xt = ElementTree.parse(xml_file) + errors = schema.iter_errors(xt) + error_count = 0 + for error in errors: + error_count += 1 + breakpoint() + print(" -", error) + if error_count == 0: + print(f"{xml_file}\tValid.") + + + +def main(): + xsd_root = Path("/home/gorgor/.open-MaStR/data/xml_download/Dokumentation MaStR Gesamtdatenexport/xsd") + xml_root = Path("/home/gorgor/.open-MaStR/data/xml_download/Gesamtdatenexport_20251129") + xsd_file = xsd_root / "EinheitenWind.xsd" + xml_files = [xml_root / basename for basename in glob.glob("EinheitenWind*.xml", root_dir=xml_root)] + xml_files =["/home/gorgor/.open-MaStR/data/xml_download/EinheitenWind_formatted.xml"] + print(xsd_file) + print(xml_files) + check_if_files_valid_under_schema_et(xsd_file=xsd_file, xml_files=xml_files) + + +if __name__ == "__main__": + main() diff --git a/open_mastr/xml_download/utils_cleansing_bulk.py b/open_mastr/xml_download/utils_cleansing_bulk.py index b48a50f1..b38da277 100644 --- a/open_mastr/xml_download/utils_cleansing_bulk.py +++ b/open_mastr/xml_download/utils_cleansing_bulk.py @@ -1,23 +1,28 @@ import pandas as pd import numpy as np +from collections.abc import Collection +from zipfile import ZipFile + from open_mastr.xml_download.colums_to_replace import ( system_catalog, - columns_replace_list, ) -from zipfile import ZipFile -def cleanse_bulk_data(df: pd.DataFrame, zipped_xml_file_path: str) -> pd.DataFrame: - df = replace_ids_with_names(df, system_catalog) - # Katalogeintraege: int -> string value +def cleanse_bulk_data( + df: pd.DataFrame, + catalog_columns: Collection[str], + zipped_xml_file_path: str, +) -> pd.DataFrame: + df = replace_system_catalog_ids(df, system_catalog) + catalog_columns = set(catalog_columns) - system_catalog.keys() df = replace_mastr_katalogeintraege( - zipped_xml_file_path=zipped_xml_file_path, df=df + zipped_xml_file_path=zipped_xml_file_path, df=df, catalog_columns=catalog_columns, ) return df -def replace_ids_with_names(df: pd.DataFrame, system_catalog: dict) -> pd.DataFrame: - """Replaces ids with names according to the system catalog. This is +def replace_system_catalog_ids(df: pd.DataFrame, system_catalog: dict[int, str]) -> pd.DataFrame: + """Replaces IDs with names according to the system catalog. This is necessary since the data from the bulk download encodes columns with IDs instead of the actual values.""" for column_name, name_mapping_dictionary in system_catalog.items(): @@ -27,16 +32,18 @@ def replace_ids_with_names(df: pd.DataFrame, system_catalog: dict) -> pd.DataFra def replace_mastr_katalogeintraege( - zipped_xml_file_path: str, df: pd.DataFrame, + catalog_columns: Collection[str], + zipped_xml_file_path: str, ) -> pd.DataFrame: """Replaces the IDs from the mastr database by its mapped string values from - the table katalogwerte""" + the table Katalogwerte""" + # TODO: Create Katalogwerte dict once for whole download, not once per processed file. katalogwerte = create_katalogwerte_from_bulk_download(zipped_xml_file_path) for column_name in df.columns: - if column_name in columns_replace_list: + if column_name in catalog_columns: if df[column_name].dtype == "O": - # Handle comma seperated strings from catalog values + # Handle comma-separated strings from catalog values df[column_name] = ( df[column_name] .str.split(",", expand=True) diff --git a/open_mastr/xml_download/utils_download_bulk.py b/open_mastr/xml_download/utils_download_bulk.py index a8d37ae3..52af4422 100644 --- a/open_mastr/xml_download/utils_download_bulk.py +++ b/open_mastr/xml_download/utils_download_bulk.py @@ -1,3 +1,5 @@ +import datetime +import math import os import shutil import time @@ -24,7 +26,7 @@ def gen_version( - when: time.struct_time = time.localtime(), use_version: str = "current" + when: datetime.date, use_version: str = "current" ) -> str: """ Generates the current version. @@ -53,13 +55,13 @@ def gen_version( 2024-31-12 = version 24.2 """ - year = when.tm_year + year = when.year release = 1 - if when.tm_mon < 4 or (when.tm_mon == 4 and when.tm_mday == 1): + if when.month < 4 or (when.month == 4 and when.day == 1): year = year - 1 release = 2 - elif when.tm_mon > 10 or (when.tm_mon == 10 and when.tm_mday > 1): + elif when.month > 10 or (when.month == 10 and when.day > 1): release = 2 # Change to MaStR version number that was used before @@ -84,7 +86,7 @@ def gen_version( return f"{year}.{release}" -def gen_url(when: time.struct_time = time.localtime(), use_version="current") -> str: +def gen_url(when: datetime.date, use_version="current") -> str: """Generates the download URL for the specified date. Note that not all dates are archived on the website. @@ -110,13 +112,13 @@ def gen_url(when: time.struct_time = time.localtime(), use_version="current") -> Defaults to "current". """ version = gen_version(when, use_version) - date = time.strftime("%Y%m%d", when) + date = when.strftime("%Y%m%d") return f"https://download.marktstammdatenregister.de/Gesamtdatenexport_{date}_{version}.zip" def download_xml_Mastr( - save_path: str, bulk_date_string: str, bulk_data_list: list, xml_folder_path: str + save_path: str, bulk_date: datetime.date, bulk_data_list: list, xml_folder_path: str ) -> None: """Downloads the zipped MaStR. @@ -124,7 +126,7 @@ def download_xml_Mastr( ----------- save_path: str Full file path where the downloaded MaStR zip file will be saved. - bulk_date_string: str + bulk_date_string: datetime.date Date for which the file should be downloaded. bulk_data_list: list List of tables/technologis to be downloaded. @@ -134,9 +136,7 @@ def download_xml_Mastr( log.info("Starting the Download from marktstammdatenregister.de.") - # TODO this should take bulk_date_string - now = time.localtime() - url = gen_url(now) + url = gen_url(bulk_date) time_a = time.perf_counter() r = requests.get(url, stream=True, headers={"User-Agent": USER_AGENT}) @@ -144,19 +144,17 @@ def download_xml_Mastr( log.warning( "Download file was not found. Assuming that the new file was not published yet and retrying with yesterday." ) - now = time.localtime( - time.mktime(now) - (24 * 60 * 60) - ) # subtract 1 day from the date - url = gen_url(now) + bulk_date -= datetime.timedelta(days=1) + url = gen_url(bulk_date) r = requests.get(url, stream=True, headers={"User-Agent": USER_AGENT}) if r.status_code == 404: - url = gen_url(now, use_version="before") # Use lower MaStR Version + url = gen_url(bulk_date, use_version="before") # Use lower MaStR Version log.warning( f"Download file was not found. Assuming that the version of MaStR has changed and retrying with download link: {url}" ) r = requests.get(url, stream=True, headers={"User-Agent": USER_AGENT}) if r.status_code == 404: - url = gen_url(now, use_version="after") # Use higher MaStR Version + url = gen_url(bulk_date, use_version="after") # Use higher MaStR Version log.warning( f"Download file was not found. Assuming that the version of MaStR has changed and retrying with download link: {url}" ) @@ -321,6 +319,7 @@ def full_download_without_unzip_http( "Warning: The servers from MaStR restrict the download speed." " You may want to download it another time." ) + # TODO: Explain this number total_length = int(23000) with ( open(save_path, "wb") as zfile, @@ -339,3 +338,29 @@ def full_download_without_unzip_http( else: # remove warning bar.set_postfix_str(s="") + + +def download_documentation(save_path: str) -> None: + """Downloads the zipped MaStR. + + Parameters + ----------- + save_path: str + Full file path where the downloaded MaStR documentation zip file will be saved. + """ + log.info("Starting the MaStR documentation download from marktstammdatenregister.de.") + url = "https://www.marktstammdatenregister.de/MaStRHilfe/files/gesamtdatenexport/Dokumentation%20MaStR%20Gesamtdatenexport.zip" + + time_a = time.perf_counter() + r = requests.get(url, headers={"User-Agent": USER_AGENT}) + + r.raise_for_status() + with open(save_path, "wb") as zfile: + zfile.write(r.content) + + time_b = time.perf_counter() + log.info( + f"MaStR documentation download is finished. It took {round(time_b - time_a)} seconds." + ) + log.info(f"MaStR was successfully downloaded to {save_path!r}.") + diff --git a/open_mastr/xml_download/utils_write_to_database.py b/open_mastr/xml_download/utils_write_to_database.py index e71abc18..d45945cd 100644 --- a/open_mastr/xml_download/utils_write_to_database.py +++ b/open_mastr/xml_download/utils_write_to_database.py @@ -1,8 +1,10 @@ import os +from collections.abc import Collection, Mapping from concurrent.futures import ProcessPoolExecutor, wait from io import StringIO from multiprocessing import cpu_count from shutil import Error +from typing import Type, TypeVar from zipfile import ZipFile import re @@ -10,17 +12,22 @@ import numpy as np import pandas as pd import sqlalchemy -from sqlalchemy import select, create_engine, inspect +from sqlalchemy import Column, Engine, Table, delete, select, create_engine, inspect +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql import text from sqlalchemy.sql.sqltypes import Date, DateTime from open_mastr.utils.config import setup_logger from open_mastr.utils.helpers import data_to_include_tables from open_mastr.utils.orm import tablename_mapping +from open_mastr.utils.xsd_tables import normalize_column_name +from open_mastr.utils.sqlalchemy_tables import CatalogInteger, CatalogString from open_mastr.xml_download.utils_cleansing_bulk import cleanse_bulk_data log = setup_logger() +DeclarativeBase_T = TypeVar("DeclarativeBase_T", bound=DeclarativeBase) + def write_mastr_xml_to_database( engine: sqlalchemy.engine.Engine, @@ -28,12 +35,15 @@ def write_mastr_xml_to_database( data: list, bulk_cleansing: bool, bulk_download_date: str, + mastr_table_to_db_table: Mapping[str, Table], + alter_database_tables: bool, ) -> None: """Write the Mastr in xml format into a database defined by the engine parameter.""" log.info("Starting bulk download...") include_tables = data_to_include_tables(data, mapping="write_xml") threads_data = [] + lower_mastr_table_to_db_table = {table_name.lower(): db_table for table_name, db_table in mastr_table_to_db_table.items()} with ZipFile(zipped_xml_file_path, "r") as f: files_list = correct_ordering_of_filelist(f.namelist()) @@ -44,17 +54,21 @@ def write_mastr_xml_to_database( if not is_table_relevant(xml_table_name, include_tables): continue - sql_table_name = extract_sql_table_name(xml_table_name) + db_table = lower_mastr_table_to_db_table.get(xml_table_name) + if db_table is None: + log.warning(f"Skipping MaStR file {file_name!r} because no database table was found for {xml_table_name=}") + continue + threads_data.append( ( file_name, - xml_table_name, - sql_table_name, + db_table, str(engine.url), engine.url.password, zipped_xml_file_path, bulk_download_date, bulk_cleansing, + alter_database_tables, ) ) @@ -100,13 +114,13 @@ def get_number_of_processes(): def process_xml_file( file_name: str, - xml_table_name: str, - sql_table_name: str, + db_table: Table, connection_url: str, password: str, zipped_xml_file_path: str, bulk_download_date: str, bulk_cleansing: bool, + alter_database_tables: bool, ) -> None: """Process a single xml file and write it to the database.""" try: @@ -122,26 +136,88 @@ def process_xml_file( with ZipFile(zipped_xml_file_path, "r") as f: log.info(f"Processing file '{file_name}'...") if is_first_file(file_name): - log.info(f"Creating table '{sql_table_name}'...") - create_database_table(engine, xml_table_name) + delete_all_existing_rows(db_table=db_table, engine=engine) df = read_xml_file(f, file_name) df = process_table_before_insertion( - df, - xml_table_name, - zipped_xml_file_path, - bulk_download_date, - bulk_cleansing, + df=df, + db_table=db_table, + zipped_xml_file_path=zipped_xml_file_path, + bulk_download_date=bulk_download_date, + bulk_cleansing=bulk_cleansing, + ) + df = check_for_column_mismatch_and_try_to_solve_it( + df=df, + db_table=db_table, + engine=engine, + alter_database_tables=alter_database_tables, ) if engine.dialect.name == "sqlite": - add_table_to_sqlite_database(df, xml_table_name, sql_table_name, engine) + add_table_to_sqlite_database( + df=df, + db_table=db_table, + engine=engine, + ) else: add_table_to_non_sqlite_database( - df, xml_table_name, sql_table_name, engine + df=df, + db_table=db_table, + engine=engine, ) except Exception as e: log.error(f"Error processing file '{file_name}': '{e}'") +def delete_all_existing_rows(db_table: Table, engine: Engine) -> None: + with engine.begin() as con: + con.execute(delete(db_table)) + + +def check_for_column_mismatch_and_try_to_solve_it( + df: pd.DataFrame, + db_table: Table, + engine: Engine, + alter_database_tables: bool, +) -> pd.DataFrame: + df_column_names = set(df.columns) + db_column_names = {column.name for column in db_table.columns} + + if additional_db_column_names := db_column_names - df_column_names: + # Many columns are optional and it's perfectly normal to have and XML file / a dataframe that doesn't have + # a column that is present in the database. So this is only worth a debug message. + log.debug( + f"Database table {db_table.name} has some columns that weren't found in the XML file." + f" Proceeding and trying to insert anyway. Additional DB columns:" + f" {', '.join(additional_db_column_names)}" + ) + + if additional_df_column_names := df_column_names - db_column_names: + if alter_database_tables: + log.warning( + f"XML file has some columns that aren't present in the database table {db_table.name}." + f" Trying to add the columns to the table. Additional XML columns:" + f" {', '.join(additional_df_column_names)}" + ) + # TODO: What if we can add some columns and not others? We should then return the columns for which we succeeded. + try: + add_missing_columns_to_table( + engine=engine, + db_table=db_table, + missing_columns=additional_df_column_names, + ) + except Exception: + log.exception("Could not add at least some columns to the database. Ignoring the columns from the XML file instead.") + df = df.drop(columns=additional_df_column_names) + else: + log.warning( + f"XML file has some columns that aren't present in the database table {db_table.name}." + f" Ignoring those columns since you asked not to alter tables. Additional XML columns:" + f" {', '.join(additional_df_column_names)}" + ) + df = df.drop(columns=additional_df_column_names) + + return df + + def create_efficient_engine(connection_url: str) -> sqlalchemy.engine.Engine: """Create an efficient engine for the SQLite database.""" is_sqlite = connection_url.startswith("sqlite://") @@ -254,44 +330,35 @@ def is_first_file(file_name: str) -> bool: def cast_date_columns_to_datetime( - xml_table_name: str, df: pd.DataFrame + db_table: Table, df: pd.DataFrame ) -> pd.DataFrame: - sqlalchemy_columnlist = tablename_mapping[xml_table_name][ - "__class__" - ].__table__.columns.items() - for column in sqlalchemy_columnlist: - column_name = column[0] - if is_date_column(column, df): + for column in db_table.columns: + if is_date_column(column) and column.name in df.columns: # Convert column to datetime64, invalid string -> NaT - df[column_name] = pd.to_datetime(df[column_name], errors="coerce") + df[column.name] = pd.to_datetime(df[column.name], errors="coerce") return df -def cast_date_columns_to_string(xml_table_name: str, df: pd.DataFrame) -> pd.DataFrame: - column_list = tablename_mapping[xml_table_name][ - "__class__" - ].__table__.columns.items() - for column in column_list: - column_name = column[0] - - if not (column[0] in df.columns and is_date_column(column, df)): +def cast_date_columns_to_string(db_table: Table, df: pd.DataFrame) -> pd.DataFrame: + for column in db_table.columns: + if not is_date_column(column) or column.name not in df.columns: continue - df[column_name] = pd.to_datetime(df[column_name], errors="coerce") + df[column.name] = pd.to_datetime(df[column.name], errors="coerce") - if type(column[1].type) is Date: - df[column_name] = ( - df[column_name].dt.strftime("%Y-%m-%d").replace("NaT", None) + if type(column.type) is Date: + df[column.name] = ( + df[column.name].dt.strftime("%Y-%m-%d").replace("NaT", None) ) - elif type(column[1].type) is DateTime: - df[column_name] = ( - df[column_name].dt.strftime("%Y-%m-%d %H:%M:%S.%f").replace("NaT", None) + elif type(column.type) is DateTime: + df[column.name] = ( + df[column.name].dt.strftime("%Y-%m-%d %H:%M:%S.%f").replace("NaT", None) ) return df -def is_date_column(column, df: pd.DataFrame) -> bool: - return type(column[1].type) in [Date, DateTime] and column[0] in df.columns +def is_date_column(column: Column) -> bool: + return type(column.type) in [Date, DateTime] def correct_ordering_of_filelist(files_list: list) -> list: @@ -329,46 +396,27 @@ def read_xml_file(f: ZipFile, file_name: str) -> pd.DataFrame: return handle_xml_syntax_error(xml_file.read().decode("utf-16"), error) -def change_column_names_to_orm_format( - df: pd.DataFrame, xml_table_name: str -) -> pd.DataFrame: - if tablename_mapping[xml_table_name]["replace_column_names"]: - df.rename( - columns=tablename_mapping[xml_table_name]["replace_column_names"], - inplace=True, - ) - return df - - def add_table_to_non_sqlite_database( df: pd.DataFrame, - xml_table_name: str, - sql_table_name: str, + db_table: Table, engine: sqlalchemy.engine.Engine, ) -> None: # get a dictionary for the data types - table_columns_list = list( - tablename_mapping[xml_table_name]["__class__"].__table__.columns - ) dtypes_for_writing_sql = { column.name: column.type - for column in table_columns_list + for column in db_table.columns if column.name in df.columns } # Convert date and datetime columns into the datatype datetime. - df = cast_date_columns_to_datetime(xml_table_name, df) - - add_missing_columns_to_table( - engine, xml_table_name, column_list=df.columns.tolist() - ) + df = cast_date_columns_to_datetime(db_table, df) for _ in range(10000): try: with engine.connect() as con: with con.begin(): df.to_sql( - sql_table_name, + db_table.name, con=con, index=False, if_exists="append", @@ -382,7 +430,7 @@ def add_table_to_non_sqlite_database( except sqlalchemy.exc.IntegrityError: # error resulting from Unique constraint failed df = write_single_entries_until_not_unique_comes_up( - df, xml_table_name, engine + df, db_table, engine ) @@ -419,7 +467,7 @@ def add_zero_as_first_character_for_too_short_string(df: pd.DataFrame) -> pd.Dat def write_single_entries_until_not_unique_comes_up( - df: pd.DataFrame, xml_table_name: str, engine: sqlalchemy.engine.Engine + df: pd.DataFrame, db_table: Table, engine: sqlalchemy.engine.Engine ) -> pd.DataFrame: """ Remove from dataframe these rows, which are already existing in the database table @@ -433,15 +481,14 @@ def write_single_entries_until_not_unique_comes_up( ------- Filtered dataframe """ + # TODO: Check if we need to support composite primary keys for the MaStR changes table. + # Because this here assumes single-column primary keys. + primary_key = next(c for c in db_table.columns if c.primary_key) - table = tablename_mapping[xml_table_name]["__class__"].__table__ - primary_key = next(c for c in table.columns if c.primary_key) - - with engine.connect() as con: - with con.begin(): - key_list = ( - pd.read_sql(sql=select(primary_key), con=con).values.squeeze().tolist() - ) + with engine.begin() as con: + key_list = ( + pd.read_sql(sql=select(primary_key), con=con).values.squeeze().tolist() + ) len_df_before = len(df) df = df.drop_duplicates( @@ -460,8 +507,8 @@ def write_single_entries_until_not_unique_comes_up( def add_missing_columns_to_table( engine: sqlalchemy.engine.Engine, - xml_table_name: str, - column_list: list, + db_table: Table, + missing_columns: Collection[str], ) -> None: """ Some files introduce new columns for existing tables. @@ -477,36 +524,27 @@ def add_missing_columns_to_table( ------- """ - log = setup_logger() - - # get the columns name from the existing database - inspector = sqlalchemy.inspect(engine) - table_name = tablename_mapping[xml_table_name]["__class__"].__table__.name - columns = inspector.get_columns(table_name) - column_names_from_database = [column["name"] for column in columns] - - missing_columns = set(column_list) - set(column_names_from_database) - + table_name = db_table.name for column_name in missing_columns: - if not column_exists(engine, table_name, column_name): - alter_query = 'ALTER TABLE %s ADD "%s" VARCHAR NULL;' % ( - table_name, - column_name, - ) - try: - with engine.connect().execution_options(autocommit=True) as con: - with con.begin(): - con.execute( - text(alter_query).execution_options(autocommit=True) - ) - except sqlalchemy.exc.OperationalError as err: - # If the column already exists, we can ignore the error. - if "duplicate column name" not in str(err): - raise err - log.info( - "From the downloaded xml files following new attribute was " - f"introduced: {table_name}.{column_name}" - ) + alter_query = 'ALTER TABLE %s ADD "%s" VARCHAR NULL;' % ( + table_name, + column_name, + ) + try: + with engine.connect().execution_options(autocommit=True) as con: + with con.begin(): + con.execute( + text(alter_query).execution_options(autocommit=True) + ) + except sqlalchemy.exc.OperationalError as err: + # If the column already exists, we can ignore the error. + if "duplicate column name" not in str(err): + raise err + log.info( + f"Added the following columns to database table {table_name}:" + f" {', '.join(missing_columns)}" + ) + def delete_wrong_xml_entry(err: Error, df: pd.DataFrame) -> pd.DataFrame: @@ -562,57 +600,67 @@ def find_nearest_brackets(xml_string: str, position: int) -> tuple[int, int]: def process_table_before_insertion( df: pd.DataFrame, - xml_table_name: str, + db_table: Table, zipped_xml_file_path: str, bulk_download_date: str, bulk_cleansing: bool, ) -> pd.DataFrame: df = add_zero_as_first_character_for_too_short_string(df) - df = change_column_names_to_orm_format(df, xml_table_name) # Add Column that refers to the source of the data df["DatenQuelle"] = "bulk" df["DatumDownload"] = bulk_download_date + df = normalize_column_names_in_df(df) + if bulk_cleansing: - df = cleanse_bulk_data(df, zipped_xml_file_path) + catalog_columns = { + column.name + for column in db_table.columns + # TODO: Is it okay to rely so heavily on the SQLALchemy model to decide how to process the table? + if isinstance(column.type, (CatalogInteger, CatalogString)) + } + df = cleanse_bulk_data( + df=df, catalog_columns=catalog_columns, zipped_xml_file_path=zipped_xml_file_path + ) return df +def normalize_column_names_in_df(df: pd.DataFrame) -> pd.DataFrame: + return df.rename(columns={column_name: normalize_column_name(column_name) for column_name in df.columns}) + + def add_table_to_sqlite_database( df: pd.DataFrame, - xml_table_name: str, - sql_table_name: str, + db_table: Table, engine: sqlalchemy.engine.Engine, ) -> None: column_list = df.columns.tolist() - add_missing_columns_to_table(engine, xml_table_name, column_list) # Convert NaNs to None. df = df.where(pd.notnull(df), None) # Convert date columns to strings. Dates are not supported directly by SQLite. - df = cast_date_columns_to_string(xml_table_name, df) + df = cast_date_columns_to_string(db_table, df) # Create SQL statement for bulk insert. ON CONFLICT DO NOTHING prevents duplicates. - insert_stmt = f"INSERT INTO {sql_table_name} ({','.join(column_list)}) VALUES ({','.join(['?' for _ in column_list])}) ON CONFLICT DO NOTHING" + insert_stmt = f"INSERT INTO {db_table.name} ({','.join(column_list)}) VALUES ({','.join(['?' for _ in column_list])}) ON CONFLICT DO NOTHING" for _ in range(10000): try: - with engine.connect() as con: - with con.begin(): - con.connection.executemany(insert_stmt, df.to_numpy()) - break + with engine.begin() as con: + con.connection.executemany(insert_stmt, df.to_numpy()) + break except sqlalchemy.exc.DataError as err: delete_wrong_xml_entry(err, df) except sqlalchemy.exc.IntegrityError: # error resulting from Unique constraint failed df = write_single_entries_until_not_unique_comes_up( - df, xml_table_name, engine + df, db_table, engine ) - except: + except Exception: # If any unexpected error occurs, we'll switch back to the non-SQLite method. - add_table_to_non_sqlite_database(df, xml_table_name, sql_table_name, engine) + add_table_to_non_sqlite_database(df, db_table, engine) break diff --git a/open_mastr/xml_download/xsd_to_table.py b/open_mastr/xml_download/xsd_to_table.py new file mode 100644 index 00000000..49fb63f7 --- /dev/null +++ b/open_mastr/xml_download/xsd_to_table.py @@ -0,0 +1,61 @@ + + +class SqlalchemyMastrModelMaker: + MASTR_COLUMN_TYPE_TO_SQLALCHEMY_TYPE = { + MastrColumnType.STRING: String, + MastrColumnType.INTEGER: Integer, + MastrColumnType.FLOAT: Float, + MastrColumnType.DATE: Date, + MastrColumnType.DATETIME: DateTime(timezone=True), + MastrColumnType.BOOLEAN: Boolean, + MastrColumnType.CATALOG_VALUE: Integer, + } + + @classmethod + def make_sqlalchemy_mastr_model( + cls, + table: MastrTableDescription, + primary_key_columns: set[str], + base: DeclarativeBase, + mixins: tuple[type, ...] = tuple(), + ): + namespace = { + "__tablename__": table.table_name, + "__annotations__": {}, + } + + for col in table.columns: + sa_type = cls.MASTR_COLUMN_TYPE_TO_SQLALCHEMY_TYPE[col.type] + kwargs = {"primary_key": True} if col.name in primary_key_columns else {"nullable": True} + namespace[col.name] = mapped_column(sa_type, **kwargs) + + bases = (base,) + mixins + return type(table.instance_name, bases, namespace) + + +class Base(DeclarativeBase): + pass + + +class ParentAllTables(object): + DatenQuelle = Column(String) + DatumDownload = Column(Date) + + +def generate_sqlalchemy_models(xsd_file: str) -> str: + schema = xmlschema.XMLSchema(xsd_file) + table = MastrTableDescription.from_xml_schema(schema) + + model = SqlalchemyMastrModelMaker.make_sqlalchemy_mastr_model( + table=table, + primary_key_columns={"EinheitMastrNummer"}, + base=Base, + mixins=(ParentAllTables,) + ) + + +if __name__ == "__main__": + import sys + xsd_path = sys.argv[1] + generate_sqlalchemy_models(xsd_path) + diff --git a/pyproject.toml b/pyproject.toml index 5871bfbe..205d7672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "keyring", "pyyaml", "xmltodict", + "xmlschema", ] requires-python = ">=3.9, <4" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5f8cfa81..f635b387 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,6 +1,7 @@ import pytest import os from os.path import expanduser +from pathlib import Path import itertools import random @@ -26,21 +27,10 @@ ) -# Check if db is empty -_db_exists = False -_db_folder_path = os.path.join( - expanduser("~"), ".open-MaStR", "data", "sqlite" -) # FIXME: use path in tmpdir when implemented -if os.path.isdir(_db_folder_path): - for entry in os.scandir(path=_db_folder_path): - _db_path = os.path.join(_db_folder_path, "open-mastr.db") - if os.path.getsize(_db_path) > 1000000: # empty db = 327.7kB < 1 MB - _db_exists = True - - @pytest.fixture -def db(): - return Mastr() +def mastr(tmp_path: Path): + output_dir = tmp_path / "output_dir" + return Mastr(output_dir=output_dir) def test_Mastr_validate_working_parameter(): @@ -119,8 +109,8 @@ def test_Mastr_validate_not_working_parameter(): ) -def test_validate_parameter_format_for_mastr_init(db): - engine_list_working = ["sqlite", db.engine] +def test_validate_parameter_format_for_mastr_init(mastr): + engine_list_working = ["sqlite", mastr.engine] engine_list_failing = ["HI", 12] for engine in engine_list_working: diff --git a/tests/test_mastr.py b/tests/test_mastr.py index ce7cd6fa..c2355f49 100644 --- a/tests/test_mastr.py +++ b/tests/test_mastr.py @@ -1,4 +1,5 @@ import shutil +from pathlib import Path from open_mastr.mastr import Mastr import os @@ -14,97 +15,58 @@ _xml_folder_path = os.path.join(expanduser("~"), ".open-MaStR", "data", "xml_download") if os.path.isdir(_xml_folder_path): for entry in os.scandir(path=_xml_folder_path): - if "Gesamtdatenexport" in entry.name: + if "Gesamtdatenexport" in entry.name and entry.name.endswith(".zip"): _xml_file_exists = True @pytest.fixture(scope="module") -def zipped_xml_file_path(): +def zipped_xml_file_path() -> str: zipped_xml_file_path = None for entry in os.scandir(path=_xml_folder_path): - if "Gesamtdatenexport" in entry.name: + if "Gesamtdatenexport" in entry.name and entry.name.endswith(".zip"): zipped_xml_file_path = os.path.join(_xml_folder_path, entry.name) return zipped_xml_file_path @pytest.fixture -def db_path(): - return os.path.join( - os.path.expanduser("~"), ".open-MaStR", "data", "sqlite", "mastr-test.db" - ) +def mastr(tmp_path: Path) -> Mastr: + output_dir = tmp_path / "output_dir" + return Mastr(output_dir=output_dir) -@pytest.fixture -def db(db_path): - return Mastr(engine=sqlalchemy.create_engine(f"sqlite:///{db_path}")) - - -@pytest.fixture -def db_translated(db_path): - engine = sqlalchemy.create_engine(f"sqlite:///{db_path}") - db_api = Mastr(engine=engine) - - db_api.download(date="existing", data=["wind", "hydro", "biomass", "combustion"]) - db_api.translate() - - return db_api - - -def test_Mastr_init(db): +def test_mastr_init(mastr: Mastr) -> None: # test if folder structure exists - assert os.path.exists(db.home_directory) - assert os.path.exists(db._sqlite_folder_path) + assert os.path.exists(mastr.home_directory) + assert os.path.exists(mastr._sqlite_folder_path) # test if engine and connection were created - assert type(db.engine) == sqlalchemy.engine.Engine - - -@pytest.mark.skipif( - not _xml_file_exists, reason="The zipped xml file could not be found." -) -def test_Mastr_translate(db_translated, db_path): - # test if database was renamed correctly - transl_path = db_path[:-3] + "-translated.db" - assert os.path.exists(transl_path) - - # test if columns got translated - inspector = sqlalchemy.inspect(db_translated.engine) - table_names = inspector.get_table_names() - - for table in table_names: - for column in inspector.get_columns(table): - column = column["name"] - assert column in TRANSLATIONS.values() or column not in TRANSLATIONS.keys() - - # test if new translated version replaces previous one - db_translated.engine.dispose() - engine = sqlalchemy.create_engine(f"sqlite:///{db_path}") - db_empty = Mastr(engine=engine) - db_empty.translate() - - for table in table_names: - assert pd.read_sql(sql=table, con=db_empty.engine).shape[0] == 0 + assert type(mastr.engine) == sqlalchemy.engine.Engine @pytest.mark.dependency(name="bulk_downloaded") -def test_mastr_download(db): - db.download(data="wind") - df_wind = pd.read_sql("wind_extended", con=db.engine) +def test_mastr_download(mastr: Mastr) -> None: + mastr.download(data="wind") + df_wind = pd.read_sql("EinheitenWind", con=mastr.engine) + assert len(df_wind) > 10000 + + mastr.download(data="biomass") + df_biomass = pd.read_sql("EinheitenBiomasse", con=mastr.engine) assert len(df_wind) > 10000 + assert len(df_biomass) > 10000 - db.download(data="biomass") - df_biomass = pd.read_sql("biomass_extended", con=db.engine) + mastr.download(data=["wind", "nuclear"]) + df_biomass = pd.read_sql("EinheitenBiomasse", con=mastr.engine) assert len(df_wind) > 10000 assert len(df_biomass) > 10000 @pytest.mark.dependency(depends=["bulk_downloaded"]) -def test_mastr_download_keep_old_files(db, zipped_xml_file_path): +def test_mastr_download_keep_old_files(mastr: Mastr, zipped_xml_file_path: str) -> None: file_today = zipped_xml_file_path yesterday = (date.today() - timedelta(days=1)).strftime("%Y%m%d") file_old = re.sub(r"\d{8}", yesterday, os.path.basename(file_today)) file_old = os.path.join(os.path.dirname(zipped_xml_file_path), file_old) shutil.copy(file_today, file_old) - db.download(data="gsgk", keep_old_files=True) + mastr.download(data="gsgk", keep_old_files=True) assert os.path.exists(file_old) diff --git a/tests/xml_download/test_utils_cleansing_bulk.py b/tests/xml_download/test_utils_cleansing_bulk.py index 9a29ad76..f3c01abc 100644 --- a/tests/xml_download/test_utils_cleansing_bulk.py +++ b/tests/xml_download/test_utils_cleansing_bulk.py @@ -7,6 +7,7 @@ import pytest from open_mastr.xml_download.utils_cleansing_bulk import ( + cleanse_bulk_data, create_katalogwerte_from_bulk_download, replace_mastr_katalogeintraege, ) @@ -16,7 +17,7 @@ _xml_folder_path = os.path.join(expanduser("~"), ".open-MaStR", "data", "xml_download") if os.path.isdir(_xml_folder_path): for entry in os.scandir(path=_xml_folder_path): - if "Gesamtdatenexport" in entry.name: + if "Gesamtdatenexport" in entry.name and entry.name.endswith(".zip"): _xml_file_exists = True _sqlite_folder_path = os.path.join(expanduser("~"), ".open-MaStR", "data", "sqlite") @@ -42,12 +43,40 @@ def con(): def zipped_xml_file_path(): zipped_xml_file_path = None for entry in os.scandir(path=_xml_folder_path): - if "Gesamtdatenexport" in entry.name: + if "Gesamtdatenexport" in entry.name and entry.name.endswith(".zip"): zipped_xml_file_path = os.path.join(_xml_folder_path, entry.name) return zipped_xml_file_path + +@pytest.mark.skipif( + not _xml_file_exists, reason="The zipped xml file could not be found." +) +def test_cleanse_bulk_data(zipped_xml_file_path): + df_raw = pd.DataFrame( + { + "ID": [0, 1, 2], + "Bundesland": [335, 335, 336], + "Einheittyp": [1, 8, 5], + } + ) + df_replaced = pd.DataFrame( + { + "ID": [0, 1, 2], + "Bundesland": ["Bayern", "Bayern", "Bremen"], + "Einheittyp": ["Solareinheit", "Stromspeichereinheit", "Geothermie"], + } + ) + + pd.testing.assert_frame_equal( + cleanse_bulk_data( + df=df_raw, zipped_xml_file_path=zipped_xml_file_path, catalog_columns={"Bundesland", "Einheittyp"}, + ), + df_replaced, + ) + + @pytest.mark.skipif( not _xml_file_exists, reason="The zipped xml file could not be found." ) @@ -57,7 +86,10 @@ def test_replace_mastr_katalogeintraege(zipped_xml_file_path): {"ID": [0, 1, 2], "Bundesland": ["Bayern", "Bayern", "Bremen"]} ) pd.testing.assert_frame_equal( - df_replaced, replace_mastr_katalogeintraege(zipped_xml_file_path, df_raw) + replace_mastr_katalogeintraege( + zipped_xml_file_path=zipped_xml_file_path, df=df_raw, catalog_columns={"Bundesland", "Einheittyp"}, + ), + df_replaced, ) diff --git a/tests/xml_download/test_utils_download_bulk.py b/tests/xml_download/test_utils_download_bulk.py index 8f650933..4557dbe8 100644 --- a/tests/xml_download/test_utils_download_bulk.py +++ b/tests/xml_download/test_utils_download_bulk.py @@ -1,3 +1,4 @@ +from datetime import date import time from open_mastr.xml_download.utils_download_bulk import ( gen_url, @@ -8,7 +9,7 @@ def test_gen_url(): - when = time.strptime("2024-01-01", "%Y-%m-%d") + when = date(2024, 1, 1) url = gen_url(when) assert type(url) == str assert ( @@ -16,7 +17,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20240101_23.2.zip" ) - when = time.strptime("2024-04-01", "%Y-%m-%d") + when = date(2024, 4, 1) url = gen_url(when) assert type(url) == str assert ( @@ -24,7 +25,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20240401_23.2.zip" ) - when = time.strptime("2024-04-02", "%Y-%m-%d") + when = date(2024, 4, 2) url = gen_url(when) assert type(url) == str assert ( @@ -32,7 +33,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20240402_24.1.zip" ) - when = time.strptime("2024-10-01", "%Y-%m-%d") + when = date(2024, 10, 1) url = gen_url(when) assert type(url) == str assert ( @@ -40,7 +41,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20241001_24.1.zip" ) - when = time.strptime("2024-10-02", "%Y-%m-%d") + when = date(2024, 10, 2) url = gen_url(when) assert type(url) == str assert ( @@ -48,7 +49,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20241002_24.2.zip" ) - when = time.strptime("2024-12-31", "%Y-%m-%d") + when = date(2024, 12, 31) url = gen_url(when) assert type(url) == str assert ( @@ -58,7 +59,7 @@ def test_gen_url(): # Tests for use_version parameter - when = time.strptime("2024-12-31", "%Y-%m-%d") + when = date(2024, 12, 31) url = gen_url(when, use_version="before") assert type(url) == str assert ( @@ -66,7 +67,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20241231_24.1.zip" ) - when = time.strptime("2024-12-31", "%Y-%m-%d") + when = date(2024, 12, 31) url = gen_url(when, use_version="after") assert type(url) == str assert ( @@ -74,7 +75,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20241231_25.1.zip" ) - when = time.strptime("2024-04-02", "%Y-%m-%d") + when = date(2024, 4, 2) url = gen_url(when, use_version="before") assert type(url) == str assert ( @@ -82,7 +83,7 @@ def test_gen_url(): == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20240402_23.2.zip" ) - when = time.strptime("2024-04-02", "%Y-%m-%d") + when = date(2024, 4, 2) url = gen_url(when, use_version="after") assert type(url) == str assert ( diff --git a/tests/xml_download/test_utils_write_to_database.py b/tests/xml_download/test_utils_write_to_database.py index bf54f16d..3b26ef42 100644 --- a/tests/xml_download/test_utils_write_to_database.py +++ b/tests/xml_download/test_utils_write_to_database.py @@ -8,7 +8,20 @@ import numpy as np import pandas as pd import pytest -from sqlalchemy import create_engine, inspect +from sqlalchemy import ( + Boolean, + Column, + create_engine, + Date, + DateTime, + Double, + inspect, + Integer, + MetaData, + String, + Table, +) + from sqlalchemy.sql import text from open_mastr.utils import orm @@ -17,7 +30,6 @@ add_missing_columns_to_table, add_zero_as_first_character_for_too_short_string, cast_date_columns_to_string, - change_column_names_to_orm_format, correct_ordering_of_filelist, create_database_table, extract_sql_table_name, @@ -37,7 +49,7 @@ _xml_folder_path = os.path.join(expanduser("~"), ".open-MaStR", "data", "xml_download") if os.path.isdir(_xml_folder_path): for entry in os.scandir(path=_xml_folder_path): - if "Gesamtdatenexport" in entry.name: + if "Gesamtdatenexport" in entry.name and entry.name.endswith(".zip"): _xml_file_exists = True @@ -51,9 +63,11 @@ def capture_wrap(): @pytest.fixture(scope="module") def zipped_xml_file_path(): + # TODO: Remove this + return "/home/gorgor/.open-MaStR/data/Gesamtdatenexport_20251228.zip" zipped_xml_file_path = None for entry in os.scandir(path=_xml_folder_path): - if "Gesamtdatenexport" in entry.name: + if "Gesamtdatenexport" in entry.name and entry.name.endswith(".zip"): zipped_xml_file_path = os.path.join(_xml_folder_path, entry.name) return zipped_xml_file_path @@ -97,16 +111,6 @@ def test_is_table_relevant(): assert is_table_relevant("netzanschlusspunkte", include_tables) is False -def test_create_database_table(engine_testdb): - orm.Base.metadata.create_all(engine_testdb) - xml_table_name = "einheitenkernkraft" - sql_table_name = "nuclear_extended" - - create_database_table(engine_testdb, xml_table_name) - - assert inspect(engine_testdb).has_table(sql_table_name) is True - - def test_is_first_file(): assert is_first_file("EinheitenKernkraft.xml") is True assert is_first_file("EinheitenKernkraft_1.xml") is True @@ -114,9 +118,16 @@ def test_is_first_file(): def test_cast_date_columns_to_string(): + table = Table( + "anlageneegwasser", + MetaData(), + Column("EegMastrNummer", String, primary_key=True), + Column("Registrierungsdatum", Date), + Column("DatumLetzteAktualisierung", DateTime), + ) initial_df = pd.DataFrame( { - "EegMastrNummer": [1, 2, 3], + "EegMastrNummer": ["1", "2", "3"], "Registrierungsdatum": [ datetime(2024, 3, 11).date(), datetime(1999, 2, 1).date(), @@ -131,7 +142,7 @@ def test_cast_date_columns_to_string(): ) expected_df = pd.DataFrame( { - "EegMastrNummer": [1, 2, 3], + "EegMastrNummer": ["1", "2", "3"], "Registrierungsdatum": ["2024-03-11", "1999-02-01", np.nan], "DatumLetzteAktualisierung": [ "2022-03-22 00:00:00.000000", @@ -142,32 +153,14 @@ def test_cast_date_columns_to_string(): ) pd.testing.assert_frame_equal( - expected_df, cast_date_columns_to_string("anlageneegwasser", initial_df) + expected_df, cast_date_columns_to_string(table, initial_df) ) def test_is_date_column(): - columns = RetrofitUnits.__table__.columns.items() - df = pd.DataFrame( - { - "Id": [1], - "DatumLetzteAktualisierung": [datetime(2022, 3, 22)], - "WiederinbetriebnahmeDatum": [datetime(2024, 3, 11).date()], - } - ) - - date_column = list(filter(lambda col: col[0] == "Id", columns))[0] - assert is_date_column(date_column, df) is False - - datetime_column = list( - filter(lambda col: col[0] == "DatumLetzteAktualisierung", columns) - )[0] - assert is_date_column(datetime_column, df) is True - - date_column = list( - filter(lambda col: col[0] == "WiederinbetriebnahmeDatum", columns) - )[0] - assert is_date_column(date_column, df) is True + assert is_date_column(Column("Id", Integer, primary_key=True)) is False + assert is_date_column(Column("DatumLetzteAktualisierung", DateTime)) is True + assert is_date_column(Column("WiederinbetriebnahmeDatum", Date)) is True def test_correct_ordering_of_filelist(): @@ -226,15 +219,6 @@ def test_read_xml_file(zipped_xml_file_path): assert df.shape[0] > 0 - # Since the file is from the latest download, its content can vary over time. To make sure that the table is - # correctly created, we check that all of its columns are associated are included in our mapping. - for column in df.columns: - if column in tablename_mapping[file_name.lower()]["replace_column_names"]: - column = tablename_mapping[file_name.lower()]["replace_column_names"][ - column - ] - assert column in ElectricityConsumer.__table__.columns.keys() - def test_add_zero_as_first_character_for_too_short_string(): # Prepare @@ -251,6 +235,8 @@ def test_add_zero_as_first_character_for_too_short_string(): pd.testing.assert_frame_equal(df_edited, df_correct) +# TODO: Do we want to keep this kind of renaming? +@pytest.mark.skip def test_change_column_names_to_orm_format(): initial_df = pd.DataFrame( { @@ -307,12 +293,17 @@ def test_process_table_before_insertion(zipped_xml_file_path): def test_add_missing_columns_to_table(engine_testdb): + table = Table( + "einheitengasverbraucher", + MetaData(), + Column("EinheitMastrNummer", String, primary_key=True), + Column("DatumLetzteAktualisierung", DateTime), + ) + # We must recreate the table to be sure that the new column is not present. + table.drop(engine_testdb, checkfirst=True) + table.create(engine_testdb) with engine_testdb.connect() as con: with con.begin(): - # We must recreate the table to be sure that the new colum is not present. - con.execute(text("DROP TABLE IF EXISTS gas_consumer")) - create_database_table(engine_testdb, "einheitengasverbraucher") - initial_data_in_db = pd.DataFrame( { "EinheitMastrNummer": ["id1"], @@ -320,11 +311,11 @@ def test_add_missing_columns_to_table(engine_testdb): } ) initial_data_in_db.to_sql( - "gas_consumer", con=con, if_exists="append", index=False + table.name, con=con, if_exists="append", index=False ) add_missing_columns_to_table( - engine_testdb, "einheitengasverbraucher", ["NewColumn"] + engine_testdb, table, ["NewColumn"] ) expected_df = pd.DataFrame( @@ -336,7 +327,7 @@ def test_add_missing_columns_to_table(engine_testdb): ) with engine_testdb.connect() as con: with con.begin(): - actual_df = pd.read_sql_table("gas_consumer", con=con) + actual_df = pd.read_sql_table(table.name, con=con) # The actual_df will contain more columns than the expected_df, so we can't use assert_frame_equal. assert expected_df.index.isin(actual_df.index).all() @@ -346,13 +337,28 @@ def test_add_missing_columns_to_table(engine_testdb): [add_table_to_sqlite_database, add_table_to_non_sqlite_database], ) def test_add_table_to_sqlite_database(engine_testdb, add_table_to_database_function): - with engine_testdb.connect() as con: - with con.begin(): - # We must recreate the table to be sure that no other data is present. - con.execute(text("DROP TABLE IF EXISTS gsgk_eeg")) - create_database_table( - engine_testdb, "anlageneeggeothermiegrubengasdruckentspannung" - ) + table = Table( + "anlageneeggeothermiegrubengasdruckentspannung", + MetaData(), + Column("EegMastrNummer", String, primary_key=True), + Column("InstallierteLeistung", Double), + Column("AnlageBetriebsstatus", String), + Column("Registrierungsdatum", Date), + Column("Meldedatum", DateTime), + Column("DatumLetzteAktualisierung", DateTime), + Column("EegInbetriebnahmedatum", DateTime), + Column("VerknuepfteEinheit", String), + Column("AnlagenschluesselEeg", String), + Column("AusschreibungZuschlag", Boolean), + Column("AnlagenkennzifferAnlagenregister", String), + Column("AnlagenkennzifferAnlagenregister_nv", String), + Column("Netzbetreiberzuordnungen", String), + Column("DatenQuelle", String), + Column("DatumDownload", DateTime), + ) + # We must recreate the table to be sure that no other data is present. + table.drop(engine_testdb, checkfirst=True) + table.create(engine_testdb) df = pd.DataFrame( { @@ -369,10 +375,10 @@ def test_add_table_to_sqlite_database(engine_testdb, add_table_to_database_funct ) expected_df = pd.DataFrame( { + "EegMastrNummer": ["id1", "id2"], "InstallierteLeistung": [1.0, 100.4], "AnlageBetriebsstatus": [None, None], "Registrierungsdatum": [datetime(2022, 2, 2), datetime(2024, 3, 20)], - "EegMastrNummer": ["id1", "id2"], "Meldedatum": [np.datetime64("NaT"), np.datetime64("NaT")], "DatumLetzteAktualisierung": [ datetime(2022, 12, 2, 10, 10, 10, 300), @@ -391,12 +397,12 @@ def test_add_table_to_sqlite_database(engine_testdb, add_table_to_database_funct ) add_table_to_database_function( - df, "anlageneeggeothermiegrubengasdruckentspannung", "gsgk_eeg", engine_testdb + df, table, engine_testdb ) with engine_testdb.connect() as con: with con.begin(): pd.testing.assert_frame_equal( - expected_df, pd.read_sql_table("gsgk_eeg", con=con) + expected_df, pd.read_sql_table(table.name, con=con) )