From fc8d0d47af9fc1f186f5f986d0e7a98d3dc5c86a Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Fri, 7 Nov 2025 10:42:09 +0100 Subject: [PATCH 1/2] Support type matching between parameterized and unparameterized types Allow functions/operators registered with unparameterized types (e.g., set) to match operands with parameterized generic types (e.g., typing.Set[str]). --- beanquery/compiler.py | 11 +++--- beanquery/sources/beancount.py | 10 +++--- beanquery/types.py | 65 +++++++++++++++++++++++++++++----- 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 6e6c442e..bba59982 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -525,15 +525,14 @@ def _all(self, node): left = self._compile(node.left) - # lookup operator implementaton and check typing + # Lookup operator implementation and check typing. op = self._OPERATORS[node.op] - for func in OPERATORS[op]: - if func.__intypes__ == [right_element_dtype, left.dtype]: - break - else: + func = types.operator_lookup(OPERATORS[op], [right_element_dtype, left.dtype]) + + if func is None: raise CompilationError( f'operator "{op.__name__.lower()}(' - f'{left.dtype.__name__}, {right_element_dtype.__name__})" not supported', node) + f'{types.name(left.dtype)}, {types.name(right_element_dtype)})" not supported', node) # need to instantiate the operaotr implementation to get to the underlying function operator = func(None, None).operator diff --git a/beanquery/sources/beancount.py b/beanquery/sources/beancount.py index e388b6c6..fd9cfe5d 100644 --- a/beanquery/sources/beancount.py +++ b/beanquery/sources/beancount.py @@ -357,12 +357,12 @@ def description(entry): return None return ' | '.join(filter(None, [entry.payee, entry.narration])) - @columns.register(set) + @columns.register(typing.Set[str]) def tags(entry): """The set of tags of the transaction.""" return getattr(entry, 'tags', None) - @columns.register(set) + @columns.register(typing.Set[str]) def links(entry): """The set of links of the transaction.""" return getattr(entry, 'links', None) @@ -493,12 +493,12 @@ def description(context): """A combination of the payee + narration for the transaction of this posting.""" return ' | '.join(filter(None, [context.entry.payee, context.entry.narration])) - @columns.register(set) + @columns.register(typing.Set[str]) def tags(context): """The set of tags of the parent transaction for this posting.""" return context.entry.tags - @columns.register(set) + @columns.register(typing.Set[str]) def links(context): """The set of links of the parent transaction for this posting.""" return context.entry.links @@ -513,7 +513,7 @@ def account(context): """The account of the posting.""" return context.posting.account - @columns.register(set) + @columns.register(typing.Set[str]) def other_accounts(context): """The list of other accounts in the transaction, excluding that of this posting.""" return sorted({posting.account for posting in context.entry.postings if posting is not context.posting}) diff --git a/beanquery/types.py b/beanquery/types.py index 5391c8cf..462476d5 100644 --- a/beanquery/types.py +++ b/beanquery/types.py @@ -42,15 +42,43 @@ def __init_subclass__(cls): def _bases(t): + """Return the type hierarchy for a given type, excluding ``object``. + + This function extracts the Method Resolution Order (MRO) for a type, + which includes the type itself and all its base classes. The ``object`` + type is excluded from the hierarchy (except when the type IS ``object``) + because BQL uses ``object`` to represent untyped values, not as a universal + base type. This prevents functions registered for untyped values from + matching all typed values. + + For generic types like ``typing.Set[str]``, the origin type (``set``) is + extracted and its MRO is returned, allowing parameterized types to match + functions/operators registered with unparameterized types. + + Args: + t: A type object, which can be a plain type (str, int, set) or a + generic type (typing.Set[str], typing.List[int]). + + Returns: + A tuple of types representing the type hierarchy, with ``object`` excluded + unless the input type is ``object`` itself or ``NoneType``. + """ if t is NoneType: return (object,) + + # Handle generic types like typing.Set[str], typing.List[int], etc. + # Extract the origin type (e.g., set from Set[str]) and include it + # in the bases so that functions registered with unparameterized types + # (e.g., @function([set], ...)) can match parameterized types (Set[str]) + origin = typing.get_origin(t) + if origin is not None: + origin_bases = origin.__mro__ + if len(origin_bases) > 1 and origin_bases[-1] is object: + return origin_bases[:-1] + return origin_bases + bases = t.__mro__ if len(bases) > 1 and bases[-1] is object: - # All types that are not ``object`` have more than one class - # in their ``__mro__``. BQL uses ``object`` for untypes - # values. Do not return ``object`` as base for strict types, - # to avoid functions taking untyped onjects to accept all - # values. return bases[:-1] return bases @@ -59,9 +87,11 @@ def function_lookup(functions, name, operands): """Lookup a BQL function implementation. Args: - functions: The functions registry to interrogate. - name: The function name. - operands: Function operands. + functions: A dict mapping function names (str) to lists of function + implementations. Each implementation has an __intypes__ attribute + specifying the expected operand types. + name: The function name (str). + operands: Function operands, each with a .dtype attribute. Returns: A EvalNode (or subclass) instance or None if the function was not found. @@ -73,6 +103,25 @@ def function_lookup(functions, name, operands): return None +def operator_lookup(operators, operand_types): + """Lookup an operator implementation by matching operand types. + + Args: + operators: A list of operator implementations. Each implementation has + an __intypes__ attribute specifying the expected operand types as a + list (e.g., [str, str] for a binary operator on strings). + operand_types: Sequence of types for the operands (e.g., [str, str]). + + Returns: + An operator implementation or None if not found. + """ + for signature in itertools.product(*(_bases(t) for t in operand_types)): + for op in operators: + if op.__intypes__ == list(signature): + return op + return None + + # Map types to their BQL name. Used to find the name of the type cast funtion. MAP = { bool: 'bool', From f8ac6294c2365c09df380dbfa2aadb7273e64199 Mon Sep 17 00:00:00 2001 From: Moritz Lell Date: Fri, 7 Nov 2025 11:38:52 +0100 Subject: [PATCH 2/2] Allow any()/all() on LHS to allow ANY(set) ~ pattern There were two matching operators: * subject ~ pattern (case-insensitive) * pattern ?~ subject (case-sensitive) This is somewhat confusing, but allows for this: SELECT * WHERE "Assets:" ?~ ANY(accounts) The ANY operator was only allowed on the RHS. This commit extends it, so it can be used on either side, allowing for SELECT WHERE ANY(accounts) ~ "Assets:" --- beanquery/compiler.py | 65 ++++++++++++++++++------- beanquery/parser/ast.py | 4 +- beanquery/parser/bql.ebnf | 8 +++- beanquery/parser/parser.py | 84 ++++++++++++++++++++++++--------- beanquery/query_compile.py | 44 ++++++++++++----- beanquery/query_execute_test.py | 8 ++-- 6 files changed, 155 insertions(+), 58 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index bba59982..ced692d4 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -506,39 +506,70 @@ def _and(self, node: ast.And): @_compile.register(ast.All) @_compile.register(ast.Any) def _all(self, node): + # This parses a node of the form + # All(left, op, right, side), which arises from the syntax: + # + # ALL( ... ) (side == 'lhs') + # ALL( ... ) (side == 'rhs') + # + # Example: ANY(accounts) = "Assets:Checking" + # + left = self._compile(node.left) right = self._compile(node.right) + + if node.side == 'lhs': + collection = left + collection_node = node.left + value = right + elif node.side == 'rhs': + collection = right + collection_node = node.right + value = left + + if isinstance(collection, EvalQuery): + if len(collection.columns) != 1: + raise CompilationError('subquery has too many columns', collection_node) + collection = EvalConstantSubquery1D(collection) + + collection_dtype = typing.get_origin(collection.dtype) or collection.dtype + value_dtype = typing.get_origin(value.dtype) or value.dtype - if isinstance(right, EvalQuery): - if len(right.columns) != 1: - raise CompilationError('subquery has too many columns', node.right) - right = EvalConstantSubquery1D(right) + if collection_dtype not in {list, set, EvalConstantSubquery1D}: + raise CompilationError( + f'ANY/ALL requires a collection (list, set, or subquery), got {types.name(collection.dtype)}', + node) - right_dtype = typing.get_origin(right.dtype) or right.dtype - if right_dtype not in {list, set}: - raise CompilationError(f'not a list or set but {right_dtype}', node.right) - args = typing.get_args(right.dtype) + collection_dtype = typing.get_origin(collection.dtype) or collection.dtype + if collection_dtype not in {list, set}: + raise CompilationError(f'not a list or set but {collection_dtype}', collection_node) + args = typing.get_args(collection.dtype) if args: assert len(args) == 1 - right_element_dtype = args[0] + collection_element_dtype = args[0] else: - right_element_dtype = object - - left = self._compile(node.left) + collection_element_dtype = object # Lookup operator implementation and check typing. op = self._OPERATORS[node.op] - func = types.operator_lookup(OPERATORS[op], [right_element_dtype, left.dtype]) + if node.side == 'rhs': + # value op ANY(collection) -> operator(value, element) + func = types.operator_lookup(OPERATORS[op], [value.dtype, collection_element_dtype]) + left_type, right_type = value.dtype, collection_element_dtype + else: # node.side == 'lhs' + # ANY(collection) op value -> operator(element, value) + func = types.operator_lookup(OPERATORS[op], [collection_element_dtype, value.dtype]) + left_type, right_type = collection_element_dtype, value.dtype if func is None: raise CompilationError( f'operator "{op.__name__.lower()}(' - f'{types.name(left.dtype)}, {types.name(right_element_dtype)})" not supported', node) + f'{types.name(left_type)}, {types.name(right_type)})" not supported', node) # need to instantiate the operaotr implementation to get to the underlying function operator = func(None, None).operator cls = EvalAll if type(node) is ast.All else EvalAny - return cls(operator, left, right) + return cls(operator, collection, value, node.side) @_compile.register def _function(self, node: ast.Function): @@ -582,9 +613,9 @@ def _function(self, node: ast.Function): ast.Attribute(ast.Column('entry', parseinfo=node.parseinfo), 'meta'), key])]) return self._compile(node) - # Replace ``has_account(regexp)`` with ``('(?i)' + regexp) ~? any (accounts)``. + # Replace ``has_account(regexp)`` with ``any (accounts) ~ ('(?i)' + regexp) ``. if node.fname == 'has_account': - node = ast.Any(ast.Add(ast.Constant('(?i)'), node.operands[0]), '?~', ast.Column('accounts')) + node = ast.Any(ast.Column('accounts'), '~', ast.Add(ast.Constant('(?i)'), node.operands[0]), side = 'lhs') return self._compile(node) function = function(self.context, operands) diff --git a/beanquery/parser/ast.py b/beanquery/parser/ast.py index 0fc9e49e..638baf0e 100644 --- a/beanquery/parser/ast.py +++ b/beanquery/parser/ast.py @@ -314,8 +314,8 @@ class Sub(BinaryOp): __slots__ = () -Any = node('Any', 'left op right') -All = node('All', 'left op right') +Any = node('Any', 'left op right side') +All = node('All', 'left op right side') CreateTable = node('CreateTable', 'name columns using query') diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 1c078b63..19aec514 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -133,11 +133,15 @@ comparison ; any::Any - = left:sum op:op 'any' '(' right:expression ')' + = + | left:sum op:op 'any' '(' right:expression ')' side:`rhs` + | 'any' '(' left:expression ')' op:op right:sum side:`lhs` ; all::All - = left:sum op:op 'all' '(' right:expression ')' + = + | left:sum op:op 'all' '(' right:expression ')' side:`rhs` + | 'all' '(' left:expression ')' op:op right:sum side:`lhs` ; op diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 2ec9ee84..bb48f9f3 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -577,7 +577,7 @@ def _comparison_(self): self._sum_() self._error( 'expecting one of: ' - ' ' + "'all' 'any' " ' ' ' ' ' ' @@ -586,29 +586,71 @@ def _comparison_(self): @tatsumasu('Any') @nomemo def _any_(self): - self._sum_() - self.name_last_node('left') - self._op_() - self.name_last_node('op') - self._token('any') - self._token('(') - self._expression_() - self.name_last_node('right') - self._token(')') - self._define(['left', 'op', 'right'], []) + with self._choice(): + with self._option(): + self._sum_() + self.name_last_node('left') + self._op_() + self.name_last_node('op') + self._token('any') + self._token('(') + self._expression_() + self.name_last_node('right') + self._token(')') + self._constant('rhs') + self.name_last_node('side') + self._define(['left', 'op', 'right', 'side'], []) + with self._option(): + self._token('any') + self._token('(') + self._expression_() + self.name_last_node('left') + self._token(')') + self._op_() + self.name_last_node('op') + self._sum_() + self.name_last_node('right') + self._constant('lhs') + self.name_last_node('side') + self._define(['left', 'op', 'right', 'side'], []) + self._error( + 'expecting one of: ' + "'any' " + ) @tatsumasu('All') def _all_(self): - self._sum_() - self.name_last_node('left') - self._op_() - self.name_last_node('op') - self._token('all') - self._token('(') - self._expression_() - self.name_last_node('right') - self._token(')') - self._define(['left', 'op', 'right'], []) + with self._choice(): + with self._option(): + self._sum_() + self.name_last_node('left') + self._op_() + self.name_last_node('op') + self._token('all') + self._token('(') + self._expression_() + self.name_last_node('right') + self._token(')') + self._constant('rhs') + self.name_last_node('side') + self._define(['left', 'op', 'right', 'side'], []) + with self._option(): + self._token('all') + self._token('(') + self._expression_() + self.name_last_node('left') + self._token(')') + self._op_() + self.name_last_node('op') + self._sum_() + self.name_last_node('right') + self._constant('lhs') + self.name_last_node('side') + self._define(['left', 'op', 'right', 'side'], []) + self._error( + 'expecting one of: ' + "'all' " + ) @tatsumasu() def _op_(self): diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index eda87ec7..193d9a70 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -448,13 +448,14 @@ def __call__(self, context): class EvalAny(EvalNode): - __slots__ = ('op', 'left', 'right') + __slots__ = ('op', 'left', 'right', 'collection_side') - def __init__(self, op, left, right): + def __init__(self, op, left, right, collection_side): super().__init__(bool) self.op = op self.left = left self.right = right + self.collection_side = collection_side def __call__(self, row): left = self.left(row) @@ -463,26 +464,45 @@ def __call__(self, row): right = self.right(row) if right is None: return None - return any(self.op(left, x) for x in right) + + if self.collection_side == 'right': + # Original form: value op ANY(collection) + return any(self.op(left, x) for x in right) + else: # collection_side == 'left' + # New form: ANY(collection) op value + return any(self.op(x, right) for x in left) class EvalAll(EvalNode): - __slots__ = ('op', 'left', 'right') + __slots__ = ('op', 'collection', 'value', 'side') - def __init__(self, op, left, right): + def __init__(self, op, collection, value, side): + """ + Either: ALL() (side == 'lhs') + Or: ALL() (side == 'rhs') + """ super().__init__(bool) self.op = op - self.left = left - self.right = right + self.collection = collection + self.value = value + if side not in ['lhs', 'rhs']: + raise ValueError('EvalAll: Parameter "side" must be one of "lhs", "rhs"') + self.side = side def __call__(self, row): - left = self.left(row) - if left is None: + collection = self.collection(row) + if collection is None: return None - right = self.right(row) - if right is None: + value = self.value(row) + if value is None: return None - return all(self.op(left, x) for x in right) + + if self.side == 'rhs': + # Syntax form: value op ALL(collection) + return all(self.op(value, x) for x in collection) + else: # side == 'lhs' + # Syntax form: ALL(collection) op value + return all(self.op(x, value) for x in collection) class EvalRow(EvalNode): diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 0358d632..9bc0b75c 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -1667,11 +1667,11 @@ def test_in_accounts_transactions(self): ] ) - def test_mathes_any_accounts_transactions(self): + def test_matches_any_accounts_transactions(self): self.check_query(self.data, """ SELECT date, narration FROM #transactions - WHERE ':Two' ?~ ANY(accounts) + WHERE ANY(accounts) ~ ':Two' """, (('date', datetime.date), ('narration', str)), [ @@ -1679,11 +1679,11 @@ def test_mathes_any_accounts_transactions(self): ] ) - def test_mathes_all_accounts_transactions(self): + def test_matches_all_accounts_transactions(self): self.check_query(self.data, """ SELECT date, narration FROM #transactions - WHERE '(?i):two|:cash' ?~ ALL(accounts) + WHERE ALL(accounts) ~ '(?i):two|:cash' """, (('date', datetime.date), ('narration', str)), [