diff --git a/CHANGES.rst b/CHANGES.rst index d2cc076f..e31878e0 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -24,3 +24,7 @@ Version 0.1 (unreleased) behavior, the query should be written:: SELECT date, narration ORDER BY date DESC, narration DESC + +- Output names defined with ``SELECT ... AS`` can now be used in the + ``WHERE`` and ``HAVING`` clauses in addition to the ``GROUP BY`` and + ``ORDER BY`` clauses where they were already supported. diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index f5fafad5..46a7f867 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -341,20 +341,27 @@ class CompilationEnvironment: # The name of the context. context_name = None - # Maps of names to evaluators for columns and functions. - columns = None - functions = None + # Maps of names to evaluators for output names, columns, and functions. + names = {} + columns = {} + functions = {} def get_column(self, name): """Return a column accessor for the given named column. Args: name: A string, the name of the column to access. """ - try: - return self.columns[name]() - except KeyError as exc: - raise CompilationError("Invalid column name '{}' in {} context.".format( - name, self.context_name)) from exc + expr = self.names.get(name) + if expr is not None: + # Expression evaluatoes may keep state (for example + # aggregate functions) thus we need to return a copy. + return copy.copy(expr) + + column = self.columns.get(name) + if column is not None: + return column() + + raise CompilationError(f'Unknown column "{name}" in {self.context_name}') def get_function(self, name, operands): """Return a function accessor for the given named function. @@ -492,7 +499,7 @@ def is_hashable_type(node): return not issubclass(node.dtype, inventory.Inventory) -def find_unique_name(name, allocated_set): +def unique_name(name, allocated_set): """Come up with a unique name for 'name' amongst 'allocated_set'. Args: @@ -525,9 +532,15 @@ def compile_targets(targets, environ): Args: targets: A list of target expressions from the parser. environ: A compilation context for the targets. + Returns: - A list of compiled target expressions with resolved names. + A tuple containing list of compiled expressions and a dictionary + mapping explicit output names assigned with the AS keyword to + compiled extpressions. + """ + names = {} + # Bind the targets expressions to the execution context. if isinstance(targets, query_parser.Wildcard): # Insert the full list of available columns. @@ -539,11 +552,19 @@ def compile_targets(targets, environ): target_names = set() for target in targets: c_expr = compile_expression(target.expression, environ) - target_name = find_unique_name( - target.name or query_parser.get_expression_name(target.expression), - target_names) - target_names.add(target_name) - c_targets.append(EvalTarget(c_expr, target_name, is_aggregate(c_expr))) + if target.name: + # The target as an explicit output name: make sure that it + # does not collied with any other output name. + name = target.name + if name in target_names: + raise CompilationError(f'Duplicate output name "{name}" in SELECT list') + # Keep track of explicit output names. + names[name] = c_expr + else: + # Otherwise generate an unique output name. + name = unique_name(query_parser.get_expression_name(target.expression), target_names) + target_names.add(name) + c_targets.append(EvalTarget(c_expr, name, is_aggregate(c_expr))) columns, aggregates = get_columns_and_aggregates(c_expr) @@ -559,7 +580,7 @@ def compile_targets(targets, environ): raise CompilationError( "Aggregates of aggregates are not allowed") - return c_targets + return c_targets, names def compile_group_by(group_by, c_targets, environ): @@ -843,7 +864,9 @@ def compile_select(select, targets_environ, postings_environ, entries_environ): c_from = compile_from(select.from_clause, entries_environ) # Compile the targets. - c_targets = compile_targets(select.targets, targets_environ) + c_targets, output_names = compile_targets(select.targets, targets_environ) + targets_environ.names = output_names + postings_environ.names = output_names # Bind the WHERE expression to the execution environment. c_where = compile_expression(select.where_clause, postings_environ) diff --git a/beanquery/query_compile_test.py b/beanquery/query_compile_test.py index d25312a4..665bb087 100644 --- a/beanquery/query_compile_test.py +++ b/beanquery/query_compile_test.py @@ -189,12 +189,11 @@ def test_compile_EvalSub(self): class TestCompileMisc(unittest.TestCase): - def test_find_unique_names(self): - self.assertEqual('date', qc.find_unique_name('date', {})) - self.assertEqual('date', qc.find_unique_name('date', {'account', 'number'})) - self.assertEqual('date_1', qc.find_unique_name('date', {'date', 'number'})) - self.assertEqual('date_2', - qc.find_unique_name('date', {'date', 'date_1', 'date_3'})) + def test_unique_name(self): + self.assertEqual('date', qc.unique_name('date', {})) + self.assertEqual('date', qc.unique_name('date', {'account', 'number'})) + self.assertEqual('date_1', qc.unique_name('date', {'date', 'number'})) + self.assertEqual('date_2', qc.unique_name('date', {'date', 'date_1', 'date_3'})) class CompileSelectBase(unittest.TestCase): @@ -349,10 +348,9 @@ def test_compile_targets_wildcard(self): for target in query.c_targets)) def test_compile_targets_named(self): - # Test the wildcard expansion. - query = self.compile("SELECT length(account), account as a, date;") + query = self.compile("SELECT length(account) AS l, account AS a, date;") self.assertEqual( - [qc.EvalTarget(qe.F('length', str)([qe.AccountColumn()]), 'length_account', False), + [qc.EvalTarget(qe.F('length', str)([qe.AccountColumn()]), 'l', False), qc.EvalTarget(qe.AccountColumn(), 'a', False), qc.EvalTarget(qe.DateColumn(), 'date', False)], query.c_targets) @@ -610,6 +608,19 @@ def test_compile_order_by_aggregate(self): self.assertEqual([(1, False)], query.order_spec) +class TestCompileSelectNamed(CompileSelectBase): + + def test_compile_select_where_name(self): + query = self.compile(""" + SELECT date AS d WHERE d = 2022-03-30; + """) + + def test_compile_select_having_name(self): + query = self.compile(""" + SELECT sum(position) AS s GROUP BY year HAVING not empty(s); + """) + + class TestTranslationJournal(CompileSelectBase): maxDiff = 4096 diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index bf27e5ca..ff57c5a0 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -1079,5 +1079,44 @@ def test_flatten(self): ]) +class TestOutputNames(QueryBase): + + data = """ + 2020-01-01 open Assets:Bank + 2020-01-01 open Assets:Receivable + 2020-01-01 open Income:Sponsorship + + 2020-03-01 * "Sponsorship from A" + invoice: "A01" + Assets:Receivable 100.00 USD + Income:Sponsorship -100.00 USD + + 2020-03-01 * "Sponsorship from B" + invoice: "B01" + Assets:Receivable 30.00 USD + Income:Sponsorship -30.00 USD + + 2020-03-10 * "Payment from A" + invoice: "A01" + Assets:Bank 100.00 USD + Assets:Receivable -100.00 USD + """ + + def test_output_names(self): + self.check_query(self.data, """ + SELECT + entry_meta('invoice') AS invoice, + sum(position) AS balance + WHERE + root(account, 2) = 'Assets:Receivable' + GROUP BY + invoice + HAVING + not empty(balance); + """, + [('invoice', object), ('balance', inventory.Inventory)], + [('B01', inventory.from_string("30.00 USD"))]) + + if __name__ == '__main__': unittest.main()