Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pytrendy/detect_trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .io.plot_pytrendy import plot_pytrendy
from .io.results_pytrendy import PyTrendyResults

def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, method_params:dict=None) -> PyTrendyResults:
def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, method_params:dict=None, plot_params:dict=None) -> PyTrendyResults:
"""
This is the main function that runs trend detection end-to-end.

Expand Down Expand Up @@ -38,9 +38,12 @@ def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, meth
Defaults to `True`.
method_params (dict, optional):
Optional parameters to customize detection heuristics. Supported keys:

- **is_abrupt_padded** (`bool`): Whether to pad abrupt transitions between segments. Defaults to `False`.
- **abrupt_padding** (`int`): Number of days to pad around abrupt transitions. Only referenced when `is_abrupt_padded` is `True`. Defaults to `28`.
plot_params (dict, optional):
Dictionary of plotting parameters to pass to `plot_pytrendy`. Supported keys:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify in comment to pass to plt.subplots() in plot_pytrendy. In case user is interested in trying other args as long as it's applicable to subplots call.

- **figsize** (`tuple`): The figure size for the plot. Defaults to (20,5)


Returns:
PyTrendyResults:
Expand All @@ -66,7 +69,7 @@ def detect_trends(df:pd.DataFrame, date_col:str, value_col: str, plot=True, meth
segments = get_segments(df)
segments = refine_segments(df, value_col, segments, method_params)
segments = analyse_segments(df, value_col, segments)
if plot: plot_pytrendy(df, value_col, segments)
if plot: plot_pytrendy(df, value_col, segments, plot_params=plot_params)

results = PyTrendyResults(segments)
return results
9 changes: 7 additions & 2 deletions pytrendy/io/plot_pytrendy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import matplotlib.dates as mdates
import matplotlib.patches as mpatches

def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict], suppress_show: bool = False) -> plt.Figure:
def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict], suppress_show: bool = False, plot_params: dict = None) -> plt.Figure:
"""
Visualizes detected trend segments over the original time series signal.

Expand All @@ -21,6 +21,8 @@ def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict
List of segment dictionaries containing keys like `'start'`, `'end'`, `'direction'`, `'trend_class'`, and `'change_rank'`.
suppress_show (bool, optional):
If True, suppresses the automatic display of the plot with plt.show(). Defaults to False.
plot_params (dict, optional):
Dictionary of plotting parameters. Currently supports 'figsize' (tuple). Defaults to None.

Returns:
matplotlib.figure.Figure:
Expand All @@ -35,7 +37,10 @@ def plot_pytrendy(df: pd.DataFrame, value_col: str, segments_enhanced: list[dict
'Noise': 'lightgray',
}

fig, ax = plt.subplots(figsize=(20, 5))
plot_params = plot_params or {}
figsize = plot_params.get("figsize", (20, 5))

fig, ax = plt.subplots(figsize=figsize)

# Plot the value line
ax.plot(df.index, df[value_col], color='black', lw=1)
Expand Down
35 changes: 35 additions & 0 deletions tests/tests_plotting/custom/test_custom_plot_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import pandas as pd
import pytrendy as pt
from pytrendy.io.plot_pytrendy import plot_pytrendy

class TestCustomPlotParams:
"""Test custom plot parameters for plot visualization."""

def _prepare_and_plot(self, df, value_col, segments, **kwargs):
"""Helper to prepare dataframe and create plot."""
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')[[value_col]]
return plot_pytrendy(df, value_col, segments, suppress_show=True, **kwargs)

@pytest.mark.core
@pytest.mark.plot
@pytest.mark.mpl_image_compare(baseline_dir='./', filename='test_custom_plot_params_figsize.png', style='default', remove_text=True)
def test_custom_plot_params_figsize(self):
"""Test custom figsize in plot parameters."""
df = pt.load_data('series_synthetic')
results = pt.detect_trends(
df,
date_col='date',
value_col='gradual',
plot=False,
method_params=dict(is_abrupt_padded=False)
)

plot_params = {
'figsize': (16, 8),
'title': "Custom Plot Title"
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove title it turns out, it can't be passed to subplots().

}

fig = self._prepare_and_plot(df, 'gradual', results.segments, plot_params=plot_params)
return fig
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading