diff --git a/Dockerfile b/Dockerfile index d8d7d5a..90a7e8a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ RUN apk --update add openjdk8-jre gcc musl-dev bash ENV JAVA_HOME /usr/ # Hadoop -ENV HADOOP_VERSION 2.7.2 +ENV HADOOP_VERSION 3.3.3 ENV HADOOP_HOME /usr/hadoop-$HADOOP_VERSION ENV HADOOP_CONF_DIR=$HADOOP_HOME/etc/hadoop ENV PATH $PATH:$HADOOP_HOME/bin @@ -14,7 +14,7 @@ RUN wget "http://archive.apache.org/dist/hadoop/common/hadoop-$HADOOP_VERSION/ha && rm "hadoop-$HADOOP_VERSION.tar.gz" # Spark -ENV SPARK_VERSION 2.4.8 +ENV SPARK_VERSION 3.3.3 ENV SPARK_PACKAGE spark-$SPARK_VERSION ENV SPARK_HOME /usr/$SPARK_PACKAGE-bin-without-hadoop ENV PYSPARK_PYTHON python diff --git a/requirements.in b/requirements.in index 506c151..aa6e850 100644 --- a/requirements.in +++ b/requirements.in @@ -1,4 +1,4 @@ black==20.8b1 -pyspark==2.4.7 +pyspark[connect]==3.4.0 pytest-testdox==2.0.1 -pytest==6.1.1 \ No newline at end of file +pytest==7.3.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 654b48a..5829e86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,50 +1,81 @@ # -# This file is autogenerated by pip-compile -# To update, run: +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: # # pip-compile requirements.in # appdirs==1.4.4 # via black -attrs==20.2.0 - # via pytest black==20.8b1 # via -r requirements.in click==7.1.2 # via black +exceptiongroup==1.2.0 + # via pytest +googleapis-common-protos==1.62.0 + # via + # grpcio-status + # pyspark +grpcio==1.60.0 + # via + # grpcio-status + # pyspark +grpcio-status==1.60.0 + # via pyspark iniconfig==1.0.1 # via pytest mypy-extensions==0.4.3 # via black +numpy==1.26.2 + # via + # pandas + # pyarrow + # pyspark packaging==20.4 # via pytest +pandas==2.1.4 + # via pyspark pathspec==0.8.0 # via black pluggy==0.13.1 # via pytest -py4j==0.10.7 +protobuf==4.25.1 + # via + # googleapis-common-protos + # grpcio-status +py4j==0.10.9.7 + # via pyspark +pyarrow==14.0.1 # via pyspark -py==1.10.0 - # via pytest pyparsing==2.4.7 # via packaging -pyspark==2.4.7 - # via -r requirements.in -pytest-testdox==2.0.1 - # via -r requirements.in -pytest==6.1.1 +pyspark[connect]==3.4.0 + # via + # -r requirements.in + # pyspark +pytest==7.3.1 # via # -r requirements.in # pytest-testdox +pytest-testdox==2.0.1 + # via -r requirements.in +python-dateutil==2.8.2 + # via pandas +pytz==2023.3.post1 + # via pandas regex==2020.10.28 # via black six==1.15.0 - # via packaging -toml==0.10.1 # via - # black - # pytest + # packaging + # python-dateutil +toml==0.10.1 + # via black +tomli==2.0.1 + # via pytest typed-ast==1.4.1 # via black typing-extensions==3.7.4.3 # via black +tzdata==2023.3 + # via pandas diff --git a/src/pyspark_test.py b/src/pyspark_test.py index f2c6777..dbcb56c 100644 --- a/src/pyspark_test.py +++ b/src/pyspark_test.py @@ -2,14 +2,36 @@ import pyspark +try: + from pyspark.sql.connect.dataframe import DataFrame as CDF -def _check_isinstance(left: Any, right: Any, cls): - assert isinstance( - left, cls - ), f"Left expected type {cls}, found {type(left)} instead" - assert isinstance( - right, cls - ), f"Right expected type {cls}, found {type(right)} instead" + has_connect_deps = True +except ImportError: + has_connect_deps = False + + +def _check_isinstance_df(left: Any, right: Any): + types_to_test = [pyspark.sql.DataFrame] + msg_string = "" + # If Spark Connect dependencies are not available, the input is not going to be a Spark Connect + # DataFrame so we can safely skip the validation. + if has_connect_deps: + types_to_test.append(CDF) + msg_string = " or {CDF}" + + left_good = any(map(lambda x: isinstance(left, x), types_to_test)) + right_good = any(map(lambda x: isinstance(right, x), types_to_test)) + assert ( + left_good + ), f"Left expected type {pyspark.sql.DataFrame}{msg_string}, found {type(left)} instead" + assert ( + right_good + ), f"Right expected type {pyspark.sql.DataFrame}{msg_string}, found {type(right)} instead" + + # Check that both sides are of the same DataFrame type. + assert type(left) == type( + right + ), f"Left and right DataFrames are not of the same type: {type(left)} != {type(right)}" def _check_columns( @@ -39,7 +61,8 @@ def _check_schema( def _check_df_content( - left_df: pyspark.sql.DataFrame, right_df: pyspark.sql.DataFrame, + left_df: pyspark.sql.DataFrame, + right_df: pyspark.sql.DataFrame, ): left_df_list = left_df.collect() right_df_list = right_df.collect() @@ -88,7 +111,7 @@ def assert_pyspark_df_equal( """ # Check if - _check_isinstance(left_df, right_df, pyspark.sql.DataFrame) + _check_isinstance_df(left_df, right_df) # Check Column Names if check_column_names: diff --git a/tests/unit_test/test_assert_pyspark_df_equal.py b/tests/unit_test/test_assert_pyspark_df_equal.py index 90cd2d7..5d5901d 100644 --- a/tests/unit_test/test_assert_pyspark_df_equal.py +++ b/tests/unit_test/test_assert_pyspark_df_equal.py @@ -12,6 +12,7 @@ ) from src.pyspark_test import assert_pyspark_df_equal +from src.pyspark_test import _check_isinstance_df class TestAssertPysparkDfEqual: @@ -68,7 +69,7 @@ def test_assert_pyspark_df_equal_one_is_not_pyspark_df( right_df = "Demo" with pytest.raises( AssertionError, - match="Right expected type , found instead", + match="Right expected type or .*?, found instead", ): assert_pyspark_df_equal(left_df, right_df) @@ -324,3 +325,18 @@ def test_assert_pyspark_df_equal_different_row_count( match="Number of rows are not same.\n \n Actual Rows: 2\n Expected Rows: 3", ): assert_pyspark_df_equal(left_df, right_df) + + def test_instance_checks_for_spark_connect( + self, spark_session: pyspark.sql.SparkSession + ): + from pyspark.sql.connect.dataframe import DataFrame as CDF + left_df = spark_session.range(1) + right_df = spark_session.range(1) + _check_isinstance_df(left_df, right_df) + + left_df = CDF.withPlan(None, None) + right_df = CDF.withPlan(None, None) + _check_isinstance_df(left_df, right_df) + + with pytest.raises(AssertionError): + _check_isinstance_df(spark_session.range(1), right_df)