Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,13 @@ Columns appear in the output in the order they are listed, so `select` can also
> read("data.parquet") |> select(:email, :id) |> write("reordered.csv")
```

With optional `group_by(...)`, you can use aggregates in `select`: `sum`, `avg`, `min`, `max`, `count` (non-null values in a column), and `count_distinct` (distinct non-null values). A `select` of only aggregates (no `group_by`) summarizes the whole table.

```text
> read("data.parquet") |> select(count(:id))
> read("data.parquet") |> group_by(:country) |> select(:country, count_distinct(:user_id))
```

#### `head(n)`

Take the first _n_ rows.
Expand Down
24 changes: 16 additions & 8 deletions docs/REPL.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,29 +168,35 @@ If the column name is specified as a `Symbol` (`:name`) or a bare identifier (`n

If the column name is specified as a `String` (`"_one"`), then `select()` performs an exact, case-sensitive match.

The same rules apply to column arguments inside `sum(...)`, `avg(...)`, `min(...)`, and `max(...)`.
The same rules apply to column arguments inside `sum(...)`, `avg(...)`, `min(...)`, `max(...)`, `count(...)`, and `count_distinct(...)`.

#### Global aggregates (`sum` / `avg` / `min` / `max` without `group_by`)
#### Global aggregates (`sum` / `avg` / `min` / `max` / `count` / `count_distinct` without `group_by`)

With no `group_by()`, you may use `select()` with only aggregate arguments (`sum(...)`, `avg(...)`, `min(...)`, and/or `max(...)`). That aggregates the whole table (for example, one row of totals/averages/minima/maxima).
With no `group_by()`, you may use `select()` with only aggregate arguments (`sum(...)`, `avg(...)`, `min(...)`, `max(...)`, `count(...)`, and/or `count_distinct(...)`). That aggregates the whole table (for example, one row of totals, averages, extrema, or counts).

- `count(:col)` counts non-null values in `:col`.
- `count_distinct(:col)` counts distinct non-null values in `:col`.

```flt
read("input.parquet") |> select(sum(:quantity))
read("input.parquet") |> select(avg(:amount))
read("input.parquet") |> select(min(:amount), max(:amount))
read("input.parquet") |> select(count(:id))
read("input.parquet") |> select(count_distinct(:user_id))
```

Mixing plain column projections with aggregates in the same `select()` without `group_by()` is not allowed—use `group_by()` for the key columns first, or use only aggregates for a global summary.

#### Grouped aggregates (`group_by` + `select`)

`group_by(:key1, ...)` is a separate pipeline step. Every column in `group_by` _MUST_ be included as a plain column in `select()`. Any other selected column must use an aggregate (`sum()`, `avg()`, `min()`, or `max()`).
`group_by(:key1, ...)` is a separate pipeline step. Every column in `group_by` _MUST_ be included as a plain column in `select()`. Any other selected column must use an aggregate (`sum()`, `avg()`, `min()`, `max()`, `count()`, or `count_distinct()`).

```flt
read("input.parquet") |> group_by(:country_code) |> select(:country_code, avg(:amount))
read("input.parquet") |> group_by(:country_code) |> select(:country_code, count(:order_id))
```

`select(:country_code, avg(:amount)) |> group_by(:country_code)` is equivalent to the form above (and you can use `sum()`, `min()`, or `max()` instead of `avg()` where appropriate).
`select(:country_code, avg(:amount)) |> group_by(:country_code)` is equivalent to the form above (and you can use `sum()`, `min()`, `max()`, `count()`, or `count_distinct()` instead of `avg()` where appropriate).

If `group_by()` is present but `select()` lists only key columns (no aggregates), the statement is still valid (distinct group keys); the REPL prints this warning to stderr:

Expand Down Expand Up @@ -218,8 +224,10 @@ read("input.parquet") |> head(5) # first 5 rows
head("input.parquet", 5) # equivalent to the previous statement
```

### Metadata functions: `schema` and `count`
### Metadata functions: `schema` and `count()`

`schema()` prints the table schema. It takes no arguments.

`schema` prints the table schema.
`count()` (as a pipeline stage after `read`, e.g. `read("file.parquet") |> count()`) prints the total row count. It takes no arguments.

`count` prints the row count. Neither takes arguments.
Do not confuse that with **`count(:column)` inside `select()`**, which is an aggregate: it counts non-null values in that column (globally or per group when used with `group_by`). Use `count_distinct(:column)` for distinct non-null values.
62 changes: 62 additions & 0 deletions features/repl/aggregates.feature
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,65 @@ Feature: Aggregate Functions
Then the command should succeed
And the output should contain "1,20"
And the output should contain "2,5"

Scenario: Count non-null values
Given a Parquet file with the following data:
```
item_id,quantity
1,11
2,22
3,33
```
When the REPL is ran and the user types:
```
read("$TEMPDIR/input.parquet") |> select(count(:quantity))
```
Then the command should succeed
And the output should contain "3"

Scenario: Count distinct values
Given a Parquet file with the following data:
```
item_id,region
1,US
2,US
3,EU
```
When the REPL is ran and the user types:
```
read("$TEMPDIR/input.parquet") |> select(count_distinct(:region))
```
Then the command should succeed
And the output should contain "2"

Scenario: Group by with count
Given a Parquet file with the following data:
```
item_id,quantity
1,10
1,20
2,5
```
When the REPL is ran and the user types:
```
read("$TEMPDIR/input.parquet") |> group_by(:item_id) |> select(:item_id, count(:quantity))
```
Then the command should succeed
And the output should contain "1,2"
And the output should contain "2,1"

Scenario: Group by with count distinct
Given a Parquet file with the following data:
```
item_id,region
1,US
1,US
2,EU
```
When the REPL is ran and the user types:
```
read("$TEMPDIR/input.parquet") |> group_by(:item_id) |> select(:item_id, count_distinct(:region))
```
Then the command should succeed
And the output should contain "1,1"
And the output should contain "2,1"
47 changes: 30 additions & 17 deletions src/cli/repl/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ fn select_args_are_all_aggregates(args: &[Expr]) -> bool {
matches!(
e,
Expr::FunctionCall(n, a)
if matches!(n.to_string().as_str(), "sum" | "avg" | "min" | "max")
&& a.len() == 1
if matches!(
n.to_string().as_str(),
"sum" | "avg" | "min" | "max" | "count" | "count_distinct"
) && a.len() == 1
)
})
}
Expand Down Expand Up @@ -100,28 +102,34 @@ fn select_aggregate_item(name: &str, col: ColumnSpec) -> SelectItem {
"avg" => SelectItem::Avg(col),
"min" => SelectItem::Min(col),
"max" => SelectItem::Max(col),
_ => unreachable!("select_aggregate_item only called for sum, avg, min, max"),
"count" => SelectItem::Count(col),
"count_distinct" => SelectItem::CountDistinct(col),
_ => unreachable!(
"select_aggregate_item only called for sum, avg, min, max, count, or count_distinct"
),
}
}

/// Extracts select items: column refs or `sum(column)` / `avg(column)` / `min(column)` / `max(column)`.
/// Extracts select items: column refs or `sum(column)` / `avg(column)` / `min(column)` / `max(column)` /
/// `count(column)` / `count_distinct(column)`.
pub(super) fn extract_select_items(args: &[Expr]) -> crate::Result<Vec<SelectItem>> {
const SELECT_AGG_EXPECTED: &str =
"select expects column names, sum(column), avg(column), min(column), or max(column)";
const SELECT_AGG_EXPECTED: &str = "select expects column names, sum(column), avg(column), min(column), max(column), count(column), or count_distinct(column)";
args.iter()
.map(|expr| match expr {
Expr::FunctionCall(name, inner) => {
let name_str = name.to_string();
match name_str.as_str() {
"sum" | "avg" | "min" | "max" => match inner.as_slice() {
[one] => Ok(select_aggregate_item(
name_str.as_str(),
extract_one_column_spec(one)?,
)),
_ => Err(Error::UnsupportedFunctionCall(format!(
"{name_str}() expects exactly one column argument"
))),
},
"sum" | "avg" | "min" | "max" | "count" | "count_distinct" => {
match inner.as_slice() {
[one] => Ok(select_aggregate_item(
name_str.as_str(),
extract_one_column_spec(one)?,
)),
_ => Err(Error::UnsupportedFunctionCall(format!(
"{name_str}() expects exactly one column argument"
))),
}
}
_ => Err(Error::UnsupportedFunctionCall(format!(
"{SELECT_AGG_EXPECTED}, got {expr:?}"
))),
Expand Down Expand Up @@ -270,12 +278,17 @@ fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate::
SelectItem::Column(c) => {
if !keys.iter().any(|k| k == c) {
return Err(Error::InvalidReplPipeline(
"select with group_by: non-key columns must use an aggregate (sum, avg, min, or max), not plain columns"
"select with group_by: non-key columns must use an aggregate (sum, avg, min, max, count, or count_distinct), not plain columns"
.to_string(),
));
}
}
SelectItem::Sum(_) | SelectItem::Avg(_) | SelectItem::Min(_) | SelectItem::Max(_) => {}
SelectItem::Sum(_)
| SelectItem::Avg(_)
| SelectItem::Min(_)
| SelectItem::Max(_)
| SelectItem::Count(_)
| SelectItem::CountDistinct(_) => {}
}
}
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions src/cli/repl/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ impl fmt::Display for ReplPipelineStage {
SelectItem::Avg(c) => format!("avg({})", format_column_spec(c)),
SelectItem::Min(c) => format!("min({})", format_column_spec(c)),
SelectItem::Max(c) => format!("max({})", format_column_spec(c)),
SelectItem::Count(c) => format!("count({})", format_column_spec(c)),
SelectItem::CountDistinct(c) => {
format!("count_distinct({})", format_column_spec(c))
}
})
.collect::<Vec<_>>();
write!(f, "select({})", cols.join(", "))
Expand Down
27 changes: 23 additions & 4 deletions src/cli/repl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ fn test_is_statement_complete_select_min_only() {
assert!(is_statement_complete(&exprs));
}

#[test]
fn test_is_statement_complete_select_count_only() {
let expr = parse("select(count(:quantity))");
assert!(is_statement_complete(std::slice::from_ref(&expr)));
let exprs = pipe_exprs("select(count_distinct(:id))");
assert!(is_statement_complete(&exprs));
}

#[test]
fn test_is_statement_complete_select_then_group_by() {
let exprs = pipe_exprs(r#"read("f.parquet") |> select(:id, sum(:qty)) |> group_by(:id)"#);
Expand All @@ -126,7 +134,12 @@ fn test_plan_stage_select_aggregates() {
("select(sum(:quantity))", SelectItem::Sum(qty.clone())),
("select(avg(:quantity))", SelectItem::Avg(qty.clone())),
("select(min(:quantity))", SelectItem::Min(qty.clone())),
("select(max(:quantity))", SelectItem::Max(qty)),
("select(max(:quantity))", SelectItem::Max(qty.clone())),
("select(count(:quantity))", SelectItem::Count(qty.clone())),
(
"select(count_distinct(:quantity))",
SelectItem::CountDistinct(qty),
),
];
for (input, expected_col) in cases {
let expr = parse(input);
Expand Down Expand Up @@ -381,7 +394,9 @@ fn test_extract_select_items_aggregates() {
("sum", SelectItem::Sum(qty.clone())),
("avg", SelectItem::Avg(qty.clone())),
("min", SelectItem::Min(qty.clone())),
("max", SelectItem::Max(qty)),
("max", SelectItem::Max(qty.clone())),
("count", SelectItem::Count(qty.clone())),
("count_distinct", SelectItem::CountDistinct(qty)),
];
for (fn_name, expected) in cases {
let args = vec![Expr::FunctionCall(
Expand Down Expand Up @@ -552,7 +567,9 @@ fn test_validate_accepts_read_aggregate_select_only() {
SelectItem::Sum(q.clone()),
SelectItem::Avg(q.clone()),
SelectItem::Min(q.clone()),
SelectItem::Max(q),
SelectItem::Max(q.clone()),
SelectItem::Count(q.clone()),
SelectItem::CountDistinct(q),
];
for item in aggregates {
let stages = vec![
Expand Down Expand Up @@ -842,7 +859,9 @@ fn test_terminal_stage_classification() {
SelectItem::Sum(col_x.clone()),
SelectItem::Avg(col_x.clone()),
SelectItem::Min(col_x.clone()),
SelectItem::Max(col_x),
SelectItem::Max(col_x.clone()),
SelectItem::Count(col_x.clone()),
SelectItem::CountDistinct(col_x),
] {
assert!(
ReplPipelineStage::Select {
Expand Down
2 changes: 1 addition & 1 deletion src/pipeline/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl PipelineBuilder {
))
}

/// Display pipeline: global aggregate `select(sum(...), avg(...), min(...), max(...), ...)` with full result (one row) to stdout.
/// Display pipeline: global aggregate `select(sum(...), avg(...), min(...), max(...), count(...), count_distinct(...), ...)` with full result (one row) to stdout.
fn build_aggregate_display_pipeline(
&self,
input_path: &str,
Expand Down
24 changes: 22 additions & 2 deletions src/pipeline/dataframe/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use arrow::array::RecordBatchReader;
use arrow::record_batch::RecordBatch;
use datafusion::execution::context::SessionContext;
use datafusion::functions_aggregate::expr_fn::avg;
use datafusion::functions_aggregate::expr_fn::count;
use datafusion::functions_aggregate::expr_fn::count_distinct;
use datafusion::functions_aggregate::expr_fn::max;
use datafusion::functions_aggregate::expr_fn::min;
use datafusion::functions_aggregate::expr_fn::sum;
Expand Down Expand Up @@ -161,15 +163,17 @@ pub(super) fn apply_select_spec_to_dataframe(
SelectItem::Column(c) => {
if !column_spec_in_group_keys(c, group_by_keys) {
return Err(Error::GenericError(
"select with group_by: non-key columns must use an aggregate (sum, avg, min, or max), not plain columns"
"select with group_by: non-key columns must use an aggregate (sum, avg, min, max, count, or count_distinct), not plain columns"
.to_string(),
));
}
}
SelectItem::Sum(_)
| SelectItem::Avg(_)
| SelectItem::Min(_)
| SelectItem::Max(_) => {}
| SelectItem::Max(_)
| SelectItem::Count(_)
| SelectItem::CountDistinct(_) => {}
}
}

Expand Down Expand Up @@ -198,6 +202,14 @@ pub(super) fn apply_select_spec_to_dataframe(
let name = cs.resolve(arrow_schema)?;
aggs.push(max(col(name.as_str())));
}
SelectItem::Count(cs) => {
let name = cs.resolve(arrow_schema)?;
aggs.push(count(col(name.as_str())));
}
SelectItem::CountDistinct(cs) => {
let name = cs.resolve(arrow_schema)?;
aggs.push(count_distinct(col(name.as_str())));
}
SelectItem::Column(_) => {}
}
}
Expand Down Expand Up @@ -246,6 +258,14 @@ pub(super) fn apply_select_spec_to_dataframe(
let name = cs.resolve(arrow_schema)?;
aggs.push(max(col(name.as_str())));
}
SelectItem::Count(cs) => {
let name = cs.resolve(arrow_schema)?;
aggs.push(count(col(name.as_str())));
}
SelectItem::CountDistinct(cs) => {
let name = cs.resolve(arrow_schema)?;
aggs.push(count_distinct(col(name.as_str())));
}
SelectItem::Column(_) => {}
}
}
Expand Down
Loading
Loading