Skip to content

Commit fe0cf8c

Browse files
authored
feat: allow DataFrame.filter to accept SQL strings\ (#1276)
1 parent 6b16285 commit fe0cf8c

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ DataFusion's DataFrame API offers a wide range of operations:
9595
# Select with expressions
9696
df = df.select(column("a") + column("b"), column("a") - column("b"))
9797
98-
# Filter rows
98+
# Filter rows (expressions or SQL strings)
9999
df = df.filter(column("age") > literal(25))
100+
df = df.filter("age > 25")
100101
101102
# Add computed columns
102103
df = df.with_column("full_name", column("first_name") + literal(" ") + column("last_name"))

python/datafusion/dataframe.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -466,31 +466,37 @@ def drop(self, *columns: str) -> DataFrame:
466466

467467
return DataFrame(self.df.drop(*normalized_columns))
468468

469-
def filter(self, *predicates: Expr) -> DataFrame:
469+
def filter(self, *predicates: Expr | str) -> DataFrame:
470470
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.
471471
472472
Rows for which ``predicate`` evaluates to ``False`` or ``None`` are filtered
473473
out. If more than one predicate is provided, these predicates will be
474-
combined as a logical AND. Each ``predicate`` must be an
474+
combined as a logical AND. Each ``predicate`` can be an
475475
:class:`~datafusion.expr.Expr` created using helper functions such as
476-
:func:`datafusion.col` or :func:`datafusion.lit`.
477-
If more complex logic is required, see the logical operations in
478-
:py:mod:`~datafusion.functions`.
476+
:func:`datafusion.col` or :func:`datafusion.lit`, or a SQL expression string
477+
that will be parsed against the DataFrame schema. If more complex logic is
478+
required, see the logical operations in :py:mod:`~datafusion.functions`.
479479
480480
Example::
481481
482482
from datafusion import col, lit
483483
df.filter(col("a") > lit(1))
484+
df.filter("a > 1")
484485
485486
Args:
486-
predicates: Predicate expression(s) to filter the DataFrame.
487+
predicates: Predicate expression(s) or SQL strings to filter the DataFrame.
487488
488489
Returns:
489490
DataFrame after filtering.
490491
"""
491492
df = self.df
492-
for p in predicates:
493-
df = df.filter(ensure_expr(p))
493+
for predicate in predicates:
494+
expr = (
495+
self.parse_sql_expr(predicate)
496+
if isinstance(predicate, str)
497+
else predicate
498+
)
499+
df = df.filter(ensure_expr(expr))
494500
return DataFrame(df)
495501

496502
def parse_sql_expr(self, expr: str) -> Expr:

python/tests/test_dataframe.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,29 @@ def test_filter(df):
306306
assert result.column(2) == pa.array([5])
307307

308308

309+
def test_filter_string_predicates(df):
310+
df_str = df.filter("a > 2")
311+
result = df_str.collect()[0]
312+
313+
assert result.column(0) == pa.array([3])
314+
assert result.column(1) == pa.array([6])
315+
assert result.column(2) == pa.array([8])
316+
317+
df_mixed = df.filter("a > 1", column("b") != literal(6))
318+
result_mixed = df_mixed.collect()[0]
319+
320+
assert result_mixed.column(0) == pa.array([2])
321+
assert result_mixed.column(1) == pa.array([5])
322+
assert result_mixed.column(2) == pa.array([5])
323+
324+
df_strings = df.filter("a > 1", "b < 6")
325+
result_strings = df_strings.collect()[0]
326+
327+
assert result_strings.column(0) == pa.array([2])
328+
assert result_strings.column(1) == pa.array([5])
329+
assert result_strings.column(2) == pa.array([5])
330+
331+
309332
def test_parse_sql_expr(df):
310333
plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan()
311334
plan2 = df.filter(column("a") > literal(2)).logical_plan()
@@ -388,9 +411,16 @@ def test_aggregate_tuple_aggs(df):
388411
assert result_tuple == result_list
389412

390413

391-
def test_filter_string_unsupported(df):
392-
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
393-
df.filter("a > 1")
414+
def test_filter_string_equivalent(df):
415+
df1 = df.filter("a > 1").to_pydict()
416+
df2 = df.filter(column("a") > literal(1)).to_pydict()
417+
assert df1 == df2
418+
419+
420+
def test_filter_string_invalid(df):
421+
with pytest.raises(Exception) as excinfo:
422+
df.filter("this is not valid sql").collect()
423+
assert "Expected Expr" not in str(excinfo.value)
394424

395425

396426
def test_drop(df):

0 commit comments

Comments
 (0)