From 5475d87039a715d8d7c6c4e05f9841451f6222a8 Mon Sep 17 00:00:00 2001 From: johnfolly Date: Wed, 8 Apr 2026 17:50:56 +0200 Subject: [PATCH 01/12] fix: update foreign key constraints and ensure uniqueness in runner metadata --- cosmotech/coal/postgresql/runner.py | 2 +- cosmotech/coal/postgresql/store.py | 2 +- cosmotech/coal/postgresql/utils.py | 17 ++++++++++++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cosmotech/coal/postgresql/runner.py b/cosmotech/coal/postgresql/runner.py index 34497bcd..503a24dd 100644 --- a/cosmotech/coal/postgresql/runner.py +++ b/cosmotech/coal/postgresql/runner.py @@ -49,7 +49,7 @@ def send_runner_metadata_to_postgresql( CREATE TABLE IF NOT EXISTS {schema_table} ( id varchar(32) PRIMARY KEY, name varchar(256), - last_csm_run_id varchar(32), + last_csm_run_id varchar(32) UNIQUE, run_template_id varchar(32) ); """ diff --git a/cosmotech/coal/postgresql/store.py b/cosmotech/coal/postgresql/store.py index 8db8fb72..95d6c422 100644 --- a/cosmotech/coal/postgresql/store.py +++ b/cosmotech/coal/postgresql/store.py @@ -115,7 +115,7 @@ def dump_store_to_postgresql_from_conf( ) if fk_id and _psql.is_metadata_exists(): metadata_table = f"{_psql.metadata_table_name}" - _psql.add_fk_constraint(table_name, "csm_run_id", metadata_table, "last_csm_run_id") + _psql.add_fk_constraint(target_table_name, "csm_run_id", metadata_table, "last_csm_run_id") total_rows += rows _up_time = perf_counter() diff --git a/cosmotech/coal/postgresql/utils.py b/cosmotech/coal/postgresql/utils.py index 4d863608..80b27c86 100644 --- a/cosmotech/coal/postgresql/utils.py +++ b/cosmotech/coal/postgresql/utils.py @@ -155,12 +155,23 @@ def add_fk_constraint( to_table: str, to_col: str, ) -> None: - # Connect to PostgreSQL and remove runner metadata row + # Connect to PostgreSQL and add a foreign key constraint with dbapi.connect(self.full_uri, autocommit=True) as conn: with conn.cursor() as curs: sql_add_fk = f""" - ALTER TABLE {self.db_schema}.{from_table} - CONSTRAINT metadata FOREIGN KEY ({from_col}) REFERENCES {to_table}({to_col}) + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'metadata' + AND conrelid = '{self.db_schema}.{from_table}'::regclass + ) THEN + ALTER TABLE {self.db_schema}.{from_table} + ADD CONSTRAINT metadata FOREIGN KEY ({from_col}) + REFERENCES {self.db_schema}.{to_table}({to_col}); + END IF; + END $$; """ curs.execute(sql_add_fk) conn.commit() From 57591eae31264d261ba6ec6432d8c3ee371ba10b Mon Sep 17 00:00:00 2001 From: johnfolly Date: Thu, 9 Apr 2026 09:54:15 +0200 Subject: [PATCH 02/12] fix: add metadata cleanup trigger to prevent foreign key violations --- cosmotech/coal/postgresql/runner.py | 10 ++++++++++ cosmotech/coal/postgresql/utils.py | 3 ++- .../coal/test_postgresql/test_postgresql_runner.py | 11 ++++++++--- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/cosmotech/coal/postgresql/runner.py b/cosmotech/coal/postgresql/runner.py index 503a24dd..3649beb8 100644 --- a/cosmotech/coal/postgresql/runner.py +++ b/cosmotech/coal/postgresql/runner.py @@ -56,6 +56,16 @@ def send_runner_metadata_to_postgresql( LOGGER.info(T("coal.services.postgresql.creating_table").format(schema_table=schema_table)) curs.execute(sql_create_table) conn.commit() + + last_run_id = runner.get("lastRunInfo").get("lastRunId") + if last_run_id: + sql_delete_from_metatable = f""" + DELETE FROM {schema_table} + WHERE last_csm_run_id= $1; + """ + curs.execute(sql_delete_from_metatable, (last_run_id,)) + conn.commit() + sql_upsert = f""" INSERT INTO {schema_table} (id, name, last_csm_run_id, run_template_id) VALUES ($1, $2, $3, $4) diff --git a/cosmotech/coal/postgresql/utils.py b/cosmotech/coal/postgresql/utils.py index 80b27c86..4ac078ea 100644 --- a/cosmotech/coal/postgresql/utils.py +++ b/cosmotech/coal/postgresql/utils.py @@ -169,7 +169,8 @@ def add_fk_constraint( ) THEN ALTER TABLE {self.db_schema}.{from_table} ADD CONSTRAINT metadata FOREIGN KEY ({from_col}) - REFERENCES {self.db_schema}.{to_table}({to_col}); + REFERENCES {self.db_schema}.{to_table}({to_col}) + ON DELETE CASCADE; END IF; END $$; """ diff --git a/tests/unit/coal/test_postgresql/test_postgresql_runner.py b/tests/unit/coal/test_postgresql/test_postgresql_runner.py index 870a1f89..83789a07 100644 --- a/tests/unit/coal/test_postgresql/test_postgresql_runner.py +++ b/tests/unit/coal/test_postgresql/test_postgresql_runner.py @@ -72,14 +72,19 @@ def test_send_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ut mock_connect.assert_called_once_with("postgresql://user:password@localhost:5432/testdb", autocommit=True) # Check that SQL statements were executed - assert mock_cursor.execute.call_count == 2 + assert mock_cursor.execute.call_count == 3 # Verify the SQL statements (partially, since the exact SQL is complex) create_table_call = mock_cursor.execute.call_args_list[0] assert "CREATE TABLE IF NOT EXISTS" in create_table_call[0][0] assert "public.test_runnermetadata" in create_table_call[0][0] - upsert_call = mock_cursor.execute.call_args_list[1] + delete_call = mock_cursor.execute.call_args_list[1] + assert "DELETE FROM" in delete_call[0][0] + assert "public.test_runnermetadata" in delete_call[0][0] + assert delete_call[0][1] == ("test-run-id",) + + upsert_call = mock_cursor.execute.call_args_list[2] assert "INSERT INTO" in upsert_call[0][0] assert "public.test_runnermetadata" in upsert_call[0][0] assert upsert_call[0][1] == ( @@ -90,7 +95,7 @@ def test_send_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ut ) # Check that commits were called - assert mock_conn.commit.call_count == 2 + assert mock_conn.commit.call_count == 3 # Verify the function returns the lastRunId assert result == "test-run-id" From 85a3c38b35f47b40488e9dd4eaeefdcf04179aa0 Mon Sep 17 00:00:00 2001 From: johnfolly Date: Thu, 9 Apr 2026 12:32:04 +0200 Subject: [PATCH 03/12] fix: update foreign key reference in runner metadata deletion logic --- cosmotech/coal/postgresql/runner.py | 15 +++++++-------- .../test_postgresql/test_postgresql_runner.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/cosmotech/coal/postgresql/runner.py b/cosmotech/coal/postgresql/runner.py index 3649beb8..21c454e5 100644 --- a/cosmotech/coal/postgresql/runner.py +++ b/cosmotech/coal/postgresql/runner.py @@ -57,14 +57,13 @@ def send_runner_metadata_to_postgresql( curs.execute(sql_create_table) conn.commit() - last_run_id = runner.get("lastRunInfo").get("lastRunId") - if last_run_id: - sql_delete_from_metatable = f""" - DELETE FROM {schema_table} - WHERE last_csm_run_id= $1; - """ - curs.execute(sql_delete_from_metatable, (last_run_id,)) - conn.commit() + runner_id = runner.get("id") + sql_delete_from_metatable = f""" + DELETE FROM {schema_table} + WHERE id= $1; + """ + curs.execute(sql_delete_from_metatable, (runner_id,)) + conn.commit() sql_upsert = f""" INSERT INTO {schema_table} (id, name, last_csm_run_id, run_template_id) diff --git a/tests/unit/coal/test_postgresql/test_postgresql_runner.py b/tests/unit/coal/test_postgresql/test_postgresql_runner.py index 83789a07..33c3f4a8 100644 --- a/tests/unit/coal/test_postgresql/test_postgresql_runner.py +++ b/tests/unit/coal/test_postgresql/test_postgresql_runner.py @@ -82,7 +82,7 @@ def test_send_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ut delete_call = mock_cursor.execute.call_args_list[1] assert "DELETE FROM" in delete_call[0][0] assert "public.test_runnermetadata" in delete_call[0][0] - assert delete_call[0][1] == ("test-run-id",) + assert delete_call[0][1] == ("test-runner-id",) upsert_call = mock_cursor.execute.call_args_list[2] assert "INSERT INTO" in upsert_call[0][0] From 7a13c1f2fb5c4bfab8bcc8e87448959382976d8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Thu, 9 Apr 2026 16:31:23 +0200 Subject: [PATCH 04/12] chore: remove unecessary code check on conflict is not necessary anymore as we remove the row beforehand --- cosmotech/coal/postgresql/runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cosmotech/coal/postgresql/runner.py b/cosmotech/coal/postgresql/runner.py index 21c454e5..9236f3e6 100644 --- a/cosmotech/coal/postgresql/runner.py +++ b/cosmotech/coal/postgresql/runner.py @@ -67,10 +67,7 @@ def send_runner_metadata_to_postgresql( sql_upsert = f""" INSERT INTO {schema_table} (id, name, last_csm_run_id, run_template_id) - VALUES ($1, $2, $3, $4) - ON CONFLICT (id) - DO - UPDATE SET name = EXCLUDED.name, last_csm_run_id = EXCLUDED.last_csm_run_id; + VALUES ($1, $2, $3, $4) """ LOGGER.debug(runner) curs.execute( From 902123bcd8371eadfc749a3ab1cc2ac078899c99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Thu, 9 Apr 2026 16:36:21 +0200 Subject: [PATCH 05/12] add input_collector class this aims to simplify input (datasets and parameters) retrival after a run_load_data --- cosmotech/coal/utils/input_collector.py | 86 +++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 cosmotech/coal/utils/input_collector.py diff --git a/cosmotech/coal/utils/input_collector.py b/cosmotech/coal/utils/input_collector.py new file mode 100644 index 00000000..d4fd9e53 --- /dev/null +++ b/cosmotech/coal/utils/input_collector.py @@ -0,0 +1,86 @@ +import json +import os +from pathlib import Path + +from cosmotech.coal.utils.configuration import ENVIRONMENT_CONFIGURATION as EC + + +class InputCollector: + def __init__(self): + self.dataset_collector = DatasetCollector() + self.parameter_collector = ParameterCollector() + + def fetch_dataset(self, dataset_name: str) -> Path: + return self.dataset_collector.fetch(dataset_name) + + def fetch_parameter(self, param_name: str) -> Path: + return self.parameter_collector.fetch(param_name) + + def fetch(self, name: str) -> Path: + try: + return self.fetch_parameter(name) + except (KeyError, FileNotFoundError): + return self.fetch_dataset(name) + + +class DatasetCollector: + def __init__(self): + self.paths: dict[str, Path] = {} + + def collect(self): + for dataset_id in os.listdir(EC.cosmotech.dataset_absolute_path): + for r, d, f in os.walk(Path(EC.cosmotech.dataset_absolute_path) / dataset_id): + print(f"dataset: {r} {d} {f}") + for dataset_name in f: + path = Path(r) / dataset_name + self.paths[dataset_name] = path + + def fetch(self, dataset_name: str) -> Path: + # lazy collection to avoid unnecessary os.walk calls + if not self.paths: + self.collect() + if dataset_name in self.paths: + return self.paths[dataset_name] + raise FileNotFoundError(f"File for {dataset_name} not found in {EC.cosmotech.dataset_absolute_path}.") + + +class ParameterCollector: + def __init__(self): + self.paths: dict[str, Path] = {} + self.parameters: dict[str, str] = {} + parameter_file = Path(EC.cosmotech.parameters_absolute_path) / "parameters.json" + if parameter_file.exists(): + with open(parameter_file) as f: + parameters = json.load(f) + print(f"{parameters=}") + for parameter in parameters: + self.parameters[parameter["parameterId"]] = parameter["value"] + + def collect(self): + for dataset_id in os.listdir(EC.cosmotech.dataset_absolute_path): + for r, d, f in os.walk(Path(EC.cosmotech.parameters_absolute_path) / dataset_id): + print(f"parameter: {r} {d} {f}") + for file_name in f: + path = Path(r) / file_name + param_name = path.parent.name + self.paths[param_name] = path + + def fetch_parameter(self, param_name: str) -> Path: + return self.parameters[param_name] + + def fetch_file_path(self, param_name: str) -> Path: + # lazy collection to avoid unnecessary os.walk calls + if not self.paths: + self.collect() + if param_name in self.paths: + return self.paths[param_name] + raise FileNotFoundError(f"File for {param_name} not found in {EC.cosmotech.parameters_absolute_path}.") + + def fetch(self, param_name: str) -> Path: + if param_name in self.parameters: + return self.parameters[param_name] + else: + return self.fetch_file_path(param_name) + + +ENVIRONMENT_INPUT_COLLECTOR = InputCollector() From a47b99543e48817054155351909c5ef6efcd4354 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Fri, 10 Apr 2026 10:56:07 +0200 Subject: [PATCH 06/12] add input collector input collector is a scraping tools taht search in dataset and parameter folder and give easy access to this data --- cosmotech/coal/utils/input_collector.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cosmotech/coal/utils/input_collector.py b/cosmotech/coal/utils/input_collector.py index d4fd9e53..f0fbf112 100644 --- a/cosmotech/coal/utils/input_collector.py +++ b/cosmotech/coal/utils/input_collector.py @@ -30,7 +30,6 @@ def __init__(self): def collect(self): for dataset_id in os.listdir(EC.cosmotech.dataset_absolute_path): for r, d, f in os.walk(Path(EC.cosmotech.dataset_absolute_path) / dataset_id): - print(f"dataset: {r} {d} {f}") for dataset_name in f: path = Path(r) / dataset_name self.paths[dataset_name] = path @@ -52,14 +51,12 @@ def __init__(self): if parameter_file.exists(): with open(parameter_file) as f: parameters = json.load(f) - print(f"{parameters=}") for parameter in parameters: self.parameters[parameter["parameterId"]] = parameter["value"] def collect(self): - for dataset_id in os.listdir(EC.cosmotech.dataset_absolute_path): + for dataset_id in os.listdir(EC.cosmotech.parameters_absolute_path): for r, d, f in os.walk(Path(EC.cosmotech.parameters_absolute_path) / dataset_id): - print(f"parameter: {r} {d} {f}") for file_name in f: path = Path(r) / file_name param_name = path.parent.name From 2ca933ba3e156a04652e6f66b63b07b485fb5a88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Fri, 10 Apr 2026 17:58:22 +0200 Subject: [PATCH 07/12] add test for input_collector AI generated but reviewed by me --- cosmotech/coal/utils/input_collector.py | 5 + .../test_utils/test_utils_input_collector.py | 243 ++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 tests/unit/coal/test_utils/test_utils_input_collector.py diff --git a/cosmotech/coal/utils/input_collector.py b/cosmotech/coal/utils/input_collector.py index f0fbf112..efd9ac25 100644 --- a/cosmotech/coal/utils/input_collector.py +++ b/cosmotech/coal/utils/input_collector.py @@ -47,6 +47,8 @@ class ParameterCollector: def __init__(self): self.paths: dict[str, Path] = {} self.parameters: dict[str, str] = {} + + def read_parameters_json(self): parameter_file = Path(EC.cosmotech.parameters_absolute_path) / "parameters.json" if parameter_file.exists(): with open(parameter_file) as f: @@ -63,6 +65,9 @@ def collect(self): self.paths[param_name] = path def fetch_parameter(self, param_name: str) -> Path: + # lazy collection to avoid unnecessary json loading + if not self.parameters: + self.read_parameters_json() return self.parameters[param_name] def fetch_file_path(self, param_name: str) -> Path: diff --git a/tests/unit/coal/test_utils/test_utils_input_collector.py b/tests/unit/coal/test_utils/test_utils_input_collector.py new file mode 100644 index 00000000..57db4941 --- /dev/null +++ b/tests/unit/coal/test_utils/test_utils_input_collector.py @@ -0,0 +1,243 @@ +# Copyright (C) - 2023 - 2025 - Cosmo Tech +# This document and all information contained herein is the exclusive property - +# including all intellectual property rights pertaining thereto - of Cosmo Tech. +# Any use, reproduction, translation, broadcasting, transmission, distribution, +# etc., to any person is prohibited unless it has been previously and +# specifically authorized by written means by Cosmo Tech. + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from cosmotech.coal.utils.input_collector import ( + DatasetCollector, + InputCollector, + ParameterCollector, +) + + +@pytest.fixture +def mock_ec(tmp_path): + """Fixture that patches EC with a MagicMock whose dataset and parameters paths + both point to tmp_path. Override the attributes in individual tests as needed.""" + ec = MagicMock() + ec.cosmotech.dataset_absolute_path = str(tmp_path) + ec.cosmotech.parameters_absolute_path = str(tmp_path) + with patch("cosmotech.coal.utils.input_collector.EC", ec): + yield ec + + +class TestDatasetCollector: + def test_fetch_existing_file(self, tmp_path, mock_ec): + dataset_dir = tmp_path / "ds1" + dataset_dir.mkdir() + file = dataset_dir / "mydata.csv" + file.write_text("a,b\n1,2") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + result = collector.fetch("mydata.csv") + + assert result == file + + def test_fetch_triggers_lazy_collection(self, tmp_path, mock_ec): + dataset_dir = tmp_path / "ds1" + dataset_dir.mkdir() + (dataset_dir / "file.csv").write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + assert collector.paths == {} + collector.fetch("file.csv") + assert "file.csv" in collector.paths + + def test_fetch_missing_file_raises(self, tmp_path, mock_ec): + dataset_dir = tmp_path / "ds1" + dataset_dir.mkdir() + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + with pytest.raises(FileNotFoundError): + collector.fetch("nonexistent.csv") + + def test_collect_indexes_nested_files(self, tmp_path, mock_ec): + sub = tmp_path / "ds1" / "sub" + sub.mkdir(parents=True) + f = sub / "nested.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + collector.collect() + + assert "nested.csv" in collector.paths + assert collector.paths["nested.csv"] == f + + +class TestParameterCollector: + def test_init_starts_with_empty_dicts(self, tmp_path, mock_ec): + """parameters.json is no longer read at init — loading is now lazy.""" + params = [{"parameterId": "alpha", "value": "42"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + + assert collector.parameters == {} + assert collector.paths == {} + + def test_read_parameters_json_populates_parameters(self, tmp_path, mock_ec): + params = [{"parameterId": "alpha", "value": "42"}, {"parameterId": "beta", "value": "hello"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + collector.read_parameters_json() + + assert collector.parameters["alpha"] == "42" + assert collector.parameters["beta"] == "hello" + + def test_read_parameters_json_no_file_keeps_empty(self, tmp_path, mock_ec): + collector = ParameterCollector() + collector.read_parameters_json() + + assert collector.parameters == {} + + def test_fetch_parameter_triggers_lazy_load(self, tmp_path, mock_ec): + params = [{"parameterId": "myparam", "value": "myvalue"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + assert collector.parameters == {} + + result = collector.fetch_parameter("myparam") + + assert result == "myvalue" + assert "myparam" in collector.parameters + + def test_fetch_parameter_does_not_reload_if_already_loaded(self, tmp_path, mock_ec): + params = [{"parameterId": "key", "value": "first"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + collector.read_parameters_json() + # Mutate to confirm no second load overwrites it + collector.parameters["key"] = "modified" + + result = collector.fetch_parameter("key") + + assert result == "modified" + + def test_fetch_parameter_raises_key_error_for_unknown(self, tmp_path, mock_ec): + (tmp_path / "parameters.json").write_text(json.dumps([])) + + collector = ParameterCollector() + with pytest.raises(KeyError): + collector.fetch_parameter("nonexistent") + + def test_fetch_file_path_returns_file(self, tmp_path, mock_ec): + param_dir = tmp_path / "myparam" + param_dir.mkdir() + f = param_dir / "data.csv" + f.write_text("") + + collector = ParameterCollector() + result = collector.fetch_file_path("myparam") + + assert result == f + + def test_fetch_file_path_missing_raises(self, tmp_path, mock_ec): + collector = ParameterCollector() + with pytest.raises(FileNotFoundError): + collector.fetch_file_path("nonexistent") + + def test_fetch_returns_value_if_already_in_parameters(self, tmp_path, mock_ec): + """fetch() checks self.parameters directly — no lazy load triggered.""" + collector = ParameterCollector() + collector.parameters["preloaded"] = "value" + + result = collector.fetch("preloaded") + + assert result == "value" + + def test_fetch_falls_back_to_file_without_lazy_load(self, tmp_path, mock_ec): + """fetch() does NOT trigger read_parameters_json — falls straight to fetch_file_path.""" + param_dir = tmp_path / "myparam" + param_dir.mkdir() + f = param_dir / "data.csv" + f.write_text("") + + # parameters.json exists but fetch() won't load it + (tmp_path / "parameters.json").write_text(json.dumps([{"parameterId": "myparam", "value": "json_val"}])) + + collector = ParameterCollector() + result = collector.fetch("myparam") + + assert result == f + + +class TestInputCollector: + def test_fetch_parameter_returns_preloaded_value(self, tmp_path, mock_ec): + """InputCollector.fetch_parameter calls parameter_collector.fetch(), + which checks self.parameters directly without lazy-loading JSON.""" + collector = InputCollector() + collector.parameter_collector.parameters["key"] = "val" + + assert collector.fetch_parameter("key") == "val" + + def test_fetch_parameter_falls_back_to_file(self, tmp_path, mock_ec): + param_dir = tmp_path / "myparam" + param_dir.mkdir() + f = param_dir / "data.csv" + f.write_text("") + + collector = InputCollector() + result = collector.fetch_parameter("myparam") + + assert result == f + + def test_fetch_dataset_delegates_to_dataset_collector(self, tmp_path, mock_ec): + ds_dir = tmp_path / "ds" / "ds1" + ds_dir.mkdir(parents=True) + f = ds_dir / "myfile.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path / "ds") + + collector = InputCollector() + result = collector.fetch_dataset("myfile.csv") + + assert result == f + + def test_fetch_tries_parameter_first(self, tmp_path, mock_ec): + """fetch() catches KeyError/FileNotFoundError from fetch_parameter and falls back to dataset.""" + ds_dir = tmp_path / "ds" / "ds1" + ds_dir.mkdir(parents=True) + f = ds_dir / "fallback.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path / "ds") + + collector = InputCollector() + # Pre-load a parameter so fetch_parameter returns it without touching files + collector.parameter_collector.parameters["fallback.csv"] = "param_value" + + result = collector.fetch("fallback.csv") + + assert result == "param_value" + + def test_fetch_falls_back_to_dataset(self, tmp_path, mock_ec): + ds_dir = tmp_path / "ds" / "ds1" + ds_dir.mkdir(parents=True) + f = ds_dir / "fallback.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path / "ds") + + collector = InputCollector() + result = collector.fetch("fallback.csv") + + assert result == f From e1238728d101760d37c2e138a7d46ddf96a2b06a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Fri, 10 Apr 2026 18:02:53 +0200 Subject: [PATCH 08/12] add new parquet to store functions and cmd Made with AI assist --- cosmotech/coal/store/parquet.py | 45 ++++ .../commands/store/load_parquet_folder.py | 49 +++++ cosmotech/csm_data/commands/store/store.py | 2 + .../commands/store/load_parquet_folder.yml | 5 + .../coal/test_store/test_store_parquet.py | 200 ++++++++++++++++++ 5 files changed, 301 insertions(+) create mode 100644 cosmotech/coal/store/parquet.py create mode 100644 cosmotech/csm_data/commands/store/load_parquet_folder.py create mode 100644 cosmotech/translation/csm_data/en-US/csm_data/commands/store/load_parquet_folder.yml create mode 100644 tests/unit/coal/test_store/test_store_parquet.py diff --git a/cosmotech/coal/store/parquet.py b/cosmotech/coal/store/parquet.py new file mode 100644 index 00000000..76f1283d --- /dev/null +++ b/cosmotech/coal/store/parquet.py @@ -0,0 +1,45 @@ +# Copyright (C) - 2023 - 2025 - Cosmo Tech +# This document and all information contained herein is the exclusive property - +# including all intellectual property rights pertaining thereto - of Cosmo Tech. +# Any use, reproduction, translation, broadcasting, transmission, distribution, +# etc., to any person is prohibited unless it has been previously and +# specifically authorized by written means by Cosmo Tech. + +import pathlib + +import pyarrow as pa +import pyarrow.parquet as pq + +from cosmotech.coal.store.store import Store + + +def store_parquet_file( + table_name: str, + parquet_path: pathlib.Path, + replace_existsing_file: bool = False, + store=Store(), +): + if not parquet_path.exists(): + raise FileNotFoundError(f"File {parquet_path} does not exists") + + data: pa.Table = pq.ParquetFile(parquet_path).read() + _c = data.column_names + data = data.rename_columns([Store.sanitize_column(_column) for _column in _c]) + + store.add_table(table_name=table_name, data=data, replace=replace_existsing_file) + + +def convert_store_table_to_parquet( + table_name: str, + parquet_path: pathlib.Path, + replace_existsing_file: bool = False, + store=Store(), +): + if parquet_path.name.endswith(".parquet") and parquet_path.exists() and not replace_existsing_file: + raise FileExistsError(f"File {parquet_path} already exists") + if not parquet_path.name.endswith(".parquet"): + parquet_path = parquet_path / f"{table_name}.parquet" + folder = parquet_path.parent + folder.mkdir(parents=True, exist_ok=True) + + pq.write_table(store.get_table(table_name), parquet_path) diff --git a/cosmotech/csm_data/commands/store/load_parquet_folder.py b/cosmotech/csm_data/commands/store/load_parquet_folder.py new file mode 100644 index 00000000..08cb19ab --- /dev/null +++ b/cosmotech/csm_data/commands/store/load_parquet_folder.py @@ -0,0 +1,49 @@ +# Copyright (C) - 2023 - 2025 - Cosmo Tech +# This document and all information contained herein is the exclusive property - +# including all intellectual property rights pertaining thereto - of Cosmo Tech. +# Any use, reproduction, translation, broadcasting, transmission, distribution, +# etc., to any person is prohibited unless it has been previously and +# specifically authorized by written means by Cosmo Tech. +from cosmotech.orchestrator.utils.translate import T + +from cosmotech.csm_data.utils.click import click +from cosmotech.csm_data.utils.decorators import translate_help, web_help + + +@click.command() +@web_help("csm-data/store/load-parquet-folder") +@translate_help("csm_data.commands.store.load_parquet_folder.description") +@click.option( + "--store-folder", + envvar="CSM_PARAMETERS_ABSOLUTE_PATH", + help=T("csm_data.commands.store.load_parquet_folder.parameters.store_folder"), + metavar="PATH", + type=str, + show_envvar=True, + required=True, +) +@click.option( + "--parquet-folder", + envvar="CSM_OUTPUT_ABSOLUTE_PATH", + help=T("csm_data.commands.store.load_parquet_folder.parameters.parquet_folder"), + metavar="PATH", + type=str, + show_envvar=True, + required=True, +) +def load_parquet_folder(store_folder, parquet_folder): + # Import the modules and functions at the start of the command + import pathlib + + from cosmotech.coal.store.parquet import store_parquet_file + from cosmotech.coal.store.store import Store + from cosmotech.coal.utils.configuration import Configuration + from cosmotech.coal.utils.logger import LOGGER + + _conf = Configuration() + + _conf.coal.store = store_folder + + for parquet_path in pathlib.Path(parquet_folder).glob("*.parquet"): + LOGGER.info(T("coal.services.azure_storage.found_file").format(file=parquet_path.name)) + store_parquet_file(parquet_path.name[:-8], parquet_path, store=Store(False, _conf)) diff --git a/cosmotech/csm_data/commands/store/store.py b/cosmotech/csm_data/commands/store/store.py index 47be8957..e17b2623 100644 --- a/cosmotech/csm_data/commands/store/store.py +++ b/cosmotech/csm_data/commands/store/store.py @@ -14,6 +14,7 @@ from cosmotech.csm_data.commands.store.load_from_singlestore import ( load_from_singlestore_command, ) +from cosmotech.csm_data.commands.store.load_parquet_folder import load_parquet_folder from cosmotech.csm_data.commands.store.output import output from cosmotech.csm_data.commands.store.reset import reset from cosmotech.csm_data.utils.click import click @@ -30,6 +31,7 @@ def store(): store.add_command(reset, "reset") store.add_command(list_tables, "list-tables") store.add_command(load_csv_folder, "load-csv-folder") +store.add_command(load_parquet_folder, "load-parquet-folder") store.add_command(load_from_singlestore_command, "load-from-singlestore") store.add_command(dump_to_postgresql, "dump-to-postgresql") store.add_command(dump_to_s3, "dump-to-s3") diff --git a/cosmotech/translation/csm_data/en-US/csm_data/commands/store/load_parquet_folder.yml b/cosmotech/translation/csm_data/en-US/csm_data/commands/store/load_parquet_folder.yml new file mode 100644 index 00000000..4e869074 --- /dev/null +++ b/cosmotech/translation/csm_data/en-US/csm_data/commands/store/load_parquet_folder.yml @@ -0,0 +1,5 @@ +description: | + Running this command will find all parquet files in the given folder and put them in the store +parameters: + store_folder: The folder containing the store files + parquet_folder: The folder containing the parquet files to store diff --git a/tests/unit/coal/test_store/test_store_parquet.py b/tests/unit/coal/test_store/test_store_parquet.py new file mode 100644 index 00000000..e003fd5b --- /dev/null +++ b/tests/unit/coal/test_store/test_store_parquet.py @@ -0,0 +1,200 @@ +# Copyright (C) - 2023 - 2025 - Cosmo Tech +# This document and all information contained herein is the exclusive property - +# including all intellectual property rights pertaining thereto - of Cosmo Tech. +# Any use, reproduction, translation, broadcasting, transmission, distribution, +# etc., to any person is prohibited unless it has been previously and +# specifically authorized by written means by Cosmo Tech. + +import pathlib +from unittest.mock import MagicMock, patch + +import pyarrow as pa +import pytest + +from cosmotech.coal.store.parquet import ( + convert_store_table_to_parquet, + store_parquet_file, +) +from cosmotech.coal.store.store import Store + + +class TestParquetFunctions: + """Tests for top-level functions in the parquet module.""" + + @patch("pyarrow.parquet.ParquetFile") + @patch("pathlib.Path.exists") + def test_store_parquet_file_success(self, mock_exists, mock_parquet_file): + """Test the store_parquet_file function with a valid Parquet file.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/test.parquet") + mock_exists.return_value = True + + # Mock Parquet data + mock_data = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + mock_parquet_file.return_value.read.return_value = mock_data + + # Mock store + mock_store = MagicMock(spec=Store) + + # Act + store_parquet_file(table_name, parquet_path, False, mock_store) + + # Assert + mock_exists.assert_called_once_with() + mock_parquet_file.assert_called_once_with(parquet_path) + mock_parquet_file.return_value.read.assert_called_once_with() + mock_store.add_table.assert_called_once() + args, kwargs = mock_store.add_table.call_args + assert kwargs["table_name"] == table_name + assert kwargs["replace"] is False + + @patch("pathlib.Path.exists") + def test_store_parquet_file_not_found(self, mock_exists): + """Test the store_parquet_file function with a non-existent Parquet file.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/nonexistent.parquet") + mock_exists.return_value = False + + # Mock store + mock_store = MagicMock(spec=Store) + + # Act & Assert + with pytest.raises(FileNotFoundError): + store_parquet_file(table_name, parquet_path, False, mock_store) + + mock_exists.assert_called_once_with() + + @patch("pyarrow.parquet.ParquetFile") + @patch("pathlib.Path.exists") + def test_store_parquet_file_with_column_sanitization(self, mock_exists, mock_parquet_file): + """Test the store_parquet_file function with column sanitization.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/test.parquet") + mock_exists.return_value = True + + # Mock Parquet data with columns that need sanitization + mock_data = pa.Table.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id with space", "name-with-dash"] + ) + mock_parquet_file.return_value.read.return_value = mock_data + + # Mock store and sanitize_column + mock_store = MagicMock(spec=Store) + Store.sanitize_column = MagicMock(side_effect=lambda x: x.replace(" ", "_")) + + # Act + store_parquet_file(table_name, parquet_path, False, mock_store) + + # Assert + assert Store.sanitize_column.call_count == 2 + Store.sanitize_column.assert_any_call("id with space") + Store.sanitize_column.assert_any_call("name-with-dash") + mock_store.add_table.assert_called_once() + + @patch("pyarrow.parquet.ParquetFile") + @patch("pathlib.Path.exists") + def test_store_parquet_file_with_replace(self, mock_exists, mock_parquet_file): + """Test the store_parquet_file function with replace_existing_file=True.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/test.parquet") + mock_exists.return_value = True + + mock_data = pa.Table.from_arrays([pa.array([1, 2, 3])], names=["id"]) + mock_parquet_file.return_value.read.return_value = mock_data + + mock_store = MagicMock(spec=Store) + + # Act + store_parquet_file(table_name, parquet_path, True, mock_store) + + # Assert + args, kwargs = mock_store.add_table.call_args + assert kwargs["replace"] is True + + @patch("pyarrow.parquet.write_table") + @patch("pathlib.Path.exists") + def test_convert_store_table_to_parquet_success(self, mock_exists, mock_write_table): + """Test the convert_store_table_to_parquet function with a valid table.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/output.parquet") + mock_exists.return_value = False + + mock_store = MagicMock(spec=Store) + mock_table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + mock_store.get_table.return_value = mock_table + + with patch.object(pathlib.Path, "mkdir") as mock_mkdir: + # Act + convert_store_table_to_parquet(table_name, parquet_path, False, mock_store) + + # Assert + mock_store.get_table.assert_called_once_with(table_name) + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + mock_write_table.assert_called_once_with(mock_table, parquet_path) + + @patch("pathlib.Path.exists") + def test_convert_store_table_to_parquet_file_exists(self, mock_exists): + """Test the convert_store_table_to_parquet function when the output file already exists.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/output.parquet") + mock_exists.return_value = True + + mock_store = MagicMock(spec=Store) + + # Act & Assert + with pytest.raises(FileExistsError): + convert_store_table_to_parquet(table_name, parquet_path, False, mock_store) + + mock_exists.assert_called_once_with() + mock_store.get_table.assert_not_called() + + @patch("pyarrow.parquet.write_table") + @patch("pathlib.Path.exists") + def test_convert_store_table_to_parquet_replace_existing(self, mock_exists, mock_write_table): + """Test the convert_store_table_to_parquet function with replace_existing_file=True.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/output.parquet") + mock_exists.return_value = True + + mock_store = MagicMock(spec=Store) + mock_table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + mock_store.get_table.return_value = mock_table + + with patch.object(pathlib.Path, "mkdir") as mock_mkdir: + # Act + convert_store_table_to_parquet(table_name, parquet_path, True, mock_store) + + # Assert + mock_store.get_table.assert_called_once_with(table_name) + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + mock_write_table.assert_called_once_with(mock_table, parquet_path) + + @patch("pyarrow.parquet.write_table") + @patch("pathlib.Path.exists") + def test_convert_store_table_to_parquet_directory_path(self, mock_exists, mock_write_table): + """Test the convert_store_table_to_parquet function with a directory path.""" + # Arrange + table_name = "test_table" + parquet_path = pathlib.Path("/path/to/directory") # Not ending with .parquet + mock_exists.return_value = False + + mock_store = MagicMock(spec=Store) + mock_table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + mock_store.get_table.return_value = mock_table + + with patch.object(pathlib.Path, "mkdir") as mock_mkdir: + # Act + convert_store_table_to_parquet(table_name, parquet_path, False, mock_store) + + # Assert + mock_store.get_table.assert_called_once_with(table_name) + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + expected_path = parquet_path / f"{table_name}.parquet" + mock_write_table.assert_called_once_with(mock_table, expected_path) From 8faa63485579826475dd111a5af562644b1b884d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Wed, 15 Apr 2026 11:44:31 +0200 Subject: [PATCH 09/12] fix fk contraint management --- cosmotech/coal/postgresql/utils.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/cosmotech/coal/postgresql/utils.py b/cosmotech/coal/postgresql/utils.py index 4ac078ea..a95e8d60 100644 --- a/cosmotech/coal/postgresql/utils.py +++ b/cosmotech/coal/postgresql/utils.py @@ -156,24 +156,19 @@ def add_fk_constraint( to_col: str, ) -> None: # Connect to PostgreSQL and add a foreign key constraint - with dbapi.connect(self.full_uri, autocommit=True) as conn: + with dbapi.connect(self.full_uri, autocommit=False) as conn: with conn.cursor() as curs: + sql_drop_fk = f""" + ALTER TABLE {self.db_schema}.{from_table} + DROP CONSTRAINT IF EXISTS metadata; + """ sql_add_fk = f""" - DO $$ - BEGIN - IF NOT EXISTS ( - SELECT 1 - FROM pg_constraint - WHERE conname = 'metadata' - AND conrelid = '{self.db_schema}.{from_table}'::regclass - ) THEN - ALTER TABLE {self.db_schema}.{from_table} - ADD CONSTRAINT metadata FOREIGN KEY ({from_col}) - REFERENCES {self.db_schema}.{to_table}({to_col}) - ON DELETE CASCADE; - END IF; - END $$; + ALTER TABLE {self.db_schema}.{from_table} + ADD CONSTRAINT metadata FOREIGN KEY ({from_col}) + REFERENCES {self.db_schema}.{to_table}({to_col}) + ON DELETE CASCADE; """ + curs.execute(sql_drop_fk) curs.execute(sql_add_fk) conn.commit() From b07ff60c26deb08fea849d5d96f2b005ef7787f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Wed, 15 Apr 2026 11:45:50 +0200 Subject: [PATCH 10/12] add output and tmp cosmo env ver to configuration --- cosmotech/coal/utils/configuration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cosmotech/coal/utils/configuration.py b/cosmotech/coal/utils/configuration.py index 55a6d400..3bd88761 100644 --- a/cosmotech/coal/utils/configuration.py +++ b/cosmotech/coal/utils/configuration.py @@ -76,6 +76,8 @@ class Configuration(Dotdict): "api": {"url": "CSM_API_URL", "scope": "CSM_API_SCOPE"}, "dataset_absolute_path": "CSM_DATASET_ABSOLUTE_PATH", "parameters_absolute_path": "CSM_PARAMETERS_ABSOLUTE_PATH", + "output_absolute_path": "CSM_OUTPUT_ABSOLUTE_PATH", + "tmp_absolute_path": "CSM_TEMP_ABSOLUTE_PATH", "organization_id": "CSM_ORGANIZATION_ID", "workspace_id": "CSM_WORKSPACE_ID", "runner_id": "CSM_RUNNER_ID", From 4677e33468a014ac7ff8a7604faf389d2747472e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Wed, 15 Apr 2026 11:53:58 +0200 Subject: [PATCH 11/12] add workspace collector change default workspace path to "" --- cosmotech/coal/utils/input_collector.py | 41 +++++++++++++++++-- .../csm_data/commands/api/wsf_load_file.py | 2 +- .../test_utils/test_utils_input_collector.py | 2 +- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/cosmotech/coal/utils/input_collector.py b/cosmotech/coal/utils/input_collector.py index efd9ac25..25684960 100644 --- a/cosmotech/coal/utils/input_collector.py +++ b/cosmotech/coal/utils/input_collector.py @@ -3,12 +3,14 @@ from pathlib import Path from cosmotech.coal.utils.configuration import ENVIRONMENT_CONFIGURATION as EC +from cosmotech.coal.utils.logger import LOGGER class InputCollector: def __init__(self): self.dataset_collector = DatasetCollector() self.parameter_collector = ParameterCollector() + self.workspace_collector = WorkspaceCollector() def fetch_dataset(self, dataset_name: str) -> Path: return self.dataset_collector.fetch(dataset_name) @@ -16,11 +18,19 @@ def fetch_dataset(self, dataset_name: str) -> Path: def fetch_parameter(self, param_name: str) -> Path: return self.parameter_collector.fetch(param_name) + def fetch_workspace_file(self, file_name: str) -> Path: + return self.workspace_collector.fetch(file_name) + def fetch(self, name: str) -> Path: try: return self.fetch_parameter(name) except (KeyError, FileNotFoundError): - return self.fetch_dataset(name) + LOGGER.debug(f"Parameter {name} not found, trying workspace files.") + try: + return self.fetch_workspace_file(name) + except FileNotFoundError: + LOGGER.debug(f"Workspace file {name} not found, trying dataset files.") + return self.fetch_dataset(name) class DatasetCollector: @@ -33,6 +43,7 @@ def collect(self): for dataset_name in f: path = Path(r) / dataset_name self.paths[dataset_name] = path + self.paths[path.stem] = path def fetch(self, dataset_name: str) -> Path: # lazy collection to avoid unnecessary os.walk calls @@ -43,6 +54,27 @@ def fetch(self, dataset_name: str) -> Path: raise FileNotFoundError(f"File for {dataset_name} not found in {EC.cosmotech.dataset_absolute_path}.") +class WorkspaceCollector: + def __init__(self): + self.paths: dict[str, Path] = {} + + def collect(self): + workspace_path = Path(EC.cosmotech.dataset_absolute_path) / "workspace_files" + if workspace_path.exists(): + for r, d, f in os.walk(workspace_path): + for file_name in f: + path = Path(r) / file_name + self.paths[file_name] = path + self.paths[path.stem] = path + + def fetch(self, file_name: str) -> Path: + if not self.paths: + self.collect() + if file_name in self.paths: + return self.paths[file_name] + raise FileNotFoundError(f"File {file_name} not found in workspace_files.") + + class ParameterCollector: def __init__(self): self.paths: dict[str, Path] = {} @@ -63,6 +95,7 @@ def collect(self): path = Path(r) / file_name param_name = path.parent.name self.paths[param_name] = path + self.paths[path.stem] = path def fetch_parameter(self, param_name: str) -> Path: # lazy collection to avoid unnecessary json loading @@ -79,9 +112,9 @@ def fetch_file_path(self, param_name: str) -> Path: raise FileNotFoundError(f"File for {param_name} not found in {EC.cosmotech.parameters_absolute_path}.") def fetch(self, param_name: str) -> Path: - if param_name in self.parameters: - return self.parameters[param_name] - else: + try: + return self.fetch_parameter(param_name) + except KeyError: return self.fetch_file_path(param_name) diff --git a/cosmotech/csm_data/commands/api/wsf_load_file.py b/cosmotech/csm_data/commands/api/wsf_load_file.py index 1f84d1e1..2f26eed0 100644 --- a/cosmotech/csm_data/commands/api/wsf_load_file.py +++ b/cosmotech/csm_data/commands/api/wsf_load_file.py @@ -35,7 +35,7 @@ "--workspace-path", help=T("csm_data.commands.api.wsf_load_file.parameters.workspace_path"), metavar="PATH", - default="/", + default="", type=str, ) @click.option( diff --git a/tests/unit/coal/test_utils/test_utils_input_collector.py b/tests/unit/coal/test_utils/test_utils_input_collector.py index 57db4941..12116c9e 100644 --- a/tests/unit/coal/test_utils/test_utils_input_collector.py +++ b/tests/unit/coal/test_utils/test_utils_input_collector.py @@ -176,7 +176,7 @@ def test_fetch_falls_back_to_file_without_lazy_load(self, tmp_path, mock_ec): collector = ParameterCollector() result = collector.fetch("myparam") - assert result == f + assert result == "json_val" class TestInputCollector: From a8e49d07ecf6cc1ef74ea84677f1520f34ecea6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Al=C3=A9p=C3=A9e?= Date: Wed, 15 Apr 2026 19:35:27 +0200 Subject: [PATCH 12/12] change dataset parts upload, part_name are file_name without extention --- cosmotech/coal/cosmotech_api/apis/dataset.py | 6 +++--- .../unit/coal/test_cosmotech_api/test_apis/test_dataset.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cosmotech/coal/cosmotech_api/apis/dataset.py b/cosmotech/coal/cosmotech_api/apis/dataset.py index a137b2f4..680f0bca 100644 --- a/cosmotech/coal/cosmotech_api/apis/dataset.py +++ b/cosmotech/coal/cosmotech_api/apis/dataset.py @@ -79,7 +79,7 @@ def _download_part(self, dataset_id, dataset_part, destination): ) @staticmethod - def path_to_parts(_path, part_type) -> list[tuple[str, Path, DatasetPartTypeEnum]]: + def path_to_parts(_path, part_type) -> list[tuple[str, str, Path, DatasetPartTypeEnum]]: if (_path := Path(_path)).is_dir(): return list((str(_p.relative_to(_path)), _p, part_type) for _p in _path.rglob("*") if _p.is_file()) return list(((_path.name, _path, part_type),)) @@ -118,7 +118,7 @@ def upload_dataset( additional_data=additional_data, parts=list( DatasetPartCreateRequest( - name=_p_name, + name=Path(_p_name).stem, description=_p_name, sourceName=_p_name, type=_type, @@ -195,7 +195,7 @@ def upload_dataset_parts( # Create new part part_request = DatasetPartCreateRequest( - name=_p_name, + name=Path(_p_name).stem, description=_p_name, sourceName=_p_name, type=_type, diff --git a/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py b/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py index 8ac14f48..9d493e16 100644 --- a/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py +++ b/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py @@ -543,13 +543,13 @@ def test_update_dataset_mixed_files(self, mock_cosmotech_config, mock_api_client assert len(args_list) == 2 # check first call used to create csv part dpcr = args_list[0].kwargs.get("dataset_part_create_request") - assert dpcr.name == "data.csv" + assert dpcr.name == "data" assert dpcr.source_name == "data.csv" assert dpcr.description == "data.csv" assert dpcr.type == DatasetPartTypeEnum.FILE # check second call used to create db part dpcr = args_list[1].kwargs.get("dataset_part_create_request") - assert dpcr.name == "data.db" + assert dpcr.name == "data" assert dpcr.source_name == "data.db" assert dpcr.description == "data.db" assert dpcr.type == DatasetPartTypeEnum.DB