Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions csv_detective/detection/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def detect_formats(
if len(formats) == 0:
return analysis, None

# filter parent-children to only active formats (tags may exclude some)
active_parent_children = {
parent: [c for c in children if c in formats]
for parent, children in fmtm._children.items()
if parent in formats
}

# Perform testing on fields
if not in_chunks:
# table is small enough to be tested in one go
Expand All @@ -49,6 +56,7 @@ def detect_formats(
limited_output=limited_output,
skipna=skipna,
verbose=verbose,
parent_children=active_parent_children,
)
handle_empty_columns(scores_table_fields)
res_categorical, _ = detect_categorical_variable(
Expand All @@ -67,6 +75,7 @@ def detect_formats(
limited_output=limited_output,
skipna=skipna,
verbose=verbose,
parent_children=active_parent_children,
)
analysis["columns_fields"] = prepare_output_dict(scores_table_fields, limited_output)

Expand Down
24 changes: 23 additions & 1 deletion csv_detective/format.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Any, Callable

from csv_detective.parsing.text import header_score
Expand All @@ -15,6 +16,7 @@ def __init__(
tags: list[str] = [],
mandatory_label: bool = False,
python_type: str = "string",
parent: str | None = None,
) -> None:
"""
Instanciates a Format object.
Expand All @@ -39,6 +41,7 @@ def __init__(
self.tags: list[str] = tags
self.mandatory_label: bool = mandatory_label
self.python_type: str = python_type
self.parent: str | None = parent

def is_valid_label(self, val: str) -> float:
return header_score(val, self.labels)
Expand Down Expand Up @@ -82,7 +85,7 @@ def __init__(
_test_values=module._test_values,
**{
attr: val
for attr in ["labels", "description", "tags", "mandatory_label", "python_type"]
for attr in ["labels", "description", "tags", "mandatory_label", "python_type", "parent"]
if (val := getattr(module, attr, None))
}
| {
Expand All @@ -100,6 +103,25 @@ def __init__(
)
for label in format_labels
}
self._children: dict[str, list[str]] = defaultdict(list)
for name, fmt in self.formats.items():
if fmt.parent is not None:
if fmt.parent not in self.formats:
raise ValueError(
f"Format '{name}' declares parent '{fmt.parent}' which does not exist"
)
self._children[fmt.parent].append(name)

def get_leaf_formats(self) -> dict[str, Format]:
return {name: fmt for name, fmt in self.formats.items() if name not in self._children}

def get_ancestors(self, format_name: str) -> list[str]:
ancestors = []
current = format_name
while (parent := self.formats[current].parent) is not None:
ancestors.append(parent)
current = parent
return ancestors

def get_formats_from_tags(self, tags: list[str]) -> dict[str, Format]:
return {
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/code_epci.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from csv_detective.formats.siren import _is as is_siren

proportion = 0.9
parent = "siren"
description = "French EPCI (group of communes) code, subgroup of SIREN"
tags = ["fr", "geo"]
mandatory_label = True
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/geojson.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

proportion = 1
parent = "json"
description = "JSON object in the [GeoJSON](https://fr.wikipedia.org/wiki/GeoJSON) format"
tags = ["geo"]
python_type = "json"
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/latitude_l93.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from csv_detective.formats.latitude_wgs import SHARED_LATITUDE_LABELS

proportion = 1
parent = "float"
description = "Latitude in the Lambert 93 format"
tags = ["fr", "geo"]
mandatory_label = True
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/latitude_wgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from csv_detective.formats.int import _is as is_int

proportion = 1
parent = "float"
description = "Latitude in the WGS format"
tags = ["geo"]
mandatory_label = True
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/latitude_wgs_fr_metropole.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from csv_detective.formats.latitude_wgs import _is as is_latitude, labels # noqa

proportion = 1
parent = "latitude_wgs"
description = "Latitude within the French metropole bounds in the WGS format"
tags = ["fr", "geo"]
mandatory_label = True
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/longitude_l93.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from csv_detective.formats.longitude_wgs import SHARED_LONGITUDE_LABELS

proportion = 1
parent = "float"
description = "Longitude in the Lambert 93 format"
tags = ["fr", "geo"]
mandatory_label = True
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/longitude_wgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from csv_detective.formats.int import _is as is_int

proportion = 1
parent = "float"
description = "Longitude in the WGS format"
tags = ["geo"]
mandatory_label = True
Expand Down
1 change: 1 addition & 0 deletions csv_detective/formats/longitude_wgs_fr_metropole.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from csv_detective.formats.longitude_wgs import _is as is_longitude, labels # noqa

proportion = 1
parent = "longitude_wgs"
description = "Longitude within the French metropole bounds in the WGS format"
tags = ["fr", "geo"]
mandatory_label = True
Expand Down
86 changes: 77 additions & 9 deletions csv_detective/parsing/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,25 @@ def test_col(
limited_output: bool,
skipna: bool = True,
verbose: bool = False,
parent_children: dict[str, list[str]] | None = None,
):
if verbose:
start = time()
logging.info("Testing columns to get formats")
return_table = pd.DataFrame(columns=table.columns)
for idx, (label, format) in enumerate(formats.items()):

if parent_children:
parent_names = set(parent_children.keys())
leaf_formats = {k: v for k, v in formats.items() if k not in parent_names}
parent_formats = {k: v for k, v in formats.items() if k in parent_names}
else:
leaf_formats = formats
parent_formats = {}

for idx, (label, format) in enumerate(leaf_formats.items()):
if verbose:
start_type = time()
logging.info(f"\t- Starting with format '{label}'")
# improvement lead : put the longest tests behind and make them only if previous tests not satisfactory
# => the following needs to change, "apply" means all columns are tested for one type at once
for col in table.columns:
return_table.loc[label, col] = test_col_val(
table[col],
Expand All @@ -96,9 +104,38 @@ def test_col(
)
if verbose:
display_logs_depending_process_time(
f'\t> Done with format "{label}" in {round(time() - start_type, 3)}s ({idx + 1}/{len(formats)})',
f'\t> Done with format "{label}" in {round(time() - start_type, 3)}s ({idx + 1}/{len(leaf_formats)})',
time() - start_type,
)

for label, format in parent_formats.items():
if verbose:
start_type = time()
logging.info(f"\t- Propagating to parent format '{label}'")
children = parent_children[label]
for col in table.columns:
child_scores = [
return_table.loc[child, col]
for child in children
if child in return_table.index
]
max_child = max(child_scores) if child_scores else 0
if max_child > 0:
return_table.loc[label, col] = max_child
else:
return_table.loc[label, col] = test_col_val(
table[col],
format,
skipna=skipna,
limited_output=limited_output,
verbose=verbose,
)
if verbose:
display_logs_depending_process_time(
f'\t> Done with parent format "{label}" in {round(time() - start_type, 3)}s',
time() - start_type,
)

if verbose:
display_logs_depending_process_time(
f"Done testing columns in {round(time() - start, 3)}s", time() - start
Expand Down Expand Up @@ -138,6 +175,7 @@ def test_col_chunks(
limited_output: bool,
skipna: bool = True,
verbose: bool = False,
parent_children: dict[str, list[str]] | None = None,
) -> tuple[pd.DataFrame, dict, dict[str, pd.Series]]:
def build_remaining_tests_per_col(return_table: pd.DataFrame) -> dict[str, list[str]]:
# returns a dict with the table's columns as keys and the list of remaining format labels to apply
Expand All @@ -156,7 +194,10 @@ def build_remaining_tests_per_col(return_table: pd.DataFrame) -> dict[str, list[
logging.info("Testing columns to get formats on chunks")

# analysing the sample to get a first guess
return_table = test_col(table, formats, limited_output, skipna=skipna, verbose=verbose)
return_table = test_col(
table, formats, limited_output, skipna=skipna, verbose=verbose,
parent_children=parent_children,
)
# mandatory_label formats are zeroed out at the end if the label doesn't match,
# so there's no point running the expensive field tests on those columns
mandatory_label_skip: dict[str, set[str]] = {
Expand Down Expand Up @@ -229,23 +270,50 @@ def build_remaining_tests_per_col(return_table: pd.DataFrame) -> dict[str, list[
if not any(remaining_tests for remaining_tests in remaining_tests_per_col.values()):
# no more potential tests to do on any column, early stop
break
parent_names = set(parent_children.keys()) if parent_children else set()
for col, fmt_labels in remaining_tests_per_col.items():
# testing each column with the tests that are still competing
# after previous batchs analyses
# testing each column with the leaf tests that are still competing
for label in fmt_labels:
if label in parent_names:
continue
batch_col_test = test_col_val(
batch[col],
formats[label],
limited_output=limited_output,
skipna=skipna,
)
return_table.loc[label, col] = (
# if this batch's column tested 0 then test fails overall
0
if batch_col_test == 0
# otherwise updating the score with weighted average
else ((return_table.loc[label, col] * idx + batch_col_test) / (idx + 1))
)
# propagate to parent formats
for parent_label in fmt_labels:
if parent_label not in parent_names:
continue
children = parent_children[parent_label]
child_scores = [
return_table.loc[child, col]
for child in children
if child in return_table.index
]
max_child = max(child_scores) if child_scores else 0
if max_child > 0:
return_table.loc[parent_label, col] = max_child
else:
batch_col_test = test_col_val(
batch[col],
formats[parent_label],
limited_output=limited_output,
skipna=skipna,
)
return_table.loc[parent_label, col] = (
0
if batch_col_test == 0
else (
(return_table.loc[parent_label, col] * idx + batch_col_test) / (idx + 1)
)
)
remaining_tests_per_col = build_remaining_tests_per_col(return_table)
batch, batch_number = [], batch_number + 1
analysis["nb_duplicates"] = sum(row_hashes_count > 1)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ def test_early_detection(args):
mock_func.assert_not_called()


def test_parent_score_propagation():
# latitude_wgs_fr_metropole values are valid floats,
# so float should get a propagated score without being tested directly
lat_values = fmtm.formats["latitude_wgs_fr_metropole"]._test_values[True]
table = pd.DataFrame({"lat": (lat_values * 100)[:100]})
parent_children = {
parent: [c for c in children]
for parent, children in fmtm._children.items()
}
returned_table = col_test(table, fmtm.formats, limited_output=True, parent_children=parent_children)
assert returned_table.loc["latitude_wgs_fr_metropole", "lat"] > 0
assert returned_table.loc["latitude_wgs", "lat"] > 0
assert returned_table.loc["float", "lat"] > 0


def test_parent_tested_when_children_fail():
# float values that are NOT latitudes should still get float detected
table = pd.DataFrame({"val": ["999.99", "-500.5", "123456.789"] * 34})
parent_children = {
parent: [c for c in children]
for parent, children in fmtm._children.items()
}
returned_table = col_test(table, fmtm.formats, limited_output=True, parent_children=parent_children)
assert returned_table.loc["float", "val"] > 0
assert returned_table.loc["latitude_wgs", "val"] == 0
assert returned_table.loc["latitude_wgs_fr_metropole", "val"] == 0


def test_all_proportion_1():
# building a table that uses only correct values for these formats, except on one row
table = pd.DataFrame(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ def test_get_from_tags(tags):
assert tag in fmt.tags


def test_parent_references_valid():
for name, fmt in fmtm.formats.items():
if fmt.parent is not None:
assert fmt.parent in fmtm.formats, (
f"Format '{name}' declares parent '{fmt.parent}' which does not exist"
)


def test_true_subset_invariant():
for name, fmt in fmtm.formats.items():
if fmt.parent is None:
continue
parent = fmtm.formats[fmt.parent]
for val in fmt._test_values[True]:
assert parent.func(val), (
f"'{val}' is valid for '{name}' but not for parent '{fmt.parent}'"
)


def test_leaf_formats_excludes_parents():
leaves = fmtm.get_leaf_formats()
for parent_name in fmtm._children:
assert parent_name not in leaves


def test_get_ancestors():
ancestors = fmtm.get_ancestors("latitude_wgs_fr_metropole")
assert ancestors == ["latitude_wgs", "float"]


@pytest.mark.parametrize(
"func, max_pos_args",
(
Expand Down