diff --git a/exaflow/algorithms/federated/statistics/histogram.py b/exaflow/algorithms/federated/statistics/histogram.py index e4e6f86d3..487dea91a 100644 --- a/exaflow/algorithms/federated/statistics/histogram.py +++ b/exaflow/algorithms/federated/statistics/histogram.py @@ -85,7 +85,9 @@ def hist( ) def _value_counts(self, series: pd.Series, categories) -> List[int]: - counts = series.value_counts() + # Use a plain dict to avoid pandas Series.get positional fallback on + # integer-like keys (e.g. key=0 on index ["1"] returning first element). + counts = series.value_counts().to_dict() resolved = [] for cat in categories: count = counts.get(cat, 0) @@ -94,7 +96,7 @@ def _value_counts(self, series: pd.Series, categories) -> List[int]: resolved.append(int(count)) return resolved - def _value_counts_numeric_fallback(self, counts, cat: str) -> int: + def _value_counts_numeric_fallback(self, counts: Dict, cat: str) -> int: for caster in (float, int): try: coerced = caster(cat) diff --git a/tests/standalone_tests/federated_algorithms/statistics/test_histogram.py b/tests/standalone_tests/federated_algorithms/statistics/test_histogram.py index 665cfc9b3..6a054bd52 100644 --- a/tests/standalone_tests/federated_algorithms/statistics/test_histogram.py +++ b/tests/standalone_tests/federated_algorithms/statistics/test_histogram.py @@ -296,3 +296,31 @@ def test_federated_algorithm_with_multiple_workers(self, case): y_levels=y_levels, min_row_count=1, ) + + @pytest.mark.parametrize("n_workers", [1, 3]) + def test_numeric_like_enum_codes_are_counted_by_label(self, n_workers): + y_levels = ["0", "1", "9"] + df = pd.DataFrame( + { + "y": ["1", "1", "1", "1", "1", "1"], + "group": ["A", "A", "B", "B", "C", "C"], + } + ) + metadata = { + "y": { + "is_categorical": True, + "enumerations": {"0": "No", "1": "Yes", "9": "Unknown"}, + }, + "group": { + "is_categorical": True, + "enumerations": {"A": "A", "B": "B", "C": "C"}, + }, + } + self.run_comparison( + X=df, + y=np.zeros((df.shape[0],), dtype=float), + n_workers=n_workers, + metadata=metadata, + y_levels=y_levels, + min_row_count=1, + )