diff --git a/lightweight_mmm/plot.py b/lightweight_mmm/plot.py index cd1c74f..61b2b55 100644 --- a/lightweight_mmm/plot.py +++ b/lightweight_mmm/plot.py @@ -890,7 +890,8 @@ def plot_pre_post_budget_allocation_comparison( optimal_buget_allocation: jnp.ndarray, previous_budget_allocation: jnp.ndarray, channel_names: Optional[Sequence[Any]] = None, - figure_size: Tuple[int, int] = (20, 10) + figure_size: Tuple[int, int] = (20, 10), + save_path: Optional[str] = None ) -> matplotlib.figure.Figure: """Plots a barcharts to compare pre & post budget allocation. @@ -905,6 +906,7 @@ def plot_pre_post_budget_allocation_comparison( budget allocation proportion. channel_names: Names of media channels to be added to plot. figure_size: size of the plot. + save_path: Path to save the plotted figure. Returns: Barplots of budget allocation across media channels pre & post optimization. @@ -1004,6 +1006,11 @@ def plot_pre_post_budget_allocation_comparison( textcoords="offset points") plt.tight_layout() + + # Save the plot if save_path is provided + if save_path: + fig.savefig(save_path, bbox_inches="tight") + plt.close() return fig @@ -1014,6 +1021,7 @@ def plot_media_baseline_contribution_area_plot( channel_names: Optional[Sequence[Any]] = None, fig_size: Optional[Tuple[int, int]] = (20, 7), legend_outside: Optional[bool] = False, + save_path: Optional[str] = None ) -> matplotlib.figure.Figure: """Plots an area chart to visualize weekly media & baseline contribution. @@ -1023,6 +1031,7 @@ def plot_media_baseline_contribution_area_plot( channel_names: Names of media channels. fig_size: Size of the figure to plot as used by matplotlib. legend_outside: Put the legend outside of the chart, center-right. + save_path: Path to save the plotted figure. Returns: Stacked area chart of weekly baseline & media contribution. @@ -1072,6 +1081,11 @@ def plot_media_baseline_contribution_area_plot( for tick in ax.get_xticklabels(): tick.set_rotation(45) + + # Save the plot if save_path is provided + if save_path: + fig.savefig(save_path, bbox_inches="tight") + plt.close() return fig