From 6561e7014c6eee6d81f98b2b19161fb30f0eec5d Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 4 Dec 2020 15:03:54 -0800 Subject: [PATCH 1/7] add axes argument and return axes --- .../covidcast-py/covidcast/plotting.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index 2eb0e78b..f161da80 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -43,7 +43,8 @@ def plot(data: pd.DataFrame, time_value: date = None, plot_type: str = "choropleth", combine_megacounties: bool = True, - **kwargs: Any) -> figure.Figure: + ax = None, + **kwargs: Any) -> axes.Axes: """Given the output data frame of :py:func:`covidcast.signal`, plot a choropleth or bubble map. Projections used for plotting: @@ -79,7 +80,8 @@ def plot(data: pd.DataFrame, Defaults to `True`. :param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``. :param plot_type: Type of plot to create. Either choropleth (default) or bubble map. - :return: Matplotlib figure object. + :param ax: Optional matplotlib axis to plot on. + :return: Matplotlib axes object. """ if plot_type not in {"choropleth", "bubble"}: @@ -92,26 +94,27 @@ def plot(data: pd.DataFrame, kwargs["vmax"] = kwargs.get("vmax", meta["mean_value"] + 3 * meta["stdev_value"]) kwargs["figsize"] = kwargs.get("figsize", (12.8, 9.6)) - fig, ax = _plot_background_states(kwargs["figsize"]) + ax = _plot_background_states(kwargs["figsize"]) if ax is None else ax + ax.axis("off") ax.set_title(f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}") if plot_type == "choropleth": _plot_choro(ax, day_data, combine_megacounties, **kwargs) else: _plot_bubble(ax, day_data, geo_type, **kwargs) - return fig + return ax def plot_choropleth(data: pd.DataFrame, time_value: date = None, combine_megacounties: bool = True, - **kwargs: Any) -> figure.Figure: + **kwargs: Any) -> axes.Axes: """Plot choropleths for a signal. This method is deprecated and has been generalized to plot(). :param data: Data frame of signal values, as returned from :py:func:`covidcast.signal`. :param time_value: If multiple days of data are present in ``data``, map only values from this day. Defaults to plotting the most recent day of data in ``data``. :param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``. - :return: Matplotlib figure object. + :return: Matplotlib axes object. """ warnings.warn("Function `plot_choropleth` is deprecated. Use `plot()` instead.") return plot(data, time_value, "choropleth", combine_megacounties, **kwargs) @@ -286,21 +289,22 @@ def _plot_bubble(ax: axes.Axes, data: gpd.GeoDataFrame, geo_type: str, **kwargs: ax.legend(frameon=False, ncol=8, loc="lower center", bbox_to_anchor=(0.5, -0.1)) -def _plot_background_states(figsize: tuple) -> tuple: +def _plot_background_states(figsize: tuple, ax=None) -> axes.Axes: """Plot US states in light grey as the background for other plots. :param figsize: Dimensions of plot. - :return: Matplotlib figure and axes. + :param ax: Optional matplotlib axis to plot on. + :return: Matplotlib axes. """ - fig, ax = plt.subplots(1, figsize=figsize) - ax.axis("off") + if ax is None: + fig, ax = plt.subplots(1, figsize=figsize) state_shapefile_path = pkg_resources.resource_filename(__name__, SHAPEFILE_PATHS["state"]) state = gpd.read_file(state_shapefile_path) for state in _project_and_transform(state, "STATEFP"): state.plot(color="0.9", ax=ax, edgecolor="0.8", linewidth=0.5) ax.set_xlim(plt.xlim()) ax.set_ylim(plt.ylim()) - return fig, ax + return ax def _project_and_transform(data: gpd.GeoDataFrame, From 9c146e98c9f259f5f0dfdfbb6887768c2f93cd26 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 4 Dec 2020 15:28:47 -0800 Subject: [PATCH 2/7] fix tests --- .../covidcast-py/tests/test_plotting.py | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/Python-packages/covidcast-py/tests/test_plotting.py b/Python-packages/covidcast-py/tests/test_plotting.py index 8d458cfb..01ac6e71 100644 --- a/Python-packages/covidcast-py/tests/test_plotting.py +++ b/Python-packages/covidcast-py/tests/test_plotting.py @@ -3,6 +3,7 @@ from unittest.mock import patch import matplotlib +from matplotlib import pyplot as plt import platform import geopandas as gpd import numpy as np @@ -49,21 +50,22 @@ def test_plot(mock_metadata): test_county["value"] = test_county.value.astype("float") # w/o megacounties - no_mega_fig1 = plotting.plot(test_county, - time_value=date(2020, 8, 4), - combine_megacounties=False) + plotting.plot(test_county, time_value=date(2020, 8, 4), combine_megacounties=False) + no_mega_fig1 = plt.gcf() # give margin of +-2 for floating point errors and weird variations (1 isn't consistent) assert np.allclose(_convert_to_array(no_mega_fig1), expected["no_mega_1"], atol=2, rtol=0) - no_mega_fig2 = plotting.plot_choropleth(test_county, - cmap="viridis", - figsize=(5, 5), - edgecolor="0.8", - combine_megacounties=False) + plotting.plot_choropleth(test_county, + cmap="viridis", + figsize=(5, 5), + edgecolor="0.8", + combine_megacounties=False) + no_mega_fig2 = plt.gcf() assert np.allclose(_convert_to_array(no_mega_fig2), expected["no_mega_2"], atol=2, rtol=0) # w/ megacounties - mega_fig = plotting.plot_choropleth(test_county, time_value=date(2020, 8, 4)) + plotting.plot_choropleth(test_county, time_value=date(2020, 8, 4)) + mega_fig = plt.gcf() # give margin of +-2 for floating point errors and weird variations (1 isn't consistent) assert np.allclose(_convert_to_array(mega_fig), expected["mega"], atol=2, rtol=0) @@ -72,7 +74,8 @@ def test_plot(mock_metadata): os.path.join(CURRENT_PATH, "reference_data/test_input_state_signal.csv"), dtype=str) test_state["time_value"] = test_state.time_value.astype("datetime64[D]") test_state["value"] = test_state.value.astype("float") - state_fig = plotting.plot(test_state) + plotting.plot(test_state) + state_fig = plt.gcf() assert np.allclose(_convert_to_array(state_fig), expected["state"], atol=2, rtol=0) # test MSA @@ -80,12 +83,13 @@ def test_plot(mock_metadata): os.path.join(CURRENT_PATH, "reference_data/test_input_msa_signal.csv"), dtype=str) test_msa["time_value"] = test_msa.time_value.astype("datetime64[D]") test_msa["value"] = test_msa.value.astype("float") - msa_fig = plotting.plot(test_msa) + plotting.plot(test_msa) + msa_fig = plt.gcf() assert np.allclose(_convert_to_array(msa_fig), expected["msa"], atol=2, rtol=0) # test bubble - msa_bubble_fig = plotting.plot(test_msa, plot_type="bubble") - from matplotlib import pyplot as plt + plotting.plot(test_msa, plot_type="bubble") + msa_bubble_fig = plt.gcf() assert np.allclose(_convert_to_array(msa_bubble_fig), expected["msa_bubble"], atol=2, rtol=0) From 516fc303d9ba1cf1cb16ba4b1b13e73580bba4d9 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 4 Dec 2020 15:35:22 -0800 Subject: [PATCH 3/7] Add type annotation for ax args --- Python-packages/covidcast-py/covidcast/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index f161da80..6eb9ec32 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -43,7 +43,7 @@ def plot(data: pd.DataFrame, time_value: date = None, plot_type: str = "choropleth", combine_megacounties: bool = True, - ax = None, + ax: axes.Axes = None, **kwargs: Any) -> axes.Axes: """Given the output data frame of :py:func:`covidcast.signal`, plot a choropleth or bubble map. @@ -289,7 +289,7 @@ def _plot_bubble(ax: axes.Axes, data: gpd.GeoDataFrame, geo_type: str, **kwargs: ax.legend(frameon=False, ncol=8, loc="lower center", bbox_to_anchor=(0.5, -0.1)) -def _plot_background_states(figsize: tuple, ax=None) -> axes.Axes: +def _plot_background_states(figsize: tuple, ax: axes.Axes = None) -> axes.Axes: """Plot US states in light grey as the background for other plots. :param figsize: Dimensions of plot. From 0877ef640b769c3a4fb2ea9bd74455730d8526af Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 7 Dec 2020 09:00:03 -0800 Subject: [PATCH 4/7] Fix docstrings --- Python-packages/covidcast-py/covidcast/plotting.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index 6eb9ec32..a8efaa8a 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -72,6 +72,9 @@ def plot(data: pd.DataFrame, bubble but have the region displayed in white, and values above the mean + 3 std dev are binned into the highest bubble. Bubbles are scaled by area. + A Matplotlib Axes object can be provided to plot the maps onto an existing figure. Otherwise, + a new Axes object will be created and returned. + :param data: Data frame of signal values, as returned from :py:func:`covidcast.signal`. :param time_value: If multiple days of data are present in ``data``, map only values from this day. Defaults to plotting the most recent day of data in ``data``. @@ -94,7 +97,7 @@ def plot(data: pd.DataFrame, kwargs["vmax"] = kwargs.get("vmax", meta["mean_value"] + 3 * meta["stdev_value"]) kwargs["figsize"] = kwargs.get("figsize", (12.8, 9.6)) - ax = _plot_background_states(kwargs["figsize"]) if ax is None else ax + ax = _plot_background_states(kwargs["figsize"]) if ax is None else _plot_background_states(ax) ax.axis("off") ax.set_title(f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}") if plot_type == "choropleth": @@ -110,6 +113,9 @@ def plot_choropleth(data: pd.DataFrame, **kwargs: Any) -> axes.Axes: """Plot choropleths for a signal. This method is deprecated and has been generalized to plot(). + .. deprecated:: 0.1.1 + Use ``plot()`` instead. + :param data: Data frame of signal values, as returned from :py:func:`covidcast.signal`. :param time_value: If multiple days of data are present in ``data``, map only values from this day. Defaults to plotting the most recent day of data in ``data``. @@ -289,11 +295,11 @@ def _plot_bubble(ax: axes.Axes, data: gpd.GeoDataFrame, geo_type: str, **kwargs: ax.legend(frameon=False, ncol=8, loc="lower center", bbox_to_anchor=(0.5, -0.1)) -def _plot_background_states(figsize: tuple, ax: axes.Axes = None) -> axes.Axes: +def _plot_background_states(ax: axes.Axes = None, figsize: tuple = (12.8, 9.6)) -> axes.Axes: """Plot US states in light grey as the background for other plots. - :param figsize: Dimensions of plot. :param ax: Optional matplotlib axis to plot on. + :param figsize: Dimensions of plot. Ignored if ax is provided. :return: Matplotlib axes. """ if ax is None: From 96fdf31c76184b203daad93ac8e2cba0d11118fc Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 7 Dec 2020 09:23:28 -0800 Subject: [PATCH 5/7] fix argument order --- Python-packages/covidcast-py/covidcast/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index a8efaa8a..2cf9aa25 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -295,11 +295,11 @@ def _plot_bubble(ax: axes.Axes, data: gpd.GeoDataFrame, geo_type: str, **kwargs: ax.legend(frameon=False, ncol=8, loc="lower center", bbox_to_anchor=(0.5, -0.1)) -def _plot_background_states(ax: axes.Axes = None, figsize: tuple = (12.8, 9.6)) -> axes.Axes: +def _plot_background_states(figsize: tuple = (12.8, 9.6),ax: axes.Axes = None) -> axes.Axes: """Plot US states in light grey as the background for other plots. - :param ax: Optional matplotlib axis to plot on. :param figsize: Dimensions of plot. Ignored if ax is provided. + :param ax: Optional matplotlib axis to plot on. :return: Matplotlib axes. """ if ax is None: From 4097b6402b8c43bb82b5ea13c9a85629f8f35e54 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 7 Dec 2020 09:30:17 -0800 Subject: [PATCH 6/7] Fix keyword arg --- Python-packages/covidcast-py/covidcast/plotting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index 2cf9aa25..7009b411 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -97,7 +97,8 @@ def plot(data: pd.DataFrame, kwargs["vmax"] = kwargs.get("vmax", meta["mean_value"] + 3 * meta["stdev_value"]) kwargs["figsize"] = kwargs.get("figsize", (12.8, 9.6)) - ax = _plot_background_states(kwargs["figsize"]) if ax is None else _plot_background_states(ax) + ax = _plot_background_states(kwargs["figsize"]) if ax is None \ + else _plot_background_states(ax=ax) ax.axis("off") ax.set_title(f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}") if plot_type == "choropleth": From edd9b81eec57bf6d861ff405a12e02726b3ffa88 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 7 Dec 2020 09:59:22 -0800 Subject: [PATCH 7/7] fix limit issue --- Python-packages/covidcast-py/covidcast/plotting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Python-packages/covidcast-py/covidcast/plotting.py b/Python-packages/covidcast-py/covidcast/plotting.py index 7009b411..859630bb 100644 --- a/Python-packages/covidcast-py/covidcast/plotting.py +++ b/Python-packages/covidcast-py/covidcast/plotting.py @@ -296,7 +296,7 @@ def _plot_bubble(ax: axes.Axes, data: gpd.GeoDataFrame, geo_type: str, **kwargs: ax.legend(frameon=False, ncol=8, loc="lower center", bbox_to_anchor=(0.5, -0.1)) -def _plot_background_states(figsize: tuple = (12.8, 9.6),ax: axes.Axes = None) -> axes.Axes: +def _plot_background_states(figsize: tuple = (12.8, 9.6), ax: axes.Axes = None) -> axes.Axes: """Plot US states in light grey as the background for other plots. :param figsize: Dimensions of plot. Ignored if ax is provided. @@ -309,8 +309,8 @@ def _plot_background_states(figsize: tuple = (12.8, 9.6),ax: axes.Axes = None) - state = gpd.read_file(state_shapefile_path) for state in _project_and_transform(state, "STATEFP"): state.plot(color="0.9", ax=ax, edgecolor="0.8", linewidth=0.5) - ax.set_xlim(plt.xlim()) - ax.set_ylim(plt.ylim()) + ax.set_xlim(ax.get_xlim()) + ax.set_ylim(ax.get_ylim()) return ax