diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 6e37a578..f22ed637 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.15.0 +current_version = 0.16.0 parse = (?P\d+)\.(?P\d+)\.(?P\d+)((?P(a|na))+(?P\d+))? serialize = {major}.{minor}.{patch}{release}{build} diff --git a/.github/workflows/ci-develop.yml b/.github/workflows/ci-develop.yml index 6a7a6457..5794d864 100644 --- a/.github/workflows/ci-develop.yml +++ b/.github/workflows/ci-develop.yml @@ -13,6 +13,7 @@ jobs: runs-on: ${{ matrix.os }} if: ${{ !github.event.pull_request.draft }} strategy: + max-parallel: 1 matrix: os: [macos-latest, ubuntu-latest, windows-latest] python-version: ['3.10', '3.11', '3.12'] diff --git a/.github/workflows/ci-production.yml b/.github/workflows/ci-production.yml index 16065860..ce18c2e9 100644 --- a/.github/workflows/ci-production.yml +++ b/.github/workflows/ci-production.yml @@ -13,6 +13,7 @@ jobs: runs-on: ${{ matrix.os }} if: ${{ !github.event.pull_request.draft }} strategy: + max-parallel: 1 matrix: os: [macos-latest, ubuntu-latest, windows-latest] python-version: ['3.10', '3.11', '3.12'] @@ -32,7 +33,7 @@ jobs: - name: create package run: python -m build --sdist - name: import open-mastr - run: python -m pip install ./dist/open_mastr-0.15.0.tar.gz + run: python -m pip install ./dist/open_mastr-0.16.0.tar.gz - name: Create credentials file env: MASTR_TOKEN: ${{ secrets.MASTR_TOKEN }} diff --git a/.github/workflows/test-pypi-publish.yml b/.github/workflows/test-pypi-publish.yml index 5fae8627..83c19222 100644 --- a/.github/workflows/test-pypi-publish.yml +++ b/.github/workflows/test-pypi-publish.yml @@ -13,8 +13,6 @@ jobs: environment: pypi-publish steps: - uses: actions/checkout@v4 - with: - ref: release - name: Set up Python 3.10 uses: actions/setup-python@v3 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 1dbe5e5b..a6382314 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,25 @@ For each version important additions, changes and removals are listed here. The format is inspired from [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and the versioning aims to respect [Semantic Versioning](http://semver.org/spec/v2.0.0.html). -## [v0.XX.X] unreleased - 202X-XX-XX +## [v0.16.0] Partial downloads with open-MaStR PartialPumpkinPull - 2025-11-26 ### Added +- Add partial bulk download + [#652](https://github.com/OpenEnergyPlatform/open-MaStR/pull/652) ### Changed +- Updates the system_catalog dict with missing Einheittyp values + [#653](https://github.com/OpenEnergyPlatform/open-MaStR/pull/653) +- Fix package publication workflow + [#636](https://github.com/OpenEnergyPlatform/open-MaStR/pull/636) +- Change print statement about data cleansing + [#650](https://github.com/OpenEnergyPlatform/open-MaStR/pull/650) +- Improve logging + [#666](https://github.com/OpenEnergyPlatform/open-MaStR/pull/666) +- Several improvements in bulk download: Support retaining old zip bulk files; + Prevent zip file deletion on full download; Add technology checks to full + bulk download + [#667](https://github.com/OpenEnergyPlatform/open-MaStR/pull/667) +- Limit number of parallel CI jobs + [#669](https://github.com/OpenEnergyPlatform/open-MaStR/pull/669) ### Removed @@ -35,7 +51,7 @@ and the versioning aims to respect [Semantic Versioning](http://semver.org/spec/ [#621](https://github.com/OpenEnergyPlatform/open-MaStR/pull/621) ### Removed - Moved old code artefacts from `scripts` folder to paper specific - [repository](https://github.com/FlorianK13/verify-marktstammdaten) + [repository](https://github.com/FlorianK13/verify-marktstammdaten) [#561](https://github.com/OpenEnergyPlatform/open-MaStR/pull/561) - Remove old dependencies and broken README links [#619](https://github.com/OpenEnergyPlatform/open-MaStR/pull/619) diff --git a/CITATION.cff b/CITATION.cff index d2fe6752..d496ecf2 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -34,10 +34,14 @@ authors: given-names: "Alexandra-Andreea" alias: "@AlexandraImbrisca" affiliation: "Technical University of Munich" + - family-names: 'Krämer' + given-names: "Kevin" + alias: "pt-kkraemer" + affiliation: "ProjectTogether gGmbH" title: "open-MaStR" type: software license: AGPL-3.0 -version: 0.15.0 +version: 0.16.0 doi: -date-released: 2025-04-19 +date-released: 2025-11-26 url: "https://github.com/OpenEnergyPlatform/open-MaStR/" diff --git a/README.rst b/README.rst index 9bbc4dbb..80250d3a 100644 --- a/README.rst +++ b/README.rst @@ -108,7 +108,6 @@ These projects already use open-mastr: - `Wasserstoffatlas `_ - `EE-Status App `_ - `Digiplan Anhalt `_ -- `Data Quality Assessment of the MaStR `_ - `EmPowerPlan `_ - `Goal100 Monitor `_ @@ -119,7 +118,7 @@ changes in a `Pull Request `_. + - The `bundesAPI/Marktstammdaten-API `_ is another implementation to access data via an official API. Collaboration @@ -146,7 +145,7 @@ Data .. |badge_license| image:: https://img.shields.io/github/license/OpenEnergyPlatform/open-MaStR - :target: LICENSE.txt + :target: LICENSE.md :alt: License .. |badge_rtd| image:: https://readthedocs.org/projects/open-mastr/badge/?style=flat diff --git a/docs/advanced.md b/docs/advanced.md index ec4ffd39..e9471f1a 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -63,7 +63,7 @@ The project home directory is structured as follows (files and folders below `da File names are defined here. * `logging.yml`
Logging configuration. For changing the log level to increase or decrease details of log - messages, edit the level of the handlers. + messages, edit the level of the handlers. See below for details on logging. * **data** * `dataversion-`
Contains exported data as csv files from method [`to_csv`][open_mastr.Mastr.to_csv] @@ -83,6 +83,19 @@ The project home directory is structured as follows (files and folders below `da For the download via the API, logs are stored in a single file in `/$HOME//.open-MaStR/logs/open_mastr.log`. New logging messages are appended. It is recommended to delete the log file from time to time because of its required disk space. +By default, the log level is set to `INFO`. You can increase or decrease the verbosity by either changing `logging.yml` (see above) +or adjusting it manually in your code. E.g. to enable `DEBUG` messages in `open_mastr.log` you can use the following snippet: + +```python + + import logging + from open_mastr import Mastr + + # Increase to DEBUG to show more details in open_mastr.log + # Must be called after importing open_mastr to have the open-MaStR logger imported + logging.getLogger("open-MaStR").setLevel(logging.DEBUG) +``` + ### Data @@ -148,8 +161,11 @@ If needed, the tables in the database can be obtained as csv files. Those files === "Disadvantages" * No single tables or entries can be downloaded - * Download takes long time + * Download takes long time (you can use the partial download though, see [Getting Started](getting_started.md#bulk-download)) +**Note**: By default, existing zip files in `$HOME/.open-MaStR/data/xml_download` are deleted when a new file is +downloaded. You can change this behavior by setting `keep_old_downloads`=True in +[`Mastr.download()`][open_mastr.Mastr.download]. ## SOAP API download diff --git a/docs/getting_started.md b/docs/getting_started.md index 5a3dd671..891efbbe 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -35,7 +35,16 @@ db = Mastr() db.download() ``` -When a `Mastr` object is initialized, a sqlite database is created in `$HOME/.open-MaStR/data/sqlite`. With the function `Mastr.download()`, the **whole MaStR is downloaded** in the zipped xml file format. It is then read into the sqlite database and simple data cleansing functions are started. +When a `Mastr` object is initialized, a sqlite database is created in `$HOME/.open-MaStR/data/sqlite`. With the function [`Mastr.download()`][open_mastr.Mastr.download], the **whole MaStR is downloaded** in the zipped xml file format. It is then read into the sqlite database and simple data cleansing functions are started. + +If you are interested in a specific part of the dataset, you can specify this by using the `data` parameter: + +```python +from open_mastr import Mastr + +db = Mastr() +db.download(data=["wind","hydro"]) +``` More detailed information can be found in the section [bulk download](advanced.md#bulk-download). diff --git a/environment.yml b/environment.yml index 104130fe..66f2bd03 100644 --- a/environment.yml +++ b/environment.yml @@ -3,4 +3,4 @@ channels: - conda-forge - defaults dependencies: - - python=3.10 + - python=3.11 diff --git a/open_mastr/mastr.py b/open_mastr/mastr.py index 646a50b1..1be39e2f 100644 --- a/open_mastr/mastr.py +++ b/open_mastr/mastr.py @@ -2,7 +2,10 @@ from sqlalchemy import inspect, create_engine # import xml dependencies -from open_mastr.xml_download.utils_download_bulk import download_xml_Mastr +from open_mastr.xml_download.utils_download_bulk import ( + download_xml_Mastr, + delete_xml_files_not_from_given_date, +) from open_mastr.xml_download.utils_write_to_database import ( write_mastr_xml_to_database, ) @@ -23,6 +26,10 @@ create_db_query, db_query_to_csv, reverse_fill_basic_units, + delete_zip_file_if_corrupted, + create_database_engine, + rename_table, + create_translated_database_engine, ) from open_mastr.utils.config import ( create_data_dir, @@ -33,13 +40,6 @@ ) import open_mastr.utils.orm as orm -# import initialize_database dependencies -from open_mastr.utils.helpers import ( - create_database_engine, - rename_table, - create_translated_database_engine, -) - # constants from open_mastr.utils.constants import TECHNOLOGIES, ADDITIONAL_TABLES @@ -92,7 +92,10 @@ def __init__(self, engine="sqlite", connect_to_translated_db=False) -> None: else: self.engine = create_database_engine(engine, self._sqlite_folder_path) - print( + log.info( + "\n==================================================\n" + "---------> open-MaStR started <---------\n" + "==================================================\n" f"Data will be written to the following database: {self.engine.url}\n" "If you run into problems, try to " "delete the database and update the package by running " @@ -107,6 +110,7 @@ def download( data=None, date=None, bulk_cleansing=True, + keep_old_downloads: bool = False, api_processes=None, api_limit=50, api_chunksize=1000, @@ -126,8 +130,8 @@ def download( from marktstammdatenregister.de, (see :ref:`Configuration `). Default to 'bulk'. data : str or list or None, optional - Determines which types of data are written to the database. If None, all data is - used. If it is a list, possible entries are listed below with respect to the download method. Missing categories are + Determines which data is partially downloaded from the bulk download and written to the database. If None, all data is downloaded and written to the database. + If it is a list, possible entries are listed below with respect to the download method. Missing categories are being developed. If only one data is of interest, this can be given as a string. Default to None, where all data is included. | Data | Bulk | API | @@ -157,7 +161,7 @@ def download( |-----------------------|------|------| | "today" | latest files are downloaded from marktstammdatenregister.de | - | | "20230101" | If file from this date exists locally, it is used. Otherwise it throws an error (You can only receive todays data from the server) | - | - | "existing" | Use latest downloaded zipped xml files, throws an error if the bulk download folder is empty | - | + | "existing" | Deprecated since 0.16, see [#616](https://github.com/OpenEnergyPlatform/open-MaStR/issues/616#issuecomment-3089377062) | - | | "latest" | - | Retrieve data that is newer than the newest data already in the table | | datetime.datetime(2020, 11, 27) | - | Retrieve data that is newer than this time stamp | | None | set date="today" | set date="latest" | @@ -168,6 +172,8 @@ def download( In its original format, many entries in the MaStR are encoded with IDs. Columns like `state` or `fueltype` do not contain entries such as "Hessen" or "Braunkohle", but instead only contain IDs. Cleansing replaces these IDs with their corresponding original entries. + keep_old_downloads: bool + If set to True, prior downloaded MaStR zip files will be kept. api_processes : int or None or "max", optional Number of parallel processes used to download additional data. Defaults to `None`. If set to "max", the maximum number of possible processes @@ -233,12 +239,20 @@ def download( xml_folder_path, f"Gesamtdatenexport_{bulk_download_date}.zip", ) - download_xml_Mastr(zipped_xml_file_path, date, xml_folder_path) - print( - f"\nWould you like to speed up the bulk download?\n" - f"Try our new parallelized processing by setting os.environ['USE_RECOMMENDED_NUMBER_OF_PROCESSES'] = True " - f"or configure your own number of processes via os.environ['NUMBER_OF_PROCESSES'] = your_number\n" + delete_zip_file_if_corrupted(zipped_xml_file_path) + if not keep_old_downloads: + delete_xml_files_not_from_given_date( + zipped_xml_file_path, + xml_folder_path, + ) + + download_xml_Mastr(zipped_xml_file_path, date, data, xml_folder_path) + + log.info( + "\nWould you like to speed up the creation of your MaStR database?\n" + "Try our new parallelized processing by setting os.environ['USE_RECOMMENDED_NUMBER_OF_PROCESSES'] = True " + "or configure your own number of processes via os.environ['NUMBER_OF_PROCESSES'] = your_number\n" ) write_mastr_xml_to_database( @@ -255,8 +269,8 @@ def download( # Set api_processes to None in order to avoid the malfunctioning usage if api_processes: api_processes = None - print( - "Warning: The implementation of parallel processes " + log.warning( + "The implementation of parallel processes " "is currently under construction. Please let " "the argument api_processes at the default value None." ) @@ -425,9 +439,11 @@ def translate(self) -> None: try: os.remove(new_path) except Exception as e: - print(f"An error occurred: {e}") + log.error( + f"An error occurred while removing old translated database: {e}" + ) - print("Replacing previous version of the translated database...") + log.info("Replacing previous version of the translated database...") for table in inspector.get_table_names(): rename_table(table, inspector.get_columns(table), self.engine) @@ -436,9 +452,9 @@ def translate(self) -> None: try: os.rename(old_path, new_path) - print(f"Database '{old_path}' changed to '{new_path}'") + log.info(f"Database '{old_path}' changed to '{new_path}'") except Exception as e: - print(f"An error occurred: {e}") + log.error(f"An error occurred while renaming database: {e}") self.engine = create_engine(f"sqlite:///{new_path}") self.is_translated = True diff --git a/open_mastr/soap_api/metadata/description.py b/open_mastr/soap_api/metadata/description.py index a4986959..728aec23 100644 --- a/open_mastr/soap_api/metadata/description.py +++ b/open_mastr/soap_api/metadata/description.py @@ -1,10 +1,13 @@ from io import BytesIO +import logging import re from urllib.request import urlopen from zipfile import ZipFile import xmltodict from collections import OrderedDict +log = logging.getLogger(__name__) + class DataDescription(object): """ @@ -150,9 +153,11 @@ def functions_data_documentation(self): fcn["sequence"]["element"]["@type"].split(":")[1] ]["sequence"]["element"] else: - print(type(fcn["sequence"])) - print(fcn["sequence"]) - raise ValueError + log.error(f"Unexpected sequence type: {type(fcn['sequence'])}") + log.error(f"Sequence content: {fcn['sequence']}") + raise ValueError( + f"Unexpected sequence structure in function metadata" + ) # Add data for inherited columns from base types if "@base" in fcn: diff --git a/open_mastr/utils/config/logging.yml b/open_mastr/utils/config/logging.yml index 64a5ac75..c1b4c29b 100644 --- a/open_mastr/utils/config/logging.yml +++ b/open_mastr/utils/config/logging.yml @@ -4,6 +4,8 @@ disable_existing_loggers: False formatters: standard: format: "%(asctime)s [%(levelname)s] %(message)s" + debug: + format: "%(asctime)s [%(levelname)s] %(name)s:%(funcName)s:%(lineno)d - %(message)s" handlers: console: @@ -12,14 +14,13 @@ handlers: class: "logging.StreamHandler" stream: "ext://sys.stdout" file: - class: "logging.FileHandler" level: "DEBUG" - formatter: "standard" + formatter: "debug" + class: "logging.FileHandler" mode: "a" -root: - level: "DEBUG" - loggers: open-MaStR: + level: "INFO" handlers: ["console", "file"] + propagate: no diff --git a/open_mastr/utils/helpers.py b/open_mastr/utils/helpers.py index ad4f4dd8..9ac2492b 100644 --- a/open_mastr/utils/helpers.py +++ b/open_mastr/utils/helpers.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from datetime import date, datetime from warnings import warn +from zipfile import BadZipfile, ZipFile import dateutil import sqlalchemy @@ -123,7 +124,7 @@ def validate_parameter_format_for_download_method( def validate_parameter_method(method) -> None: if method not in ["bulk", "API"]: - raise ValueError("parameter method has to be either 'bulk' or 'API'.") + raise ValueError("parameter method has to be either 'bulk', or 'API'.") def validate_parameter_api_location_types(api_location_types) -> None: @@ -244,18 +245,16 @@ def raise_warning_for_invalid_parameter_combinations( ) if method == "bulk" and ( - ( - any( - parameter is not None - for parameter in [ - api_processes, - api_data_types, - api_location_types, - ] - ) - or api_limit != 50 - or api_chunksize != 1000 + any( + parameter is not None + for parameter in [ + api_processes, + api_data_types, + api_location_types, + ] ) + or api_limit != 50 + or api_chunksize != 1000 ): warn( "For method = 'bulk', API related parameters (with prefix api_) are ignored." @@ -302,20 +301,17 @@ def transform_date_parameter(self, method, date, **kwargs): date = kwargs.get("bulk_date", date) date = "today" if date is None else date if date == "existing": - existing_files_list = os.listdir( - os.path.join(self.output_dir, "data", "xml_download") + log.warning( + """ + The date parameter 'existing' is deprecated and will be removed in the future. + The date parameter is set to `today`. + + If this change causes problems for you, please comment in this issue on github: + https://github.com/OpenEnergyPlatform/open-MaStR/issues/616#issuecomment-3089377062 + + """ ) - if not existing_files_list: - date = "today" - print( - "By choosing `date`='existing' you want to use an existing " - "xml download." - "However no xml_files were downloaded yet. The parameter `date` is" - "therefore set to 'today'." - ) - # we assume that there is only one file in the folder which is the - # zipped xml folder - date = existing_files_list[0].split("_")[1].split(".")[0] + date = "today" elif method == "API": date = kwargs.get("api_date", date) @@ -347,37 +343,32 @@ def print_api_settings( api_processes, api_location_types, ): - print( + log.info( f"Downloading with soap_API.\n\n -- API settings -- \nunits after date: " f"{date}\nunit download limit per data: " f"{api_limit}\nparallel_processes: {api_processes}\nchunksize: " f"{api_chunksize}\ndata_api: {data}" ) if "permit" in harmonisation_log: - print( - f"data_types: {api_data_types}" "\033[31m", + log.warning( + f"data_types: {api_data_types} - " "Attention, 'permit_data' was automatically set in api_data_types, " - "as you defined 'permit' in parameter data_api.", - "\033[m", + "as you defined 'permit' in parameter data_api." ) else: - print(f"data_types: {api_data_types}") + log.info(f"data_types: {api_data_types}") if "location" in harmonisation_log: - print( - "location_types:", - "\033[31m", - "Attention, 'location' is in parameter data. location_types are set to", - "\033[m", - f"{api_location_types}" - "\n If you want to change location_types, please remove 'location' " + log.warning( + f"location_types: {api_location_types} - " + "Attention, 'location' is in parameter data. location_types are set accordingly. " + "If you want to change location_types, please remove 'location' " "from data_api and specify api_location_types." - "\n ------------------ \n", ) else: - print( + log.info( f"location_types: {api_location_types}", "\n ------------------ \n", ) @@ -480,9 +471,7 @@ def create_db_query( unit_type_map_reversed = reverse_unit_type_map() with session_scope(engine=engine) as session: - if tech: - # Select orm tables for specified additional_data. orm_tables = { f"{dat}": getattr(orm, ORM_MAP[tech].get(dat, "KeyNotAvailable"), None) @@ -553,7 +542,6 @@ def create_db_query( return query_tech if additional_table: - orm_table = getattr(orm, ORM_MAP[additional_table], None) query_additional_tables = Query(orm_table, session=session) @@ -741,7 +729,6 @@ def db_query_to_csv(db_query, data_table: str, chunksize: int) -> None: chunk_df[col] = chunk_df[col].str.replace("\r", "") if not chunk_df.empty: - if chunk_number == 0: chunk_df.to_csv( csv_file, @@ -810,3 +797,16 @@ def create_translated_database_engine(engine, folder_path) -> sqlalchemy.engine. ) return create_engine(f"sqlite:///{db_path}") + + +def delete_zip_file_if_corrupted(save_path: str): + """ + Check if existing zip file is corrupted and if yes, delete it, if no, zipfile exists. + """ + if os.path.exists(save_path): + try: + with ZipFile(save_path) as _: + pass + except BadZipfile: + log.info(f"Bad Zip file is deleted: {save_path}") + os.remove(save_path) diff --git a/open_mastr/utils/unzip_http.py b/open_mastr/utils/unzip_http.py new file mode 100644 index 00000000..0674e130 --- /dev/null +++ b/open_mastr/utils/unzip_http.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2022 Saul Pwanson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Originally from +# https://github.com/saulpw/unzip-http +# Adjusted for our use case + +""" +usage: unzip_http [-h] [-l] [-f] [-o] url [files ...] + +Extract individual files from .zip files over http without downloading the +entire archive. HTTP server must send `Accept-Ranges: bytes` and +`Content-Length` in headers. + +positional arguments: + url URL of the remote zip file + files Files to extract. If no filenames given, displays .zip + contents (filenames and sizes). Each filename can be a + wildcard glob. + +options: + -h, --help show this help message and exit + -l, --list List files in the remote zip file + -f, --full-filepaths Recreate folder structure from zip file when extracting + (instead of extracting the files to the current + directory) + -o, --stdout Write files to stdout (if multiple files: concatenate + them to stdout, in zipfile order) +""" + +import sys +import os +import io +import math +import time +import zlib +import struct +import fnmatch +import pathlib +import urllib.parse +import zipfile +import logging + +log = logging.getLogger(__name__) + +__version__ = "0.6" + + +def error(s): + raise Exception(s) + + +def warning(s): + log.warning(s) + + +def get_bits(val: int, *args): + "Generate bitfields (one for each arg) from LSB to MSB." + for n in args: + x = val & (2**n - 1) + val >>= n + yield x + + +class RemoteZipInfo: + def __init__( + self, + filename: str = "", + date_time: int = 0, + header_offset: int = 0, + compress_type: int = 0, + compress_size: int = 0, + file_size: int = 0, + ): + self.filename = filename + self.header_offset = header_offset + self.compress_type = compress_type + self.compress_size = compress_size + self.file_size = file_size + + sec, mins, hour, day, mon, year = get_bits(date_time, 5, 6, 5, 5, 4, 7) + self.date_time = (year + 1980, mon, day, hour, mins, sec) + + def is_dir(self): + return self.filename.endswith("/") + + def parse_extra(self, extra): + i = 0 + while i < len(extra): + fieldid, fieldsz = struct.unpack_from("= 0: + ( + magic, + eocd_sz, + create_ver, + min_ver, + disk_num, + disk_start, + disk_num_records, + total_num_records, + cdir_bytes, + cdir_start, + ) = struct.unpack_from(self.fmt_eocd64, resp.data, offset=i) + else: + i = resp.data.rfind(self.magic_eocd) + if i >= 0: + ( + magic, + disk_num, + disk_start, + disk_num_records, + total_num_records, + cdir_bytes, + cdir_start, + comment_len, + ) = struct.unpack_from(self.fmt_eocd, resp.data, offset=i) + + if cdir_start < 0 or cdir_start >= self.zip_size: + error("cannot find central directory") + + if self.zip_size <= 65536: + filehdr_index = cdir_start + else: + filehdr_index = 65536 - (self.zip_size - cdir_start) + + if filehdr_index < 0: + resp = self.get_range(cdir_start, self.zip_size - cdir_start) + filehdr_index = 0 + + cdir_end = filehdr_index + cdir_bytes + while filehdr_index < cdir_end: + sizeof_cdirentry = struct.calcsize(self.fmt_cdirentry) + + ( + magic, + ver, + ver_needed, + flags, + method, + date_time, + crc, + complen, + uncomplen, + fnlen, + extralen, + commentlen, + disknum_start, + internal_attr, + external_attr, + local_header_ofs, + ) = struct.unpack_from(self.fmt_cdirentry, resp.data, offset=filehdr_index) + + filehdr_index += sizeof_cdirentry + + filename = resp.data[filehdr_index : filehdr_index + fnlen] + filehdr_index += fnlen + + extra = resp.data[filehdr_index : filehdr_index + extralen] + filehdr_index += extralen + + # comment = resp.data[filehdr_index:filehdr_index+commentlen] + filehdr_index += commentlen + + rzi = RemoteZipInfo( + filename.decode(), + date_time, + local_header_ofs, + method, + complen, + uncomplen, + ) + + rzi.parse_extra(extra) + yield rzi + + def extract(self, member, path=None, pwd=None): + if pwd: + raise NotImplementedError("Passwords not supported yet") + + path = path or pathlib.Path(".") + + outpath = path / member + os.makedirs(outpath.parent, exist_ok=True) + with self.open(member) as fpin: + with open(path / member, mode="wb") as fpout: + while True: + r = fpin.read(65536) + if not r: + break + fpout.write(r) + + def extractzip(self, member, path=None, pwd=None): + if pwd: + raise NotImplementedError("Passwords not supported yet") + + path = path or pathlib.Path(".") + outpath = path + os.makedirs(outpath.parent, exist_ok=True) + with self.open(member) as fpin: + with zipfile.ZipFile(outpath, "a", zipfile.ZIP_DEFLATED) as zout: + with zout.open(member, "w") as fpout: + while True: + r = fpin.read(65536) + if not r: + break + fpout.write(r) + + def extractall(self, path=None, members=None, pwd=None): + for fn in members or self.namelist(): + self.extract(fn, path, pwd=pwd) + + def get_range(self, start, n): + return self.http.request( + "GET", + self.url, + headers={"Range": f"bytes={start}-{start+n-1}"}, + preload_content=False, + ) + + def matching_files(self, *globs): + for f in self.files.values(): + if any(fnmatch.fnmatch(f.filename, g) for g in globs): + yield f + + def open(self, fn): + if isinstance(fn, str): + f = list(self.matching_files(fn)) + if not f: + error(f"no files matching {fn}") + f = f[0] + else: + f = fn + + sizeof_localhdr = struct.calcsize(self.fmt_localhdr) + r = self.get_range(f.header_offset, sizeof_localhdr) + localhdr = struct.unpack_from(self.fmt_localhdr, r.data) + ( + magic, + ver, + flags, + method, + dos_datetime, + _, + _, + uncomplen, + fnlen, + extralen, + ) = localhdr + if method == 0: # none + return self.get_range( + f.header_offset + sizeof_localhdr + fnlen + extralen, f.compress_size + ) + elif method == 8: # DEFLATE + resp = self.get_range( + f.header_offset + sizeof_localhdr + fnlen + extralen, f.compress_size + ) + return io.BufferedReader(RemoteZipStream(resp, f)) + else: + error(f"unknown compression method {method}") + + def open_text(self, fn): + return io.TextIOWrapper(self.open(fn)) + + +class RemoteZipStream(io.RawIOBase): + def __init__(self, fp, info): + super().__init__() + self.raw = fp + self._decompressor = zlib.decompressobj(-15) + self._buffer = bytes() + + def readable(self): + return True + + def readinto(self, b): + r = self.read(len(b)) + b[: len(r)] = r + return len(r) + + def read(self, n): + while n > len(self._buffer): + r = self.raw.read(2**18) + if not r: + self._buffer += self._decompressor.flush() + break + self._buffer += self._decompressor.decompress(r) + + ret = self._buffer[:n] + self._buffer = self._buffer[n:] + + return ret + + +### script start + + +class StreamProgress: + def __init__(self, fp, name="", total=0): + self.name = name + self.fp = fp + self.total = total + self.start_time = time.time() + self.last_update = 0 + self.amtread = 0 + + def read(self, n): + r = self.fp.read(n) + self.amtread += len(r) + now = time.time() + if now - self.last_update > 0.1: + self.last_update = now + + elapsed_s = now - self.start_time + sys.stderr.write( + f"\r{elapsed_s:.0f}s {self.amtread/10**6:.02f}/{self.total/10**6:.02f}MB ({self.amtread/10**6/elapsed_s:.02f} MB/s) {self.name}" + ) + + if not r: + sys.stderr.write("\n") + + return r + + +def list_files(rzf): + def safelog(x): + return 1 if x == 0 else math.ceil(math.log10(x)) + + digits_compr = max(safelog(f.compress_size) for f in rzf.infolist()) + digits_plain = max(safelog(f.file_size) for f in rzf.infolist()) + fmtstr = f"%{digits_compr}d -> %{digits_plain}d\t%s" + for f in rzf.infolist(): + log.info(fmtstr % (f.compress_size, f.file_size, f.filename)) + + +def extract_one(outfile, rzf, f, ofname): + log.info(f"Extracting {f.filename} to {ofname}...") + + fp = StreamProgress(rzf.open(f), name=f.filename, total=f.compress_size) + while r := fp.read(2**18): + outfile.write(r) + + +def download_file(f, rzf, args): + if not any(fnmatch.fnmatch(f.filename, g) for g in args.files): + return + + if args.stdout: + extract_one(sys.stdout.buffer, rzf, f, "stdout") + else: + path = pathlib.Path(f.filename) + if args.full_filepaths: + path.parent.mkdir(parents=True, exist_ok=True) + else: + path = path.name + + with open(str(path), "wb") as of: + extract_one(of, rzf, f, str(path)) diff --git a/open_mastr/xml_download/colums_to_replace.py b/open_mastr/xml_download/colums_to_replace.py index 8e6ead17..421ac44c 100644 --- a/open_mastr/xml_download/colums_to_replace.py +++ b/open_mastr/xml_download/colums_to_replace.py @@ -23,6 +23,20 @@ 3: "Gaserzeugungslokation", 4: "Gasverbrauchslokation", }, + "Einheittyp": { + 1: "Solareinheit", + 2: "Windeinheit", + 3: "Biomasse", + 4: "Wasser", + 5: "Geothermie", + 6: "Verbrennung", + 7: "Kernenergie", + 8: "Stromspeichereinheit", + 9: "Stromverbrauchseinheit", + 10: "Gasverbrauchseinheit", + 11: "Gaserzeugungseinheit", + 12: "Gasspeichereinheit", + }, } # columns to replace lists all columns where the entries have diff --git a/open_mastr/xml_download/utils_download_bulk.py b/open_mastr/xml_download/utils_download_bulk.py index 02e69e84..a8d37ae3 100644 --- a/open_mastr/xml_download/utils_download_bulk.py +++ b/open_mastr/xml_download/utils_download_bulk.py @@ -2,7 +2,8 @@ import shutil import time from importlib.metadata import PackageNotFoundError, version -from zipfile import BadZipfile, ZipFile +from zipfile import ZipFile +from pathlib import Path import numpy as np import requests @@ -10,6 +11,8 @@ # setup logger from open_mastr.utils.config import setup_logger +from open_mastr.utils.constants import BULK_INCLUDE_TABLES_MAP, BULK_DATA +from open_mastr.utils import unzip_http try: USER_AGENT = ( @@ -113,44 +116,25 @@ def gen_url(when: time.struct_time = time.localtime(), use_version="current") -> def download_xml_Mastr( - save_path: str, bulk_date_string: str, xml_folder_path: str + save_path: str, bulk_date_string: str, bulk_data_list: list, xml_folder_path: str ) -> None: """Downloads the zipped MaStR. Parameters ----------- save_path: str - The path where the downloaded MaStR zipped folder will be saved. + Full file path where the downloaded MaStR zip file will be saved. + bulk_date_string: str + Date for which the file should be downloaded. + bulk_data_list: list + List of tables/technologis to be downloaded. + xml_folder_path: str + Path where the downloaded MaStR zip file will be saved. """ - if os.path.exists(save_path): - try: - _ = ZipFile(save_path) - except BadZipfile: - log.info(f"Bad Zip file is deleted: {save_path}") - os.remove(save_path) - else: - print("MaStR already downloaded.") - return None - - if bulk_date_string != "today": - raise OSError( - "There exists no file for given date. MaStR can only be downloaded " - "from the website if today's date is given." - ) - shutil.rmtree(xml_folder_path, ignore_errors=True) - os.makedirs(xml_folder_path, exist_ok=True) - - print_message = ( - "Download has started, this can take several minutes." - "The download bar is only a rough estimate." - ) - warning_message = ( - "Warning: The servers from MaStR restrict the download speed." - " You may want to download it another time." - ) - print(print_message) + log.info("Starting the Download from marktstammdatenregister.de.") + # TODO this should take bulk_date_string now = time.localtime() url = gen_url(now) @@ -182,10 +166,165 @@ def download_xml_Mastr( log.error("Could not download file: download URL not found") return - total_length = int(18000 * 1024 * 1024) + if bulk_data_list == BULK_DATA: + full_download_without_unzip_http(save_path, r, bulk_data_list) + else: + try: + partial_download_with_unzip_http(save_path, url, bulk_data_list) + except Exception as e: + log.warning(f"Partial download failed, fallback to full download: {e}") + full_download_without_unzip_http(save_path, r, bulk_data_list) + + time_b = time.perf_counter() + log.info( + f"Download is finished. It took {int(np.around(time_b - time_a))} seconds." + ) + log.info(f"MaStR was successfully downloaded to {xml_folder_path}.") + + +def check_download_completeness( + save_path: str, bulk_data_list: list +) -> tuple[list, bool]: + """Checks if an existing download contains the xml-files corresponding to the bulk_data_list.""" + with ZipFile(save_path, "r") as zip_ref: + existing_files = [ + zip_name.lower().split("_")[0].split(".")[0] + for zip_name in zip_ref.namelist() + ] + + missing_data_set = set() + for bulk_data_name in bulk_data_list: + for bulk_file_name in BULK_INCLUDE_TABLES_MAP[bulk_data_name]: + if bulk_file_name not in existing_files: + missing_data_set.add(bulk_data_name) + + is_katalogwerte_existing = False + if "katalogwerte" in existing_files: + is_katalogwerte_existing = True + return list(missing_data_set), is_katalogwerte_existing + + +def delete_xml_files_not_from_given_date( + save_path: str, + xml_folder_path: str, +) -> None: + """ + Delete xml files that are not corresponding to the given date. + Assumes that the xml folder only contains one zipfile. + + Parameters + ---------- + save_path: str + Full file path where the downloaded MaStR zip file will be saved. + xml_folder_path: str + Path where the downloaded MaStR zip file will be saved. + """ + if os.path.exists(save_path): + return + else: + shutil.rmtree(xml_folder_path) + os.makedirs(xml_folder_path) + + +def partial_download_with_unzip_http(save_path: str, url: str, bulk_data_list: list): + """ + + Parameters + ---------- + save_path: str + Full file path where the downloaded MaStR zip file will be saved. + url: str + URL path to bulk file. + bulk_data_list: list + List of tables/technologis to be downloaded. + + Returns + ------- + None + """ + is_katalogwerte_existing = False + if os.path.exists(save_path): + bulk_data_list, is_katalogwerte_existing = check_download_completeness( + save_path, bulk_data_list + ) + if bool(bulk_data_list): + log.info( + f"MaStR file already present but missing the following data: {bulk_data_list}" + ) + else: + log.info(f"MaStR file already present: {save_path}") + return None + + remote_zip_file = unzip_http.RemoteZipFile(url) + remote_zip_names = [ + remote_zip_name.lower().split("_")[0].split(".")[0] + for remote_zip_name in remote_zip_file.namelist() + ] + + remote_index_list = [] + download_files_list = [] + for bulk_data_name in bulk_data_list: + # Example: ['wind','solar'] + for bulk_file_name in BULK_INCLUDE_TABLES_MAP[bulk_data_name]: + # Example: From "wind" we get ["anlageneegwind", "einheitenwind"], and from "solar" we get ["anlageneegsolar", "einheitensolar"] + # and we have to find the corresponding index in the remote_zip_file list in order to fetch the correct file + remote_index_list = [ + remote_index + for remote_index, remote_zip_name in enumerate(remote_zip_names) + if remote_zip_name == bulk_file_name + ] + # for remote_index in tqdm(remote_index_list): + for remote_index in remote_index_list: + # Example: remote_zip_file.namelist()[remote_index] corresponds to e.g. 'AnlagenEegSolar_1.xml' + download_files_list.append(remote_zip_file.namelist()[remote_index]) + + for zipfile_name in tqdm(download_files_list, unit=" file"): + remote_zip_file.extractzip(zipfile_name, path=Path(save_path)) + + if not is_katalogwerte_existing: + remote_zip_file.extractzip("Katalogwerte.xml", path=Path(save_path)) + + +def full_download_without_unzip_http( + save_path: str, + r: requests.models.Response, + bulk_data_list: list, +) -> None: + """ + + Parameters + ---------- + save_path: str + Full file path where the downloaded MaStR zip file will be saved. + r: requests.models.Response + Response from making a request to MaStR. + bulk_data_list: list + List of tables/technologis to be downloaded. + + Returns + ------- + None + """ + if os.path.exists(save_path): + bulk_data_list, is_katalogwerte_existing = check_download_completeness( + save_path, bulk_data_list + ) + if bool(bulk_data_list): + print( + f"MaStR file already present but missing the following data: {bulk_data_list}" + ) + else: + print(f"MaStR file already present: {save_path}") + return None + + warning_message = ( + "Warning: The servers from MaStR restrict the download speed." + " You may want to download it another time." + ) + total_length = int(23000) with ( open(save_path, "wb") as zfile, - tqdm(desc=save_path, total=(total_length / 1024 / 1024), unit="") as bar, + tqdm(desc=save_path, total=total_length, unit="") as bar, ): for chunk in r.iter_content(chunk_size=1024 * 1024): # chunk size of 1024 * 1024 needs 9min 11 sek = 551sek @@ -200,6 +339,3 @@ def download_xml_Mastr( else: # remove warning bar.set_postfix_str(s="") - time_b = time.perf_counter() - print(f"Download is finished. It took {int(np.around(time_b - time_a))} seconds.") - print(f"MaStR was successfully downloaded to {xml_folder_path}.") diff --git a/open_mastr/xml_download/utils_write_to_database.py b/open_mastr/xml_download/utils_write_to_database.py index 4b220909..e71abc18 100644 --- a/open_mastr/xml_download/utils_write_to_database.py +++ b/open_mastr/xml_download/utils_write_to_database.py @@ -19,6 +19,8 @@ from open_mastr.utils.orm import tablename_mapping from open_mastr.xml_download.utils_cleansing_bulk import cleanse_bulk_data +log = setup_logger() + def write_mastr_xml_to_database( engine: sqlalchemy.engine.Engine, @@ -28,7 +30,7 @@ def write_mastr_xml_to_database( bulk_download_date: str, ) -> None: """Write the Mastr in xml format into a database defined by the engine parameter.""" - print("Starting bulk download and data cleansing...") + log.info("Starting bulk download...") include_tables = data_to_include_tables(data, mapping="write_xml") threads_data = [] @@ -71,7 +73,7 @@ def write_mastr_xml_to_database( for item in interleaved_files: process_xml_file(*item) - print("Bulk download and data cleansing were successful.") + log.info("Bulk download was successful.") def get_number_of_processes(): @@ -82,11 +84,11 @@ def get_number_of_processes(): try: number_of_processes = int(os.environ.get("NUMBER_OF_PROCESSES")) except ValueError: - print("Warning: Invalid value for NUMBER_OF_PROCESSES. Fallback to 1.") + log.warning("Invalid value for NUMBER_OF_PROCESSES. Fallback to 1.") return 1 if number_of_processes >= cpu_count(): - print( - f"Warning: Your system supports {cpu_count()} CPUs. Using " + log.warning( + f"Your system supports {cpu_count()} CPUs. Using " f"more processes than available CPUs may cause excessive " f"context-switching overhead." ) @@ -118,9 +120,9 @@ def process_xml_file( # The connection url obfuscates the password. We must replace the masked password with the actual password. engine = create_efficient_engine(connection_url) with ZipFile(zipped_xml_file_path, "r") as f: - print(f"Processing file '{file_name}'...") + log.info(f"Processing file '{file_name}'...") if is_first_file(file_name): - print(f"Creating table '{sql_table_name}'...") + log.info(f"Creating table '{sql_table_name}'...") create_database_table(engine, xml_table_name) df = read_xml_file(f, file_name) df = process_table_before_insertion( @@ -137,7 +139,7 @@ def process_xml_file( df, xml_table_name, sql_table_name, engine ) except Exception as e: - print(f"Error processing file '{file_name}': '{e}'") + log.error(f"Error processing file '{file_name}': '{e}'") def create_efficient_engine(connection_url: str) -> sqlalchemy.engine.Engine: @@ -224,7 +226,7 @@ def is_table_relevant(xml_table_name: str, include_tables: list) -> bool: tablename_mapping[xml_table_name]["__class__"] is not None ) except KeyError: - print( + log.warning( f"Table '{xml_table_name}' is not supported by your open-mastr version and " f"will be skipped." ) @@ -451,7 +453,7 @@ def write_single_entries_until_not_unique_comes_up( labels=key_list, errors="ignore" ) # drop primary keys that already exist in the table df = df.reset_index() - print(f"{len_df_before - len(df)} entries already existed in the database.") + log.warning(f"{len_df_before - len(df)} entries already existed in the database.") return df @@ -509,7 +511,7 @@ def add_missing_columns_to_table( def delete_wrong_xml_entry(err: Error, df: pd.DataFrame) -> pd.DataFrame: delete_entry = str(err).split("«")[0].split("»")[1] - print(f"The entry {delete_entry} was deleted due to its false data type.") + log.warning(f"The entry {delete_entry} was deleted due to its false data type.") return df.replace(delete_entry, np.nan) @@ -548,7 +550,7 @@ def find_nearest_brackets(xml_string: str, position: int) -> tuple[int, int]: row_with_error[: left_bracket + 1] + row_with_error[right_bracket:] ) try: - print("One invalid xml expression was deleted.") + log.warning("One invalid xml expression was deleted.") df = pd.read_xml(StringIO("\n".join(data))) return df except lxml.etree.XMLSyntaxError as e: diff --git a/pyproject.toml b/pyproject.toml index a4fcb367..5871bfbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "open_mastr" -version = "0.15.0" +version = "0.16.0" dependencies = [ "pandas>=2.2.2", "numpy", @@ -79,4 +79,4 @@ open_mastr = [ include = ["open_mastr", "open_mastr.soap_api", "open_mastr.soap_api.metadata", "open_mastr.utils", "open_mastr.utils.config", "open_mastr.xml_download"] # package names should match these glob patterns (["*"] by default) # from setup.py - not yet included in here -# download_url="https://github.com/OpenEnergyPlatform/open-MaStR/archive""/refs/tags/v0.15.0.tar.gz", +# download_url="https://github.com/OpenEnergyPlatform/open-MaStR/archive""/refs/tags/v0.16.0.tar.gz", diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 4a19f4fb..7779a9c8 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -7,6 +7,7 @@ from datetime import datetime import pandas as pd from open_mastr import Mastr +from zipfile import ZipFile from open_mastr.utils import orm from open_mastr.utils.constants import ( @@ -25,6 +26,7 @@ create_db_query, db_query_to_csv, reverse_unit_type_map, + delete_zip_file_if_corrupted, ) @@ -398,6 +400,18 @@ def test_db_query_to_csv(tmpdir, engine): os.rmdir(get_data_version_dir()) +def test_delete_zip_file_if_corrupted(): + test_zip_path = os.path.join("tests", "test.zip") + with ZipFile(test_zip_path, "w") as zf: + zf.writestr(os.path.join("tests", "file.txt"), "Hello, world!") + with open(test_zip_path, "wb+") as f: + f.seek(10) + f.write(b"\xff\xff\xff\xff") + + delete_zip_file_if_corrupted(test_zip_path) + assert not os.path.exists(test_zip_path) + + def test_save_metadata(): # FIXME: implement in #386 pass diff --git a/tests/test_mastr.py b/tests/test_mastr.py index 9fe8883b..16f7c1f6 100644 --- a/tests/test_mastr.py +++ b/tests/test_mastr.py @@ -1,10 +1,14 @@ +import shutil + from open_mastr.mastr import Mastr import os +import re import sqlalchemy import pytest from os.path import expanduser import pandas as pd from open_mastr.utils.constants import TRANSLATIONS +from datetime import date, timedelta _xml_file_exists = False _xml_folder_path = os.path.join(expanduser("~"), ".open-MaStR", "data", "xml_download") @@ -14,9 +18,14 @@ _xml_file_exists = True -@pytest.fixture -def db(): - return Mastr() +@pytest.fixture(scope="module") +def zipped_xml_file_path(): + zipped_xml_file_path = None + for entry in os.scandir(path=_xml_folder_path): + if "Gesamtdatenexport" in entry.name: + zipped_xml_file_path = os.path.join(_xml_folder_path, entry.name) + + return zipped_xml_file_path @pytest.fixture @@ -26,6 +35,11 @@ def db_path(): ) +@pytest.fixture +def db(db_path): + return Mastr(engine=sqlalchemy.create_engine(f"sqlite:///{db_path}")) + + @pytest.fixture def db_translated(db_path): engine = sqlalchemy.create_engine(f"sqlite:///{db_path}") @@ -71,3 +85,27 @@ def test_Mastr_translate(db_translated, db_path): for table in table_names: assert pd.read_sql(sql=table, con=db_empty.engine).shape[0] == 0 + + +@pytest.mark.dependency(name="bulk_downloaded") +def test_mastr_download(db): + db.download(data="wind") + df_wind = pd.read_sql("wind_extended", con=db.engine) + assert len(df_wind) > 10000 + + db.download(data="biomass") + df_biomass = pd.read_sql("biomass_extended", con=db.engine) + assert len(df_wind) > 10000 + assert len(df_biomass) > 10000 + + +@pytest.mark.dependency(depends=["bulk_downloaded"]) +def test_mastr_download_keep_old_files(db, zipped_xml_file_path): + file_today = zipped_xml_file_path + yesterday = (date.today() - timedelta(days=1)).strftime("%Y%m%d") + file_old = re.sub(r"\d{8}", yesterday, os.path.basename(file_today)) + file_old = os.path.join(os.path.dirname(zipped_xml_file_path), file_old) + shutil.copy(file_today, file_old) + db.download(data="gsgk", keep_old_files=True) + + assert os.path.exists(file_old) diff --git a/tests/xml_download/test_utils_download_bulk.py b/tests/xml_download/test_utils_download_bulk.py index e1f60bb0..8f650933 100644 --- a/tests/xml_download/test_utils_download_bulk.py +++ b/tests/xml_download/test_utils_download_bulk.py @@ -1,5 +1,10 @@ import time -from open_mastr.xml_download.utils_download_bulk import gen_url +from open_mastr.xml_download.utils_download_bulk import ( + gen_url, + delete_xml_files_not_from_given_date, +) +import os +import shutil def test_gen_url(): @@ -84,3 +89,27 @@ def test_gen_url(): url == "https://download.marktstammdatenregister.de/Gesamtdatenexport_20240402_24.2.zip" ) + + +def test_delete_xml_files_not_from_given_date(): + xml_folder_path = os.path.join("tests", "test_utils_download") + expected_file = os.path.join(xml_folder_path, "20250102.txt") + os.makedirs(xml_folder_path) + + # Case where expected file exists + open(expected_file, "w").close() + delete_xml_files_not_from_given_date( + save_path=expected_file, xml_folder_path=xml_folder_path + ) + assert os.path.exists(expected_file) + os.remove(expected_file) + + # Case where old date is deleted + path_old_file = os.path.join(xml_folder_path, "20250101.txt") + open(path_old_file, "w").close() + delete_xml_files_not_from_given_date( + save_path=expected_file, xml_folder_path=xml_folder_path + ) + assert not os.path.exists(path_old_file) + # clean up test folder + shutil.rmtree(xml_folder_path)