diff --git a/corrai/sampling.py b/corrai/sampling.py index 791d665..f81733b 100644 --- a/corrai/sampling.py +++ b/corrai/sampling.py @@ -114,26 +114,30 @@ class Sample: """ parameters: list[Parameter] - values: np.ndarray = field(init=False) + values: pd.DataFrame = field(init=False) results: pd.Series = field(default_factory=lambda: pd.Series(dtype=object)) def __post_init__(self): - self.values = np.empty((0, len(self.parameters))) + self.values = pd.DataFrame(columns=[par.name for par in self.parameters]) def __len__(self): return self.values.shape[0] def __getitem__(self, idx): if isinstance(idx, (int, slice, list, np.ndarray)): - return {"values": self.values[idx], "results": self.results[idx]} + return {"values": self.values.loc[idx, :], "results": self.results.loc[idx]} raise TypeError(f"Unsupported index type: {type(idx)}") def __setitem__(self, idx, item: dict): if "values" in item: - self.values[idx] = item["values"] + self.values.loc[idx, :] = item["values"] if "results" in item: if isinstance(idx, int): self.results.at[idx] = item["results"] + elif isinstance(idx, slice): + self.results.loc[idx] = pd.Series( + item["results"], index=self.results.loc[idx].index + ) else: self.results.iloc[idx] = pd.Series( item["results"], index=self.results.index[idx] @@ -146,6 +150,9 @@ def _validate(self): self.values ), f"Mismatch: {len(self.values)} values vs {len(self.results)} results" + if not self.values.index.equals(self.results.index): + raise ValueError("Mismatch between values and results indices") + def get_pending_index(self) -> np.ndarray: """ Identify which samples have not yet been simulated. @@ -201,12 +208,12 @@ def get_list_parameter_value_pairs( """ selected_values = self[idx]["values"] - if selected_values.ndim == 1: - selected_values = selected_values[np.newaxis, :] + if isinstance(selected_values, pd.Series): + selected_values = selected_values.to_frame().T return [ [(par, val) for par, val in zip(self.parameters, row)] - for row in selected_values + for row in selected_values.values ] def get_dimension_less_values( @@ -250,7 +257,8 @@ def add_samples(self, values: np.ndarray, results: list[pd.DataFrame] = None): n_samples, n_params = values.shape assert n_params == len(self.parameters), "Mismatch in number of parameters" - self.values = np.vstack([self.values, values]) + new_df = pd.DataFrame(values, columns=self.values.columns) + self.values = pd.concat([self.values, new_df], ignore_index=True) if results is None: new_results = pd.Series([pd.DataFrame()] * n_samples, dtype=object) @@ -539,7 +547,7 @@ def _legend_for(i: int) -> str: if not show_legends: return "Simulations" parameter_names = [par.name for par in self.parameters] - vals = self.values[i, :] + vals = self.values.loc[i, :].values return ", ".join( f"{n}: {round(v, round_ndigits)}" for n, v in zip(parameter_names, vals) ) diff --git a/corrai/sensitivity.py b/corrai/sensitivity.py index 1049dca..beeea05 100644 --- a/corrai/sensitivity.py +++ b/corrai/sensitivity.py @@ -172,7 +172,7 @@ def analyze( ) if self.x_needed: - analyse_kwargs["X"] = self.sampler.sample.get_dimension_less_values() + analyse_kwargs["X"] = self.sampler.sample.get_dimension_less_values().values analyse_kwargs["problem"] = self.sampler.get_salib_problem() diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 1a4a095..42a06b9 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -58,7 +58,17 @@ def test_sample_functions(self): ) assert sample.get_pending_index().tolist() == [True, False] - assert sample.values.tolist() == [[1.0, 0.9, 10.0], [3.0, 0.85, 20.0]] + pd.testing.assert_frame_equal( + sample.values, + pd.DataFrame( + { + "param_1": {0: 1.0, 1: 3.0}, + "param_2": {0: 0.9, 1: 0.85}, + "param_3": {0: 10.0, 1: 20.0}, + } + ), + ) + assert sample.get_parameters_intervals().tolist() == [ [0.0, 10.0], [0.8, 1.2], @@ -83,12 +93,22 @@ def test_sample_functions(self): "values": np.array([9.9, 1.1, 88]), "results": pd.DataFrame({"res": [123]}, index=[pd.Timestamp("2009-01-01")]), } - np.testing.assert_allclose(sample.values[0], [9.9, 1.1, 88]) + pd.testing.assert_series_equal( + sample.values.loc[0], + pd.Series({"param_1": 9.9, "param_2": 1.1, "param_3": 88.0}, name=0), + ) assert not sample.results.iloc[0].empty dimless_val = sample.get_dimension_less_values() - np.testing.assert_allclose( - dimless_val, np.array([[0.99, 0.75, 0.88], [0.3, 0.125, 0.2]]) + pd.testing.assert_frame_equal( + dimless_val, + pd.DataFrame( + { + "param_1": {0: 0.99, 1: 0.3}, + "param_2": {0: 0.7500000000000003, 1: 0.12499999999999986}, + "param_3": {0: 0.88, 1: 0.2}, + } + ), ) pd.testing.assert_frame_equal( @@ -302,9 +322,9 @@ def test_lhs_sampler(self): sampler.simulate_pending() expected = { - 0: [[85.75934698790918]], - 1: [[38.08478803524709]], - 2: [[61.67268698504139]], + 0: np.array([[85.75934698790918]]), + 1: np.array([[38.08478803524709]]), + 2: np.array([[61.67268698504139]]), } for k, arr in sampler.results.items(): @@ -328,7 +348,7 @@ def test_lhs_sampler(self): sampler.add_sample(3, rng=42, simulate=False) sampler.simulate_at(slice(4, 7)) - assert [df.empty for df in sampler.results[-3:].values] == [False, True, True] + assert [df.empty for df in sampler.results[-3:].values] == [False, False, True] sampler.add_sample(3, rng=42, simulate=False) sampler.simulate_at(slice(10, None))