diff --git a/README.md b/README.md index 4d30a19..6354c0c 100644 --- a/README.md +++ b/README.md @@ -445,6 +445,13 @@ Columns appear in the output in the order they are listed, so `select` can also > read("data.parquet") |> select(:email, :id) |> write("reordered.csv") ``` +With optional `group_by(...)`, you can use aggregates in `select`: `sum`, `avg`, `min`, `max`, `count` (non-null values in a column), and `count_distinct` (distinct non-null values). A `select` of only aggregates (no `group_by`) summarizes the whole table. + +```text +> read("data.parquet") |> select(count(:id)) +> read("data.parquet") |> group_by(:country) |> select(:country, count_distinct(:user_id)) +``` + #### `head(n)` Take the first _n_ rows. diff --git a/docs/REPL.md b/docs/REPL.md index 68b905a..9248d35 100644 --- a/docs/REPL.md +++ b/docs/REPL.md @@ -168,29 +168,35 @@ If the column name is specified as a `Symbol` (`:name`) or a bare identifier (`n If the column name is specified as a `String` (`"_one"`), then `select()` performs an exact, case-sensitive match. -The same rules apply to column arguments inside `sum(...)`, `avg(...)`, `min(...)`, and `max(...)`. +The same rules apply to column arguments inside `sum(...)`, `avg(...)`, `min(...)`, `max(...)`, `count(...)`, and `count_distinct(...)`. -#### Global aggregates (`sum` / `avg` / `min` / `max` without `group_by`) +#### Global aggregates (`sum` / `avg` / `min` / `max` / `count` / `count_distinct` without `group_by`) -With no `group_by()`, you may use `select()` with only aggregate arguments (`sum(...)`, `avg(...)`, `min(...)`, and/or `max(...)`). That aggregates the whole table (for example, one row of totals/averages/minima/maxima). +With no `group_by()`, you may use `select()` with only aggregate arguments (`sum(...)`, `avg(...)`, `min(...)`, `max(...)`, `count(...)`, and/or `count_distinct(...)`). That aggregates the whole table (for example, one row of totals, averages, extrema, or counts). + +- `count(:col)` counts non-null values in `:col`. +- `count_distinct(:col)` counts distinct non-null values in `:col`. ```flt read("input.parquet") |> select(sum(:quantity)) read("input.parquet") |> select(avg(:amount)) read("input.parquet") |> select(min(:amount), max(:amount)) +read("input.parquet") |> select(count(:id)) +read("input.parquet") |> select(count_distinct(:user_id)) ``` Mixing plain column projections with aggregates in the same `select()` without `group_by()` is not allowed—use `group_by()` for the key columns first, or use only aggregates for a global summary. #### Grouped aggregates (`group_by` + `select`) -`group_by(:key1, ...)` is a separate pipeline step. Every column in `group_by` _MUST_ be included as a plain column in `select()`. Any other selected column must use an aggregate (`sum()`, `avg()`, `min()`, or `max()`). +`group_by(:key1, ...)` is a separate pipeline step. Every column in `group_by` _MUST_ be included as a plain column in `select()`. Any other selected column must use an aggregate (`sum()`, `avg()`, `min()`, `max()`, `count()`, or `count_distinct()`). ```flt read("input.parquet") |> group_by(:country_code) |> select(:country_code, avg(:amount)) +read("input.parquet") |> group_by(:country_code) |> select(:country_code, count(:order_id)) ``` -`select(:country_code, avg(:amount)) |> group_by(:country_code)` is equivalent to the form above (and you can use `sum()`, `min()`, or `max()` instead of `avg()` where appropriate). +`select(:country_code, avg(:amount)) |> group_by(:country_code)` is equivalent to the form above (and you can use `sum()`, `min()`, `max()`, `count()`, or `count_distinct()` instead of `avg()` where appropriate). If `group_by()` is present but `select()` lists only key columns (no aggregates), the statement is still valid (distinct group keys); the REPL prints this warning to stderr: @@ -218,8 +224,10 @@ read("input.parquet") |> head(5) # first 5 rows head("input.parquet", 5) # equivalent to the previous statement ``` -### Metadata functions: `schema` and `count` +### Metadata functions: `schema` and `count()` + +`schema()` prints the table schema. It takes no arguments. -`schema` prints the table schema. +`count()` (as a pipeline stage after `read`, e.g. `read("file.parquet") |> count()`) prints the total row count. It takes no arguments. -`count` prints the row count. Neither takes arguments. +Do not confuse that with **`count(:column)` inside `select()`**, which is an aggregate: it counts non-null values in that column (globally or per group when used with `group_by`). Use `count_distinct(:column)` for distinct non-null values. diff --git a/features/repl/aggregates.feature b/features/repl/aggregates.feature index 89d153b..2331f0d 100644 --- a/features/repl/aggregates.feature +++ b/features/repl/aggregates.feature @@ -123,3 +123,65 @@ Feature: Aggregate Functions Then the command should succeed And the output should contain "1,20" And the output should contain "2,5" + + Scenario: Count non-null values + Given a Parquet file with the following data: + ``` + item_id,quantity + 1,11 + 2,22 + 3,33 + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> select(count(:quantity)) + ``` + Then the command should succeed + And the output should contain "3" + + Scenario: Count distinct values + Given a Parquet file with the following data: + ``` + item_id,region + 1,US + 2,US + 3,EU + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> select(count_distinct(:region)) + ``` + Then the command should succeed + And the output should contain "2" + + Scenario: Group by with count + Given a Parquet file with the following data: + ``` + item_id,quantity + 1,10 + 1,20 + 2,5 + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> group_by(:item_id) |> select(:item_id, count(:quantity)) + ``` + Then the command should succeed + And the output should contain "1,2" + And the output should contain "2,1" + + Scenario: Group by with count distinct + Given a Parquet file with the following data: + ``` + item_id,region + 1,US + 1,US + 2,EU + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> group_by(:item_id) |> select(:item_id, count_distinct(:region)) + ``` + Then the command should succeed + And the output should contain "1,1" + And the output should contain "2,1" diff --git a/src/cli/repl/plan.rs b/src/cli/repl/plan.rs index 2d977ec..d012656 100644 --- a/src/cli/repl/plan.rs +++ b/src/cli/repl/plan.rs @@ -67,8 +67,10 @@ fn select_args_are_all_aggregates(args: &[Expr]) -> bool { matches!( e, Expr::FunctionCall(n, a) - if matches!(n.to_string().as_str(), "sum" | "avg" | "min" | "max") - && a.len() == 1 + if matches!( + n.to_string().as_str(), + "sum" | "avg" | "min" | "max" | "count" | "count_distinct" + ) && a.len() == 1 ) }) } @@ -100,28 +102,34 @@ fn select_aggregate_item(name: &str, col: ColumnSpec) -> SelectItem { "avg" => SelectItem::Avg(col), "min" => SelectItem::Min(col), "max" => SelectItem::Max(col), - _ => unreachable!("select_aggregate_item only called for sum, avg, min, max"), + "count" => SelectItem::Count(col), + "count_distinct" => SelectItem::CountDistinct(col), + _ => unreachable!( + "select_aggregate_item only called for sum, avg, min, max, count, or count_distinct" + ), } } -/// Extracts select items: column refs or `sum(column)` / `avg(column)` / `min(column)` / `max(column)`. +/// Extracts select items: column refs or `sum(column)` / `avg(column)` / `min(column)` / `max(column)` / +/// `count(column)` / `count_distinct(column)`. pub(super) fn extract_select_items(args: &[Expr]) -> crate::Result> { - const SELECT_AGG_EXPECTED: &str = - "select expects column names, sum(column), avg(column), min(column), or max(column)"; + const SELECT_AGG_EXPECTED: &str = "select expects column names, sum(column), avg(column), min(column), max(column), count(column), or count_distinct(column)"; args.iter() .map(|expr| match expr { Expr::FunctionCall(name, inner) => { let name_str = name.to_string(); match name_str.as_str() { - "sum" | "avg" | "min" | "max" => match inner.as_slice() { - [one] => Ok(select_aggregate_item( - name_str.as_str(), - extract_one_column_spec(one)?, - )), - _ => Err(Error::UnsupportedFunctionCall(format!( - "{name_str}() expects exactly one column argument" - ))), - }, + "sum" | "avg" | "min" | "max" | "count" | "count_distinct" => { + match inner.as_slice() { + [one] => Ok(select_aggregate_item( + name_str.as_str(), + extract_one_column_spec(one)?, + )), + _ => Err(Error::UnsupportedFunctionCall(format!( + "{name_str}() expects exactly one column argument" + ))), + } + } _ => Err(Error::UnsupportedFunctionCall(format!( "{SELECT_AGG_EXPECTED}, got {expr:?}" ))), @@ -270,12 +278,17 @@ fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate:: SelectItem::Column(c) => { if !keys.iter().any(|k| k == c) { return Err(Error::InvalidReplPipeline( - "select with group_by: non-key columns must use an aggregate (sum, avg, min, or max), not plain columns" + "select with group_by: non-key columns must use an aggregate (sum, avg, min, max, count, or count_distinct), not plain columns" .to_string(), )); } } - SelectItem::Sum(_) | SelectItem::Avg(_) | SelectItem::Min(_) | SelectItem::Max(_) => {} + SelectItem::Sum(_) + | SelectItem::Avg(_) + | SelectItem::Min(_) + | SelectItem::Max(_) + | SelectItem::Count(_) + | SelectItem::CountDistinct(_) => {} } } Ok(()) diff --git a/src/cli/repl/stage.rs b/src/cli/repl/stage.rs index 8733069..783d0b7 100644 --- a/src/cli/repl/stage.rs +++ b/src/cli/repl/stage.rs @@ -113,6 +113,10 @@ impl fmt::Display for ReplPipelineStage { SelectItem::Avg(c) => format!("avg({})", format_column_spec(c)), SelectItem::Min(c) => format!("min({})", format_column_spec(c)), SelectItem::Max(c) => format!("max({})", format_column_spec(c)), + SelectItem::Count(c) => format!("count({})", format_column_spec(c)), + SelectItem::CountDistinct(c) => { + format!("count_distinct({})", format_column_spec(c)) + } }) .collect::>(); write!(f, "select({})", cols.join(", ")) diff --git a/src/cli/repl/tests.rs b/src/cli/repl/tests.rs index b1ea0d6..d2d8219 100644 --- a/src/cli/repl/tests.rs +++ b/src/cli/repl/tests.rs @@ -113,6 +113,14 @@ fn test_is_statement_complete_select_min_only() { assert!(is_statement_complete(&exprs)); } +#[test] +fn test_is_statement_complete_select_count_only() { + let expr = parse("select(count(:quantity))"); + assert!(is_statement_complete(std::slice::from_ref(&expr))); + let exprs = pipe_exprs("select(count_distinct(:id))"); + assert!(is_statement_complete(&exprs)); +} + #[test] fn test_is_statement_complete_select_then_group_by() { let exprs = pipe_exprs(r#"read("f.parquet") |> select(:id, sum(:qty)) |> group_by(:id)"#); @@ -126,7 +134,12 @@ fn test_plan_stage_select_aggregates() { ("select(sum(:quantity))", SelectItem::Sum(qty.clone())), ("select(avg(:quantity))", SelectItem::Avg(qty.clone())), ("select(min(:quantity))", SelectItem::Min(qty.clone())), - ("select(max(:quantity))", SelectItem::Max(qty)), + ("select(max(:quantity))", SelectItem::Max(qty.clone())), + ("select(count(:quantity))", SelectItem::Count(qty.clone())), + ( + "select(count_distinct(:quantity))", + SelectItem::CountDistinct(qty), + ), ]; for (input, expected_col) in cases { let expr = parse(input); @@ -381,7 +394,9 @@ fn test_extract_select_items_aggregates() { ("sum", SelectItem::Sum(qty.clone())), ("avg", SelectItem::Avg(qty.clone())), ("min", SelectItem::Min(qty.clone())), - ("max", SelectItem::Max(qty)), + ("max", SelectItem::Max(qty.clone())), + ("count", SelectItem::Count(qty.clone())), + ("count_distinct", SelectItem::CountDistinct(qty)), ]; for (fn_name, expected) in cases { let args = vec![Expr::FunctionCall( @@ -552,7 +567,9 @@ fn test_validate_accepts_read_aggregate_select_only() { SelectItem::Sum(q.clone()), SelectItem::Avg(q.clone()), SelectItem::Min(q.clone()), - SelectItem::Max(q), + SelectItem::Max(q.clone()), + SelectItem::Count(q.clone()), + SelectItem::CountDistinct(q), ]; for item in aggregates { let stages = vec![ @@ -842,7 +859,9 @@ fn test_terminal_stage_classification() { SelectItem::Sum(col_x.clone()), SelectItem::Avg(col_x.clone()), SelectItem::Min(col_x.clone()), - SelectItem::Max(col_x), + SelectItem::Max(col_x.clone()), + SelectItem::Count(col_x.clone()), + SelectItem::CountDistinct(col_x), ] { assert!( ReplPipelineStage::Select { diff --git a/src/pipeline/builder.rs b/src/pipeline/builder.rs index 4d1526b..bb6c39f 100644 --- a/src/pipeline/builder.rs +++ b/src/pipeline/builder.rs @@ -322,7 +322,7 @@ impl PipelineBuilder { )) } - /// Display pipeline: global aggregate `select(sum(...), avg(...), min(...), max(...), ...)` with full result (one row) to stdout. + /// Display pipeline: global aggregate `select(sum(...), avg(...), min(...), max(...), count(...), count_distinct(...), ...)` with full result (one row) to stdout. fn build_aggregate_display_pipeline( &self, input_path: &str, diff --git a/src/pipeline/dataframe/transform.rs b/src/pipeline/dataframe/transform.rs index 380a647..1f65568 100644 --- a/src/pipeline/dataframe/transform.rs +++ b/src/pipeline/dataframe/transform.rs @@ -7,6 +7,8 @@ use arrow::array::RecordBatchReader; use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionContext; use datafusion::functions_aggregate::expr_fn::avg; +use datafusion::functions_aggregate::expr_fn::count; +use datafusion::functions_aggregate::expr_fn::count_distinct; use datafusion::functions_aggregate::expr_fn::max; use datafusion::functions_aggregate::expr_fn::min; use datafusion::functions_aggregate::expr_fn::sum; @@ -161,7 +163,7 @@ pub(super) fn apply_select_spec_to_dataframe( SelectItem::Column(c) => { if !column_spec_in_group_keys(c, group_by_keys) { return Err(Error::GenericError( - "select with group_by: non-key columns must use an aggregate (sum, avg, min, or max), not plain columns" + "select with group_by: non-key columns must use an aggregate (sum, avg, min, max, count, or count_distinct), not plain columns" .to_string(), )); } @@ -169,7 +171,9 @@ pub(super) fn apply_select_spec_to_dataframe( SelectItem::Sum(_) | SelectItem::Avg(_) | SelectItem::Min(_) - | SelectItem::Max(_) => {} + | SelectItem::Max(_) + | SelectItem::Count(_) + | SelectItem::CountDistinct(_) => {} } } @@ -198,6 +202,14 @@ pub(super) fn apply_select_spec_to_dataframe( let name = cs.resolve(arrow_schema)?; aggs.push(max(col(name.as_str()))); } + SelectItem::Count(cs) => { + let name = cs.resolve(arrow_schema)?; + aggs.push(count(col(name.as_str()))); + } + SelectItem::CountDistinct(cs) => { + let name = cs.resolve(arrow_schema)?; + aggs.push(count_distinct(col(name.as_str()))); + } SelectItem::Column(_) => {} } } @@ -246,6 +258,14 @@ pub(super) fn apply_select_spec_to_dataframe( let name = cs.resolve(arrow_schema)?; aggs.push(max(col(name.as_str()))); } + SelectItem::Count(cs) => { + let name = cs.resolve(arrow_schema)?; + aggs.push(count(col(name.as_str()))); + } + SelectItem::CountDistinct(cs) => { + let name = cs.resolve(arrow_schema)?; + aggs.push(count_distinct(col(name.as_str()))); + } SelectItem::Column(_) => {} } } diff --git a/src/pipeline/spec.rs b/src/pipeline/spec.rs index 1269e0e..609c274 100644 --- a/src/pipeline/spec.rs +++ b/src/pipeline/spec.rs @@ -38,6 +38,10 @@ pub enum SelectItem { Min(ColumnSpec), /// Global maximum over one column (REPL `max(:col)`). Max(ColumnSpec), + /// Count of non-null values in one column (REPL `count(:col)`). + Count(ColumnSpec), + /// Count of distinct non-null values in one column (REPL `count_distinct(:col)`). + CountDistinct(ColumnSpec), } /// Macro to build a [`SelectSpec`] from homogeneous column forms: @@ -98,7 +102,12 @@ impl SelectItem { pub fn is_aggregate(&self) -> bool { matches!( self, - SelectItem::Sum(_) | SelectItem::Avg(_) | SelectItem::Min(_) | SelectItem::Max(_) + SelectItem::Sum(_) + | SelectItem::Avg(_) + | SelectItem::Min(_) + | SelectItem::Max(_) + | SelectItem::Count(_) + | SelectItem::CountDistinct(_) ) } } @@ -173,7 +182,9 @@ impl SelectSpec { SelectItem::Sum(_) | SelectItem::Avg(_) | SelectItem::Min(_) - | SelectItem::Max(_) => Err(Error::PipelinePlanningError( + | SelectItem::Max(_) + | SelectItem::Count(_) + | SelectItem::CountDistinct(_) => Err(Error::PipelinePlanningError( PipelinePlanningError::AggregatesInProjectionSelect, )), }) @@ -213,6 +224,13 @@ mod tests { assert!(max_item.is_aggregate()); } + #[test] + fn test_select_item_count_aggregates_are_aggregate() { + let c = ColumnSpec::CaseInsensitive("x".into()); + assert!(SelectItem::Count(c.clone()).is_aggregate()); + assert!(SelectItem::CountDistinct(c).is_aggregate()); + } + fn schema_with_columns(names: &[&str]) -> Schema { let fields: Vec = names .iter()