diff --git a/CHANGELOG.md b/CHANGELOG.md index 959915450f..8174551ce4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +## 1.51.0 (TBD) + +### Snowpark Python API Updates + +#### New Features + +- Added support for `DataFrame.pipe`. + ## 1.50.0 (2026-04-23) ### Snowpark Python API Updates diff --git a/docs/source/snowpark/dataframe.rst b/docs/source/snowpark/dataframe.rst index a62bab7bc2..81a34a9839 100644 --- a/docs/source/snowpark/dataframe.rst +++ b/docs/source/snowpark/dataframe.rst @@ -67,6 +67,7 @@ DataFrame DataFrame.natural_join DataFrame.orderBy DataFrame.order_by + DataFrame.pipe DataFrame.pivot DataFrame.print_schema DataFrame.printSchema diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index be9065df05..71ac5be1dc 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -24,6 +24,7 @@ Optional, Set, Tuple, + TypeVar, Union, overload, ) @@ -243,10 +244,23 @@ else: from collections.abc import Iterable +# Python 3.9 needs to use typing_extensions.ParamSpec and typing_extensions.Concatenate +# Python 3.10+ can use typing.ParamSpec and typing.Concatenate because they are available in the standard library +if sys.version_info < (3, 10): + from typing_extensions import Concatenate, ParamSpec +else: + from typing import Concatenate, ParamSpec + + if TYPE_CHECKING: import modin.pandas # pragma: no cover from table import Table # pragma: no cover + +T = TypeVar("T") +P = ParamSpec("P") + + _logger = getLogger(__name__) _ONE_MILLION = 1000000 @@ -7099,6 +7113,39 @@ def print_schema(self, level: Optional[int] = None) -> None: # naturalJoin = natural_join # withColumns = with_columns + def pipe( + self, + function: Callable[Concatenate["DataFrame", P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + """Applies a function to the DataFrame and returns the result. + + Args: + function: A user-defined function (UDF) to apply to the DataFrame. + *args: Additional positional arguments to pass to the UDF. + **kwargs: Additional keyword arguments to pass to the UDF. + + Returns: + The result of applying the function to the DataFrame. + + Example:: + + >>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + >>> def test_function(df: DataFrame, col: str, threshold: float = 0): + ... df = df.filter(df[col] > threshold) + ... return df + >>> result = df.pipe(test_function, "a", threshold=1) + >>> result.show() + ------------- + |"A" |"B" | + ------------- + |3 |4 | + ------------- + + """ + return function(self, *args, **kwargs) + def map( dataframe: DataFrame, diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index 4d3e94f5f2..aa34d75999 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -409,6 +409,28 @@ def test_dataFrame_printSchema(capfd, mock_server_connection): ) +def test_dataframe_pipe(session): + df: DataFrame = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + + # test normal function + def test_function(df: DataFrame, col: str, threshold: float = 0.0): + df = df.filter(df[col] > threshold) + return df.collect(), df.count() + + result, expected_result = df.pipe(test_function, "a", threshold=1), test_function( + df, "a", 1 + ) + + assert result == expected_result + + # test lambda function + result, expected_result = df.pipe(lambda x: int(x.count())), ( + lambda x: int(x.count()) + )(df) + + assert result == expected_result + + def test_session(): fake_session = mock.create_autospec(Session, _session_id=123456) fake_session._analyzer = mock.Mock()