Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions corrai/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ def plot_hist(
unit: str = "",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
bin_size: float = 1.0,
bins: int = 30,
colors: str = "orange",
reference_value: int | float = None,
reference_label: str = "Reference",
show_rug: bool = False,
title: str = None,
):
Expand All @@ -237,10 +239,16 @@ def plot_hist(
Additional kwargs for aggregation.
reference_time_series : Series, optional
Reference time series.
bin_size : float, default=1.0
Histogram bin size.
bins : int, default=30
Histogram number of bins.
colors : str, default="orange"
Color of the histogram.
reference_value: int, float, optional
Add a vertical dashed red line at reference value.
May be used for comparison with an expected value
reference_label: str, optional
Label name for reference value line to be displayed in the legend.
Default is "Reference"
show_rug : bool, default=False
If True, display rug plot below histogram.
title : str, optional
Expand All @@ -263,11 +271,26 @@ def plot_hist(
fig = ff.create_distplot(
[res.squeeze().to_numpy()],
[f"{method}_{indicator}"],
bin_size=bin_size,
bin_size=(res.max() - res.min()) / bins,
colors=[colors],
show_rug=show_rug,
)

if reference_value is not None:
counts, _ = np.histogram(res, bins=bins, density=True)

fig.add_trace(
go.Scatter(
x=[reference_value, reference_value],
y=[0, max(counts)],
mode="lines",
line=dict(color="red", width=2, dash="dash"),
name=reference_label,
)
)

# Make sure it spans the full y-axis range
fig.update_yaxes(range=[0, None]) # auto from 0 to max
title = (
f"Sample distribution of {method} {indicator}" if title is None else title
)
Expand Down
31 changes: 31 additions & 0 deletions corrai/sensitivity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from functools import wraps

import datetime as dt
import numpy as np
Expand All @@ -7,6 +8,7 @@
from SALib.analyze import morris, sobol, fast, rbd_fast
from corrai.base.parameter import Parameter
from corrai.sampling import (
Sample,
SobolSampler,
MorrisSampler,
FASTSampler,
Expand Down Expand Up @@ -84,6 +86,35 @@ def values(self):
def results(self):
return self.sampler.results

@wraps(Sample.plot_hist)
def plot_sample_hist(
self,
indicator: str,
method: str = "mean",
unit: str = "",
agg_method_kwarg: dict = None,
reference_time_series: pd.Series = None,
bins: int = 30,
colors: str = "orange",
reference_value: int | float = None,
reference_label: str = "Reference",
show_rug: bool = False,
title: str = None,
):
return self.sampler.sample.plot_hist(
indicator=indicator,
method=method,
unit=unit,
agg_method_kwarg=agg_method_kwarg,
reference_time_series=reference_time_series,
bins=bins,
colors=colors,
reference_value=reference_value,
reference_label=reference_label,
show_rug=show_rug,
title=title,
)

def plot_sample(
self,
indicator: str | None = None,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def test_plot_hist(self):
indicator="res",
method="mean",
unit="J",
bin_size=0.5,
bins=10,
colors="orange",
reference_value=70,
show_rug=True,
)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def test_sanalysis_sobol_with_sobol_sampler(self):
pd.Timestamp("2009-01-01 05:00:00"),
]

sobol_analysis.plot_sample_hist(
"res", bins=10, reference_value=10, reference_label="ref"
)

np.testing.assert_almost_equal(
res["2009-01-01 00:00:00"]["S1"],
np.array([0.33080399, 0.44206835, 0.00946747]),
Expand Down