diff --git a/csv_detective/detection/formats.py b/csv_detective/detection/formats.py index 616cc82..a453e6a 100755 --- a/csv_detective/detection/formats.py +++ b/csv_detective/detection/formats.py @@ -40,6 +40,13 @@ def detect_formats( if len(formats) == 0: return analysis, None + # filter parent-children to only active formats (tags may exclude some) + active_parent_children = { + parent: [c for c in children if c in formats] + for parent, children in fmtm._children.items() + if parent in formats + } + # Perform testing on fields if not in_chunks: # table is small enough to be tested in one go @@ -49,6 +56,7 @@ def detect_formats( limited_output=limited_output, skipna=skipna, verbose=verbose, + parent_children=active_parent_children, ) handle_empty_columns(scores_table_fields) res_categorical, _ = detect_categorical_variable( @@ -67,6 +75,7 @@ def detect_formats( limited_output=limited_output, skipna=skipna, verbose=verbose, + parent_children=active_parent_children, ) analysis["columns_fields"] = prepare_output_dict(scores_table_fields, limited_output) diff --git a/csv_detective/format.py b/csv_detective/format.py index 8080757..292c5fe 100755 --- a/csv_detective/format.py +++ b/csv_detective/format.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import Any, Callable from csv_detective.parsing.text import header_score @@ -15,6 +16,7 @@ def __init__( tags: list[str] = [], mandatory_label: bool = False, python_type: str = "string", + parent: str | None = None, ) -> None: """ Instanciates a Format object. @@ -39,6 +41,7 @@ def __init__( self.tags: list[str] = tags self.mandatory_label: bool = mandatory_label self.python_type: str = python_type + self.parent: str | None = parent def is_valid_label(self, val: str) -> float: return header_score(val, self.labels) @@ -82,7 +85,7 @@ def __init__( _test_values=module._test_values, **{ attr: val - for attr in ["labels", "description", "tags", "mandatory_label", "python_type"] + for attr in ["labels", "description", "tags", "mandatory_label", "python_type", "parent"] if (val := getattr(module, attr, None)) } | { @@ -100,6 +103,25 @@ def __init__( ) for label in format_labels } + self._children: dict[str, list[str]] = defaultdict(list) + for name, fmt in self.formats.items(): + if fmt.parent is not None: + if fmt.parent not in self.formats: + raise ValueError( + f"Format '{name}' declares parent '{fmt.parent}' which does not exist" + ) + self._children[fmt.parent].append(name) + + def get_leaf_formats(self) -> dict[str, Format]: + return {name: fmt for name, fmt in self.formats.items() if name not in self._children} + + def get_ancestors(self, format_name: str) -> list[str]: + ancestors = [] + current = format_name + while (parent := self.formats[current].parent) is not None: + ancestors.append(parent) + current = parent + return ancestors def get_formats_from_tags(self, tags: list[str]) -> dict[str, Format]: return { diff --git a/csv_detective/formats/code_epci.py b/csv_detective/formats/code_epci.py index 11accb4..0b92f6f 100755 --- a/csv_detective/formats/code_epci.py +++ b/csv_detective/formats/code_epci.py @@ -1,6 +1,7 @@ from csv_detective.formats.siren import _is as is_siren proportion = 0.9 +parent = "siren" description = "French EPCI (group of communes) code, subgroup of SIREN" tags = ["fr", "geo"] mandatory_label = True diff --git a/csv_detective/formats/geojson.py b/csv_detective/formats/geojson.py index e131804..d7e0d24 100755 --- a/csv_detective/formats/geojson.py +++ b/csv_detective/formats/geojson.py @@ -1,6 +1,7 @@ import json proportion = 1 +parent = "json" description = "JSON object in the [GeoJSON](https://fr.wikipedia.org/wiki/GeoJSON) format" tags = ["geo"] python_type = "json" diff --git a/csv_detective/formats/latitude_l93.py b/csv_detective/formats/latitude_l93.py index e18010c..b47d4ba 100755 --- a/csv_detective/formats/latitude_l93.py +++ b/csv_detective/formats/latitude_l93.py @@ -5,6 +5,7 @@ from csv_detective.formats.latitude_wgs import SHARED_LATITUDE_LABELS proportion = 1 +parent = "float" description = "Latitude in the Lambert 93 format" tags = ["fr", "geo"] mandatory_label = True diff --git a/csv_detective/formats/latitude_wgs.py b/csv_detective/formats/latitude_wgs.py index 0df4a80..1008913 100755 --- a/csv_detective/formats/latitude_wgs.py +++ b/csv_detective/formats/latitude_wgs.py @@ -2,6 +2,7 @@ from csv_detective.formats.int import _is as is_int proportion = 1 +parent = "float" description = "Latitude in the WGS format" tags = ["geo"] mandatory_label = True diff --git a/csv_detective/formats/latitude_wgs_fr_metropole.py b/csv_detective/formats/latitude_wgs_fr_metropole.py index a68dced..a130f6d 100755 --- a/csv_detective/formats/latitude_wgs_fr_metropole.py +++ b/csv_detective/formats/latitude_wgs_fr_metropole.py @@ -1,6 +1,7 @@ from csv_detective.formats.latitude_wgs import _is as is_latitude, labels # noqa proportion = 1 +parent = "latitude_wgs" description = "Latitude within the French metropole bounds in the WGS format" tags = ["fr", "geo"] mandatory_label = True diff --git a/csv_detective/formats/longitude_l93.py b/csv_detective/formats/longitude_l93.py index f3ad521..101aed8 100755 --- a/csv_detective/formats/longitude_l93.py +++ b/csv_detective/formats/longitude_l93.py @@ -5,6 +5,7 @@ from csv_detective.formats.longitude_wgs import SHARED_LONGITUDE_LABELS proportion = 1 +parent = "float" description = "Longitude in the Lambert 93 format" tags = ["fr", "geo"] mandatory_label = True diff --git a/csv_detective/formats/longitude_wgs.py b/csv_detective/formats/longitude_wgs.py index 48f6a82..b2d262e 100755 --- a/csv_detective/formats/longitude_wgs.py +++ b/csv_detective/formats/longitude_wgs.py @@ -2,6 +2,7 @@ from csv_detective.formats.int import _is as is_int proportion = 1 +parent = "float" description = "Longitude in the WGS format" tags = ["geo"] mandatory_label = True diff --git a/csv_detective/formats/longitude_wgs_fr_metropole.py b/csv_detective/formats/longitude_wgs_fr_metropole.py index 883829d..9323c24 100755 --- a/csv_detective/formats/longitude_wgs_fr_metropole.py +++ b/csv_detective/formats/longitude_wgs_fr_metropole.py @@ -1,6 +1,7 @@ from csv_detective.formats.longitude_wgs import _is as is_longitude, labels # noqa proportion = 1 +parent = "longitude_wgs" description = "Longitude within the French metropole bounds in the WGS format" tags = ["fr", "geo"] mandatory_label = True diff --git a/csv_detective/parsing/columns.py b/csv_detective/parsing/columns.py index f29d370..1e37448 100755 --- a/csv_detective/parsing/columns.py +++ b/csv_detective/parsing/columns.py @@ -75,17 +75,25 @@ def test_col( limited_output: bool, skipna: bool = True, verbose: bool = False, + parent_children: dict[str, list[str]] | None = None, ): if verbose: start = time() logging.info("Testing columns to get formats") return_table = pd.DataFrame(columns=table.columns) - for idx, (label, format) in enumerate(formats.items()): + + if parent_children: + parent_names = set(parent_children.keys()) + leaf_formats = {k: v for k, v in formats.items() if k not in parent_names} + parent_formats = {k: v for k, v in formats.items() if k in parent_names} + else: + leaf_formats = formats + parent_formats = {} + + for idx, (label, format) in enumerate(leaf_formats.items()): if verbose: start_type = time() logging.info(f"\t- Starting with format '{label}'") - # improvement lead : put the longest tests behind and make them only if previous tests not satisfactory - # => the following needs to change, "apply" means all columns are tested for one type at once for col in table.columns: return_table.loc[label, col] = test_col_val( table[col], @@ -96,9 +104,38 @@ def test_col( ) if verbose: display_logs_depending_process_time( - f'\t> Done with format "{label}" in {round(time() - start_type, 3)}s ({idx + 1}/{len(formats)})', + f'\t> Done with format "{label}" in {round(time() - start_type, 3)}s ({idx + 1}/{len(leaf_formats)})', + time() - start_type, + ) + + for label, format in parent_formats.items(): + if verbose: + start_type = time() + logging.info(f"\t- Propagating to parent format '{label}'") + children = parent_children[label] + for col in table.columns: + child_scores = [ + return_table.loc[child, col] + for child in children + if child in return_table.index + ] + max_child = max(child_scores) if child_scores else 0 + if max_child > 0: + return_table.loc[label, col] = max_child + else: + return_table.loc[label, col] = test_col_val( + table[col], + format, + skipna=skipna, + limited_output=limited_output, + verbose=verbose, + ) + if verbose: + display_logs_depending_process_time( + f'\t> Done with parent format "{label}" in {round(time() - start_type, 3)}s', time() - start_type, ) + if verbose: display_logs_depending_process_time( f"Done testing columns in {round(time() - start, 3)}s", time() - start @@ -138,6 +175,7 @@ def test_col_chunks( limited_output: bool, skipna: bool = True, verbose: bool = False, + parent_children: dict[str, list[str]] | None = None, ) -> tuple[pd.DataFrame, dict, dict[str, pd.Series]]: def build_remaining_tests_per_col(return_table: pd.DataFrame) -> dict[str, list[str]]: # returns a dict with the table's columns as keys and the list of remaining format labels to apply @@ -156,7 +194,10 @@ def build_remaining_tests_per_col(return_table: pd.DataFrame) -> dict[str, list[ logging.info("Testing columns to get formats on chunks") # analysing the sample to get a first guess - return_table = test_col(table, formats, limited_output, skipna=skipna, verbose=verbose) + return_table = test_col( + table, formats, limited_output, skipna=skipna, verbose=verbose, + parent_children=parent_children, + ) # mandatory_label formats are zeroed out at the end if the label doesn't match, # so there's no point running the expensive field tests on those columns mandatory_label_skip: dict[str, set[str]] = { @@ -229,10 +270,12 @@ def build_remaining_tests_per_col(return_table: pd.DataFrame) -> dict[str, list[ if not any(remaining_tests for remaining_tests in remaining_tests_per_col.values()): # no more potential tests to do on any column, early stop break + parent_names = set(parent_children.keys()) if parent_children else set() for col, fmt_labels in remaining_tests_per_col.items(): - # testing each column with the tests that are still competing - # after previous batchs analyses + # testing each column with the leaf tests that are still competing for label in fmt_labels: + if label in parent_names: + continue batch_col_test = test_col_val( batch[col], formats[label], @@ -240,12 +283,37 @@ def build_remaining_tests_per_col(return_table: pd.DataFrame) -> dict[str, list[ skipna=skipna, ) return_table.loc[label, col] = ( - # if this batch's column tested 0 then test fails overall 0 if batch_col_test == 0 - # otherwise updating the score with weighted average else ((return_table.loc[label, col] * idx + batch_col_test) / (idx + 1)) ) + # propagate to parent formats + for parent_label in fmt_labels: + if parent_label not in parent_names: + continue + children = parent_children[parent_label] + child_scores = [ + return_table.loc[child, col] + for child in children + if child in return_table.index + ] + max_child = max(child_scores) if child_scores else 0 + if max_child > 0: + return_table.loc[parent_label, col] = max_child + else: + batch_col_test = test_col_val( + batch[col], + formats[parent_label], + limited_output=limited_output, + skipna=skipna, + ) + return_table.loc[parent_label, col] = ( + 0 + if batch_col_test == 0 + else ( + (return_table.loc[parent_label, col] * idx + batch_col_test) / (idx + 1) + ) + ) remaining_tests_per_col = build_remaining_tests_per_col(return_table) batch, batch_number = [], batch_number + 1 analysis["nb_duplicates"] = sum(row_hashes_count > 1) diff --git a/tests/test_fields.py b/tests/test_fields.py index 0afbede..8a9962c 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -160,6 +160,34 @@ def test_early_detection(args): mock_func.assert_not_called() +def test_parent_score_propagation(): + # latitude_wgs_fr_metropole values are valid floats, + # so float should get a propagated score without being tested directly + lat_values = fmtm.formats["latitude_wgs_fr_metropole"]._test_values[True] + table = pd.DataFrame({"lat": (lat_values * 100)[:100]}) + parent_children = { + parent: [c for c in children] + for parent, children in fmtm._children.items() + } + returned_table = col_test(table, fmtm.formats, limited_output=True, parent_children=parent_children) + assert returned_table.loc["latitude_wgs_fr_metropole", "lat"] > 0 + assert returned_table.loc["latitude_wgs", "lat"] > 0 + assert returned_table.loc["float", "lat"] > 0 + + +def test_parent_tested_when_children_fail(): + # float values that are NOT latitudes should still get float detected + table = pd.DataFrame({"val": ["999.99", "-500.5", "123456.789"] * 34}) + parent_children = { + parent: [c for c in children] + for parent, children in fmtm._children.items() + } + returned_table = col_test(table, fmtm.formats, limited_output=True, parent_children=parent_children) + assert returned_table.loc["float", "val"] > 0 + assert returned_table.loc["latitude_wgs", "val"] == 0 + assert returned_table.loc["latitude_wgs_fr_metropole", "val"] == 0 + + def test_all_proportion_1(): # building a table that uses only correct values for these formats, except on one row table = pd.DataFrame( diff --git a/tests/test_structure.py b/tests/test_structure.py index 8bf65b9..3358483 100755 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -52,6 +52,36 @@ def test_get_from_tags(tags): assert tag in fmt.tags +def test_parent_references_valid(): + for name, fmt in fmtm.formats.items(): + if fmt.parent is not None: + assert fmt.parent in fmtm.formats, ( + f"Format '{name}' declares parent '{fmt.parent}' which does not exist" + ) + + +def test_true_subset_invariant(): + for name, fmt in fmtm.formats.items(): + if fmt.parent is None: + continue + parent = fmtm.formats[fmt.parent] + for val in fmt._test_values[True]: + assert parent.func(val), ( + f"'{val}' is valid for '{name}' but not for parent '{fmt.parent}'" + ) + + +def test_leaf_formats_excludes_parents(): + leaves = fmtm.get_leaf_formats() + for parent_name in fmtm._children: + assert parent_name not in leaves + + +def test_get_ancestors(): + ancestors = fmtm.get_ancestors("latitude_wgs_fr_metropole") + assert ancestors == ["latitude_wgs", "float"] + + @pytest.mark.parametrize( "func, max_pos_args", (