From bdf3913a7011fe0ad3745920ac10118b265264c6 Mon Sep 17 00:00:00 2001 From: BaptisteDE Date: Fri, 5 Sep 2025 18:46:38 +0200 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=9A=A7=20Sample=20values=20numpy=20to?= =?UTF-8?q?=20DataFrame.=20Tests=20are=20failing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- corrai/sampling.py | 20 +++++++++++--------- tests/test_sampling.py | 16 +++++++++++----- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/corrai/sampling.py b/corrai/sampling.py index 791d665..12ee2ec 100644 --- a/corrai/sampling.py +++ b/corrai/sampling.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +import parso.pgen2.grammar_parser import plotly.graph_objects as go import datetime as dt @@ -114,23 +115,23 @@ class Sample: """ parameters: list[Parameter] - values: np.ndarray = field(init=False) + values: pd.DataFrame = field(init=False) results: pd.Series = field(default_factory=lambda: pd.Series(dtype=object)) def __post_init__(self): - self.values = np.empty((0, len(self.parameters))) + self.values = pd.DataFrame(columns=[par.name for par in self.parameters]) def __len__(self): return self.values.shape[0] def __getitem__(self, idx): if isinstance(idx, (int, slice, list, np.ndarray)): - return {"values": self.values[idx], "results": self.results[idx]} + return {"values": self.values.loc[idx, :], "results": self.results[idx]} raise TypeError(f"Unsupported index type: {type(idx)}") def __setitem__(self, idx, item: dict): if "values" in item: - self.values[idx] = item["values"] + self.values.loc[idx, :] = item["values"] if "results" in item: if isinstance(idx, int): self.results.at[idx] = item["results"] @@ -201,12 +202,12 @@ def get_list_parameter_value_pairs( """ selected_values = self[idx]["values"] - if selected_values.ndim == 1: - selected_values = selected_values[np.newaxis, :] + if isinstance(selected_values, pd.Series): + selected_values = selected_values.to_frame() return [ [(par, val) for par, val in zip(self.parameters, row)] - for row in selected_values + for row in selected_values.values ] def get_dimension_less_values( @@ -250,7 +251,8 @@ def add_samples(self, values: np.ndarray, results: list[pd.DataFrame] = None): n_samples, n_params = values.shape assert n_params == len(self.parameters), "Mismatch in number of parameters" - self.values = np.vstack([self.values, values]) + new_df = pd.DataFrame(values, columns=self.values.columns) + self.values = pd.concat([self.values, new_df], ignore_index=True) if results is None: new_results = pd.Series([pd.DataFrame()] * n_samples, dtype=object) @@ -539,7 +541,7 @@ def _legend_for(i: int) -> str: if not show_legends: return "Simulations" parameter_names = [par.name for par in self.parameters] - vals = self.values[i, :] + vals = self.values.loc[i, :].values return ", ".join( f"{n}: {round(v, round_ndigits)}" for n, v in zip(parameter_names, vals) ) diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 1a4a095..ffe59f2 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -58,7 +58,11 @@ def test_sample_functions(self): ) assert sample.get_pending_index().tolist() == [True, False] - assert sample.values.tolist() == [[1.0, 0.9, 10.0], [3.0, 0.85, 20.0]] + pd.testing.assert_frame_equal(sample.values, pd.DataFrame({'param_1': {0: 1.0, 1: 3.0}, + 'param_2': {0: 0.9, 1: 0.85}, + 'param_3': {0: 10.0, 1: 20.0}})) + + assert sample.get_parameters_intervals().tolist() == [ [0.0, 10.0], [0.8, 1.2], @@ -83,13 +87,15 @@ def test_sample_functions(self): "values": np.array([9.9, 1.1, 88]), "results": pd.DataFrame({"res": [123]}, index=[pd.Timestamp("2009-01-01")]), } - np.testing.assert_allclose(sample.values[0], [9.9, 1.1, 88]) + pd.testing.assert_series_equal(sample.values.loc[0], pd.Series({'param_1': 9.9, 'param_2': 1.1, 'param_3': 88.0}, name=0)) assert not sample.results.iloc[0].empty dimless_val = sample.get_dimension_less_values() - np.testing.assert_allclose( - dimless_val, np.array([[0.99, 0.75, 0.88], [0.3, 0.125, 0.2]]) - ) + pd.testing.assert_frame_equal( + dimless_val, pd.DataFrame({'param_1': {0: 0.99, 1: 0.3}, + 'param_2': {0: 0.7500000000000003, 1: 0.12499999999999986}, + 'param_3': {0: 0.88, 1: 0.2}})) + pd.testing.assert_frame_equal( sample.get_aggregated_time_series("res"), From 9d479ae34b666d3d0cb7b6e531b096779d6bae85 Mon Sep 17 00:00:00 2001 From: BaptisteDE Date: Mon, 8 Sep 2025 11:29:09 +0200 Subject: [PATCH 2/3] =?UTF-8?q?=E2=9C=85=20fixed=20Sampling=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- corrai/sampling.py | 12 +++++++++--- tests/test_sampling.py | 40 +++++++++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/corrai/sampling.py b/corrai/sampling.py index 12ee2ec..f81733b 100644 --- a/corrai/sampling.py +++ b/corrai/sampling.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd -import parso.pgen2.grammar_parser import plotly.graph_objects as go import datetime as dt @@ -126,7 +125,7 @@ def __len__(self): def __getitem__(self, idx): if isinstance(idx, (int, slice, list, np.ndarray)): - return {"values": self.values.loc[idx, :], "results": self.results[idx]} + return {"values": self.values.loc[idx, :], "results": self.results.loc[idx]} raise TypeError(f"Unsupported index type: {type(idx)}") def __setitem__(self, idx, item: dict): @@ -135,6 +134,10 @@ def __setitem__(self, idx, item: dict): if "results" in item: if isinstance(idx, int): self.results.at[idx] = item["results"] + elif isinstance(idx, slice): + self.results.loc[idx] = pd.Series( + item["results"], index=self.results.loc[idx].index + ) else: self.results.iloc[idx] = pd.Series( item["results"], index=self.results.index[idx] @@ -147,6 +150,9 @@ def _validate(self): self.values ), f"Mismatch: {len(self.values)} values vs {len(self.results)} results" + if not self.values.index.equals(self.results.index): + raise ValueError("Mismatch between values and results indices") + def get_pending_index(self) -> np.ndarray: """ Identify which samples have not yet been simulated. @@ -203,7 +209,7 @@ def get_list_parameter_value_pairs( selected_values = self[idx]["values"] if isinstance(selected_values, pd.Series): - selected_values = selected_values.to_frame() + selected_values = selected_values.to_frame().T return [ [(par, val) for par, val in zip(self.parameters, row)] diff --git a/tests/test_sampling.py b/tests/test_sampling.py index ffe59f2..42a06b9 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -58,10 +58,16 @@ def test_sample_functions(self): ) assert sample.get_pending_index().tolist() == [True, False] - pd.testing.assert_frame_equal(sample.values, pd.DataFrame({'param_1': {0: 1.0, 1: 3.0}, - 'param_2': {0: 0.9, 1: 0.85}, - 'param_3': {0: 10.0, 1: 20.0}})) - + pd.testing.assert_frame_equal( + sample.values, + pd.DataFrame( + { + "param_1": {0: 1.0, 1: 3.0}, + "param_2": {0: 0.9, 1: 0.85}, + "param_3": {0: 10.0, 1: 20.0}, + } + ), + ) assert sample.get_parameters_intervals().tolist() == [ [0.0, 10.0], @@ -87,15 +93,23 @@ def test_sample_functions(self): "values": np.array([9.9, 1.1, 88]), "results": pd.DataFrame({"res": [123]}, index=[pd.Timestamp("2009-01-01")]), } - pd.testing.assert_series_equal(sample.values.loc[0], pd.Series({'param_1': 9.9, 'param_2': 1.1, 'param_3': 88.0}, name=0)) + pd.testing.assert_series_equal( + sample.values.loc[0], + pd.Series({"param_1": 9.9, "param_2": 1.1, "param_3": 88.0}, name=0), + ) assert not sample.results.iloc[0].empty dimless_val = sample.get_dimension_less_values() pd.testing.assert_frame_equal( - dimless_val, pd.DataFrame({'param_1': {0: 0.99, 1: 0.3}, - 'param_2': {0: 0.7500000000000003, 1: 0.12499999999999986}, - 'param_3': {0: 0.88, 1: 0.2}})) - + dimless_val, + pd.DataFrame( + { + "param_1": {0: 0.99, 1: 0.3}, + "param_2": {0: 0.7500000000000003, 1: 0.12499999999999986}, + "param_3": {0: 0.88, 1: 0.2}, + } + ), + ) pd.testing.assert_frame_equal( sample.get_aggregated_time_series("res"), @@ -308,9 +322,9 @@ def test_lhs_sampler(self): sampler.simulate_pending() expected = { - 0: [[85.75934698790918]], - 1: [[38.08478803524709]], - 2: [[61.67268698504139]], + 0: np.array([[85.75934698790918]]), + 1: np.array([[38.08478803524709]]), + 2: np.array([[61.67268698504139]]), } for k, arr in sampler.results.items(): @@ -334,7 +348,7 @@ def test_lhs_sampler(self): sampler.add_sample(3, rng=42, simulate=False) sampler.simulate_at(slice(4, 7)) - assert [df.empty for df in sampler.results[-3:].values] == [False, True, True] + assert [df.empty for df in sampler.results[-3:].values] == [False, False, True] sampler.add_sample(3, rng=42, simulate=False) sampler.simulate_at(slice(10, None)) From 99c6963ed81535f84ae5801602b7297c6521fc4b Mon Sep 17 00:00:00 2001 From: BaptisteDE Date: Mon, 8 Sep 2025 11:44:24 +0200 Subject: [PATCH 3/3] =?UTF-8?q?=E2=9A=A1=EF=B8=8Fcompatible=20with=20Sampl?= =?UTF-8?q?e.values=20being=20df?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- corrai/sensitivity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/corrai/sensitivity.py b/corrai/sensitivity.py index 1049dca..beeea05 100644 --- a/corrai/sensitivity.py +++ b/corrai/sensitivity.py @@ -172,7 +172,7 @@ def analyze( ) if self.x_needed: - analyse_kwargs["X"] = self.sampler.sample.get_dimension_less_values() + analyse_kwargs["X"] = self.sampler.sample.get_dimension_less_values().values analyse_kwargs["problem"] = self.sampler.get_salib_problem()