From b7fdaee81e3edbd42a3187a0fae764e74e95f18e Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Thu, 6 Nov 2025 22:48:14 +0100 Subject: [PATCH 01/10] Implement GROUP BY ROLLUP (col1, col2, ...) --- beanquery/compiler.py | 192 +++++++++++++++++++++++++++++++------ beanquery/parser/ast.py | 7 +- beanquery/parser/bql.ebnf | 21 +++- beanquery/parser/parser.py | 64 ++++++++++--- beanquery/query_compile.py | 78 +++++++++++++++ 5 files changed, 319 insertions(+), 43 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 6e6c442e..e887deab 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -96,11 +96,15 @@ def _compile(self, node: Optional[ast.Node]): return None raise NotImplementedError - @_compile.register - def _select(self, node: ast.Select): - self.stack.append(self.table) - - # Compile the FROM clause. + def _compile_select_base(self, node: ast.Select): + """Compile common parts of SELECT: FROM, targets, WHERE, GROUP BY, ORDER BY. + + Args: + node: A Select AST node. + Returns: + Tuple of (c_targets, c_where, element_indexes, having_index, order_spec) + """ + # Compile FROM clause c_from_expr = self._compile_from(node.from_clause) # Compile the targets. @@ -108,10 +112,6 @@ def _select(self, node: ast.Select): # Bind the WHERE expression to the execution environment. c_where = self._compile(node.where_clause) - - # Check that the FROM clause does not contain aggregates. This - # should never trigger if the compilation environment does not - # contain any aggregate. if c_where is not None and is_aggregate(c_where): raise CompilationError('aggregates are not allowed in WHERE clause') @@ -119,13 +119,50 @@ def _select(self, node: ast.Select): if c_from_expr is not None: c_where = c_from_expr if c_where is None else EvalAnd([c_from_expr, c_where]) - # Process the GROUP-BY clause. - new_targets, group_indexes, having_index = self._compile_group_by(node.group_by, c_targets) + # Process the GROUP BY clause. + new_targets, element_indexes, having_index = self._compile_group_by(node.group_by, c_targets) c_targets.extend(new_targets) # Process the ORDER-BY clause. new_targets, order_spec = self._compile_order_by(node.order_by, c_targets) c_targets.extend(new_targets) + + return c_targets, c_where, element_indexes, having_index, order_spec + + @_compile.register + def _select(self, node: ast.Select): + self.stack.append(self.table) + + # Handle ROLLUP queries separately + # Check if any grouping element has rollup=True + if node.group_by and node.group_by.elements: + has_rollup = any( + elem.get('rollup') for elem in node.group_by.elements + ) + if has_rollup: + result = self._compile_rollup(node) + self.stack.pop() + return result + + # Compile common SELECT parts + c_targets, c_where, element_indexes, having_index, order_spec = self._compile_select_base(node) + + # There is no ROLLUP clause there, therefore all elements follow the + # format {'indexes': [x], 'rollup': None} + assert ( + element_indexes is None + or all(e['rollup'] is None for e in element_indexes) + ) + assert ( + element_indexes is None + or all(len(e['indexes']) == 1 for e in element_indexes) + ) + + # Flatten element_indexes for regular GROUP BY + if element_indexes is not None: + group_indexes = [elem['indexes'][0] for elem in element_indexes] + else: + group_indexes = None # If this is an aggregate query (it groups, see list of indexes), check that # the set of non-aggregates match exactly the group indexes. This should @@ -158,6 +195,65 @@ def _select(self, node: ast.Select): self.stack.pop() return query + + def _compile_rollup(self, node: ast.Select): + """Compile a ROLLUP query as a union of grouping sets. + + Supports both full ROLLUP and mixed grouping: + - GROUP BY ROLLUP (a, b, c) → grouping sets: [(a,b,c), (a,b), (a), ()] + - GROUP BY x, ROLLUP (a, b) → grouping sets: [(x,a,b), (x,a), (x)] + + Args: + node: A Select AST node with at least one rollup element in group_by. + Returns: + An EvalUnion that executes multiple queries for each grouping set. + """ + from .query_compile import EvalUnion + + # Compile common SELECT parts + c_targets, c_where, element_indexes, having_index, order_spec = self._compile_select_base(node) + + # Separate regular columns from ROLLUP columns using element structure + regular_indexes = [] + rollup_indexes = [] + + for elem in element_indexes: + if elem['rollup']: + # ROLLUP element + rollup_indexes.extend(elem['indexes']) + else: + # Regular element + regular_indexes.extend(elem['indexes']) + + # Generate hierarchical grouping sets for ROLLUP columns + # For ROLLUP (a, b, c), generate: [(a,b,c), (a,b), (a), ()] + rollup_sets = [] + for i in range(len(rollup_indexes), -1, -1): + # Combine regular columns (always present) with rollup columns (hierarchical) + grouping_set = regular_indexes + rollup_indexes[:i] + rollup_sets.append(grouping_set) + + # Create a query for each grouping set + queries = [ + EvalQuery(self.table, + c_targets, + c_where, + grouping_set, + having_index, + None, + None, + node.distinct) + for grouping_set in rollup_sets + ] + + # Create union of all grouping set queries + return EvalUnion( + queries=queries, + rollup_sets=rollup_sets, + order_spec=order_spec, + limit=node.limit + ) + def _compile_from(self, node): if node is None: return None @@ -364,20 +460,32 @@ def _compile_group_by(self, group_by, c_targets): Returns: A tuple of new_targets: A list of new compiled target nodes. - group_indexes: If the query is an aggregate query, a list of integer - indexes to be used for processing grouping. Note that this list may be - empty (in the case of targets with only aggregates). On the other hand, - if this is not an aggregated query, this is set to None. So do - distinguish the empty list vs. None. + element_indexes: A list of dicts, one per grouping element: + [{'indexes': [int, ...], 'rollup': bool}, ...] + Each dict represents one grouping element from the grammar. + + Examples: + - Non-aggregate query: None + - Aggregate without GROUP BY: [] + - Regular GROUP BY account, year: + [{'indexes': [0], 'rollup': None}, {'indexes': [1], 'rollup': None}] + - Full ROLLUP: GROUP BY ROLLUP (account, year): + [{'indexes': [0, 1], 'rollup': True}] + - Mixed grouping: GROUP BY region, ROLLUP (year, month): + [{'indexes': [2], 'rollup': None}, {'indexes': [0, 1], 'rollup': True}] + - Implicit GROUP BY (when SUPPORT_IMPLICIT_GROUPBY=True): + [{'indexes': [0], 'rollup': None}, {'indexes': [2], 'rollup': None}] + + having_index: Index of HAVING expression in targets, or None. """ new_targets = c_targets[:] c_target_expressions = [c_target.c_expr for c_target in c_targets] - group_indexes = [] + element_indexes = [] having_index = None if group_by: - assert group_by.columns, "Internal error with GROUP-BY parsing" + assert group_by.elements, "Internal error with GROUP-BY parsing" # Compile group-by expressions and resolve them to their targets if # possible. A GROUP-BY column may be one of the following: @@ -389,7 +497,32 @@ def _compile_group_by(self, group_by, c_targets): # References by name are converted to indexes. New expressions are # inserted into the list of targets as invisible targets. targets_name_map = {target.name: index for index, target in enumerate(c_targets)} - for column in group_by.columns: + + # Initialize element structures + for elem in group_by.elements: + # Iterating over GROUP BY syntax elements, which are either a + # simple grouping column/expression, or a ROLLUP (col1, ...) + # element. + rollup_value = elem.get('rollup') + element_indexes.append({ + 'indexes': [], + 'rollup': rollup_value if rollup_value else None + }) + + # Collect all columns with their syntax element position + columns_by_element = [] + for elem_idx, elem in enumerate(group_by.elements): + if elem.get('rollup'): + columns = elem['columns'] + else: + columns = [elem['column']] + + for column in columns: + columns_by_element.append((elem_idx, column)) + + # Compile all columns and add the indexes to element_indexes, + # to return the same structure as in the parsed GROUP BY clause. + for elem_idx, column in columns_by_element: index = None # Process target references by index. @@ -428,7 +561,6 @@ def _compile_group_by(self, group_by, c_targets): c_target_expressions.append(c_expr) assert index is not None, "Internal error, could not index group-by reference." - group_indexes.append(index) # Check that the group-by column references a non-aggregate. c_expr = new_targets[index].c_expr @@ -438,6 +570,9 @@ def _compile_group_by(self, group_by, c_targets): # Check that the group-by column has a supported hashable type. if not issubclass(c_expr.dtype, collections.abc.Hashable): raise CompilationError(f'GROUP-BY a non-hashable type is not supported: "{column}"') + + # Add compiled index to the corresponding element + element_indexes[elem_idx]['indexes'].append(index) # Compile HAVING clause. if group_by.having is not None: @@ -455,25 +590,28 @@ def _compile_group_by(self, group_by, c_targets): # If the query is an aggregate query, check that all the targets are # aggregates. if all(aggregate_bools): - # FIXME: shold we really be checking for the empty - # list or is checking for a false value enough? - assert group_indexes == [] + # Empty element_indexes for aggregate query without GROUP BY + element_indexes = [] elif SUPPORT_IMPLICIT_GROUPBY: # If some of the targets aren't aggregates, automatically infer # that they are to be implicit group by targets. This makes for # a much more convenient syntax for our lightweight SQL, where # grouping is optional. - group_indexes = [ + implicit_indexes = [ index for index, c_target in enumerate(c_targets) if not c_target.is_aggregate] + # Wrap as individual elements + element_indexes = [ + {'indexes': [idx], 'rollup': None} for idx in implicit_indexes + ] else: raise CompilationError('aggregate query without a GROUP-BY should have only aggregates') else: - # This is not an aggregate query; don't set group_indexes to + # This is not an aggregate query; don't set element_indexes to # anything useful, we won't need it. - group_indexes = None + element_indexes = None - return new_targets[len(c_targets):], group_indexes, having_index + return new_targets[len(c_targets):], element_indexes, having_index @_compile.register def _column(self, node: ast.Column): diff --git a/beanquery/parser/ast.py b/beanquery/parser/ast.py index 0fc9e49e..05a62350 100644 --- a/beanquery/parser/ast.py +++ b/beanquery/parser/ast.py @@ -148,9 +148,12 @@ class From(Node): # A GROUP BY clause. # # Attributes: -# columns: A list of group-by expressions, simple Column() or otherwise. +# elements: A list of grouping elements. Each element is a dict with: +# - 'column': A single column expression (for regular grouping), or +# - 'columns': A list of columns (for ROLLUP grouping) +# - 'rollup': True if this element is a ROLLUP, None otherwise # having: An expression tree for the optional HAVING clause, or None. -GroupBy = node('GroupBy', 'columns having') +GroupBy = node('GroupBy', 'elements having') # An ORDER BY clause. # diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 1c078b63..2369d79a 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -3,7 +3,7 @@ @@ignorecase :: True @@keyword :: 'AND' 'AS' 'ASC' 'BY' 'DESC' 'DISTINCT' 'FALSE' 'FROM' 'GROUP' 'HAVING' 'IN' 'IS' 'LIMIT' 'NOT' 'OR' 'ORDER' 'PIVOT' - 'SELECT' 'TRUE' 'WHERE' + 'ROLLUP' 'SELECT' 'TRUE' 'WHERE' @@keyword :: 'CREATE' 'TABLE' 'USING' 'INSERT' 'INTO' @@keyword :: 'BALANCES' 'JOURNAL' 'PRINT' @@comments :: /(\/\*([^*]|[\r\n]|(\*+([^*\/]|[\r\n])))*\*+\/)/ @@ -55,8 +55,25 @@ table::Table = name:identifier ; + +# A GROUP BY clause supports multiple syntaxes: +# 1. Regular GROUP BY: GROUP BY col1, col2 [HAVING condition] +# 2. Full ROLLUP: GROUP BY ROLLUP (col1, col2) [HAVING condition] +# 3. Mixed grouping: GROUP BY col1, ROLLUP (col2, col3) [HAVING condition] +# +# Examples: +# - GROUP BY account, year +# - GROUP BY account, year HAVING SUM(position) > 100 +# - GROUP BY ROLLUP (account, year) +# - GROUP BY ROLLUP (account, year) HAVING SUM(position) > 100 +# - GROUP BY region, ROLLUP (year, month) groupby::GroupBy - = columns:','.{ (integer | expression) }+ ['HAVING' having:expression] + = elements:','.{ grouping_element }+ ['HAVING' having:expression] + ; + +grouping_element + = 'ROLLUP' '(' columns:','.{ (integer | expression) }+ ')' rollup:`True` + | column:(integer | expression) rollup:`` ; order::OrderBy diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 2ec9ee84..274b60e6 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -42,6 +42,7 @@ 'OR', 'ORDER', 'PIVOT', + 'ROLLUP', 'SELECT', 'TRUE', 'WHERE', @@ -344,24 +345,63 @@ def sep0(): self._token(',') def block1(): - with self._group(): - with self._choice(): - with self._option(): - self._integer_() - with self._option(): - self._expression_() - self._error( - 'expecting one of: ' - ' ' - ) + self._grouping_element_() self._positive_gather(block1, sep0) - self.name_last_node('columns') + self.name_last_node('elements') with self._optional(): self._token('HAVING') self._expression_() self.name_last_node('having') self._define(['having'], []) - self._define(['columns', 'having'], []) + self._define(['elements', 'having'], []) + + @tatsumasu() + def _grouping_element_(self): + with self._choice(): + with self._option(): + self._token('ROLLUP') + self._token('(') + + def sep0(): + self._token(',') + + def block1(): + with self._group(): + with self._choice(): + with self._option(): + self._integer_() + with self._option(): + self._expression_() + self._error( + 'expecting one of: ' + ' ' + ) + self._positive_gather(block1, sep0) + self.name_last_node('columns') + self._token(')') + self._constant(True) + self.name_last_node('rollup') + self._define(['columns', 'rollup'], []) + with self._option(): + with self._group(): + with self._choice(): + with self._option(): + self._integer_() + with self._option(): + self._expression_() + self._error( + 'expecting one of: ' + ' ' + ) + self.name_last_node('column') + self._constant('') + self.name_last_node('rollup') + self._define(['column', 'rollup'], []) + self._error( + 'expecting one of: ' + "'ROLLUP' " + ' [0-9]+' + ) @tatsumasu('OrderBy') def _order_(self): diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index eda87ec7..2d861099 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -637,6 +637,84 @@ def __call__(self): return query_execute.execute_select(self) +@dataclasses.dataclass +class EvalUnion: + """Implement UNION of multiple queries for ROLLUP. + + This class executes multiple GROUP BY queries with different grouping sets + and unions their results to implement SQL ROLLUP functionality. + + Execution steps: + 1. Compute the union of all grouping sets to identify all group columns + 2. Execute each query with its specific grouping set + 3. For each result set, NULL out columns that are not in that grouping set + (this distinguishes subtotal rows from detail rows) + 4. Union all result sets together + 5. Apply ORDER BY to the combined results (with NULL sorting last) + 6. Apply LIMIT if specified + + Example for GROUP BY region, ROLLUP (year, month): + - Query 1: GROUP BY region, year, month (detail rows) + - Query 2: GROUP BY region, year (year subtotals, month=NULL) + - Query 3: GROUP BY region (region subtotals, year=NULL, month=NULL) + """ + + queries: list[EvalQuery] + rollup_sets: list[list[int]] # List of grouping sets, one per query + order_spec: list[tuple[int, ast.Ordering]] + limit: int + + @property + def columns(self): + # All queries have the same columns + return self.queries[0].columns + + def __call__(self): + # Execute all queries and union the results + all_rows = [] + columns = None + + # Compute union of all grouping sets to find all group columns + # This is needed to determine which columns to NULL at each level + full_group_indexes = list(set(idx for grouping_set in self.rollup_sets for idx in grouping_set)) + + for query, grouping_set in zip(self.queries, self.rollup_sets): + cols, rows = query() + if columns is None: + columns = cols + + # NULL out columns that are not in this grouping set + # Columns in full_group_indexes but not in grouping_set should be NULL + grouping_set_indexes = set(grouping_set) + null_indexes = [idx for idx in full_group_indexes if idx not in grouping_set_indexes] + + # Replace values with None for non-grouped columns + if null_indexes: + rows = [ + tuple(None if i in null_indexes else val for i, val in enumerate(row)) + for row in rows + ] + + all_rows.extend(rows) + + # Apply ORDER BY if specified + if self.order_spec: + # Sort in reverse order to leverage Python's stable sort + for col_index, ordering in reversed(self.order_spec): + # NULL sorts last: (0, val) for non-NULL, (1, None) for NULL + # This is because NULL denotes the total row in ROLLUP + all_rows.sort( + key=lambda row: (0, row[col_index]) if row[col_index] is not None else (1, None), + reverse=bool(ordering) + ) + + # Apply LIMIT if specified + if self.limit is not None: + all_rows = all_rows[:self.limit] + + return columns, all_rows + + @dataclasses.dataclass class EvalPivot: """Implement PIVOT BY clause.""" From fa02a5cddfcabf9d1ddd2373c9d45cc6d934ea69 Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Thu, 13 Nov 2025 11:31:19 +0100 Subject: [PATCH 02/10] Add a "Sentinel" object to provide row headers for summarizations, like (Total), (earlier), (other) --- beanquery/query_compile.py | 50 ++++++++++++++++++++++++++++++++++++++ beanquery/query_env.py | 1 + beanquery/query_render.py | 18 +++++++++++--- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index 2d861099..83a5eaee 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -17,6 +17,56 @@ import re import operator + +@functools.total_ordering +class Sentinel: + """General sentinel value with configurable sort order and display string. + + Sentinels are special marker values that can be used in query results. + They have a configurable sort order relative to regular values: + - Negative sort_order: sorts before all non-sentinel values + - Positive sort_order: sorts after all non-sentinel values + - Zero sort_order: sorts with regular values (not recommended) + + Args: + sort_order: Integer determining sort position. Negative sorts before + regular values, positive sorts after. + display: String representation when rendered. + """ + + def __init__(self, sort_order, display): + self.sort_order = sort_order + self.display = display + + def __str__(self): + return self.display + + def __repr__(self): + return f"Sentinel({self.sort_order}, {self.display!r})" + + def __eq__(self, other): + if not isinstance(other, Sentinel): + return False + return self.sort_order == other.sort_order and self.display == other.display + + def __lt__(self, other): + if isinstance(other, Sentinel): + # Compare sentinels by their sort order + return self.sort_order < other.sort_order + else: + # Compare sentinel to non-sentinel value + # Negative sort_order: sentinel sorts before regular values + # Positive sort_order: sentinel sorts after regular values + return self.sort_order < 0 + + def __hash__(self): + return hash((self.sort_order, self.display)) + + +# Sentinel instances for various use cases +SENTINEL_EARLIER = Sentinel(-1, "(earlier)") +SENTINEL_LATER = Sentinel(1, "(later)") + from decimal import Decimal from typing import List diff --git a/beanquery/query_env.py b/beanquery/query_env.py index 7606b193..fb621acb 100644 --- a/beanquery/query_env.py +++ b/beanquery/query_env.py @@ -28,6 +28,7 @@ from beanquery import query_compile from beanquery import types +from beanquery.query_compile import SENTINEL_EARLIER, SENTINEL_LATER class ColumnsRegistry(dict): diff --git a/beanquery/query_render.py b/beanquery/query_render.py index 24d53618..4bf52bbf 100644 --- a/beanquery/query_render.py +++ b/beanquery/query_render.py @@ -16,6 +16,8 @@ from beancount.core import inventory from beancount.core import position +from beanquery.query_compile import Sentinel + class Align(enum.Enum): LEFT = 0 @@ -449,10 +451,18 @@ def render_rows(rows, renderers, ctx): for row in rows: - # Render the row cells. Do not pass missing values to the - # renderers but substitute them with the appropriate - # placeholder string. - cells = [render.format(value) if value is not None else null for render, value in zip(renderers, row)] + # Render the row cells. Handle special cases: + # - Sentinel: Convert to string representation + # - None: Substitute with null placeholder + # - Regular values: Pass to renderer + cells = [] + for render, value in zip(renderers, row): + if isinstance(value, Sentinel): + cells.append(str(value)) + elif value is not None: + cells.append(render.format(value)) + else: + cells.append(null) if not any(isinstance(cell, list) for cell in cells): # No multi line cells. Yield the row. From 1681d548073676e3e80eb2debe48b8ef877fcbf8 Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Wed, 10 Dec 2025 11:52:08 +0100 Subject: [PATCH 03/10] Deal with multiple data types (e.g. Sentinel objects) when rendering columns --- beanquery/query_render.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/beanquery/query_render.py b/beanquery/query_render.py index 4bf52bbf..0ca0b39f 100644 --- a/beanquery/query_render.py +++ b/beanquery/query_render.py @@ -158,12 +158,19 @@ class SetRenderer(ColumnRenderer): def __init__(self, ctx): super().__init__(ctx) self.sep = ctx.listsep + self.ctx = ctx def update(self, value): - self.maxwidth = max(self.maxwidth, sum(len(x) + len(self.sep) for x in value) - len(self.sep)) + self.maxwidth = max(self.maxwidth, sum(len(str(x)) + len(self.sep) for x in value) - len(self.sep)) def format(self, value): - return self.sep.join(str(x) for x in sorted(value)) + """Format the value.""" + if not value: + return '' + # Get the appropriate renderer for the first item's type + item = next(iter(value)) + renderer = _get_renderer(type(item), self.ctx) + return self.sep.join(renderer.format(item) for item in sorted(value)) class DateRenderer(ColumnRenderer): From 299c797973a7b3a2df5d0ab099610dd0808965f8 Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Thu, 6 Nov 2025 23:02:18 +0100 Subject: [PATCH 04/10] Support PIVOT BY after GROUP BY ROLLUP by using a sentinel instead of NULL for subtotals --- beanquery/compiler.py | 13 ++++++++++++- beanquery/query_compile.py | 20 +++++++++++--------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index e887deab..080b82cb 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -247,12 +247,23 @@ def _compile_rollup(self, node: ast.Select): ] # Create union of all grouping set queries - return EvalUnion( + union = EvalUnion( queries=queries, rollup_sets=rollup_sets, order_spec=order_spec, limit=node.limit ) + + # Flatten element_indexes for PIVOT BY compilation + group_indexes = regular_indexes + rollup_indexes + + # Handle PIVOT BY if present + pivots = self._compile_pivot_by(node.pivot_by, c_targets, group_indexes) + if pivots: + from .query_compile import EvalPivot + return EvalPivot(union, pivots) + + return union def _compile_from(self, node): if node is None: diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index 83a5eaee..5d71735c 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -13,6 +13,7 @@ import collections import dataclasses import datetime +import functools import itertools import re import operator @@ -66,6 +67,7 @@ def __hash__(self): # Sentinel instances for various use cases SENTINEL_EARLIER = Sentinel(-1, "(earlier)") SENTINEL_LATER = Sentinel(1, "(later)") +ROLLUP_TOTAL = Sentinel(2, "(Total)") from decimal import Decimal from typing import List @@ -733,15 +735,15 @@ def __call__(self): if columns is None: columns = cols - # NULL out columns that are not in this grouping set - # Columns in full_group_indexes but not in grouping_set should be NULL + # Mark subtotal columns that are not in this grouping set + # Columns in full_group_indexes but not in grouping_set get ROLLUP_TOTAL sentinel grouping_set_indexes = set(grouping_set) - null_indexes = [idx for idx in full_group_indexes if idx not in grouping_set_indexes] + subtotal_indexes = [idx for idx in full_group_indexes if idx not in grouping_set_indexes] - # Replace values with None for non-grouped columns - if null_indexes: + # Replace values with ROLLUP_TOTAL for non-grouped columns (subtotal rows) + if subtotal_indexes: rows = [ - tuple(None if i in null_indexes else val for i, val in enumerate(row)) + tuple(ROLLUP_TOTAL if i in subtotal_indexes else val for i, val in enumerate(row)) for row in rows ] @@ -751,10 +753,10 @@ def __call__(self): if self.order_spec: # Sort in reverse order to leverage Python's stable sort for col_index, ordering in reversed(self.order_spec): - # NULL sorts last: (0, val) for non-NULL, (1, None) for NULL - # This is because NULL denotes the total row in ROLLUP + # ROLLUP_TOTAL sorts last (after all regular values) + # RollupTotal implements comparison operators to sort after everything all_rows.sort( - key=lambda row: (0, row[col_index]) if row[col_index] is not None else (1, None), + key=lambda row: row[col_index], reverse=bool(ordering) ) From 0b88056b8d4b525e76c52491bfd5fa93810cec4b Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Thu, 13 Nov 2025 11:31:19 +0100 Subject: [PATCH 05/10] Rename Rollup Sentinel for better consistency --- beanquery/query_compile.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index 5d71735c..e6ae8cd4 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -65,9 +65,9 @@ def __hash__(self): # Sentinel instances for various use cases +SENTINEL_ROLLUP_TOTAL = Sentinel(2, "(Total)") SENTINEL_EARLIER = Sentinel(-1, "(earlier)") SENTINEL_LATER = Sentinel(1, "(later)") -ROLLUP_TOTAL = Sentinel(2, "(Total)") from decimal import Decimal from typing import List @@ -736,14 +736,14 @@ def __call__(self): columns = cols # Mark subtotal columns that are not in this grouping set - # Columns in full_group_indexes but not in grouping_set get ROLLUP_TOTAL sentinel + # Columns in full_group_indexes but not in grouping_set get SENTINEL_ROLLUP_TOTAL sentinel grouping_set_indexes = set(grouping_set) subtotal_indexes = [idx for idx in full_group_indexes if idx not in grouping_set_indexes] - # Replace values with ROLLUP_TOTAL for non-grouped columns (subtotal rows) + # Replace values with SENTINEL_ROLLUP_TOTAL for non-grouped columns (subtotal rows) if subtotal_indexes: rows = [ - tuple(ROLLUP_TOTAL if i in subtotal_indexes else val for i, val in enumerate(row)) + tuple(SENTINEL_ROLLUP_TOTAL if i in subtotal_indexes else val for i, val in enumerate(row)) for row in rows ] @@ -753,8 +753,8 @@ def __call__(self): if self.order_spec: # Sort in reverse order to leverage Python's stable sort for col_index, ordering in reversed(self.order_spec): - # ROLLUP_TOTAL sorts last (after all regular values) - # RollupTotal implements comparison operators to sort after everything + # SENTINEL_ROLLUP_TOTAL sorts last (after all regular values) + # Sentinel implements comparison operators to sort based on sort_order all_rows.sort( key=lambda row: row[col_index], reverse=bool(ordering) From 681381029b1c610c0004b8919235f62f6037e459 Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Fri, 7 Nov 2025 14:19:03 +0100 Subject: [PATCH 06/10] Implement GROUP BY CUBE / GROUPING SETS --- beanquery/compiler.py | 187 ++++++++++++++++++++++++++----------- beanquery/parser/bql.ebnf | 17 +++- beanquery/parser/parser.py | 88 ++++++++++++++++- 3 files changed, 232 insertions(+), 60 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 080b82cb..46b9e275 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -1,5 +1,6 @@ import collections.abc import importlib +import itertools import typing from decimal import Decimal @@ -28,6 +29,7 @@ EvalOr, EvalPivot, EvalQuery, + EvalUnion, EvalConstantSubquery1D, EvalRow, EvalTarget, @@ -133,29 +135,39 @@ def _compile_select_base(self, node: ast.Select): def _select(self, node: ast.Select): self.stack.append(self.table) - # Handle ROLLUP queries separately - # Check if any grouping element has rollup=True + # Handle ROLLUP/CUBE/GROUPING SETS queries separately + # Check if any grouping element has rollup=True, cube=True, or sets=True if node.group_by and node.group_by.elements: has_rollup = any( elem.get('rollup') for elem in node.group_by.elements ) + has_cube = any( + elem.get('cube') for elem in node.group_by.elements + ) + has_sets = any( + elem.get('sets') for elem in node.group_by.elements + ) if has_rollup: - result = self._compile_rollup(node) + result = self._compile_grouping_sets(node, 'rollup') + self.stack.pop() + return result + if has_cube: + result = self._compile_grouping_sets(node, 'cube') + self.stack.pop() + return result + if has_sets: + result = self._compile_grouping_sets(node, 'sets') self.stack.pop() return result # Compile common SELECT parts c_targets, c_where, element_indexes, having_index, order_spec = self._compile_select_base(node) - # There is no ROLLUP clause there, therefore all elements follow the - # format {'indexes': [x], 'rollup': None} + # There is no ROLLUP/CUBE clause, therefore all elements follow the + # format {'indexes': [x], 'modifier': None} assert ( element_indexes is None - or all(e['rollup'] is None for e in element_indexes) - ) - assert ( - element_indexes is None - or all(len(e['indexes']) == 1 for e in element_indexes) + or all(e['modifier'] is None for e in element_indexes) ) # Flatten element_indexes for regular GROUP BY @@ -196,42 +208,73 @@ def _select(self, node: ast.Select): return query - def _compile_rollup(self, node: ast.Select): - """Compile a ROLLUP query as a union of grouping sets. + def _compile_grouping_sets(self, node: ast.Select, grouping_type): + """Compile ROLLUP/CUBE/GROUPING SETS query as a union of grouping sets. + + Supports full and mixed grouping: + - ROLLUP: GROUP BY ROLLUP (a, b, c) → grouping sets: [(a,b,c), (a,b), (a), ()] + - ROLLUP: GROUP BY x, ROLLUP (a, b) → grouping sets: [(x,a,b), (x,a), (x)] + - CUBE: GROUP BY CUBE (a, b) → grouping sets: [(a,b), (a), (b), ()] + - CUBE: GROUP BY x, CUBE (a, b) → grouping sets: [(x,a,b), (x,a), (x,b), (x)] + - SETS: GROUP BY GROUPING SETS ((a, b), (a), ()) → grouping sets: [(a,b), (a), ()] - Supports both full ROLLUP and mixed grouping: - - GROUP BY ROLLUP (a, b, c) → grouping sets: [(a,b,c), (a,b), (a), ()] - - GROUP BY x, ROLLUP (a, b) → grouping sets: [(x,a,b), (x,a), (x)] + ROLLUP generates hierarchical grouping sets (prefixes). + CUBE generates all possible combinations (power set) of the columns. + SETS uses explicitly specified grouping sets. Args: - node: A Select AST node with at least one rollup element in group_by. + node: A Select AST node with at least one rollup/cube/sets element in group_by. + grouping_type: Either 'rollup', 'cube', or 'sets'. Returns: An EvalUnion that executes multiple queries for each grouping set. """ - from .query_compile import EvalUnion + from .query_compile import EvalUnion, EvalPivot + import itertools # Compile common SELECT parts c_targets, c_where, element_indexes, having_index, order_spec = self._compile_select_base(node) - # Separate regular columns from ROLLUP columns using element structure + # Separate regular columns from ROLLUP/CUBE/SETS columns using element + # structure. Columns are represented by their numerical indexes as + # determined by _compile_select_base(). regular_indexes = [] - rollup_indexes = [] + rollup_indexes = [] # Will be a flat list for ROLLUP/CUBE, list of lists for SETS for elem in element_indexes: - if elem['rollup']: - # ROLLUP element - rollup_indexes.extend(elem['indexes']) + if elem['modifier'] == grouping_type: + # For GROUPING SETS, 'indexes' is a list of lists + # For ROLLUP/CUBE, 'indexes' is a flat list + rollup_indexes = elem['indexes'] else: # Regular element regular_indexes.extend(elem['indexes']) - # Generate hierarchical grouping sets for ROLLUP columns - # For ROLLUP (a, b, c), generate: [(a,b,c), (a,b), (a), ()] - rollup_sets = [] - for i in range(len(rollup_indexes), -1, -1): - # Combine regular columns (always present) with rollup columns (hierarchical) - grouping_set = regular_indexes + rollup_indexes[:i] - rollup_sets.append(grouping_set) + # Generate grouping sets based on type + if grouping_type == 'rollup': + # Hierarchical: [(a,b,c), (a,b), (a), ()] + grouping_sets = [] + for i in range(len(rollup_indexes), -1, -1): + # Combine regular columns (always present) with special columns (hierarchical) + grouping_set = regular_indexes + rollup_indexes[:i] + grouping_sets.append(grouping_set) + elif grouping_type == 'cube': + # All combinations: power set + # For CUBE (a, b, c), generate all 2^3 = 8 combinations: + # [(a,b,c), (a,b), (a,c), (a), (b,c), (b), (c), ()] + grouping_sets = [] + for r in range(len(rollup_indexes), -1, -1): + for combo in itertools.combinations(rollup_indexes, r): + # Combine regular columns (always present) with special column combination + grouping_set = regular_indexes + list(combo) + grouping_sets.append(grouping_set) + else: # sets + # Explicit grouping sets specified by user + # rollup_indexes is already a list of lists + grouping_sets = [] + for set_indexes in rollup_indexes: + # Combine regular columns with this grouping set + grouping_set = regular_indexes + set_indexes + grouping_sets.append(grouping_set) # Create a query for each grouping set queries = [ @@ -243,24 +286,29 @@ def _compile_rollup(self, node: ast.Select): None, None, node.distinct) - for grouping_set in rollup_sets + for grouping_set in grouping_sets ] # Create union of all grouping set queries union = EvalUnion( queries=queries, - rollup_sets=rollup_sets, + rollup_sets=grouping_sets, order_spec=order_spec, limit=node.limit ) # Flatten element_indexes for PIVOT BY compilation - group_indexes = regular_indexes + rollup_indexes + # For GROUPING SETS, rollup_indexes is a list of lists, so flatten it + if grouping_type == 'sets': + # Flatten the list of lists and get unique indexes + flat_indexes = list(set(idx for set_indexes in rollup_indexes for idx in set_indexes)) + group_indexes = regular_indexes + flat_indexes + else: + group_indexes = regular_indexes + rollup_indexes # Handle PIVOT BY if present pivots = self._compile_pivot_by(node.pivot_by, c_targets, group_indexes) if pivots: - from .query_compile import EvalPivot return EvalPivot(union, pivots) return union @@ -472,20 +520,25 @@ def _compile_group_by(self, group_by, c_targets): A tuple of new_targets: A list of new compiled target nodes. element_indexes: A list of dicts, one per grouping element: - [{'indexes': [int, ...], 'rollup': bool}, ...] + [{'indexes': [int, ...], 'modifier': str or None}, ...] Each dict represents one grouping element from the grammar. + 'modifier' can be None, 'rollup', 'cube', or 'sets'. Examples: - Non-aggregate query: None - Aggregate without GROUP BY: [] - Regular GROUP BY account, year: - [{'indexes': [0], 'rollup': None}, {'indexes': [1], 'rollup': None}] + [{'indexes': [0], 'modifier': None}, {'indexes': [1], 'modifier': None}] - Full ROLLUP: GROUP BY ROLLUP (account, year): - [{'indexes': [0, 1], 'rollup': True}] + [{'indexes': [0, 1], 'modifier': 'rollup', 'grouping_sets': None}] + - Full CUBE: GROUP BY CUBE (account, year): + [{'indexes': [0, 1], 'modifier': 'cube', 'grouping_sets': None}] + - GROUPING SETS: GROUP BY GROUPING SETS ((account, year), (account), ()): + [{'indexes': [[0, 1], [0], []], 'modifier': 'sets', 'grouping_sets': [...]}] - Mixed grouping: GROUP BY region, ROLLUP (year, month): - [{'indexes': [2], 'rollup': None}, {'indexes': [0, 1], 'rollup': True}] + [{'indexes': [2], 'modifier': None, 'grouping_sets': None}, {'indexes': [0, 1], 'modifier': 'rollup', 'grouping_sets': None}] - Implicit GROUP BY (when SUPPORT_IMPLICIT_GROUPBY=True): - [{'indexes': [0], 'rollup': None}, {'indexes': [2], 'rollup': None}] + [{'indexes': [0], 'modifier': None, 'grouping_sets': None}, {'indexes': [2], 'modifier': None, 'grouping_sets': None}] having_index: Index of HAVING expression in targets, or None. """ @@ -512,28 +565,51 @@ def _compile_group_by(self, group_by, c_targets): # Initialize element structures for elem in group_by.elements: # Iterating over GROUP BY syntax elements, which are either a - # simple grouping column/expression, or a ROLLUP (col1, ...) - # element. - rollup_value = elem.get('rollup') - element_indexes.append({ - 'indexes': [], - 'rollup': rollup_value if rollup_value else None - }) + # simple grouping column/expression, a ROLLUP (col1, ...) + # element, a CUBE (col1, ...) element, or a GROUPING SETS element. + if elem.get('rollup'): + modifier = 'rollup' + elif elem.get('cube'): + modifier = 'cube' + elif elem.get('sets'): + modifier = 'sets' + else: + modifier = None + + # For GROUPING SETS, 'indexes' will be a list of lists + # For other modifiers, 'indexes' is a flat list + if modifier == 'sets': + element_indexes.append({ + 'indexes': [[] for _ in elem['grouping_sets']], + 'modifier': modifier, + 'grouping_sets': elem.get('grouping_sets') + }) + else: + element_indexes.append({ + 'indexes': [], + 'modifier': modifier, + 'grouping_sets': None + }) # Collect all columns with their syntax element position + # For GROUPING SETS, also track which set within the element columns_by_element = [] for elem_idx, elem in enumerate(group_by.elements): - if elem.get('rollup'): + if elem.get('rollup') or elem.get('cube'): columns = elem['columns'] + for column in columns: + columns_by_element.append((elem_idx, None, column)) + elif elem.get('sets'): + # For GROUPING SETS, track which set each column belongs to + for set_idx, grouping_set in enumerate(elem['grouping_sets']): + for column in grouping_set['columns']: + columns_by_element.append((elem_idx, set_idx, column)) else: - columns = [elem['column']] - - for column in columns: - columns_by_element.append((elem_idx, column)) + columns_by_element.append((elem_idx, None, elem['column'])) # Compile all columns and add the indexes to element_indexes, # to return the same structure as in the parsed GROUP BY clause. - for elem_idx, column in columns_by_element: + for elem_idx, set_idx, column in columns_by_element: index = None # Process target references by index. @@ -583,7 +659,12 @@ def _compile_group_by(self, group_by, c_targets): raise CompilationError(f'GROUP-BY a non-hashable type is not supported: "{column}"') # Add compiled index to the corresponding element - element_indexes[elem_idx]['indexes'].append(index) + # For GROUPING SETS, append to the specific set's list + # For other modifiers, append to the flat list + if set_idx is not None: + element_indexes[elem_idx]['indexes'][set_idx].append(index) + else: + element_indexes[elem_idx]['indexes'].append(index) # Compile HAVING clause. if group_by.having is not None: @@ -613,7 +694,7 @@ def _compile_group_by(self, group_by, c_targets): if not c_target.is_aggregate] # Wrap as individual elements element_indexes = [ - {'indexes': [idx], 'rollup': None} for idx in implicit_indexes + {'indexes': [idx], 'modifier': None} for idx in implicit_indexes ] else: raise CompilationError('aggregate query without a GROUP-BY should have only aggregates') diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 2369d79a..0fc9ac18 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -59,21 +59,32 @@ table::Table # A GROUP BY clause supports multiple syntaxes: # 1. Regular GROUP BY: GROUP BY col1, col2 [HAVING condition] # 2. Full ROLLUP: GROUP BY ROLLUP (col1, col2) [HAVING condition] -# 3. Mixed grouping: GROUP BY col1, ROLLUP (col2, col3) [HAVING condition] +# 3. Full CUBE: GROUP BY CUBE (col1, col2) [HAVING condition] +# 4. GROUPING SETS: GROUP BY GROUPING SETS ((col1, col2), (col1), ()) [HAVING condition] +# 5. Mixed grouping: GROUP BY col1, ROLLUP (col2, col3) [HAVING condition] # # Examples: # - GROUP BY account, year # - GROUP BY account, year HAVING SUM(position) > 100 # - GROUP BY ROLLUP (account, year) # - GROUP BY ROLLUP (account, year) HAVING SUM(position) > 100 +# - GROUP BY CUBE (account, year) +# - GROUP BY GROUPING SETS ((account, year), (account), ()) # - GROUP BY region, ROLLUP (year, month) +# - GROUP BY region, CUBE (year, month) groupby::GroupBy = elements:','.{ grouping_element }+ ['HAVING' having:expression] ; grouping_element - = 'ROLLUP' '(' columns:','.{ (integer | expression) }+ ')' rollup:`True` - | column:(integer | expression) rollup:`` + = 'ROLLUP' '(' columns:','.{ (integer | expression) }+ ')' rollup:`True` cube:`` sets:`` + | 'CUBE' '(' columns:','.{ (integer | expression) }+ ')' rollup:`` cube:`True` sets:`` + | 'GROUPING' 'SETS' '(' grouping_sets:','.{ grouping_set }+ ')' rollup:`` cube:`` sets:`True` + | column:(integer | expression) rollup:`` cube:`` sets:`` + ; + +grouping_set + = '(' columns:','.{ (integer | expression) }* ')' ; order::OrderBy diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 274b60e6..2478e59e 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -381,7 +381,59 @@ def block1(): self._token(')') self._constant(True) self.name_last_node('rollup') - self._define(['columns', 'rollup'], []) + self._constant('') + self.name_last_node('cube') + self._constant('') + self.name_last_node('sets') + self._define(['columns', 'cube', 'rollup', 'sets'], []) + with self._option(): + self._token('CUBE') + self._token('(') + + def sep2(): + self._token(',') + + def block3(): + with self._group(): + with self._choice(): + with self._option(): + self._integer_() + with self._option(): + self._expression_() + self._error( + 'expecting one of: ' + ' ' + ) + self._positive_gather(block3, sep2) + self.name_last_node('columns') + self._token(')') + self._constant('') + self.name_last_node('rollup') + self._constant(True) + self.name_last_node('cube') + self._constant('') + self.name_last_node('sets') + self._define(['columns', 'cube', 'rollup', 'sets'], []) + with self._option(): + self._token('GROUPING') + self._token('SETS') + self._token('(') + + def sep4(): + self._token(',') + + def block5(): + self._grouping_set_() + self._positive_gather(block5, sep4) + self.name_last_node('grouping_sets') + self._token(')') + self._constant('') + self.name_last_node('rollup') + self._constant('') + self.name_last_node('cube') + self._constant(True) + self.name_last_node('sets') + self._define(['cube', 'grouping_sets', 'rollup', 'sets'], []) with self._option(): with self._group(): with self._choice(): @@ -396,13 +448,41 @@ def block1(): self.name_last_node('column') self._constant('') self.name_last_node('rollup') - self._define(['column', 'rollup'], []) + self._constant('') + self.name_last_node('cube') + self._constant('') + self.name_last_node('sets') + self._define(['column', 'cube', 'rollup', 'sets'], []) self._error( 'expecting one of: ' - "'ROLLUP' " - ' [0-9]+' + "'CUBE' 'GROUPING' 'ROLLUP' " + ' ' + '[0-9]+' ) + @tatsumasu() + def _grouping_set_(self): + self._token('(') + + def sep0(): + self._token(',') + + def block1(): + with self._group(): + with self._choice(): + with self._option(): + self._integer_() + with self._option(): + self._expression_() + self._error( + 'expecting one of: ' + ' ' + ) + self._gather(block1, sep0) + self.name_last_node('columns') + self._token(')') + self._define(['columns'], []) + @tatsumasu('OrderBy') def _order_(self): with self._group(): From 352efcfefa14d55350fb6016bc25663d7e21116a Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Mon, 10 Nov 2025 16:03:38 +0100 Subject: [PATCH 07/10] Do cartesian product when combining ROLLUP, CUBE, GROUPING SET. Also, refactor _select(), remove _compile_select_base() to move closer to the original logic before ROLLUP, CUBE, etc. --- beanquery/compiler.py | 299 +++++++++++++++++++++++------------------- 1 file changed, 164 insertions(+), 135 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 46b9e275..94b636be 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -98,15 +98,11 @@ def _compile(self, node: Optional[ast.Node]): return None raise NotImplementedError - def _compile_select_base(self, node: ast.Select): - """Compile common parts of SELECT: FROM, targets, WHERE, GROUP BY, ORDER BY. + @_compile.register + def _select(self, node: ast.Select): + self.stack.append(self.table) - Args: - node: A Select AST node. - Returns: - Tuple of (c_targets, c_where, element_indexes, having_index, order_spec) - """ - # Compile FROM clause + # Compile the FROM clause c_from_expr = self._compile_from(node.from_clause) # Compile the targets. @@ -114,6 +110,10 @@ def _compile_select_base(self, node: ast.Select): # Bind the WHERE expression to the execution environment. c_where = self._compile(node.where_clause) + + # Check that the FROM clause does not contain aggregates. This + # should never trigger if the compilation environment does not + # contain any aggregate. if c_where is not None and is_aggregate(c_where): raise CompilationError('aggregates are not allowed in WHERE clause') @@ -129,50 +129,32 @@ def _compile_select_base(self, node: ast.Select): new_targets, order_spec = self._compile_order_by(node.order_by, c_targets) c_targets.extend(new_targets) - return c_targets, c_where, element_indexes, having_index, order_spec - - @_compile.register - def _select(self, node: ast.Select): - self.stack.append(self.table) - - # Handle ROLLUP/CUBE/GROUPING SETS queries separately - # Check if any grouping element has rollup=True, cube=True, or sets=True if node.group_by and node.group_by.elements: - has_rollup = any( - elem.get('rollup') for elem in node.group_by.elements - ) - has_cube = any( - elem.get('cube') for elem in node.group_by.elements - ) - has_sets = any( - elem.get('sets') for elem in node.group_by.elements - ) - if has_rollup: - result = self._compile_grouping_sets(node, 'rollup') - self.stack.pop() - return result - if has_cube: - result = self._compile_grouping_sets(node, 'cube') - self.stack.pop() - return result - if has_sets: - result = self._compile_grouping_sets(node, 'sets') - self.stack.pop() - return result + if any(elem.get('rollup') or elem.get('cube') or elem.get('sets') + for elem in node.group_by.elements): + is_grouping = "complex" + else: + is_grouping = "simple" + else: + is_grouping = "none" - # Compile common SELECT parts - c_targets, c_where, element_indexes, having_index, order_spec = self._compile_select_base(node) - - # There is no ROLLUP/CUBE clause, therefore all elements follow the - # format {'indexes': [x], 'modifier': None} - assert ( - element_indexes is None - or all(e['modifier'] is None for e in element_indexes) - ) + if is_grouping in ['none', 'simple'] and element_indexes is not None: + # Element indexes might be != None if we are grouping implicitly + # (only aggregate functions or explicitly) + # + # For simple grouping, there is no ROLLUP/CUBE clause, therefore all + # elements follow the format {'indexes': [x], 'modifier': None} + assert all(e['modifier'] is None for e in element_indexes) + + # Ensure all elements in element_indexes[x]['indexes'] are of type int. + assert all(isinstance(idx, int) + for elem in element_indexes for idx in elem['indexes']) # Flatten element_indexes for regular GROUP BY if element_indexes is not None: - group_indexes = [elem['indexes'][0] for elem in element_indexes] + group_indexes = set() + for elem in element_indexes: + group_indexes.update(elem['indexes']) else: group_indexes = None @@ -181,7 +163,7 @@ def _select(self, node: ast.Select): # always be the case at this point, because we have added all the necessary # targets to the list of group-by expressions and should have resolved all # the indexes. - if group_indexes is not None: + if is_grouping != 'none': non_aggregate_indexes = {index for index, c_target in enumerate(c_targets) if not c_target.is_aggregate} if non_aggregate_indexes != set(group_indexes): @@ -191,15 +173,26 @@ def _select(self, node: ast.Select): 'all non-aggregates must be covered by GROUP-BY clause in aggregate query: ' 'the following targets are missing: {}'.format(','.join(missing_names))) - query = EvalQuery(self.table, - c_targets, - c_where, - group_indexes, - having_index, - order_spec, - node.limit, - node.distinct) + # Handle ROLLUP/CUBE/GROUPING SETS queries separately: + # Check if any grouping element has such a modifier. + if is_grouping == "complex": + + # Obtain the UnionEval node and the flattened list of group indexes + query, group_indexes = self._compile_grouping_sets(node, c_targets, c_where, element_indexes, having_index, order_spec) + + else: # grouping in ['none', 'simple' ] + + + query = EvalQuery(self.table, + c_targets, + c_where, + group_indexes, + having_index, + order_spec, + node.limit, + node.distinct) + pivots = self._compile_pivot_by(node.pivot_by, c_targets, group_indexes) if pivots: return EvalPivot(query, pivots) @@ -208,75 +201,59 @@ def _select(self, node: ast.Select): return query - def _compile_grouping_sets(self, node: ast.Select, grouping_type): - """Compile ROLLUP/CUBE/GROUPING SETS query as a union of grouping sets. - - Supports full and mixed grouping: + def _compile_grouping_sets(self, node: ast.Select, c_targets, c_where, element_indexes, having_index, order_spec): + """Compile a query with complex grouping (ROLLUP, CUBE, SETS) as a union. + Supports full and mixed grouping: - ROLLUP: GROUP BY ROLLUP (a, b, c) → grouping sets: [(a,b,c), (a,b), (a), ()] - ROLLUP: GROUP BY x, ROLLUP (a, b) → grouping sets: [(x,a,b), (x,a), (x)] - CUBE: GROUP BY CUBE (a, b) → grouping sets: [(a,b), (a), (b), ()] - CUBE: GROUP BY x, CUBE (a, b) → grouping sets: [(x,a,b), (x,a), (x,b), (x)] - SETS: GROUP BY GROUPING SETS ((a, b), (a), ()) → grouping sets: [(a,b), (a), ()] - - ROLLUP generates hierarchical grouping sets (prefixes). - CUBE generates all possible combinations (power set) of the columns. - SETS uses explicitly specified grouping sets. - + + This function handles standard-compliant mixed grouping by: + 1. Separating simple grouping columns from complex ones (ROLLUP, etc.). + 2. Generating grouping sets for each complex element. + 3. Combining the sets from complex elements using a cartesian product. + 4. Prepending the simple grouping columns to each resulting set. + Args: - node: A Select AST node with at least one rollup/cube/sets element in group_by. - grouping_type: Either 'rollup', 'cube', or 'sets'. + node: A Select AST node with one or more complex grouping elements. + Returns: - An EvalUnion that executes multiple queries for each grouping set. + A tuple of (EvalUnion, list of group indexes). The EvalUnion node + executes a query for each final grouping set. The list of group + indexes has the indexes of the unique columns used in the GROUP BY + clause. """ - from .query_compile import EvalUnion, EvalPivot - import itertools - - # Compile common SELECT parts - c_targets, c_where, element_indexes, having_index, order_spec = self._compile_select_base(node) - - # Separate regular columns from ROLLUP/CUBE/SETS columns using element - # structure. Columns are represented by their numerical indexes as - # determined by _compile_select_base(). - regular_indexes = [] - rollup_indexes = [] # Will be a flat list for ROLLUP/CUBE, list of lists for SETS - + # Separate simple and complex grouping elements + # Why: Simple columns are treated as a prefix for all grouping sets, + # while complex elements (ROLLUP, CUBE, SETS) generate multiple sets + # that need to be combined. + simple_indexes = [] + complex_elements = [] for elem in element_indexes: - if elem['modifier'] == grouping_type: - # For GROUPING SETS, 'indexes' is a list of lists - # For ROLLUP/CUBE, 'indexes' is a flat list - rollup_indexes = elem['indexes'] + if elem['modifier'] is None: + simple_indexes.extend(elem['indexes']) else: - # Regular element - regular_indexes.extend(elem['indexes']) - - # Generate grouping sets based on type - if grouping_type == 'rollup': - # Hierarchical: [(a,b,c), (a,b), (a), ()] - grouping_sets = [] - for i in range(len(rollup_indexes), -1, -1): - # Combine regular columns (always present) with special columns (hierarchical) - grouping_set = regular_indexes + rollup_indexes[:i] - grouping_sets.append(grouping_set) - elif grouping_type == 'cube': - # All combinations: power set - # For CUBE (a, b, c), generate all 2^3 = 8 combinations: - # [(a,b,c), (a,b), (a,c), (a), (b,c), (b), (c), ()] - grouping_sets = [] - for r in range(len(rollup_indexes), -1, -1): - for combo in itertools.combinations(rollup_indexes, r): - # Combine regular columns (always present) with special column combination - grouping_set = regular_indexes + list(combo) - grouping_sets.append(grouping_set) - else: # sets - # Explicit grouping sets specified by user - # rollup_indexes is already a list of lists - grouping_sets = [] - for set_indexes in rollup_indexes: - # Combine regular columns with this grouping set - grouping_set = regular_indexes + set_indexes - grouping_sets.append(grouping_set) - - # Create a query for each grouping set + complex_elements.append(elem) + + # Generate and combine grouping sets from complex elements + # Why: We iterate through the complex elements, generate the grouping + # sets for each, and combine them using a cartesian product to handle + # mixed grouping constructs like `GROUP BY ROLLUP(...), CUBE(...)`. + combined_sets = [[]] # Start with an empty set for the initial product + for elem in complex_elements: + element_sets = _get_grouping_sets_for_element(elem) + combined_sets = _combine_grouping_sets(combined_sets, element_sets) + + # Prepend simple columns to all generated sets + # Why: The simple GROUP BY columns must be included in every grouping + # set generated by the complex elements. + final_grouping_sets = [simple_indexes + s for s in combined_sets] + + # Create a query for each final grouping set + # Why: The result of a complex grouping query is the UNION of the + # results of running a separate query for each individual grouping set. queries = [ EvalQuery(self.table, c_targets, @@ -286,32 +263,22 @@ def _compile_grouping_sets(self, node: ast.Select, grouping_type): None, None, node.distinct) - for grouping_set in grouping_sets + for grouping_set in final_grouping_sets ] - - # Create union of all grouping set queries + + # Wrap the individual queries in a UNION operator. union = EvalUnion( queries=queries, - rollup_sets=grouping_sets, + rollup_sets=final_grouping_sets, order_spec=order_spec, limit=node.limit ) - - # Flatten element_indexes for PIVOT BY compilation - # For GROUPING SETS, rollup_indexes is a list of lists, so flatten it - if grouping_type == 'sets': - # Flatten the list of lists and get unique indexes - flat_indexes = list(set(idx for set_indexes in rollup_indexes for idx in set_indexes)) - group_indexes = regular_indexes + flat_indexes - else: - group_indexes = regular_indexes + rollup_indexes - - # Handle PIVOT BY if present - pivots = self._compile_pivot_by(node.pivot_by, c_targets, group_indexes) - if pivots: - return EvalPivot(union, pivots) - - return union + + # List of unique column indexes used in the GROUP BY clause + all_group_by_indexes = list(set(idx for s in final_grouping_sets for idx in s)) + + return union, all_group_by_indexes + def _compile_from(self, node): if node is None: @@ -1139,5 +1106,67 @@ def is_aggregate(node): return bool(aggregates) + +def _combine_grouping_sets(list_of_sets1, list_of_sets2): + """Compute the cartesian product of two lists of grouping sets. + + >>> _combine_grouping_sets([['a'], ['b']], [['c'], ['d']]) + [['a', 'c'], ['a', 'd'], ['b', 'c'], ['b', 'd']] + >>> _combine_grouping_sets([['a']], [['b', 'c'], []]) + [['a', 'b', 'c'], ['a']] + >>> _combine_grouping_sets([], [['a']]) + [['a']] + >>> _combine_grouping_sets([['a']], []) + [['a']] + """ + # Why: This helper function is used to combine grouping sets from different + # grouping elements (e.g., `ROLLUP` and `CUBE`) by creating a + # cartesian product of their individual grouping sets. + import itertools + if not list_of_sets1: + return list_of_sets2 + if not list_of_sets2: + return list_of_sets1 + + return [s1 + s2 for s1, s2 in itertools.product(list_of_sets1, list_of_sets2)] + +def _get_grouping_sets_for_element(element): + """Generate grouping sets for a single GROUP BY element. + This function isolates the logic for generating grouping sets based on + the element's modifier (`rollup`, `cube`, `sets`). + + Args: + element (dict): A dictionary representing a grouping element, which contains: + - 'modifier' (str): The type of grouping modifier ('rollup', 'cube', 'sets', or None). + - 'indexes' (list): A list of integer indexes representing the grouping columns. + + Returns: + list: A list of lists, where each inner list represents a grouping set. + + Example: + >>> element = {'modifier': 'rollup', 'indexes': [0, 1]} + >>> _get_grouping_sets_for_element(element) + [[0, 1], [0], []] + """ + modifier = element['modifier'] + indexes = element['indexes'] + + if modifier == 'rollup': + # Hierarchical prefixes: e.g., (a,b) -> [(a,b), (a), ()] + return [indexes[:i] for i in range(len(indexes), -1, -1)] + elif modifier == 'cube': + # Power set: e.g., (a,b) -> [(a,b), (a), (b), ()] + sets = [] + for i in range(len(indexes), -1, -1): + for combo in itertools.combinations(indexes, i): + sets.append(list(combo)) + return sets + elif modifier == 'sets': + # User-defined sets + return indexes + else: # Regular column + return [indexes] + + def compile(context, statement, parameters=None): return Compiler(context).compile(statement, parameters) From c90f2645914f503b7319b46c08e6062191fe347a Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Mon, 8 Dec 2025 11:58:09 +0100 Subject: [PATCH 08/10] Replace separate rollup/cube/sets flags with unified 'type' field in AST --- beanquery/compiler.py | 21 +++++++++++---------- beanquery/parser/bql.ebnf | 8 ++++---- beanquery/parser/parser.py | 38 +++++++++++--------------------------- 3 files changed, 26 insertions(+), 41 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 94b636be..e17c6f9a 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -130,7 +130,7 @@ def _select(self, node: ast.Select): c_targets.extend(new_targets) if node.group_by and node.group_by.elements: - if any(elem.get('rollup') or elem.get('cube') or elem.get('sets') + if any(elem.get('type') in ('rollup', 'cube', 'sets') for elem in node.group_by.elements): is_grouping = "complex" else: @@ -152,9 +152,9 @@ def _select(self, node: ast.Select): # Flatten element_indexes for regular GROUP BY if element_indexes is not None: - group_indexes = set() + group_indexes = [] for elem in element_indexes: - group_indexes.update(elem['indexes']) + group_indexes.extend(elem['indexes']) else: group_indexes = None @@ -490,6 +490,7 @@ def _compile_group_by(self, group_by, c_targets): [{'indexes': [int, ...], 'modifier': str or None}, ...] Each dict represents one grouping element from the grammar. 'modifier' can be None, 'rollup', 'cube', or 'sets'. + Note: The 'type' field in the AST element is used to determine the modifier. Examples: - Non-aggregate query: None @@ -531,14 +532,14 @@ def _compile_group_by(self, group_by, c_targets): # Initialize element structures for elem in group_by.elements: - # Iterating over GROUP BY syntax elements, which are either a - # simple grouping column/expression, a ROLLUP (col1, ...) + # Iterating over GROUP BY syntax elements, which are either a + # simple grouping column/expression, a ROLLUP (col1, ...) # element, a CUBE (col1, ...) element, or a GROUPING SETS element. - if elem.get('rollup'): + if elem.get('type') == 'rollup': modifier = 'rollup' - elif elem.get('cube'): + elif elem.get('type') == 'cube': modifier = 'cube' - elif elem.get('sets'): + elif elem.get('type') == 'sets': modifier = 'sets' else: modifier = None @@ -562,11 +563,11 @@ def _compile_group_by(self, group_by, c_targets): # For GROUPING SETS, also track which set within the element columns_by_element = [] for elem_idx, elem in enumerate(group_by.elements): - if elem.get('rollup') or elem.get('cube'): + if elem.get('type') in ('rollup', 'cube'): columns = elem['columns'] for column in columns: columns_by_element.append((elem_idx, None, column)) - elif elem.get('sets'): + elif elem.get('type') == 'sets': # For GROUPING SETS, track which set each column belongs to for set_idx, grouping_set in enumerate(elem['grouping_sets']): for column in grouping_set['columns']: diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 0fc9ac18..b915451d 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -77,10 +77,10 @@ groupby::GroupBy ; grouping_element - = 'ROLLUP' '(' columns:','.{ (integer | expression) }+ ')' rollup:`True` cube:`` sets:`` - | 'CUBE' '(' columns:','.{ (integer | expression) }+ ')' rollup:`` cube:`True` sets:`` - | 'GROUPING' 'SETS' '(' grouping_sets:','.{ grouping_set }+ ')' rollup:`` cube:`` sets:`True` - | column:(integer | expression) rollup:`` cube:`` sets:`` + = 'ROLLUP' '(' columns:','.{ (integer | expression) }+ ')' type:`rollup` + | 'CUBE' '(' columns:','.{ (integer | expression) }+ ')' type:`cube` + | 'GROUPING' 'SETS' '(' grouping_sets:','.{ grouping_set }+ ')' type:`sets` + | column:(integer | expression) type:`` ; grouping_set diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 2478e59e..699dac8a 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -379,13 +379,9 @@ def block1(): self._positive_gather(block1, sep0) self.name_last_node('columns') self._token(')') - self._constant(True) - self.name_last_node('rollup') - self._constant('') - self.name_last_node('cube') - self._constant('') - self.name_last_node('sets') - self._define(['columns', 'cube', 'rollup', 'sets'], []) + self._constant('rollup') + self.name_last_node('type') + self._define(['columns', 'type'], []) with self._option(): self._token('CUBE') self._token('(') @@ -407,13 +403,9 @@ def block3(): self._positive_gather(block3, sep2) self.name_last_node('columns') self._token(')') - self._constant('') - self.name_last_node('rollup') - self._constant(True) - self.name_last_node('cube') - self._constant('') - self.name_last_node('sets') - self._define(['columns', 'cube', 'rollup', 'sets'], []) + self._constant('cube') + self.name_last_node('type') + self._define(['columns', 'type'], []) with self._option(): self._token('GROUPING') self._token('SETS') @@ -427,13 +419,9 @@ def block5(): self._positive_gather(block5, sep4) self.name_last_node('grouping_sets') self._token(')') - self._constant('') - self.name_last_node('rollup') - self._constant('') - self.name_last_node('cube') - self._constant(True) - self.name_last_node('sets') - self._define(['cube', 'grouping_sets', 'rollup', 'sets'], []) + self._constant('sets') + self.name_last_node('type') + self._define(['grouping_sets', 'type'], []) with self._option(): with self._group(): with self._choice(): @@ -447,12 +435,8 @@ def block5(): ) self.name_last_node('column') self._constant('') - self.name_last_node('rollup') - self._constant('') - self.name_last_node('cube') - self._constant('') - self.name_last_node('sets') - self._define(['column', 'cube', 'rollup', 'sets'], []) + self.name_last_node('type') + self._define(['column', 'type'], []) self._error( 'expecting one of: ' "'CUBE' 'GROUPING' 'ROLLUP' " From 7acdab6a151cbfa38ca1830decbc9b977bca0821 Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Mon, 8 Dec 2025 13:35:14 +0100 Subject: [PATCH 09/10] Add tests for GROUP BY CUBE / ROLLUP --- beanquery/parser_test.py | 53 ++++++++++++++ beanquery/query_execute_test.py | 120 ++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) diff --git a/beanquery/parser_test.py b/beanquery/parser_test.py index 33f47fdb..d00e2c54 100644 --- a/beanquery/parser_test.py +++ b/beanquery/parser_test.py @@ -355,6 +355,59 @@ def test_groupby_numbers(self): def test_groupby_empty(self): with self.assertRaises(parser.ParseError): parser.parse("SELECT * GROUP BY;") + + def test_groupby_rollup(self): + """Test ROLLUP syntax in GROUP BY clause.""" + self.assertParse( + "SELECT * GROUP BY ROLLUP (account, year);", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + {'columns': [ast.Column('account'), ast.Column('year')], 'type': 'rollup'} + ], None))) + + def test_groupby_cube(self): + """Test CUBE syntax in GROUP BY clause.""" + self.assertParse( + "SELECT * GROUP BY CUBE (account, year);", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + {'columns': [ast.Column('account'), ast.Column('year')], 'type': 'cube'} + ], None))) + + def test_groupby_grouping_sets(self): + """Test GROUPING SETS syntax in GROUP BY clause.""" + self.assertParse( + "SELECT * GROUP BY GROUPING SETS ((account, year), (account), ());", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + {'grouping_sets': [ + [ast.Column('account'), ast.Column('year')], + [ast.Column('account')], + [] + ], 'type': 'sets'} + ], None))) + + def test_groupby_mixed(self): + """Test mixed grouping elements in GROUP BY clause.""" + self.assertParse( + "SELECT * GROUP BY region, ROLLUP (year, month);", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + {'column': ast.Column('region'), 'type': ''}, + {'columns': [ast.Column('year'), ast.Column('month')], 'type': 'rollup'} + ], None))) + + def test_groupby_rollup_with_having(self): + """Test ROLLUP syntax with HAVING clause.""" + self.assertParse( + "SELECT * GROUP BY ROLLUP (account, year) HAVING sum(position) > 100;", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + {'columns': [ast.Column('account'), ast.Column('year')], 'type': 'rollup'} + ], + ast.Greater( + ast.Function('sum', [ast.Column('position')]), + ast.Constant(100))))) class TestSelectOrderBy(QueryParserTestBase): diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 0358d632..5ad0cdef 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -1111,6 +1111,126 @@ def test_aggregated_group_by_with_having(self): ('Expenses:Bar', D(2.0)), ('Expenses:Foo', D(1.0)), ]) + def test_rollup_basic(self): + """Test ROLLUP functionality with a simple example.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, sum(number) as amount + GROUP BY ROLLUP (account) + ORDER BY account; + """, + [ + ('account', str), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', D('-1.00')), + ('Expenses:Restaurant', D('3.00')), + ('Liabilities:Credit-Card', D('-2.00')), + (qc.Sentinel(2, '(Total)'), D('0.00')), # Total row + ]) + + def test_cube_basic(self): + """Test CUBE functionality with a simple example.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, sum(number) as amount + GROUP BY CUBE (account) + ORDER BY account; + """, + [ + ('account', str), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', D('-1.00')), + ('Expenses:Restaurant', D('3.00')), + ('Liabilities:Credit-Card', D('-2.00')), + (qc.Sentinel(2, '(Total)'), D('0.00')), # Total row + ]) + + def test_rollup_two_columns(self): + """Test ROLLUP with two columns.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, year(date) as year, sum(number) as amount + GROUP BY ROLLUP (account, year(date)) + ORDER BY account, year; + """, + [ + ('account', str), + ('year', int), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', 2010, D('-1.00')), + ('Assets:Bank:Checking', qc.Sentinel(2, '(Total)'), D('-1.00')), # Subtotal for account + ('Expenses:Restaurant', 2010, D('3.00')), + ('Expenses:Restaurant', qc.Sentinel(2, '(Total)'), D('3.00')), # Subtotal for account + ('Liabilities:Credit-Card', 2010, D('-2.00')), + ('Liabilities:Credit-Card', qc.Sentinel(2, '(Total)'), D('-2.00')), # Subtotal for account + (qc.Sentinel(2, '(Total)'), qc.Sentinel(2, '(Total)'), D('0.00')), # Grand total + ]) + + def test_cube_two_columns(self): + """Test CUBE with two columns.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, year(date) as year, sum(number) as amount + GROUP BY CUBE (account, year(date)) + ORDER BY account, year; + """, + [ + ('account', str), + ('year', int), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', 2010, D('-1.00')), + ('Assets:Bank:Checking', qc.Sentinel(2, '(Total)'), D('-1.00')), # Subtotal for account + ('Expenses:Restaurant', 2010, D('3.00')), + ('Expenses:Restaurant', qc.Sentinel(2, '(Total)'), D('3.00')), # Subtotal for account + ('Liabilities:Credit-Card', 2010, D('-2.00')), + ('Liabilities:Credit-Card', qc.Sentinel(2, '(Total)'), D('-2.00')), # Subtotal for account + (qc.Sentinel(2, '(Total)'), 2010, D('0.00')), # Subtotal for year + (qc.Sentinel(2, '(Total)'), qc.Sentinel(2, '(Total)'), D('0.00')), # Grand total + ]) class TestExecuteOptions(QueryBase): From 16f7ceed96f91abe18b43caaa916e25aa621f6f8 Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Tue, 9 Dec 2025 18:53:22 +0100 Subject: [PATCH 10/10] Refactor BQL grammar to standardize GROUP BY modifier handling - In AST: Change 'sets' modifier to 'grouping sets' for consistency - Update AST structure to use GroupByElement for consistency with other language constructs - Update tests to reflect new structure --- beanquery/compiler.py | 63 +++++++++++--------- beanquery/parser/ast.py | 15 +++-- beanquery/parser/bql.ebnf | 14 ++--- beanquery/parser/parser.py | 18 +++--- beanquery/parser_test.py | 101 +++++++++++++++++--------------- beanquery/query_compile_test.py | 9 ++- 6 files changed, 123 insertions(+), 97 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index e17c6f9a..0141ceaf 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -122,6 +122,11 @@ def _select(self, node: ast.Select): c_where = c_from_expr if c_where is None else EvalAnd([c_from_expr, c_where]) # Process the GROUP BY clause. + # Returns, among others, `element_indexes`, which looks like + # list(dict(indexes, modifier), ...), where modifier relates to the + # used GROUP BY modifer (ROLLUP, CUBE, GROUPING SETS) or is empty + # if no keyword is used, indexes is a list of column indexes, or a + # list of lists if type == 'grouping sets' new_targets, element_indexes, having_index = self._compile_group_by(node.group_by, c_targets) c_targets.extend(new_targets) @@ -129,9 +134,11 @@ def _select(self, node: ast.Select): new_targets, order_spec = self._compile_order_by(node.order_by, c_targets) c_targets.extend(new_targets) + # For complex grouping, we have not just a list of column names, + # but grouping elements, which must be compiled separately further below if node.group_by and node.group_by.elements: - if any(elem.get('type') in ('rollup', 'cube', 'sets') - for elem in node.group_by.elements): + if any(elem['modifier'] in ('rollup', 'cube', 'grouping sets') + for elem in element_indexes): is_grouping = "complex" else: is_grouping = "simple" @@ -489,8 +496,8 @@ def _compile_group_by(self, group_by, c_targets): element_indexes: A list of dicts, one per grouping element: [{'indexes': [int, ...], 'modifier': str or None}, ...] Each dict represents one grouping element from the grammar. - 'modifier' can be None, 'rollup', 'cube', or 'sets'. - Note: The 'type' field in the AST element is used to determine the modifier. + 'modifier' can be None, 'rollup', 'cube', or 'grouping sets'. + Note: The 'type' property in the AST element is used to determine the modifier. Examples: - Non-aggregate query: None @@ -498,15 +505,15 @@ def _compile_group_by(self, group_by, c_targets): - Regular GROUP BY account, year: [{'indexes': [0], 'modifier': None}, {'indexes': [1], 'modifier': None}] - Full ROLLUP: GROUP BY ROLLUP (account, year): - [{'indexes': [0, 1], 'modifier': 'rollup', 'grouping_sets': None}] + [{'indexes': [0, 1], 'modifier': 'rollup'}] - Full CUBE: GROUP BY CUBE (account, year): - [{'indexes': [0, 1], 'modifier': 'cube', 'grouping_sets': None}] + [{'indexes': [0, 1], 'modifier': 'cube'}] - GROUPING SETS: GROUP BY GROUPING SETS ((account, year), (account), ()): - [{'indexes': [[0, 1], [0], []], 'modifier': 'sets', 'grouping_sets': [...]}] + [{'indexes': [[0, 1], [0], []], 'modifier': 'grouping sets'}] - Mixed grouping: GROUP BY region, ROLLUP (year, month): - [{'indexes': [2], 'modifier': None, 'grouping_sets': None}, {'indexes': [0, 1], 'modifier': 'rollup', 'grouping_sets': None}] + [{'indexes': [2], 'modifier': None}, {'indexes': [0, 1], 'modifier': 'rollup'}] - Implicit GROUP BY (when SUPPORT_IMPLICIT_GROUPBY=True): - [{'indexes': [0], 'modifier': None, 'grouping_sets': None}, {'indexes': [2], 'modifier': None, 'grouping_sets': None}] + [{'indexes': [0], 'modifier': None}, {'indexes': [2], 'modifier': None}] having_index: Index of HAVING expression in targets, or None. """ @@ -535,45 +542,45 @@ def _compile_group_by(self, group_by, c_targets): # Iterating over GROUP BY syntax elements, which are either a # simple grouping column/expression, a ROLLUP (col1, ...) # element, a CUBE (col1, ...) element, or a GROUPING SETS element. - if elem.get('type') == 'rollup': + if elem.type == 'rollup': modifier = 'rollup' - elif elem.get('type') == 'cube': + elif elem.type == 'cube': modifier = 'cube' - elif elem.get('type') == 'sets': - modifier = 'sets' + elif elem.type == 'grouping sets': + modifier = 'grouping sets' else: modifier = None # For GROUPING SETS, 'indexes' will be a list of lists # For other modifiers, 'indexes' is a flat list - if modifier == 'sets': + if modifier == 'grouping sets': element_indexes.append({ - 'indexes': [[] for _ in elem['grouping_sets']], - 'modifier': modifier, - 'grouping_sets': elem.get('grouping_sets') + 'indexes': [[] for _ in elem.columns], + 'modifier': modifier }) else: element_indexes.append({ 'indexes': [], - 'modifier': modifier, - 'grouping_sets': None + 'modifier': modifier }) # Collect all columns with their syntax element position # For GROUPING SETS, also track which set within the element columns_by_element = [] for elem_idx, elem in enumerate(group_by.elements): - if elem.get('type') in ('rollup', 'cube'): - columns = elem['columns'] + if elem.type in ('rollup', 'cube'): + columns = elem.columns for column in columns: columns_by_element.append((elem_idx, None, column)) - elif elem.get('type') == 'sets': + elif elem.type == 'grouping sets': # For GROUPING SETS, track which set each column belongs to - for set_idx, grouping_set in enumerate(elem['grouping_sets']): - for column in grouping_set['columns']: + for set_idx, grouping_set in enumerate(elem.columns): + for column in grouping_set.columns: columns_by_element.append((elem_idx, set_idx, column)) else: - columns_by_element.append((elem_idx, None, elem['column'])) + # Simple grouping (no modifier) + assert elem.type == '' + columns_by_element.append((elem_idx, None, elem.columns)) # Compile all columns and add the indexes to element_indexes, # to return the same structure as in the parsed GROUP BY clause. @@ -1134,11 +1141,11 @@ def _combine_grouping_sets(list_of_sets1, list_of_sets2): def _get_grouping_sets_for_element(element): """Generate grouping sets for a single GROUP BY element. This function isolates the logic for generating grouping sets based on - the element's modifier (`rollup`, `cube`, `sets`). + the element's modifier (`rollup`, `cube`, `grouping sets`). Args: element (dict): A dictionary representing a grouping element, which contains: - - 'modifier' (str): The type of grouping modifier ('rollup', 'cube', 'sets', or None). + - 'modifier' (str): The type of grouping modifier ('rollup', 'cube', 'grouping sets', or None). - 'indexes' (list): A list of integer indexes representing the grouping columns. Returns: @@ -1162,7 +1169,7 @@ def _get_grouping_sets_for_element(element): for combo in itertools.combinations(indexes, i): sets.append(list(combo)) return sets - elif modifier == 'sets': + elif modifier == 'grouping sets': # User-defined sets return indexes else: # Regular column diff --git a/beanquery/parser/ast.py b/beanquery/parser/ast.py index 05a62350..de0feb61 100644 --- a/beanquery/parser/ast.py +++ b/beanquery/parser/ast.py @@ -148,13 +148,20 @@ class From(Node): # A GROUP BY clause. # # Attributes: -# elements: A list of grouping elements. Each element is a dict with: -# - 'column': A single column expression (for regular grouping), or -# - 'columns': A list of columns (for ROLLUP grouping) -# - 'rollup': True if this element is a ROLLUP, None otherwise +# elements: A list of grouping elements. See GroupByElement. # having: An expression tree for the optional HAVING clause, or None. GroupBy = node('GroupBy', 'elements having') +# A GROUP BY grouping element. +# Attributes: +# columns: If type == '': ast.Column() or an integer column index. If +# type != '': A list of ast.Column() or integer column indexes. +# type: Distinguishes the grouping modes relating to the keywords +# ROLLUP, CUBE, GROUPING SETS. 'type' has the keyword in lower case. For +# simple grouping (no keyword), 'type' is the empty string. +# +GroupByElement = node('GroupByElement', 'columns type') + # An ORDER BY clause. # # Attributes: diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index b915451d..bffbbda3 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -1,9 +1,9 @@ @@grammar :: BQL @@parseinfo :: True @@ignorecase :: True -@@keyword :: 'AND' 'AS' 'ASC' 'BY' 'DESC' 'DISTINCT' 'FALSE' 'FROM' - 'GROUP' 'HAVING' 'IN' 'IS' 'LIMIT' 'NOT' 'OR' 'ORDER' 'PIVOT' - 'ROLLUP' 'SELECT' 'TRUE' 'WHERE' +@@keyword :: 'AND' 'AS' 'ASC' 'BY' 'CUBE' 'DESC' 'DISTINCT' 'FALSE' 'FROM' + 'GROUP' 'GROUPING' 'HAVING' 'IN' 'IS' 'LIMIT' 'NOT' 'OR' 'ORDER' 'PIVOT' + 'ROLLUP' 'SELECT' 'SETS' 'TRUE' 'WHERE' @@keyword :: 'CREATE' 'TABLE' 'USING' 'INSERT' 'INTO' @@keyword :: 'BALANCES' 'JOURNAL' 'PRINT' @@comments :: /(\/\*([^*]|[\r\n]|(\*+([^*\/]|[\r\n])))*\*+\/)/ @@ -76,15 +76,15 @@ groupby::GroupBy = elements:','.{ grouping_element }+ ['HAVING' having:expression] ; -grouping_element +grouping_element::GroupByElement = 'ROLLUP' '(' columns:','.{ (integer | expression) }+ ')' type:`rollup` | 'CUBE' '(' columns:','.{ (integer | expression) }+ ')' type:`cube` - | 'GROUPING' 'SETS' '(' grouping_sets:','.{ grouping_set }+ ')' type:`sets` - | column:(integer | expression) type:`` + | 'GROUPING' 'SETS' '(' columns:','.{ grouping_set }+ ')' type:`grouping sets` + | columns: (integer | expression) type:`` ; grouping_set - = '(' columns:','.{ (integer | expression) }* ')' + = '(' @:','.{ (integer | expression) }* ')' ; order::OrderBy diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 699dac8a..e37ecfd5 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -29,11 +29,13 @@ 'AS', 'ASC', 'BY', + 'CUBE', 'DESC', 'DISTINCT', 'FALSE', 'FROM', 'GROUP', + 'GROUPING', 'HAVING', 'IN', 'IS', @@ -44,6 +46,7 @@ 'PIVOT', 'ROLLUP', 'SELECT', + 'SETS', 'TRUE', 'WHERE', 'CREATE', @@ -355,7 +358,7 @@ def block1(): self._define(['having'], []) self._define(['elements', 'having'], []) - @tatsumasu() + @tatsumasu('GroupByElement') def _grouping_element_(self): with self._choice(): with self._option(): @@ -417,11 +420,11 @@ def sep4(): def block5(): self._grouping_set_() self._positive_gather(block5, sep4) - self.name_last_node('grouping_sets') + self.name_last_node('columns') self._token(')') - self._constant('sets') + self._constant('grouping sets') self.name_last_node('type') - self._define(['grouping_sets', 'type'], []) + self._define(['columns', 'type'], []) with self._option(): with self._group(): with self._choice(): @@ -433,10 +436,10 @@ def block5(): 'expecting one of: ' ' ' ) - self.name_last_node('column') + self.name_last_node('columns') self._constant('') self.name_last_node('type') - self._define(['column', 'type'], []) + self._define(['columns', 'type'], []) self._error( 'expecting one of: ' "'CUBE' 'GROUPING' 'ROLLUP' " @@ -463,9 +466,8 @@ def block1(): ' ' ) self._gather(block1, sep0) - self.name_last_node('columns') + self.name_last_node('@') self._token(')') - self._define(['columns'], []) @tatsumasu('OrderBy') def _order_(self): diff --git a/beanquery/parser_test.py b/beanquery/parser_test.py index d00e2c54..77a301d7 100644 --- a/beanquery/parser_test.py +++ b/beanquery/parser_test.py @@ -309,48 +309,56 @@ def test_groupby_one(self): self.assertParse( "SELECT * GROUP BY a;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ast.Column('a')], None))) + group_by=ast.GroupBy( + [ast.GroupByElement(ast.Column('a'), '')], None))) def test_groupby_many(self): + ge = ast.GroupByElement self.assertParse( "SELECT * GROUP BY a, b, c;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - ast.Column('a'), - ast.Column('b'), - ast.Column('c')], None))) + group_by=ast.GroupBy([ + ge(ast.Column('a'), ''), + ge(ast.Column('b'), ''), + ge(ast.Column('c'), '')], + None))) def test_groupby_expr(self): + ge = ast.GroupByElement self.assertParse( "SELECT * GROUP BY length(a) > 0, b;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - ast.Greater( - ast.Function('length', [ - ast.Column('a')]), - ast.Constant(0)), - ast.Column('b')], None))) + group_by=ast.GroupBy([ + ge( ast.Greater( + ast.Function('length', [ + ast.Column('a')]), + ast.Constant(0)), ''), + ge(ast.Column('b'), '')], + None))) def test_groupby_having(self): self.assertParse( "SELECT * GROUP BY a HAVING sum(x) = 0;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ast.Column('a')], - ast.Equal( - ast.Function('sum', [ - ast.Column('x')]), - ast.Constant(0))))) + group_by=ast.GroupBy( + [ast.GroupByElement(ast.Column('a'), '')], + ast.Equal( + ast.Function('sum', [ + ast.Column('x')]), + ast.Constant(0))))) def test_groupby_numbers(self): self.assertParse( "SELECT * GROUP BY 1;", Select(ast.Asterisk(), - group_by=ast.GroupBy([1], None))) + group_by=ast.GroupBy([ast.GroupByElement(1, '')], None))) + ge = ast.GroupByElement self.assertParse( "SELECT * GROUP BY 2, 4, 5;", Select(ast.Asterisk(), - group_by=ast.GroupBy([2, 4, 5], None))) + group_by=ast.GroupBy( + [ge(2, ''), ge(4, ''), ge(5, '')], None))) def test_groupby_empty(self): with self.assertRaises(parser.ParseError): @@ -361,53 +369,52 @@ def test_groupby_rollup(self): self.assertParse( "SELECT * GROUP BY ROLLUP (account, year);", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - {'columns': [ast.Column('account'), ast.Column('year')], 'type': 'rollup'} - ], None))) + group_by=ast.GroupBy([ + ast.GroupByElement( + [ast.Column('account'), ast.Column('year')], 'rollup') + ], None) + ) + ) def test_groupby_cube(self): """Test CUBE syntax in GROUP BY clause.""" self.assertParse( "SELECT * GROUP BY CUBE (account, year);", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - {'columns': [ast.Column('account'), ast.Column('year')], 'type': 'cube'} - ], None))) + group_by=ast.GroupBy([ + ast.GroupByElement( + [ast.Column('account'), ast.Column('year')], 'cube') + ], None) + ) + ) def test_groupby_grouping_sets(self): """Test GROUPING SETS syntax in GROUP BY clause.""" self.assertParse( "SELECT * GROUP BY GROUPING SETS ((account, year), (account), ());", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - {'grouping_sets': [ - [ast.Column('account'), ast.Column('year')], - [ast.Column('account')], - [] - ], 'type': 'sets'} - ], None))) + group_by=ast.GroupBy([ + ast.GroupByElement([ + [ast.Column('account'), ast.Column('year')], + [ast.Column('account')], + [] + ], 'grouping sets') + ], None) + ) + ) def test_groupby_mixed(self): """Test mixed grouping elements in GROUP BY clause.""" + ge = ast.GroupByElement self.assertParse( "SELECT * GROUP BY region, ROLLUP (year, month);", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - {'column': ast.Column('region'), 'type': ''}, - {'columns': [ast.Column('year'), ast.Column('month')], 'type': 'rollup'} - ], None))) - - def test_groupby_rollup_with_having(self): - """Test ROLLUP syntax with HAVING clause.""" - self.assertParse( - "SELECT * GROUP BY ROLLUP (account, year) HAVING sum(position) > 100;", - Select(ast.Asterisk(), - group_by=ast.GroupBy([ - {'columns': [ast.Column('account'), ast.Column('year')], 'type': 'rollup'} - ], - ast.Greater( - ast.Function('sum', [ast.Column('position')]), - ast.Constant(100))))) + group_by=ast.GroupBy([ + ge(ast.Column('region'),''), + ge([ast.Column('year'), ast.Column('month')], 'rollup') + ], None) + ) + ) class TestSelectOrderBy(QueryParserTestBase): diff --git a/beanquery/query_compile_test.py b/beanquery/query_compile_test.py index 4418f45a..44fe9dcb 100644 --- a/beanquery/query_compile_test.py +++ b/beanquery/query_compile_test.py @@ -656,10 +656,13 @@ def test_journal_with_account_func_and_from(self): class TestTranslationBalance(CompileSelectBase): + _ge = ast.GroupByElement # to shorten test cases + group_by = ast.GroupBy([ - ast.Column('account'), - ast.Function('account_sortkey', [ - ast.Column(name='account')])], None) + _ge(ast.Column('account'), ''), + _ge(ast.Function('account_sortkey', [ + ast.Column(name='account')]), '') + ], None) order_by = [ast.OrderBy(ast.Function('account_sortkey', [ast.Column('account')]), ast.Ordering.ASC)]