diff --git a/setup.py b/setup.py index 645af58..908996e 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ url='https://github.com/Khan/tinyquery', keywords=['bigquery'], packages=['tinyquery'], - install_requires=['arrow==0.12.1', 'ply==3.10'], + install_requires=['arrow==0.12.1', 'ply==3.10', 'six==1.11.0'], classifiers=[ 'License :: OSI Approved :: MIT License', 'Programming Language :: Python :: 2', diff --git a/tinyquery/api_client.py b/tinyquery/api_client.py index 0a6a436..62ce4f7 100644 --- a/tinyquery/api_client.py +++ b/tinyquery/api_client.py @@ -4,9 +4,13 @@ This can be used in place of the value returned by apiclient.discovery.build(). """ +from __future__ import absolute_import + import functools import json +import six + class TinyQueryApiClient(object): def __init__(self, tq_service): @@ -36,7 +40,7 @@ def __init__(self, func, args, kwargs): self.args = args self.kwargs = kwargs - def execute(self): + def execute(self, **kwargs): return self.func(*self.args, **self.kwargs) @@ -153,7 +157,7 @@ def insert(self, projectId, body): create_disposition, write_disposition) else: assert False, 'Unknown job type: {}'.format( - body['configuration'].keys()) + list(body['configuration'].keys())) @staticmethod def _get_config_table(config, key): @@ -225,16 +229,16 @@ def schema_from_table(table): """Given a tinyquery.Table, build an API-compatible schema.""" return {'fields': [ {'name': name, 'type': col.type} - for name, col in table.columns.iteritems() + for name, col in table.columns.items() ]} def rows_from_table(table): """Given a tinyquery.Table, build an API-compatible rows object.""" result_rows = [] - for i in xrange(table.num_rows): + for i in six.moves.xrange(table.num_rows): field_values = [{'v': str(col.values[i])} - for col in table.columns.itervalues()] + for col in table.columns.values()] result_rows.append({ 'f': field_values }) diff --git a/tinyquery/api_client_test.py b/tinyquery/api_client_test.py index 541e63b..96617f8 100644 --- a/tinyquery/api_client_test.py +++ b/tinyquery/api_client_test.py @@ -1,8 +1,10 @@ +from __future__ import absolute_import + import unittest -import api_client -import tq_types -import tinyquery +from tinyquery import api_client +from tinyquery import tq_types +from tinyquery import tinyquery class ApiClientTest(unittest.TestCase): @@ -144,7 +146,7 @@ def test_table_copy(self): } ).execute() - for _ in xrange(5): + for _ in range(5): self.tq_service.jobs().insert( projectId='test_project', body={ diff --git a/tinyquery/compiler.py b/tinyquery/compiler.py index c95ec6f..9ab0685 100644 --- a/tinyquery/compiler.py +++ b/tinyquery/compiler.py @@ -4,19 +4,20 @@ -Validate that the expression is well-typed. -Resolve all select fields to their aliases and types. """ +from __future__ import absolute_import + import collections import itertools -import parser -import runtime -import tq_ast -import typed_ast -import type_context -import tq_types - +import six -class CompileError(Exception): - pass +from tinyquery import exceptions +from tinyquery import parser +from tinyquery import runtime +from tinyquery import tq_ast +from tinyquery import typed_ast +from tinyquery import type_context +from tinyquery import tq_types def compile_text(text, tables_by_name): @@ -93,7 +94,7 @@ def expand_select_fields(self, select_fields, table_expr): """ table_ctx = table_expr.type_ctx star_select_fields = [] - for table_name, col_name in table_ctx.columns.iterkeys(): + for table_name, col_name in table_ctx.columns: if table_name is not None: col_ref = table_name + '.' + col_name else: @@ -112,9 +113,9 @@ def expand_select_fields(self, select_fields, table_expr): elif (field.expr and isinstance(field.expr, tq_ast.ColumnId) and field.expr.name.endswith('.*')): prefix = field.expr.name[:-len('.*')] - record_star_fields = filter( - lambda f: f.alias.startswith(prefix), - star_select_fields) + record_star_fields = [f + for f in star_select_fields + if f.alias.startswith(prefix)] result_fields.extend(record_star_fields) else: result_fields.append(field) @@ -213,7 +214,7 @@ def compile_table_expr(self, table_expr): return method(table_expr) def compile_table_expr_TableId(self, table_expr): - import tinyquery # TODO(colin): fix circular import + from tinyquery import tinyquery # TODO(colin): fix circular import table = self.tables_by_name[table_expr.name] if isinstance(table, tinyquery.Table): return self.compile_table_ref(table_expr, table) @@ -225,7 +226,7 @@ def compile_table_expr_TableId(self, table_expr): def compile_table_ref(self, table_expr, table): alias = table_expr.alias or table_expr.name columns = collections.OrderedDict([ - (name, column.type) for name, column in table.columns.iteritems() + (name, column.type) for name, column in table.columns.items() ]) type_ctx = type_context.TypeContext.from_table_and_columns( alias, columns, None) @@ -265,8 +266,7 @@ def compile_table_expr_Join(self, table_expr): [table_expr.base], (join_part.table_expr for join_part in table_expr.join_parts) ) - compiled_result = map(self.compile_joined_table, - table_expressions) + compiled_result = [self.compile_joined_table(x) for x in table_expressions] compiled_table_exprs, compiled_aliases = zip(*compiled_result) type_contexts = [compiled_table.type_ctx for compiled_table in compiled_table_exprs] @@ -280,9 +280,11 @@ def compile_table_expr_Join(self, table_expr): type_contexts) return typed_ast.Join( base=compiled_table_exprs[0], - tables=zip(compiled_table_exprs[1:], - (join_part.join_type - for join_part in table_expr.join_parts)), + # wrapping in list() for python 3 support (shouldn't be a large number + # of items so performance impact should be minimal) + tables=list(zip(compiled_table_exprs[1:], + (join_part.join_type + for join_part in table_expr.join_parts))), conditions=result_fields, type_ctx=result_type_ctx) @@ -294,7 +296,7 @@ def compile_joined_table(self, table_expr): elif isinstance(table_expr, tq_ast.TableId): alias = table_expr.name else: - raise CompileError('Table expression must have an alias name.') + raise exceptions.CompileError('Table expression must have an alias name.') result_ctx = compiled_table.type_ctx.context_with_full_alias(alias) compiled_table = compiled_table.with_type_ctx(result_ctx) return compiled_table, alias @@ -366,7 +368,7 @@ def compile_join_field(expr, join_type): left_column_id)] # Fall through to the error case if the aliases are the # same for both sides. - raise CompileError('JOIN conditions must consist of an AND of = ' + raise exceptions.CompileError('JOIN conditions must consist of an AND of = ' 'comparisons between two field on distinct ' 'tables. Got expression %s' % expr) return [compile_join_field(expr, join_type) @@ -434,7 +436,7 @@ def compile_groups(self, groups, select_fields, aliases, table_ctx): def compile_select_field(self, expr, alias, within_clause, type_ctx): if within_clause is not None and within_clause != 'RECORD' and ( expr.args[0].name.split('.')[0] != within_clause): - raise CompileError('WITHIN clause syntax error') + raise exceptions.CompileError('WITHIN clause syntax error') else: compiled_expr = self.compile_expr(expr, type_ctx) return typed_ast.SelectField(compiled_expr, alias, within_clause) @@ -467,7 +469,7 @@ def compile_Literal(self, expr, type_ctx): return typed_ast.Literal(expr.value, tq_types.INT) if isinstance(expr.value, float): return typed_ast.Literal(expr.value, tq_types.FLOAT) - elif isinstance(expr.value, basestring): + elif isinstance(expr.value, six.string_types): return typed_ast.Literal(expr.value, tq_types.STRING) elif expr.value is None: return typed_ast.Literal(expr.value, tq_types.NONETYPE) @@ -485,7 +487,7 @@ def compile_UnaryOperator(self, expr, type_ctx): try: result_type = func.check_types(compiled_val.type) except TypeError: - raise CompileError('Invalid type for operator {}: {}'.format( + raise exceptions.CompileError('Invalid type for operator {}: {}'.format( expr.operator, [compiled_val.type])) return typed_ast.FunctionCall(func, [compiled_val], result_type) @@ -501,7 +503,7 @@ def compile_BinaryOperator(self, expr, type_ctx): result_type = func.check_types(compiled_left.type, compiled_right.type) except TypeError: - raise CompileError('Invalid types for operator {}: {}'.format( + raise exceptions.CompileError('Invalid types for operator {}: {}'.format( expr.operator, [arg.type for arg in [compiled_left, compiled_right]])) @@ -516,7 +518,7 @@ def compile_FunctionCall(self, expr, type_ctx): # that the evaluator knows to change the context. if self.is_innermost_aggregate(expr): if type_ctx.aggregate_context is None: - raise CompileError('Unexpected aggregate function.') + raise exceptions.CompileError('Unexpected aggregate function.') sub_expr_ctx = type_ctx.aggregate_context ast_type = typed_ast.AggregateFunctionCall else: @@ -530,7 +532,7 @@ def compile_FunctionCall(self, expr, type_ctx): result_type = func.check_types( *(arg.type for arg in compiled_args)) except TypeError: - raise CompileError('Invalid types for function {}: {}'.format( + raise exceptions.CompileError('Invalid types for function {}: {}'.format( expr.name, [arg.type for arg in compiled_args])) return ast_type(func, compiled_args, result_type) @@ -557,7 +559,7 @@ def get_aliases(cls, select_field_list): for alias in proposed_aliases: if alias is not None: if alias in used_aliases: - raise CompileError( + raise exceptions.CompileError( 'Ambiguous column name {}.'.format(alias)) used_aliases.add(alias) diff --git a/tinyquery/compiler_test.py b/tinyquery/compiler_test.py index f05bb8c..7aeb52c 100644 --- a/tinyquery/compiler_test.py +++ b/tinyquery/compiler_test.py @@ -1,18 +1,21 @@ # TODO(colin): fix these lint errors (http://pep8.readthedocs.io/en/release-1.7.x/intro.html#error-codes) # pep8-disable:E122 +from __future__ import absolute_import + import collections import datetime import unittest -import compiler -import context -import runtime -import tinyquery -import tq_ast -import tq_modes -import tq_types -import type_context -import typed_ast +from tinyquery import exceptions +from tinyquery import compiler +from tinyquery import context +from tinyquery import runtime +from tinyquery import tinyquery +from tinyquery import tq_ast +from tinyquery import tq_modes +from tinyquery import tq_types +from tinyquery import type_context +from tinyquery import typed_ast class CompilerTest(unittest.TestCase): @@ -124,7 +127,7 @@ def assert_compiled_select(self, text, expected_ast): self.assertEqual(expected_ast, ast) def assert_compile_error(self, text): - self.assertRaises(compiler.CompileError, compiler.compile_text, + self.assertRaises(exceptions.CompileError, compiler.compile_text, text, self.tables_by_name) def make_type_context(self, table_column_type_triples, @@ -178,7 +181,7 @@ def test_unary_operator(self): ) def test_mistyped_unary_operator(self): - with self.assertRaises(compiler.CompileError) as context: + with self.assertRaises(exceptions.CompileError) as context: compiler.compile_text('SELECT -strings FROM rainbow_table', self.tables_by_name) self.assertTrue('Invalid type for operator' in str(context.exception)) @@ -187,12 +190,12 @@ def test_strange_arithmetic(self): try: compiler.compile_text('SELECT times + ints + floats + bools FROM ' 'rainbow_table', self.tables_by_name) - except compiler.CompileError: + except exceptions.CompileError: self.fail('Compiler exception on arithmetic across all numeric ' 'types.') def test_mistyped_binary_operator(self): - with self.assertRaises(compiler.CompileError) as context: + with self.assertRaises(exceptions.CompileError) as context: compiler.compile_text('SELECT ints CONTAINS strings FROM ' 'rainbow_table', self.tables_by_name) @@ -241,7 +244,7 @@ def test_function_calls(self): ) def test_mistyped_function_call(self): - with self.assertRaises(compiler.CompileError) as context: + with self.assertRaises(exceptions.CompileError) as context: compiler.compile_text('SELECT SUM(strings) FROM rainbow_table', self.tables_by_name) self.assertTrue('Invalid types for function' in str(context.exception)) @@ -1001,7 +1004,7 @@ def test_within_clause(self): self.make_type_context([])))) def test_within_clause_error(self): - with self.assertRaises(compiler.CompileError) as context: + with self.assertRaises(exceptions.CompileError) as context: compiler.compile_text( 'SELECT r1.s, COUNT(r1.s) WITHIN r2 AS ' 'num_s_in_r1 FROM record_table', diff --git a/tinyquery/context.py b/tinyquery/context.py index 9a3bae5..1b4cec1 100644 --- a/tinyquery/context.py +++ b/tinyquery/context.py @@ -2,13 +2,16 @@ It is the basic container for intermediate data when evaluating a query. """ +from __future__ import absolute_import import collections import itertools import logging -import repeated_util -import tq_modes +import six + +from tinyquery import repeated_util +from tinyquery import tq_modes class Context(object): @@ -25,7 +28,7 @@ class Context(object): """ def __init__(self, num_rows, columns, aggregate_context): assert isinstance(columns, collections.OrderedDict) - for (table_name, col_name), column in columns.iteritems(): + for (table_name, col_name), column in columns.items(): assert len(column.values) == num_rows, ( 'Column %s had %s rows, expected %s.' % ( (table_name, col_name), len(column.values), num_rows)) @@ -69,22 +72,22 @@ def context_from_table(table, type_context): The order of the columns in the type context must match the order of the columns in the table. """ - any_column = table.columns.itervalues().next() + any_column = table.columns[next(iter(table.columns))] new_columns = collections.OrderedDict([ (column_name, column) - for (column_name, column) in zip(type_context.columns.iterkeys(), - table.columns.itervalues()) + for (column_name, column) in zip(type_context.columns, + table.columns.values()) ]) return Context(len(any_column.values), new_columns, None) def context_with_overlayed_type_context(context, type_context): """Given a context, use the given type context for all column names.""" - any_column = context.columns.itervalues().next() + any_column = context.columns[next(iter(context.columns))] new_columns = collections.OrderedDict([ (column_name, column) - for (column_name, column) in zip(type_context.columns.iterkeys(), - context.columns.itervalues()) + for (column_name, column) in zip(type_context.columns, + context.columns.values()) ]) return Context(len(any_column.values), new_columns, None) @@ -94,7 +97,7 @@ def empty_context_from_type_context(type_context): result_columns = collections.OrderedDict( # TODO(Samantha): Fix this. Mode is not always nullable (col_name, Column(type=col_type, mode=tq_modes.NULLABLE, values=[])) - for col_name, col_type in type_context.columns.iteritems() + for col_name, col_type in type_context.columns.items() ) return Context(0, result_columns, None) @@ -115,11 +118,10 @@ def mask_context(context, mask): # behavior as function evaluation on repeated fields. Fix. if mask.mode == tq_modes.REPEATED: num_rows = len( - filter( - None, - (len(filter(None, row)) for row in mask.values))) + [r for r in (any(row) for row in mask.values) if r] + ) new_columns = collections.OrderedDict() - for col_name, col in context.columns.iteritems(): + for col_name, col in context.columns.items(): if col.mode == tq_modes.REPEATED: allowable = True new_values = [] @@ -182,18 +184,18 @@ def mask_context(context, mask): values=new_values) else: orig_column_values = [ - col.values for col in context.columns.itervalues()] + col.values for col in context.columns.values()] mask_values = mask.values - num_rows = len(filter(None, mask.values)) + num_rows = len([v for v in mask.values if v]) new_values = [ Column( type=col.type, mode=col.mode, values=list(itertools.compress(values, mask_values))) - for col, values in zip(context.columns.itervalues(), + for col, values in zip(context.columns.values(), orig_column_values)] new_columns = collections.OrderedDict([ - (name, col) for name, col in zip(context.columns.iterkeys(), + (name, col) for name, col in zip(context.columns, new_values)]) return Context( @@ -208,7 +210,7 @@ def empty_context_from_template(context): num_rows=0, columns=collections.OrderedDict( (name, empty_column_from_template(column)) - for name, column in context.columns.iteritems() + for name, column in context.columns.items() ), aggregate_context=None) @@ -224,7 +226,7 @@ def append_row_to_context(src_context, index, dest_context): The schemas of the two contexts must match. """ dest_context.num_rows += 1 - for name, column in dest_context.columns.iteritems(): + for name, column in dest_context.columns.items(): column.values.append(src_context.columns[name].values[index]) @@ -241,9 +243,9 @@ def append_partial_context_to_context(src_context, dest_context): # Ignore fully-qualified names for this operation. short_named_src_column_values = { col_name: column.values - for (_, col_name), column in src_context.columns.iteritems()} + for (_, col_name), column in src_context.columns.items()} - for (_, col_name), dest_column in dest_context.columns.iteritems(): + for (_, col_name), dest_column in dest_context.columns.items(): src_column_values = short_named_src_column_values.get(col_name) if src_column_values is None: dest_column.values.extend([None] * src_context.num_rows) @@ -258,7 +260,7 @@ def append_context_to_context(src_context, dest_context): account. """ dest_context.num_rows += src_context.num_rows - for dest_column_key, dest_column in dest_context.columns.iteritems(): + for dest_column_key, dest_column in dest_context.columns.items(): src_column = src_context.columns.get(dest_column_key) if src_column is None: dest_column.values.extend([None] * src_context.num_rows) @@ -272,7 +274,7 @@ def row_context_from_context(src_context, index): columns = collections.OrderedDict( (col_name, Column(type=col.type, mode=col.mode, values=[col.values[index]])) - for col_name, col in src_context.columns.iteritems() + for col_name, col in src_context.columns.items() ) return Context(1, columns, None) @@ -282,15 +284,15 @@ def cross_join_contexts(context1, context2): assert context2.aggregate_context is None result_columns = collections.OrderedDict( [(col_name, Column(type=col.type, mode=col.mode, values=[])) - for col_name, col in context1.columns.iteritems()] + + for col_name, col in context1.columns.items()] + [(col_name, Column(type=col.type, mode=col.mode, values=[])) - for col_name, col in context2.columns.iteritems()]) + for col_name, col in context2.columns.items()]) - for index1 in xrange(context1.num_rows): - for index2 in xrange(context2.num_rows): - for col_name, column in context1.columns.iteritems(): + for index1 in six.moves.xrange(context1.num_rows): + for index2 in six.moves.xrange(context2.num_rows): + for col_name, column in context1.columns.items(): result_columns[col_name].values.append(column.values[index1]) - for col_name, column in context2.columns.iteritems(): + for col_name, column in context2.columns.items(): result_columns[col_name].values.append(column.values[index2]) return Context(context1.num_rows * context2.num_rows, result_columns, None) @@ -304,5 +306,5 @@ def truncate_context(context, limit): return context.num_rows = limit - for column in context.columns.itervalues(): + for column in context.columns.values(): column.values[limit:] = [] diff --git a/tinyquery/evaluator.py b/tinyquery/evaluator.py index 6d803fc..35a110e 100644 --- a/tinyquery/evaluator.py +++ b/tinyquery/evaluator.py @@ -1,12 +1,16 @@ # TODO(colin): fix these lint errors (http://pep8.readthedocs.io/en/release-1.7.x/intro.html#error-codes) # pep8-disable:E115,E128 +from __future__ import absolute_import + import collections -import context -import tq_ast -import tq_modes -import typed_ast -import tq_types +import six + +from tinyquery import context +from tinyquery import tq_ast +from tinyquery import tq_modes +from tinyquery import typed_ast +from tinyquery import tq_types class Evaluator(object): @@ -50,7 +54,8 @@ def evaluate_select(self, select_ast): if select_ast.orderings is not None: result = self.evaluate_orderings(select_context, result, - select_ast.orderings) + select_ast.orderings, + select_ast.select_fields) if select_ast.limit is not None: context.truncate_context(result, select_ast.limit) @@ -105,7 +110,7 @@ def evaluate_groups(self, select_fields, group_set, select_context): # TODO: Seems pretty ugly and wasteful to use a whole context as a # group key. - for i in xrange(select_context.num_rows): + for i in six.moves.xrange(select_context.num_rows): key = self.get_group_key( field_groups, alias_group_list, select_context, alias_group_result_context, i) @@ -119,7 +124,7 @@ def evaluate_groups(self, select_fields, group_set, select_context): result_context = self.empty_context_from_select_fields(select_fields) result_col_names = [field.alias for field in select_fields] - for context_key, group_context in group_contexts.iteritems(): + for context_key, group_context in group_contexts.items(): group_eval_context = context.Context( 1, context_key.columns, group_context) group_aggregate_result_context = self.evaluate_select_fields( @@ -131,7 +136,7 @@ def evaluate_groups(self, select_fields, group_set, select_context): return result_context def evaluate_orderings(self, overall_context, select_context, - ordering_col): + ordering_col, select_fields): """ Evaluate a context and order it by a list of given columns. @@ -145,31 +150,44 @@ def evaluate_orderings(self, overall_context, select_context, is_ascending which is a boolean for the order in which the column has to be arranged (True for ascending and False for descending). + select_fields: A list of select fields that can be used to map + aliases back to the overall context Returns: A context with the results. """ + select_aliases = collections.OrderedDict( + (select_field.alias, (select_field.expr.table, select_field.expr.column)) + for select_field in select_fields + ) + assert select_context.aggregate_context is None all_values = [] sort_by_indexes = collections.OrderedDict() - for ((_, column_name), column) in overall_context.columns.iteritems(): + for ((_, column_name), column) in overall_context.columns.items(): all_values.append(column.values) for order_by_column in ordering_col: - for count, ((_, column_name), column) in enumerate( - overall_context.columns.iteritems()): - if order_by_column.column_id.name == column_name: + order_column_name = order_by_column.column_id.name + + for count, (column_identifier, column) in enumerate( + overall_context.columns.items()): + if ( + '%s.%s' % column_identifier == order_column_name + or select_aliases.get(order_column_name) == column_identifier + or order_column_name not in select_aliases and order_column_name == column_identifier[1] + ): sort_by_indexes[count] = order_by_column.is_ascending break reversed_sort_by_indexes = collections.OrderedDict( reversed(list(sort_by_indexes.items()))) - t_all_values = map(list, zip(*all_values)) - for index, is_ascending in reversed_sort_by_indexes.iteritems(): + t_all_values = [list(z) for z in zip(*all_values)] + for index, is_ascending in reversed_sort_by_indexes.items(): t_all_values.sort(key=lambda x: (x[index]), reverse=not is_ascending) - ordered_values = map(list, zip(*t_all_values)) + ordered_values = [list(z) for z in zip(*t_all_values)] # If we started evaluating an ordering over 0 rows, # all_values was originally [[], [], [], ...], i.e. the empty list for # each column, but now ordered_values is just the empty list, since @@ -181,9 +199,15 @@ def evaluate_orderings(self, overall_context, select_context, ordered_values = all_values for key in select_context.columns: - for count, (_, overall_key) in enumerate(overall_context.columns): + for count, overall_column_identifier in enumerate(overall_context.columns): overall_context_loop_break = False - if overall_key == key[1]: + if ( + key == overall_column_identifier + or not key[0] and ( + key[1] == '%s.%s' % overall_column_identifier + or select_aliases.get(key[1]) == overall_column_identifier + ) + ): select_context.columns[key] = context.Column( type=select_context.columns[key].type, mode=select_context.columns[key].mode, @@ -279,8 +303,8 @@ def evaluate_within(self, select_fields, group_set, ctx, ctx_with_primary_key = context.empty_context_from_template(ctx) context.append_context_to_context(ctx, ctx_with_primary_key) - (table_name, _), _ = ctx_with_primary_key.columns.items()[0] - row_nums = range(1, ctx_with_primary_key.num_rows + 1) + table_name = next(iter(ctx_with_primary_key.columns)) + row_nums = list(six.moves.xrange(1, ctx_with_primary_key.num_rows + 1)) row_nums_col = context.Column( type=tq_types.INT, mode=tq_modes.NULLABLE, values=row_nums) ctx_with_primary_key.columns[(table_name, @@ -364,7 +388,7 @@ def eval_table_TableUnion(self, table_expr): def eval_table_Join(self, table_expr): base_context = self.evaluate_table_expr(table_expr.base) rhs_tables, join_types = zip(*table_expr.tables) - other_contexts = map(self.evaluate_table_expr, rhs_tables) + other_contexts = [self.evaluate_table_expr(x) for x in rhs_tables] lhs_context = base_context @@ -382,7 +406,7 @@ def eval_table_Join(self, table_expr): lhs_key_refs = [cond.column1 for cond in conditions] rhs_key_refs = [cond.column2 for cond in conditions] rhs_key_contexts = {} - for i in xrange(rhs_context.num_rows): + for i in six.moves.xrange(rhs_context.num_rows): rhs_key = self.get_join_key(rhs_context, rhs_key_refs, i) if rhs_key not in rhs_key_contexts: rhs_key_contexts[rhs_key] = ( @@ -395,7 +419,7 @@ def eval_table_Join(self, table_expr): context.empty_context_from_template(lhs_context), context.empty_context_from_template(rhs_context)) - for i in xrange(lhs_context.num_rows): + for i in six.moves.xrange(lhs_context.num_rows): lhs_key = self.get_join_key(lhs_context, lhs_key_refs, i) lhs_row_context = context.row_context_from_context( lhs_context, i) @@ -468,7 +492,7 @@ def evaluate_AggregateFunctionCall(self, func_call, context): return func_call.func.evaluate(context.num_rows, *arg_results) def evaluate_Literal(self, literal, context_object): - values = [literal.value for _ in xrange(context_object.num_rows)] + values = [literal.value for _ in six.moves.xrange(context_object.num_rows)] return context.Column(type=literal.type, mode=tq_modes.NULLABLE, values=values) diff --git a/tinyquery/evaluator_test.py b/tinyquery/evaluator_test.py index 6713f78..0f177b9 100644 --- a/tinyquery/evaluator_test.py +++ b/tinyquery/evaluator_test.py @@ -1,15 +1,17 @@ # TODO(colin): fix these lint errors (http://pep8.readthedocs.io/en/release-1.7.x/intro.html#error-codes) # pep8-disable:E122,E127,E128 +from __future__ import absolute_import + import collections import contextlib import datetime import mock import unittest -import context -import tinyquery -import tq_modes -import tq_types +from tinyquery import context +from tinyquery import tinyquery +from tinyquery import tq_modes +from tinyquery import tq_types # TODO(Samantha): Not all modes are nullable. @@ -659,6 +661,16 @@ def test_order_no_rows(self): self.make_context([ ('str', tq_types.STRING, [])])) + def test_order_aggregate(self): + self.skipTest("Ordering by an aggregate field is not yet supported") + # TODO: this is not yet supported + self.assert_query_result( + 'SELECT val1, MAX(val2) as m FROM test_table GROUP BY val1 ORDER BY m', + self.make_context([ + ('val1', tq_types.INT, [1, 2, 8, 4]), + ('m', tq_types.INT, [2, 4, 6, 8]), + ])) + def test_select_multiple_tables(self): self.assert_query_result( 'SELECT val1, val2, val3 FROM test_table, test_table_2', @@ -802,6 +814,53 @@ def test_repeated_select_from_join(self): collections.OrderedDict([((None, 'i'), expected_column)]), None)) + def test_join_ordering(self): + # Using aliases + self.assert_query_result( + 'SELECT t1.val1 as v1, t1.val2 as v2, t3.foo as foo, t3.bar as bar FROM test_table t1' + ' JOIN test_table_3 t3 ON t1.val1 = t3.foo ORDER BY v2, bar', + self.make_context([ + ('v1', tq_types.INT, [1, 1, 1, 1, 2, 4]), + ('v2', tq_types.INT, [1, 1, 2, 2, 6, 8]), + ('foo', tq_types.INT, [1, 1, 1, 1, 2, 4]), + ('bar', tq_types.INT, [1, 2, 1, 2, 7, 3]), + ])) + # Not using aliases + self.assert_query_result( + 'SELECT t1.val1, t1.val2, t3.foo, t3.bar FROM test_table t1' + ' JOIN test_table_3 t3 ON t1.val1 = t3.foo ORDER BY val2, bar', + self.make_context([ + ('t1.val1', tq_types.INT, [1, 1, 1, 1, 2, 4]), + ('t1.val2', tq_types.INT, [1, 1, 2, 2, 6, 8]), + ('t3.foo', tq_types.INT, [1, 1, 1, 1, 2, 4]), + ('t3.bar', tq_types.INT, [1, 2, 1, 2, 7, 3]), + ])) + + def test_join_ordering_duplicate_column_names(self): + self.assert_query_result( + 'SELECT t1.val1 as v1, t2.val2 as v2 FROM test_table t1' + ' JOIN test_table_2 t2 ON t1.val1 = t2.val3' + ' ORDER BY v2', + self.make_context([ + ('v1', tq_types.INT, [8]), + ('v2', tq_types.INT, [7]), + ])) + + def test_order_without_select(self): + self.assert_query_result( + 'SELECT val1 FROM test_table ORDER BY val2', + self.make_context([ + ('val1', tq_types.INT, [1, 1, 8, 2, 4]) + ])) + self.assert_query_result( + 'SELECT t1.val1, t1.val2, t3.foo FROM test_table t1' + ' JOIN test_table_3 t3 ON t1.val1 = t3.foo ORDER BY val2, bar', + self.make_context([ + ('t1.val1', tq_types.INT, [1, 1, 1, 1, 2, 4]), + ('t1.val2', tq_types.INT, [1, 1, 2, 2, 6, 8]), + ('t3.foo', tq_types.INT, [1, 1, 1, 1, 2, 4]), + ])) + def test_null_test(self): self.assert_query_result( 'SELECT foo IS NULL, foo IS NOT NULL FROM null_table', @@ -831,7 +890,7 @@ def test_hash(self): 'SELECT HASH(floats) FROM rainbow_table', self.make_context([ ('f0_', tq_types.INT, - map(hash, [1.41, 2.72, float('infinity')]))])) + [hash(x) for x in [1.41, 2.72, float('infinity')]])])) def test_null_hash(self): self.assert_query_result( @@ -945,28 +1004,28 @@ def test_null_count_distinct(self): 'SELECT COUNT(DISTINCT val1) FROM some_nulls_table', self.make_context([('f0_', tq_types.INT, [2])])) - def test_group_concat_unquoted(self): + def test_string_agg(self): self.assert_query_result( - 'SELECT GROUP_CONCAT_UNQUOTED(str) FROM string_table', + 'SELECT STRING_AGG(str) FROM string_table', self.make_context([ ('f0_', tq_types.STRING, ['hello,world']) ])) self.assert_query_result( - 'SELECT GROUP_CONCAT_UNQUOTED(children.name) FROM record_table_2', + 'SELECT STRING_AGG(children.name) FROM record_table_2', self.make_context([ ('f0_', tq_types.STRING, ['Jane,John,Earl,Sam,Kit']) ])) - def test_null_group_concat_unquoted(self): + def test_null_string_agg(self): self.assert_query_result( - 'SELECT GROUP_CONCAT_UNQUOTED(str) FROM string_table_with_null', + 'SELECT STRING_AGG(str) FROM string_table_with_null', self.make_context([ ('f0_', tq_types.STRING, ['hello,world']) ])) - def test_group_concat_unquoted_separator(self): + def test_string_agg_separator(self): self.assert_query_result( - 'SELECT GROUP_CONCAT_UNQUOTED(str, \' || \') FROM string_table', + 'SELECT STRING_AGG(str, \' || \') FROM string_table', self.make_context([ ('f0_', tq_types.STRING, ['hello || world']) ])) @@ -1248,6 +1307,11 @@ def test_other_timestamp_functions(self): 'SELECT STRFTIME_UTC_USEC(1274259481071200, "%Y-%m-%d")', self.make_context([ ('f0_', tq_types.STRING, ['2010-05-19'])])) + + self.assert_query_result( + 'SELECT FORMAT_TIMESTAMP("%Y-%m-%d", TIMESTAMP("2015-01-02 00:00:00"))', + self.make_context([ + ('f0_', tq_types.STRING, ['2015-01-02'])])) self.assert_query_result( 'SELECT USEC_TO_TIMESTAMP(1349053323000000)', @@ -1280,6 +1344,11 @@ def test_other_timestamp_functions(self): self.make_context([ ('f0_', tq_types.INT, [1262304000000000])])) + def test_replace(self): + self.assert_query_result( + "SELECT REPLACE(str, 'o', 'e') FROM string_table_with_null", + self.make_context([('f0_', tq_types.STRING, ["helle", "werld", None])])) + def test_first(self): # Test over the equivalent of a GROUP BY self.assert_query_result( @@ -1296,6 +1365,22 @@ def test_first(self): ]) ) + def test_last(self): + # Test over the equivalent of a GROUP BY + self.assert_query_result( + 'SELECT LAST(val1) FROM test_table', + self.make_context([ + ('f0_', tq_types.INT, [2]) + ]) + ) + # Test over something repeated + self.assert_query_result( + 'SELECT LAST(QUANTILES(val1, 3)) FROM test_table', + self.make_context([ + ('f0_', tq_types.INT, [8]) + ]) + ) + # TODO(colin): test behavior on empty list in both cases def test_left(self): diff --git a/tinyquery/exceptions.py b/tinyquery/exceptions.py new file mode 100644 index 0000000..412cc61 --- /dev/null +++ b/tinyquery/exceptions.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import + + +class CompileError(Exception): + pass diff --git a/tinyquery/lexer.py b/tinyquery/lexer.py index 61cccd3..39eda47 100644 --- a/tinyquery/lexer.py +++ b/tinyquery/lexer.py @@ -1,4 +1,5 @@ """The lexer turns a query string into a stream of tokens.""" +from __future__ import absolute_import from ply import lex @@ -63,7 +64,7 @@ 'FLOAT', 'ID', 'STRING' -] + reserved_words.values() +] + list(reserved_words.values()) # wrapping with list() to support python 3 t_PLUS = r'\+' diff --git a/tinyquery/lexer_test.py b/tinyquery/lexer_test.py index 3d1354c..3b959a4 100644 --- a/tinyquery/lexer_test.py +++ b/tinyquery/lexer_test.py @@ -1,6 +1,8 @@ +from __future__ import absolute_import + import unittest -import lexer +from tinyquery import lexer plus = ('PLUS', '+') diff --git a/tinyquery/parser.py b/tinyquery/parser.py index 7949617..867d9d0 100644 --- a/tinyquery/parser.py +++ b/tinyquery/parser.py @@ -1,10 +1,12 @@ """The parser turns a stream of tokens into an AST.""" +from __future__ import absolute_import + import os from ply import yacc -import tq_ast -import lexer +from tinyquery import tq_ast +from tinyquery import lexer tokens = lexer.tokens @@ -484,6 +486,6 @@ def parse_text(text): if should_rebuild_parser: parser = yacc.yacc() else: - import parsetab + from tinyquery import parsetab parser = yacc.yacc(debug=0, write_tables=0, tabmodule=parsetab) return parser.parse(text, lexer=lexer.get_lexer()) diff --git a/tinyquery/parser_test.py b/tinyquery/parser_test.py index 370d7c0..139107f 100644 --- a/tinyquery/parser_test.py +++ b/tinyquery/parser_test.py @@ -1,7 +1,9 @@ +from __future__ import absolute_import + import unittest -import tq_ast -import parser +from tinyquery import tq_ast +from tinyquery import parser def literal(value): diff --git a/tinyquery/repeated_util.py b/tinyquery/repeated_util.py index d88ca97..f715ce8 100644 --- a/tinyquery/repeated_util.py +++ b/tinyquery/repeated_util.py @@ -5,8 +5,9 @@ These functions allow us to flatten into non-repeated columns to apply various operations and then unflatten back into repeated columns afterwards. """ +from __future__ import absolute_import -import tq_modes +from tinyquery import tq_modes def rebuild_column_values(repetitions, values, result): @@ -80,7 +81,8 @@ def flatten_column_values(repeated_column_indices, column_values): values. The list for each column will not contain nested lists. """ - rows = zip(*column_values) + # wrapping in list for python 3 support + rows = list(zip(*column_values)) repetition_counts = [ max(max(len(row[idx]) for idx in repeated_column_indices), 1) for row in rows diff --git a/tinyquery/runtime.py b/tinyquery/runtime.py index b8c92a5..5051ebb 100644 --- a/tinyquery/runtime.py +++ b/tinyquery/runtime.py @@ -1,4 +1,6 @@ """Implementation of the standard built-in functions.""" +from __future__ import absolute_import + import abc import datetime import functools @@ -9,12 +11,13 @@ import time import arrow +import six -import compiler -import context -import repeated_util -import tq_types -import tq_modes +from tinyquery import exceptions +from tinyquery import context +from tinyquery import repeated_util +from tinyquery import tq_types +from tinyquery import tq_modes def pass_through_none(fn): @@ -158,9 +161,8 @@ def check_types(self, type1, type2): return tq_types.INT def _evaluate(self, num_rows, column1, column2): - values = map(lambda (x, y): - None if None in (x, y) else self.func(x, y), - zip(column1.values, column2.values)) + values = [None if None in (x, y) else self.func(x, y) + for x, y in zip(column1.values, column2.values)] # TODO(Samantha): Code smell incoming t = self.check_types(column1.type, column2.type) return context.Column(type=t, mode=tq_modes.NULLABLE, values=values) @@ -210,8 +212,8 @@ def _evaluate(self, num_rows, column1, column2): if other_column.type == tq_types.STRING: # Convert that string to datetime if we can. try: - converted = map(lambda x: arrow.get(x).to('UTC').naive, - other_column.values) + converted = [arrow.get(x).to('UTC').naive + for x in other_column.values] except: raise TypeError('Invalid comparison on timestamp, ' 'expected numeric type or ISO8601 ' @@ -219,10 +221,8 @@ def _evaluate(self, num_rows, column1, column2): elif other_column.type in tq_types.NUMERIC_TYPE_SET: # Cast that numeric to a float accounting for microseconds and # then to a datetime. - converted = map( - pass_through_none( - lambda x: arrow.get(float(x) / 1E6).to('UTC').naive), - other_column.values) + convert = pass_through_none(lambda x: arrow.get(float(x) / 1E6).to('UTC').naive) + converted = [convert(x) for x in other_column.values] else: # No other way to compare a timestamp with anything other than @@ -237,9 +237,8 @@ def _evaluate(self, num_rows, column1, column2): mode=other_column.mode, values=converted) - values = map(lambda (x, y): - None if None in (x, y) else self.func(x, y), - zip(column1.values, column2.values)) + values = [None if None in (x, y) else self.func(x, y) + for x, y in zip(column1.values, column2.values)] return context.Column(type=tq_types.BOOL, mode=tq_modes.NULLABLE, values=values) @@ -254,9 +253,8 @@ def check_types(self, type1, type2): return tq_types.BOOL def _evaluate(self, num_rows, column1, column2): - values = map(lambda (x, y): - None if None in (x, y) else self.func(x, y), - zip(column1.values, column2.values)) + values = [None if None in (x, y) else self.func(x, y) + for x, y in zip(column1.values, column2.values)] return context.Column(type=tq_types.BOOL, mode=tq_modes.NULLABLE, values=values) @@ -271,7 +269,7 @@ def check_types(self, arg): return tq_types.INT def _evaluate(self, num_rows, column): - values = map(self.func, column.values) + values = [self.func(x) for x in column.values] return context.Column(type=tq_types.INT, mode=tq_modes.NULLABLE, values=values) @@ -284,7 +282,7 @@ def check_types(self, arg): return tq_types.BOOL def _evaluate(self, num_rows, column): - values = map(self.func, column.values) + values = [self.func(x) for x in column.values] return context.Column(type=tq_types.BOOL, mode=tq_modes.NULLABLE, values=values) @@ -302,7 +300,7 @@ def check_types(self, arg): return tq_types.FLOAT def _evaluate(self, num_rows, column): - values = map(self.func, column.values) + values = [self.func(x) for x in column.values] return context.Column(type=tq_types.FLOAT, mode=tq_modes.NULLABLE, values=values) @@ -341,8 +339,8 @@ def check_types(self, arg1, arg2): def _evaluate(self, num_rows, column1, column2): t = self.check_types(column1.type, column2.type) - values = map(lambda (x, y): x if x is not None else y, - zip(column1.values, column2.values)) + values = [x if x is not None else y + for x, y in zip(column1.values, column2.values)] return context.Column(type=t, mode=tq_modes.NULLABLE, values=values) @@ -362,11 +360,11 @@ def _evaluate(self, num_rows, *cols): rows = zip(*[col.values for col in cols]) def first_nonnull(row): - result = filter(lambda x: x is not None, row) - if result: - return result[0] + for x in row: + if x is not None: + return x return None - values = map(first_nonnull, rows) + values = [first_nonnull(r) for r in rows] return context.Column(type=result_type, mode=tq_modes.NULLABLE, values=values) @@ -377,7 +375,8 @@ def check_types(self, arg): def _evaluate(self, num_rows, column): # TODO: Use CityHash. - values = map(pass_through_none(hash), column.values) + hash_fn = pass_through_none(hash) + values = [hash_fn(x) for x in column.values] return context.Column(type=tq_types.INT, mode=tq_modes.NULLABLE, values=values) @@ -389,8 +388,8 @@ def check_types(self, arg): return tq_types.FLOAT def _evaluate(self, num_rows, column): - values = map(pass_through_none(math.floor), - column.values) + floor = pass_through_none(math.floor) + values = [floor(x) for x in column.values] return context.Column(type=tq_types.FLOAT, mode=tq_modes.NULLABLE, values=values) @@ -412,7 +411,7 @@ def string_converter(arg): converter = string_converter elif column.type == tq_types.TIMESTAMP: return timestamp_to_usec.evaluate(num_rows, column) - values = map(converter, column.values) + values = [converter(x) for x in column.values] return context.Column(type=tq_types.INT, mode=tq_modes.NULLABLE, values=values) @@ -422,7 +421,7 @@ def check_types(self): return tq_types.FLOAT def _evaluate(self, num_rows): - values = [random.random() for _ in xrange(num_rows)] + values = [random.random() for _ in six.moves.xrange(num_rows)] # TODO(Samantha): Should this be required? return context.Column(type=tq_types.FLOAT, mode=tq_modes.NULLABLE, values=values) @@ -558,6 +557,24 @@ def _evaluate(self, num_rows, column): values=values) +class LastFunction(AggregateFunction): + def check_types(self, rep_list_type): + return rep_list_type + + def _evaluate(self, num_rows, column): + values = [] + if len(column.values) == 0: + values = [None] + + if column.mode == tq_modes.REPEATED: + values = [repeated_row[-1] if len(repeated_row) > 0 else None + for repeated_row in column.values] + else: + values = [column.values[-1]] + return context.Column(type=column.type, mode=tq_modes.NULLABLE, + values=values) + + class NoArgFunction(ScalarFunction): def __init__(self, func, return_type=tq_types.INT): self.func = func @@ -568,7 +585,7 @@ def check_types(self): def _evaluate(self, num_rows): return context.Column(type=self.type, mode=tq_modes.NULLABLE, - values=[self.func() for _ in xrange(num_rows)]) + values=[self.func() for _ in six.moves.xrange(num_rows)]) class InFunction(ScalarFunction): @@ -578,8 +595,7 @@ def check_types(self, arg1, *arg_types): def _evaluate(self, num_rows, arg1, *other_args): values = [val1 in val_list for val1, val_list in zip(arg1.values, - zip(*(map(lambda x: x.values, - other_args))))] + zip(*[x.values for x in other_args]))] return context.Column(type=tq_types.BOOL, mode=tq_modes.NULLABLE, values=values) @@ -591,8 +607,8 @@ def check_types(self, *arg_types): return tq_types.STRING def _evaluate(self, num_rows, *columns): - values = map(lambda strs: None if None in strs else ''.join(strs), - zip(*map(lambda x: x.values, columns))) + values = [None if None in strs else ''.join(strs) + for strs in zip(*[x.values for x in columns])] return context.Column(tq_types.STRING, tq_modes.NULLABLE, values=values) @@ -602,7 +618,8 @@ def check_types(self, arg_type): return tq_types.STRING def _evaluate(self, num_rows, column): - values = map(pass_through_none(str), column.values) + pass_through_none_str = pass_through_none(str) + values = [pass_through_none_str(x) for x in column.values] return context.Column(type=tq_types.STRING, mode=tq_modes.NULLABLE, values=values) @@ -617,8 +634,7 @@ def check_types(self, arg): def _evaluate(self, num_rows, column): return context.Column(type=self.check_types(column.type), mode=tq_modes.NULLABLE, - values=[self.func(filter(lambda x: x is not None, - column.values))]) + values=[self.func([x for x in column.values if x is not None])]) class SumFunction(AggregateFunction): @@ -677,7 +693,7 @@ def _evaluate(self, num_rows, column): values=[len(set(values) - set([None]))]) -class GroupConcatUnquotedFunction(AggregateFunction): +class StringAggFunction(AggregateFunction): def check_types(self, *arg_types): return tq_types.STRING @@ -727,9 +743,9 @@ def _evaluate(self, num_rows, column, num_quantiles_list): # quantile, so we need one more set of brackets than you might expect. values = [[ sorted_args[ - min(len(sorted_args) * i / (num_quantiles - 1), + min(len(sorted_args) * i // (num_quantiles - 1), len(sorted_args) - 1) - ] for i in xrange(num_quantiles) + ] for i in six.moves.xrange(num_quantiles) ]] return context.Column(type=tq_types.INT, mode=tq_modes.REPEATED, values=values) @@ -743,9 +759,8 @@ def check_types(self, type1, type2): def _evaluate(self, num_rows, column1, column2): if len(column1.values) == len(column2.values): - values = map(lambda (v1, v2): None if None in (v1, v2) else - v2 in v1, - zip(column1.values, column2.values)) + values = [None if None in (v1, v2) else v2 in v1 + for v1, v2 in zip(column1.values, column2.values)] return context.Column(type=tq_types.BOOL, mode=tq_modes.NULLABLE, values=values) @@ -768,13 +783,12 @@ def _evaluate(self, num_rows, column): # epoch here, whereas arrow wants a unix timestamp, with possible # decimal part representing microseconds. converter = lambda ts: float(ts) / 1E6 + convert_fn = pass_through_none( + # arrow.get parses ISO8601 strings and int/float unix + # timestamps without a format parameter + lambda ts: arrow.get(converter(ts)).to('UTC').naive) try: - values = map( - pass_through_none( - # arrow.get parses ISO8601 strings and int/float unix - # timestamps without a format parameter - lambda ts: arrow.get(converter(ts)).to('UTC').naive), - column.values) + values = [convert_fn(x) for x in column.values] except: raise TypeError( 'TIMESTAMP requires an ISO8601 string or unix timestamp in ' @@ -794,7 +808,7 @@ def check_types(self, type1): return self.type def _evaluate(self, num_rows, column1): - values = map(self.extractor, column1.values) + values = [self.extractor(x) for x in column1.values] return context.Column(type=self.type, mode=tq_modes.NULLABLE, values=values) @@ -826,20 +840,19 @@ def adder(ts): year = ts.year + (ts.month - 1 + num_intervals) // 12 month = 1 + (ts.month - 1 + num_intervals) % 12 return ts.replace(year=year, month=month) - values = map(adder, timestamps.values) + values = [adder(x) for x in timestamps.values] elif interval_type == 'YEAR': - values = map( - pass_through_none( - lambda ts: ts.replace(year=(ts.year + num_intervals))), - timestamps.values) + convert_fn = pass_through_none( + lambda ts: ts.replace(year=(ts.year + num_intervals))) + values = [convert_fn(x) for x in timestamps.values] else: # All of the other valid options for bigquery are also valid # keyword arguments to datetime.timedelta, when lowercased and # pluralized. python_interval_name = interval_type.lower() + 's' delta = datetime.timedelta(**{python_interval_name: num_intervals}) - values = map(pass_through_none(lambda ts: ts + delta), - timestamps.values) + convert_fn = pass_through_none(lambda ts: ts + delta) + values = [convert_fn(x) for x in timestamps.values] return context.Column(type=tq_types.TIMESTAMP, mode=tq_modes.NULLABLE, values=values) @@ -853,9 +866,8 @@ def check_types(self, type1, type2): return tq_types.INT def _evaluate(self, num_rows, lhs_ts, rhs_ts): - values = map(lambda (lhs, rhs): None if None in (lhs, rhs) else - int(round((lhs - rhs).total_seconds() / 24 / 3600)), - zip(lhs_ts.values, rhs_ts.values)) + values = [None if None in (lhs, rhs) else int(round((lhs - rhs).total_seconds() / 24 / 3600)) + for lhs, rhs in zip(lhs_ts.values, rhs_ts.values)] return context.Column(type=tq_types.INT, mode=tq_modes.NULLABLE, values=values) @@ -911,7 +923,7 @@ def _year_truncate(self, ts): def _evaluate(self, num_rows, timestamps): truncate_fn = pass_through_none( getattr(self, '_%s_truncate' % self.interval)) - values = map(truncate_fn, timestamps.values) + values = [truncate_fn(x) for x in timestamps.values] return context.Column(type=tq_types.TIMESTAMP, mode=tq_modes.NULLABLE, values=values) @@ -938,11 +950,10 @@ def _evaluate(self, num_rows, unix_timestamps, weekdays): timestamps = TimestampFunction().evaluate(num_rows, unix_timestamps) truncated = TimestampShiftFunction('day').evaluate( num_rows, timestamps) - values = map( - pass_through_none( + convert = pass_through_none( lambda ts: ts + datetime.timedelta( - days=(weekday - self._weekday_from_ts(ts)))), - truncated.values) + days=(weekday - self._weekday_from_ts(ts)))) + values = [convert(x) for x in truncated.values] ts_result = context.Column( type=tq_types.TIMESTAMP, mode=tq_modes.NULLABLE, values=values) return timestamp_to_usec.evaluate(num_rows, ts_result) @@ -964,9 +975,25 @@ def check_types(self, type1, type2): def _evaluate(self, num_rows, unix_timestamps, formats): format_str = _ensure_literal(formats.values) timestamps = TimestampFunction().evaluate(num_rows, unix_timestamps) - values = map( - pass_through_none(lambda ts: ts.strftime(format_str)), - timestamps.values) + convert = pass_through_none(lambda ts: ts.strftime(format_str)) + values = [convert(x) for x in timestamps.values] + return context.Column(type=tq_types.STRING, mode=tq_modes.NULLABLE, + values=values) + + +class FormatTimestampFunction(ScalarFunction): + def check_types(self, type1, type2): + if not (type2 in tq_types.DATETIME_TYPE_SET and + type1 == tq_types.STRING): + raise TypeError('Expected a string and a date, got %s.' % ( + [type1, type2])) + return tq_types.STRING + + def _evaluate(self, num_rows, formats, unix_timestamps): + format_str = _ensure_literal(formats.values) + timestamps = TimestampFunction().evaluate(num_rows, unix_timestamps) + convert = pass_through_none(lambda ts: ts.strftime(format_str)) + values = [convert(x) for x in timestamps.values] return context.Column(type=tq_types.STRING, mode=tq_modes.NULLABLE, values=values) @@ -992,7 +1019,7 @@ def apply(*args): # is usually to return NULL if any arguments are NULL. if any(arg is None for arg in args): return None - return reduce(self.reducer, args) + return functools.reduce(self.reducer, args) values = [apply(*vals) for vals in zip(*[col.values for col in columns])] @@ -1001,6 +1028,21 @@ def apply(*args): values=values) +class ReplaceFunction(ScalarFunction): + def check_types(self, *arg_types): + if any(arg_type != tq_types.STRING for arg_type in arg_types): + raise TypeError('REPLACE only takes string arguments.') + return tq_types.STRING + + def _evaluate(self, num_rows, values, old, new): + values = [value.replace(old, new) if value is not None else None + for value, old, new in zip(values.values, + old.values, + new.values)] + return context.Column(tq_types.STRING, tq_modes.NULLABLE, + values=values) + + class JSONExtractFunction(ScalarFunction): """Extract from a JSON string based on a JSONPath expression. @@ -1038,9 +1080,11 @@ def _parse_property_name(self, json_path): raise ValueError( 'Invalid json path expression. Cannot end in ".".') prop_name_plus = json_path[1:] - next_separator_positions = filter( - lambda pos: pos != -1, - [prop_name_plus.find('.'), prop_name_plus.find('[')]) + next_separator_positions = [ + pos + for pos in [prop_name_plus.find('.'), prop_name_plus.find('[')] + if pos != -1 + ] if next_separator_positions: end_idx = min(next_separator_positions) @@ -1101,9 +1145,9 @@ def _extract_by_json_path(self, parsed_json_expr, json_path): def _evaluate(self, num_rows, json_expressions, json_paths): json_path = _ensure_literal(json_paths.values) - parsed_json = map( - pass_through_none(json.loads), - json_expressions.values) + json_load = pass_through_none(json.loads) + parsed_json = [json_load(x) + for x in json_expressions.values] if not json_path.startswith('$'): raise ValueError( 'Invalid json path expression. Must start with $.') @@ -1223,6 +1267,7 @@ def _evaluate(self, num_rows, json_expressions, json_paths): lambda dt: int(dt.strftime('%j'), 10), return_type=tq_types.INT), TimestampFunction()), + 'format_timestamp': FormatTimestampFunction(), 'format_utc_usec': Compose( TimestampExtractFunction( lambda dt: dt.strftime('%Y-%m-%d %H:%M:%S.%f'), @@ -1308,6 +1353,7 @@ def _evaluate(self, num_rows, json_expressions, json_paths): lambda dt: dt.year, return_type=tq_types.INT), TimestampFunction()), + 'replace': ReplaceFunction(), 'json_extract': JSONExtractFunction(), 'json_extract_scalar': JSONExtractFunction(scalar=True), } @@ -1320,10 +1366,12 @@ def _evaluate(self, num_rows, json_expressions, json_paths): 'count': CountFunction(), 'avg': AvgFunction(), 'count_distinct': CountDistinctFunction(), - 'group_concat_unquoted': GroupConcatUnquotedFunction(), + 'string_agg': StringAggFunction(), + 'group_concat_unquoted': StringAggFunction(), 'stddev_samp': StddevSampFunction(), 'quantiles': QuantilesFunction(), - 'first': FirstFunction() + 'first': FirstFunction(), + 'last': LastFunction(), } @@ -1345,7 +1393,7 @@ def get_func(name): elif name in _AGGREGATE_FUNCTIONS: return _AGGREGATE_FUNCTIONS[name] else: - raise compiler.CompileError('Unknown function: {}'.format(name)) + raise exceptions.CompileError('Unknown function: {}'.format(name)) def is_aggregate_func(name): diff --git a/tinyquery/tinyquery.py b/tinyquery/tinyquery.py index f5ebf46..8d411b7 100644 --- a/tinyquery/tinyquery.py +++ b/tinyquery/tinyquery.py @@ -1,12 +1,16 @@ """Implementation of the TinyQuery service.""" +from __future__ import absolute_import + import collections import json -import compiler -import context -import evaluator -import tq_modes -import tq_types +import six + +from tinyquery import compiler +from tinyquery import context +from tinyquery import evaluator +from tinyquery import tq_modes +from tinyquery import tq_types class TinyQueryError(Exception): @@ -35,7 +39,7 @@ def load_table_from_csv(self, table_name, raw_schema, filename): 'Expected {} tokens on line {}, but got {}'.format( len(result_table.columns), line, len(tokens))) for token, column in zip(tokens, - result_table.columns.itervalues()): + result_table.columns.values()): # Run a casting function over the value we are given. # CSV doesn't have a null value, so the string 'null' is # used as the null value. @@ -89,9 +93,9 @@ def run_cast_function(key, mode, value): if value is None: return None elif mode == tq_modes.REPEATED: - return map(cast_function, value) + return [cast_function(x) for x in value] else: - if isinstance(value, str): + if isinstance(value, six.binary_type): return cast_function(value.decode('utf-8')) else: return cast_function(value) @@ -143,7 +147,7 @@ def flatten_row(output, row, schema, prefix='', ever_repeated=False): return output def process_row(row): - for (key, value) in row.iteritems(): + for (key, value) in row.items(): mode = result_table.columns[key].mode token = run_cast_function(key, mode, value) if not tq_modes.check_mode(token, mode): @@ -157,6 +161,7 @@ def process_row(row): row = json.loads(line) flattened_row = flatten_row({}, row, fake_raw_schema) process_row(flattened_row) + result_table.num_rows += 1 self.load_table_or_view(result_table) @@ -199,7 +204,7 @@ def get_all_tables(self): def get_table_names_for_dataset(self, dataset): # TODO(alan): Improve this to use a more first-class dataset structure. return [full_table[len(dataset + '.'):] - for full_table in self.tables_by_name.iterkeys() + for full_table in self.tables_by_name if full_table.startswith(dataset + '.')] def get_all_table_info_in_dataset(self, project_id, dataset): @@ -226,7 +231,7 @@ def get_table_info(self, project, dataset, table_name): table = self.tables_by_name[dataset + '.' + table_name] schema_fields = [] # TODO(colin): record fields should appear grouped. - for col_name, column in table.columns.iteritems(): + for col_name, column in table.columns.items(): schema_fields.append({ 'name': col_name, 'type': column.type, @@ -293,7 +298,7 @@ def run_query_job(self, project_id, query, dest_dataset, dest_table_name, def table_from_context(table_name, ctx): return Table(table_name, ctx.num_rows, collections.OrderedDict( (col_name, column) - for (_, col_name), column in ctx.columns.iteritems() + for (_, col_name), column in ctx.columns.items() )) def run_copy_job(self, project_id, src_dataset, src_table_name, @@ -337,7 +342,7 @@ def load_empty_table_from_template(self, table_name, template_table): # TODO(Samantha): This shouldn't just be nullable. (col_name, context.Column(type=col.type, mode=tq_modes.NULLABLE, values=[])) - for col_name, col in template_table.columns.iteritems() + for col_name, col in template_table.columns.items() ) table = Table(table_name, 0, columns) self.load_table_or_view(table) @@ -345,13 +350,13 @@ def load_empty_table_from_template(self, table_name, template_table): @staticmethod def clear_table(table): table.num_rows = 0 - for column in table.columns.itervalues(): + for column in table.columns.values(): column.values[:] = [] @staticmethod def append_to_table(src_table, dest_table): dest_table.num_rows += src_table.num_rows - for col_name, column in dest_table.columns.iteritems(): + for col_name, column in dest_table.columns.items(): if col_name in src_table.columns: column.values.extend(src_table.columns[col_name].values) else: @@ -378,8 +383,8 @@ class Table(object): """ def __init__(self, name, num_rows, columns): assert isinstance(columns, collections.OrderedDict) - for col_name, column in columns.iteritems(): - assert isinstance(col_name, basestring) + for col_name, column in columns.items(): + assert isinstance(col_name, six.string_types) assert len(column.values) == num_rows, ( 'Column %s had %s rows, expected %s.' % ( col_name, len(column.values), num_rows)) diff --git a/tinyquery/tinyquery_test.py b/tinyquery/tinyquery_test.py index 40e5f74..2ea9e77 100644 --- a/tinyquery/tinyquery_test.py +++ b/tinyquery/tinyquery_test.py @@ -1,7 +1,9 @@ +from __future__ import absolute_import + import json import unittest -import tinyquery +from tinyquery import tinyquery class TinyQueryTest(unittest.TestCase): @@ -66,6 +68,7 @@ def test_make_empty_table(self): table = tinyquery.TinyQuery.make_empty_table( 'test_table', self.record_schema) self.assertIn('r.r2.d2', table.columns) + self.assertEqual(table.num_rows, 0) def test_load_table_from_newline_delimited_json(self): record_json = json.dumps({ @@ -86,6 +89,7 @@ def test_load_table_from_newline_delimited_json(self): table = tq.tables_by_name['test_table'] self.assertIn('r.r2.d2', table.columns) self.assertIn(3, table.columns['r.r2.d2'].values) + self.assertEqual(table.num_rows, 1) def test_load_json_with_null_records(self): record_json = json.dumps({ @@ -135,3 +139,21 @@ def test_load_json_with_repeated_records(self): ['a', 'b', 'c', 'd', 'e']) self.assertEqual(table.columns['r.inner_repeated'].values[0], ['l', 'm', 'n']) + + def test_load_table_multiple_rows_count(self): + record_json = json.dumps({ + 'i': 1, + 'r': { + 's': 'hello!', + 'r2': { + 'd2': 3, + }, + }, + }) + tq = tinyquery.TinyQuery() + tq.load_table_from_newline_delimited_json( + 'test_table', + json.dumps(self.record_schema['fields']), + [record_json, record_json, record_json, record_json]) + table = tq.tables_by_name['test_table'] + self.assertEqual(table.num_rows, 4) diff --git a/tinyquery/tq_ast.py b/tinyquery/tq_ast.py index e8c362f..0ebd502 100644 --- a/tinyquery/tq_ast.py +++ b/tinyquery/tq_ast.py @@ -3,6 +3,7 @@ This AST format is desinged to be easy to parse into. See typed_ast for the AST format that is used during the evaluation step. """ +from __future__ import absolute_import import collections diff --git a/tinyquery/tq_modes.py b/tinyquery/tq_modes.py index 91f293c..d025f8e 100644 --- a/tinyquery/tq_modes.py +++ b/tinyquery/tq_modes.py @@ -1,5 +1,6 @@ """ Defines the valid modes. Currently we just use strings to identify them. """ +from __future__ import absolute_import NULLABLE = "NULLABLE" REQUIRED = "REQUIRED" diff --git a/tinyquery/tq_types.py b/tinyquery/tq_types.py index 572778b..41b3dfb 100644 --- a/tinyquery/tq_types.py +++ b/tinyquery/tq_types.py @@ -1,6 +1,9 @@ """Defines the valid types. Currently we just uses strings to identify them. """ +from __future__ import absolute_import + import arrow +import six # TODO(Samantha): Structs. @@ -20,11 +23,11 @@ INT: int, FLOAT: float, BOOL: bool, - STRING: unicode, + STRING: six.text_type, TIMESTAMP: lambda val: arrow.get(val).to('UTC').naive, NONETYPE: lambda _: None, 'null': lambda _: None } DATETIME_TYPE_SET = set([INT, STRING, TIMESTAMP]) -TYPE_TYPE = basestring +TYPE_TYPE = six.string_types diff --git a/tinyquery/type_context.py b/tinyquery/type_context.py index 9c2a7e8..37a3b79 100644 --- a/tinyquery/type_context.py +++ b/tinyquery/type_context.py @@ -1,9 +1,13 @@ +from __future__ import absolute_import + import collections import re -import compiler -import tq_types -import typed_ast +import six + +from tinyquery import exceptions +from tinyquery import tq_types +from tinyquery import typed_ast # TODO(Samantha): Should checking modes go here? @@ -41,7 +45,7 @@ def from_table_and_columns(cls, table_name, columns_without_table, collections.OrderedDict( ((table_name, column_name), col_type) for column_name, col_type - in columns_without_table.iteritems()), + in columns_without_table.items()), implicit_column_context, aggregate_context) @staticmethod @@ -54,10 +58,10 @@ def assert_type(value, expected_type): def from_full_columns(cls, full_columns, implicit_column_context=None, aggregate_context=None): """Given just the columns field, fill in alias information.""" - for (table_name, col_name), col_type in full_columns.iteritems(): + for (table_name, col_name), col_type in full_columns.items(): if table_name is not None: - cls.assert_type(table_name, basestring) - cls.assert_type(col_name, basestring) + cls.assert_type(table_name, six.string_types) + cls.assert_type(col_name, six.string_types) cls.assert_type(col_type, tq_types.TYPE_TYPE) aliases = {} @@ -88,12 +92,12 @@ def union_contexts(cls, contexts): for context in contexts: assert context.aggregate_context is None - for (_, column_name), col_type in context.columns.iteritems(): + for (_, column_name), col_type in context.columns.items(): full_column = (None, column_name) if full_column in result_columns: if result_columns[full_column] == col_type: continue - raise compiler.CompileError( + raise exceptions.CompileError( 'Incompatible types when performing union on field ' '{}: {} vs. {}'.format(full_column, result_columns[full_column], @@ -134,12 +138,12 @@ def column_ref_for_name(self, name): if len(possible_results) == 1: return possible_results[0] elif len(possible_results) > 1: - raise compiler.CompileError('Ambiguous field: {}'.format(name)) + raise exceptions.CompileError('Ambiguous field: {}'.format(name)) else: if self.implicit_column_context is not None: return self.implicit_column_context.column_ref_for_name(name) else: - raise compiler.CompileError('Field not found: {}'.format(name)) + raise exceptions.CompileError('Field not found: {}'.format(name)) def context_with_subquery_alias(self, subquery_alias): """Handle the case where a subquery has an alias. @@ -153,7 +157,7 @@ def context_with_subquery_alias(self, subquery_alias): collections.OrderedDict( ((subquery_alias, col_name), col_type) for (_, col_name), col_type - in self.implicit_column_context.columns.iteritems() + in self.implicit_column_context.columns.items() ) ) return TypeContext(self.columns, self.aliases, self.ambig_aliases, @@ -163,7 +167,7 @@ def context_with_full_alias(self, alias): assert self.aggregate_context is None new_columns = collections.OrderedDict( ((alias, col_name), col_type) - for (_, col_name), col_type in self.columns.iteritems() + for (_, col_name), col_type in self.columns.items() ) if self.implicit_column_context: new_implicit_column_context = ( diff --git a/tinyquery/typed_ast.py b/tinyquery/typed_ast.py index dbc0f43..e9a4bca 100644 --- a/tinyquery/typed_ast.py +++ b/tinyquery/typed_ast.py @@ -1,8 +1,9 @@ """A set of AST classes with types and aliases filled in.""" +from __future__ import absolute_import import collections -import type_context -import tq_modes + +from tinyquery import tq_modes class Select(collections.namedtuple( @@ -74,6 +75,7 @@ def __init__(self, *_, **__): class NoTable(collections.namedtuple('NoTable', []), TableExpression): @property def type_ctx(self): + from tinyquery import type_context # To avoid circular import return type_context.TypeContext.from_full_columns( collections.OrderedDict())