diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 6e6c442e..21bbe794 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -15,20 +15,22 @@ from .query_compile import ( EvalAggregator, - EvalAnd, EvalAll, + EvalAnd, EvalAny, EvalCoalesce, EvalColumn, EvalConstant, + EvalConstantSubquery1D, EvalCreateTable, EvalGetItem, EvalGetter, + EvalHashJoin, EvalInsert, EvalOr, EvalPivot, + EvalProjection, EvalQuery, - EvalConstantSubquery1D, EvalRow, EvalTarget, FUNCTIONS, @@ -54,7 +56,8 @@ def __init__(self, message, node=None): class Compiler: def __init__(self, context): self.context = context - self.stack = [context.tables.get(None)] + self.stack = [] + self.columns = {} @property def table(self): @@ -98,7 +101,11 @@ def _compile(self, node: Optional[ast.Node]): @_compile.register def _select(self, node: ast.Select): - self.stack.append(self.table) + self.stack.append(self.context.tables.get(None)) + + # JOIN. + if isinstance(node.from_clause, ast.Join): + return self._compile_join(node) # Compile the FROM clause. c_from_expr = self._compile_from(node.from_clause) @@ -177,11 +184,12 @@ def _compile_from(self, node): # FROM expression. if isinstance(node, ast.From): # Check if the FROM expression is a column name belongin to the current table. - if isinstance(node.expression, ast.Column): - column = self.table.columns.get(node.expression.name) + if isinstance(node.expression, ast.Column) and len(node.expression.ids) == 1: + name = node.expression.ids[0].name + column = self.table.columns.get(name) if column is None: # When it is not, threat it as a table name. - table = self.context.tables.get(node.expression.name) + table = self.context.tables.get(name) if table is not None: self.table = table return None @@ -214,8 +222,7 @@ def _compile_targets(self, targets): # Bind the targets expressions to the execution context. if isinstance(targets, ast.Asterisk): # Insert the full list of available columns. - targets = [ast.Target(ast.Column(name), None) - for name in self.table.wildcard_columns] + targets = [ast.Target(ast.Column([ast.Name(name)]), None) for name in self.table.wildcard_columns] # Compile targets. c_targets = [] @@ -287,7 +294,7 @@ def _compile_order_by(self, order_by, c_targets): # simple Column expressions. If they refer to a target name, we # resolve them. if isinstance(column, ast.Column): - name = column.name + name = '.'.join(i.name for i in column.ids) index = targets_name_map.get(name, None) # Otherwise we compile the expression and add it to the list of @@ -337,7 +344,7 @@ def _compile_pivot_by(self, pivot_by, targets, group_indexes): continue # Process target references by name. - if isinstance(column, ast.Column): + if isinstance(column, ast.Name): index = names.get(column.name, None) if index is None: raise CompilationError(f'PIVOT BY column {column!r} is not in the targets list') @@ -403,7 +410,7 @@ def _compile_group_by(self, group_by, c_targets): # simple Column expressions. If they refer to a target name, we # resolve them. if isinstance(column, ast.Column): - name = column.name + name = '.'.join(i.name for i in column.ids) index = targets_name_map.get(name, None) # Otherwise we compile the expression and add it to the list of @@ -475,12 +482,121 @@ def _compile_group_by(self, group_by, c_targets): return new_targets[len(c_targets):], group_indexes, having_index + def _compile_join(self, node): + join = node.from_clause + + left = self.context.tables.get(join.left.name) + if left is None: + raise CompilationError(f'table "{join.left.name}" does not exist', join.left) + right = self.context.tables.get(join.right.name) + if right is None: + raise CompilationError(f'table "{join.right.name}" does not exist', join.right) + self.table = right + self.stack.append(left) + + if join.using is not None: + join.constraint = ast.Equal( + ast.Column([ast.Name(join.left.name) , ast.Name(join.using.name)]), + ast.Column([ast.Name(join.right.name) , ast.Name(join.using.name)]), + ) + + constraint = self._compile(join.constraint) + keycolnames = [col for t, col in self.columns.keys() if t == left.name] + targets = self._compile_targets(node.targets) + + left_column_names = [col for t, col in self.columns.keys() if t == left.name] + right_column_names = [col for t, col in self.columns.keys() if t == right.name] + + left_p = EvalProjection(left, [left[col] for col in left_column_names]) + right_p = EvalProjection(right, [right[col] for col in right_column_names]) + + from beanquery.tables import Table + + def itemgetter(item, datatype): + def func(row): + return row[item] + func.dtype = datatype + func.__qualname__ = func.__name__ = f'column[{item}, {datatype.__name__}]' + return func + + table = Table() + table.columns = {} + for i, column in enumerate(sorted(self.columns.items(), key=lambda x: x[0][0])): + key, col = column + tname, colname = key + table.columns[f'{tname}.{colname}'] = itemgetter(i, col.dtype) + table.columns[f'{colname}'] = itemgetter(i, col.dtype) + + self.stack = [table] + constraint = self._compile(join.constraint) + + left_columns = {col: itemgetter(i, left.columns[col].dtype) for i, col in enumerate(left_column_names)} + keycols = [left_columns[name] for name in keycolnames] + def keyfunc(lrow, keycols=keycols): + return tuple(keycol(lrow) for keycol in keycols) + + join = EvalHashJoin(left_p, right_p, constraint, keyfunc) + + targets = self._compile_targets(node.targets) + cols = [] + for t in targets: + expr = t.c_expr + expr.name = t.name + cols.append(expr) + + return EvalProjection(join, cols) + + def _resolve_column(self, node: ast.Column): + parts = node.ids[::-1] + + # FIXME!! + if len(parts) > 1: + colname = f'{parts[-1].name}.{parts[-2].name}' + for table in reversed(self.stack): + column = table.columns.get(colname) + if column is not None: + self.columns[(table.name, colname)] = column + return column, parts[:-2] + + colname = parts.pop().name + for table in reversed(self.stack): + column = table.columns.get(colname) + if column is not None: + self.columns[(table.name, colname)] = column + return column, parts + if parts: + # table.column + name = colname + colname = parts.pop().name + for table in reversed(self.stack): + if table.name == name: + column = table.columns.get(colname) + if column is not None: + self.columns[(table.name, colname)] = column + return column, parts + raise CompilationError(f'column "{colname}" not found in table "{table.name}"', node) + @_compile.register def _column(self, node: ast.Column): - column = self.table.columns.get(node.name) - if column is not None: - return column - raise CompilationError(f'column "{node.name}" not found in table "{self.table.name}"', node) + column, parts = self._resolve_column(node) + for part in parts: + column = self._resolve_attribute(column, part) + return column + + # operand = None + # if isinstance(node.operand, ast.Column): + # # This can be table.column or column.attribute. + # if node.operand.name in self.table.columns: + # operand = self._column(node.operand) + # elif f'{node.operand.name}.{node.name}' in self.table.columns: + # return self._column(ast.Column(f'{node.operand.name}.{node.name}')) + # else: + # for table in reversed(self.stack): + # if table.name == node.operand.name: + # column = table.columns.get(node.name) + # if column: + # self.columns.append((table.name, node.name)) + # return column @_compile.register def _or(self, node: ast.Or): @@ -567,25 +683,25 @@ def _function(self, node: ast.Function): # Replace ``meta(key)`` with ``meta[key]``. if node.fname == 'meta': key = node.operands[0] - node = ast.Function('getitem', [ast.Column('meta', parseinfo=node.parseinfo), key]) + node = ast.Function('getitem', [ast.Column([ast.Name('meta')], parseinfo=node.parseinfo), key]) return self._compile(node) # Replace ``entry_meta(key)`` with ``entry.meta[key]``. if node.fname == 'entry_meta': key = node.operands[0] - node = ast.Function('getitem', [ast.Attribute(ast.Column('entry', parseinfo=node.parseinfo), 'meta'), key]) + node = ast.Function('getitem', [ast.Attribute(ast.Column([ast.Name('entry')], parseinfo=node.parseinfo), 'meta'), key]) return self._compile(node) # Replace ``any_meta(key)`` with ``getitem(meta, key, entry.meta[key])``. if node.fname == 'any_meta': key = node.operands[0] - node = ast.Function('getitem', [ast.Column('meta', parseinfo=node.parseinfo), key, ast.Function('getitem', [ - ast.Attribute(ast.Column('entry', parseinfo=node.parseinfo), 'meta'), key])]) + node = ast.Function('getitem', [ast.Column([ast.Name('meta')], parseinfo=node.parseinfo), key, ast.Function('getitem', [ + ast.Attribute(ast.Column([ast.Name('entry')], parseinfo=node.parseinfo), 'meta'), key])]) return self._compile(node) # Replace ``has_account(regexp)`` with ``('(?i)' + regexp) ~? any (accounts)``. if node.fname == 'has_account': - node = ast.Any(ast.Add(ast.Constant('(?i)'), node.operands[0]), '?~', ast.Column('accounts')) + node = ast.Any(ast.Add(ast.Constant('(?i)'), node.operands[0]), '?~', ast.Column([ast.Name('accounts')])) return self._compile(node) function = function(self.context, operands) @@ -601,9 +717,7 @@ def _subscript(self, node: ast.Subscript): return EvalGetItem(operand, node.key) raise CompilationError('column type is not subscriptable', node) - @_compile.register - def _attribute(self, node: ast.Attribute): - operand = self._compile(node.operand) + def _resolve_attribute(self, operand, node): dtype = types.ALIASES.get(operand.dtype, operand.dtype) if issubclass(dtype, types.Structure): getter = dtype.columns.get(node.name) @@ -612,6 +726,11 @@ def _attribute(self, node: ast.Attribute): return EvalGetter(operand, getter, getter.dtype) raise CompilationError('column type is not structured', node) + @_compile.register + def _attribute(self, node: ast.Attribute): + operand = self._compile(node.operand) + return self._resolve_attribute(operand, node) + @_compile.register def _unaryop(self, node: ast.UnaryOp): operand = self._compile(node.operand) @@ -711,7 +830,7 @@ def _constant(self, node: ast.Constant): # in the current table. if isinstance(node.value, str) and node.text and node.text[0] == '"': if node.value in self.table.columns: - return self._column(ast.Column(node.value)) + return self._column(ast.Column([ast.Name(node.value)])) return EvalConstant(node.value) @_compile.register @@ -732,7 +851,7 @@ def _journal(self, node: ast.Journal): @_compile.register def _print(self, node: ast.Print): - self.table = self.context.tables.get('entries') + self.stack.append(self.context.tables.get('entries')) expr = self._compile_from(node.from_clause) targets = [EvalTarget(EvalRow(), 'ROW(*)', False)] return EvalQuery(self.table, targets, expr, None, None, None, None, False) @@ -854,7 +973,7 @@ def get_target_name(target): if target.name is not None: return target.name if isinstance(target.expression, ast.Column): - return target.expression.name + return '.'.join(i.name for i in target.expression.ids) return target.expression.text.strip() diff --git a/beanquery/parser/ast.py b/beanquery/parser/ast.py index 0fc9e49e..c0e27761 100644 --- a/beanquery/parser/ast.py +++ b/beanquery/parser/ast.py @@ -145,6 +145,8 @@ class From(Node): clear: Optional[bool] = None parseinfo: Any = dataclasses.field(default=None, compare=False, repr=False) +Join = node('Join', 'left right constraint using') + # A GROUP BY clause. # # Attributes: @@ -180,11 +182,14 @@ def __repr__(self): # name: The table name. Table = node('Table', 'name') +Name = node('Name', 'name') + # A reference to a column. # # Attributes: # name: A string, the name of the column to access. -Column = node('Column', 'name') +Column = node('Column', 'ids') + # A function call. # diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 5dd3ec25..5e7de865 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -25,7 +25,7 @@ statement select::Select = 'SELECT' ['DISTINCT' distinct:`True`] targets:(','.{ target }+ | asterisk) - ['FROM' from_clause:(_table | subselect | from)] + ['FROM' from_clause:(_table | join | subselect | from)] ['WHERE' where_clause:expression] ['GROUP' 'BY' group_by:groupby] ['ORDER' 'BY' order_by:','.{order}+] @@ -45,6 +45,10 @@ from::From | expression:expression ['OPEN' 'ON' open:date] ['CLOSE' ('ON' close:date | {} close:`True`)] ['CLEAR' clear:`True`] ; +join::Join + = left:table 'JOIN' ~ right:table ('ON' constraint:expression | 'USING' using:name) + ; + _table::Table = | name:/#([a-zA-Z_][a-zA-Z0-9_]*)?/ @@ -68,7 +72,7 @@ ordering ; pivotby::PivotBy - = columns+:(integer | column) ',' columns+:(integer | column) + = columns+:(integer | name) ',' columns+:(integer | name) ; target::Target @@ -301,10 +305,14 @@ function::Function | fname:identifier '(' operands+:asterisk ')' ; -column::Column +name::Name = name:identifier ; +column::Column + = ids:'.'.{ name }+ + ; + literal = | date @@ -396,6 +404,6 @@ create_table::CreateTable insert::Insert = 'INSERT' 'INTO' ~ table:table - ['(' columns:','.{column} ')'] - 'VALUES' '(' values:','.{expression} ')' + ['(' columns:','.{ name } ')'] + 'VALUES' '(' values:','.{ expression } ')' ; diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 74f92c17..8fe1dcb1 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -157,13 +157,15 @@ def block1(): with self._choice(): with self._option(): self.__table_() + with self._option(): + self._join_() with self._option(): self._subselect_() with self._option(): self._from_() self._error( 'expecting one of: ' - '<_table> ' + '<_table> ' ) self.name_last_node('from_clause') self._define(['from_clause'], []) @@ -316,6 +318,32 @@ def _from_(self): ' ' ) + @tatsumasu('Join') + def _join_(self): + self._table_() + self.name_last_node('left') + self._token('JOIN') + self._cut() + self._table_() + self.name_last_node('right') + with self._group(): + with self._choice(): + with self._option(): + self._token('ON') + self._expression_() + self.name_last_node('constraint') + self._define(['constraint'], []) + with self._option(): + self._token('USING') + self._name_() + self.name_last_node('using') + self._define(['using'], []) + self._error( + 'expecting one of: ' + "'ON' 'USING'" + ) + self._define(['constraint', 'left', 'right', 'using'], []) + @tatsumasu('Table') def __table_(self): with self._choice(): @@ -400,10 +428,10 @@ def _pivotby_(self): with self._option(): self._integer_() with self._option(): - self._column_() + self._name_() self._error( 'expecting one of: ' - ' ' + ' ' ) self.add_last_node_to_name('columns') self._token(',') @@ -412,10 +440,10 @@ def _pivotby_(self): with self._option(): self._integer_() with self._option(): - self._column_() + self._name_() self._error( 'expecting one of: ' - ' ' + ' ' ) self.add_last_node_to_name('columns') self._define( @@ -955,7 +983,8 @@ def _atom_(self): "'%(' '%s' 'SELECT' " ' ' ' ' - ' ' + '' ) @tatsumasu('Placeholder') @@ -1011,11 +1040,22 @@ def block1(): '' ) - @tatsumasu('Column') - def _column_(self): + @tatsumasu('Name') + def _name_(self): self._identifier_() self.name_last_node('name') + @tatsumasu('Column') + def _column_(self): + + def sep0(): + self._token('.') + + def block1(): + self._name_() + self._positive_gather(block1, sep0) + self.name_last_node('ids') + @tatsumasu() def _literal_(self): with self._choice(): @@ -1250,7 +1290,7 @@ def sep0(): self._token(',') def block1(): - self._column_() + self._name_() self._gather(block1, sep0) self.name_last_node('columns') self._token(')') diff --git a/beanquery/parser_test.py b/beanquery/parser_test.py index cbff0813..cda80ff0 100644 --- a/beanquery/parser_test.py +++ b/beanquery/parser_test.py @@ -59,28 +59,28 @@ def test_select(self): self.assertParse( "SELECT date;", Select([ - ast.Target(ast.Column('date'), None) + ast.Target(ast.Column([ast.Name('date')]), None) ])) self.assertParse( "SELECT date, account", Select([ - ast.Target(ast.Column('date'), None), - ast.Target(ast.Column('account'), None) + ast.Target(ast.Column([ast.Name('date')]), None), + ast.Target(ast.Column([ast.Name('account')]), None) ])) self.assertParse( "SELECT date as xdate;", Select([ - ast.Target(ast.Column('date'), 'xdate') + ast.Target(ast.Column([ast.Name('date')]), 'xdate') ])) self.assertParse( "SELECT date as x, account, position as y;", Select([ - ast.Target(ast.Column('date'), 'x'), - ast.Target(ast.Column('account'), None), - ast.Target(ast.Column('position'), 'y') + ast.Target(ast.Column([ast.Name('date')]), 'x'), + ast.Target(ast.Column([ast.Name('account')]), None), + ast.Target(ast.Column([ast.Name('position')]), 'y') ])) def test_literals(self): @@ -115,38 +115,38 @@ def test_literals(self): self.assertParseTarget("SELECT ('x', 'y', 'z');", ast.Constant(['x', 'y', 'z'])) # column - self.assertParseTarget("SELECT date;", ast.Column('date')) + self.assertParseTarget("SELECT date;", ast.Column([ast.Name('date')])) def test_expressions(self): # comparison operators - self.assertParseTarget("SELECT a = 42;", ast.Equal(ast.Column('a'), ast.Constant(42))) - self.assertParseTarget("SELECT a != 42;", ast.NotEqual(ast.Column('a'), ast.Constant(42))) - self.assertParseTarget("SELECT a > 42;", ast.Greater(ast.Column('a'), ast.Constant(42))) - self.assertParseTarget("SELECT a >= 42;", ast.GreaterEq(ast.Column('a'), ast.Constant(42))) - self.assertParseTarget("SELECT a < 42;", ast.Less(ast.Column('a'), ast.Constant(42))) - self.assertParseTarget("SELECT a <= 42;", ast.LessEq(ast.Column('a'), ast.Constant(42))) - self.assertParseTarget("SELECT a ~ 'abc';", ast.Match(ast.Column('a'), ast.Constant('abc'))) - self.assertParseTarget("SELECT not a;", ast.Not(ast.Column('a'))) - self.assertParseTarget("SELECT a IS NULL;", ast.IsNull(ast.Column('a'))) - self.assertParseTarget("SELECT a IS NOT NULL;", ast.IsNotNull(ast.Column('a'))) + self.assertParseTarget("SELECT a = 42;", ast.Equal(ast.Column([ast.Name('a')]), ast.Constant(42))) + self.assertParseTarget("SELECT a != 42;", ast.NotEqual(ast.Column([ast.Name('a')]), ast.Constant(42))) + self.assertParseTarget("SELECT a > 42;", ast.Greater(ast.Column([ast.Name('a')]), ast.Constant(42))) + self.assertParseTarget("SELECT a >= 42;", ast.GreaterEq(ast.Column([ast.Name('a')]), ast.Constant(42))) + self.assertParseTarget("SELECT a < 42;", ast.Less(ast.Column([ast.Name('a')]), ast.Constant(42))) + self.assertParseTarget("SELECT a <= 42;", ast.LessEq(ast.Column([ast.Name('a')]), ast.Constant(42))) + self.assertParseTarget("SELECT a ~ 'abc';", ast.Match(ast.Column([ast.Name('a')]), ast.Constant('abc'))) + self.assertParseTarget("SELECT not a;", ast.Not(ast.Column([ast.Name('a')]))) + self.assertParseTarget("SELECT a IS NULL;", ast.IsNull(ast.Column([ast.Name('a')]))) + self.assertParseTarget("SELECT a IS NOT NULL;", ast.IsNotNull(ast.Column([ast.Name('a')]))) # bool expressions - self.assertParseTarget("SELECT a AND b;", ast.And([ast.Column('a'), ast.Column('b')])) - self.assertParseTarget("SELECT a AND b AND c;", ast.And([ast.Column('a'), ast.Column('b'), ast.Column('c')])) - self.assertParseTarget("SELECT a OR b;", ast.Or([ast.Column('a'), ast.Column('b')])) - self.assertParseTarget("SELECT a OR b OR c;", ast.Or([ast.Column('a'), ast.Column('b'), ast.Column('c')])) - self.assertParseTarget("SELECT a AND b OR c;", ast.Or([ast.And([ast.Column('a'), ast.Column('b')]), ast.Column('c')])) - self.assertParseTarget("SELECT NOT a;", ast.Not(ast.Column('a'))) + self.assertParseTarget("SELECT a AND b;", ast.And([ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')])])) + self.assertParseTarget("SELECT a AND b AND c;", ast.And([ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]), ast.Column([ast.Name('c')])])) + self.assertParseTarget("SELECT a OR b;", ast.Or([ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')])])) + self.assertParseTarget("SELECT a OR b OR c;", ast.Or([ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]), ast.Column([ast.Name('c')])])) + self.assertParseTarget("SELECT a AND b OR c;", ast.Or([ast.And([ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')])]), ast.Column([ast.Name('c')])])) + self.assertParseTarget("SELECT NOT a;", ast.Not(ast.Column([ast.Name('a')]))) # math expressions with identifiers - self.assertParseTarget("SELECT a * b;", ast.Mul(ast.Column('a'), ast.Column('b'))) - self.assertParseTarget("SELECT a / b;", ast.Div(ast.Column('a'), ast.Column('b'))) - self.assertParseTarget("SELECT a + b;", ast.Add(ast.Column('a'), ast.Column('b'))) - self.assertParseTarget("SELECT a+b;", ast.Add(ast.Column('a'), ast.Column('b'))) - self.assertParseTarget("SELECT a - b;", ast.Sub(ast.Column('a'), ast.Column('b'))) - self.assertParseTarget("SELECT a-b;", ast.Sub(ast.Column('a'), ast.Column('b'))) - self.assertParseTarget("SELECT +a;", ast.Column('a')) - self.assertParseTarget("SELECT -a;", ast.Neg(ast.Column('a'))) + self.assertParseTarget("SELECT a * b;", ast.Mul(ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]))) + self.assertParseTarget("SELECT a / b;", ast.Div(ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]))) + self.assertParseTarget("SELECT a + b;", ast.Add(ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]))) + self.assertParseTarget("SELECT a+b;", ast.Add(ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]))) + self.assertParseTarget("SELECT a - b;", ast.Sub(ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]))) + self.assertParseTarget("SELECT a-b;", ast.Sub(ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')]))) + self.assertParseTarget("SELECT +a;", ast.Column([ast.Name('a')])) + self.assertParseTarget("SELECT -a;", ast.Neg(ast.Column([ast.Name('a')]))) # math expressions with numerals self.assertParseTarget("SELECT 2 * 3;", ast.Mul(ast.Constant(2), ast.Constant(3))) @@ -164,8 +164,8 @@ def test_expressions(self): # functions self.assertParseTarget("SELECT random();", ast.Function('random', [])) - self.assertParseTarget("SELECT min(a);", ast.Function('min', [ast.Column('a')])) - self.assertParseTarget("SELECT min(a, b);", ast.Function('min', [ast.Column('a'), ast.Column('b')])) + self.assertParseTarget("SELECT min(a);", ast.Function('min', [ast.Column([ast.Name('a')])])) + self.assertParseTarget("SELECT min(a, b);", ast.Function('min', [ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')])])) self.assertParseTarget("SELECT count(*);", ast.Function('count', [ast.Asterisk()])) def test_non_associative(self): @@ -178,9 +178,9 @@ def test_complex_expressions(self): "SELECT NOT a = (b != (42 AND 17));", ast.Not( ast.Equal( - ast.Column('a'), + ast.Column([ast.Name('a')]), ast.NotEqual( - ast.Column('b'), + ast.Column([ast.Name('b')]), ast.And([ ast.Constant(42), ast.Constant(17)]))))) @@ -192,45 +192,45 @@ def test_operators_precedence(self): self.assertParseTarget( "SELECT a AND b OR c AND d;", - ast.Or([ast.And([ast.Column('a'), ast.Column('b')]), - ast.And([ast.Column('c'), ast.Column('d')])])) + ast.Or([ast.And([ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')])]), + ast.And([ast.Column([ast.Name('c')]), ast.Column([ast.Name('d')])])])) self.assertParseTarget( "SELECT a = 2 AND b != 3;", - ast.And([ast.Equal(ast.Column('a'), ast.Constant(2)), - ast.NotEqual(ast.Column('b'), ast.Constant(3))])) + ast.And([ast.Equal(ast.Column([ast.Name('a')]), ast.Constant(2)), + ast.NotEqual(ast.Column([ast.Name('b')]), ast.Constant(3))])) self.assertParseTarget( "SELECT not a AND b;", - ast.And([ast.Not(ast.Column('a')), ast.Column('b')])) + ast.And([ast.Not(ast.Column([ast.Name('a')])), ast.Column([ast.Name('b')])])) self.assertParseTarget( "SELECT a + b AND c - d;", - ast.And([ast.Add(ast.Column('a'), ast.Column('b')), - ast.Sub(ast.Column('c'), ast.Column('d'))])) + ast.And([ast.Add(ast.Column([ast.Name('a')]), ast.Column([ast.Name('b')])), + ast.Sub(ast.Column([ast.Name('c')]), ast.Column([ast.Name('d')]))])) self.assertParseTarget( "SELECT a * b + c / d - 3;", ast.Sub( ast.Add( ast.Mul( - ast.Column(name='a'), - ast.Column(name='b')), - ast.Div(ast.Column(name='c'), - ast.Column(name='d'))), + ast.Column([ast.Name('a')]), + ast.Column([ast.Name('b')])), + ast.Div(ast.Column([ast.Name('c')]), + ast.Column([ast.Name('d')]))), ast.Constant(value=3))) self.assertParseTarget( "SELECT 'orange' IN tags AND 'bananas' IN tags;", ast.And([ - ast.In(ast.Constant('orange'), ast.Column('tags')), - ast.In(ast.Constant('bananas'), ast.Column('tags'))])) + ast.In(ast.Constant('orange'), ast.Column([ast.Name('tags')])), + ast.In(ast.Constant('bananas'), ast.Column([ast.Name('tags')]))])) class TestSelectFrom(QueryParserTestBase): def test_select_from(self): - expr = ast.Equal(ast.Column('d'), ast.And([ast.Function('max', [ast.Column('e')]), ast.Constant(17)])) + expr = ast.Equal(ast.Column([ast.Name('d')]), ast.And([ast.Function('max', [ast.Column([ast.Name('e')])]), ast.Constant(17)])) with self.assertRaises(parser.ParseError): parser.parse("SELECT a, b FROM;") @@ -279,12 +279,12 @@ def test_select_from(self): class TestSelectWhere(QueryParserTestBase): def test_where(self): - expr = ast.Equal(ast.Column('d'), ast.And([ast.Function('max', [ast.Column('e')]), ast.Constant(17)])) + expr = ast.Equal(ast.Column([ast.Name('d')]), ast.And([ast.Function('max', [ast.Column([ast.Name('e')])]), ast.Constant(17)])) self.assertParse( "SELECT a, b WHERE d = (max(e) and 17);", Select([ - ast.Target(ast.Column('a'), None), - ast.Target(ast.Column('b'), None) + ast.Target(ast.Column([ast.Name('a')]), None), + ast.Target(ast.Column([ast.Name('b')]), None) ], None, expr)) with self.assertRaises(parser.ParseError): @@ -294,12 +294,12 @@ def test_where(self): class TestSelectFromAndWhere(QueryParserTestBase): def test_from_and_where(self): - expr = ast.Equal(ast.Column('d'), ast.And([ast.Function('max', [ast.Column('e')]), ast.Constant(17)])) + expr = ast.Equal(ast.Column([ast.Name('d')]), ast.And([ast.Function('max', [ast.Column([ast.Name('e')])]), ast.Constant(17)])) self.assertParse( "SELECT a, b FROM d = (max(e) and 17) WHERE d = (max(e) and 17);", Select([ - ast.Target(ast.Column('a'), None), - ast.Target(ast.Column('b'), None) + ast.Target(ast.Column([ast.Name('a')]), None), + ast.Target(ast.Column([ast.Name('b')]), None) ], ast.From(expr, None, None, None), expr)) @@ -311,16 +311,16 @@ def test_from_select(self): SELECT * FROM date = 2014-05-02 ) WHERE c = 5 LIMIT 100;""", Select([ - ast.Target(ast.Column('a'), None), - ast.Target(ast.Column('b'), None)], + ast.Target(ast.Column([ast.Name('a')]), None), + ast.Target(ast.Column([ast.Name('b')]), None)], Select( ast.Asterisk(), ast.From( ast.Equal( - ast.Column('date'), + ast.Column([ast.Name('date')]), ast.Constant(datetime.date(2014, 5, 2))), None, None, None)), - ast.Equal(ast.Column('c'), ast.Constant(5)), + ast.Equal(ast.Column([ast.Name('c')]), ast.Constant(5)), limit=100)) @@ -330,16 +330,16 @@ def test_groupby_one(self): self.assertParse( "SELECT * GROUP BY a;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ast.Column('a')], None))) + group_by=ast.GroupBy([ast.Column([ast.Name('a')])], None))) def test_groupby_many(self): self.assertParse( "SELECT * GROUP BY a, b, c;", Select(ast.Asterisk(), group_by=ast.GroupBy([ - ast.Column('a'), - ast.Column('b'), - ast.Column('c')], None))) + ast.Column([ast.Name('a')]), + ast.Column([ast.Name('b')]), + ast.Column([ast.Name('c')])], None))) def test_groupby_expr(self): self.assertParse( @@ -348,18 +348,18 @@ def test_groupby_expr(self): group_by=ast.GroupBy([ ast.Greater( ast.Function('length', [ - ast.Column('a')]), + ast.Column([ast.Name('a')])]), ast.Constant(0)), - ast.Column('b')], None))) + ast.Column([ast.Name('b')])], None))) def test_groupby_having(self): self.assertParse( "SELECT * GROUP BY a HAVING sum(x) = 0;", Select(ast.Asterisk(), - group_by=ast.GroupBy([ast.Column('a')], + group_by=ast.GroupBy([ast.Column([ast.Name('a')])], ast.Equal( ast.Function('sum', [ - ast.Column('x')]), + ast.Column([ast.Name('x')])]), ast.Constant(0))))) def test_groupby_numbers(self): @@ -385,39 +385,39 @@ def test_orderby_one(self): "SELECT * ORDER BY a;", Select(ast.Asterisk(), order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC)])) + ast.OrderBy(ast.Column([ast.Name('a')]), ast.Ordering.ASC)])) def test_orderby_many(self): self.assertParse( "SELECT * ORDER BY a, b, c;", Select(ast.Asterisk(), order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC), - ast.OrderBy(ast.Column('b'), ast.Ordering.ASC), - ast.OrderBy(ast.Column('c'), ast.Ordering.ASC)])) + ast.OrderBy(ast.Column([ast.Name('a')]), ast.Ordering.ASC), + ast.OrderBy(ast.Column([ast.Name('b')]), ast.Ordering.ASC), + ast.OrderBy(ast.Column([ast.Name('c')]), ast.Ordering.ASC)])) def test_orderby_asc(self): self.assertParse( "SELECT * ORDER BY a ASC;", Select(ast.Asterisk(), order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC)])) + ast.OrderBy(ast.Column([ast.Name('a')]), ast.Ordering.ASC)])) def test_orderby_desc(self): self.assertParse( "SELECT * ORDER BY a DESC;", Select(ast.Asterisk(), order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.DESC)])) + ast.OrderBy(ast.Column([ast.Name('a')]), ast.Ordering.DESC)])) def test_orderby_many_asc_desc(self): self.assertParse( "SELECT * ORDER BY a ASC, b DESC, c;", Select(ast.Asterisk(), order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC), - ast.OrderBy(ast.Column('b'), ast.Ordering.DESC), - ast.OrderBy(ast.Column('c'), ast.Ordering.ASC)])) + ast.OrderBy(ast.Column([ast.Name('a')]), ast.Ordering.ASC), + ast.OrderBy(ast.Column([ast.Name('b')]), ast.Ordering.DESC), + ast.OrderBy(ast.Column([ast.Name('c')]), ast.Ordering.ASC)])) def test_orderby_empty(self): with self.assertRaises(parser.ParseError): @@ -438,7 +438,7 @@ def test_pivotby(self): self.assertParse( "SELECT * PIVOT BY a, b", - Select(ast.Asterisk(), pivot_by=ast.PivotBy([ast.Column('a'), ast.Column('b')]))) + Select(ast.Asterisk(), pivot_by=ast.PivotBy([ast.Name('a'), ast.Name('b')]))) self.assertParse( "SELECT * PIVOT BY 1, 2", @@ -449,7 +449,7 @@ class TestSelectOptions(QueryParserTestBase): def test_distinct(self): self.assertParse( - "SELECT DISTINCT x;", Select([ast.Target(ast.Column('x'), None)], distinct=True)) + "SELECT DISTINCT x;", Select([ast.Target(ast.Column([ast.Name('x')]), None)], distinct=True)) def test_limit_present(self): self.assertParse( @@ -473,7 +473,7 @@ def test_balances_from(self): None, ast.From( ast.Equal( - ast.Column('date'), + ast.Column([ast.Name('date')]), ast.Constant(datetime.date(2014, 1, 1))), None, True, None), None)) @@ -484,7 +484,7 @@ def test_balances_from_with_transformer(self): ast.Balances('units', ast.From( ast.Equal( - ast.Column('date'), + ast.Column([ast.Name('date')]), ast.Constant(datetime.date(2014, 1, 1))), None, True, None), None)) @@ -495,7 +495,7 @@ def test_balances_from_with_transformer_simple(self): ast.Balances('units', None, ast.Equal( - ast.Column('date'), + ast.Column([ast.Name('date')]), ast.Constant(datetime.date(2014, 1, 1))))) @@ -527,7 +527,7 @@ def test_journal_from(self): ast.Journal(None, None, ast.From( ast.Equal( - ast.Column('date'), + ast.Column([ast.Name('date')]), ast.Constant(datetime.date(2014, 1, 1)) ), None, True, None))) @@ -544,7 +544,7 @@ def test_print_from(self): ast.Print( ast.From( ast.Equal( - ast.Column('date'), + ast.Column([ast.Name('date')]), ast.Constant(datetime.date(2014, 1, 1)) ), None, True, None))) @@ -555,8 +555,8 @@ def test_comments(self): self.assertParse( """SELECT first, /* comment */ second""", Select([ - ast.Target(ast.Column('first'), None), - ast.Target(ast.Column('second'), None) + ast.Target(ast.Column([ast.Name('first')]), None), + ast.Target(ast.Column([ast.Name('second')]), None) ])) self.assertParse( @@ -564,29 +564,29 @@ def test_comments(self): comment */ second;""", Select([ - ast.Target(ast.Column('first'), None), - ast.Target(ast.Column('second'), None), + ast.Target(ast.Column([ast.Name('first')]), None), + ast.Target(ast.Column([ast.Name('second')]), None), ])) self.assertParse( """SELECT first, /**/ second;""", Select([ - ast.Target(ast.Column('first'), None), - ast.Target(ast.Column('second'), None), + ast.Target(ast.Column([ast.Name('first')]), None), + ast.Target(ast.Column([ast.Name('second')]), None), ])) self.assertParse( """SELECT first, /* /* */ second;""", Select([ - ast.Target(ast.Column('first'), None), - ast.Target(ast.Column('second'), None), + ast.Target(ast.Column([ast.Name('first')]), None), + ast.Target(ast.Column([ast.Name('second')]), None), ])) self.assertParse( """SELECT first, /* ; */ second;""", Select([ - ast.Target(ast.Column('first'), None), - ast.Target(ast.Column('second'), None), + ast.Target(ast.Column([ast.Name('first')]), None), + ast.Target(ast.Column([ast.Name('second')]), None), ])) @@ -608,20 +608,26 @@ def test_tosexp(self): (target expression: (add left: (column - name: 'a') + ids: ( + (name + name: 'a'))) right: (constant value: 1)))) from-clause: (table name: 'test') where-clause: (greater left: (column - name: 'a') + ids: ( + (name + name: 'a'))) right: (constant value: 42)) order-by: ( (orderby column: (column - name: 'b') + ids: ( + (name + name: 'b'))) ordering: desc)))''')) def test_walk(self): diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index eda87ec7..d5424efb 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -18,7 +18,7 @@ import operator from decimal import Decimal -from typing import List +from typing import List, Callable from dateutil.relativedelta import relativedelta @@ -703,3 +703,54 @@ def __call__(self): values = tuple(value(None) for value in self.values) self.table.insert(values) return (), [] + + +@dataclasses.dataclass +class EvalProjection: + table: tables.Table + expressions: list[EvalNode] + + def __iter__(self): + expressions = self.expressions + for row in self.table: + yield tuple(expr(row) for expr in expressions) + + def __call__(self): + columns = tuple(cursor.Column(expr.name, expr.dtype) for expr in self.expressions) + return columns, list(iter(self)) + + +@dataclasses.dataclass +class EvalScanJoin: + left: object + right: object + condition: EvalNode + + def __iter__(self): + for lrow in self.left: + for rrow in self.right: + row = (*lrow, *rrow) + if self.condition(row): + yield row + + +@dataclasses.dataclass +class EvalHashJoin: + left: object + right: object + condition: EvalNode + keyfunc: Callable + + def __iter__(self): + cache = {} + for lrow in self.left: + key = self.keyfunc(lrow) + rrow = cache.get(key) + if rrow is not None: + yield (*lrow, *rrow) + continue + for rrow in self.right: + row = (*lrow, *rrow) + if self.condition(row): + cache[key] = rrow + yield row diff --git a/beanquery/query_compile_test.py b/beanquery/query_compile_test.py index 4418f45a..ef0b5184 100644 --- a/beanquery/query_compile_test.py +++ b/beanquery/query_compile_test.py @@ -16,8 +16,10 @@ from beanquery import query_env as qe from beanquery import parser from beanquery import tables +from beanquery import types from beanquery.parser import ast from beanquery.sources import test +from beanquery.sources.beancount import GetItemColumn class Table: @@ -52,7 +54,7 @@ def setUpClass(cls): # parser for expressions cls.parser = parser.BQLParser(start='expression') cls.compiler = compiler.Compiler(cls.context) - cls.compiler.table = cls.context.tables['test'] + cls.compiler.stack.append(cls.context.tables['test']) def compile(self, expr): expr = self.parser.parse(expr, semantics=parser.BQLSemantics()) @@ -574,22 +576,20 @@ def test_compile_order_by_aggregate(self): class TestTranslationJournal(CompileSelectBase): - maxDiff = 4096 - def test_journal(self): journal = parser.parse("JOURNAL;") select = compiler.transform_journal(journal) self.assertEqual(select, ast.Select([ - ast.Target(ast.Column('date'), None), - ast.Target(ast.Column('flag'), None), + ast.Target(ast.Column([ast.Name('date')]), None), + ast.Target(ast.Column([ast.Name('flag')]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('payee'), ast.Constant(48)]), None), + ast.Column([ast.Name('payee')]), ast.Constant(48)]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('narration'), ast.Constant(80)]), None), - ast.Target(ast.Column('account'), None), - ast.Target(ast.Column('position'), None), - ast.Target(ast.Column('balance'), None), + ast.Column([ast.Name('narration')]), ast.Constant(80)]), None), + ast.Target(ast.Column([ast.Name('account')]), None), + ast.Target(ast.Column([ast.Name('position')]), None), + ast.Target(ast.Column([ast.Name('balance')]), None), ], None, None, None, None, None, None, None)) @@ -597,79 +597,79 @@ def test_journal_with_account(self): journal = parser.parse("JOURNAL 'liabilities';") select = compiler.transform_journal(journal) self.assertEqual(select, ast.Select([ - ast.Target(ast.Column('date'), None), - ast.Target(ast.Column('flag'), None), + ast.Target(ast.Column([ast.Name('date')]), None), + ast.Target(ast.Column([ast.Name('flag')]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('payee'), + ast.Column([ast.Name('payee')]), ast.Constant(48)]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('narration'), + ast.Column([ast.Name('narration')]), ast.Constant(80)]), None), - ast.Target(ast.Column('account'), None), - ast.Target(ast.Column('position'), None), - ast.Target(ast.Column('balance'), None), + ast.Target(ast.Column([ast.Name('account')]), None), + ast.Target(ast.Column([ast.Name('position')]), None), + ast.Target(ast.Column([ast.Name('balance')]), None), ], None, - ast.Match(ast.Column('account'), ast.Constant('liabilities')), + ast.Match(ast.Column([ast.Name('account')]), ast.Constant('liabilities')), None, None, None, None, None)) def test_journal_with_account_and_from(self): journal = parser.parse("JOURNAL 'liabilities' FROM year = 2014;") select = compiler.transform_journal(journal) self.assertEqual(select, ast.Select([ - ast.Target(ast.Column('date'), None), - ast.Target(ast.Column('flag'), None), + ast.Target(ast.Column([ast.Name('date')]), None), + ast.Target(ast.Column([ast.Name('flag')]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('payee'), + ast.Column([ast.Name('payee')]), ast.Constant(48)]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('narration'), + ast.Column([ast.Name('narration')]), ast.Constant(80)]), None), - ast.Target(ast.Column('account'), None), - ast.Target(ast.Column('position'), None), - ast.Target(ast.Column('balance'), None), + ast.Target(ast.Column([ast.Name('account')]), None), + ast.Target(ast.Column([ast.Name('position')]), None), + ast.Target(ast.Column([ast.Name('balance')]), None), ], - ast.From(ast.Equal(ast.Column('year'), ast.Constant(2014)), None, None, None), - ast.Match(ast.Column('account'), ast.Constant('liabilities')), + ast.From(ast.Equal(ast.Column([ast.Name('year')]), ast.Constant(2014)), None, None, None), + ast.Match(ast.Column([ast.Name('account')]), ast.Constant('liabilities')), None, None, None, None, None)) def test_journal_with_account_func_and_from(self): journal = parser.parse("JOURNAL 'liabilities' AT cost FROM year = 2014;") select = compiler.transform_journal(journal) self.assertEqual(select, ast.Select([ - ast.Target(ast.Column('date'), None), - ast.Target(ast.Column('flag'), None), + ast.Target(ast.Column([ast.Name('date')]), None), + ast.Target(ast.Column([ast.Name('flag')]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('payee'), + ast.Column([ast.Name('payee')]), ast.Constant(48)]), None), ast.Target(ast.Function('maxwidth', [ - ast.Column('narration'), + ast.Column([ast.Name('narration')]), ast.Constant(80)]), None), - ast.Target(ast.Column('account'), None), - ast.Target(ast.Function('cost', [ast.Column('position')]), None), - ast.Target(ast.Function('cost', [ast.Column('balance')]), None), + ast.Target(ast.Column([ast.Name('account')]), None), + ast.Target(ast.Function('cost', [ast.Column([ast.Name('position')])]), None), + ast.Target(ast.Function('cost', [ast.Column([ast.Name('balance')])]), None), ], - ast.From(ast.Equal(ast.Column('year'), ast.Constant(2014)), None, None, None), - ast.Match(ast.Column('account'), ast.Constant('liabilities')), + ast.From(ast.Equal(ast.Column([ast.Name('year')]), ast.Constant(2014)), None, None, None), + ast.Match(ast.Column([ast.Name('account')]), ast.Constant('liabilities')), None, None, None, None, None)) class TestTranslationBalance(CompileSelectBase): group_by = ast.GroupBy([ - ast.Column('account'), + ast.Column([ast.Name('account')]), ast.Function('account_sortkey', [ - ast.Column(name='account')])], None) + ast.Column([ast.Name('account')])])], None) - order_by = [ast.OrderBy(ast.Function('account_sortkey', [ast.Column('account')]), ast.Ordering.ASC)] + order_by = [ast.OrderBy(ast.Function('account_sortkey', [ast.Column([ast.Name('account')])]), ast.Ordering.ASC)] def test_balance(self): balance = parser.parse("BALANCES;") select = compiler.transform_balances(balance) self.assertEqual(select, ast.Select([ - ast.Target(ast.Column('account'), None), + ast.Target(ast.Column([ast.Name('account')]), None), ast.Target(ast.Function('sum', [ - ast.Column('position') + ast.Column([ast.Name('position')]) ]), None), ], None, None, self.group_by, self.order_by, None, None, None)) @@ -678,10 +678,10 @@ def test_balance_with_units(self): balance = parser.parse("BALANCES AT cost;") select = compiler.transform_balances(balance) self.assertEqual(select, ast.Select([ - ast.Target(ast.Column('account'), None), + ast.Target(ast.Column([ast.Name('account')]), None), ast.Target(ast.Function('sum', [ ast.Function('cost', [ - ast.Column('position') + ast.Column([ast.Name('position')]) ]) ]), None) ], @@ -691,14 +691,14 @@ def test_balance_with_units_and_from(self): balance = parser.parse("BALANCES AT cost FROM year = 2014;") select = compiler.transform_balances(balance) self.assertEqual(select, ast.Select([ - ast.Target(ast.Column('account'), None), + ast.Target(ast.Column([ast.Name('account')]), None), ast.Target(ast.Function('sum', [ ast.Function('cost', [ - ast.Column('position') + ast.Column([ast.Name('position')]) ]) ]), None), ], - ast.From(ast.Equal(ast.Column('year'), ast.Constant(2014)), None, None, None), + ast.From(ast.Equal(ast.Column([ast.Name('year')]), ast.Constant(2014)), None, None, None), None, self.group_by, self.order_by, None, None, None)) def test_print(self): @@ -730,7 +730,6 @@ def setUpClass(cls): def compile(self, query, params): c = compiler.Compiler(self.context) - c.table = self.context.tables.get('') return c.compile(parser.parse(query), params) def test_named_parameters(self): @@ -850,3 +849,54 @@ def test_quoted_string_in_expression(self): # if the double quoted string is not a table name, it is a string literal ideed query = self.compile('''SELECT "a" + "b" FROM postings''') self.assertEqual(query.c_targets[0].c_expr.value, 'ab') + + +class TestColumns(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.conn = beanquery.connect('') + cls.conn.execute('''CREATE TABLE foo (x int, y int)''') + class Baz(types.Structure): + columns = { + 'one': GetItemColumn('one', int), + 'two': GetItemColumn('two', int), + } + class Table(tables.Table): + name = 'bar' + columns = { + 'p': GetItemColumn(0, 'int'), + 'q': GetItemColumn(1, Baz), + } + cls.conn.tables['bar'] = Table() + + def compile(self, query, params=None): + return self.conn.compile(self.conn.parse(query)) + + def test_column(self): + q = self.compile('''SELECT x FROM foo''') + self.assertEqual(q.c_targets[0].c_expr, self.conn.tables['foo'].columns['x']) + + def test_column_invalid(self): + with self.assertRaisesRegex(beanquery.CompilationError, 'column "nope" not found in table "foo"'): + self.compile('''SELECT nope FROM foo''') + + def test_table_column(self): + q = self.compile('''SELECT foo.x FROM foo''') + self.assertEqual(q.c_targets[0].c_expr, self.conn.tables['foo'].columns['x']) + + def test_table_column_invalid(self): + with self.assertRaisesRegex(beanquery.CompilationError, 'column "nope" not found in table "foo"'): + self.compile('''SELECT foo.nope FROM foo''') + + def test_column_field(self): + q = self.compile('''SELECT q.one FROM bar''') + self.assertEqual(q.c_targets[0].c_expr.operand, self.conn.tables['bar'].columns['q']) + + def test_table_column_field(self): + q = self.compile('''SELECT bar.q.one FROM bar''') + self.assertEqual(q.c_targets[0].c_expr.operand, self.conn.tables['bar'].columns['q']) + + def test_table_column_field_invalid(self): + with self.assertRaisesRegex(beanquery.CompilationError, 'structured type has no attribute "nope"'): + self.compile('''SELECT bar.q.nope FROM bar''') diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 0358d632..89440ce5 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -506,7 +506,7 @@ class TestFilterEntries(CommonInputBase, QueryBase): def compile(self, query): # use the ``entries`` table as default table for queries c = compiler.Compiler(self.ctx) - c.table = self.ctx.tables.get('entries') + self.ctx.tables[None] = self.ctx.tables.get('entries') return c.compile(self.ctx.parse(query)) @staticmethod @@ -1881,3 +1881,29 @@ def test_csv_source(self, filename): self.assertEqual(names, ['id', 'name', 'check', 'date', 'value']) types = [column.dtype for column in conn.tables['test'].columns.values()] self.assertEqual(types, [int, str, bool, datetime.date, Decimal]) + + +class TestJoin(unittest.TestCase): + + @classmethod + def setUpClass(cls): + conn = beanquery.connect('') + conn.execute('''CREATE TABLE foo (x int, a str)''') + conn.execute('''CREATE TABLE qux (x int, b str)''') + conn.execute('''INSERT INTO foo (x, a) VALUES (1, 'one')''') + conn.execute('''INSERT INTO foo (x, a) VALUES (2, 'two')''') + conn.execute('''INSERT INTO foo (x, a) VALUES (2, 'two')''') + conn.execute('''INSERT INTO foo (x, a) VALUES (2, 'two')''') + conn.execute('''INSERT INTO foo (x, a) VALUES (3, 'three')''') + conn.execute('''INSERT INTO qux (x, b) VALUES (1, 'uno')''') + conn.execute('''INSERT INTO qux (x, b) VALUES (2, 'due')''') + conn.execute('''INSERT INTO qux (x, b) VALUES (3, 'tre')''') + cls.conn = conn + + def test_join_on(self): + curs = self.conn.execute('''SELECT a + '=' + b FROM foo JOIN qux ON foo.x = qux.x + 1''') + self.assertEqual(curs.fetchall(), [('two=uno',), ('two=uno',), ('two=uno',), ('three=due',)]) + + def test_join_using(self): + curs = self.conn.execute('''SELECT a + '=' + b FROM foo JOIN qux USING x''') + self.assertEqual(curs.fetchall(), [('one=uno',), ('two=due',), ('two=due',), ('two=due',), ('three=tre',)]) diff --git a/pyproject.toml b/pyproject.toml index cc35d267..874a775c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ exclude_also = [ ] [tool.ruff] -line-length = 128 +line-length = 256 target-version = 'py38' [tool.ruff.lint]