From 0b9185a633cee2f3d8cf7b605470a31e8328cc75 Mon Sep 17 00:00:00 2001 From: Alistair Israel Date: Mon, 6 Apr 2026 21:21:38 -0400 Subject: [PATCH 1/2] feat(repl): add filter() with order vs select and require filter before group_by Made-with: Cursor --- docs/REPL.md | 18 +++- src/cli/repl/builder_bridge.rs | 40 ++++++++- src/cli/repl/plan.rs | 62 +++++++++++-- src/cli/repl/stage.rs | 35 ++++++-- src/cli/repl/tests.rs | 133 ++++++++++++++++++++++++++++ src/errors.rs | 2 + src/pipeline/builder.rs | 48 ++++++++++ src/pipeline/dataframe/execute.rs | 25 +++++- src/pipeline/dataframe/transform.rs | 20 ++++- src/pipeline/tests.rs | 56 ++++++++++++ 10 files changed, 416 insertions(+), 23 deletions(-) diff --git a/docs/REPL.md b/docs/REPL.md index 9248d35..2802edb 100644 --- a/docs/REPL.md +++ b/docs/REPL.md @@ -111,7 +111,7 @@ For the following functions, note that the function signatures and types provide ### Pipeline shape -A REPL pipeline must start with `read(...)`. You may use at most one `group_by(...)` and one `select(...)` per pipeline (in either order). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select` or `group_by` in the same pipeline. If `group_by(...)` appears, a matching `select(...)` is required. +A REPL pipeline must start with `read(...)`. You may use at most one `filter("...")`, one `group_by(...)`, and one `select(...)` per pipeline, in **any order**, except that when both `filter` and `group_by` are used, **`filter` must come before `group_by`** (for example `read(...) |> filter("...") |> group_by(...) |> select(...)`). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select`, `group_by`, or `filter` in the same pipeline. If `group_by(...)` appears, a matching `select(...)` is required. ### `read` @@ -131,6 +131,22 @@ Reads a Parquet, Avro, ORC, CSV, or JSON file at the given `path`. If `file_type | `.csv` | CSV | | `.json` | JSON | +### `filter` + +```flt +filter(data: Data, sql: String) -> Data +``` + +`filter` takes a single string that is a SQL predicate fragment (as in a `WHERE` clause). It is parsed with Apache DataFusion. Whether the predicate applies to **source** columns or **post-`select`** columns depends on where `filter` appears relative to `select` in the pipeline: `filter` **before** `select` filters raw rows; `filter` **after** `select` filters the projected or aggregated result (column names must match that step’s schema). + +```flt +read("input.parquet") |> filter("amount > 0") |> select(:amount, :status) |> head(10) +read("input.parquet") |> select(:amount, :status) |> filter("amount > 0 AND status = 'active'") |> head(10) +read("input.parquet") |> filter("amount > 0") |> group_by(:country) |> select(:country, sum(:amount)) |> head(10) +``` + +`filter` is only supported for inputs read through DataFusion (Parquet, Avro, CSV, JSON). It is **not** supported for ORC files in the REPL; convert or use another format first. + ### `write` ```flt diff --git a/src/cli/repl/builder_bridge.rs b/src/cli/repl/builder_bridge.rs index 21d99ed..2852623 100644 --- a/src/cli/repl/builder_bridge.rs +++ b/src/cli/repl/builder_bridge.rs @@ -28,20 +28,47 @@ pub(crate) fn repl_stages_to_pipeline_builder( builder.read(path); let mut i = 1usize; + let mut filter_idx: Option = None; + let mut select_idx: Option = None; let mut group_keys: Option> = None; let mut select_columns: Option> = None; + let mut filter_sql: Option = None; - for _ in 0..2 { + while i < body.len() { match body.get(i) { + Some(ReplPipelineStage::Filter { sql }) => { + filter_idx = Some(i); + filter_sql = Some(sql.clone()); + i += 1; + } Some(ReplPipelineStage::GroupBy { columns }) => { group_keys = Some(columns.clone()); i += 1; } Some(ReplPipelineStage::Select { columns }) => { + select_idx = Some(i); select_columns = Some(columns.clone()); i += 1; } - _ => break, + Some( + ReplPipelineStage::Head { .. } + | ReplPipelineStage::Tail { .. } + | ReplPipelineStage::Sample { .. } + | ReplPipelineStage::Schema + | ReplPipelineStage::Count + | ReplPipelineStage::Write { .. }, + ) => break, + Some(ReplPipelineStage::Read { .. }) => { + return Err(crate::Error::InvalidReplPipeline( + "unexpected read(path) after start of pipeline".to_string(), + )); + } + Some(ReplPipelineStage::Print) => { + return Err(crate::Error::InvalidReplPipeline( + "unexpected print() in pipeline body".to_string(), + )); + } + None => break, } } @@ -53,6 +80,15 @@ pub(crate) fn repl_stages_to_pipeline_builder( builder.select_spec(spec); } + if let Some(sql) = filter_sql { + let runs_after = match (filter_idx, select_idx) { + (Some(f), Some(s)) => f > s, + _ => false, + }; + builder.filter_sql(&sql); + builder.filter_runs_after_select(runs_after); + } + match body.get(i) { Some(ReplPipelineStage::Head { n }) => { builder.head(*n); diff --git a/src/cli/repl/plan.rs b/src/cli/repl/plan.rs index d012656..ff4a201 100644 --- a/src/cli/repl/plan.rs +++ b/src/cli/repl/plan.rs @@ -85,6 +85,14 @@ pub(super) fn extract_path_from_args(func_name: &str, args: &[Expr]) -> crate::R } } +/// Extracts a single SQL predicate string (same shape as [`extract_path_from_args`]). +pub(super) fn extract_sql_predicate_from_args( + func_name: &str, + args: &[Expr], +) -> crate::Result { + extract_path_from_args(func_name, args) +} + fn extract_one_column_spec(expr: &Expr) -> crate::Result { match expr { Expr::Literal(Literal::Symbol(s)) => Ok(ColumnSpec::CaseInsensitive(s.clone())), @@ -168,6 +176,10 @@ pub(super) fn plan_stage(expr: Expr) -> crate::Result { let path = extract_path_from_args("read", &args)?; Ok(ReplPipelineStage::Read { path }) } + "filter" => { + let sql = extract_sql_predicate_from_args("filter", &args)?; + Ok(ReplPipelineStage::Filter { sql }) + } "group_by" => { if args.is_empty() { return Err(Error::UnsupportedFunctionCall( @@ -294,8 +306,8 @@ fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate:: Ok(()) } -/// Validates that stages match `read` → optional `group_by` / `select` (either order) → optional slice or `schema`/`count` → optional `write`, -/// with optional trailing `print` only after head/tail/sample. +/// Validates that stages match `read` → optional permuted `filter` / `group_by` / `select` (at most one each) → optional slice or `schema`/`count` → optional `write`, +/// with optional trailing `print` only after head/tail/sample. When both `filter` and `group_by` are present, `filter` must come first in the pipeline. pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> crate::Result<()> { if stages.is_empty() { return Err(Error::InvalidReplPipeline("empty pipeline".to_string())); @@ -313,17 +325,29 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra } let mut i = 1usize; + let mut filter_idx: Option = None; + let mut group_idx: Option = None; let mut group_by_cols: Option<&Vec> = None; let mut select_items: Option<&Vec> = None; - for _ in 0..2 { + while i < body.len() { match body.get(i) { + Some(ReplPipelineStage::Filter { .. }) => { + if filter_idx.is_some() { + return Err(Error::InvalidReplPipeline( + "only one filter(...) is allowed in a pipeline".to_string(), + )); + } + filter_idx = Some(i); + i += 1; + } Some(ReplPipelineStage::GroupBy { columns }) => { if group_by_cols.is_some() { return Err(Error::InvalidReplPipeline( "only one group_by(...) is allowed in a pipeline".to_string(), )); } + group_idx = Some(i); group_by_cols = Some(columns); i += 1; } @@ -336,7 +360,25 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra select_items = Some(columns); i += 1; } - _ => break, + Some( + ReplPipelineStage::Head { .. } + | ReplPipelineStage::Tail { .. } + | ReplPipelineStage::Sample { .. } + | ReplPipelineStage::Schema + | ReplPipelineStage::Count + | ReplPipelineStage::Write { .. }, + ) => break, + Some(ReplPipelineStage::Read { .. }) => { + return Err(Error::InvalidReplPipeline( + "unexpected read(path) after start of pipeline".to_string(), + )); + } + Some(ReplPipelineStage::Print) => { + return Err(Error::InvalidReplPipeline( + "unexpected print() in pipeline body".to_string(), + )); + } + None => break, } } @@ -346,6 +388,14 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra )); } + if let (Some(fi), Some(gi)) = (filter_idx, group_idx) + && fi >= gi + { + return Err(Error::InvalidReplPipeline( + "filter(...) must appear before group_by(...)".to_string(), + )); + } + if let Some(keys) = group_by_cols { let items = select_items.expect("select when group_by"); validate_grouped_select(keys, items)?; @@ -402,14 +452,14 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra } _ => { return Err(Error::InvalidReplPipeline( - "invalid pipeline stage order (expected read, optional group_by and select, then head|tail|sample|schema|count, or write)".to_string(), + "invalid pipeline stage order (expected read, optional filter, group_by, and select in any order with filter before group_by when both are used, then head|tail|sample|schema|count, or write)".to_string(), )); } } if i != body.len() { return Err(Error::InvalidReplPipeline( - "invalid pipeline stage order (expected read, optional group_by/select, optional head|tail|sample|schema|count, optional write)".to_string(), + "invalid pipeline stage order (expected read, optional filter/group_by/select, optional head|tail|sample|schema|count, optional write)".to_string(), )); } diff --git a/src/cli/repl/stage.rs b/src/cli/repl/stage.rs index 783d0b7..9079393 100644 --- a/src/cli/repl/stage.rs +++ b/src/cli/repl/stage.rs @@ -9,15 +9,33 @@ use crate::pipeline::SelectSpec; /// A planned pipeline stage with validated, extracted arguments. #[derive(Debug, PartialEq)] pub enum ReplPipelineStage { - Read { path: String }, - GroupBy { columns: Vec }, - Select { columns: Vec }, - Head { n: usize }, - Tail { n: usize }, - Sample { n: usize }, + Read { + path: String, + }, + /// SQL predicate (`parse_sql_expr` + `filter`); runs before or after `select` per pipeline order. + Filter { + sql: String, + }, + GroupBy { + columns: Vec, + }, + Select { + columns: Vec, + }, + Head { + n: usize, + }, + Tail { + n: usize, + }, + Sample { + n: usize, + }, Schema, Count, - Write { path: String }, + Write { + path: String, + }, Print, } @@ -68,7 +86,7 @@ impl ReplPipelineStage { // Single-stage check: full pipeline uses `repl_pipeline_last_select_is_terminal`. select_spec_from_items(columns).is_aggregate_only() } - ReplPipelineStage::GroupBy { .. } => false, + ReplPipelineStage::GroupBy { .. } | ReplPipelineStage::Filter { .. } => false, ReplPipelineStage::Read { .. } | ReplPipelineStage::Print => false, } } @@ -100,6 +118,7 @@ impl fmt::Display for ReplPipelineStage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ReplPipelineStage::Read { path } => write!(f, r#"read("{path}")"#), + ReplPipelineStage::Filter { sql } => write!(f, "filter({sql:?})"), ReplPipelineStage::GroupBy { columns } => { let cols: Vec = columns.iter().map(format_column_spec).collect(); write!(f, "group_by({})", cols.join(", ")) diff --git a/src/cli/repl/tests.rs b/src/cli/repl/tests.rs index d2d8219..7be65f8 100644 --- a/src/cli/repl/tests.rs +++ b/src/cli/repl/tests.rs @@ -22,6 +22,7 @@ use super::plan::validate_repl_pipeline_stages; use super::stage::ReplPipelineStage; use crate::Error; use crate::pipeline::DataFramePipeline; +use crate::pipeline::DisplaySlice; use crate::pipeline::Pipeline; use crate::pipeline::SelectSpec; @@ -64,6 +65,18 @@ fn test_plan_stage_read() { ); } +#[test] +fn test_plan_stage_filter() { + let expr = parse(r#"filter("a > 1")"#); + let stage = plan_stage(expr).unwrap(); + assert_eq!( + stage, + ReplPipelineStage::Filter { + sql: "a > 1".to_string(), + } + ); +} + #[test] fn test_plan_stage_select() { let expr = Expr::FunctionCall( @@ -526,6 +539,126 @@ fn test_extract_path_from_args_write_bad_args() { // ── validate_repl_pipeline_stages / plan not implemented ────── +#[test] +fn test_validate_rejects_duplicate_filter() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + }, + ReplPipelineStage::Filter { + sql: "x > 0".into(), + }, + ReplPipelineStage::Filter { + sql: "y < 1".into(), + }, + ReplPipelineStage::Head { n: 1 }, + ]; + let err = validate_repl_pipeline_stages(&stages).unwrap_err(); + assert!(matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("only one filter"))); +} + +#[test] +fn test_validate_accepts_filter_without_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Filter { + sql: "id > 0".into(), + }, + ReplPipelineStage::Head { n: 5 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_validate_accepts_select_filter_head() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + "one".into(), + ))], + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::Head { n: 5 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_validate_rejects_filter_after_group_by() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::Head { n: 3 }, + ]; + let err = validate_repl_pipeline_stages(&stages).unwrap_err(); + assert!( + matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("filter(...) must appear before group_by")) + ); +} + +#[test] +fn test_validate_accepts_filter_group_by_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Head { n: 3 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_builder_bridge_sets_filter_sql() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + "one".into(), + ))], + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::Head { n: 2 }, + ]; + let mut builder = repl_stages_to_pipeline_builder(&stages).unwrap(); + let Pipeline::DataFrame(p) = builder.build().unwrap() else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_sql.as_deref(), Some("true")); + assert!(p.filter_runs_after_select); + assert_eq!(p.slice, Some(DisplaySlice::Head(2))); +} + #[test] fn test_validate_rejects_second_select() { let stages = vec![ diff --git a/src/errors.rs b/src/errors.rs index 826c827..09a3137 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -64,6 +64,8 @@ pub enum PipelinePlanningError { "Aggregates in select are not supported for ORC input; use Parquet, CSV, JSON, or Avro" )] AggregatesNotSupportedForOrc, + #[error("filter() is not supported for ORC input; use Parquet, CSV, JSON, or Avro")] + FilterNotSupportedForOrc, } /// Errors produced while running a pipeline (wrong format, consumed state, etc.). diff --git a/src/pipeline/builder.rs b/src/pipeline/builder.rs index bb6c39f..1f42f9c 100644 --- a/src/pipeline/builder.rs +++ b/src/pipeline/builder.rs @@ -23,6 +23,10 @@ use crate::resolve_file_type; /// Fluent builder for a [`Pipeline`] (file conversion or stdout display: head/tail/sample, schema, or row count). pub struct PipelineBuilder { read: Option, + /// REPL: SQL predicate for DataFusion `parse_sql_expr` + `filter`. + filter_sql: Option, + /// When true, apply `filter_sql` after `select`; when false, before (or on raw data if no select). + filter_runs_after_select: bool, select: Option, head: Option, tail: Option, @@ -44,6 +48,8 @@ impl Default for PipelineBuilder { fn default() -> Self { Self { read: None, + filter_sql: None, + filter_runs_after_select: false, select: None, head: None, tail: None, @@ -76,6 +82,19 @@ impl PipelineBuilder { self.read = Some(path.to_string()); self } + + /// SQL `WHERE`-style predicate for DataFusion (REPL `filter("...")`). + pub fn filter_sql(&mut self, sql: &str) -> &mut Self { + self.filter_sql = Some(sql.to_string()); + self + } + + /// When set with [`filter_sql`](Self::filter_sql), apply the filter after `select` if true, before if false. + pub fn filter_runs_after_select(&mut self, v: bool) -> &mut Self { + self.filter_runs_after_select = v; + self + } + /// Sets column selection as exact name matches. pub fn select(&mut self, columns: &[&str]) -> &mut Self { self.select = Some(SelectSpec { @@ -196,6 +215,7 @@ impl PipelineBuilder { } reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filter(input_file_type, &self.filter_sql)?; let slice = slice_from_head_tail_sample(self.head, self.tail, self.sample); Ok(dispatch_pipeline( @@ -205,6 +225,8 @@ impl PipelineBuilder { slice, self.sparse, self.csv_has_header, + self.filter_sql.clone(), + self.filter_runs_after_select, UnifiedSink::Write { output_path: output_path.to_string(), output_file_type, @@ -234,6 +256,7 @@ impl PipelineBuilder { let csv_has_header = self.csv_has_header; reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filter(input_file_type, &self.filter_sql)?; Ok(dispatch_pipeline( input_path, @@ -242,6 +265,8 @@ impl PipelineBuilder { None, sparse, csv_has_header, + self.filter_sql.clone(), + self.filter_runs_after_select, UnifiedSink::Schema { output_format, sparse, @@ -266,6 +291,7 @@ impl PipelineBuilder { let sparse = self.sparse; reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filter(input_file_type, &self.filter_sql)?; Ok(dispatch_pipeline( input_path, @@ -274,6 +300,8 @@ impl PipelineBuilder { None, sparse, csv_has_header, + self.filter_sql.clone(), + self.filter_runs_after_select, UnifiedSink::Count, )) } @@ -307,6 +335,7 @@ impl PipelineBuilder { let sparse = self.sparse; reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filter(input_file_type, &self.filter_sql)?; Ok(dispatch_pipeline( input_path, @@ -315,6 +344,8 @@ impl PipelineBuilder { Some(slice), sparse, csv_has_header, + self.filter_sql.clone(), + self.filter_runs_after_select, UnifiedSink::Display { output_format, csv_stdout_headers, @@ -341,6 +372,7 @@ impl PipelineBuilder { PipelinePlanningError::AggregatesNotSupportedForOrc, )); } + reject_orc_with_filter(input_file_type, &self.filter_sql)?; let output_format = self .display_output_format .unwrap_or(DisplayOutputFormat::Csv); @@ -352,6 +384,8 @@ impl PipelineBuilder { input_path, input_file_type, select, + filter_sql: self.filter_sql.clone(), + filter_runs_after_select: self.filter_runs_after_select, slice: None, csv_has_header, sparse, @@ -434,6 +468,7 @@ enum UnifiedSink { Count, } +#[allow(clippy::too_many_arguments)] fn dispatch_pipeline( input_path: String, input_file_type: FileType, @@ -441,6 +476,8 @@ fn dispatch_pipeline( slice: Option, sparse: bool, csv_has_header: Option, + filter_sql: Option, + filter_runs_after_select: bool, sink: UnifiedSink, ) -> Pipeline { if input_file_type == FileType::Orc { @@ -457,6 +494,8 @@ fn dispatch_pipeline( input_path, input_file_type, select, + filter_sql, + filter_runs_after_select, slice, csv_has_header, sparse, @@ -539,6 +578,15 @@ fn reject_orc_with_aggregates( Ok(()) } +fn reject_orc_with_filter(input_file_type: FileType, filter_sql: &Option) -> Result<()> { + if input_file_type == FileType::Orc && filter_sql.is_some() { + return Err(Error::PipelinePlanningError( + PipelinePlanningError::FilterNotSupportedForOrc, + )); + } + Ok(()) +} + fn slice_from_head_tail_sample( head: Option, tail: Option, diff --git a/src/pipeline/dataframe/execute.rs b/src/pipeline/dataframe/execute.rs index 8686cba..e15de78 100644 --- a/src/pipeline/dataframe/execute.rs +++ b/src/pipeline/dataframe/execute.rs @@ -50,6 +50,10 @@ pub struct DataFramePipeline { pub(crate) input_path: String, pub(crate) input_file_type: FileType, pub(crate) select: Option, + /// SQL predicate: `parse_sql_expr` + `filter` (placement vs `select` via [`filter_runs_after_select`](DataFramePipeline::filter_runs_after_select)). + pub(crate) filter_sql: Option, + /// When true, run filter after `select`; when false, before `select` (REPL order). + pub(crate) filter_runs_after_select: bool, pub(crate) slice: Option, pub(crate) csv_has_header: Option, pub(crate) sparse: bool, @@ -57,11 +61,13 @@ pub struct DataFramePipeline { } impl DataFramePipeline { - /// Read, optional column select, optional head/tail/sample, then [`DataFrameSink`]. + /// Read, optional column select, optional SQL filter, optional head/tail/sample, then [`DataFrameSink`]. pub fn execute(&mut self) -> crate::Result<()> { let input_path = self.input_path.clone(); let input_file_type = self.input_file_type; let select = self.select.clone(); + let filter_sql = self.filter_sql.clone(); + let filter_runs_after_select = self.filter_runs_after_select; let slice = self.slice; let csv_has_header = self.csv_has_header; let sparse = self.sparse; @@ -101,6 +107,7 @@ impl DataFramePipeline { sparse: schema_sparse, } => { let use_file_metadata_schema = select.is_none() + && filter_sql.is_none() && matches!( input_file_type, FileType::Parquet | FileType::Avro | FileType::Csv | FileType::Orc @@ -116,6 +123,8 @@ impl DataFramePipeline { input_path.clone(), input_file_type, select, + filter_sql.clone(), + filter_runs_after_select, None, csv_has_header, ) @@ -130,13 +139,15 @@ impl DataFramePipeline { Ok::<(), Error>(()) } DataFrameSink::Count => { - let total = if select.is_none() { + let total = if select.is_none() && filter_sql.is_none() { count_rows(&input_path, input_file_type, csv_has_header).await? } else { let mut source = dataframe_pipeline_prepare_source( input_path.clone(), input_file_type, select, + filter_sql.clone(), + filter_runs_after_select, None, csv_has_header, ) @@ -159,6 +170,8 @@ impl DataFramePipeline { input_path, input_file_type, select, + filter_sql.clone(), + filter_runs_after_select, slice, csv_has_header, ) @@ -206,6 +219,8 @@ impl DataFramePipeline { input_path, input_file_type, select, + filter_sql.clone(), + filter_runs_after_select, slice, csv_has_header, ) @@ -230,11 +245,13 @@ impl DataFramePipeline { } } -/// Read into a [`DataFrameSource`], apply optional column select, then optional head/tail/sample. +/// Read into a [`DataFrameSource`], apply optional SQL filter and column select per `filter_runs_after_select`, then optional head/tail/sample. pub(crate) async fn dataframe_pipeline_prepare_source( input_path: String, input_file_type: FileType, select: Option, + filter_sql: Option, + filter_runs_after_select: bool, slice: Option, csv_has_header: Option, ) -> crate::Result { @@ -252,6 +269,8 @@ pub(crate) async fn dataframe_pipeline_prepare_source( df, &input_path, input_file_type, + filter_sql.as_deref(), + filter_runs_after_select, select.as_ref(), None, slice, diff --git a/src/pipeline/dataframe/transform.rs b/src/pipeline/dataframe/transform.rs index 1f65568..0140e14 100644 --- a/src/pipeline/dataframe/transform.rs +++ b/src/pipeline/dataframe/transform.rs @@ -278,23 +278,37 @@ pub(super) fn apply_select_spec_to_dataframe( Ok(df) } -/// Applies optional column selection, SQL-style row limit, and display slice to a loaded [`DataFrame`]. +/// Applies optional SQL filter and column selection (order controlled by `filter_runs_after_select`), then SQL-style row limit and display slice. /// -/// Used by `LegacyDataFrameReader` and `dataframe_pipeline_prepare_source` so read → project → -/// cap/slice stays in one place. +/// When `filter_runs_after_select` is false, `filter_sql` runs on the raw frame before `select`; when true, after `select` (post-projection / post-aggregate). +#[allow(clippy::too_many_arguments)] // Pipeline finalize bundles several optional stages; splitting would not simplify call sites. pub(super) async fn finalize_dataframe_source( mut df: DataFrame, input_path: &str, input_file_type: FileType, + filter_sql: Option<&str>, + filter_runs_after_select: bool, select: Option<&SelectSpec>, limit: Option, slice: Option, ) -> crate::Result { + if let Some(sql) = filter_sql + && !filter_runs_after_select + { + let expr = df.parse_sql_expr(sql)?; + df = df.filter(expr)?; + } if let Some(spec) = select && !spec.is_empty() { df = apply_select_spec_to_dataframe(df, spec)?; } + if let Some(sql) = filter_sql + && filter_runs_after_select + { + let expr = df.parse_sql_expr(sql)?; + df = df.filter(expr)?; + } if let Some(n) = limit { df = dataframe_apply_head(df, n)?; } diff --git a/src/pipeline/tests.rs b/src/pipeline/tests.rs index 50ffdb2..84eb497 100644 --- a/src/pipeline/tests.rs +++ b/src/pipeline/tests.rs @@ -165,6 +165,62 @@ fn test_pipeline_builder_read_head_display_orc_uses_record_batch_pipeline() { assert!(matches!(p.sink, RecordBatchSink::Display { .. })); } +#[test] +fn test_pipeline_builder_filter_sql_rejected_for_orc() { + use crate::errors::PipelinePlanningError; + + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/userdata.orc") + .select(&["col"]) + .filter_sql("col IS NOT NULL") + .filter_runs_after_select(true) + .head(3); + let err = match builder.build() { + Err(e) => e, + Ok(_) => panic!("ORC + filter should fail at plan time"), + }; + assert!(matches!( + err, + Error::PipelinePlanningError(PipelinePlanningError::FilterNotSupportedForOrc) + )); +} + +#[test] +fn test_pipeline_builder_read_filter_head_sets_filter_sql() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .select(&["one"]) + .filter_sql("true") + .filter_runs_after_select(true) + .head(3); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_sql.as_deref(), Some("true")); + assert!(p.filter_runs_after_select); + assert_eq!(p.slice, Some(DisplaySlice::Head(3))); +} + +#[test] +fn test_pipeline_builder_filter_before_select_sets_placement() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .filter_sql("true") + .filter_runs_after_select(false) + .select(&["one"]) + .head(2); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_sql.as_deref(), Some("true")); + assert!(!p.filter_runs_after_select); +} + #[test] fn test_pipeline_builder_read_sample_display_parquet() { let mut builder = PipelineBuilder::new(); From 607430b1dd87e45ab6bf89f61a0e4c816ac9731f Mon Sep 17 00:00:00 2001 From: Alistair Israel Date: Tue, 7 Apr 2026 21:04:10 -0400 Subject: [PATCH 2/2] feat(pipeline): dual WHERE/HAVING filters via FilterSpec and REPL straddling select Made-with: Cursor --- docs/REPL.md | 11 ++- src/cli/repl/builder_bridge.rs | 37 ++++++--- src/cli/repl/plan.rs | 41 ++++++---- src/cli/repl/stage.rs | 2 +- src/cli/repl/tests.rs | 123 ++++++++++++++++++++++++++-- src/pipeline.rs | 1 + src/pipeline/builder.rs | 93 +++++++++++++-------- src/pipeline/dataframe/execute.rs | 49 ++++++----- src/pipeline/dataframe/transform.rs | 16 ++-- src/pipeline/spec.rs | 28 +++++++ src/pipeline/tests.rs | 69 +++++++++++++--- 11 files changed, 356 insertions(+), 114 deletions(-) diff --git a/docs/REPL.md b/docs/REPL.md index 2802edb..0d7b07c 100644 --- a/docs/REPL.md +++ b/docs/REPL.md @@ -111,7 +111,7 @@ For the following functions, note that the function signatures and types provide ### Pipeline shape -A REPL pipeline must start with `read(...)`. You may use at most one `filter("...")`, one `group_by(...)`, and one `select(...)` per pipeline, in **any order**, except that when both `filter` and `group_by` are used, **`filter` must come before `group_by`** (for example `read(...) |> filter("...") |> group_by(...) |> select(...)`). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select`, `group_by`, or `filter` in the same pipeline. If `group_by(...)` appears, a matching `select(...)` is required. +A REPL pipeline must start with `read(...)`. You may use at most **two** `filter("...")` stages (only when one appears **before** `select(...)` and one **after**, so you can combine WHERE-like and HAVING-like predicates), at most one `group_by(...)`, and one `select(...)`, in **any order** among those stages (subject to `group_by(...)` requiring a matching `select(...)`). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select` or `group_by`. ### `read` @@ -137,12 +137,19 @@ Reads a Parquet, Avro, ORC, CSV, or JSON file at the given `path`. If `file_type filter(data: Data, sql: String) -> Data ``` -`filter` takes a single string that is a SQL predicate fragment (as in a `WHERE` clause). It is parsed with Apache DataFusion. Whether the predicate applies to **source** columns or **post-`select`** columns depends on where `filter` appears relative to `select` in the pipeline: `filter` **before** `select` filters raw rows; `filter` **after** `select` filters the projected or aggregated result (column names must match that step’s schema). +`filter` takes a single string that is a SQL predicate fragment. It is parsed with Apache DataFusion. **Placement relative to `select` in the pipeline** fixes whether you filter input rows or the result of the `select` step (which, when `group_by` is used, includes aggregation in one logical step—similar to SQL **WHERE** vs **HAVING**): + +- **`filter` before `select`** (when `select` is present): predicate on **source** columns, evaluated on each input row **before** projection or aggregation (WHERE-like when `group_by` is used). +- **`filter` after `select` without `group_by`**: predicate on **projected** columns only. +- **`filter` after `select` with `group_by`**: predicate on the **grouped/aggregated** result; use the output column names DataFusion produces for aggregates (commonly `sum(column_name)`, `avg(column_name)`, etc., matching the source column name). +- **Two `filter` stages**: the first (by pipeline order before `select`) runs on **input rows**; the second (after `select`) runs on the **result**—together, analogous to **WHERE** then **HAVING** when `group_by` and aggregates are used. ```flt read("input.parquet") |> filter("amount > 0") |> select(:amount, :status) |> head(10) read("input.parquet") |> select(:amount, :status) |> filter("amount > 0 AND status = 'active'") |> head(10) read("input.parquet") |> filter("amount > 0") |> group_by(:country) |> select(:country, sum(:amount)) |> head(10) +read("input.parquet") |> group_by(:country) |> select(:country, sum(:amount)) |> filter("sum(amount) > 100") |> head(10) +read("input.parquet") |> filter("status = 'active'") |> group_by(:country) |> select(:country, sum(:amount)) |> filter("sum(amount) > 100") |> head(10) ``` `filter` is only supported for inputs read through DataFusion (Parquet, Avro, CSV, JSON). It is **not** supported for ORC files in the REPL; convert or use another format first. diff --git a/src/cli/repl/builder_bridge.rs b/src/cli/repl/builder_bridge.rs index 2852623..26af303 100644 --- a/src/cli/repl/builder_bridge.rs +++ b/src/cli/repl/builder_bridge.rs @@ -28,17 +28,15 @@ pub(crate) fn repl_stages_to_pipeline_builder( builder.read(path); let mut i = 1usize; - let mut filter_idx: Option = None; let mut select_idx: Option = None; let mut group_keys: Option> = None; let mut select_columns: Option> = None; - let mut filter_sql: Option = None; + let mut filters: Vec<(usize, String)> = Vec::new(); while i < body.len() { match body.get(i) { Some(ReplPipelineStage::Filter { sql }) => { - filter_idx = Some(i); - filter_sql = Some(sql.clone()); + filters.push((i, sql.clone())); i += 1; } Some(ReplPipelineStage::GroupBy { columns }) => { @@ -46,7 +44,9 @@ pub(crate) fn repl_stages_to_pipeline_builder( i += 1; } Some(ReplPipelineStage::Select { columns }) => { - select_idx = Some(i); + if select_idx.is_none() { + select_idx = Some(i); + } select_columns = Some(columns.clone()); i += 1; } @@ -80,13 +80,26 @@ pub(crate) fn repl_stages_to_pipeline_builder( builder.select_spec(spec); } - if let Some(sql) = filter_sql { - let runs_after = match (filter_idx, select_idx) { - (Some(f), Some(s)) => f > s, - _ => false, - }; - builder.filter_sql(&sql); - builder.filter_runs_after_select(runs_after); + match filters.len() { + 0 => {} + 1 => { + let (f, sql) = &filters[0]; + if select_idx.is_some_and(|s| *f > s) { + builder.filter_after_select(sql); + } else { + builder.filter_before_select(sql); + } + } + 2 => { + filters.sort_by_key(|(idx, _)| *idx); + builder.filter_before_select(&filters[0].1); + builder.filter_after_select(&filters[1].1); + } + _ => { + return Err(crate::Error::InvalidReplPipeline( + "at most two filter(...) stages are allowed in a pipeline".to_string(), + )); + } } match body.get(i) { diff --git a/src/cli/repl/plan.rs b/src/cli/repl/plan.rs index ff4a201..e4aa654 100644 --- a/src/cli/repl/plan.rs +++ b/src/cli/repl/plan.rs @@ -306,8 +306,8 @@ fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate:: Ok(()) } -/// Validates that stages match `read` → optional permuted `filter` / `group_by` / `select` (at most one each) → optional slice or `schema`/`count` → optional `write`, -/// with optional trailing `print` only after head/tail/sample. When both `filter` and `group_by` are present, `filter` must come first in the pipeline. +/// Validates that stages match `read` → optional permuted `filter` (at most two, straddling `select` if two) / `group_by` / `select` (at most one each) → optional slice or `schema`/`count` → optional `write`, +/// with optional trailing `print` only after head/tail/sample. A single `filter` maps to before or after `select` by stage order; two `filter` stages require `select(...)` strictly between them (WHERE-like then HAVING-like when aggregating). pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> crate::Result<()> { if stages.is_empty() { return Err(Error::InvalidReplPipeline("empty pipeline".to_string())); @@ -325,20 +325,20 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra } let mut i = 1usize; - let mut filter_idx: Option = None; - let mut group_idx: Option = None; + let mut filter_indices: Vec = Vec::new(); + let mut select_idx: Option = None; let mut group_by_cols: Option<&Vec> = None; let mut select_items: Option<&Vec> = None; while i < body.len() { match body.get(i) { Some(ReplPipelineStage::Filter { .. }) => { - if filter_idx.is_some() { + if filter_indices.len() >= 2 { return Err(Error::InvalidReplPipeline( - "only one filter(...) is allowed in a pipeline".to_string(), + "at most two filter(...) stages are allowed in a pipeline".to_string(), )); } - filter_idx = Some(i); + filter_indices.push(i); i += 1; } Some(ReplPipelineStage::GroupBy { columns }) => { @@ -347,7 +347,6 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra "only one group_by(...) is allowed in a pipeline".to_string(), )); } - group_idx = Some(i); group_by_cols = Some(columns); i += 1; } @@ -357,6 +356,9 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra "only one select(...) is allowed in a pipeline".to_string(), )); } + if select_idx.is_none() { + select_idx = Some(i); + } select_items = Some(columns); i += 1; } @@ -388,12 +390,21 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra )); } - if let (Some(fi), Some(gi)) = (filter_idx, group_idx) - && fi >= gi - { - return Err(Error::InvalidReplPipeline( - "filter(...) must appear before group_by(...)".to_string(), - )); + if filter_indices.len() == 2 { + let Some(si) = select_idx else { + return Err(Error::InvalidReplPipeline( + "two filter(...) stages require select(...) between them (one before and one after select(...))" + .to_string(), + )); + }; + let f0 = filter_indices[0].min(filter_indices[1]); + let f1 = filter_indices[0].max(filter_indices[1]); + if !(f0 < si && si < f1) { + return Err(Error::InvalidReplPipeline( + "two filter(...) stages must have select(...) strictly between them (one before select, one after)" + .to_string(), + )); + } } if let Some(keys) = group_by_cols { @@ -452,7 +463,7 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra } _ => { return Err(Error::InvalidReplPipeline( - "invalid pipeline stage order (expected read, optional filter, group_by, and select in any order with filter before group_by when both are used, then head|tail|sample|schema|count, or write)".to_string(), + "invalid pipeline stage order (expected read, optional filter/group_by/select in any order, then head|tail|sample|schema|count, or write)".to_string(), )); } } diff --git a/src/cli/repl/stage.rs b/src/cli/repl/stage.rs index 9079393..86b1cd7 100644 --- a/src/cli/repl/stage.rs +++ b/src/cli/repl/stage.rs @@ -12,7 +12,7 @@ pub enum ReplPipelineStage { Read { path: String, }, - /// SQL predicate (`parse_sql_expr` + `filter`); runs before or after `select` per pipeline order. + /// SQL predicate (`parse_sql_expr` + `filter`); before or after the `select` stage per pipeline order (after includes post-aggregate when `group_by` is used). Filter { sql: String, }, diff --git a/src/cli/repl/tests.rs b/src/cli/repl/tests.rs index 7be65f8..7189582 100644 --- a/src/cli/repl/tests.rs +++ b/src/cli/repl/tests.rs @@ -23,6 +23,7 @@ use super::stage::ReplPipelineStage; use crate::Error; use crate::pipeline::DataFramePipeline; use crate::pipeline::DisplaySlice; +use crate::pipeline::FilterSpec; use crate::pipeline::Pipeline; use crate::pipeline::SelectSpec; @@ -540,11 +541,12 @@ fn test_extract_path_from_args_write_bad_args() { // ── validate_repl_pipeline_stages / plan not implemented ────── #[test] -fn test_validate_rejects_duplicate_filter() { +fn test_validate_rejects_three_filters() { let stages = vec![ ReplPipelineStage::Read { path: "a.parquet".into(), }, + ReplPipelineStage::Filter { sql: "true".into() }, ReplPipelineStage::Select { columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], }, @@ -557,7 +559,50 @@ fn test_validate_rejects_duplicate_filter() { ReplPipelineStage::Head { n: 1 }, ]; let err = validate_repl_pipeline_stages(&stages).unwrap_err(); - assert!(matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("only one filter"))); + assert!(matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("at most two filter"))); +} + +#[test] +fn test_validate_rejects_two_filters_both_after_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + }, + ReplPipelineStage::Filter { + sql: "x > 0".into(), + }, + ReplPipelineStage::Filter { + sql: "x < 10".into(), + }, + ReplPipelineStage::Head { n: 1 }, + ]; + let err = validate_repl_pipeline_stages(&stages).unwrap_err(); + assert!(matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("strictly between"))); +} + +#[test] +fn test_validate_accepts_two_filters_straddling_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Filter { + sql: "one > 0".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + "one".into(), + ))], + }, + ReplPipelineStage::Filter { + sql: "one < 1000".into(), + }, + ReplPipelineStage::Head { n: 5 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); } #[test] @@ -592,7 +637,7 @@ fn test_validate_accepts_select_filter_head() { } #[test] -fn test_validate_rejects_filter_after_group_by() { +fn test_validate_accepts_filter_after_group_by_select() { let stages = vec![ ReplPipelineStage::Read { path: "fixtures/table.parquet".into(), @@ -609,10 +654,7 @@ fn test_validate_rejects_filter_after_group_by() { ReplPipelineStage::Filter { sql: "true".into() }, ReplPipelineStage::Head { n: 3 }, ]; - let err = validate_repl_pipeline_stages(&stages).unwrap_err(); - assert!( - matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("filter(...) must appear before group_by")) - ); + validate_repl_pipeline_stages(&stages).unwrap(); } #[test] @@ -636,6 +678,69 @@ fn test_validate_accepts_filter_group_by_select() { validate_repl_pipeline_stages(&stages).unwrap(); } +#[test] +fn test_builder_bridge_post_aggregate_filter_runs_after_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Filter { + sql: "sum(three) > 0".into(), + }, + ReplPipelineStage::Head { n: 3 }, + ]; + let mut builder = repl_stages_to_pipeline_builder(&stages).unwrap(); + let Pipeline::DataFrame(p) = builder.build().unwrap() else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, None); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); +} + +#[test] +fn test_builder_bridge_where_and_having_filters() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Filter { + sql: "sum(three) > 0".into(), + }, + ReplPipelineStage::Head { n: 3 }, + ]; + let mut builder = repl_stages_to_pipeline_builder(&stages).unwrap(); + let Pipeline::DataFrame(p) = builder.build().unwrap() else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, Some(FilterSpec::new("true"))); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); +} + #[test] fn test_builder_bridge_sets_filter_sql() { let stages = vec![ @@ -654,8 +759,8 @@ fn test_builder_bridge_sets_filter_sql() { let Pipeline::DataFrame(p) = builder.build().unwrap() else { panic!("expected DataFrame pipeline"); }; - assert_eq!(p.filter_sql.as_deref(), Some("true")); - assert!(p.filter_runs_after_select); + assert_eq!(p.filter_before_select, None); + assert_eq!(p.filter_after_select, Some(FilterSpec::new("true"))); assert_eq!(p.slice, Some(DisplaySlice::Head(2))); } diff --git a/src/pipeline.rs b/src/pipeline.rs index f85597b..d478337 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -72,6 +72,7 @@ pub use sampling::sample_from_reader; pub use sampling::tail_batches; pub use spec::ColumnSpec; pub(crate) use spec::DisplaySlice; +pub use spec::FilterSpec; pub use spec::SelectItem; pub use spec::SelectSpec; pub use step::Producer; diff --git a/src/pipeline/builder.rs b/src/pipeline/builder.rs index 1f42f9c..80bc54a 100644 --- a/src/pipeline/builder.rs +++ b/src/pipeline/builder.rs @@ -16,6 +16,7 @@ use crate::pipeline::record_batch::RecordBatchPipeline; use crate::pipeline::record_batch::RecordBatchSink; use crate::pipeline::spec::ColumnSpec; use crate::pipeline::spec::DisplaySlice; +use crate::pipeline::spec::FilterSpec; use crate::pipeline::spec::SelectItem; use crate::pipeline::spec::SelectSpec; use crate::resolve_file_type; @@ -23,10 +24,10 @@ use crate::resolve_file_type; /// Fluent builder for a [`Pipeline`] (file conversion or stdout display: head/tail/sample, schema, or row count). pub struct PipelineBuilder { read: Option, - /// REPL: SQL predicate for DataFusion `parse_sql_expr` + `filter`. - filter_sql: Option, - /// When true, apply `filter_sql` after `select`; when false, before (or on raw data if no select). - filter_runs_after_select: bool, + /// REPL: SQL predicate applied before `select` (WHERE-like on input rows when aggregating). + filter_before_select: Option, + /// REPL: SQL predicate applied after `select` (post-projection or HAVING-like after `group_by` aggregates). + filter_after_select: Option, select: Option, head: Option, tail: Option, @@ -48,8 +49,8 @@ impl Default for PipelineBuilder { fn default() -> Self { Self { read: None, - filter_sql: None, - filter_runs_after_select: false, + filter_before_select: None, + filter_after_select: None, select: None, head: None, tail: None, @@ -83,15 +84,15 @@ impl PipelineBuilder { self } - /// SQL `WHERE`-style predicate for DataFusion (REPL `filter("...")`). - pub fn filter_sql(&mut self, sql: &str) -> &mut Self { - self.filter_sql = Some(sql.to_string()); + /// SQL predicate for DataFusion `parse_sql_expr` + `filter`, applied **before** the pipeline `select` step (input rows / WHERE-like). + pub fn filter_before_select(&mut self, sql: &str) -> &mut Self { + self.filter_before_select = Some(FilterSpec::new(sql)); self } - /// When set with [`filter_sql`](Self::filter_sql), apply the filter after `select` if true, before if false. - pub fn filter_runs_after_select(&mut self, v: bool) -> &mut Self { - self.filter_runs_after_select = v; + /// SQL predicate applied **after** the pipeline `select` step (post-projection or HAVING-like when `group_by` + aggregates are present). + pub fn filter_after_select(&mut self, sql: &str) -> &mut Self { + self.filter_after_select = Some(FilterSpec::new(sql)); self } @@ -215,7 +216,11 @@ impl PipelineBuilder { } reject_orc_with_aggregates(input_file_type, &select)?; - reject_orc_with_filter(input_file_type, &self.filter_sql)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; let slice = slice_from_head_tail_sample(self.head, self.tail, self.sample); Ok(dispatch_pipeline( @@ -225,8 +230,8 @@ impl PipelineBuilder { slice, self.sparse, self.csv_has_header, - self.filter_sql.clone(), - self.filter_runs_after_select, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Write { output_path: output_path.to_string(), output_file_type, @@ -256,7 +261,11 @@ impl PipelineBuilder { let csv_has_header = self.csv_has_header; reject_orc_with_aggregates(input_file_type, &select)?; - reject_orc_with_filter(input_file_type, &self.filter_sql)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; Ok(dispatch_pipeline( input_path, @@ -265,8 +274,8 @@ impl PipelineBuilder { None, sparse, csv_has_header, - self.filter_sql.clone(), - self.filter_runs_after_select, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Schema { output_format, sparse, @@ -291,7 +300,11 @@ impl PipelineBuilder { let sparse = self.sparse; reject_orc_with_aggregates(input_file_type, &select)?; - reject_orc_with_filter(input_file_type, &self.filter_sql)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; Ok(dispatch_pipeline( input_path, @@ -300,8 +313,8 @@ impl PipelineBuilder { None, sparse, csv_has_header, - self.filter_sql.clone(), - self.filter_runs_after_select, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Count, )) } @@ -335,7 +348,11 @@ impl PipelineBuilder { let sparse = self.sparse; reject_orc_with_aggregates(input_file_type, &select)?; - reject_orc_with_filter(input_file_type, &self.filter_sql)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; Ok(dispatch_pipeline( input_path, @@ -344,8 +361,8 @@ impl PipelineBuilder { Some(slice), sparse, csv_has_header, - self.filter_sql.clone(), - self.filter_runs_after_select, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Display { output_format, csv_stdout_headers, @@ -372,7 +389,11 @@ impl PipelineBuilder { PipelinePlanningError::AggregatesNotSupportedForOrc, )); } - reject_orc_with_filter(input_file_type, &self.filter_sql)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; let output_format = self .display_output_format .unwrap_or(DisplayOutputFormat::Csv); @@ -384,8 +405,8 @@ impl PipelineBuilder { input_path, input_file_type, select, - filter_sql: self.filter_sql.clone(), - filter_runs_after_select: self.filter_runs_after_select, + filter_before_select: self.filter_before_select.clone(), + filter_after_select: self.filter_after_select.clone(), slice: None, csv_has_header, sparse, @@ -476,8 +497,8 @@ fn dispatch_pipeline( slice: Option, sparse: bool, csv_has_header: Option, - filter_sql: Option, - filter_runs_after_select: bool, + filter_before_select: Option, + filter_after_select: Option, sink: UnifiedSink, ) -> Pipeline { if input_file_type == FileType::Orc { @@ -494,8 +515,8 @@ fn dispatch_pipeline( input_path, input_file_type, select, - filter_sql, - filter_runs_after_select, + filter_before_select, + filter_after_select, slice, csv_has_header, sparse, @@ -578,8 +599,14 @@ fn reject_orc_with_aggregates( Ok(()) } -fn reject_orc_with_filter(input_file_type: FileType, filter_sql: &Option) -> Result<()> { - if input_file_type == FileType::Orc && filter_sql.is_some() { +fn reject_orc_with_filters( + input_file_type: FileType, + filter_before_select: &Option, + filter_after_select: &Option, +) -> Result<()> { + if input_file_type == FileType::Orc + && (filter_before_select.is_some() || filter_after_select.is_some()) + { return Err(Error::PipelinePlanningError( PipelinePlanningError::FilterNotSupportedForOrc, )); diff --git a/src/pipeline/dataframe/execute.rs b/src/pipeline/dataframe/execute.rs index e15de78..cda3f71 100644 --- a/src/pipeline/dataframe/execute.rs +++ b/src/pipeline/dataframe/execute.rs @@ -11,6 +11,7 @@ use crate::FileType; use crate::cli::DisplayOutputFormat; use crate::errors::PipelineExecutionError; use crate::pipeline::DisplaySlice; +use crate::pipeline::FilterSpec; use crate::pipeline::ProgressVecRecordBatchReader; use crate::pipeline::SelectSpec; use crate::pipeline::Step; @@ -50,10 +51,10 @@ pub struct DataFramePipeline { pub(crate) input_path: String, pub(crate) input_file_type: FileType, pub(crate) select: Option, - /// SQL predicate: `parse_sql_expr` + `filter` (placement vs `select` via [`filter_runs_after_select`](DataFramePipeline::filter_runs_after_select)). - pub(crate) filter_sql: Option, - /// When true, run filter after `select`; when false, before `select` (REPL order). - pub(crate) filter_runs_after_select: bool, + /// SQL predicate before `select` (`parse_sql_expr` + `filter`). + pub(crate) filter_before_select: Option, + /// SQL predicate after `select` (post-projection or post-aggregate). + pub(crate) filter_after_select: Option, pub(crate) slice: Option, pub(crate) csv_has_header: Option, pub(crate) sparse: bool, @@ -61,13 +62,13 @@ pub struct DataFramePipeline { } impl DataFramePipeline { - /// Read, optional column select, optional SQL filter, optional head/tail/sample, then [`DataFrameSink`]. + /// Read, optional column select, optional SQL filters before/after select, optional head/tail/sample, then [`DataFrameSink`]. pub fn execute(&mut self) -> crate::Result<()> { let input_path = self.input_path.clone(); let input_file_type = self.input_file_type; let select = self.select.clone(); - let filter_sql = self.filter_sql.clone(); - let filter_runs_after_select = self.filter_runs_after_select; + let filter_before_select = self.filter_before_select.clone(); + let filter_after_select = self.filter_after_select.clone(); let slice = self.slice; let csv_has_header = self.csv_has_header; let sparse = self.sparse; @@ -107,7 +108,8 @@ impl DataFramePipeline { sparse: schema_sparse, } => { let use_file_metadata_schema = select.is_none() - && filter_sql.is_none() + && filter_before_select.is_none() + && filter_after_select.is_none() && matches!( input_file_type, FileType::Parquet | FileType::Avro | FileType::Csv | FileType::Orc @@ -123,8 +125,8 @@ impl DataFramePipeline { input_path.clone(), input_file_type, select, - filter_sql.clone(), - filter_runs_after_select, + filter_before_select.clone(), + filter_after_select.clone(), None, csv_has_header, ) @@ -139,15 +141,18 @@ impl DataFramePipeline { Ok::<(), Error>(()) } DataFrameSink::Count => { - let total = if select.is_none() && filter_sql.is_none() { + let total = if select.is_none() + && filter_before_select.is_none() + && filter_after_select.is_none() + { count_rows(&input_path, input_file_type, csv_has_header).await? } else { let mut source = dataframe_pipeline_prepare_source( input_path.clone(), input_file_type, select, - filter_sql.clone(), - filter_runs_after_select, + filter_before_select.clone(), + filter_after_select.clone(), None, csv_has_header, ) @@ -170,8 +175,8 @@ impl DataFramePipeline { input_path, input_file_type, select, - filter_sql.clone(), - filter_runs_after_select, + filter_before_select.clone(), + filter_after_select.clone(), slice, csv_has_header, ) @@ -219,8 +224,8 @@ impl DataFramePipeline { input_path, input_file_type, select, - filter_sql.clone(), - filter_runs_after_select, + filter_before_select.clone(), + filter_after_select.clone(), slice, csv_has_header, ) @@ -245,13 +250,13 @@ impl DataFramePipeline { } } -/// Read into a [`DataFrameSource`], apply optional SQL filter and column select per `filter_runs_after_select`, then optional head/tail/sample. +/// Read into a [`DataFrameSource`], apply optional SQL filters before/after `select`, then optional head/tail/sample. pub(crate) async fn dataframe_pipeline_prepare_source( input_path: String, input_file_type: FileType, select: Option, - filter_sql: Option, - filter_runs_after_select: bool, + filter_before_select: Option, + filter_after_select: Option, slice: Option, csv_has_header: Option, ) -> crate::Result { @@ -269,8 +274,8 @@ pub(crate) async fn dataframe_pipeline_prepare_source( df, &input_path, input_file_type, - filter_sql.as_deref(), - filter_runs_after_select, + filter_before_select.as_ref().map(FilterSpec::as_str), + filter_after_select.as_ref().map(FilterSpec::as_str), select.as_ref(), None, slice, diff --git a/src/pipeline/dataframe/transform.rs b/src/pipeline/dataframe/transform.rs index 0140e14..2f83229 100644 --- a/src/pipeline/dataframe/transform.rs +++ b/src/pipeline/dataframe/transform.rs @@ -278,23 +278,21 @@ pub(super) fn apply_select_spec_to_dataframe( Ok(df) } -/// Applies optional SQL filter and column selection (order controlled by `filter_runs_after_select`), then SQL-style row limit and display slice. +/// Applies optional SQL filters before and after column selection, then SQL-style row limit and display slice. /// -/// When `filter_runs_after_select` is false, `filter_sql` runs on the raw frame before `select`; when true, after `select` (post-projection / post-aggregate). +/// `filter_before_select` runs on the raw frame; `filter_after_select` runs after `select` (post-projection or post-`group_by` aggregate when the spec includes grouping). #[allow(clippy::too_many_arguments)] // Pipeline finalize bundles several optional stages; splitting would not simplify call sites. pub(super) async fn finalize_dataframe_source( mut df: DataFrame, input_path: &str, input_file_type: FileType, - filter_sql: Option<&str>, - filter_runs_after_select: bool, + filter_before_select: Option<&str>, + filter_after_select: Option<&str>, select: Option<&SelectSpec>, limit: Option, slice: Option, ) -> crate::Result { - if let Some(sql) = filter_sql - && !filter_runs_after_select - { + if let Some(sql) = filter_before_select { let expr = df.parse_sql_expr(sql)?; df = df.filter(expr)?; } @@ -303,9 +301,7 @@ pub(super) async fn finalize_dataframe_source( { df = apply_select_spec_to_dataframe(df, spec)?; } - if let Some(sql) = filter_sql - && filter_runs_after_select - { + if let Some(sql) = filter_after_select { let expr = df.parse_sql_expr(sql)?; df = df.filter(expr)?; } diff --git a/src/pipeline/spec.rs b/src/pipeline/spec.rs index 609c274..b671353 100644 --- a/src/pipeline/spec.rs +++ b/src/pipeline/spec.rs @@ -16,6 +16,34 @@ pub(crate) enum DisplaySlice { Sample(usize), } +/// SQL predicate string for DataFusion `parse_sql_expr` + `filter` (newtype over `String`). +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FilterSpec(String); + +impl FilterSpec { + pub fn new(s: impl Into) -> Self { + Self(s.into()) + } + + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} + +impl std::ops::Deref for FilterSpec { + type Target = str; + + fn deref(&self) -> &str { + self.as_str() + } +} + +impl AsRef for FilterSpec { + fn as_ref(&self) -> &str { + self.as_str() + } +} + /// How to match a column name: exact (case-sensitive) or case-insensitive. #[derive(Clone, Debug, PartialEq)] pub enum ColumnSpec { diff --git a/src/pipeline/tests.rs b/src/pipeline/tests.rs index 84eb497..47dcbaf 100644 --- a/src/pipeline/tests.rs +++ b/src/pipeline/tests.rs @@ -8,6 +8,7 @@ use crate::Error; use crate::FileType; use crate::pipeline::ColumnSpec; use crate::pipeline::DataframeParquetReader; +use crate::pipeline::FilterSpec; use crate::pipeline::SelectItem; use crate::pipeline::SelectSpec; use crate::pipeline::avro::DataframeAvroWriter; @@ -173,8 +174,7 @@ fn test_pipeline_builder_filter_sql_rejected_for_orc() { builder .read("fixtures/userdata.orc") .select(&["col"]) - .filter_sql("col IS NOT NULL") - .filter_runs_after_select(true) + .filter_after_select("col IS NOT NULL") .head(3); let err = match builder.build() { Err(e) => e, @@ -192,15 +192,14 @@ fn test_pipeline_builder_read_filter_head_sets_filter_sql() { builder .read("fixtures/table.parquet") .select(&["one"]) - .filter_sql("true") - .filter_runs_after_select(true) + .filter_after_select("true") .head(3); let built = builder.build().expect("build display pipeline"); let Pipeline::DataFrame(p) = built else { panic!("expected DataFrame pipeline"); }; - assert_eq!(p.filter_sql.as_deref(), Some("true")); - assert!(p.filter_runs_after_select); + assert_eq!(p.filter_before_select, None); + assert_eq!(p.filter_after_select, Some(FilterSpec::new("true"))); assert_eq!(p.slice, Some(DisplaySlice::Head(3))); } @@ -209,16 +208,66 @@ fn test_pipeline_builder_filter_before_select_sets_placement() { let mut builder = PipelineBuilder::new(); builder .read("fixtures/table.parquet") - .filter_sql("true") - .filter_runs_after_select(false) + .filter_before_select("true") .select(&["one"]) .head(2); let built = builder.build().expect("build display pipeline"); let Pipeline::DataFrame(p) = built else { panic!("expected DataFrame pipeline"); }; - assert_eq!(p.filter_sql.as_deref(), Some("true")); - assert!(!p.filter_runs_after_select); + assert_eq!(p.filter_before_select, Some(FilterSpec::new("true"))); + assert_eq!(p.filter_after_select, None); +} + +#[test] +fn test_pipeline_builder_grouped_select_post_aggregate_filter_sets_flag() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .select_spec(SelectSpec { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + group_by: Some(vec![ColumnSpec::CaseInsensitive("two".into())]), + }) + .filter_after_select("sum(three) > 0") + .head(5); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, None); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); +} + +#[test] +fn test_pipeline_builder_both_filters_before_and_after_select() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .filter_before_select("one > 0") + .select_spec(SelectSpec { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + group_by: Some(vec![ColumnSpec::CaseInsensitive("two".into())]), + }) + .filter_after_select("sum(three) > 0") + .head(5); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, Some(FilterSpec::new("one > 0"))); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); } #[test]