From c00bbdcbaa0e8d6ab5093fc3b26e4f40f5a29d9f Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Fri, 27 Oct 2023 05:34:29 -0400 Subject: [PATCH 1/6] Support PL/Container --- greenplumpython/func.py | 125 +++++++++++++++++++++++--------------- greenplumpython/type.py | 31 ++++++---- tests/test_plcontainer.py | 23 +++++++ tests/test_type.py | 6 +- 4 files changed, 120 insertions(+), 65 deletions(-) create mode 100644 tests/test_plcontainer.py diff --git a/greenplumpython/func.py b/greenplumpython/func.py index f4676f5e..f071733d 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -18,8 +18,9 @@ from greenplumpython.db import Database from greenplumpython.expr import Expr, _serialize_to_expr from greenplumpython.group import DataFrameGroupingSet -from greenplumpython.type import _serialize_to_type +from greenplumpython.type import _serialize_to_type_name, _defined_types +import psycopg2 class FunctionExpr(Expr): """ @@ -111,52 +112,68 @@ def apply( if grouping_col_names is not None and len(grouping_col_names) != 0 else None ) - unexpanded_dataframe = DataFrame( - " ".join( + try: + return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType + _serialize_to_type_name(return_annotation, db=db, for_return=True) + return DataFrame( + f""" + SELECT * FROM plcontainer_apply(TABLE( + SELECT * {from_clause}), '{self._function._qualified_name_str}', 4096) AS + {_defined_types[return_annotation.__args__[0]]._serialize(db=db)} + """, + db=db, + parents=parents, + ) + except psycopg2.errors.InternalError_: + unexpanded_dataframe = DataFrame( + " ".join( + [ + f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}", + ("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "", + from_clause, + group_by_clause, + ] + ), + db=db, + parents=parents, + ) + # We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a + # function returning records that contains more than one attributes + # will be called multiple times if we do + # ```sql + # SELECT (func(a, b)).* FROM t; + # ``` + # which might cause performance issue. To workaround we need to do + # ```sql + # WITH func_call AS ( + # SELECT func(a, b) AS result FROM t + # ) + # SELECT (result).* FROM func_call; + # ``` + rebased_grouping_cols = ( [ - f"SELECT {_serialize_to_expr(self, db=db)} {'AS ' + column_name if column_name is not None else ''}", - ("," + ",".join(grouping_cols)) if (grouping_cols is not None) else "", - from_clause, - group_by_clause, + _serialize_to_expr(unexpanded_dataframe[name], db=db) + for name in grouping_col_names ] - ), - db=db, - parents=parents, - ) - # We use 2 `DataFrame`s here because on GPDB 6X and PostgreSQL <= 9.6, a - # function returning records that contains more than one attributes - # will be called multiple times if we do - # ```sql - # SELECT (func(a, b)).* FROM t; - # ``` - # which might cause performance issue. To workaround we need to do - # ```sql - # WITH func_call AS ( - # SELECT func(a, b) AS result FROM t - # ) - # SELECT (result).* FROM func_call; - # ``` - rebased_grouping_cols = ( - [_serialize_to_expr(unexpanded_dataframe[name], db=db) for name in grouping_col_names] - if grouping_col_names is not None - else None - ) - result_cols = ( - _serialize_to_expr(unexpanded_dataframe["*"], db=db) - if not expand - else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db) - # `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())` - if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0 - else f"({unexpanded_dataframe._name}).*" - if not expand - else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}" - ) + if grouping_col_names is not None + else None + ) + result_cols = ( + _serialize_to_expr(unexpanded_dataframe["*"], db=db) + if not expand + else _serialize_to_expr(unexpanded_dataframe[column_name]["*"], db=db) + # `len(rebased_grouping_cols) == 0` means `GROUP BY GROUPING SETS (())` + if rebased_grouping_cols is None or len(rebased_grouping_cols) == 0 + else f"({unexpanded_dataframe._name}).*" + if not expand + else f"{','.join(rebased_grouping_cols)}, {_serialize_to_expr(unexpanded_dataframe[column_name]['*'], db=db)}" + ) - return DataFrame( - f"SELECT {result_cols} FROM {unexpanded_dataframe._name}", - db=db, - parents=[unexpanded_dataframe], - ) + return DataFrame( + f"SELECT {result_cols} FROM {unexpanded_dataframe._name}", + db=db, + parents=[unexpanded_dataframe], + ) @property def _function(self) -> "_AbstractFunction": @@ -272,12 +289,14 @@ def __init__( name: Optional[str] = None, schema: Optional[str] = None, language_handler: Literal["plpython3u"] = "plpython3u", + runtime: Optional[str] = None ) -> None: # noqa D107 super().__init__(wrapped_func, name, schema) self._created_in_dbs: Optional[Set[Database]] = set() if wrapped_func is not None else None self._wrapped_func = wrapped_func self._language_handler = language_handler + self._runtime = runtime def unwrap(self) -> Callable[..., Any]: """Get the wrapped Python function in the database function.""" @@ -302,14 +321,18 @@ def _serialize(self, db: Database) -> str: func_sig = inspect.signature(self._wrapped_func) func_args = ",".join( [ - f'"{param.name}" {_serialize_to_type(param.annotation, db=db)}' + f'"{param.name}" {_serialize_to_type_name(param.annotation, db=db)}' for param in func_sig.parameters.values() ] ) func_arg_names = ",".join( [f"{param.name}={param.name}" for param in func_sig.parameters.values()] ) - return_type = _serialize_to_type(func_sig.return_annotation, db=db, for_return=True) + return_type = ( + _serialize_to_type_name(func_sig.return_annotation, db=db, for_return=True) + if self._language_handler != "plcontainer" + else "SETOF record" + ) func_pickled: bytes = dill.dumps(self._wrapped_func) _, func_name = self._qualified_name # Modify the AST of the wrapped function to minify dependency: (1-3) @@ -335,6 +358,7 @@ def _serialize(self, db: Database) -> str: f"CREATE FUNCTION {self._qualified_name_str} ({func_args}) " f"RETURNS {return_type} " f"AS $$\n" + f"# container: {self._runtime}\n" f"try:\n" f" return GD['{func_ast.name}']({func_arg_names})\n" f"except KeyError:\n" @@ -461,7 +485,7 @@ def _create_in_db(self, db: Database) -> None: state_param = next(param_list) args_string = ",".join( [ - f"{param.name} {_serialize_to_type(param.annotation, db=db)}" + f"{param.name} {_serialize_to_type_name(param.annotation, db=db)}" for param in param_list ] ) @@ -470,7 +494,7 @@ def _create_in_db(self, db: Database) -> None: ( f"CREATE AGGREGATE {self._qualified_name_str} ({args_string}) (\n" f" SFUNC = {self.transition_function._qualified_name_str},\n" - f" STYPE = {_serialize_to_type(state_param.annotation, db=db)}\n" + f" STYPE = {_serialize_to_type_name(state_param.annotation, db=db)}\n" f");\n" ), has_results=False, @@ -547,6 +571,7 @@ def aggregate_function(name: str, schema: Optional[str] = None) -> AggregateFunc def create_function( wrapped_func: Optional[Callable[..., Any]] = None, language_handler: Literal["plpython3u"] = "plpython3u", + runtime: Optional[str] = None ) -> NormalFunction: """ Create a :class:`~func.NormalFunction` from the given Python function. @@ -610,8 +635,8 @@ def create_function( """ # If user needs extra parameters when creating a function if wrapped_func is None: - return functools.partial(create_function, language_handler=language_handler) - return NormalFunction(wrapped_func=wrapped_func, language_handler=language_handler) + return functools.partial(create_function, language_handler=language_handler, runtime=runtime) + return NormalFunction(wrapped_func=wrapped_func, language_handler=language_handler, runtime=runtime) # FIXME: Add test cases for optional parameters diff --git a/greenplumpython/type.py b/greenplumpython/type.py index 891ff6c1..33b15a0d 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -94,6 +94,17 @@ def __init__( if self._modifier is not None: self._qualified_name_str += f"({self._modifier})" + def _serialize(self, db: Database) -> str: + if self._annotation is None: + raise Exception("No type annotation to serialize") + members = get_type_hints(self._annotation) + if len(members) == 0: + raise Exception(f"Failed to get annotations for type {self._annotation}") + members_str = ",\n".join( + [f"{name} {_serialize_to_type_name(type_t, db)}" for name, type_t in members.items()] + ) + return f"({members_str})" + # -- Creation of a composite type in Greenplum corresponding to the class_type given def _create_in_db(self, db: Database): # noqa: D400 @@ -115,14 +126,9 @@ def _create_in_db(self, db: Database): self._annotation, type ), "Only composite data types can be created in database." schema = "pg_temp" - members = get_type_hints(self._annotation) - if len(members) == 0: - raise Exception(f"Failed to get annotations for type {self._annotation}") - att_type_str = ",\n".join( - [f"{name} {_serialize_to_type(type_t, db)}" for name, type_t in members.items()] - ) + db._execute( - f'CREATE TYPE "{schema}"."{self._name}" AS (\n' f"{att_type_str}\n" f");", + f'CREATE TYPE "{schema}"."{self._name}" AS {self._serialize(db=db)};', has_results=False, ) self._created_in_dbs.add(db) @@ -178,7 +184,7 @@ def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = Non return DataType(name, schema=schema, modifier=modifier) -def _serialize_to_type( +def _serialize_to_type_name( annotation: Union[DataType, type], db: Database, for_return: bool = False, @@ -204,10 +210,10 @@ def _serialize_to_type( if annotation.__origin__ == list or annotation.__origin__ == List: args: Tuple[type, ...] = annotation.__args__ if for_return: - return f"SETOF {_serialize_to_type(args[0], db)}" # type: ignore - if args[0] in _defined_types: - return f"{_serialize_to_type(args[0], db)}[]" # type: ignore - raise NotImplementedError() + return f"SETOF {_serialize_to_type_name(args[0], db)}" # type: ignore + else: + return f"{_serialize_to_type_name(args[0], db)}[]" # type: ignore + raise NotImplementedError("Only list is supported as generic data type") else: if isinstance(annotation, DataType): return annotation._qualified_name_str @@ -216,4 +222,5 @@ def _serialize_to_type( type_name = "type_" + uuid4().hex _defined_types[annotation] = DataType(name=type_name, annotation=annotation) _defined_types[annotation]._create_in_db(db) + print(_defined_types) return _defined_types[annotation]._qualified_name_str diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py new file mode 100644 index 00000000..d386d5e1 --- /dev/null +++ b/tests/test_plcontainer.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +import greenplumpython as gp + +from tests import db + + +def test_simple_func(db: gp.Database): + @dataclass + class Int: + i: int + + @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") + def add_one(x: list[Int]) -> list[Int]: + return [{"i": arg["i"] + 1} for arg in x] + + assert ( + len( + list( + db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True) + ) + ) + == 10 + ) diff --git a/tests/test_type.py b/tests/test_type.py index c7fc2ed6..4dd12944 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -3,7 +3,7 @@ import pytest import greenplumpython as gp -from greenplumpython.type import _serialize_to_type +from greenplumpython.type import _serialize_to_type_name from tests import db @@ -76,7 +76,7 @@ class Person: _first_name: str _last_name: str - type_name = _serialize_to_type(Person, db=db) + type_name = _serialize_to_type_name(Person, db=db) assert isinstance(type_name, str) @@ -88,5 +88,5 @@ def __init__(self, _first_name: str, _last_name: str) -> None: self._last_name = _last_name with pytest.raises(Exception) as exc_info: - _serialize_to_type(Person, db=db) + _serialize_to_type_name(Person, db=db) assert "Failed to get annotations" in str(exc_info.value) From bbe77b3536be59b832a9f9626fba6f23dbf04092 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Sun, 29 Oct 2023 21:34:32 -0400 Subject: [PATCH 2/6] Use plcontainer only when asked --- greenplumpython/func.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index f071733d..997d30b4 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -20,7 +20,6 @@ from greenplumpython.group import DataFrameGroupingSet from greenplumpython.type import _serialize_to_type_name, _defined_types -import psycopg2 class FunctionExpr(Expr): """ @@ -112,7 +111,7 @@ def apply( if grouping_col_names is not None and len(grouping_col_names) != 0 else None ) - try: + if self._function._language_handler == "plcontainer": return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType _serialize_to_type_name(return_annotation, db=db, for_return=True) return DataFrame( @@ -124,7 +123,7 @@ def apply( db=db, parents=parents, ) - except psycopg2.errors.InternalError_: + else: unexpanded_dataframe = DataFrame( " ".join( [ From 93cee7a108a9c9c43af2e1920697dd3b4c4058fd Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Sun, 29 Oct 2023 21:43:26 -0400 Subject: [PATCH 3/6] Fix errors --- greenplumpython/func.py | 17 ++++++++++++----- tests/test_plcontainer.py | 4 +++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index 997d30b4..89af78a4 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -111,7 +111,10 @@ def apply( if grouping_col_names is not None and len(grouping_col_names) != 0 else None ) - if self._function._language_handler == "plcontainer": + if ( + isinstance(self._function, NormalFunction) + and self._function._language_handler == "plcontainer" + ): return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType _serialize_to_type_name(return_annotation, db=db, for_return=True) return DataFrame( @@ -288,7 +291,7 @@ def __init__( name: Optional[str] = None, schema: Optional[str] = None, language_handler: Literal["plpython3u"] = "plpython3u", - runtime: Optional[str] = None + runtime: Optional[str] = None, ) -> None: # noqa D107 super().__init__(wrapped_func, name, schema) @@ -570,7 +573,7 @@ def aggregate_function(name: str, schema: Optional[str] = None) -> AggregateFunc def create_function( wrapped_func: Optional[Callable[..., Any]] = None, language_handler: Literal["plpython3u"] = "plpython3u", - runtime: Optional[str] = None + runtime: Optional[str] = None, ) -> NormalFunction: """ Create a :class:`~func.NormalFunction` from the given Python function. @@ -634,8 +637,12 @@ def create_function( """ # If user needs extra parameters when creating a function if wrapped_func is None: - return functools.partial(create_function, language_handler=language_handler, runtime=runtime) - return NormalFunction(wrapped_func=wrapped_func, language_handler=language_handler, runtime=runtime) + return functools.partial( + create_function, language_handler=language_handler, runtime=runtime + ) + return NormalFunction( + wrapped_func=wrapped_func, language_handler=language_handler, runtime=runtime + ) # FIXME: Add test cases for optional parameters diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py index d386d5e1..8ff6ccb6 100644 --- a/tests/test_plcontainer.py +++ b/tests/test_plcontainer.py @@ -16,7 +16,9 @@ def add_one(x: list[Int]) -> list[Int]: assert ( len( list( - db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True) + db.create_dataframe(columns={"i": range(10)}).apply( + lambda _: add_one(), expand=True + ) ) ) == 10 From bd35ca8763240217c4fd3a0227ab67c55e26868b Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Sun, 29 Oct 2023 21:47:11 -0400 Subject: [PATCH 4/6] Sort imports --- greenplumpython/func.py | 2 +- tests/test_plcontainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index 89af78a4..c92082b7 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -18,7 +18,7 @@ from greenplumpython.db import Database from greenplumpython.expr import Expr, _serialize_to_expr from greenplumpython.group import DataFrameGroupingSet -from greenplumpython.type import _serialize_to_type_name, _defined_types +from greenplumpython.type import _defined_types, _serialize_to_type_name class FunctionExpr(Expr): diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py index 8ff6ccb6..6b16f051 100644 --- a/tests/test_plcontainer.py +++ b/tests/test_plcontainer.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -import greenplumpython as gp +import greenplumpython as gp from tests import db From b5cda8e80ccfc9b83dc29064a60eda9995319920 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Sun, 29 Oct 2023 21:48:53 -0400 Subject: [PATCH 5/6] Remove print() --- greenplumpython/type.py | 1 - 1 file changed, 1 deletion(-) diff --git a/greenplumpython/type.py b/greenplumpython/type.py index 33b15a0d..338cd599 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -222,5 +222,4 @@ def _serialize_to_type_name( type_name = "type_" + uuid4().hex _defined_types[annotation] = DataType(name=type_name, annotation=annotation) _defined_types[annotation]._create_in_db(db) - print(_defined_types) return _defined_types[annotation]._qualified_name_str From 67268fb5b5213671c26d42cdf41607607af389ef Mon Sep 17 00:00:00 2001 From: Ruxue Zeng <36695415+ruxuez@users.noreply.github.com> Date: Wed, 27 Dec 2023 10:04:26 +0100 Subject: [PATCH 6/6] Allow pass column to plcontainer_apply (#229) This PR allows to pass specific columns when using plcontainer_apply. --- greenplumpython/func.py | 11 ++++++++- tests/test_plcontainer.py | 47 ++++++++++++++++++++++++++++++++------- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index c92082b7..dff81877 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -117,10 +117,18 @@ def apply( ): return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType _serialize_to_type_name(return_annotation, db=db, for_return=True) + input_args = self._args + if len(input_args) == 0: + raise Exception("No input data specified, please specify a DataFrame or Columns") + input_clause = ( + "*" + if (len(input_args) == 1 and isinstance(input_args[0], DataFrame)) + else ",".join([arg._serialize(db=db) for arg in input_args]) + ) return DataFrame( f""" SELECT * FROM plcontainer_apply(TABLE( - SELECT * {from_clause}), '{self._function._qualified_name_str}', 4096) AS + SELECT {input_clause} {from_clause}), '{self._function._qualified_name_str}', 4096) AS {_defined_types[return_annotation.__args__[0]]._serialize(db=db)} """, db=db, @@ -370,6 +378,7 @@ def _serialize(self, db: Database) -> str: f" import sys as {sys_lib_name}\n" f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n" f" raise ModuleNotFoundError\n" + f" {sys_lib_name}.modules['plpy']=plpy\n" f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n" f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n" f" except ModuleNotFoundError:\n" diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py index 6b16f051..1a1b39e4 100644 --- a/tests/test_plcontainer.py +++ b/tests/test_plcontainer.py @@ -1,25 +1,56 @@ from dataclasses import dataclass +import pytest + import greenplumpython as gp from tests import db -def test_simple_func(db: gp.Database): - @dataclass - class Int: - i: int +@dataclass +class Int: + i: int + + +@dataclass +class Pair: + i: int + j: int + + +@pytest.fixture +def t(db: gp.Database): + rows = [(i, i) for i in range(10)] + return db.create_dataframe(rows=rows, column_names=["a", "b"]) + + +@gp.create_function(language_handler="plcontainer", runtime="plc_python_example") +def add_one(x: list[Int]) -> list[Int]: + return [{"i": arg["i"] + 1} for arg in x] - @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") - def add_one(x: list[Int]) -> list[Int]: - return [{"i": arg["i"] + 1} for arg in x] +def test_simple_func(db: gp.Database): assert ( len( list( db.create_dataframe(columns={"i": range(10)}).apply( - lambda _: add_one(), expand=True + lambda t: add_one(t), expand=True ) ) ) == 10 ) + + +def test_func_no_input(db: gp.Database): + + with pytest.raises(Exception) as exc_info: # no input data for func raises Exception + db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True) + assert "No input data specified, please specify a DataFrame or Columns" in str(exc_info.value) + + +def test_func_column(db: gp.Database, t: gp.DataFrame): + @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") + def add(x: list[Pair]) -> list[Int]: + return [{"i": arg["i"] + arg["j"]} for arg in x] + + assert len(list(t.apply(lambda t: add(t["a"], t["b"]), expand=True))) == 10