diff --git a/docs/REPL.md b/docs/REPL.md index 9248d35..0d7b07c 100644 --- a/docs/REPL.md +++ b/docs/REPL.md @@ -111,7 +111,7 @@ For the following functions, note that the function signatures and types provide ### Pipeline shape -A REPL pipeline must start with `read(...)`. You may use at most one `group_by(...)` and one `select(...)` per pipeline (in either order). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select` or `group_by` in the same pipeline. If `group_by(...)` appears, a matching `select(...)` is required. +A REPL pipeline must start with `read(...)`. You may use at most **two** `filter("...")` stages (only when one appears **before** `select(...)` and one **after**, so you can combine WHERE-like and HAVING-like predicates), at most one `group_by(...)`, and one `select(...)`, in **any order** among those stages (subject to `group_by(...)` requiring a matching `select(...)`). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select` or `group_by`. ### `read` @@ -131,6 +131,29 @@ Reads a Parquet, Avro, ORC, CSV, or JSON file at the given `path`. If `file_type | `.csv` | CSV | | `.json` | JSON | +### `filter` + +```flt +filter(data: Data, sql: String) -> Data +``` + +`filter` takes a single string that is a SQL predicate fragment. It is parsed with Apache DataFusion. **Placement relative to `select` in the pipeline** fixes whether you filter input rows or the result of the `select` step (which, when `group_by` is used, includes aggregation in one logical step—similar to SQL **WHERE** vs **HAVING**): + +- **`filter` before `select`** (when `select` is present): predicate on **source** columns, evaluated on each input row **before** projection or aggregation (WHERE-like when `group_by` is used). +- **`filter` after `select` without `group_by`**: predicate on **projected** columns only. +- **`filter` after `select` with `group_by`**: predicate on the **grouped/aggregated** result; use the output column names DataFusion produces for aggregates (commonly `sum(column_name)`, `avg(column_name)`, etc., matching the source column name). +- **Two `filter` stages**: the first (by pipeline order before `select`) runs on **input rows**; the second (after `select`) runs on the **result**—together, analogous to **WHERE** then **HAVING** when `group_by` and aggregates are used. + +```flt +read("input.parquet") |> filter("amount > 0") |> select(:amount, :status) |> head(10) +read("input.parquet") |> select(:amount, :status) |> filter("amount > 0 AND status = 'active'") |> head(10) +read("input.parquet") |> filter("amount > 0") |> group_by(:country) |> select(:country, sum(:amount)) |> head(10) +read("input.parquet") |> group_by(:country) |> select(:country, sum(:amount)) |> filter("sum(amount) > 100") |> head(10) +read("input.parquet") |> filter("status = 'active'") |> group_by(:country) |> select(:country, sum(:amount)) |> filter("sum(amount) > 100") |> head(10) +``` + +`filter` is only supported for inputs read through DataFusion (Parquet, Avro, CSV, JSON). It is **not** supported for ORC files in the REPL; convert or use another format first. + ### `write` ```flt diff --git a/src/cli/repl/builder_bridge.rs b/src/cli/repl/builder_bridge.rs index 21d99ed..26af303 100644 --- a/src/cli/repl/builder_bridge.rs +++ b/src/cli/repl/builder_bridge.rs @@ -28,20 +28,47 @@ pub(crate) fn repl_stages_to_pipeline_builder( builder.read(path); let mut i = 1usize; + let mut select_idx: Option = None; let mut group_keys: Option> = None; let mut select_columns: Option> = None; + let mut filters: Vec<(usize, String)> = Vec::new(); - for _ in 0..2 { + while i < body.len() { match body.get(i) { + Some(ReplPipelineStage::Filter { sql }) => { + filters.push((i, sql.clone())); + i += 1; + } Some(ReplPipelineStage::GroupBy { columns }) => { group_keys = Some(columns.clone()); i += 1; } Some(ReplPipelineStage::Select { columns }) => { + if select_idx.is_none() { + select_idx = Some(i); + } select_columns = Some(columns.clone()); i += 1; } - _ => break, + Some( + ReplPipelineStage::Head { .. } + | ReplPipelineStage::Tail { .. } + | ReplPipelineStage::Sample { .. } + | ReplPipelineStage::Schema + | ReplPipelineStage::Count + | ReplPipelineStage::Write { .. }, + ) => break, + Some(ReplPipelineStage::Read { .. }) => { + return Err(crate::Error::InvalidReplPipeline( + "unexpected read(path) after start of pipeline".to_string(), + )); + } + Some(ReplPipelineStage::Print) => { + return Err(crate::Error::InvalidReplPipeline( + "unexpected print() in pipeline body".to_string(), + )); + } + None => break, } } @@ -53,6 +80,28 @@ pub(crate) fn repl_stages_to_pipeline_builder( builder.select_spec(spec); } + match filters.len() { + 0 => {} + 1 => { + let (f, sql) = &filters[0]; + if select_idx.is_some_and(|s| *f > s) { + builder.filter_after_select(sql); + } else { + builder.filter_before_select(sql); + } + } + 2 => { + filters.sort_by_key(|(idx, _)| *idx); + builder.filter_before_select(&filters[0].1); + builder.filter_after_select(&filters[1].1); + } + _ => { + return Err(crate::Error::InvalidReplPipeline( + "at most two filter(...) stages are allowed in a pipeline".to_string(), + )); + } + } + match body.get(i) { Some(ReplPipelineStage::Head { n }) => { builder.head(*n); diff --git a/src/cli/repl/plan.rs b/src/cli/repl/plan.rs index d012656..e4aa654 100644 --- a/src/cli/repl/plan.rs +++ b/src/cli/repl/plan.rs @@ -85,6 +85,14 @@ pub(super) fn extract_path_from_args(func_name: &str, args: &[Expr]) -> crate::R } } +/// Extracts a single SQL predicate string (same shape as [`extract_path_from_args`]). +pub(super) fn extract_sql_predicate_from_args( + func_name: &str, + args: &[Expr], +) -> crate::Result { + extract_path_from_args(func_name, args) +} + fn extract_one_column_spec(expr: &Expr) -> crate::Result { match expr { Expr::Literal(Literal::Symbol(s)) => Ok(ColumnSpec::CaseInsensitive(s.clone())), @@ -168,6 +176,10 @@ pub(super) fn plan_stage(expr: Expr) -> crate::Result { let path = extract_path_from_args("read", &args)?; Ok(ReplPipelineStage::Read { path }) } + "filter" => { + let sql = extract_sql_predicate_from_args("filter", &args)?; + Ok(ReplPipelineStage::Filter { sql }) + } "group_by" => { if args.is_empty() { return Err(Error::UnsupportedFunctionCall( @@ -294,8 +306,8 @@ fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate:: Ok(()) } -/// Validates that stages match `read` → optional `group_by` / `select` (either order) → optional slice or `schema`/`count` → optional `write`, -/// with optional trailing `print` only after head/tail/sample. +/// Validates that stages match `read` → optional permuted `filter` (at most two, straddling `select` if two) / `group_by` / `select` (at most one each) → optional slice or `schema`/`count` → optional `write`, +/// with optional trailing `print` only after head/tail/sample. A single `filter` maps to before or after `select` by stage order; two `filter` stages require `select(...)` strictly between them (WHERE-like then HAVING-like when aggregating). pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> crate::Result<()> { if stages.is_empty() { return Err(Error::InvalidReplPipeline("empty pipeline".to_string())); @@ -313,11 +325,22 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra } let mut i = 1usize; + let mut filter_indices: Vec = Vec::new(); + let mut select_idx: Option = None; let mut group_by_cols: Option<&Vec> = None; let mut select_items: Option<&Vec> = None; - for _ in 0..2 { + while i < body.len() { match body.get(i) { + Some(ReplPipelineStage::Filter { .. }) => { + if filter_indices.len() >= 2 { + return Err(Error::InvalidReplPipeline( + "at most two filter(...) stages are allowed in a pipeline".to_string(), + )); + } + filter_indices.push(i); + i += 1; + } Some(ReplPipelineStage::GroupBy { columns }) => { if group_by_cols.is_some() { return Err(Error::InvalidReplPipeline( @@ -333,10 +356,31 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra "only one select(...) is allowed in a pipeline".to_string(), )); } + if select_idx.is_none() { + select_idx = Some(i); + } select_items = Some(columns); i += 1; } - _ => break, + Some( + ReplPipelineStage::Head { .. } + | ReplPipelineStage::Tail { .. } + | ReplPipelineStage::Sample { .. } + | ReplPipelineStage::Schema + | ReplPipelineStage::Count + | ReplPipelineStage::Write { .. }, + ) => break, + Some(ReplPipelineStage::Read { .. }) => { + return Err(Error::InvalidReplPipeline( + "unexpected read(path) after start of pipeline".to_string(), + )); + } + Some(ReplPipelineStage::Print) => { + return Err(Error::InvalidReplPipeline( + "unexpected print() in pipeline body".to_string(), + )); + } + None => break, } } @@ -346,6 +390,23 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra )); } + if filter_indices.len() == 2 { + let Some(si) = select_idx else { + return Err(Error::InvalidReplPipeline( + "two filter(...) stages require select(...) between them (one before and one after select(...))" + .to_string(), + )); + }; + let f0 = filter_indices[0].min(filter_indices[1]); + let f1 = filter_indices[0].max(filter_indices[1]); + if !(f0 < si && si < f1) { + return Err(Error::InvalidReplPipeline( + "two filter(...) stages must have select(...) strictly between them (one before select, one after)" + .to_string(), + )); + } + } + if let Some(keys) = group_by_cols { let items = select_items.expect("select when group_by"); validate_grouped_select(keys, items)?; @@ -402,14 +463,14 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra } _ => { return Err(Error::InvalidReplPipeline( - "invalid pipeline stage order (expected read, optional group_by and select, then head|tail|sample|schema|count, or write)".to_string(), + "invalid pipeline stage order (expected read, optional filter/group_by/select in any order, then head|tail|sample|schema|count, or write)".to_string(), )); } } if i != body.len() { return Err(Error::InvalidReplPipeline( - "invalid pipeline stage order (expected read, optional group_by/select, optional head|tail|sample|schema|count, optional write)".to_string(), + "invalid pipeline stage order (expected read, optional filter/group_by/select, optional head|tail|sample|schema|count, optional write)".to_string(), )); } diff --git a/src/cli/repl/stage.rs b/src/cli/repl/stage.rs index 783d0b7..86b1cd7 100644 --- a/src/cli/repl/stage.rs +++ b/src/cli/repl/stage.rs @@ -9,15 +9,33 @@ use crate::pipeline::SelectSpec; /// A planned pipeline stage with validated, extracted arguments. #[derive(Debug, PartialEq)] pub enum ReplPipelineStage { - Read { path: String }, - GroupBy { columns: Vec }, - Select { columns: Vec }, - Head { n: usize }, - Tail { n: usize }, - Sample { n: usize }, + Read { + path: String, + }, + /// SQL predicate (`parse_sql_expr` + `filter`); before or after the `select` stage per pipeline order (after includes post-aggregate when `group_by` is used). + Filter { + sql: String, + }, + GroupBy { + columns: Vec, + }, + Select { + columns: Vec, + }, + Head { + n: usize, + }, + Tail { + n: usize, + }, + Sample { + n: usize, + }, Schema, Count, - Write { path: String }, + Write { + path: String, + }, Print, } @@ -68,7 +86,7 @@ impl ReplPipelineStage { // Single-stage check: full pipeline uses `repl_pipeline_last_select_is_terminal`. select_spec_from_items(columns).is_aggregate_only() } - ReplPipelineStage::GroupBy { .. } => false, + ReplPipelineStage::GroupBy { .. } | ReplPipelineStage::Filter { .. } => false, ReplPipelineStage::Read { .. } | ReplPipelineStage::Print => false, } } @@ -100,6 +118,7 @@ impl fmt::Display for ReplPipelineStage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ReplPipelineStage::Read { path } => write!(f, r#"read("{path}")"#), + ReplPipelineStage::Filter { sql } => write!(f, "filter({sql:?})"), ReplPipelineStage::GroupBy { columns } => { let cols: Vec = columns.iter().map(format_column_spec).collect(); write!(f, "group_by({})", cols.join(", ")) diff --git a/src/cli/repl/tests.rs b/src/cli/repl/tests.rs index d2d8219..7189582 100644 --- a/src/cli/repl/tests.rs +++ b/src/cli/repl/tests.rs @@ -22,6 +22,8 @@ use super::plan::validate_repl_pipeline_stages; use super::stage::ReplPipelineStage; use crate::Error; use crate::pipeline::DataFramePipeline; +use crate::pipeline::DisplaySlice; +use crate::pipeline::FilterSpec; use crate::pipeline::Pipeline; use crate::pipeline::SelectSpec; @@ -64,6 +66,18 @@ fn test_plan_stage_read() { ); } +#[test] +fn test_plan_stage_filter() { + let expr = parse(r#"filter("a > 1")"#); + let stage = plan_stage(expr).unwrap(); + assert_eq!( + stage, + ReplPipelineStage::Filter { + sql: "a > 1".to_string(), + } + ); +} + #[test] fn test_plan_stage_select() { let expr = Expr::FunctionCall( @@ -526,6 +540,230 @@ fn test_extract_path_from_args_write_bad_args() { // ── validate_repl_pipeline_stages / plan not implemented ────── +#[test] +fn test_validate_rejects_three_filters() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + }, + ReplPipelineStage::Filter { + sql: "x > 0".into(), + }, + ReplPipelineStage::Filter { + sql: "y < 1".into(), + }, + ReplPipelineStage::Head { n: 1 }, + ]; + let err = validate_repl_pipeline_stages(&stages).unwrap_err(); + assert!(matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("at most two filter"))); +} + +#[test] +fn test_validate_rejects_two_filters_both_after_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + }, + ReplPipelineStage::Filter { + sql: "x > 0".into(), + }, + ReplPipelineStage::Filter { + sql: "x < 10".into(), + }, + ReplPipelineStage::Head { n: 1 }, + ]; + let err = validate_repl_pipeline_stages(&stages).unwrap_err(); + assert!(matches!(err, Error::InvalidReplPipeline(msg) if msg.contains("strictly between"))); +} + +#[test] +fn test_validate_accepts_two_filters_straddling_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Filter { + sql: "one > 0".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + "one".into(), + ))], + }, + ReplPipelineStage::Filter { + sql: "one < 1000".into(), + }, + ReplPipelineStage::Head { n: 5 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_validate_accepts_filter_without_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Filter { + sql: "id > 0".into(), + }, + ReplPipelineStage::Head { n: 5 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_validate_accepts_select_filter_head() { + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + "one".into(), + ))], + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::Head { n: 5 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_validate_accepts_filter_after_group_by_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::Head { n: 3 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_validate_accepts_filter_group_by_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Head { n: 3 }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_builder_bridge_post_aggregate_filter_runs_after_select() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Filter { + sql: "sum(three) > 0".into(), + }, + ReplPipelineStage::Head { n: 3 }, + ]; + let mut builder = repl_stages_to_pipeline_builder(&stages).unwrap(); + let Pipeline::DataFrame(p) = builder.build().unwrap() else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, None); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); +} + +#[test] +fn test_builder_bridge_where_and_having_filters() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::GroupBy { + columns: vec![ColumnSpec::CaseInsensitive("two".into())], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + }, + ReplPipelineStage::Filter { + sql: "sum(three) > 0".into(), + }, + ReplPipelineStage::Head { n: 3 }, + ]; + let mut builder = repl_stages_to_pipeline_builder(&stages).unwrap(); + let Pipeline::DataFrame(p) = builder.build().unwrap() else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, Some(FilterSpec::new("true"))); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); +} + +#[test] +fn test_builder_bridge_sets_filter_sql() { + let stages = vec![ + ReplPipelineStage::Read { + path: "fixtures/table.parquet".into(), + }, + ReplPipelineStage::Select { + columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + "one".into(), + ))], + }, + ReplPipelineStage::Filter { sql: "true".into() }, + ReplPipelineStage::Head { n: 2 }, + ]; + let mut builder = repl_stages_to_pipeline_builder(&stages).unwrap(); + let Pipeline::DataFrame(p) = builder.build().unwrap() else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, None); + assert_eq!(p.filter_after_select, Some(FilterSpec::new("true"))); + assert_eq!(p.slice, Some(DisplaySlice::Head(2))); +} + #[test] fn test_validate_rejects_second_select() { let stages = vec![ diff --git a/src/errors.rs b/src/errors.rs index 826c827..09a3137 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -64,6 +64,8 @@ pub enum PipelinePlanningError { "Aggregates in select are not supported for ORC input; use Parquet, CSV, JSON, or Avro" )] AggregatesNotSupportedForOrc, + #[error("filter() is not supported for ORC input; use Parquet, CSV, JSON, or Avro")] + FilterNotSupportedForOrc, } /// Errors produced while running a pipeline (wrong format, consumed state, etc.). diff --git a/src/pipeline.rs b/src/pipeline.rs index f85597b..d478337 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -72,6 +72,7 @@ pub use sampling::sample_from_reader; pub use sampling::tail_batches; pub use spec::ColumnSpec; pub(crate) use spec::DisplaySlice; +pub use spec::FilterSpec; pub use spec::SelectItem; pub use spec::SelectSpec; pub use step::Producer; diff --git a/src/pipeline/builder.rs b/src/pipeline/builder.rs index bb6c39f..80bc54a 100644 --- a/src/pipeline/builder.rs +++ b/src/pipeline/builder.rs @@ -16,6 +16,7 @@ use crate::pipeline::record_batch::RecordBatchPipeline; use crate::pipeline::record_batch::RecordBatchSink; use crate::pipeline::spec::ColumnSpec; use crate::pipeline::spec::DisplaySlice; +use crate::pipeline::spec::FilterSpec; use crate::pipeline::spec::SelectItem; use crate::pipeline::spec::SelectSpec; use crate::resolve_file_type; @@ -23,6 +24,10 @@ use crate::resolve_file_type; /// Fluent builder for a [`Pipeline`] (file conversion or stdout display: head/tail/sample, schema, or row count). pub struct PipelineBuilder { read: Option, + /// REPL: SQL predicate applied before `select` (WHERE-like on input rows when aggregating). + filter_before_select: Option, + /// REPL: SQL predicate applied after `select` (post-projection or HAVING-like after `group_by` aggregates). + filter_after_select: Option, select: Option, head: Option, tail: Option, @@ -44,6 +49,8 @@ impl Default for PipelineBuilder { fn default() -> Self { Self { read: None, + filter_before_select: None, + filter_after_select: None, select: None, head: None, tail: None, @@ -76,6 +83,19 @@ impl PipelineBuilder { self.read = Some(path.to_string()); self } + + /// SQL predicate for DataFusion `parse_sql_expr` + `filter`, applied **before** the pipeline `select` step (input rows / WHERE-like). + pub fn filter_before_select(&mut self, sql: &str) -> &mut Self { + self.filter_before_select = Some(FilterSpec::new(sql)); + self + } + + /// SQL predicate applied **after** the pipeline `select` step (post-projection or HAVING-like when `group_by` + aggregates are present). + pub fn filter_after_select(&mut self, sql: &str) -> &mut Self { + self.filter_after_select = Some(FilterSpec::new(sql)); + self + } + /// Sets column selection as exact name matches. pub fn select(&mut self, columns: &[&str]) -> &mut Self { self.select = Some(SelectSpec { @@ -196,6 +216,11 @@ impl PipelineBuilder { } reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; let slice = slice_from_head_tail_sample(self.head, self.tail, self.sample); Ok(dispatch_pipeline( @@ -205,6 +230,8 @@ impl PipelineBuilder { slice, self.sparse, self.csv_has_header, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Write { output_path: output_path.to_string(), output_file_type, @@ -234,6 +261,11 @@ impl PipelineBuilder { let csv_has_header = self.csv_has_header; reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; Ok(dispatch_pipeline( input_path, @@ -242,6 +274,8 @@ impl PipelineBuilder { None, sparse, csv_has_header, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Schema { output_format, sparse, @@ -266,6 +300,11 @@ impl PipelineBuilder { let sparse = self.sparse; reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; Ok(dispatch_pipeline( input_path, @@ -274,6 +313,8 @@ impl PipelineBuilder { None, sparse, csv_has_header, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Count, )) } @@ -307,6 +348,11 @@ impl PipelineBuilder { let sparse = self.sparse; reject_orc_with_aggregates(input_file_type, &select)?; + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; Ok(dispatch_pipeline( input_path, @@ -315,6 +361,8 @@ impl PipelineBuilder { Some(slice), sparse, csv_has_header, + self.filter_before_select.clone(), + self.filter_after_select.clone(), UnifiedSink::Display { output_format, csv_stdout_headers, @@ -341,6 +389,11 @@ impl PipelineBuilder { PipelinePlanningError::AggregatesNotSupportedForOrc, )); } + reject_orc_with_filters( + input_file_type, + &self.filter_before_select, + &self.filter_after_select, + )?; let output_format = self .display_output_format .unwrap_or(DisplayOutputFormat::Csv); @@ -352,6 +405,8 @@ impl PipelineBuilder { input_path, input_file_type, select, + filter_before_select: self.filter_before_select.clone(), + filter_after_select: self.filter_after_select.clone(), slice: None, csv_has_header, sparse, @@ -434,6 +489,7 @@ enum UnifiedSink { Count, } +#[allow(clippy::too_many_arguments)] fn dispatch_pipeline( input_path: String, input_file_type: FileType, @@ -441,6 +497,8 @@ fn dispatch_pipeline( slice: Option, sparse: bool, csv_has_header: Option, + filter_before_select: Option, + filter_after_select: Option, sink: UnifiedSink, ) -> Pipeline { if input_file_type == FileType::Orc { @@ -457,6 +515,8 @@ fn dispatch_pipeline( input_path, input_file_type, select, + filter_before_select, + filter_after_select, slice, csv_has_header, sparse, @@ -539,6 +599,21 @@ fn reject_orc_with_aggregates( Ok(()) } +fn reject_orc_with_filters( + input_file_type: FileType, + filter_before_select: &Option, + filter_after_select: &Option, +) -> Result<()> { + if input_file_type == FileType::Orc + && (filter_before_select.is_some() || filter_after_select.is_some()) + { + return Err(Error::PipelinePlanningError( + PipelinePlanningError::FilterNotSupportedForOrc, + )); + } + Ok(()) +} + fn slice_from_head_tail_sample( head: Option, tail: Option, diff --git a/src/pipeline/dataframe/execute.rs b/src/pipeline/dataframe/execute.rs index 8686cba..cda3f71 100644 --- a/src/pipeline/dataframe/execute.rs +++ b/src/pipeline/dataframe/execute.rs @@ -11,6 +11,7 @@ use crate::FileType; use crate::cli::DisplayOutputFormat; use crate::errors::PipelineExecutionError; use crate::pipeline::DisplaySlice; +use crate::pipeline::FilterSpec; use crate::pipeline::ProgressVecRecordBatchReader; use crate::pipeline::SelectSpec; use crate::pipeline::Step; @@ -50,6 +51,10 @@ pub struct DataFramePipeline { pub(crate) input_path: String, pub(crate) input_file_type: FileType, pub(crate) select: Option, + /// SQL predicate before `select` (`parse_sql_expr` + `filter`). + pub(crate) filter_before_select: Option, + /// SQL predicate after `select` (post-projection or post-aggregate). + pub(crate) filter_after_select: Option, pub(crate) slice: Option, pub(crate) csv_has_header: Option, pub(crate) sparse: bool, @@ -57,11 +62,13 @@ pub struct DataFramePipeline { } impl DataFramePipeline { - /// Read, optional column select, optional head/tail/sample, then [`DataFrameSink`]. + /// Read, optional column select, optional SQL filters before/after select, optional head/tail/sample, then [`DataFrameSink`]. pub fn execute(&mut self) -> crate::Result<()> { let input_path = self.input_path.clone(); let input_file_type = self.input_file_type; let select = self.select.clone(); + let filter_before_select = self.filter_before_select.clone(); + let filter_after_select = self.filter_after_select.clone(); let slice = self.slice; let csv_has_header = self.csv_has_header; let sparse = self.sparse; @@ -101,6 +108,8 @@ impl DataFramePipeline { sparse: schema_sparse, } => { let use_file_metadata_schema = select.is_none() + && filter_before_select.is_none() + && filter_after_select.is_none() && matches!( input_file_type, FileType::Parquet | FileType::Avro | FileType::Csv | FileType::Orc @@ -116,6 +125,8 @@ impl DataFramePipeline { input_path.clone(), input_file_type, select, + filter_before_select.clone(), + filter_after_select.clone(), None, csv_has_header, ) @@ -130,13 +141,18 @@ impl DataFramePipeline { Ok::<(), Error>(()) } DataFrameSink::Count => { - let total = if select.is_none() { + let total = if select.is_none() + && filter_before_select.is_none() + && filter_after_select.is_none() + { count_rows(&input_path, input_file_type, csv_has_header).await? } else { let mut source = dataframe_pipeline_prepare_source( input_path.clone(), input_file_type, select, + filter_before_select.clone(), + filter_after_select.clone(), None, csv_has_header, ) @@ -159,6 +175,8 @@ impl DataFramePipeline { input_path, input_file_type, select, + filter_before_select.clone(), + filter_after_select.clone(), slice, csv_has_header, ) @@ -206,6 +224,8 @@ impl DataFramePipeline { input_path, input_file_type, select, + filter_before_select.clone(), + filter_after_select.clone(), slice, csv_has_header, ) @@ -230,11 +250,13 @@ impl DataFramePipeline { } } -/// Read into a [`DataFrameSource`], apply optional column select, then optional head/tail/sample. +/// Read into a [`DataFrameSource`], apply optional SQL filters before/after `select`, then optional head/tail/sample. pub(crate) async fn dataframe_pipeline_prepare_source( input_path: String, input_file_type: FileType, select: Option, + filter_before_select: Option, + filter_after_select: Option, slice: Option, csv_has_header: Option, ) -> crate::Result { @@ -252,6 +274,8 @@ pub(crate) async fn dataframe_pipeline_prepare_source( df, &input_path, input_file_type, + filter_before_select.as_ref().map(FilterSpec::as_str), + filter_after_select.as_ref().map(FilterSpec::as_str), select.as_ref(), None, slice, diff --git a/src/pipeline/dataframe/transform.rs b/src/pipeline/dataframe/transform.rs index 1f65568..2f83229 100644 --- a/src/pipeline/dataframe/transform.rs +++ b/src/pipeline/dataframe/transform.rs @@ -278,23 +278,33 @@ pub(super) fn apply_select_spec_to_dataframe( Ok(df) } -/// Applies optional column selection, SQL-style row limit, and display slice to a loaded [`DataFrame`]. +/// Applies optional SQL filters before and after column selection, then SQL-style row limit and display slice. /// -/// Used by `LegacyDataFrameReader` and `dataframe_pipeline_prepare_source` so read → project → -/// cap/slice stays in one place. +/// `filter_before_select` runs on the raw frame; `filter_after_select` runs after `select` (post-projection or post-`group_by` aggregate when the spec includes grouping). +#[allow(clippy::too_many_arguments)] // Pipeline finalize bundles several optional stages; splitting would not simplify call sites. pub(super) async fn finalize_dataframe_source( mut df: DataFrame, input_path: &str, input_file_type: FileType, + filter_before_select: Option<&str>, + filter_after_select: Option<&str>, select: Option<&SelectSpec>, limit: Option, slice: Option, ) -> crate::Result { + if let Some(sql) = filter_before_select { + let expr = df.parse_sql_expr(sql)?; + df = df.filter(expr)?; + } if let Some(spec) = select && !spec.is_empty() { df = apply_select_spec_to_dataframe(df, spec)?; } + if let Some(sql) = filter_after_select { + let expr = df.parse_sql_expr(sql)?; + df = df.filter(expr)?; + } if let Some(n) = limit { df = dataframe_apply_head(df, n)?; } diff --git a/src/pipeline/spec.rs b/src/pipeline/spec.rs index 609c274..b671353 100644 --- a/src/pipeline/spec.rs +++ b/src/pipeline/spec.rs @@ -16,6 +16,34 @@ pub(crate) enum DisplaySlice { Sample(usize), } +/// SQL predicate string for DataFusion `parse_sql_expr` + `filter` (newtype over `String`). +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FilterSpec(String); + +impl FilterSpec { + pub fn new(s: impl Into) -> Self { + Self(s.into()) + } + + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} + +impl std::ops::Deref for FilterSpec { + type Target = str; + + fn deref(&self) -> &str { + self.as_str() + } +} + +impl AsRef for FilterSpec { + fn as_ref(&self) -> &str { + self.as_str() + } +} + /// How to match a column name: exact (case-sensitive) or case-insensitive. #[derive(Clone, Debug, PartialEq)] pub enum ColumnSpec { diff --git a/src/pipeline/tests.rs b/src/pipeline/tests.rs index 50ffdb2..47dcbaf 100644 --- a/src/pipeline/tests.rs +++ b/src/pipeline/tests.rs @@ -8,6 +8,7 @@ use crate::Error; use crate::FileType; use crate::pipeline::ColumnSpec; use crate::pipeline::DataframeParquetReader; +use crate::pipeline::FilterSpec; use crate::pipeline::SelectItem; use crate::pipeline::SelectSpec; use crate::pipeline::avro::DataframeAvroWriter; @@ -165,6 +166,110 @@ fn test_pipeline_builder_read_head_display_orc_uses_record_batch_pipeline() { assert!(matches!(p.sink, RecordBatchSink::Display { .. })); } +#[test] +fn test_pipeline_builder_filter_sql_rejected_for_orc() { + use crate::errors::PipelinePlanningError; + + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/userdata.orc") + .select(&["col"]) + .filter_after_select("col IS NOT NULL") + .head(3); + let err = match builder.build() { + Err(e) => e, + Ok(_) => panic!("ORC + filter should fail at plan time"), + }; + assert!(matches!( + err, + Error::PipelinePlanningError(PipelinePlanningError::FilterNotSupportedForOrc) + )); +} + +#[test] +fn test_pipeline_builder_read_filter_head_sets_filter_sql() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .select(&["one"]) + .filter_after_select("true") + .head(3); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, None); + assert_eq!(p.filter_after_select, Some(FilterSpec::new("true"))); + assert_eq!(p.slice, Some(DisplaySlice::Head(3))); +} + +#[test] +fn test_pipeline_builder_filter_before_select_sets_placement() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .filter_before_select("true") + .select(&["one"]) + .head(2); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, Some(FilterSpec::new("true"))); + assert_eq!(p.filter_after_select, None); +} + +#[test] +fn test_pipeline_builder_grouped_select_post_aggregate_filter_sets_flag() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .select_spec(SelectSpec { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + group_by: Some(vec![ColumnSpec::CaseInsensitive("two".into())]), + }) + .filter_after_select("sum(three) > 0") + .head(5); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, None); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); +} + +#[test] +fn test_pipeline_builder_both_filters_before_and_after_select() { + let mut builder = PipelineBuilder::new(); + builder + .read("fixtures/table.parquet") + .filter_before_select("one > 0") + .select_spec(SelectSpec { + columns: vec![ + SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + ], + group_by: Some(vec![ColumnSpec::CaseInsensitive("two".into())]), + }) + .filter_after_select("sum(three) > 0") + .head(5); + let built = builder.build().expect("build display pipeline"); + let Pipeline::DataFrame(p) = built else { + panic!("expected DataFrame pipeline"); + }; + assert_eq!(p.filter_before_select, Some(FilterSpec::new("one > 0"))); + assert_eq!( + p.filter_after_select, + Some(FilterSpec::new("sum(three) > 0")) + ); +} + #[test] fn test_pipeline_builder_read_sample_display_parquet() { let mut builder = PipelineBuilder::new();