diff --git a/corrai/sampling.py b/corrai/sampling.py index c8702c0..8fa8841 100644 --- a/corrai/sampling.py +++ b/corrai/sampling.py @@ -217,8 +217,10 @@ def plot_hist( unit: str = "", agg_method_kwarg: dict = None, reference_time_series: pd.Series = None, - bin_size: float = 1.0, + bins: int = 30, colors: str = "orange", + reference_value: int | float = None, + reference_label: str = "Reference", show_rug: bool = False, title: str = None, ): @@ -237,10 +239,16 @@ def plot_hist( Additional kwargs for aggregation. reference_time_series : Series, optional Reference time series. - bin_size : float, default=1.0 - Histogram bin size. + bins : int, default=30 + Histogram number of bins. colors : str, default="orange" Color of the histogram. + reference_value: int, float, optional + Add a vertical dashed red line at reference value. + May be used for comparison with an expected value + reference_label: str, optional + Label name for reference value line to be displayed in the legend. + Default is "Reference" show_rug : bool, default=False If True, display rug plot below histogram. title : str, optional @@ -263,11 +271,26 @@ def plot_hist( fig = ff.create_distplot( [res.squeeze().to_numpy()], [f"{method}_{indicator}"], - bin_size=bin_size, + bin_size=(res.max() - res.min()) / bins, colors=[colors], show_rug=show_rug, ) + if reference_value is not None: + counts, _ = np.histogram(res, bins=bins, density=True) + + fig.add_trace( + go.Scatter( + x=[reference_value, reference_value], + y=[0, max(counts)], + mode="lines", + line=dict(color="red", width=2, dash="dash"), + name=reference_label, + ) + ) + + # Make sure it spans the full y-axis range + fig.update_yaxes(range=[0, None]) # auto from 0 to max title = ( f"Sample distribution of {method} {indicator}" if title is None else title ) diff --git a/corrai/sensitivity.py b/corrai/sensitivity.py index 125eb68..485b141 100644 --- a/corrai/sensitivity.py +++ b/corrai/sensitivity.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from functools import wraps import datetime as dt import numpy as np @@ -7,6 +8,7 @@ from SALib.analyze import morris, sobol, fast, rbd_fast from corrai.base.parameter import Parameter from corrai.sampling import ( + Sample, SobolSampler, MorrisSampler, FASTSampler, @@ -84,6 +86,35 @@ def values(self): def results(self): return self.sampler.results + @wraps(Sample.plot_hist) + def plot_sample_hist( + self, + indicator: str, + method: str = "mean", + unit: str = "", + agg_method_kwarg: dict = None, + reference_time_series: pd.Series = None, + bins: int = 30, + colors: str = "orange", + reference_value: int | float = None, + reference_label: str = "Reference", + show_rug: bool = False, + title: str = None, + ): + return self.sampler.sample.plot_hist( + indicator=indicator, + method=method, + unit=unit, + agg_method_kwarg=agg_method_kwarg, + reference_time_series=reference_time_series, + bins=bins, + colors=colors, + reference_value=reference_value, + reference_label=reference_label, + show_rug=show_rug, + title=title, + ) + def plot_sample( self, indicator: str | None = None, diff --git a/tests/test_sampling.py b/tests/test_sampling.py index d0860b9..5eaa058 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -55,8 +55,9 @@ def test_plot_hist(self): indicator="res", method="mean", unit="J", - bin_size=0.5, + bins=10, colors="orange", + reference_value=70, show_rug=True, ) diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py index 2bbbc7e..9141d79 100644 --- a/tests/test_sensitivity.py +++ b/tests/test_sensitivity.py @@ -50,6 +50,10 @@ def test_sanalysis_sobol_with_sobol_sampler(self): pd.Timestamp("2009-01-01 05:00:00"), ] + sobol_analysis.plot_sample_hist( + "res", bins=10, reference_value=10, reference_label="ref" + ) + np.testing.assert_almost_equal( res["2009-01-01 00:00:00"]["S1"], np.array([0.33080399, 0.44206835, 0.00946747]),