Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#### New Features

- Added support for `DataFrame.pipe`.
- Added `artifact_repository` support to `udtf_configs` in `session.read.dbapi()`, enabling users to specify a custom artifact repository (e.g. PyPI) for packages used by the internal UDTF during distributed ingestion.

#### Bug Fixes
Expand Down
1 change: 1 addition & 0 deletions docs/source/snowpark/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ DataFrame
DataFrame.natural_join
DataFrame.orderBy
DataFrame.order_by
DataFrame.pipe
DataFrame.pivot
DataFrame.print_schema
DataFrame.printSchema
Expand Down
47 changes: 47 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
Expand Down Expand Up @@ -243,10 +244,23 @@
else:
from collections.abc import Iterable

# Python 3.9 needs to use typing_extensions.ParamSpec and typing_extensions.Concatenate
# Python 3.10+ can use typing.ParamSpec and typing.Concatenate because they are available in the standard library
if sys.version_info < (3, 10):
from typing_extensions import Concatenate, ParamSpec
else:
from typing import Concatenate, ParamSpec


if TYPE_CHECKING:
import modin.pandas # pragma: no cover
from table import Table # pragma: no cover


T = TypeVar("T")
P = ParamSpec("P")


_logger = getLogger(__name__)

_ONE_MILLION = 1000000
Expand Down Expand Up @@ -7099,6 +7113,39 @@ def print_schema(self, level: Optional[int] = None) -> None:
# naturalJoin = natural_join
# withColumns = with_columns

def pipe(
self,
function: Callable[Concatenate["DataFrame", P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""Applies a function to the DataFrame and returns the result.

Args:
function: A user-defined function (UDF) to apply to the DataFrame.
*args: Additional positional arguments to pass to the UDF.
**kwargs: Additional keyword arguments to pass to the UDF.

Returns:
The result of applying the function to the DataFrame.

Example::

>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> def test_function(df: DataFrame, col: str, threshold: float = 0):
... df = df.filter(df[col] > threshold)
... return df.collect()
>>> result = df.pipe(test_function, "a", threshold=1)
>>> result.show()
-------------
|"A" |"B" |
-------------
|3 |4 |
-------------
<BLANKLINE>
"""
return function(self, *args, **kwargs)


def map(
dataframe: DataFrame,
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,28 @@ def test_dataFrame_printSchema(capfd, mock_server_connection):
)


def test_dataframe_pipe(session):
df: DataFrame = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])

# test normal function
def test_function(df: DataFrame, col: str, threshold: float = 0.0):
df = df.filter(df[col] > threshold)
return df.collect(), df.count()

result, expected_result = df.pipe(test_function, "a", threshold=1), test_function(
df, "a", 1
)

assert result == expected_result

# test lambda function
result, expected_result = df.pipe(lambda x: int(x.count())), (
lambda x: int(x.count())
)(df)

assert result == expected_result


def test_session():
fake_session = mock.create_autospec(Session, _session_id=123456)
fake_session._analyzer = mock.Mock()
Expand Down
Loading