Skip to content
74 changes: 37 additions & 37 deletions corrai/base/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,7 @@ def get_parameters_intervals(self):
def get_list_parameter_value_pairs(
self, idx: int | list[int] | np.ndarray | slice = None
):
idx = slice(None) if idx is None else idx

if isinstance(idx, int) or (
isinstance(idx, list) and all(isinstance(x, bool) for x in idx)
):
idx = np.array(idx)

selected_values = self.values[idx]
selected_values = self[idx]["values"]

if selected_values.ndim == 1:
selected_values = selected_values[np.newaxis, :]
Expand All @@ -92,6 +85,13 @@ def get_list_parameter_value_pairs(
for row in selected_values
]

def get_dimension_less_values(
self, idx: int | list[int] | np.ndarray | slice = slice(None)
):
values = self[idx]["values"]
intervals = self.get_parameters_intervals()
return (values - intervals[:, 0]) / (intervals[:, 1] - intervals[:, 0])

def add_samples(self, values: np.ndarray, results: list[pd.DataFrame] = None):
n_samples, n_params = values.shape
assert n_params == len(self.parameters), "Mismatch in number of parameters"
Expand All @@ -106,6 +106,34 @@ def add_samples(self, values: np.ndarray, results: list[pd.DataFrame] = None):

self.results = pd.concat([self.results, new_results], ignore_index=True)

def plot(
self,
indicator: str | None = None,
reference_timeseries: pd.Series | None = None,
title: str | None = None,
y_label: str | None = None,
x_label: str | None = None,
alpha: float = 0.5,
show_legends: bool = False,
round_ndigits: int = 2,
) -> go.Figure:
if self.results is None:
raise ValueError("No results available to plot. Run a simulation first.")

return plot_sample(
results=self.results,
indicator=indicator,
reference_timeseries=reference_timeseries,
title=title,
y_label=y_label,
x_label=x_label,
alpha=alpha,
show_legends=show_legends,
parameter_values=self.values,
parameter_names=[p.name for p in self.parameters],
round_ndigits=round_ndigits,
)


class Sampler(ABC):
def __init__(
Expand Down Expand Up @@ -203,35 +231,6 @@ def get_aggregate_time_series(
prefix,
)

def plot_sample(
self,
indicator: str | None = None,
reference_timeseries: pd.Series | None = None,
title: str | None = None,
y_label: str | None = None,
x_label: str | None = None,
alpha: float = 0.5,
show_legends: bool = False,
round_ndigits: int = 2,
) -> go.Figure:
if self.results is None:
raise ValueError("No results available to plot. Run a simulation first.")

return plot_sample(
results=self.results,
indicator=indicator,
reference_timeseries=reference_timeseries,
title=title,
y_label=y_label,
x_label=x_label,
alpha=alpha,
show_legends=show_legends,
parameter_values=self.values,
parameter_names=[p.name for p in self.parameters],
round_ndigits=round_ndigits,
)


class RealSampler(Sampler, ABC):
def __init__(
self,
Expand Down Expand Up @@ -261,6 +260,7 @@ def __init__(
self, parameters: list[Parameter], model: Model, simulation_options: dict = None
):
super().__init__(parameters, model, simulation_options)
self._dimless_values = None

def add_sample(
self,
Expand Down
Loading
Loading