diff --git a/cosmotech/coal/postgresql/runner.py b/cosmotech/coal/postgresql/runner.py index 34497bcd..9236f3e6 100644 --- a/cosmotech/coal/postgresql/runner.py +++ b/cosmotech/coal/postgresql/runner.py @@ -49,19 +49,25 @@ def send_runner_metadata_to_postgresql( CREATE TABLE IF NOT EXISTS {schema_table} ( id varchar(32) PRIMARY KEY, name varchar(256), - last_csm_run_id varchar(32), + last_csm_run_id varchar(32) UNIQUE, run_template_id varchar(32) ); """ LOGGER.info(T("coal.services.postgresql.creating_table").format(schema_table=schema_table)) curs.execute(sql_create_table) conn.commit() + + runner_id = runner.get("id") + sql_delete_from_metatable = f""" + DELETE FROM {schema_table} + WHERE id= $1; + """ + curs.execute(sql_delete_from_metatable, (runner_id,)) + conn.commit() + sql_upsert = f""" INSERT INTO {schema_table} (id, name, last_csm_run_id, run_template_id) - VALUES ($1, $2, $3, $4) - ON CONFLICT (id) - DO - UPDATE SET name = EXCLUDED.name, last_csm_run_id = EXCLUDED.last_csm_run_id; + VALUES ($1, $2, $3, $4) """ LOGGER.debug(runner) curs.execute( diff --git a/cosmotech/coal/postgresql/store.py b/cosmotech/coal/postgresql/store.py index 8db8fb72..95d6c422 100644 --- a/cosmotech/coal/postgresql/store.py +++ b/cosmotech/coal/postgresql/store.py @@ -115,7 +115,7 @@ def dump_store_to_postgresql_from_conf( ) if fk_id and _psql.is_metadata_exists(): metadata_table = f"{_psql.metadata_table_name}" - _psql.add_fk_constraint(table_name, "csm_run_id", metadata_table, "last_csm_run_id") + _psql.add_fk_constraint(target_table_name, "csm_run_id", metadata_table, "last_csm_run_id") total_rows += rows _up_time = perf_counter() diff --git a/cosmotech/coal/postgresql/utils.py b/cosmotech/coal/postgresql/utils.py index 4d863608..4ac078ea 100644 --- a/cosmotech/coal/postgresql/utils.py +++ b/cosmotech/coal/postgresql/utils.py @@ -155,12 +155,24 @@ def add_fk_constraint( to_table: str, to_col: str, ) -> None: - # Connect to PostgreSQL and remove runner metadata row + # Connect to PostgreSQL and add a foreign key constraint with dbapi.connect(self.full_uri, autocommit=True) as conn: with conn.cursor() as curs: sql_add_fk = f""" - ALTER TABLE {self.db_schema}.{from_table} - CONSTRAINT metadata FOREIGN KEY ({from_col}) REFERENCES {to_table}({to_col}) + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'metadata' + AND conrelid = '{self.db_schema}.{from_table}'::regclass + ) THEN + ALTER TABLE {self.db_schema}.{from_table} + ADD CONSTRAINT metadata FOREIGN KEY ({from_col}) + REFERENCES {self.db_schema}.{to_table}({to_col}) + ON DELETE CASCADE; + END IF; + END $$; """ curs.execute(sql_add_fk) conn.commit() diff --git a/cosmotech/coal/utils/input_collector.py b/cosmotech/coal/utils/input_collector.py new file mode 100644 index 00000000..efd9ac25 --- /dev/null +++ b/cosmotech/coal/utils/input_collector.py @@ -0,0 +1,88 @@ +import json +import os +from pathlib import Path + +from cosmotech.coal.utils.configuration import ENVIRONMENT_CONFIGURATION as EC + + +class InputCollector: + def __init__(self): + self.dataset_collector = DatasetCollector() + self.parameter_collector = ParameterCollector() + + def fetch_dataset(self, dataset_name: str) -> Path: + return self.dataset_collector.fetch(dataset_name) + + def fetch_parameter(self, param_name: str) -> Path: + return self.parameter_collector.fetch(param_name) + + def fetch(self, name: str) -> Path: + try: + return self.fetch_parameter(name) + except (KeyError, FileNotFoundError): + return self.fetch_dataset(name) + + +class DatasetCollector: + def __init__(self): + self.paths: dict[str, Path] = {} + + def collect(self): + for dataset_id in os.listdir(EC.cosmotech.dataset_absolute_path): + for r, d, f in os.walk(Path(EC.cosmotech.dataset_absolute_path) / dataset_id): + for dataset_name in f: + path = Path(r) / dataset_name + self.paths[dataset_name] = path + + def fetch(self, dataset_name: str) -> Path: + # lazy collection to avoid unnecessary os.walk calls + if not self.paths: + self.collect() + if dataset_name in self.paths: + return self.paths[dataset_name] + raise FileNotFoundError(f"File for {dataset_name} not found in {EC.cosmotech.dataset_absolute_path}.") + + +class ParameterCollector: + def __init__(self): + self.paths: dict[str, Path] = {} + self.parameters: dict[str, str] = {} + + def read_parameters_json(self): + parameter_file = Path(EC.cosmotech.parameters_absolute_path) / "parameters.json" + if parameter_file.exists(): + with open(parameter_file) as f: + parameters = json.load(f) + for parameter in parameters: + self.parameters[parameter["parameterId"]] = parameter["value"] + + def collect(self): + for dataset_id in os.listdir(EC.cosmotech.parameters_absolute_path): + for r, d, f in os.walk(Path(EC.cosmotech.parameters_absolute_path) / dataset_id): + for file_name in f: + path = Path(r) / file_name + param_name = path.parent.name + self.paths[param_name] = path + + def fetch_parameter(self, param_name: str) -> Path: + # lazy collection to avoid unnecessary json loading + if not self.parameters: + self.read_parameters_json() + return self.parameters[param_name] + + def fetch_file_path(self, param_name: str) -> Path: + # lazy collection to avoid unnecessary os.walk calls + if not self.paths: + self.collect() + if param_name in self.paths: + return self.paths[param_name] + raise FileNotFoundError(f"File for {param_name} not found in {EC.cosmotech.parameters_absolute_path}.") + + def fetch(self, param_name: str) -> Path: + if param_name in self.parameters: + return self.parameters[param_name] + else: + return self.fetch_file_path(param_name) + + +ENVIRONMENT_INPUT_COLLECTOR = InputCollector() diff --git a/tests/unit/coal/test_postgresql/test_postgresql_runner.py b/tests/unit/coal/test_postgresql/test_postgresql_runner.py index 870a1f89..33c3f4a8 100644 --- a/tests/unit/coal/test_postgresql/test_postgresql_runner.py +++ b/tests/unit/coal/test_postgresql/test_postgresql_runner.py @@ -72,14 +72,19 @@ def test_send_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ut mock_connect.assert_called_once_with("postgresql://user:password@localhost:5432/testdb", autocommit=True) # Check that SQL statements were executed - assert mock_cursor.execute.call_count == 2 + assert mock_cursor.execute.call_count == 3 # Verify the SQL statements (partially, since the exact SQL is complex) create_table_call = mock_cursor.execute.call_args_list[0] assert "CREATE TABLE IF NOT EXISTS" in create_table_call[0][0] assert "public.test_runnermetadata" in create_table_call[0][0] - upsert_call = mock_cursor.execute.call_args_list[1] + delete_call = mock_cursor.execute.call_args_list[1] + assert "DELETE FROM" in delete_call[0][0] + assert "public.test_runnermetadata" in delete_call[0][0] + assert delete_call[0][1] == ("test-runner-id",) + + upsert_call = mock_cursor.execute.call_args_list[2] assert "INSERT INTO" in upsert_call[0][0] assert "public.test_runnermetadata" in upsert_call[0][0] assert upsert_call[0][1] == ( @@ -90,7 +95,7 @@ def test_send_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ut ) # Check that commits were called - assert mock_conn.commit.call_count == 2 + assert mock_conn.commit.call_count == 3 # Verify the function returns the lastRunId assert result == "test-run-id" diff --git a/tests/unit/coal/test_utils/test_utils_input_collector.py b/tests/unit/coal/test_utils/test_utils_input_collector.py new file mode 100644 index 00000000..57db4941 --- /dev/null +++ b/tests/unit/coal/test_utils/test_utils_input_collector.py @@ -0,0 +1,243 @@ +# Copyright (C) - 2023 - 2025 - Cosmo Tech +# This document and all information contained herein is the exclusive property - +# including all intellectual property rights pertaining thereto - of Cosmo Tech. +# Any use, reproduction, translation, broadcasting, transmission, distribution, +# etc., to any person is prohibited unless it has been previously and +# specifically authorized by written means by Cosmo Tech. + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from cosmotech.coal.utils.input_collector import ( + DatasetCollector, + InputCollector, + ParameterCollector, +) + + +@pytest.fixture +def mock_ec(tmp_path): + """Fixture that patches EC with a MagicMock whose dataset and parameters paths + both point to tmp_path. Override the attributes in individual tests as needed.""" + ec = MagicMock() + ec.cosmotech.dataset_absolute_path = str(tmp_path) + ec.cosmotech.parameters_absolute_path = str(tmp_path) + with patch("cosmotech.coal.utils.input_collector.EC", ec): + yield ec + + +class TestDatasetCollector: + def test_fetch_existing_file(self, tmp_path, mock_ec): + dataset_dir = tmp_path / "ds1" + dataset_dir.mkdir() + file = dataset_dir / "mydata.csv" + file.write_text("a,b\n1,2") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + result = collector.fetch("mydata.csv") + + assert result == file + + def test_fetch_triggers_lazy_collection(self, tmp_path, mock_ec): + dataset_dir = tmp_path / "ds1" + dataset_dir.mkdir() + (dataset_dir / "file.csv").write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + assert collector.paths == {} + collector.fetch("file.csv") + assert "file.csv" in collector.paths + + def test_fetch_missing_file_raises(self, tmp_path, mock_ec): + dataset_dir = tmp_path / "ds1" + dataset_dir.mkdir() + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + with pytest.raises(FileNotFoundError): + collector.fetch("nonexistent.csv") + + def test_collect_indexes_nested_files(self, tmp_path, mock_ec): + sub = tmp_path / "ds1" / "sub" + sub.mkdir(parents=True) + f = sub / "nested.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path) + + collector = DatasetCollector() + collector.collect() + + assert "nested.csv" in collector.paths + assert collector.paths["nested.csv"] == f + + +class TestParameterCollector: + def test_init_starts_with_empty_dicts(self, tmp_path, mock_ec): + """parameters.json is no longer read at init — loading is now lazy.""" + params = [{"parameterId": "alpha", "value": "42"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + + assert collector.parameters == {} + assert collector.paths == {} + + def test_read_parameters_json_populates_parameters(self, tmp_path, mock_ec): + params = [{"parameterId": "alpha", "value": "42"}, {"parameterId": "beta", "value": "hello"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + collector.read_parameters_json() + + assert collector.parameters["alpha"] == "42" + assert collector.parameters["beta"] == "hello" + + def test_read_parameters_json_no_file_keeps_empty(self, tmp_path, mock_ec): + collector = ParameterCollector() + collector.read_parameters_json() + + assert collector.parameters == {} + + def test_fetch_parameter_triggers_lazy_load(self, tmp_path, mock_ec): + params = [{"parameterId": "myparam", "value": "myvalue"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + assert collector.parameters == {} + + result = collector.fetch_parameter("myparam") + + assert result == "myvalue" + assert "myparam" in collector.parameters + + def test_fetch_parameter_does_not_reload_if_already_loaded(self, tmp_path, mock_ec): + params = [{"parameterId": "key", "value": "first"}] + (tmp_path / "parameters.json").write_text(json.dumps(params)) + + collector = ParameterCollector() + collector.read_parameters_json() + # Mutate to confirm no second load overwrites it + collector.parameters["key"] = "modified" + + result = collector.fetch_parameter("key") + + assert result == "modified" + + def test_fetch_parameter_raises_key_error_for_unknown(self, tmp_path, mock_ec): + (tmp_path / "parameters.json").write_text(json.dumps([])) + + collector = ParameterCollector() + with pytest.raises(KeyError): + collector.fetch_parameter("nonexistent") + + def test_fetch_file_path_returns_file(self, tmp_path, mock_ec): + param_dir = tmp_path / "myparam" + param_dir.mkdir() + f = param_dir / "data.csv" + f.write_text("") + + collector = ParameterCollector() + result = collector.fetch_file_path("myparam") + + assert result == f + + def test_fetch_file_path_missing_raises(self, tmp_path, mock_ec): + collector = ParameterCollector() + with pytest.raises(FileNotFoundError): + collector.fetch_file_path("nonexistent") + + def test_fetch_returns_value_if_already_in_parameters(self, tmp_path, mock_ec): + """fetch() checks self.parameters directly — no lazy load triggered.""" + collector = ParameterCollector() + collector.parameters["preloaded"] = "value" + + result = collector.fetch("preloaded") + + assert result == "value" + + def test_fetch_falls_back_to_file_without_lazy_load(self, tmp_path, mock_ec): + """fetch() does NOT trigger read_parameters_json — falls straight to fetch_file_path.""" + param_dir = tmp_path / "myparam" + param_dir.mkdir() + f = param_dir / "data.csv" + f.write_text("") + + # parameters.json exists but fetch() won't load it + (tmp_path / "parameters.json").write_text(json.dumps([{"parameterId": "myparam", "value": "json_val"}])) + + collector = ParameterCollector() + result = collector.fetch("myparam") + + assert result == f + + +class TestInputCollector: + def test_fetch_parameter_returns_preloaded_value(self, tmp_path, mock_ec): + """InputCollector.fetch_parameter calls parameter_collector.fetch(), + which checks self.parameters directly without lazy-loading JSON.""" + collector = InputCollector() + collector.parameter_collector.parameters["key"] = "val" + + assert collector.fetch_parameter("key") == "val" + + def test_fetch_parameter_falls_back_to_file(self, tmp_path, mock_ec): + param_dir = tmp_path / "myparam" + param_dir.mkdir() + f = param_dir / "data.csv" + f.write_text("") + + collector = InputCollector() + result = collector.fetch_parameter("myparam") + + assert result == f + + def test_fetch_dataset_delegates_to_dataset_collector(self, tmp_path, mock_ec): + ds_dir = tmp_path / "ds" / "ds1" + ds_dir.mkdir(parents=True) + f = ds_dir / "myfile.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path / "ds") + + collector = InputCollector() + result = collector.fetch_dataset("myfile.csv") + + assert result == f + + def test_fetch_tries_parameter_first(self, tmp_path, mock_ec): + """fetch() catches KeyError/FileNotFoundError from fetch_parameter and falls back to dataset.""" + ds_dir = tmp_path / "ds" / "ds1" + ds_dir.mkdir(parents=True) + f = ds_dir / "fallback.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path / "ds") + + collector = InputCollector() + # Pre-load a parameter so fetch_parameter returns it without touching files + collector.parameter_collector.parameters["fallback.csv"] = "param_value" + + result = collector.fetch("fallback.csv") + + assert result == "param_value" + + def test_fetch_falls_back_to_dataset(self, tmp_path, mock_ec): + ds_dir = tmp_path / "ds" / "ds1" + ds_dir.mkdir(parents=True) + f = ds_dir / "fallback.csv" + f.write_text("") + + mock_ec.cosmotech.dataset_absolute_path = str(tmp_path / "ds") + + collector = InputCollector() + result = collector.fetch("fallback.csv") + + assert result == f