diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 6e6c442e..a7f7d4af 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -764,21 +764,24 @@ def _insert(self, node: ast.Insert): impl = getattr(table, 'insert', None) if impl is None: raise CompilationError(f'table "{node.table.name}" does not support insertion', node.table) - if len(node.values) != len(node.columns): - raise CompilationError( - f'column names and values mismatch: ' - f'expected {len(node.columns)} but {len(node.values)} values were supplied', node) - values = [EvalConstant(None)] * len(table.columns) columns = {name: i for i, name in enumerate(table.columns.keys())} - for column, value in zip(node.columns, node.values): - index = columns.get(column.name) - if index is None: - raise CompilationError(f'column "{column.name}" not found in table "{node.table.name}"', column) - expr = self._compile(value) - if not expr.dtype == table.columns.get(column.name).dtype: - raise CompilationError(f'expression has wrong type for column "{column.name}"', value) - values[index] = expr - return EvalInsert(table, values) + rows = [] + for row in node.values: + if len(row) != len(node.columns): + raise CompilationError( + f'column names and values mismatch: ' + f'expected {len(node.columns)} but {len(row)} values were supplied', node) + values = [EvalConstant(None)] * len(table.columns) + for column, value in zip(node.columns, row): + index = columns.get(column.name) + if index is None: + raise CompilationError(f'column "{column.name}" not found in table "{node.table.name}"', column) + expr = self._compile(value) + if not expr.dtype == table.columns.get(column.name).dtype: + raise CompilationError(f'expression has wrong type for column "{column.name}"', value) + values[index] = expr + rows.append(values) + return EvalInsert(table, rows) def transform_journal(journal): diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 1c078b63..4c655165 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -397,5 +397,5 @@ create_table::CreateTable insert::Insert = 'INSERT' 'INTO' ~ table:table ['(' columns:','.{column} ')'] - 'VALUES' '(' values:','.{expression} ')' + 'VALUES' ','.{ '(' values+:','.{expression}+ ')' } ; diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 2ec9ee84..35bf3f56 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -1256,17 +1256,30 @@ def block1(): self._token(')') self._define(['columns'], []) self._token('VALUES') - self._token('(') def sep2(): self._token(',') def block3(): - self._expression_() + self._token('(') + + def sep4(): + self._token(',') + + def block5(): + self._expression_() + self._positive_gather(block5, sep4) + self.add_last_node_to_name('values') + self._token(')') + self._define( + [], + ['values'], + ) self._gather(block3, sep2) - self.name_last_node('values') - self._token(')') - self._define(['columns', 'table', 'values'], []) + self._define( + ['columns', 'table'], + ['values'], + ) def main(filename, **kwargs): diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index eda87ec7..e83ed380 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -18,7 +18,7 @@ import operator from decimal import Decimal -from typing import List +from typing import List, Sequence from dateutil.relativedelta import relativedelta @@ -697,9 +697,9 @@ def __call__(self): @dataclasses.dataclass class EvalInsert: table: tables.Table - values: list[EvalNode] + rows: Sequence[Sequence[EvalNode]] def __call__(self): - values = tuple(value(None) for value in self.values) - self.table.insert(values) + for row in self.rows: + self.table.insert(tuple(value(None) for value in row)) return (), [] diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 0358d632..f95efadb 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -1818,6 +1818,12 @@ def test_insert_placeholders(self): self.assertEqual(self.conn.tables['abcd'].data[0], values) self.assertEqual(curs.fetchall(), []) + def test_insert_many(self): + curs = self.conn.execute('''INSERT INTO abcd (a) VALUES (1), (2), (3), (4)''') + values = [row[0] for row in self.conn.tables['abcd'].data] + self.assertEqual(values, [1, 2, 3, 4]) + self.assertEqual(curs.fetchall(), []) + class TestCSVTable(unittest.TestCase): diff --git a/pyproject.toml b/pyproject.toml index 2778f7ee..ff959e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ ignore = [ 'PLW2901', 'RUF012', 'RUF023', # unsorted-dunder-slots + 'RUF059', # unused-unpacked-variable 'UP007', 'UP032', ]