From b79d58fc181dde3e8b82ce8f6e1367ffff09cd8c Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 4 Dec 2020 16:09:58 -0800 Subject: [PATCH] add title arg --- Python-packages/covidcast-py/covidcast/plotting.py | 9 +++++---- Python-packages/covidcast-py/tests/test_plotting.py | 9 +++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index 6eb9ec32..1fc9d909 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -44,6 +44,7 @@ def plot(data: pd.DataFrame, plot_type: str = "choropleth", combine_megacounties: bool = True, ax: axes.Axes = None, + title: str = None, **kwargs: Any) -> axes.Axes: """Given the output data frame of :py:func:`covidcast.signal`, plot a choropleth or bubble map. @@ -81,8 +82,8 @@ def plot(data: pd.DataFrame, :param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``. :param plot_type: Type of plot to create. Either choropleth (default) or bubble map. :param ax: Optional matplotlib axis to plot on. - :return: Matplotlib axes object. - + :param title: Plot title. If not provided, will default to "source: signal, day" + :return: Matplotlib figure object. """ if plot_type not in {"choropleth", "bubble"}: raise ValueError("`plot_type` must be 'choropleth' or 'bubble'.") @@ -93,10 +94,10 @@ def plot(data: pd.DataFrame, day_data = data.loc[data.time_value == pd.to_datetime(day_to_plot), :] kwargs["vmax"] = kwargs.get("vmax", meta["mean_value"] + 3 * meta["stdev_value"]) kwargs["figsize"] = kwargs.get("figsize", (12.8, 9.6)) - ax = _plot_background_states(kwargs["figsize"]) if ax is None else ax ax.axis("off") - ax.set_title(f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}") + ax.set_title( + f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}" if title is None else title) if plot_type == "choropleth": _plot_choro(ax, day_data, combine_megacounties, **kwargs) else: diff --git a/Python-packages/covidcast-py/tests/test_plotting.py b/Python-packages/covidcast-py/tests/test_plotting.py index 01ac6e71..cd1e7e8c 100644 --- a/Python-packages/covidcast-py/tests/test_plotting.py +++ b/Python-packages/covidcast-py/tests/test_plotting.py @@ -32,10 +32,11 @@ def _convert_to_array(fig: matplotlib.figure.Figure) -> np.array: @patch("covidcast.plotting._signal_metadata") def test_plot(mock_metadata): mock_metadata.side_effect = [ + {"mean_value": 0.5330011, "stdev_value": 0.4683431}, # county metadata {"mean_value": 0.5330011, "stdev_value": 0.4683431}, {"mean_value": 0.5330011, "stdev_value": 0.4683431}, - {"mean_value": 0.5330011, "stdev_value": 0.4683431}, - {"mean_value": 0.5304083, "stdev_value": 0.235302}, + {"mean_value": 0.5304083, "stdev_value": 0.235302}, # state metadata + {"mean_value": 0.5705364, "stdev_value": 0.4348706}, # msa metadata {"mean_value": 0.5705364, "stdev_value": 0.4348706}, {"mean_value": 0.5705364, "stdev_value": 0.4348706}, ] @@ -92,6 +93,10 @@ def test_plot(mock_metadata): msa_bubble_fig = plt.gcf() assert np.allclose(_convert_to_array(msa_bubble_fig), expected["msa_bubble"], atol=2, rtol=0) + # test title + ax = plotting.plot(test_msa, title="test") + assert ax.title.get_text() == "test" + def test_get_geo_df(): test_input = pd.DataFrame({"geo_value": ["24510", "31169", "37000"],