diff --git a/pytrendy/detect_trends.py b/pytrendy/detect_trends.py index c0a7f35..f8f94ac 100644 --- a/pytrendy/detect_trends.py +++ b/pytrendy/detect_trends.py @@ -8,7 +8,7 @@ from .io.plot_pytrendy import plot_pytrendy from .io.results_pytrendy import PyTrendyResults -def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, method_params:dict=None) -> PyTrendyResults: +def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, method_params:dict=None, plot_params:dict=None) -> PyTrendyResults: """ This is the main function that runs trend detection end-to-end. @@ -38,9 +38,12 @@ def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, meth Defaults to `True`. method_params (dict, optional): Optional parameters to customize detection heuristics. Supported keys: - - **is_abrupt_padded** (`bool`): Whether to pad abrupt transitions between segments. Defaults to `False`. - **abrupt_padding** (`int`): Number of days to pad around abrupt transitions. Only referenced when `is_abrupt_padded` is `True`. Defaults to `28`. + plot_params (dict, optional): + Dictionary of plotting parameters to pass to `plot_pytrendy`. Supported keys: + - **figsize** (`tuple`): The figure size for the plot. Defaults to (20,5) + Returns: PyTrendyResults: @@ -66,7 +69,7 @@ def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, meth segments = get_segments(df) segments = refine_segments(df, value_col, segments, method_params) segments = analyse_segments(df, value_col, segments) - if plot: plot_pytrendy(df, value_col, segments) + if plot: plot_pytrendy(df, value_col, segments, plot_params=plot_params) results = PyTrendyResults(segments) return results \ No newline at end of file diff --git a/pytrendy/io/plot_pytrendy.py b/pytrendy/io/plot_pytrendy.py index 17d31b6..48c98ef 100644 --- a/pytrendy/io/plot_pytrendy.py +++ b/pytrendy/io/plot_pytrendy.py @@ -5,7 +5,7 @@ import matplotlib.dates as mdates import matplotlib.patches as mpatches -def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict], suppress_show: bool = False) -> plt.Figure: +def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict], suppress_show: bool = False, plot_params: dict = None) -> plt.Figure: """ Visualizes detected trend segments over the original time series signal. @@ -21,6 +21,8 @@ def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict List of segment dictionaries containing keys like `'start'`, `'end'`, `'direction'`, `'trend_class'`, and `'change_rank'`. suppress_show (bool, optional): If True, suppresses the automatic display of the plot with plt.show(). Defaults to False. + plot_params (dict, optional): + Dictionary of plotting parameters. Currently supports 'figsize' (tuple). Defaults to None. Returns: matplotlib.figure.Figure: @@ -35,7 +37,10 @@ def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict 'Noise': 'lightgray', } - fig, ax = plt.subplots(figsize=(20, 5)) + plot_params = plot_params or {} + figsize = plot_params.get("figsize", (20, 5)) + + fig, ax = plt.subplots(figsize=figsize) # Plot the value line ax.plot(df.index, df[value_col], color='black', lw=1) diff --git a/tests/tests_plotting/custom/test_custom_plot_params.py b/tests/tests_plotting/custom/test_custom_plot_params.py new file mode 100644 index 0000000..8703c64 --- /dev/null +++ b/tests/tests_plotting/custom/test_custom_plot_params.py @@ -0,0 +1,35 @@ +import pytest +import pandas as pd +import pytrendy as pt +from pytrendy.io.plot_pytrendy import plot_pytrendy + +class TestCustomPlotParams: + """Test custom plot parameters for plot visualization.""" + + def _prepare_and_plot(self, df, value_col, segments, **kwargs): + """Helper to prepare dataframe and create plot.""" + df['date'] = pd.to_datetime(df['date']) + df = df.set_index('date')[[value_col]] + return plot_pytrendy(df, value_col, segments, suppress_show=True, **kwargs) + + @pytest.mark.core + @pytest.mark.plot + @pytest.mark.mpl_image_compare(baseline_dir='./', filename='test_custom_plot_params_figsize.png', style='default', remove_text=True) + def test_custom_plot_params_figsize(self): + """Test custom figsize in plot parameters.""" + df = pt.load_data('series_synthetic') + results = pt.detect_trends( + df, + date_col='date', + value_col='gradual', + plot=False, + method_params=dict(is_abrupt_padded=False) + ) + + plot_params = { + 'figsize': (16, 8), + 'title': "Custom Plot Title" + } + + fig = self._prepare_and_plot(df, 'gradual', results.segments, plot_params=plot_params) + return fig diff --git a/tests/tests_plotting/custom/test_custom_plot_params_figsize.png b/tests/tests_plotting/custom/test_custom_plot_params_figsize.png new file mode 100644 index 0000000..acc555c Binary files /dev/null and b/tests/tests_plotting/custom/test_custom_plot_params_figsize.png differ