diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 6e6c442e..0141ceaf 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -1,5 +1,6 @@ import collections.abc import importlib +import itertools import typing from decimal import Decimal @@ -28,6 +29,7 @@ EvalOr, EvalPivot, EvalQuery, + EvalUnion, EvalConstantSubquery1D, EvalRow, EvalTarget, @@ -99,8 +101,8 @@ def _compile(self, node: Optional[ast.Node]): @_compile.register def _select(self, node: ast.Select): self.stack.append(self.table) - - # Compile the FROM clause. + + # Compile the FROM clause c_from_expr = self._compile_from(node.from_clause) # Compile the targets. @@ -119,20 +121,56 @@ def _select(self, node: ast.Select): if c_from_expr is not None: c_where = c_from_expr if c_where is None else EvalAnd([c_from_expr, c_where]) - # Process the GROUP-BY clause. - new_targets, group_indexes, having_index = self._compile_group_by(node.group_by, c_targets) + # Process the GROUP BY clause. + # Returns, among others, `element_indexes`, which looks like + # list(dict(indexes, modifier), ...), where modifier relates to the + # used GROUP BY modifer (ROLLUP, CUBE, GROUPING SETS) or is empty + # if no keyword is used, indexes is a list of column indexes, or a + # list of lists if type == 'grouping sets' + new_targets, element_indexes, having_index = self._compile_group_by(node.group_by, c_targets) c_targets.extend(new_targets) # Process the ORDER-BY clause. new_targets, order_spec = self._compile_order_by(node.order_by, c_targets) c_targets.extend(new_targets) + + # For complex grouping, we have not just a list of column names, + # but grouping elements, which must be compiled separately further below + if node.group_by and node.group_by.elements: + if any(elem['modifier'] in ('rollup', 'cube', 'grouping sets') + for elem in element_indexes): + is_grouping = "complex" + else: + is_grouping = "simple" + else: + is_grouping = "none" + + if is_grouping in ['none', 'simple'] and element_indexes is not None: + # Element indexes might be != None if we are grouping implicitly + # (only aggregate functions or explicitly) + # + # For simple grouping, there is no ROLLUP/CUBE clause, therefore all + # elements follow the format {'indexes': [x], 'modifier': None} + assert all(e['modifier'] is None for e in element_indexes) + + # Ensure all elements in element_indexes[x]['indexes'] are of type int. + assert all(isinstance(idx, int) + for elem in element_indexes for idx in elem['indexes']) + + # Flatten element_indexes for regular GROUP BY + if element_indexes is not None: + group_indexes = [] + for elem in element_indexes: + group_indexes.extend(elem['indexes']) + else: + group_indexes = None # If this is an aggregate query (it groups, see list of indexes), check that # the set of non-aggregates match exactly the group indexes. This should # always be the case at this point, because we have added all the necessary # targets to the list of group-by expressions and should have resolved all # the indexes. - if group_indexes is not None: + if is_grouping != 'none': non_aggregate_indexes = {index for index, c_target in enumerate(c_targets) if not c_target.is_aggregate} if non_aggregate_indexes != set(group_indexes): @@ -142,15 +180,26 @@ def _select(self, node: ast.Select): 'all non-aggregates must be covered by GROUP-BY clause in aggregate query: ' 'the following targets are missing: {}'.format(','.join(missing_names))) - query = EvalQuery(self.table, - c_targets, - c_where, - group_indexes, - having_index, - order_spec, - node.limit, - node.distinct) + # Handle ROLLUP/CUBE/GROUPING SETS queries separately: + # Check if any grouping element has such a modifier. + if is_grouping == "complex": + + # Obtain the UnionEval node and the flattened list of group indexes + query, group_indexes = self._compile_grouping_sets(node, c_targets, c_where, element_indexes, having_index, order_spec) + + else: # grouping in ['none', 'simple' ] + + + query = EvalQuery(self.table, + c_targets, + c_where, + group_indexes, + having_index, + order_spec, + node.limit, + node.distinct) + pivots = self._compile_pivot_by(node.pivot_by, c_targets, group_indexes) if pivots: return EvalPivot(query, pivots) @@ -158,6 +207,86 @@ def _select(self, node: ast.Select): self.stack.pop() return query + + def _compile_grouping_sets(self, node: ast.Select, c_targets, c_where, element_indexes, having_index, order_spec): + """Compile a query with complex grouping (ROLLUP, CUBE, SETS) as a union. + Supports full and mixed grouping: + - ROLLUP: GROUP BY ROLLUP (a, b, c) → grouping sets: [(a,b,c), (a,b), (a), ()] + - ROLLUP: GROUP BY x, ROLLUP (a, b) → grouping sets: [(x,a,b), (x,a), (x)] + - CUBE: GROUP BY CUBE (a, b) → grouping sets: [(a,b), (a), (b), ()] + - CUBE: GROUP BY x, CUBE (a, b) → grouping sets: [(x,a,b), (x,a), (x,b), (x)] + - SETS: GROUP BY GROUPING SETS ((a, b), (a), ()) → grouping sets: [(a,b), (a), ()] + + This function handles standard-compliant mixed grouping by: + 1. Separating simple grouping columns from complex ones (ROLLUP, etc.). + 2. Generating grouping sets for each complex element. + 3. Combining the sets from complex elements using a cartesian product. + 4. Prepending the simple grouping columns to each resulting set. + + Args: + node: A Select AST node with one or more complex grouping elements. + + Returns: + A tuple of (EvalUnion, list of group indexes). The EvalUnion node + executes a query for each final grouping set. The list of group + indexes has the indexes of the unique columns used in the GROUP BY + clause. + """ + # Separate simple and complex grouping elements + # Why: Simple columns are treated as a prefix for all grouping sets, + # while complex elements (ROLLUP, CUBE, SETS) generate multiple sets + # that need to be combined. + simple_indexes = [] + complex_elements = [] + for elem in element_indexes: + if elem['modifier'] is None: + simple_indexes.extend(elem['indexes']) + else: + complex_elements.append(elem) + + # Generate and combine grouping sets from complex elements + # Why: We iterate through the complex elements, generate the grouping + # sets for each, and combine them using a cartesian product to handle + # mixed grouping constructs like `GROUP BY ROLLUP(...), CUBE(...)`. + combined_sets = [[]] # Start with an empty set for the initial product + for elem in complex_elements: + element_sets = _get_grouping_sets_for_element(elem) + combined_sets = _combine_grouping_sets(combined_sets, element_sets) + + # Prepend simple columns to all generated sets + # Why: The simple GROUP BY columns must be included in every grouping + # set generated by the complex elements. + final_grouping_sets = [simple_indexes + s for s in combined_sets] + + # Create a query for each final grouping set + # Why: The result of a complex grouping query is the UNION of the + # results of running a separate query for each individual grouping set. + queries = [ + EvalQuery(self.table, + c_targets, + c_where, + grouping_set, + having_index, + None, + None, + node.distinct) + for grouping_set in final_grouping_sets + ] + + # Wrap the individual queries in a UNION operator. + union = EvalUnion( + queries=queries, + rollup_sets=final_grouping_sets, + order_spec=order_spec, + limit=node.limit + ) + + # List of unique column indexes used in the GROUP BY clause + all_group_by_indexes = list(set(idx for s in final_grouping_sets for idx in s)) + + return union, all_group_by_indexes + + def _compile_from(self, node): if node is None: return None @@ -364,20 +493,38 @@ def _compile_group_by(self, group_by, c_targets): Returns: A tuple of new_targets: A list of new compiled target nodes. - group_indexes: If the query is an aggregate query, a list of integer - indexes to be used for processing grouping. Note that this list may be - empty (in the case of targets with only aggregates). On the other hand, - if this is not an aggregated query, this is set to None. So do - distinguish the empty list vs. None. + element_indexes: A list of dicts, one per grouping element: + [{'indexes': [int, ...], 'modifier': str or None}, ...] + Each dict represents one grouping element from the grammar. + 'modifier' can be None, 'rollup', 'cube', or 'grouping sets'. + Note: The 'type' property in the AST element is used to determine the modifier. + + Examples: + - Non-aggregate query: None + - Aggregate without GROUP BY: [] + - Regular GROUP BY account, year: + [{'indexes': [0], 'modifier': None}, {'indexes': [1], 'modifier': None}] + - Full ROLLUP: GROUP BY ROLLUP (account, year): + [{'indexes': [0, 1], 'modifier': 'rollup'}] + - Full CUBE: GROUP BY CUBE (account, year): + [{'indexes': [0, 1], 'modifier': 'cube'}] + - GROUPING SETS: GROUP BY GROUPING SETS ((account, year), (account), ()): + [{'indexes': [[0, 1], [0], []], 'modifier': 'grouping sets'}] + - Mixed grouping: GROUP BY region, ROLLUP (year, month): + [{'indexes': [2], 'modifier': None}, {'indexes': [0, 1], 'modifier': 'rollup'}] + - Implicit GROUP BY (when SUPPORT_IMPLICIT_GROUPBY=True): + [{'indexes': [0], 'modifier': None}, {'indexes': [2], 'modifier': None}] + + having_index: Index of HAVING expression in targets, or None. """ new_targets = c_targets[:] c_target_expressions = [c_target.c_expr for c_target in c_targets] - group_indexes = [] + element_indexes = [] having_index = None if group_by: - assert group_by.columns, "Internal error with GROUP-BY parsing" + assert group_by.elements, "Internal error with GROUP-BY parsing" # Compile group-by expressions and resolve them to their targets if # possible. A GROUP-BY column may be one of the following: @@ -389,7 +536,55 @@ def _compile_group_by(self, group_by, c_targets): # References by name are converted to indexes. New expressions are # inserted into the list of targets as invisible targets. targets_name_map = {target.name: index for index, target in enumerate(c_targets)} - for column in group_by.columns: + + # Initialize element structures + for elem in group_by.elements: + # Iterating over GROUP BY syntax elements, which are either a + # simple grouping column/expression, a ROLLUP (col1, ...) + # element, a CUBE (col1, ...) element, or a GROUPING SETS element. + if elem.type == 'rollup': + modifier = 'rollup' + elif elem.type == 'cube': + modifier = 'cube' + elif elem.type == 'grouping sets': + modifier = 'grouping sets' + else: + modifier = None + + # For GROUPING SETS, 'indexes' will be a list of lists + # For other modifiers, 'indexes' is a flat list + if modifier == 'grouping sets': + element_indexes.append({ + 'indexes': [[] for _ in elem.columns], + 'modifier': modifier + }) + else: + element_indexes.append({ + 'indexes': [], + 'modifier': modifier + }) + + # Collect all columns with their syntax element position + # For GROUPING SETS, also track which set within the element + columns_by_element = [] + for elem_idx, elem in enumerate(group_by.elements): + if elem.type in ('rollup', 'cube'): + columns = elem.columns + for column in columns: + columns_by_element.append((elem_idx, None, column)) + elif elem.type == 'grouping sets': + # For GROUPING SETS, track which set each column belongs to + for set_idx, grouping_set in enumerate(elem.columns): + for column in grouping_set.columns: + columns_by_element.append((elem_idx, set_idx, column)) + else: + # Simple grouping (no modifier) + assert elem.type == '' + columns_by_element.append((elem_idx, None, elem.columns)) + + # Compile all columns and add the indexes to element_indexes, + # to return the same structure as in the parsed GROUP BY clause. + for elem_idx, set_idx, column in columns_by_element: index = None # Process target references by index. @@ -428,7 +623,6 @@ def _compile_group_by(self, group_by, c_targets): c_target_expressions.append(c_expr) assert index is not None, "Internal error, could not index group-by reference." - group_indexes.append(index) # Check that the group-by column references a non-aggregate. c_expr = new_targets[index].c_expr @@ -438,6 +632,14 @@ def _compile_group_by(self, group_by, c_targets): # Check that the group-by column has a supported hashable type. if not issubclass(c_expr.dtype, collections.abc.Hashable): raise CompilationError(f'GROUP-BY a non-hashable type is not supported: "{column}"') + + # Add compiled index to the corresponding element + # For GROUPING SETS, append to the specific set's list + # For other modifiers, append to the flat list + if set_idx is not None: + element_indexes[elem_idx]['indexes'][set_idx].append(index) + else: + element_indexes[elem_idx]['indexes'].append(index) # Compile HAVING clause. if group_by.having is not None: @@ -455,25 +657,28 @@ def _compile_group_by(self, group_by, c_targets): # If the query is an aggregate query, check that all the targets are # aggregates. if all(aggregate_bools): - # FIXME: shold we really be checking for the empty - # list or is checking for a false value enough? - assert group_indexes == [] + # Empty element_indexes for aggregate query without GROUP BY + element_indexes = [] elif SUPPORT_IMPLICIT_GROUPBY: # If some of the targets aren't aggregates, automatically infer # that they are to be implicit group by targets. This makes for # a much more convenient syntax for our lightweight SQL, where # grouping is optional. - group_indexes = [ + implicit_indexes = [ index for index, c_target in enumerate(c_targets) if not c_target.is_aggregate] + # Wrap as individual elements + element_indexes = [ + {'indexes': [idx], 'modifier': None} for idx in implicit_indexes + ] else: raise CompilationError('aggregate query without a GROUP-BY should have only aggregates') else: - # This is not an aggregate query; don't set group_indexes to + # This is not an aggregate query; don't set element_indexes to # anything useful, we won't need it. - group_indexes = None + element_indexes = None - return new_targets[len(c_targets):], group_indexes, having_index + return new_targets[len(c_targets):], element_indexes, having_index @_compile.register def _column(self, node: ast.Column): @@ -909,5 +1114,67 @@ def is_aggregate(node): return bool(aggregates) + +def _combine_grouping_sets(list_of_sets1, list_of_sets2): + """Compute the cartesian product of two lists of grouping sets. + + >>> _combine_grouping_sets([['a'], ['b']], [['c'], ['d']]) + [['a', 'c'], ['a', 'd'], ['b', 'c'], ['b', 'd']] + >>> _combine_grouping_sets([['a']], [['b', 'c'], []]) + [['a', 'b', 'c'], ['a']] + >>> _combine_grouping_sets([], [['a']]) + [['a']] + >>> _combine_grouping_sets([['a']], []) + [['a']] + """ + # Why: This helper function is used to combine grouping sets from different + # grouping elements (e.g., `ROLLUP` and `CUBE`) by creating a + # cartesian product of their individual grouping sets. + import itertools + if not list_of_sets1: + return list_of_sets2 + if not list_of_sets2: + return list_of_sets1 + + return [s1 + s2 for s1, s2 in itertools.product(list_of_sets1, list_of_sets2)] + +def _get_grouping_sets_for_element(element): + """Generate grouping sets for a single GROUP BY element. + This function isolates the logic for generating grouping sets based on + the element's modifier (`rollup`, `cube`, `grouping sets`). + + Args: + element (dict): A dictionary representing a grouping element, which contains: + - 'modifier' (str): The type of grouping modifier ('rollup', 'cube', 'grouping sets', or None). + - 'indexes' (list): A list of integer indexes representing the grouping columns. + + Returns: + list: A list of lists, where each inner list represents a grouping set. + + Example: + >>> element = {'modifier': 'rollup', 'indexes': [0, 1]} + >>> _get_grouping_sets_for_element(element) + [[0, 1], [0], []] + """ + modifier = element['modifier'] + indexes = element['indexes'] + + if modifier == 'rollup': + # Hierarchical prefixes: e.g., (a,b) -> [(a,b), (a), ()] + return [indexes[:i] for i in range(len(indexes), -1, -1)] + elif modifier == 'cube': + # Power set: e.g., (a,b) -> [(a,b), (a), (b), ()] + sets = [] + for i in range(len(indexes), -1, -1): + for combo in itertools.combinations(indexes, i): + sets.append(list(combo)) + return sets + elif modifier == 'grouping sets': + # User-defined sets + return indexes + else: # Regular column + return [indexes] + + def compile(context, statement, parameters=None): return Compiler(context).compile(statement, parameters) diff --git a/beanquery/parser/ast.py b/beanquery/parser/ast.py index 0fc9e49e..de0feb61 100644 --- a/beanquery/parser/ast.py +++ b/beanquery/parser/ast.py @@ -148,9 +148,19 @@ class From(Node): # A GROUP BY clause. # # Attributes: -# columns: A list of group-by expressions, simple Column() or otherwise. +# elements: A list of grouping elements. See GroupByElement. # having: An expression tree for the optional HAVING clause, or None. -GroupBy = node('GroupBy', 'columns having') +GroupBy = node('GroupBy', 'elements having') + +# A GROUP BY grouping element. +# Attributes: +# columns: If type == '': ast.Column() or an integer column index. If +# type != '': A list of ast.Column() or integer column indexes. +# type: Distinguishes the grouping modes relating to the keywords +# ROLLUP, CUBE, GROUPING SETS. 'type' has the keyword in lower case. For +# simple grouping (no keyword), 'type' is the empty string. +# +GroupByElement = node('GroupByElement', 'columns type') # An ORDER BY clause. # diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 1c078b63..bffbbda3 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -1,9 +1,9 @@ @@grammar :: BQL @@parseinfo :: True @@ignorecase :: True -@@keyword :: 'AND' 'AS' 'ASC' 'BY' 'DESC' 'DISTINCT' 'FALSE' 'FROM' - 'GROUP' 'HAVING' 'IN' 'IS' 'LIMIT' 'NOT' 'OR' 'ORDER' 'PIVOT' - 'SELECT' 'TRUE' 'WHERE' +@@keyword :: 'AND' 'AS' 'ASC' 'BY' 'CUBE' 'DESC' 'DISTINCT' 'FALSE' 'FROM' + 'GROUP' 'GROUPING' 'HAVING' 'IN' 'IS' 'LIMIT' 'NOT' 'OR' 'ORDER' 'PIVOT' + 'ROLLUP' 'SELECT' 'SETS' 'TRUE' 'WHERE' @@keyword :: 'CREATE' 'TABLE' 'USING' 'INSERT' 'INTO' @@keyword :: 'BALANCES' 'JOURNAL' 'PRINT' @@comments :: /(\/\*([^*]|[\r\n]|(\*+([^*\/]|[\r\n])))*\*+\/)/ @@ -55,8 +55,36 @@ table::Table = name:identifier ; + +# A GROUP BY clause supports multiple syntaxes: +# 1. Regular GROUP BY: GROUP BY col1, col2 [HAVING condition] +# 2. Full ROLLUP: GROUP BY ROLLUP (col1, col2) [HAVING condition] +# 3. Full CUBE: GROUP BY CUBE (col1, col2) [HAVING condition] +# 4. GROUPING SETS: GROUP BY GROUPING SETS ((col1, col2), (col1), ()) [HAVING condition] +# 5. Mixed grouping: GROUP BY col1, ROLLUP (col2, col3) [HAVING condition] +# +# Examples: +# - GROUP BY account, year +# - GROUP BY account, year HAVING SUM(position) > 100 +# - GROUP BY ROLLUP (account, year) +# - GROUP BY ROLLUP (account, year) HAVING SUM(position) > 100 +# - GROUP BY CUBE (account, year) +# - GROUP BY GROUPING SETS ((account, year), (account), ()) +# - GROUP BY region, ROLLUP (year, month) +# - GROUP BY region, CUBE (year, month) groupby::GroupBy - = columns:','.{ (integer | expression) }+ ['HAVING' having:expression] + = elements:','.{ grouping_element }+ ['HAVING' having:expression] + ; + +grouping_element::GroupByElement + = 'ROLLUP' '(' columns:','.{ (integer | expression) }+ ')' type:`rollup` + | 'CUBE' '(' columns:','.{ (integer | expression) }+ ')' type:`cube` + | 'GROUPING' 'SETS' '(' columns:','.{ grouping_set }+ ')' type:`grouping sets` + | columns: (integer | expression) type:`` + ; + +grouping_set + = '(' @:','.{ (integer | expression) }* ')' ; order::OrderBy diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 2ec9ee84..e37ecfd5 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -29,11 +29,13 @@ 'AS', 'ASC', 'BY', + 'CUBE', 'DESC', 'DISTINCT', 'FALSE', 'FROM', 'GROUP', + 'GROUPING', 'HAVING', 'IN', 'IS', @@ -42,7 +44,9 @@ 'OR', 'ORDER', 'PIVOT', + 'ROLLUP', 'SELECT', + 'SETS', 'TRUE', 'WHERE', 'CREATE', @@ -340,6 +344,113 @@ def _table_(self): @tatsumasu('GroupBy') def _groupby_(self): + def sep0(): + self._token(',') + + def block1(): + self._grouping_element_() + self._positive_gather(block1, sep0) + self.name_last_node('elements') + with self._optional(): + self._token('HAVING') + self._expression_() + self.name_last_node('having') + self._define(['having'], []) + self._define(['elements', 'having'], []) + + @tatsumasu('GroupByElement') + def _grouping_element_(self): + with self._choice(): + with self._option(): + self._token('ROLLUP') + self._token('(') + + def sep0(): + self._token(',') + + def block1(): + with self._group(): + with self._choice(): + with self._option(): + self._integer_() + with self._option(): + self._expression_() + self._error( + 'expecting one of: ' + ' ' + ) + self._positive_gather(block1, sep0) + self.name_last_node('columns') + self._token(')') + self._constant('rollup') + self.name_last_node('type') + self._define(['columns', 'type'], []) + with self._option(): + self._token('CUBE') + self._token('(') + + def sep2(): + self._token(',') + + def block3(): + with self._group(): + with self._choice(): + with self._option(): + self._integer_() + with self._option(): + self._expression_() + self._error( + 'expecting one of: ' + ' ' + ) + self._positive_gather(block3, sep2) + self.name_last_node('columns') + self._token(')') + self._constant('cube') + self.name_last_node('type') + self._define(['columns', 'type'], []) + with self._option(): + self._token('GROUPING') + self._token('SETS') + self._token('(') + + def sep4(): + self._token(',') + + def block5(): + self._grouping_set_() + self._positive_gather(block5, sep4) + self.name_last_node('columns') + self._token(')') + self._constant('grouping sets') + self.name_last_node('type') + self._define(['columns', 'type'], []) + with self._option(): + with self._group(): + with self._choice(): + with self._option(): + self._integer_() + with self._option(): + self._expression_() + self._error( + 'expecting one of: ' + ' ' + ) + self.name_last_node('columns') + self._constant('') + self.name_last_node('type') + self._define(['columns', 'type'], []) + self._error( + 'expecting one of: ' + "'CUBE' 'GROUPING' 'ROLLUP' " + ' ' + '[0-9]+' + ) + + @tatsumasu() + def _grouping_set_(self): + self._token('(') + def sep0(): self._token(',') @@ -354,14 +465,9 @@ def block1(): 'expecting one of: ' ' ' ) - self._positive_gather(block1, sep0) - self.name_last_node('columns') - with self._optional(): - self._token('HAVING') - self._expression_() - self.name_last_node('having') - self._define(['having'], []) - self._define(['columns', 'having'], []) + self._gather(block1, sep0) + self.name_last_node('@') + self._token(')') @tatsumasu('OrderBy') def _order_(self): diff --git a/beanquery/parser_test.py b/beanquery/parser_test.py index 33f47fdb..77a301d7 100644 --- a/beanquery/parser_test.py +++ b/beanquery/parser_test.py @@ -309,52 +309,112 @@ def test_groupby_one(self): self.assertParse( "SELECT * GROUP BY a;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ast.Column('a')], None))) + group_by=ast.GroupBy( + [ast.GroupByElement(ast.Column('a'), '')], None))) def test_groupby_many(self): + ge = ast.GroupByElement self.assertParse( "SELECT * GROUP BY a, b, c;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - ast.Column('a'), - ast.Column('b'), - ast.Column('c')], None))) + group_by=ast.GroupBy([ + ge(ast.Column('a'), ''), + ge(ast.Column('b'), ''), + ge(ast.Column('c'), '')], + None))) def test_groupby_expr(self): + ge = ast.GroupByElement self.assertParse( "SELECT * GROUP BY length(a) > 0, b;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ - ast.Greater( - ast.Function('length', [ - ast.Column('a')]), - ast.Constant(0)), - ast.Column('b')], None))) + group_by=ast.GroupBy([ + ge( ast.Greater( + ast.Function('length', [ + ast.Column('a')]), + ast.Constant(0)), ''), + ge(ast.Column('b'), '')], + None))) def test_groupby_having(self): self.assertParse( "SELECT * GROUP BY a HAVING sum(x) = 0;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ast.Column('a')], - ast.Equal( - ast.Function('sum', [ - ast.Column('x')]), - ast.Constant(0))))) + group_by=ast.GroupBy( + [ast.GroupByElement(ast.Column('a'), '')], + ast.Equal( + ast.Function('sum', [ + ast.Column('x')]), + ast.Constant(0))))) def test_groupby_numbers(self): self.assertParse( "SELECT * GROUP BY 1;", Select(ast.Asterisk(), - group_by=ast.GroupBy([1], None))) + group_by=ast.GroupBy([ast.GroupByElement(1, '')], None))) + ge = ast.GroupByElement self.assertParse( "SELECT * GROUP BY 2, 4, 5;", Select(ast.Asterisk(), - group_by=ast.GroupBy([2, 4, 5], None))) + group_by=ast.GroupBy( + [ge(2, ''), ge(4, ''), ge(5, '')], None))) def test_groupby_empty(self): with self.assertRaises(parser.ParseError): parser.parse("SELECT * GROUP BY;") + + def test_groupby_rollup(self): + """Test ROLLUP syntax in GROUP BY clause.""" + self.assertParse( + "SELECT * GROUP BY ROLLUP (account, year);", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + ast.GroupByElement( + [ast.Column('account'), ast.Column('year')], 'rollup') + ], None) + ) + ) + + def test_groupby_cube(self): + """Test CUBE syntax in GROUP BY clause.""" + self.assertParse( + "SELECT * GROUP BY CUBE (account, year);", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + ast.GroupByElement( + [ast.Column('account'), ast.Column('year')], 'cube') + ], None) + ) + ) + + def test_groupby_grouping_sets(self): + """Test GROUPING SETS syntax in GROUP BY clause.""" + self.assertParse( + "SELECT * GROUP BY GROUPING SETS ((account, year), (account), ());", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + ast.GroupByElement([ + [ast.Column('account'), ast.Column('year')], + [ast.Column('account')], + [] + ], 'grouping sets') + ], None) + ) + ) + + def test_groupby_mixed(self): + """Test mixed grouping elements in GROUP BY clause.""" + ge = ast.GroupByElement + self.assertParse( + "SELECT * GROUP BY region, ROLLUP (year, month);", + Select(ast.Asterisk(), + group_by=ast.GroupBy([ + ge(ast.Column('region'),''), + ge([ast.Column('year'), ast.Column('month')], 'rollup') + ], None) + ) + ) class TestSelectOrderBy(QueryParserTestBase): diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index eda87ec7..e6ae8cd4 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -13,10 +13,62 @@ import collections import dataclasses import datetime +import functools import itertools import re import operator + +@functools.total_ordering +class Sentinel: + """General sentinel value with configurable sort order and display string. + + Sentinels are special marker values that can be used in query results. + They have a configurable sort order relative to regular values: + - Negative sort_order: sorts before all non-sentinel values + - Positive sort_order: sorts after all non-sentinel values + - Zero sort_order: sorts with regular values (not recommended) + + Args: + sort_order: Integer determining sort position. Negative sorts before + regular values, positive sorts after. + display: String representation when rendered. + """ + + def __init__(self, sort_order, display): + self.sort_order = sort_order + self.display = display + + def __str__(self): + return self.display + + def __repr__(self): + return f"Sentinel({self.sort_order}, {self.display!r})" + + def __eq__(self, other): + if not isinstance(other, Sentinel): + return False + return self.sort_order == other.sort_order and self.display == other.display + + def __lt__(self, other): + if isinstance(other, Sentinel): + # Compare sentinels by their sort order + return self.sort_order < other.sort_order + else: + # Compare sentinel to non-sentinel value + # Negative sort_order: sentinel sorts before regular values + # Positive sort_order: sentinel sorts after regular values + return self.sort_order < 0 + + def __hash__(self): + return hash((self.sort_order, self.display)) + + +# Sentinel instances for various use cases +SENTINEL_ROLLUP_TOTAL = Sentinel(2, "(Total)") +SENTINEL_EARLIER = Sentinel(-1, "(earlier)") +SENTINEL_LATER = Sentinel(1, "(later)") + from decimal import Decimal from typing import List @@ -637,6 +689,84 @@ def __call__(self): return query_execute.execute_select(self) +@dataclasses.dataclass +class EvalUnion: + """Implement UNION of multiple queries for ROLLUP. + + This class executes multiple GROUP BY queries with different grouping sets + and unions their results to implement SQL ROLLUP functionality. + + Execution steps: + 1. Compute the union of all grouping sets to identify all group columns + 2. Execute each query with its specific grouping set + 3. For each result set, NULL out columns that are not in that grouping set + (this distinguishes subtotal rows from detail rows) + 4. Union all result sets together + 5. Apply ORDER BY to the combined results (with NULL sorting last) + 6. Apply LIMIT if specified + + Example for GROUP BY region, ROLLUP (year, month): + - Query 1: GROUP BY region, year, month (detail rows) + - Query 2: GROUP BY region, year (year subtotals, month=NULL) + - Query 3: GROUP BY region (region subtotals, year=NULL, month=NULL) + """ + + queries: list[EvalQuery] + rollup_sets: list[list[int]] # List of grouping sets, one per query + order_spec: list[tuple[int, ast.Ordering]] + limit: int + + @property + def columns(self): + # All queries have the same columns + return self.queries[0].columns + + def __call__(self): + # Execute all queries and union the results + all_rows = [] + columns = None + + # Compute union of all grouping sets to find all group columns + # This is needed to determine which columns to NULL at each level + full_group_indexes = list(set(idx for grouping_set in self.rollup_sets for idx in grouping_set)) + + for query, grouping_set in zip(self.queries, self.rollup_sets): + cols, rows = query() + if columns is None: + columns = cols + + # Mark subtotal columns that are not in this grouping set + # Columns in full_group_indexes but not in grouping_set get SENTINEL_ROLLUP_TOTAL sentinel + grouping_set_indexes = set(grouping_set) + subtotal_indexes = [idx for idx in full_group_indexes if idx not in grouping_set_indexes] + + # Replace values with SENTINEL_ROLLUP_TOTAL for non-grouped columns (subtotal rows) + if subtotal_indexes: + rows = [ + tuple(SENTINEL_ROLLUP_TOTAL if i in subtotal_indexes else val for i, val in enumerate(row)) + for row in rows + ] + + all_rows.extend(rows) + + # Apply ORDER BY if specified + if self.order_spec: + # Sort in reverse order to leverage Python's stable sort + for col_index, ordering in reversed(self.order_spec): + # SENTINEL_ROLLUP_TOTAL sorts last (after all regular values) + # Sentinel implements comparison operators to sort based on sort_order + all_rows.sort( + key=lambda row: row[col_index], + reverse=bool(ordering) + ) + + # Apply LIMIT if specified + if self.limit is not None: + all_rows = all_rows[:self.limit] + + return columns, all_rows + + @dataclasses.dataclass class EvalPivot: """Implement PIVOT BY clause.""" diff --git a/beanquery/query_compile_test.py b/beanquery/query_compile_test.py index 4418f45a..44fe9dcb 100644 --- a/beanquery/query_compile_test.py +++ b/beanquery/query_compile_test.py @@ -656,10 +656,13 @@ def test_journal_with_account_func_and_from(self): class TestTranslationBalance(CompileSelectBase): + _ge = ast.GroupByElement # to shorten test cases + group_by = ast.GroupBy([ - ast.Column('account'), - ast.Function('account_sortkey', [ - ast.Column(name='account')])], None) + _ge(ast.Column('account'), ''), + _ge(ast.Function('account_sortkey', [ + ast.Column(name='account')]), '') + ], None) order_by = [ast.OrderBy(ast.Function('account_sortkey', [ast.Column('account')]), ast.Ordering.ASC)] diff --git a/beanquery/query_env.py b/beanquery/query_env.py index 7606b193..fb621acb 100644 --- a/beanquery/query_env.py +++ b/beanquery/query_env.py @@ -28,6 +28,7 @@ from beanquery import query_compile from beanquery import types +from beanquery.query_compile import SENTINEL_EARLIER, SENTINEL_LATER class ColumnsRegistry(dict): diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 0358d632..5ad0cdef 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -1111,6 +1111,126 @@ def test_aggregated_group_by_with_having(self): ('Expenses:Bar', D(2.0)), ('Expenses:Foo', D(1.0)), ]) + def test_rollup_basic(self): + """Test ROLLUP functionality with a simple example.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, sum(number) as amount + GROUP BY ROLLUP (account) + ORDER BY account; + """, + [ + ('account', str), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', D('-1.00')), + ('Expenses:Restaurant', D('3.00')), + ('Liabilities:Credit-Card', D('-2.00')), + (qc.Sentinel(2, '(Total)'), D('0.00')), # Total row + ]) + + def test_cube_basic(self): + """Test CUBE functionality with a simple example.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, sum(number) as amount + GROUP BY CUBE (account) + ORDER BY account; + """, + [ + ('account', str), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', D('-1.00')), + ('Expenses:Restaurant', D('3.00')), + ('Liabilities:Credit-Card', D('-2.00')), + (qc.Sentinel(2, '(Total)'), D('0.00')), # Total row + ]) + + def test_rollup_two_columns(self): + """Test ROLLUP with two columns.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, year(date) as year, sum(number) as amount + GROUP BY ROLLUP (account, year(date)) + ORDER BY account, year; + """, + [ + ('account', str), + ('year', int), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', 2010, D('-1.00')), + ('Assets:Bank:Checking', qc.Sentinel(2, '(Total)'), D('-1.00')), # Subtotal for account + ('Expenses:Restaurant', 2010, D('3.00')), + ('Expenses:Restaurant', qc.Sentinel(2, '(Total)'), D('3.00')), # Subtotal for account + ('Liabilities:Credit-Card', 2010, D('-2.00')), + ('Liabilities:Credit-Card', qc.Sentinel(2, '(Total)'), D('-2.00')), # Subtotal for account + (qc.Sentinel(2, '(Total)'), qc.Sentinel(2, '(Total)'), D('0.00')), # Grand total + ]) + + def test_cube_two_columns(self): + """Test CUBE with two columns.""" + self.check_query( + """ + 2010-02-21 * "First" + Assets:Bank:Checking -1.00 USD + Expenses:Restaurant 1.00 USD + + 2010-02-23 * "Second" + Liabilities:Credit-Card -2.00 USD + Expenses:Restaurant 2.00 USD + """, + """ + SELECT account, year(date) as year, sum(number) as amount + GROUP BY CUBE (account, year(date)) + ORDER BY account, year; + """, + [ + ('account', str), + ('year', int), + ('amount', Decimal), + ], + [ + ('Assets:Bank:Checking', 2010, D('-1.00')), + ('Assets:Bank:Checking', qc.Sentinel(2, '(Total)'), D('-1.00')), # Subtotal for account + ('Expenses:Restaurant', 2010, D('3.00')), + ('Expenses:Restaurant', qc.Sentinel(2, '(Total)'), D('3.00')), # Subtotal for account + ('Liabilities:Credit-Card', 2010, D('-2.00')), + ('Liabilities:Credit-Card', qc.Sentinel(2, '(Total)'), D('-2.00')), # Subtotal for account + (qc.Sentinel(2, '(Total)'), 2010, D('0.00')), # Subtotal for year + (qc.Sentinel(2, '(Total)'), qc.Sentinel(2, '(Total)'), D('0.00')), # Grand total + ]) class TestExecuteOptions(QueryBase): diff --git a/beanquery/query_render.py b/beanquery/query_render.py index 24d53618..0ca0b39f 100644 --- a/beanquery/query_render.py +++ b/beanquery/query_render.py @@ -16,6 +16,8 @@ from beancount.core import inventory from beancount.core import position +from beanquery.query_compile import Sentinel + class Align(enum.Enum): LEFT = 0 @@ -156,12 +158,19 @@ class SetRenderer(ColumnRenderer): def __init__(self, ctx): super().__init__(ctx) self.sep = ctx.listsep + self.ctx = ctx def update(self, value): - self.maxwidth = max(self.maxwidth, sum(len(x) + len(self.sep) for x in value) - len(self.sep)) + self.maxwidth = max(self.maxwidth, sum(len(str(x)) + len(self.sep) for x in value) - len(self.sep)) def format(self, value): - return self.sep.join(str(x) for x in sorted(value)) + """Format the value.""" + if not value: + return '' + # Get the appropriate renderer for the first item's type + item = next(iter(value)) + renderer = _get_renderer(type(item), self.ctx) + return self.sep.join(renderer.format(item) for item in sorted(value)) class DateRenderer(ColumnRenderer): @@ -449,10 +458,18 @@ def render_rows(rows, renderers, ctx): for row in rows: - # Render the row cells. Do not pass missing values to the - # renderers but substitute them with the appropriate - # placeholder string. - cells = [render.format(value) if value is not None else null for render, value in zip(renderers, row)] + # Render the row cells. Handle special cases: + # - Sentinel: Convert to string representation + # - None: Substitute with null placeholder + # - Regular values: Pass to renderer + cells = [] + for render, value in zip(renderers, row): + if isinstance(value, Sentinel): + cells.append(str(value)) + elif value is not None: + cells.append(render.format(value)) + else: + cells.append(null) if not any(isinstance(cell, list) for cell in cells): # No multi line cells. Yield the row.