diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index eb938a7140e29..79ced506ce1ab 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -156,6 +156,8 @@ All warnings for upcoming changes in pandas will have the base class :class:`pan Other enhancements ^^^^^^^^^^^^^^^^^^ +- :class:`pandas.NamedAgg` now forwards any ``*args`` and ``**kwargs`` + to calls of ``aggfunc`` (:issue:`58283`) - :func:`pandas.merge` propagates the ``attrs`` attribute to the result if all inputs have identical ``attrs``, as has so far already been the case for :func:`pandas.concat`. diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index d279594617235..c4a8049a307ac 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -16,7 +16,7 @@ TYPE_CHECKING, Any, Literal, - NamedTuple, + Self, TypeAlias, TypeVar, cast, @@ -113,11 +113,11 @@ @set_module("pandas") -class NamedAgg(NamedTuple): +class NamedAgg(tuple): """ Helper for column specific aggregation with control over output column names. - Subclass of typing.NamedTuple. + Subclass of tuple. Parameters ---------- @@ -126,6 +126,8 @@ class NamedAgg(NamedTuple): aggfunc : function or str Function to apply to the provided column. If string, the name of a built-in pandas function. + *args, **kwargs : + Optional positional and keyword arguments passed to ``aggfunc``. See Also -------- @@ -133,19 +135,62 @@ class NamedAgg(NamedTuple): Examples -------- - >>> df = pd.DataFrame({"key": [1, 1, 2], "a": [-1, 0, 1], 1: [10, 11, 12]}) + >>> df = pd.DataFrame({"key": [1, 1, 2], "a": [-1, 0, 1], "b": [10, 11, 12]}) >>> agg_a = pd.NamedAgg(column="a", aggfunc="min") - >>> agg_1 = pd.NamedAgg(column=1, aggfunc=lambda x: np.mean(x)) - >>> df.groupby("key").agg(result_a=agg_a, result_1=agg_1) - result_a result_1 + >>> agg_b = pd.NamedAgg(column="b", aggfunc=lambda x: x.mean()) + >>> df.groupby("key").agg(result_a=agg_a, result_b=agg_b) + result_a result_b key 1 -1 10.5 2 1 12.0 + + >>> def n_between(ser, low, high, **kwargs): + ... return ser.between(low, high, **kwargs).sum() + + >>> agg_between = pd.NamedAgg("a", n_between, 0, 1) + >>> df.groupby("key").agg(count_between=agg_between) + count_between + key + 1 1 + 2 1 + + >>> agg_between_kw = pd.NamedAgg("a", n_between, 0, 1, inclusive="both") + >>> df.groupby("key").agg(count_between_kw=agg_between_kw) + count_between_kw + key + 1 1 + 2 1 """ column: Hashable aggfunc: AggScalar + __slots__ = () + + def __new__( + cls, + column: Hashable, + aggfunc: Callable[..., Any] | str, + *args: Any, + **kwargs: Any, + ) -> Self: + if ( + callable(aggfunc) + and not getattr(aggfunc, "_is_wrapped", False) + and (args or kwargs) + ): + original_func = aggfunc + + def wrapped(*call_args: Any, **call_kwargs: Any) -> Any: + series = call_args[0] + final_args = call_args[1:] + args + final_kwargs = {**kwargs, **call_kwargs} + return original_func(series, *final_args, **final_kwargs) + + wrapped._is_wrapped = True # type: ignore[attr-defined] + aggfunc = wrapped + return super().__new__(cls, (column, aggfunc)) + @set_module("pandas.api.typing") class SeriesGroupBy(GroupBy[Series]): diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index c968587c469d1..5fb3666b4cdb3 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -866,6 +866,57 @@ def test_agg_namedtuple(self): expected = df.groupby("A").agg(b=("B", "sum"), c=("B", "count")) tm.assert_frame_equal(result, expected) + def n_between(self, ser, low, high, **kwargs): + return ser.between(low, high, **kwargs).sum() + + def test_namedagg_args(self): + df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]}) + + result = df.groupby("A").agg( + count_between=pd.NamedAgg("B", self.n_between, 0, 1) + ) + expected = DataFrame({"count_between": [1, 1]}, index=Index([0, 1], name="A")) + tm.assert_frame_equal(result, expected) + + def test_namedagg_kwargs(self): + df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]}) + + result = df.groupby("A").agg( + count_between_kw=pd.NamedAgg("B", self.n_between, 0, 1, inclusive="both") + ) + expected = DataFrame( + {"count_between_kw": [1, 1]}, index=Index([0, 1], name="A") + ) + tm.assert_frame_equal(result, expected) + + def test_namedagg_args_and_kwargs(self): + df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]}) + + result = df.groupby("A").agg( + count_between_mix=pd.NamedAgg( + "B", self.n_between, 0, 1, inclusive="neither" + ) + ) + expected = DataFrame( + {"count_between_mix": [0, 0]}, index=Index([0, 1], name="A") + ) + tm.assert_frame_equal(result, expected) + + def test_multiple_named_agg_with_args_and_kwargs(self): + df = DataFrame({"A": [0, 1, 2, 3], "B": [1, 2, 3, 4]}) + + result = df.groupby("A").agg( + n_between01=pd.NamedAgg("B", self.n_between, 0, 1), + n_between13=pd.NamedAgg("B", self.n_between, 1, 3), + n_between02=pd.NamedAgg("B", self.n_between, 0, 2), + ) + expected = df.groupby("A").agg( + n_between01=("B", lambda x: x.between(0, 1).sum()), + n_between13=("B", lambda x: x.between(0, 3).sum()), + n_between02=("B", lambda x: x.between(0, 2).sum()), + ) + tm.assert_frame_equal(result, expected) + def test_mangled(self): df = DataFrame({"A": [0, 1], "B": [1, 2], "C": [3, 4]}) result = df.groupby("A").agg(b=("B", lambda x: 0), c=("C", lambda x: 1))