Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion docs/REPL.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ For the following functions, note that the function signatures and types provide

### Pipeline shape

A REPL pipeline must start with `read(...)`. You may use at most one `group_by(...)` and one `select(...)` per pipeline (in either order). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select` or `group_by` in the same pipeline. If `group_by(...)` appears, a matching `select(...)` is required.
A REPL pipeline must start with `read(...)`. You may use at most **two** `filter("...")` stages (only when one appears **before** `select(...)` and one **after**, so you can combine WHERE-like and HAVING-like predicates), at most one `group_by(...)`, and one `select(...)`, in **any order** among those stages (subject to `group_by(...)` requiring a matching `select(...)`). Further stages—`head`, `tail`, `sample`, `schema`, `count`, or `write`—are added when needed (for example, `read("x.parquet") |> head(5)` skips `select` entirely). You cannot repeat `select` or `group_by`.

### `read`

Expand All @@ -131,6 +131,29 @@ Reads a Parquet, Avro, ORC, CSV, or JSON file at the given `path`. If `file_type
| `.csv` | CSV |
| `.json` | JSON |

### `filter`

```flt
filter(data: Data, sql: String) -> Data
```

`filter` takes a single string that is a SQL predicate fragment. It is parsed with Apache DataFusion. **Placement relative to `select` in the pipeline** fixes whether you filter input rows or the result of the `select` step (which, when `group_by` is used, includes aggregation in one logical step—similar to SQL **WHERE** vs **HAVING**):

- **`filter` before `select`** (when `select` is present): predicate on **source** columns, evaluated on each input row **before** projection or aggregation (WHERE-like when `group_by` is used).
- **`filter` after `select` without `group_by`**: predicate on **projected** columns only.
- **`filter` after `select` with `group_by`**: predicate on the **grouped/aggregated** result; use the output column names DataFusion produces for aggregates (commonly `sum(column_name)`, `avg(column_name)`, etc., matching the source column name).
- **Two `filter` stages**: the first (by pipeline order before `select`) runs on **input rows**; the second (after `select`) runs on the **result**—together, analogous to **WHERE** then **HAVING** when `group_by` and aggregates are used.

```flt
read("input.parquet") |> filter("amount > 0") |> select(:amount, :status) |> head(10)
read("input.parquet") |> select(:amount, :status) |> filter("amount > 0 AND status = 'active'") |> head(10)
read("input.parquet") |> filter("amount > 0") |> group_by(:country) |> select(:country, sum(:amount)) |> head(10)
read("input.parquet") |> group_by(:country) |> select(:country, sum(:amount)) |> filter("sum(amount) > 100") |> head(10)
read("input.parquet") |> filter("status = 'active'") |> group_by(:country) |> select(:country, sum(:amount)) |> filter("sum(amount) > 100") |> head(10)
```

`filter` is only supported for inputs read through DataFusion (Parquet, Avro, CSV, JSON). It is **not** supported for ORC files in the REPL; convert or use another format first.

### `write`

```flt
Expand Down
53 changes: 51 additions & 2 deletions src/cli/repl/builder_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,47 @@ pub(crate) fn repl_stages_to_pipeline_builder(
builder.read(path);

let mut i = 1usize;
let mut select_idx: Option<usize> = None;
let mut group_keys: Option<Vec<ColumnSpec>> = None;
let mut select_columns: Option<Vec<SelectItem>> = None;
let mut filters: Vec<(usize, String)> = Vec::new();

for _ in 0..2 {
while i < body.len() {
match body.get(i) {
Some(ReplPipelineStage::Filter { sql }) => {
filters.push((i, sql.clone()));
i += 1;
}
Some(ReplPipelineStage::GroupBy { columns }) => {
group_keys = Some(columns.clone());
i += 1;
}
Some(ReplPipelineStage::Select { columns }) => {
if select_idx.is_none() {
select_idx = Some(i);
}
select_columns = Some(columns.clone());
i += 1;
}
_ => break,
Some(
ReplPipelineStage::Head { .. }
| ReplPipelineStage::Tail { .. }
| ReplPipelineStage::Sample { .. }
| ReplPipelineStage::Schema
| ReplPipelineStage::Count
| ReplPipelineStage::Write { .. },
) => break,
Some(ReplPipelineStage::Read { .. }) => {
return Err(crate::Error::InvalidReplPipeline(
"unexpected read(path) after start of pipeline".to_string(),
));
}
Some(ReplPipelineStage::Print) => {
return Err(crate::Error::InvalidReplPipeline(
"unexpected print() in pipeline body".to_string(),
));
}
None => break,
}
}

Expand All @@ -53,6 +80,28 @@ pub(crate) fn repl_stages_to_pipeline_builder(
builder.select_spec(spec);
}

match filters.len() {
0 => {}
1 => {
let (f, sql) = &filters[0];
if select_idx.is_some_and(|s| *f > s) {
builder.filter_after_select(sql);
} else {
builder.filter_before_select(sql);
}
}
2 => {
filters.sort_by_key(|(idx, _)| *idx);
builder.filter_before_select(&filters[0].1);
builder.filter_after_select(&filters[1].1);
}
_ => {
return Err(crate::Error::InvalidReplPipeline(
"at most two filter(...) stages are allowed in a pipeline".to_string(),
));
}
}

match body.get(i) {
Some(ReplPipelineStage::Head { n }) => {
builder.head(*n);
Expand Down
73 changes: 67 additions & 6 deletions src/cli/repl/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ pub(super) fn extract_path_from_args(func_name: &str, args: &[Expr]) -> crate::R
}
}

/// Extracts a single SQL predicate string (same shape as [`extract_path_from_args`]).
pub(super) fn extract_sql_predicate_from_args(
func_name: &str,
args: &[Expr],
) -> crate::Result<String> {
extract_path_from_args(func_name, args)
}

fn extract_one_column_spec(expr: &Expr) -> crate::Result<ColumnSpec> {
match expr {
Expr::Literal(Literal::Symbol(s)) => Ok(ColumnSpec::CaseInsensitive(s.clone())),
Expand Down Expand Up @@ -168,6 +176,10 @@ pub(super) fn plan_stage(expr: Expr) -> crate::Result<ReplPipelineStage> {
let path = extract_path_from_args("read", &args)?;
Ok(ReplPipelineStage::Read { path })
}
"filter" => {
let sql = extract_sql_predicate_from_args("filter", &args)?;
Ok(ReplPipelineStage::Filter { sql })
}
"group_by" => {
if args.is_empty() {
return Err(Error::UnsupportedFunctionCall(
Expand Down Expand Up @@ -294,8 +306,8 @@ fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate::
Ok(())
}

/// Validates that stages match `read` → optional `group_by` / `select` (either order) → optional slice or `schema`/`count` → optional `write`,
/// with optional trailing `print` only after head/tail/sample.
/// Validates that stages match `read` → optional permuted `filter` (at most two, straddling `select` if two) / `group_by` / `select` (at most one each) → optional slice or `schema`/`count` → optional `write`,
/// with optional trailing `print` only after head/tail/sample. A single `filter` maps to before or after `select` by stage order; two `filter` stages require `select(...)` strictly between them (WHERE-like then HAVING-like when aggregating).
pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> crate::Result<()> {
if stages.is_empty() {
return Err(Error::InvalidReplPipeline("empty pipeline".to_string()));
Expand All @@ -313,11 +325,22 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra
}

let mut i = 1usize;
let mut filter_indices: Vec<usize> = Vec::new();
let mut select_idx: Option<usize> = None;
let mut group_by_cols: Option<&Vec<ColumnSpec>> = None;
let mut select_items: Option<&Vec<SelectItem>> = None;

for _ in 0..2 {
while i < body.len() {
match body.get(i) {
Some(ReplPipelineStage::Filter { .. }) => {
if filter_indices.len() >= 2 {
return Err(Error::InvalidReplPipeline(
"at most two filter(...) stages are allowed in a pipeline".to_string(),
));
}
filter_indices.push(i);
i += 1;
}
Some(ReplPipelineStage::GroupBy { columns }) => {
if group_by_cols.is_some() {
return Err(Error::InvalidReplPipeline(
Expand All @@ -333,10 +356,31 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra
"only one select(...) is allowed in a pipeline".to_string(),
));
}
if select_idx.is_none() {
select_idx = Some(i);
}
select_items = Some(columns);
i += 1;
}
_ => break,
Some(
ReplPipelineStage::Head { .. }
| ReplPipelineStage::Tail { .. }
| ReplPipelineStage::Sample { .. }
| ReplPipelineStage::Schema
| ReplPipelineStage::Count
| ReplPipelineStage::Write { .. },
) => break,
Some(ReplPipelineStage::Read { .. }) => {
return Err(Error::InvalidReplPipeline(
"unexpected read(path) after start of pipeline".to_string(),
));
}
Some(ReplPipelineStage::Print) => {
return Err(Error::InvalidReplPipeline(
"unexpected print() in pipeline body".to_string(),
));
}
None => break,
}
}

Expand All @@ -346,6 +390,23 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra
));
}

if filter_indices.len() == 2 {
let Some(si) = select_idx else {
return Err(Error::InvalidReplPipeline(
"two filter(...) stages require select(...) between them (one before and one after select(...))"
.to_string(),
));
};
let f0 = filter_indices[0].min(filter_indices[1]);
let f1 = filter_indices[0].max(filter_indices[1]);
if !(f0 < si && si < f1) {
return Err(Error::InvalidReplPipeline(
"two filter(...) stages must have select(...) strictly between them (one before select, one after)"
.to_string(),
));
}
}

if let Some(keys) = group_by_cols {
let items = select_items.expect("select when group_by");
validate_grouped_select(keys, items)?;
Expand Down Expand Up @@ -402,14 +463,14 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra
}
_ => {
return Err(Error::InvalidReplPipeline(
"invalid pipeline stage order (expected read, optional group_by and select, then head|tail|sample|schema|count, or write)".to_string(),
"invalid pipeline stage order (expected read, optional filter/group_by/select in any order, then head|tail|sample|schema|count, or write)".to_string(),
));
}
}

if i != body.len() {
return Err(Error::InvalidReplPipeline(
"invalid pipeline stage order (expected read, optional group_by/select, optional head|tail|sample|schema|count, optional write)".to_string(),
"invalid pipeline stage order (expected read, optional filter/group_by/select, optional head|tail|sample|schema|count, optional write)".to_string(),
));
}

Expand Down
35 changes: 27 additions & 8 deletions src/cli/repl/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,33 @@ use crate::pipeline::SelectSpec;
/// A planned pipeline stage with validated, extracted arguments.
#[derive(Debug, PartialEq)]
pub enum ReplPipelineStage {
Read { path: String },
GroupBy { columns: Vec<ColumnSpec> },
Select { columns: Vec<SelectItem> },
Head { n: usize },
Tail { n: usize },
Sample { n: usize },
Read {
path: String,
},
/// SQL predicate (`parse_sql_expr` + `filter`); before or after the `select` stage per pipeline order (after includes post-aggregate when `group_by` is used).
Filter {
sql: String,
},
GroupBy {
columns: Vec<ColumnSpec>,
},
Select {
columns: Vec<SelectItem>,
},
Head {
n: usize,
},
Tail {
n: usize,
},
Sample {
n: usize,
},
Schema,
Count,
Write { path: String },
Write {
path: String,
},
Print,
}

Expand Down Expand Up @@ -68,7 +86,7 @@ impl ReplPipelineStage {
// Single-stage check: full pipeline uses `repl_pipeline_last_select_is_terminal`.
select_spec_from_items(columns).is_aggregate_only()
}
ReplPipelineStage::GroupBy { .. } => false,
ReplPipelineStage::GroupBy { .. } | ReplPipelineStage::Filter { .. } => false,
ReplPipelineStage::Read { .. } | ReplPipelineStage::Print => false,
}
}
Expand Down Expand Up @@ -100,6 +118,7 @@ impl fmt::Display for ReplPipelineStage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ReplPipelineStage::Read { path } => write!(f, r#"read("{path}")"#),
ReplPipelineStage::Filter { sql } => write!(f, "filter({sql:?})"),
ReplPipelineStage::GroupBy { columns } => {
let cols: Vec<String> = columns.iter().map(format_column_spec).collect();
write!(f, "group_by({})", cols.join(", "))
Expand Down
Loading
Loading