diff --git a/build.gradle.kts b/build.gradle.kts index d7becb436..86072e0f9 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -42,6 +42,11 @@ dependencies { nmcpAggregation(project(":isthmus")) } +// Ensure all Spark variants are published before aggregation +tasks.named("publishAggregationToCentralPortal") { + dependsOn(":spark:publishAllVariantsToCentralPortal") +} + allprojects { repositories { mavenCentral() } diff --git a/examples/substrait-spark/build.gradle.kts b/examples/substrait-spark/build.gradle.kts index f0a08a806..b363b92e7 100644 --- a/examples/substrait-spark/build.gradle.kts +++ b/examples/substrait-spark/build.gradle.kts @@ -10,17 +10,32 @@ repositories { mavenCentral() } +// Get the Spark variant property - determines which spark subproject to use +val sparkVariantProp = findProperty("sparkVariant")?.toString() ?: "spark40_2.13" + +// Map variants to their subproject paths and versions +val variantConfig = + mapOf( + "spark34_2.12" to Triple(":spark:spark-3.4_2.12", "3.4.4", "2.12"), + "spark35_2.12" to Triple(":spark:spark-3.5_2.12", "3.5.4", "2.12"), + "spark40_2.13" to Triple(":spark:spark-4.0_2.13", "4.0.2", "2.13"), + ) + +val (sparkProject, sparkVersion, scalaBinary) = + variantConfig[sparkVariantProp] ?: variantConfig["spark40_2.13"]!! + dependencies { - implementation(project(":spark")) + // Depend on the specific spark variant subproject + implementation(project(sparkProject)) // For a real Spark application, these would not be required since they would be in the Spark - // server classpath - runtimeOnly(libs.spark.core) - runtimeOnly(libs.spark.hive) + // server classpath. Use direct Maven coordinates to match the spark module's variant. + runtimeOnly("org.apache.spark:spark-core_${scalaBinary}:${sparkVersion}") + runtimeOnly("org.apache.spark:spark-hive_${scalaBinary}:${sparkVersion}") } tasks.jar { - dependsOn(":spark:jar", ":core:jar", ":core:shadowJar") + dependsOn("$sparkProject:jar", ":core:jar", ":core:shadowJar") isZip64 = true exclude("META-INF/*.RSA") diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java index 26c15274f..5cd39b7f1 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java @@ -9,9 +9,9 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.classic.Dataset; +import org.apache.spark.sql.classic.SparkSession; /** Minimal Spark application */ public class SparkConsumeSubstrait implements App.Action { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java index 270b760ca..ce623eff8 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java @@ -1,6 +1,6 @@ package io.substrait.examples; -import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.classic.SparkSession; /** Collection of helper fns */ public final class SparkHelper { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java index ddb544f00..5488013e8 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java @@ -58,7 +58,7 @@ public void run(String arg) { Dataset result = spark.sql(sqlQuery); result.show(); - LogicalPlan logical = result.logicalPlan(); + LogicalPlan logical = result.queryExecution().logical(); System.out.println(logical); LogicalPlan optimised = result.queryExecution().optimizedPlan(); diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 3472e10af..5fafafb92 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -17,12 +17,17 @@ nmcp = "1.4.4" picocli = "4.7.7" protobuf-plugin = "0.9.6" protobuf = "3.25.8" -scala-library = "2.13.18" -scalatest = "3.2.19" -scalatestplus-junit5 = "3.2.19.0" +scala-2-12 = "2.12.20" +scala-2-13 = "2.13.18" +scalatest-2-12 = "3.2.19" +scalatest-2-13 = "3.2.19" +scalatestplus-junit5-2-12 = "3.2.19.0" +scalatestplus-junit5-2-13 = "3.2.19.0" shadow = "9.3.1" slf4j = "2.0.17" -spark = "3.4.4" +spark-3-4 = "3.4.4" +spark-3-5 = "3.5.4" +spark-4-0 = "4.0.2" spotless = "8.2.1" validator = "3.0.0" @@ -59,15 +64,26 @@ picocli-codegen = { module = "info.picocli:picocli-codegen", version.ref = "pico protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "protobuf" } protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", version.ref = "protobuf" } protoc = { module = "com.google.protobuf:protoc", version.ref = "protobuf" } -scala-library = { module = "org.scala-lang:scala-library", version.ref = "scala-library" } -scalatest = { module = "org.scalatest:scalatest_2.13", version.ref = "scalatest" } -scalatestplus-junit5 = { module = "org.scalatestplus:junit-5-13_2.13", version.ref = "scalatestplus-junit5" } +scala-library-2-12 = { module = "org.scala-lang:scala-library", version.ref = "scala-2-12" } +scala-library-2-13 = { module = "org.scala-lang:scala-library", version.ref = "scala-2-13" } +scalatest-2-12 = { module = "org.scalatest:scalatest_2.12", version.ref = "scalatest-2-12" } +scalatest-2-13 = { module = "org.scalatest:scalatest_2.13", version.ref = "scalatest-2-13" } +scalatestplus-junit5-2-12 = { module = "org.scalatestplus:junit-5-12_2.12", version.ref = "scalatestplus-junit5-2-12" } +scalatestplus-junit5-2-13 = { module = "org.scalatestplus:junit-5-13_2.13", version.ref = "scalatestplus-junit5-2-13" } slf4j-api = { module = "org.slf4j:slf4j-api", version.ref = "slf4j" } slf4j-jdk14 = { module = "org.slf4j:slf4j-jdk14", version.ref = "slf4j" } -spark-catalyst = { module = "org.apache.spark:spark-catalyst_2.13", version.ref = "spark" } -spark-core = { module = "org.apache.spark:spark-core_2.13", version.ref = "spark" } -spark-hive = { module = "org.apache.spark:spark-hive_2.13", version.ref = "spark" } -spark-sql = { module = "org.apache.spark:spark-sql_2.13", version.ref = "spark" } +spark-catalyst-3-4-2-12 = { module = "org.apache.spark:spark-catalyst_2.12", version.ref = "spark-3-4" } +spark-core-3-4-2-12 = { module = "org.apache.spark:spark-core_2.12", version.ref = "spark-3-4" } +spark-hive-3-4-2-12 = { module = "org.apache.spark:spark-hive_2.12", version.ref = "spark-3-4" } +spark-sql-3-4-2-12 = { module = "org.apache.spark:spark-sql_2.12", version.ref = "spark-3-4" } +spark-catalyst-3-5-2-12 = { module = "org.apache.spark:spark-catalyst_2.12", version.ref = "spark-3-5" } +spark-core-3-5-2-12 = { module = "org.apache.spark:spark-core_2.12", version.ref = "spark-3-5" } +spark-hive-3-5-2-12 = { module = "org.apache.spark:spark-hive_2.12", version.ref = "spark-3-5" } +spark-sql-3-5-2-12 = { module = "org.apache.spark:spark-sql_2.12", version.ref = "spark-3-5" } +spark-catalyst-4-0-2-13 = { module = "org.apache.spark:spark-catalyst_2.13", version.ref = "spark-4-0" } +spark-core-4-0-2-13 = { module = "org.apache.spark:spark-core_2.13", version.ref = "spark-4-0" } +spark-hive-4-0-2-13 = { module = "org.apache.spark:spark-hive_2.13", version.ref = "spark-4-0" } +spark-sql-4-0-2-13 = { module = "org.apache.spark:spark-sql_2.13", version.ref = "spark-4-0" } [bundles] jackson = [ "jackson-databind", "jackson-annotations", "jackson-datatype-jdk8", "jackson-dataformat-yaml" ] diff --git a/settings.gradle.kts b/settings.gradle.kts index 500c43c3c..f55cc05fc 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -8,6 +8,9 @@ include( "isthmus", "isthmus-cli", "spark", + "spark:spark-3.4_2.12", + "spark:spark-3.5_2.12", + "spark:spark-4.0_2.13", "examples:substrait-spark", "examples:isthmus-api", ) diff --git a/spark/.scalafmt.conf b/spark/.scalafmt.conf index 3606e3439..cdcc43e23 100644 --- a/spark/.scalafmt.conf +++ b/spark/.scalafmt.conf @@ -1,4 +1,4 @@ -runner.dialect = scala212 +runner.dialect = scala213 # Version is required to make sure IntelliJ picks the right version version = 3.8.1 diff --git a/spark/README-MULTI-VARIANT.md b/spark/README-MULTI-VARIANT.md new file mode 100644 index 000000000..ea11dc8e3 --- /dev/null +++ b/spark/README-MULTI-VARIANT.md @@ -0,0 +1,241 @@ +# Multi-Variant Spark/Scala Build System + +This document describes how to build and publish multiple Spark/Scala variants of the substrait-spark module. + +## Supported Variants + +The substrait-spark module supports three build variants: + +| Variant | Spark Version | Scala Version | Classifier | Subproject | +|---------|---------------|---------------|------------|------------| +| Spark 3.4 | 3.4.4 | 2.12.20 | `spark34_2.12` | `:spark:spark-3.4_2.12` | +| Spark 3.5 | 3.5.4 | 2.12.20 | `spark35_2.12` | `:spark:spark-3.5_2.12` | +| Spark 4.0 | 4.0.2 | 2.13.18 | `spark40_2.13` | `:spark:spark-4.0_2.13` | + +## Architecture + +The build system uses **Gradle subprojects** for each variant. + +### Project Structure + +``` +spark/ +├── build.gradle.kts # Orchestrator project +├── src/ # Shared source code +│ ├── main/ +│ │ ├── scala/ # Common code for all versions +│ │ ├── spark-3.4/ # Spark 3.4 specific implementations +│ │ ├── spark-3.5/ # Spark 3.5 specific implementations +│ │ └── spark-4.0/ # Spark 4.0 specific implementations +│ └── test/ +│ ├── scala/ # Common test code +│ ├── spark-3.4/ # Spark 3.4 specific tests +│ ├── spark-3.5/ # Spark 3.5 specific tests +│ └── spark-4.0/ # Spark 4.0 specific tests +├── spark-3.4_2.12/ +│ └── build.gradle.kts # Spark 3.4 variant build +├── spark-3.5_2.12/ +│ └── build.gradle.kts # Spark 3.5 variant build +└── spark-4.0_2.13/ + └── build.gradle.kts # Spark 4.0 variant build +``` + +Each subproject references the shared source code in `../src/` using Gradle's source set configuration. + +## Building Variants + +### Build a Specific Variant + +Build a specific variant using its subproject path: + +```bash +# Build Spark 3.4 with Scala 2.12 +./gradlew :spark:spark-3.4_2.12:build + +# Build Spark 3.5 with Scala 2.12 +./gradlew :spark:spark-3.5_2.12:build + +# Build Spark 4.0 with Scala 2.13 +./gradlew :spark:spark-4.0_2.13:build +``` + +### Build All Variants + +To build all variants: + +```bash +./gradlew :spark:build +``` + +## Publishing Variants + +### Publish to Local Maven Repository + +Publish a specific variant: + +```bash +# Publish Spark 3.4 with Scala 2.12 +./gradlew :spark:spark-3.4_2.12:publishToMavenLocal + +# Publish Spark 3.5 with Scala 2.12 +./gradlew :spark:spark-3.5_2.12:publishToMavenLocal + +# Publish Spark 4.0 with Scala 2.13 +./gradlew :spark:spark-4.0_2.13:publishToMavenLocal +``` + +### Publish All Variants + +To publish all variants to your local Maven repository: + +```bash +./gradlew :spark:publishAllVariants +``` + +Published artifacts will be available at: +``` +~/.m2/repository/io/substrait/{classifier}/{version}/ +``` + +For example: +- `~/.m2/repository/io/substrait/spark34_2.12/0.78.0/` +- `~/.m2/repository/io/substrait/spark35_2.12/0.78.0/` +- `~/.m2/repository/io/substrait/spark40_2.13/0.78.0/` + +### Publish to Maven Central Portal + +Publish all variants to Maven Central: + +```bash +./gradlew :spark:publishAllVariantsToCentralPortal +``` + +Or publish a specific variant: + +```bash +./gradlew :spark:spark-4.0_2.13:publishMaven-publishPublicationToNmcpRepository +``` + +## Using Published Artifacts + +### Maven + +Add the appropriate variant as a dependency in your `pom.xml`: + +```xml + + + io.substrait + spark34_2.12 + 0.80.0 + + + + + io.substrait + spark35_2.12 + 0.80.0 + + + + + io.substrait + spark40_2.13 + 0.80.0 + +``` + +### Gradle + +Add the appropriate variant as a dependency in your `build.gradle.kts`: + +```kotlin +dependencies { + // Spark 3.4 with Scala 2.12 + implementation("io.substrait:spark34_2.12:0.80.0") + + // Spark 3.5 with Scala 2.12 + implementation("io.substrait:spark35_2.12:0.80.0") + + // Spark 4.0 with Scala 2.13 + implementation("io.substrait:spark40_2.13:0.80.0") +} +``` + +## Development Workflow + +### Adding Support for a New Spark Version + +1. **Create a new subproject directory**: + ```bash + mkdir -p spark/spark-4.1_2.13 + ``` + +2. **Copy and modify a build.gradle.kts** from an existing variant: + ```bash + cp spark/spark-4.0_2.13/build.gradle.kts spark/spark-4.1_2.13/ + ``` + +3. **Update the variant configuration** in the new `build.gradle.kts`: + ```kotlin + val sparkVersion = "4.1.0" + val scalaVersion = "2.13.18" + val sparkMajorMinor = "4.1" + val scalaBinary = "2.13" + val classifier = "spark41_2.13" + ``` + +4. **Add the subproject** to `settings.gradle.kts`: + ```kotlin + include( + // ... existing projects + "spark:spark-4.1_2.13", + ) + ``` + +5. **Update the orchestrator** in `spark/build.gradle.kts`: + ```kotlin + tasks.register("buildAllVariants") { + dependsOn( + // ... existing variants + ":spark:spark-4.1_2.13:build" + ) + } + ``` + +6. **Create version-specific source directory**: + ```bash + mkdir -p spark/src/main/spark-4.1 + mkdir -p spark/src/test/spark-4.1 + ``` + +7. **Add version-specific implementations** for classes with API differences + +8. **Test the new variant**: + ```bash + ./gradlew :spark:spark-4.1_2.13:build + ``` + +### Testing Changes Across All Variants + +When making changes to common code, test all variants: + +```bash +# Quick compilation test for all variants +./gradlew :spark:spark-3.4_2.12:compileScala +./gradlew :spark:spark-3.5_2.12:compileScala +./gradlew :spark:spark-4.0_2.13:compileScala + +# Or run full build for all variants +./gradlew :spark:buildAllVariants +``` + +### Cleaning Build Artifacts + +```bash +# Clean a specific variant +./gradlew :spark:spark-4.0_2.13:clean + +# Clean all variants +./gradlew :spark:clean +``` diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts index 5ae1e3f05..53ea6a14d 100644 --- a/spark/build.gradle.kts +++ b/spark/build.gradle.kts @@ -1,149 +1,53 @@ -plugins { - `maven-publish` - signing - id("java-library") - id("scala") - id("idea") - alias(libs.plugins.spotless) - alias(libs.plugins.nmcp) - id("substrait.java-conventions") -} - -val stagingRepositoryUrl = uri(layout.buildDirectory.dir("staging-deploy")) - -publishing { - publications { - create("maven-publish") { - from(components["java"]) - - pom { - name.set("Substrait Java") - description.set( - "Create a well-defined, cross-language specification for data compute operations" - ) - url.set("https://github.com/substrait-io/substrait-java") - licenses { - license { - name.set("The Apache License, Version 2.0") - url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") - } - } - developers { - developer { - id = "vbarua" - name = "Victor Barua" - } - } - scm { - connection.set("scm:git:git://github.com:substrait-io/substrait-java.git") - developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java") - url.set("https://github.com/substrait-io/substrait-java/") - } - } - } - } - repositories { - maven { - name = "local" - val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") - val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") - url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) - } - } -} - -signing { - setRequired({ - gradle.taskGraph.hasTask(":${project.name}:publishMaven-publishPublicationToNmcpRepository") - }) - val signingKeyId = - System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() } - ?: extra["SIGNING_KEY_ID"].toString() - val signingPassword = - System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() } - ?: extra["SIGNING_PASSWORD"].toString() - val signingKey = - System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() } - ?: extra["SIGNING_KEY"].toString() - useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword) - sign(publishing.publications["maven-publish"]) -} - -configurations.all { - if (name.startsWith("incrementalScalaAnalysis")) { - setExtendsFrom(emptyList()) - } -} - -java { - toolchain { languageVersion = JavaLanguageVersion.of(17) } - withJavadocJar() - withSourcesJar() -} - -tasks.withType() { - scalaCompileOptions.additionalParameters = listOf("-release:17", "-Xfatal-warnings") -} +// Orchestrator project for building all Spark variants +// The actual variant builds are in subprojects: spark-3.4_2.12, spark-3.5_2.12, spark-4.0_2.13 -var SPARKBUNDLE_VERSION = properties.get("sparkbundle.version") - -sourceSets { - main { scala { setSrcDirs(listOf("src/main/scala", "src/main/spark-3.4")) } } - test { scala { setSrcDirs(listOf("src/test/scala", "src/test/spark-3.2", "src/main/scala")) } } -} - -dependencies { - api(project(":core")) - implementation(libs.scala.library) - api(libs.spark.core) - api(libs.spark.sql) - implementation(libs.spark.hive) - implementation(libs.spark.catalyst) - implementation(libs.slf4j.api) - implementation(platform(libs.jackson.bom)) - implementation(libs.bundles.jackson) - implementation(libs.json.schema.validator) - - testImplementation(libs.scalatest) - testImplementation(platform(libs.junit.bom)) - testRuntimeOnly(libs.junit.platform.engine) - testRuntimeOnly(libs.junit.platform.launcher) - testRuntimeOnly(libs.scalatestplus.junit5) - - testImplementation(variantOf(libs.spark.core) { classifier("tests") }) - testImplementation(variantOf(libs.spark.sql) { classifier("tests") }) - testImplementation(variantOf(libs.spark.catalyst) { classifier("tests") }) -} - -spotless { - scala { - scalafmt().configFile(".scalafmt.conf") - toggleOffOn() - } -} - -tasks.register("dialect") { - dependsOn(":core:shadowJar") - classpath = java.sourceSets["main"].runtimeClasspath - mainClass = "io.substrait.spark.utils.DialectGenerator" - args = listOf("spark_dialect.yaml") -} - -tasks { - jar { - manifest { - from("../core/build/generated/sources/manifest/META-INF/MANIFEST.MF") - attributes("Implementation-Title" to "substrait-spark") - } - } - - test { - dependsOn(":core:shadowJar") - useJUnitPlatform { includeEngines("scalatest") } - jvmArgs( - "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED", - "--add-opens=java.base/java.net=ALL-UNNAMED", - ) - environment("SPARK_LOCAL_IP", "127.0.0.1") - } +plugins { + id("base") // Provides lifecycle tasks like clean, build, assemble +} + +// Aggregate task to build all variants +tasks.register("buildAllVariants") { + group = "build" + description = "Builds all Spark/Scala variants" + dependsOn( + ":spark:spark-3.4_2.12:build", + ":spark:spark-3.5_2.12:build", + ":spark:spark-4.0_2.13:build" + ) +} + +// Aggregate task to publish all variants to Maven Local +tasks.register("publishAllVariants") { + group = "publishing" + description = "Publishes all Spark/Scala variants to Maven Local" + dependsOn( + ":spark:spark-3.4_2.12:publishToMavenLocal", + ":spark:spark-3.5_2.12:publishToMavenLocal", + ":spark:spark-4.0_2.13:publishToMavenLocal" + ) +} + +// Aggregate task to publish all variants to Central Portal +tasks.register("publishAllVariantsToCentralPortal") { + group = "publishing" + description = "Publishes all Spark/Scala variants to Maven Central Portal" + dependsOn( + ":spark:spark-3.4_2.12:publishMaven-publishPublicationToNmcpRepository", + ":spark:spark-3.5_2.12:publishMaven-publishPublicationToNmcpRepository", + ":spark:spark-4.0_2.13:publishMaven-publishPublicationToNmcpRepository" + ) +} + +// Make the default build task build all variants +tasks.named("build") { + dependsOn("buildAllVariants") +} + +// Make clean task clean all variants +tasks.named("clean") { + dependsOn( + ":spark:spark-3.4_2.12:clean", + ":spark:spark-3.5_2.12:clean", + ":spark:spark-4.0_2.13:clean" + ) } diff --git a/spark/spark-3.4_2.12/build.gradle.kts b/spark/spark-3.4_2.12/build.gradle.kts new file mode 100644 index 000000000..563448457 --- /dev/null +++ b/spark/spark-3.4_2.12/build.gradle.kts @@ -0,0 +1,188 @@ +plugins { + `maven-publish` + signing + id("java-library") + id("scala") + id("idea") + alias(libs.plugins.spotless) + alias(libs.plugins.nmcp) + id("substrait.java-conventions") +} + +// Spark 3.4 with Scala 2.12 variant configuration +val sparkVersion = "3.4.4" +val scalaVersion = "2.12.20" +val sparkMajorMinor = "3.4" +val scalaBinary = "2.12" +val classifier = "spark34_2.12" + +publishing { + publications { + create("maven-publish") { + from(components["java"]) + + artifactId = classifier + + pom { + name.set("Substrait Spark $classifier") + description.set( + "Substrait integration for Apache Spark $sparkMajorMinor with Scala $scalaBinary" + ) + url.set("https://github.com/substrait-io/substrait-java") + + properties.put("spark.version", sparkVersion) + properties.put("scala.version", scalaVersion) + properties.put("scala.binary.version", scalaBinary) + + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } + } + scm { + connection.set("scm:git:git://github.com:substrait-io/substrait-java.git") + developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java") + url.set("https://github.com/substrait-io/substrait-java/") + } + } + } + } + repositories { + maven { + name = "local" + val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") + val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") + url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + } + } +} + +signing { + setRequired({ + gradle.taskGraph.hasTask(":${project.name}:publishMaven-publishPublicationToNmcpRepository") + }) + val signingKeyId = + System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY_ID"].toString() + val signingPassword = + System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_PASSWORD"].toString() + val signingKey = + System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY"].toString() + useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword) + sign(publishing.publications["maven-publish"]) +} + +configurations.all { + if (name.startsWith("incrementalScalaAnalysis")) { + setExtendsFrom(emptyList()) + } +} + +java { + toolchain { languageVersion = JavaLanguageVersion.of(17) } + withJavadocJar() + withSourcesJar() +} + +tasks.withType() { + scalaCompileOptions.additionalParameters = listOf("-release:17", "-Xfatal-warnings") +} + +// Configure source sets to reference parent's src directory +sourceSets { + main { + scala { setSrcDirs(listOf("../src/main/scala", "../src/main/spark-$sparkMajorMinor")) } + resources { setSrcDirs(listOf("../src/main/resources")) } + } + test { + scala { + setSrcDirs( + listOf("../src/test/scala", "../src/test/spark-$sparkMajorMinor", "../src/main/scala") + ) + } + resources { setSrcDirs(listOf("../src/test/resources")) } + } +} + +dependencies { + api(project(":core")) + + // Scala dependencies + implementation("org.scala-lang:scala-library:$scalaVersion") + testImplementation("org.scalatest:scalatest_$scalaBinary:3.2.19") + testRuntimeOnly("org.scalatestplus:junit-5-12_$scalaBinary:3.2.19.0") + + // Spark dependencies + api("org.apache.spark:spark-core_$scalaBinary:$sparkVersion") + api("org.apache.spark:spark-sql_$scalaBinary:$sparkVersion") + implementation("org.apache.spark:spark-hive_$scalaBinary:$sparkVersion") + implementation("org.apache.spark:spark-catalyst_$scalaBinary:$sparkVersion") + testImplementation("org.apache.spark:spark-core_$scalaBinary:$sparkVersion:tests") + testImplementation("org.apache.spark:spark-sql_$scalaBinary:$sparkVersion:tests") + testImplementation("org.apache.spark:spark-catalyst_$scalaBinary:$sparkVersion:tests") + + // Common dependencies + implementation(libs.slf4j.api) + implementation(platform(libs.jackson.bom)) + implementation(libs.bundles.jackson) + implementation(libs.json.schema.validator) + + testImplementation(platform(libs.junit.bom)) + testRuntimeOnly(libs.junit.platform.engine) + testRuntimeOnly(libs.junit.platform.launcher) +} + +// Spotless is disabled for subprojects since source files are in parent directory +// The parent spark project handles formatting +spotless { isEnforceCheck = false } + +tasks { + // Ensure shadowJar runs before compilation, but skip core tests + named("compileJava") { + dependsOn(":core:shadowJar") + mustRunAfter(":core:compileJava") + } + named("compileScala") { + dependsOn(":core:shadowJar") + mustRunAfter(":core:compileJava") + } + + // Explicitly skip core test compilation for this build + gradle.taskGraph.whenReady { + allTasks + .filter { it.path.startsWith(":core:") && it.name.contains("Test") } + .forEach { it.enabled = false } + } + + jar { + manifest { + from("../../core/build/generated/sources/manifest/META-INF/MANIFEST.MF") + attributes("Implementation-Title" to "substrait-spark-$classifier") + } + } + + test { + dependsOn(":core:shadowJar") + useJUnitPlatform { includeEngines("scalatest") } + + // Set system properties for variant identification + systemProperty("spark.version", sparkVersion) + systemProperty("scala.version", scalaVersion) + systemProperty("variant.classifier", classifier) + + jvmArgs( + "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.net=ALL-UNNAMED", + ) + environment("SPARK_LOCAL_IP", "127.0.0.1") + + // Separate test reports per variant + reports { + html.outputLocation.set(layout.buildDirectory.dir("reports/tests/$classifier")) + junitXml.outputLocation.set(layout.buildDirectory.dir("test-results/$classifier")) + } + } +} diff --git a/spark/spark-3.5_2.12/build.gradle.kts b/spark/spark-3.5_2.12/build.gradle.kts new file mode 100644 index 000000000..c42de9af5 --- /dev/null +++ b/spark/spark-3.5_2.12/build.gradle.kts @@ -0,0 +1,176 @@ +plugins { + `maven-publish` + signing + id("java-library") + id("scala") + id("idea") + alias(libs.plugins.spotless) + alias(libs.plugins.nmcp) + id("substrait.java-conventions") +} + +// Spark 3.5 with Scala 2.12 variant configuration +val sparkVersion = "3.5.4" +val scalaVersion = "2.12.20" +val sparkMajorMinor = "3.5" +val scalaBinary = "2.12" +val classifier = "spark35_2.12" + +publishing { + publications { + create("maven-publish") { + from(components["java"]) + + artifactId = classifier + + pom { + name.set("Substrait Spark $classifier") + description.set( + "Substrait integration for Apache Spark $sparkMajorMinor with Scala $scalaBinary" + ) + url.set("https://github.com/substrait-io/substrait-java") + + properties.put("spark.version", sparkVersion) + properties.put("scala.version", scalaVersion) + properties.put("scala.binary.version", scalaBinary) + + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } + } + scm { + connection.set("scm:git:git://github.com:substrait-io/substrait-java.git") + developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java") + url.set("https://github.com/substrait-io/substrait-java/") + } + } + } + } + repositories { + maven { + name = "local" + val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") + val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") + url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + } + } +} + +signing { + setRequired({ + gradle.taskGraph.hasTask(":${project.name}:publishMaven-publishPublicationToNmcpRepository") + }) + val signingKeyId = + System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY_ID"].toString() + val signingPassword = + System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_PASSWORD"].toString() + val signingKey = + System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY"].toString() + useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword) + sign(publishing.publications["maven-publish"]) +} + +configurations.all { + if (name.startsWith("incrementalScalaAnalysis")) { + setExtendsFrom(emptyList()) + } +} + +java { + toolchain { languageVersion = JavaLanguageVersion.of(17) } + withJavadocJar() + withSourcesJar() +} + +tasks.withType() { + scalaCompileOptions.additionalParameters = listOf("-release:17", "-Xfatal-warnings") +} + +// Configure source sets to reference parent's src directory +sourceSets { + main { + scala { setSrcDirs(listOf("../src/main/scala", "../src/main/spark-$sparkMajorMinor")) } + resources { setSrcDirs(listOf("../src/main/resources")) } + } + test { + scala { + setSrcDirs( + listOf("../src/test/scala", "../src/test/spark-$sparkMajorMinor", "../src/main/scala") + ) + } + resources { setSrcDirs(listOf("../src/test/resources")) } + } +} + +dependencies { + // Use runtimeElements configuration to avoid test classpath resolution + api(project(":core", configuration = "runtimeElements")) + + // Scala dependencies + implementation("org.scala-lang:scala-library:$scalaVersion") + testImplementation("org.scalatest:scalatest_$scalaBinary:3.2.19") + testRuntimeOnly("org.scalatestplus:junit-5-12_$scalaBinary:3.2.19.0") + + // Spark dependencies + api("org.apache.spark:spark-core_$scalaBinary:$sparkVersion") + api("org.apache.spark:spark-sql_$scalaBinary:$sparkVersion") + implementation("org.apache.spark:spark-hive_$scalaBinary:$sparkVersion") + implementation("org.apache.spark:spark-catalyst_$scalaBinary:$sparkVersion") + testImplementation("org.apache.spark:spark-core_$scalaBinary:$sparkVersion:tests") + testImplementation("org.apache.spark:spark-sql_$scalaBinary:$sparkVersion:tests") + testImplementation("org.apache.spark:spark-catalyst_$scalaBinary:$sparkVersion:tests") + + // Common dependencies + implementation(libs.slf4j.api) + implementation(platform(libs.jackson.bom)) + implementation(libs.bundles.jackson) + implementation(libs.json.schema.validator) + + testImplementation(platform(libs.junit.bom)) + testRuntimeOnly(libs.junit.platform.engine) + testRuntimeOnly(libs.junit.platform.launcher) +} + +// Spotless is disabled for subprojects since source files are in parent directory +// The parent spark project handles formatting +spotless { isEnforceCheck = false } + +tasks { + // Ensure shadowJar runs before compilation + named("compileJava") { dependsOn(":core:shadowJar") } + named("compileScala") { dependsOn(":core:shadowJar") } + + jar { + manifest { + from("../../core/build/generated/sources/manifest/META-INF/MANIFEST.MF") + attributes("Implementation-Title" to "substrait-spark-$classifier") + } + } + + test { + dependsOn(":core:shadowJar") + useJUnitPlatform { includeEngines("scalatest") } + + // Set system properties for variant identification + systemProperty("spark.version", sparkVersion) + systemProperty("scala.version", scalaVersion) + systemProperty("variant.classifier", classifier) + + jvmArgs( + "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.net=ALL-UNNAMED", + ) + environment("SPARK_LOCAL_IP", "127.0.0.1") + + // Separate test reports per variant + reports { + html.outputLocation.set(layout.buildDirectory.dir("reports/tests/$classifier")) + junitXml.outputLocation.set(layout.buildDirectory.dir("test-results/$classifier")) + } + } +} diff --git a/spark/spark-4.0_2.13/build.gradle.kts b/spark/spark-4.0_2.13/build.gradle.kts new file mode 100644 index 000000000..2057492e7 --- /dev/null +++ b/spark/spark-4.0_2.13/build.gradle.kts @@ -0,0 +1,183 @@ +plugins { + `maven-publish` + signing + id("java-library") + id("scala") + id("idea") + alias(libs.plugins.spotless) + alias(libs.plugins.nmcp) + id("substrait.java-conventions") +} + +// Spark 4.0 with Scala 2.13 variant configuration +val sparkVersion = "4.0.2" +val scalaVersion = "2.13.18" +val sparkMajorMinor = "4.0" +val scalaBinary = "2.13" +val classifier = "spark40_2.13" + +publishing { + publications { + create("maven-publish") { + from(components["java"]) + + artifactId = classifier + + pom { + name.set("Substrait Spark $classifier") + description.set( + "Substrait integration for Apache Spark $sparkMajorMinor with Scala $scalaBinary" + ) + url.set("https://github.com/substrait-io/substrait-java") + + properties.put("spark.version", sparkVersion) + properties.put("scala.version", scalaVersion) + properties.put("scala.binary.version", scalaBinary) + + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } + } + scm { + connection.set("scm:git:git://github.com:substrait-io/substrait-java.git") + developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java") + url.set("https://github.com/substrait-io/substrait-java/") + } + } + } + } + repositories { + maven { + name = "local" + val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") + val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") + url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + } + } +} + +signing { + setRequired({ + gradle.taskGraph.hasTask(":${project.name}:publishMaven-publishPublicationToNmcpRepository") + }) + val signingKeyId = + System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY_ID"].toString() + val signingPassword = + System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_PASSWORD"].toString() + val signingKey = + System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY"].toString() + useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword) + sign(publishing.publications["maven-publish"]) +} + +configurations.all { + if (name.startsWith("incrementalScalaAnalysis")) { + setExtendsFrom(emptyList()) + } +} + +java { + toolchain { languageVersion = JavaLanguageVersion.of(17) } + withJavadocJar() + withSourcesJar() +} + +tasks.withType() { + scalaCompileOptions.additionalParameters = listOf("-release:17", "-Xfatal-warnings") +} + +// Configure source sets to reference parent's src directory +sourceSets { + main { + scala { setSrcDirs(listOf("../src/main/scala", "../src/main/spark-$sparkMajorMinor")) } + resources { setSrcDirs(listOf("../src/main/resources")) } + } + test { + scala { + setSrcDirs( + listOf("../src/test/scala", "../src/test/spark-$sparkMajorMinor", "../src/main/scala") + ) + } + resources { setSrcDirs(listOf("../src/test/resources")) } + } +} + +dependencies { + // Use runtimeElements configuration to avoid test classpath resolution + api(project(":core", configuration = "runtimeElements")) + + // Scala dependencies + implementation("org.scala-lang:scala-library:$scalaVersion") + testImplementation("org.scalatest:scalatest_$scalaBinary:3.2.19") + testRuntimeOnly("org.scalatestplus:junit-5-13_$scalaBinary:3.2.19.0") + + // Spark dependencies + api("org.apache.spark:spark-core_$scalaBinary:$sparkVersion") + api("org.apache.spark:spark-sql_$scalaBinary:$sparkVersion") + implementation("org.apache.spark:spark-hive_$scalaBinary:$sparkVersion") + implementation("org.apache.spark:spark-catalyst_$scalaBinary:$sparkVersion") + testImplementation("org.apache.spark:spark-core_$scalaBinary:$sparkVersion:tests") + testImplementation("org.apache.spark:spark-sql_$scalaBinary:$sparkVersion:tests") + testImplementation("org.apache.spark:spark-catalyst_$scalaBinary:$sparkVersion:tests") + + // Common dependencies + implementation(libs.slf4j.api) + implementation(platform(libs.jackson.bom)) + implementation(libs.bundles.jackson) + implementation(libs.json.schema.validator) + + testImplementation(platform(libs.junit.bom)) + testRuntimeOnly(libs.junit.platform.engine) + testRuntimeOnly(libs.junit.platform.launcher) +} + +// Spotless is disabled for subprojects since source files are in parent directory +// The parent spark project handles formatting +spotless { isEnforceCheck = false } + +tasks.register("dialect") { + dependsOn(":core:shadowJar") + classpath = java.sourceSets["main"].runtimeClasspath + mainClass = "io.substrait.spark.utils.DialectGenerator" + args = listOf("../spark_dialect.yaml") +} + +tasks { + // Ensure shadowJar runs before compilation + named("compileJava") { dependsOn(":core:shadowJar") } + named("compileScala") { dependsOn(":core:shadowJar") } + + jar { + manifest { + from("../../core/build/generated/sources/manifest/META-INF/MANIFEST.MF") + attributes("Implementation-Title" to "substrait-spark-$classifier") + } + } + + test { + dependsOn(":core:shadowJar") + useJUnitPlatform { includeEngines("scalatest") } + + // Set system properties for variant identification + systemProperty("spark.version", sparkVersion) + systemProperty("scala.version", scalaVersion) + systemProperty("variant.classifier", classifier) + + jvmArgs( + "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.net=ALL-UNNAMED", + ) + environment("SPARK_LOCAL_IP", "127.0.0.1") + + // Separate test reports per variant + reports { + html.outputLocation.set(layout.buildDirectory.dir("reports/tests/$classifier")) + junitXml.outputLocation.set(layout.buildDirectory.dir("test-results/$classifier")) + } + } +} diff --git a/spark/spark_dialect.yaml b/spark/spark_dialect.yaml index 45f4c3c5d..79678251a 100644 --- a/spark/spark_dialect.yaml +++ b/spark/spark_dialect.yaml @@ -127,169 +127,168 @@ dependencies: comparison: "extension:io.substrait:functions_comparison" logarithmic: "extension:io.substrait:functions_logarithmic" datetime: "extension:io.substrait:functions_datetime" - rounding_decimal: "extension:io.substrait:functions_rounding_decimal" - string: "extension:io.substrait:functions_string" arithmetic: "extension:io.substrait:functions_arithmetic" aggregate_generic: "extension:io.substrait:functions_aggregate_generic" boolean: "extension:io.substrait:functions_boolean" + aggregate_approx: "extension:io.substrait:functions_aggregate_approx" + rounding_decimal: "extension:io.substrait:functions_rounding_decimal" + string: "extension:io.substrait:functions_string" spark: "extension:substrait:spark" arithmetic_decimal: "extension:io.substrait:functions_arithmetic_decimal" - aggregate_approx: "extension:io.substrait:functions_aggregate_approx" supported_scalar_functions: -- source: "arithmetic_decimal" - name: "add" +- source: "arithmetic" + name: "abs" system_metadata: - name: "+" - notation: "INFIX" + name: "abs" + notation: "FUNCTION" supported_impls: - - "dec_dec" + - "fp32" + - "fp64" + - "i16" + - "i32" + - "i64" + - "i8" - source: "arithmetic" - name: "add" + name: "acos" system_metadata: - name: "+" - notation: "INFIX" + name: "ACOS" + notation: "FUNCTION" supported_impls: - - "i8_i8" - - "i16_i16" - - "i32_i32" - - "i64_i64" - - "fp32_fp32" - - "fp64_fp64" -- source: "arithmetic_decimal" - name: "subtract" + - "fp64" +- source: "arithmetic" + name: "acosh" system_metadata: - name: "-" - notation: "INFIX" + name: "ACOSH" + notation: "FUNCTION" supported_impls: - - "dec_dec" + - "fp64" - source: "arithmetic" - name: "subtract" + name: "add" system_metadata: - name: "-" + name: "+" notation: "INFIX" supported_impls: - - "i8_i8" + - "fp32_fp32" + - "fp64_fp64" - "i16_i16" - "i32_i32" - "i64_i64" - - "fp32_fp32" - - "fp64_fp64" -- source: "arithmetic_decimal" - name: "multiply" - system_metadata: - name: "*" - notation: "INFIX" - supported_impls: - - "dec_dec" + - "i8_i8" - source: "arithmetic" - name: "multiply" + name: "asin" system_metadata: - name: "*" - notation: "INFIX" + name: "ASIN" + notation: "FUNCTION" supported_impls: - - "i8_i8" - - "i16_i16" - - "i32_i32" - - "i64_i64" - - "fp32_fp32" - - "fp64_fp64" -- source: "arithmetic_decimal" - name: "divide" + - "fp64" +- source: "arithmetic" + name: "asinh" system_metadata: - name: "/" - notation: "INFIX" + name: "ASINH" + notation: "FUNCTION" supported_impls: - - "dec_dec" + - "fp64" - source: "arithmetic" - name: "divide" + name: "atan" system_metadata: - name: "/" - notation: "INFIX" + name: "ATAN" + notation: "FUNCTION" supported_impls: - - "fp64_fp64" -- source: "arithmetic_decimal" - name: "abs" + - "fp64" +- source: "arithmetic" + name: "atan2" system_metadata: - name: "abs" + name: "ATAN2" notation: "FUNCTION" supported_impls: - - "dec" + - "fp64_fp64" - source: "arithmetic" - name: "abs" + name: "atanh" system_metadata: - name: "abs" + name: "ATANH" notation: "FUNCTION" supported_impls: - - "i8" - - "i16" - - "i32" - - "i64" - - "fp32" - "fp64" -- source: "arithmetic_decimal" - name: "modulus" +- source: "arithmetic" + name: "bitwise_and" system_metadata: - name: "%" + name: "&" notation: "INFIX" supported_impls: - - "dec_dec" + - "i16_i16" + - "i32_i32" + - "i64_i64" + - "i8_i8" - source: "arithmetic" - name: "modulus" + name: "bitwise_or" system_metadata: - name: "%" + name: "|" notation: "INFIX" supported_impls: - - "i8_i8" - "i16_i16" - "i32_i32" - "i64_i64" -- source: "rounding" - name: "round" + - "i8_i8" +- source: "arithmetic" + name: "bitwise_xor" system_metadata: - name: "round" - notation: "FUNCTION" + name: "^" + notation: "INFIX" supported_impls: - - "i8_i32" - - "i16_i32" + - "i16_i16" - "i32_i32" - - "i64_i32" - - "fp32_i32" - - "fp64_i32" -- source: "rounding_decimal" - name: "round" + - "i64_i64" + - "i8_i8" +- source: "arithmetic" + name: "cos" system_metadata: - name: "round" + name: "COS" notation: "FUNCTION" supported_impls: - - "dec_i32" -- source: "rounding" - name: "floor" + - "fp64" +- source: "arithmetic" + name: "cosh" system_metadata: - name: "FLOOR" + name: "COSH" notation: "FUNCTION" supported_impls: - "fp64" -- source: "rounding_decimal" - name: "floor" +- source: "arithmetic" + name: "divide" system_metadata: - name: "FLOOR" - notation: "FUNCTION" + name: "/" + notation: "INFIX" supported_impls: - - "dec" -- source: "rounding" - name: "ceil" + - "fp64_fp64" +- source: "arithmetic" + name: "exp" system_metadata: - name: "CEIL" + name: "EXP" notation: "FUNCTION" supported_impls: - "fp64" -- source: "rounding_decimal" - name: "ceil" +- source: "arithmetic" + name: "modulus" system_metadata: - name: "CEIL" - notation: "FUNCTION" + name: "%" + notation: "INFIX" supported_impls: - - "dec" + - "i16_i16" + - "i32_i32" + - "i64_i64" + - "i8_i8" +- source: "arithmetic" + name: "multiply" + system_metadata: + name: "*" + notation: "INFIX" + supported_impls: + - "fp32_fp32" + - "fp64_fp64" + - "i16_i16" + - "i32_i32" + - "i64_i64" + - "i8_i8" - source: "arithmetic" name: "power" system_metadata: @@ -298,19 +297,29 @@ supported_scalar_functions: supported_impls: - "fp64_fp64" - source: "arithmetic" - name: "exp" + name: "shift_left" system_metadata: - name: "EXP" + name: "shiftleft" notation: "FUNCTION" supported_impls: - - "fp64" + - "i32_i32" + - "i64_i32" - source: "arithmetic" - name: "sqrt" + name: "shift_right" system_metadata: - name: "SQRT" + name: "shiftright" notation: "FUNCTION" supported_impls: - - "fp64" + - "i32_i32" + - "i64_i32" +- source: "arithmetic" + name: "shift_right_unsigned" + system_metadata: + name: "shiftrightunsigned" + notation: "FUNCTION" + supported_impls: + - "i32_i32" + - "i64_i32" - source: "arithmetic" name: "sin" system_metadata: @@ -319,123 +328,136 @@ supported_scalar_functions: supported_impls: - "fp64" - source: "arithmetic" - name: "cos" + name: "sinh" system_metadata: - name: "COS" + name: "SINH" notation: "FUNCTION" supported_impls: - "fp64" - source: "arithmetic" - name: "tan" + name: "sqrt" system_metadata: - name: "TAN" + name: "SQRT" notation: "FUNCTION" supported_impls: - "fp64" - source: "arithmetic" - name: "asin" + name: "subtract" system_metadata: - name: "ASIN" - notation: "FUNCTION" + name: "-" + notation: "INFIX" supported_impls: - - "fp64" + - "fp32_fp32" + - "fp64_fp64" + - "i16_i16" + - "i32_i32" + - "i64_i64" + - "i8_i8" - source: "arithmetic" - name: "acos" + name: "tan" system_metadata: - name: "ACOS" + name: "TAN" notation: "FUNCTION" supported_impls: - "fp64" - source: "arithmetic" - name: "atan" + name: "tanh" system_metadata: - name: "ATAN" + name: "TANH" notation: "FUNCTION" supported_impls: - "fp64" -- source: "arithmetic" - name: "atan2" +- source: "arithmetic_decimal" + name: "abs" system_metadata: - name: "ATAN2" + name: "abs" notation: "FUNCTION" supported_impls: - - "fp64_fp64" -- source: "arithmetic" - name: "sinh" + - "dec" +- source: "arithmetic_decimal" + name: "add" system_metadata: - name: "SINH" - notation: "FUNCTION" + name: "+" + notation: "INFIX" supported_impls: - - "fp64" -- source: "arithmetic" - name: "cosh" + - "dec_dec" +- source: "arithmetic_decimal" + name: "divide" system_metadata: - name: "COSH" - notation: "FUNCTION" + name: "/" + notation: "INFIX" supported_impls: - - "fp64" -- source: "arithmetic" - name: "tanh" + - "dec_dec" +- source: "arithmetic_decimal" + name: "modulus" system_metadata: - name: "TANH" - notation: "FUNCTION" + name: "%" + notation: "INFIX" supported_impls: - - "fp64" -- source: "arithmetic" - name: "asinh" + - "dec_dec" +- source: "arithmetic_decimal" + name: "multiply" system_metadata: - name: "ASINH" - notation: "FUNCTION" + name: "*" + notation: "INFIX" supported_impls: - - "fp64" -- source: "arithmetic" - name: "acosh" + - "dec_dec" +- source: "arithmetic_decimal" + name: "subtract" system_metadata: - name: "ACOSH" - notation: "FUNCTION" + name: "-" + notation: "INFIX" supported_impls: - - "fp64" -- source: "arithmetic" - name: "atanh" + - "dec_dec" +- source: "boolean" + name: "not" system_metadata: - name: "ATANH" + name: "not" notation: "FUNCTION" supported_impls: - - "fp64" -- source: "logarithmic" - name: "ln" + - "bool" +- source: "comparison" + name: "equal" system_metadata: - name: "ln" - notation: "FUNCTION" + name: "=" + notation: "INFIX" supported_impls: - - "fp64" -- source: "logarithmic" - name: "log10" + - "any_any" +- source: "comparison" + name: "gt" system_metadata: - name: "LOG10" - notation: "FUNCTION" + name: ">" + notation: "INFIX" supported_impls: - - "fp64" -- source: "boolean" - name: "not" + - "any_any" +- source: "comparison" + name: "gte" system_metadata: - name: "not" - notation: "FUNCTION" + name: ">=" + notation: "INFIX" supported_impls: - - "bool" -- source: "datetime" - name: "lt" + - "any_any" +- source: "comparison" + name: "is_not_distinct_from" system_metadata: - name: "<" + name: "<=>" notation: "INFIX" supported_impls: - - "ts_ts" - - "pts_pts" - - "tstz_tstz" - - "ptstz_ptstz" - - "date_date" - - "iday_iday" - - "iyear_iyear" + - "any_any" +- source: "comparison" + name: "is_not_null" + system_metadata: + name: "isnotnull" + notation: "FUNCTION" + supported_impls: + - "any" +- source: "comparison" + name: "is_null" + system_metadata: + name: "isnull" + notation: "FUNCTION" + supported_impls: + - "any" - source: "comparison" name: "lt" system_metadata: @@ -443,19 +465,6 @@ supported_scalar_functions: notation: "INFIX" supported_impls: - "any_any" -- source: "datetime" - name: "lte" - system_metadata: - name: "<=" - notation: "INFIX" - supported_impls: - - "ts_ts" - - "pts_pts" - - "tstz_tstz" - - "ptstz_ptstz" - - "date_date" - - "iday_iday" - - "iyear_iyear" - source: "comparison" name: "lte" system_metadata: @@ -469,229 +478,286 @@ supported_scalar_functions: name: ">" notation: "INFIX" supported_impls: - - "ts_ts" - - "pts_pts" - - "tstz_tstz" - - "ptstz_ptstz" - "date_date" - "iday_iday" - "iyear_iyear" -- source: "comparison" - name: "gt" - system_metadata: - name: ">" - notation: "INFIX" - supported_impls: - - "any_any" + - "pts_pts" + - "ptstz_ptstz" + - "ts_ts" + - "tstz_tstz" - source: "datetime" name: "gte" system_metadata: name: ">=" notation: "INFIX" supported_impls: - - "ts_ts" + - "date_date" + - "iday_iday" + - "iyear_iyear" - "pts_pts" - - "tstz_tstz" - "ptstz_ptstz" + - "ts_ts" + - "tstz_tstz" +- source: "datetime" + name: "lt" + system_metadata: + name: "<" + notation: "INFIX" + supported_impls: - "date_date" - "iday_iday" - "iyear_iyear" -- source: "comparison" - name: "gte" + - "pts_pts" + - "ptstz_ptstz" + - "ts_ts" + - "tstz_tstz" +- source: "datetime" + name: "lte" system_metadata: - name: ">=" + name: "<=" notation: "INFIX" supported_impls: - - "any_any" -- source: "comparison" - name: "equal" + - "date_date" + - "iday_iday" + - "iyear_iyear" + - "pts_pts" + - "ptstz_ptstz" + - "ts_ts" + - "tstz_tstz" +- source: "logarithmic" + name: "ln" system_metadata: - name: "=" - notation: "INFIX" + name: "ln" + notation: "FUNCTION" supported_impls: - - "any_any" -- source: "comparison" - name: "is_not_distinct_from" + - "fp64" +- source: "logarithmic" + name: "log10" system_metadata: - name: "<=>" - notation: "INFIX" + name: "LOG10" + notation: "FUNCTION" supported_impls: - - "any_any" -- source: "comparison" - name: "is_null" + - "fp64" +- source: "rounding" + name: "ceil" system_metadata: - name: "isnull" + name: "CEIL" notation: "FUNCTION" supported_impls: - - "any" -- source: "comparison" - name: "is_not_null" + - "fp64" +- source: "rounding" + name: "floor" system_metadata: - name: "isnotnull" + name: "FLOOR" notation: "FUNCTION" supported_impls: - - "any" + - "fp64" +- source: "rounding" + name: "round" + system_metadata: + name: "round" + notation: "FUNCTION" + supported_impls: + - "fp32_i32" + - "fp64_i32" + - "i16_i32" + - "i32_i32" + - "i64_i32" + - "i8_i32" +- source: "rounding_decimal" + name: "ceil" + system_metadata: + name: "CEIL" + notation: "FUNCTION" + supported_impls: + - "dec" +- source: "rounding_decimal" + name: "floor" + system_metadata: + name: "FLOOR" + notation: "FUNCTION" + supported_impls: + - "dec" +- source: "rounding_decimal" + name: "round" + system_metadata: + name: "round" + notation: "FUNCTION" + supported_impls: + - "dec_i32" +- source: "spark" + name: "add" + system_metadata: + name: "date_add" + notation: "FUNCTION" + supported_impls: + - "date_i32" - source: "string" - name: "ends_with" + name: "contains" system_metadata: - name: "endswith" + name: "contains" notation: "FUNCTION" supported_impls: - - "vchar_vchar" - - "vchar_str" - - "vchar_fchar" + - "fchar_fchar" + - "fchar_str" + - "fchar_vchar" + - "str_fchar" - "str_str" - "str_vchar" - - "str_fchar" + - "vchar_fchar" + - "vchar_str" + - "vchar_vchar" +- source: "string" + name: "ends_with" + system_metadata: + name: "endswith" + notation: "FUNCTION" + supported_impls: - "fchar_fchar" - "fchar_str" - "fchar_vchar" + - "str_fchar" + - "str_str" + - "str_vchar" + - "vchar_fchar" + - "vchar_str" + - "vchar_vchar" - source: "string" name: "like" system_metadata: name: "like" notation: "FUNCTION" supported_impls: - - "vchar_vchar" - "str_str" + - "vchar_vchar" - source: "string" - name: "contains" + name: "lower" system_metadata: - name: "contains" + name: "lower" notation: "FUNCTION" supported_impls: - - "vchar_vchar" - - "vchar_str" - - "vchar_fchar" - - "str_str" - - "str_vchar" - - "str_fchar" - - "fchar_fchar" - - "fchar_str" - - "fchar_vchar" + - "fchar" + - "str" + - "vchar" +- source: "string" + name: "lpad" + system_metadata: + name: "lpad" + notation: "FUNCTION" + supported_impls: + - "str_i32_str" + - "vchar_i32_vchar" +- source: "string" + name: "rpad" + system_metadata: + name: "rpad" + notation: "FUNCTION" + supported_impls: + - "str_i32_str" + - "vchar_i32_vchar" - source: "string" name: "starts_with" system_metadata: name: "startswith" notation: "FUNCTION" supported_impls: - - "vchar_vchar" - - "vchar_str" - - "vchar_fchar" - - "str_str" - - "str_vchar" - - "str_fchar" - "fchar_fchar" - "fchar_str" - "fchar_vchar" + - "str_fchar" + - "str_str" + - "str_vchar" + - "vchar_fchar" + - "vchar_str" + - "vchar_vchar" - source: "string" name: "substring" system_metadata: name: "substring" notation: "FUNCTION" supported_impls: - - "vchar_i32_i32" - - "str_i32_i32" - "fchar_i32_i32" + - "str_i32_i32" + - "vchar_i32_i32" - source: "string" name: "upper" system_metadata: name: "upper" notation: "FUNCTION" supported_impls: - - "str" - - "vchar" - "fchar" -- source: "string" - name: "lower" - system_metadata: - name: "lower" - notation: "FUNCTION" - supported_impls: - "str" - "vchar" - - "fchar" -- source: "arithmetic" - name: "shift_left" +supported_aggregate_functions: +- source: "aggregate_approx" + name: "approx_count_distinct" system_metadata: - name: "shiftleft" + name: "approx_count_distinct" notation: "FUNCTION" supported_impls: - - "i32_i32" - - "i64_i32" -- source: "arithmetic" - name: "shift_right" + - "any" +- source: "aggregate_generic" + name: "any_value" system_metadata: - name: "shiftright" + name: "first" notation: "FUNCTION" supported_impls: - - "i32_i32" - - "i64_i32" + - "any" - source: "arithmetic" - name: "shift_right_unsigned" + name: "avg" system_metadata: - name: "shiftrightunsigned" + name: "avg" notation: "FUNCTION" supported_impls: - - "i32_i32" - - "i64_i32" -- source: "arithmetic" - name: "bitwise_and" - system_metadata: - name: "&" - notation: "INFIX" - supported_impls: - - "i8_i8" - - "i16_i16" - - "i32_i32" - - "i64_i64" + - "fp32" + - "fp64" + - "i16" + - "i32" + - "i64" + - "i8" - source: "arithmetic" - name: "bitwise_or" + name: "max" system_metadata: - name: "|" - notation: "INFIX" + name: "max" + notation: "FUNCTION" supported_impls: - - "i8_i8" - - "i16_i16" - - "i32_i32" - - "i64_i64" + - "fp32" + - "fp64" + - "i16" + - "i32" + - "i64" + - "i8" - source: "arithmetic" - name: "bitwise_xor" - system_metadata: - name: "^" - notation: "INFIX" - supported_impls: - - "i8_i8" - - "i16_i16" - - "i32_i32" - - "i64_i64" -- source: "spark" - name: "add" + name: "min" system_metadata: - name: "date_add" + name: "min" notation: "FUNCTION" supported_impls: - - "date_i32" -supported_aggregate_functions: -- source: "arithmetic_decimal" - name: "sum" + - "fp32" + - "fp64" + - "i16" + - "i32" + - "i64" + - "i8" +- source: "arithmetic" + name: "std_dev" system_metadata: - name: "sum" + name: "stddev_samp" notation: "FUNCTION" supported_impls: - - "dec" + - "fp64" - source: "arithmetic" name: "sum" system_metadata: name: "sum" notation: "FUNCTION" supported_impls: - - "i8" + - "fp32" + - "fp64" - "i16" - "i32" - "i64" - - "fp32" - - "fp64" + - "i8" - source: "arithmetic_decimal" name: "avg" system_metadata: @@ -699,18 +765,13 @@ supported_aggregate_functions: notation: "FUNCTION" supported_impls: - "dec" -- source: "arithmetic" - name: "avg" +- source: "arithmetic_decimal" + name: "max" system_metadata: - name: "avg" + name: "max" notation: "FUNCTION" supported_impls: - - "i8" - - "i16" - - "i32" - - "i64" - - "fp32" - - "fp64" + - "dec" - source: "arithmetic_decimal" name: "min" system_metadata: @@ -718,122 +779,77 @@ supported_aggregate_functions: notation: "FUNCTION" supported_impls: - "dec" -- source: "arithmetic" - name: "min" +- source: "arithmetic_decimal" + name: "sum" system_metadata: - name: "min" + name: "sum" notation: "FUNCTION" supported_impls: - - "i8" - - "i16" - - "i32" - - "i64" - - "fp32" - - "fp64" + - "dec" - source: "datetime" - name: "min" + name: "max" system_metadata: - name: "min" + name: "max" notation: "FUNCTION" supported_impls: - "date" + - "iday" + - "iyear" + - "pts" + - "ptstz" - "time" - "ts" - - "pts" - "tstz" - - "ptstz" - - "iday" - - "iyear" -- source: "arithmetic_decimal" - name: "max" - system_metadata: - name: "max" - notation: "FUNCTION" - supported_impls: - - "dec" -- source: "arithmetic" - name: "max" - system_metadata: - name: "max" - notation: "FUNCTION" - supported_impls: - - "i8" - - "i16" - - "i32" - - "i64" - - "fp32" - - "fp64" - source: "datetime" - name: "max" + name: "min" system_metadata: - name: "max" + name: "min" notation: "FUNCTION" supported_impls: - "date" + - "iday" + - "iyear" + - "pts" + - "ptstz" - "time" - "ts" - - "pts" - "tstz" - - "ptstz" - - "iday" - - "iyear" -- source: "aggregate_generic" - name: "any_value" - system_metadata: - name: "first" - notation: "FUNCTION" - supported_impls: - - "any" -- source: "aggregate_approx" - name: "approx_count_distinct" - system_metadata: - name: "approx_count_distinct" - notation: "FUNCTION" - supported_impls: - - "any" -- source: "arithmetic" - name: "std_dev" - system_metadata: - name: "stddev_samp" - notation: "FUNCTION" - supported_impls: - - "fp64" supported_window_functions: - source: "arithmetic" - name: "row_number" + name: "cume_dist" system_metadata: - name: "row_number" + name: "cume_dist" notation: "FUNCTION" supported_impls: - "" - source: "arithmetic" - name: "rank" + name: "dense_rank" system_metadata: - name: "rank" + name: "dense_rank" notation: "FUNCTION" supported_impls: - "" - source: "arithmetic" - name: "dense_rank" + name: "lag" system_metadata: - name: "dense_rank" + name: "lag" notation: "FUNCTION" supported_impls: - - "" + - "any_i32_any" - source: "arithmetic" - name: "percent_rank" + name: "lead" system_metadata: - name: "percent_rank" + name: "lead" notation: "FUNCTION" supported_impls: - - "" + - "any_i32_any" - source: "arithmetic" - name: "cume_dist" + name: "nth_value" system_metadata: - name: "cume_dist" + name: "nth_value" notation: "FUNCTION" supported_impls: - - "" + - "any_i32" - source: "arithmetic" name: "ntile" system_metadata: @@ -842,23 +858,23 @@ supported_window_functions: supported_impls: - "i32" - source: "arithmetic" - name: "lead" + name: "percent_rank" system_metadata: - name: "lead" + name: "percent_rank" notation: "FUNCTION" supported_impls: - - "any_i32_any" + - "" - source: "arithmetic" - name: "lag" + name: "rank" system_metadata: - name: "lag" + name: "rank" notation: "FUNCTION" supported_impls: - - "any_i32_any" + - "" - source: "arithmetic" - name: "nth_value" + name: "row_number" system_metadata: - name: "nth_value" + name: "row_number" notation: "FUNCTION" supported_impls: - - "any_i32" + - "" diff --git a/spark/src/main/scala/io/substrait/spark/compat/SparkCompat.scala b/spark/src/main/scala/io/substrait/spark/compat/SparkCompat.scala new file mode 100644 index 000000000..da1681345 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/compat/SparkCompat.scala @@ -0,0 +1,73 @@ +package io.substrait.spark.compat + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} + +/** + * Compatibility layer for Spark version-specific APIs. Implementations are provided in + * variant-specific source directories. + */ +trait SparkCompat { + + /** Create a ScalarSubquery with version-appropriate constructor */ + def createScalarSubquery(plan: LogicalPlan): ScalarSubquery + + /** Create an Aggregate with version-appropriate constructor */ + def createAggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan + ): Aggregate + + /** Create a LogicalRelation with version-appropriate constructor */ + def createLogicalRelation( + relation: HadoopFsRelation, + output: Seq[AttributeReference], + catalogTable: Option[org.apache.spark.sql.catalyst.catalog.CatalogTable], + isStreaming: Boolean + ): LogicalRelation + + /** Create a ListQuery with version-appropriate constructor */ + def createListQuery( + plan: LogicalPlan, + output: Seq[Attribute]): org.apache.spark.sql.catalyst.expressions.ListQuery + + /** Get SparkSession instance (returns AnyRef to work across versions) */ + def getOrCreateSparkSession(): AnyRef + + /** Create QueryExecution with version-appropriate SparkSession type */ + def createQueryExecution( + spark: AnyRef, + plan: LogicalPlan): org.apache.spark.sql.execution.QueryExecution + + /** Get config value from SparkSession */ + def getConf(spark: AnyRef, key: String): String + + /** Create InMemoryFileIndex with version-appropriate SparkSession type */ + def createInMemoryFileIndex( + spark: AnyRef, + paths: Seq[org.apache.hadoop.fs.Path], + parameters: Map[String, String], + userSpecifiedSchema: Option[org.apache.spark.sql.types.StructType] + ): org.apache.spark.sql.execution.datasources.InMemoryFileIndex + + def createHadoopFsRelation( + spark: AnyRef, + location: org.apache.spark.sql.execution.datasources.InMemoryFileIndex, + partitionSchema: org.apache.spark.sql.types.StructType, + dataSchema: org.apache.spark.sql.types.StructType, + bucketSpec: Option[org.apache.spark.sql.catalyst.catalog.BucketSpec], + fileFormat: org.apache.spark.sql.execution.datasources.FileFormat, + options: Map[String, String] + ): org.apache.spark.sql.execution.datasources.HadoopFsRelation +} + +object SparkCompat { + + /** + * Get the version-specific implementation. This will be resolved at compile time based on the + * variant being built. + */ + lazy val instance: SparkCompat = new SparkCompatImpl() +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index 526763377..76ec9c2fc 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -158,6 +158,8 @@ class FunctionMappings { s[Substring]("substring"), s[Upper]("upper"), s[Lower]("lower"), + s[StringLPad]("lpad"), + s[StringRPad]("rpad"), s[Concat]("concat"), s[Coalesce]("coalesce"), s[ShiftLeft]("shift_left"), diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 18d3b3543..e1c40921e 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -17,6 +17,7 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSparkType} +import io.substrait.spark.compat.SparkCompat import io.substrait.spark.logical.ToLogicalPlan import io.substrait.spark.utils.Util @@ -230,7 +231,7 @@ class ToSparkExpression( relConverter => { val plan = rel.accept(relConverter, context) require(plan.resolved) - val result = ScalarSubquery(plan) + val result = SparkCompat.instance.createScalarSubquery(plan) SparkTypeUtil.sameType(result.dataType, dataType) result }) @@ -248,7 +249,7 @@ class ToSparkExpression( override def visit(expr: SExpression.InPredicate, context: EmptyVisitationContext): Expression = { val needles = expr.needles().asScala.map(e => e.accept(this, context)).toSeq val haystack = expr.haystack().accept(toLogicalPlan.get, context) - new InSubquery(needles, ListQuery(haystack, childOutputs = haystack.output)) { + new InSubquery(needles, SparkCompat.instance.createListQuery(haystack, haystack.output)) { override def nullable: Boolean = expr.getType.nullable() } } diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala index a3903cba3..af31f2201 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -17,6 +17,7 @@ package io.substrait.spark.expression import io.substrait.spark.{HasOutputStack, ToSubstraitType} +import io.substrait.spark.compat.SparkCompat import io.substrait.spark.utils.Util import org.apache.spark.sql.catalyst.expressions._ @@ -52,38 +53,35 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { // its original form to make it substrait friendly. // https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala#L76 def unapply(e: Expression): Option[ScalarSubquery] = e match { - case GetStructField( - ScalarSubquery( - Project(Seq(Alias(CreateNamedStruct(args), "mergedValue")), child), - _, - _, - _, - _, - _), - ordinal, - _) => - val (name, value) = - args.grouped(2).map { case Seq(name, value) => (name, value) }.toArray.apply(ordinal) - - child match { - case Aggregate(groupingExpressions, aggregateExpressions, child) - if aggregateExpressions.forall(e => e.isInstanceOf[Alias]) => - val used = value match { - case ref: AttributeReference => ref.exprId.id - case _ => throw new UnsupportedOperationException(s"Cannot convert expression: $e") + case GetStructField(subquery: ScalarSubquery, ordinal, _) => + // Extract the project and child from the subquery + subquery.plan match { + case Project(Seq(Alias(CreateNamedStruct(args), "mergedValue")), child) => + val (name, value) = + args.grouped(2).map { case Seq(name, value) => (name, value) }.toArray.apply(ordinal) + + child match { + case agg: Aggregate if agg.aggregateExpressions.forall(e => e.isInstanceOf[Alias]) => + val used = value match { + case ref: AttributeReference => ref.exprId.id + case _ => + throw new UnsupportedOperationException(s"Cannot convert expression: $e") + } + val filteredAggExprs = agg.aggregateExpressions.filter(ae => used == ae.exprId.id) + Some( + SparkCompat.instance.createScalarSubquery( + SparkCompat.instance + .createAggregate(agg.groupingExpressions, filteredAggExprs, agg.child) + ) + ) + case _ => + Some( + SparkCompat.instance.createScalarSubquery( + Project(Seq(Alias(value, name.toString())()), child) + ) + ) } - val filteredAggExprs = aggregateExpressions.filter(ae => used == ae.exprId.id) - Some( - ScalarSubquery( - Aggregate(groupingExpressions, filteredAggExprs, child) - ) - ) - case _ => - Some( - ScalarSubquery( - Project(Seq(Alias(value, name.toString())()), child) - ) - ) + case _ => None } case _ => None } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 3d48d2aa2..bf36ba8af 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -17,9 +17,10 @@ package io.substrait.spark.logical import io.substrait.spark.{DefaultRelVisitor, FileHolder, SparkExtension, ToSparkType, ToSubstraitType} +import io.substrait.spark.compat.SparkCompat import io.substrait.spark.expression._ -import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, MultiInstanceRelation, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} @@ -28,9 +29,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, CreateTableCommand, DataWritingCommand, DropTableCommand, LeafRunnableCommand} -import org.apache.spark.sql.execution.datasources.{FileFormat => SparkFileFormat, HadoopFsRelation, InMemoryFileIndex, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1Writes} +import org.apache.spark.sql.execution.datasources.{FileFormat => SparkFileFormat, InsertIntoHadoopFsRelationCommand, V1Writes} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -41,7 +41,6 @@ import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructTyp import io.substrait.`type`.{NamedStruct, StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} -import io.substrait.expression.Expression.NestedStruct import io.substrait.plan.Plan import io.substrait.relation import io.substrait.relation.{ExtensionWrite, LocalFiles, NamedDdl, NamedWrite} @@ -64,7 +63,7 @@ import scala.jdk.CollectionConverters._ * RelVisitor to convert Substrait Rel plan to [[LogicalPlan]]. Unsupported Rel node will call * visitFallback and throw UnsupportedOperationException. */ -class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) +class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSession()) extends DefaultRelVisitor[LogicalPlan] { private val expressionConverter = @@ -142,7 +141,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) val outputs = groupBy.map(toNamedExpression) val aggregateExpressions = aggregate.getMeasures.asScala.map(fromMeasure).map(toNamedExpression).toSeq - Aggregate(groupBy, outputs ++ aggregateExpressions, child) + SparkCompat.instance.createAggregate(groupBy, outputs ++ aggregateExpressions, child) } } @@ -414,19 +413,22 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) throw new UnsupportedOperationException(s"All files must have the same format") } val (format, options) = convertFileFormat(formats.head) - new LogicalRelation( - relation = HadoopFsRelation( - location = new InMemoryFileIndex( - spark, - localFiles.getItems.asScala.map(i => new Path(i.getPath.get())).toSeq, - Map(), - Some(schema)), - partitionSchema = new StructType(), - dataSchema = schema, - bucketSpec = None, - fileFormat = format, - options = options - )(spark), + val location = SparkCompat.instance.createInMemoryFileIndex( + spark, + localFiles.getItems.asScala.map(i => new Path(i.getPath.get())).toSeq, + Map(), + Some(schema)) + val hadoopFsRelation = SparkCompat.instance.createHadoopFsRelation( + spark, + location, + new StructType(), + schema, + None, + format, + options + ) + SparkCompat.instance.createLogicalRelation( + relation = hadoopFsRelation, output = output, catalogTable = None, isStreaming = false @@ -460,10 +462,11 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) override def visit(write: NamedWrite, context: EmptyVisitationContext): LogicalPlan = { val child = write.getInput.accept(this, context) val table = catalogTable(write.getNames.asScala.toSeq) - val isHive = spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION.key) match { - case "hive" => true - case _ => false - } + val isHive = + SparkCompat.instance.getConf(spark, StaticSQLConf.CATALOG_IMPLEMENTATION.key) match { + case "hive" => true + case _ => false + } write.getOperation match { case WriteOp.CTAS => withChild(child) { @@ -570,7 +573,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) s"NamedWrite requires up to three names ([[catalog,] database,] table): $names") } - val loc = spark.conf.get(StaticSQLConf.WAREHOUSE_PATH.key) + val loc = SparkCompat.instance.getConf(spark, StaticSQLConf.WAREHOUSE_PATH.key) val storage = CatalogStorageFormat( locationUri = Some(URI.create(f"$loc/$table")), inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"), @@ -611,7 +614,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } private def resolve(plan: LogicalPlan): LogicalPlan = { - val qe = new QueryExecution(spark, plan) + val qe = SparkCompat.instance.createQueryExecution(spark, plan) qe.analyzed match { case SubqueryAlias(_, child) => child case other => other @@ -666,7 +669,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) // This is helps a bit with round-trip testing and plan readability case project: Project => Project(renameAndCastExprs(project.projectList), project.child) case aggregate: Aggregate => - Aggregate( + SparkCompat.instance.createAggregate( aggregate.groupingExpressions, renameAndCastExprs(aggregate.aggregateExpressions), aggregate.child) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index df211d8d2..d55e4c395 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -17,6 +17,7 @@ package io.substrait.spark.logical import io.substrait.spark.{FileHolder, SparkExtension, ToSubstraitType} +import io.substrait.spark.compat.WindowGroupLimitCase import io.substrait.spark.expression._ import io.substrait.spark.utils.Util @@ -37,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, V2SessionCatalog} import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable} -import org.apache.spark.sql.types.{NullType, StructType} +import org.apache.spark.sql.types.{NullType, StructField, StructType} import io.substrait.`type`.{NamedStruct, Type} import io.substrait.{proto, relation} @@ -83,6 +84,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { case p: LeafNode => convertReadOperator(p) case s: SubqueryAlias => visit(s.child) case v: View => visit(v.child) +// case plan if SparkCompat.instance.supportsWindowGroupLimit => +// SparkCompat.instance.handleWindowGroupLimit(plan, visit) + case WindowGroupLimitCase(child) => visit(child) case other => t(other) } @@ -582,7 +586,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { case CreateTable(ResolvedIdentifier(c: V2SessionCatalog, id), tableSchema, _, _, _) if id.namespace().length > 0 => val names = Seq(c.name(), id.namespace()(0), id.name()) - convertCreateTable(names, tableSchema) + val schema = StructType( + tableSchema.map(col => StructField(col.name, col.dataType, col.nullable))) + convertCreateTable(names, schema) case DropTable(ResolvedIdentifier(c: V2SessionCatalog, id), ifExists, _) if id.namespace().length > 0 => val names = Seq(c.name(), id.namespace()(0), id.name()) diff --git a/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala b/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala index 4fcf06fd9..2a527e6d1 100644 --- a/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala +++ b/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala @@ -40,7 +40,7 @@ case class SupportedFunction( supported_impls: Seq[String]) class DialectGenerator { - val schemaPath = "../substrait/text/dialect_schema.yaml" + val schemaPath = "../../substrait/text/dialect_schema.yaml" private val sourceURNs = Map( "extension:io.substrait:functions_aggregate_approx" -> "aggregate_approx", @@ -75,9 +75,10 @@ class DialectGenerator { expressions, relations, sourceURNs.map(_.swap), - scalars, - aggregates, - windows) + scalars.sortBy(f => (f.source, f.name)), + aggregates.sortBy(f => (f.source, f.name)), + windows.sortBy(f => (f.source, f.name)) + ) } def generateYaml(): String = { @@ -229,7 +230,7 @@ class DialectGenerator { .groupBy(_._1) // group by URN .filter(_._1 != "FAILED") .view - .mapValues(_.map(_._2)) + .map { case (k, v) => (k, v.map(_._2)) } .toMap case _ => println(s"NO INPUT TYPES") @@ -241,7 +242,7 @@ class DialectGenerator { sourceURNs.getOrElse(urn, ""), sig.name, FunctionMetadata(sqlName, notation), - sigs) + sigs.sorted) }.toSeq } diff --git a/spark/src/main/spark-3.4/io/substrait/spark/compat/SparkCompatImpl.scala b/spark/src/main/spark-3.4/io/substrait/spark/compat/SparkCompatImpl.scala new file mode 100644 index 000000000..833e71091 --- /dev/null +++ b/spark/src/main/spark-3.4/io/substrait/spark/compat/SparkCompatImpl.scala @@ -0,0 +1,94 @@ +package io.substrait.spark.compat + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} + +import io.substrait.relation + +class SparkCompatImpl extends SparkCompat { + + override def createScalarSubquery(plan: LogicalPlan): ScalarSubquery = { + // Spark 3.4 requires exprId parameter + ScalarSubquery(plan, exprId = NamedExpression.newExprId) + } + + override def createAggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan + ): Aggregate = { + // Spark 3.4 uses 3-parameter constructor + Aggregate(groupingExpressions, aggregateExpressions, child) + } + + override def createLogicalRelation( + relation: HadoopFsRelation, + output: Seq[AttributeReference], + catalogTable: Option[org.apache.spark.sql.catalyst.catalog.CatalogTable], + isStreaming: Boolean + ): LogicalRelation = { + // Spark 3.4 uses 4-parameter constructor (no stream parameter) + new LogicalRelation(relation, output, catalogTable, isStreaming) + } + + override def createListQuery( + plan: LogicalPlan, + output: Seq[Attribute]): org.apache.spark.sql.catalyst.expressions.ListQuery = { + // Spark 3.4 doesn't have numCols parameter + org.apache.spark.sql.catalyst.expressions.ListQuery(plan, childOutputs = output) + } + + override def getOrCreateSparkSession(): AnyRef = { + org.apache.spark.sql.SparkSession.builder().getOrCreate() + } + + override def createQueryExecution( + spark: AnyRef, + plan: LogicalPlan): org.apache.spark.sql.execution.QueryExecution = { + new org.apache.spark.sql.execution.QueryExecution( + spark.asInstanceOf[org.apache.spark.sql.SparkSession], + plan) + } + + override def getConf(spark: AnyRef, key: String): String = { + spark.asInstanceOf[org.apache.spark.sql.SparkSession].conf.get(key) + } + + override def createInMemoryFileIndex( + spark: AnyRef, + paths: Seq[org.apache.hadoop.fs.Path], + parameters: Map[String, String], + userSpecifiedSchema: Option[org.apache.spark.sql.types.StructType] + ): org.apache.spark.sql.execution.datasources.InMemoryFileIndex = { + new org.apache.spark.sql.execution.datasources.InMemoryFileIndex( + spark.asInstanceOf[org.apache.spark.sql.SparkSession], + paths, + parameters, + userSpecifiedSchema + ) + } + + override def createHadoopFsRelation( + spark: AnyRef, + location: org.apache.spark.sql.execution.datasources.InMemoryFileIndex, + partitionSchema: org.apache.spark.sql.types.StructType, + dataSchema: org.apache.spark.sql.types.StructType, + bucketSpec: Option[org.apache.spark.sql.catalyst.catalog.BucketSpec], + fileFormat: org.apache.spark.sql.execution.datasources.FileFormat, + options: Map[String, String] + ): org.apache.spark.sql.execution.datasources.HadoopFsRelation = { + org.apache.spark.sql.execution.datasources.HadoopFsRelation( + location, + partitionSchema, + dataSchema, + bucketSpec, + fileFormat, + options + )(spark.asInstanceOf[org.apache.spark.sql.SparkSession]) + } +} + +object WindowGroupLimitCase { + def unapply(l: LogicalPlan): Option[LogicalPlan] = None +} diff --git a/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala index ec3ee78e8..d5ae2f793 100644 --- a/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala +++ b/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -24,7 +24,7 @@ import io.substrait.relation.Rel class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { protected def t(p: LogicalPlan): relation.Rel = - throw new UnsupportedOperationException(s"Unable to convert the expression ${p.nodeName}") + throw new UnsupportedOperationException(s"Unable to convert the LogicalPlan ${p.nodeName}") override def visitDistinct(p: Distinct): relation.Rel = t(p) diff --git a/spark/src/main/spark-3.5/io/substrait/spark/compat/SparkCompatImpl.scala b/spark/src/main/spark-3.5/io/substrait/spark/compat/SparkCompatImpl.scala new file mode 100644 index 000000000..75898e599 --- /dev/null +++ b/spark/src/main/spark-3.5/io/substrait/spark/compat/SparkCompatImpl.scala @@ -0,0 +1,97 @@ +package io.substrait.spark.compat + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WindowGroupLimit} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} + +import io.substrait.relation + +class SparkCompatImpl extends SparkCompat { + + override def createScalarSubquery(plan: LogicalPlan): ScalarSubquery = { + // Spark 3.5 requires exprId parameter + ScalarSubquery(plan, exprId = NamedExpression.newExprId) + } + + override def createAggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan + ): Aggregate = { + // Spark 3.5 uses 3-parameter constructor + Aggregate(groupingExpressions, aggregateExpressions, child) + } + + override def createLogicalRelation( + relation: HadoopFsRelation, + output: Seq[AttributeReference], + catalogTable: Option[org.apache.spark.sql.catalyst.catalog.CatalogTable], + isStreaming: Boolean + ): LogicalRelation = { + // Spark 3.5 uses 4-parameter constructor (no stream parameter) + new LogicalRelation(relation, output, catalogTable, isStreaming) + } + + override def createListQuery( + plan: LogicalPlan, + output: Seq[Attribute]): org.apache.spark.sql.catalyst.expressions.ListQuery = { + // Spark 3.5 has numCols parameter + org.apache.spark.sql.catalyst.expressions.ListQuery(plan, numCols = output.length) + } + + override def getOrCreateSparkSession(): AnyRef = { + org.apache.spark.sql.SparkSession.builder().getOrCreate() + } + + override def createQueryExecution( + spark: AnyRef, + plan: LogicalPlan): org.apache.spark.sql.execution.QueryExecution = { + new org.apache.spark.sql.execution.QueryExecution( + spark.asInstanceOf[org.apache.spark.sql.SparkSession], + plan) + } + + override def getConf(spark: AnyRef, key: String): String = { + spark.asInstanceOf[org.apache.spark.sql.SparkSession].conf.get(key) + } + + override def createInMemoryFileIndex( + spark: AnyRef, + paths: Seq[org.apache.hadoop.fs.Path], + parameters: Map[String, String], + userSpecifiedSchema: Option[org.apache.spark.sql.types.StructType] + ): org.apache.spark.sql.execution.datasources.InMemoryFileIndex = { + new org.apache.spark.sql.execution.datasources.InMemoryFileIndex( + spark.asInstanceOf[org.apache.spark.sql.SparkSession], + paths, + parameters, + userSpecifiedSchema + ) + } + + override def createHadoopFsRelation( + spark: AnyRef, + location: org.apache.spark.sql.execution.datasources.InMemoryFileIndex, + partitionSchema: org.apache.spark.sql.types.StructType, + dataSchema: org.apache.spark.sql.types.StructType, + bucketSpec: Option[org.apache.spark.sql.catalyst.catalog.BucketSpec], + fileFormat: org.apache.spark.sql.execution.datasources.FileFormat, + options: Map[String, String] + ): org.apache.spark.sql.execution.datasources.HadoopFsRelation = { + org.apache.spark.sql.execution.datasources.HadoopFsRelation( + location, + partitionSchema, + dataSchema, + bucketSpec, + fileFormat, + options + )(spark.asInstanceOf[org.apache.spark.sql.SparkSession]) + } +} + +object WindowGroupLimitCase { + def unapply(l: LogicalPlan): Option[LogicalPlan] = l match { + case w: WindowGroupLimit => Some(w.child) + case _ => None + } +} diff --git a/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.5/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala similarity index 96% rename from spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala rename to spark/src/main/spark-3.5/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala index 345cb215f..d5ae2f793 100644 --- a/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala +++ b/spark/src/main/spark-3.5/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -24,7 +24,7 @@ import io.substrait.relation.Rel class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { protected def t(p: LogicalPlan): relation.Rel = - throw new UnsupportedOperationException(s"Unable to convert the expression ${p.nodeName}") + throw new UnsupportedOperationException(s"Unable to convert the LogicalPlan ${p.nodeName}") override def visitDistinct(p: Distinct): relation.Rel = t(p) @@ -70,5 +70,7 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { override def visitWithCTE(p: WithCTE): Rel = t(p) + override def visitOffset(p: Offset): Rel = t(p) + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) } diff --git a/spark/src/main/spark-4.0/io/substrait/spark/compat/SparkCompatImpl.scala b/spark/src/main/spark-4.0/io/substrait/spark/compat/SparkCompatImpl.scala new file mode 100644 index 000000000..bd725c6a7 --- /dev/null +++ b/spark/src/main/spark-4.0/io/substrait/spark/compat/SparkCompatImpl.scala @@ -0,0 +1,95 @@ +package io.substrait.spark.compat + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WindowGroupLimit} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} + +class SparkCompatImpl extends SparkCompat { + + override def createScalarSubquery(plan: LogicalPlan): ScalarSubquery = { + // Spark 4.0 simplified constructor - no exprId needed + ScalarSubquery(plan) + } + + override def createAggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan + ): Aggregate = { + // Spark 4.0 simplified constructor - no aggregateAttributes needed + Aggregate(groupingExpressions, aggregateExpressions, child) + } + + override def createLogicalRelation( + relation: HadoopFsRelation, + output: Seq[AttributeReference], + catalogTable: Option[org.apache.spark.sql.catalyst.catalog.CatalogTable], + isStreaming: Boolean + ): LogicalRelation = { + // Spark 4.0 requires stream parameter (None for non-streaming) + new LogicalRelation(relation, output, catalogTable, isStreaming, None) + } + + override def createListQuery( + plan: LogicalPlan, + output: Seq[Attribute]): org.apache.spark.sql.catalyst.expressions.ListQuery = { + // Spark 4.0 has numCols parameter + org.apache.spark.sql.catalyst.expressions.ListQuery(plan, numCols = output.length) + } + + override def getOrCreateSparkSession(): AnyRef = { + org.apache.spark.sql.classic.SparkSession.builder().getOrCreate() + } + + override def createQueryExecution( + spark: AnyRef, + plan: LogicalPlan): org.apache.spark.sql.execution.QueryExecution = { + new org.apache.spark.sql.execution.QueryExecution( + spark.asInstanceOf[org.apache.spark.sql.classic.SparkSession], + plan) + } + + override def getConf(spark: AnyRef, key: String): String = { + spark.asInstanceOf[org.apache.spark.sql.classic.SparkSession].conf.get(key) + } + + override def createInMemoryFileIndex( + spark: AnyRef, + paths: Seq[org.apache.hadoop.fs.Path], + parameters: Map[String, String], + userSpecifiedSchema: Option[org.apache.spark.sql.types.StructType] + ): org.apache.spark.sql.execution.datasources.InMemoryFileIndex = { + new org.apache.spark.sql.execution.datasources.InMemoryFileIndex( + spark.asInstanceOf[org.apache.spark.sql.classic.SparkSession], + paths, + parameters, + userSpecifiedSchema + ) + } + + override def createHadoopFsRelation( + spark: AnyRef, + location: org.apache.spark.sql.execution.datasources.InMemoryFileIndex, + partitionSchema: org.apache.spark.sql.types.StructType, + dataSchema: org.apache.spark.sql.types.StructType, + bucketSpec: Option[org.apache.spark.sql.catalyst.catalog.BucketSpec], + fileFormat: org.apache.spark.sql.execution.datasources.FileFormat, + options: Map[String, String] + ): org.apache.spark.sql.execution.datasources.HadoopFsRelation = { + org.apache.spark.sql.execution.datasources.HadoopFsRelation( + location, + partitionSchema, + dataSchema, + bucketSpec, + fileFormat, + options + )(spark.asInstanceOf[org.apache.spark.sql.classic.SparkSession]) + } +} + +object WindowGroupLimitCase { + def unapply(l: LogicalPlan): Option[LogicalPlan] = l match { + case w: WindowGroupLimit => Some(w.child) + case _ => None + } +} diff --git a/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-4.0/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala similarity index 95% rename from spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala rename to spark/src/main/spark-4.0/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala index 836a087f1..d5ae2f793 100644 --- a/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala +++ b/spark/src/main/spark-4.0/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -69,4 +69,8 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { override def visitSort(sort: Sort): Rel = t(sort) override def visitWithCTE(p: WithCTE): Rel = t(p) + + override def visitOffset(p: Offset): Rel = t(p) + + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) } diff --git a/spark/src/test/scala/io/substrait/spark/DialectSuite.scala b/spark/src/test/scala/io/substrait/spark/DialectSuite.scala index d186e819e..92e8a5502 100644 --- a/spark/src/test/scala/io/substrait/spark/DialectSuite.scala +++ b/spark/src/test/scala/io/substrait/spark/DialectSuite.scala @@ -16,7 +16,7 @@ import java.io.{File, FileInputStream} import scala.io.Source class DialectSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase { - private val dialectPath = "spark_dialect.yaml" + private val dialectPath = "../spark_dialect.yaml" override def beforeAll(): Unit = { super.beforeAll() diff --git a/spark/src/test/scala/io/substrait/spark/LocalFiles.scala b/spark/src/test/scala/io/substrait/spark/LocalFiles.scala index aa674555d..651b0214e 100644 --- a/spark/src/test/scala/io/substrait/spark/LocalFiles.scala +++ b/spark/src/test/scala/io/substrait/spark/LocalFiles.scala @@ -19,8 +19,9 @@ package io.substrait.spark import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} -import org.apache.spark.sql.{Dataset, DatasetUtil, Encoders, Row} +import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.DatasetUtil import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -32,6 +33,8 @@ import io.substrait.relation.ProtoRelConverter import java.nio.file.Paths class LocalFiles extends SharedSparkSession { + private val testResourcesPath = Paths.get("../src/test/resources/") + override def beforeAll(): Unit = { super.beforeAll() sparkContext.setLogLevel("WARN") @@ -77,7 +80,7 @@ class LocalFiles extends SharedSparkSession { val table = spark.read .option("header", true) .option("inferSchema", true) - .csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString) + .csv(testResourcesPath.resolve("dataset-a.csv").toAbsolutePath.toString) assertRoundTripData(table) } @@ -87,7 +90,7 @@ class LocalFiles extends SharedSparkSession { .option("header", true) .option("inferSchema", true) .option("nullValue", "seven") - .csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString) + .csv(testResourcesPath.resolve("dataset-a.csv").toAbsolutePath.toString) val result = assertRoundTripData(table) val id = result.filter("isnull(VALUE)").head().get(0) @@ -104,7 +107,7 @@ class LocalFiles extends SharedSparkSession { .schema(schema) .option("delimiter", "|") .option("quote", "'") - .csv(Paths.get("src/test/resources/dataset-a.txt").toAbsolutePath.toString) + .csv(testResourcesPath.resolve("dataset-a.txt").toAbsolutePath.toString) assertRoundTripData(table) } @@ -113,21 +116,21 @@ class LocalFiles extends SharedSparkSession { val table = spark.read .option("header", true) .option("inferSchema", true) - .csv(Paths.get("src/test/resources/csv/").toAbsolutePath.toString) + .csv(testResourcesPath.resolve("csv/").toAbsolutePath.toString) assertRoundTripData(table) } test("Read parquet file") { val table = spark.read - .parquet(Paths.get("src/test/resources/dataset-a.parquet").toAbsolutePath.toString) + .parquet(testResourcesPath.resolve("dataset-a.parquet").toAbsolutePath.toString) assertRoundTripData(table) } test("Read orc file") { val table = spark.read - .orc(Paths.get("src/test/resources/dataset-a.orc").toAbsolutePath.toString) + .orc(testResourcesPath.resolve("dataset-a.orc").toAbsolutePath.toString) assertRoundTripData(table) } @@ -136,10 +139,10 @@ class LocalFiles extends SharedSparkSession { val csv = spark.read .option("header", true) .option("inferSchema", true) - .csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString) + .csv(testResourcesPath.resolve("dataset-a.csv").toAbsolutePath.toString) val orc = spark.read - .orc(Paths.get("src/test/resources/dataset-a.orc").toAbsolutePath.toString) + .orc(testResourcesPath.resolve("dataset-a.orc").toAbsolutePath.toString) .withColumnRenamed("ID", "ID_B") .withColumnRenamed("VALUE", "VALUE_B"); @@ -154,7 +157,7 @@ class LocalFiles extends SharedSparkSession { val csv = spark.read .option("header", true) .option("inferSchema", true) - .csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString) + .csv(testResourcesPath.resolve("dataset-a.csv").toAbsolutePath.toString) csv.createOrReplaceTempView("csv") val data = spark.sql(""" diff --git a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala index 400a53aa6..2bde15ce6 100644 --- a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala +++ b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala @@ -141,7 +141,8 @@ class TypesAndLiteralsSuite extends SparkFunSuite { val originalValues = l.value.asInstanceOf[MapData].valueArray().toArray[UTF8String](StringType) val sparkValues = sparkLiteral.value.asInstanceOf[MapData].valueArray().toArray[UTF8String](StringType) - assert(originalValues.sorted.sameElements(sparkValues.sorted)) + + assert(originalValues.toSet == sparkValues.toSet) } test(s"test named struct") { diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala deleted file mode 100644 index c2c0beacb..000000000 --- a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql - -import org.apache.spark.SparkConf -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession - -trait TPCBase extends SharedSparkSession { - - protected def injectStats: Boolean = false - - override protected def sparkConf: SparkConf = { - if (injectStats) { - super.sparkConf - .set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) - .set(SQLConf.CBO_ENABLED, true) - .set(SQLConf.PLAN_STATS_ENABLED, true) - .set(SQLConf.JOIN_REORDER_ENABLED, true) - } else { - super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) - } - } - - override def beforeAll(): Unit = { - super.beforeAll() - createTables() - } - - override def afterAll(): Unit = { - dropTables() - super.afterAll() - } - - protected def createTables(): Unit - - protected def dropTables(): Unit -} diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala deleted file mode 100644 index c3247c5ce..000000000 --- a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.TableIdentifier - -trait TPCHBase extends TPCBase { - - override def createTables(): Unit = { - tpchCreateTable.values.foreach(sql => spark.sql(sql)) - } - - override def dropTables(): Unit = { - tpchCreateTable.keys.foreach { - tableName => spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) - } - } - - val tpchCreateTable = Map( - "orders" -> - """ - |CREATE TABLE `orders` ( - |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, - |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, - |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) - |USING parquet - """.stripMargin, - "nation" -> - """ - |CREATE TABLE `nation` ( - |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) - |USING parquet - """.stripMargin, - "region" -> - """ - |CREATE TABLE `region` ( - |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) - |USING parquet - """.stripMargin, - "part" -> - """ - |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, - |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, - |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) - |USING parquet - """.stripMargin, - "partsupp" -> - """ - |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, - |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) - |USING parquet - """.stripMargin, - "customer" -> - """ - |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, - |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), - |`c_mktsegment` STRING, `c_comment` STRING) - |USING parquet - """.stripMargin, - "supplier" -> - """ - |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, - |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) - |USING parquet - """.stripMargin, - "lineitem" -> - """ - |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, - |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), - |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, - |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, - |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) - |USING parquet - """.stripMargin - ) - - val tpchQueries = Seq( - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "q16", - "q17", - "q18", - "q19", - "q20", - "q21", - "q22") -} diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/DatasetUtil.scala b/spark/src/test/spark-3.4/org/apache/spark/sql/DatasetUtil.scala similarity index 100% rename from spark/src/test/spark-3.2/org/apache/spark/sql/DatasetUtil.scala rename to spark/src/test/spark-3.4/org/apache/spark/sql/DatasetUtil.scala diff --git a/spark/src/test/spark-3.4/org/apache/spark/sql/classic/DatasetUtil.scala b/spark/src/test/spark-3.4/org/apache/spark/sql/classic/DatasetUtil.scala new file mode 100644 index 000000000..43805b2ba --- /dev/null +++ b/spark/src/test/spark-3.4/org/apache/spark/sql/classic/DatasetUtil.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** Compatibility wrapper for Spark 3.4 to provide classic package API */ +object DatasetUtil { + def fromLogicalPlan(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + org.apache.spark.sql.DatasetUtil.fromLogicalPlan(sparkSession, logicalPlan) + } +} diff --git a/spark/src/test/spark-3.5/org/apache/spark/sql/DatasetUtil.scala b/spark/src/test/spark-3.5/org/apache/spark/sql/DatasetUtil.scala new file mode 100644 index 000000000..794719acc --- /dev/null +++ b/spark/src/test/spark-3.5/org/apache/spark/sql/DatasetUtil.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +object DatasetUtil { + def fromLogicalPlan(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + Dataset.ofRows(sparkSession, logicalPlan) + } +} diff --git a/spark/src/test/spark-3.5/org/apache/spark/sql/classic/DatasetUtil.scala b/spark/src/test/spark-3.5/org/apache/spark/sql/classic/DatasetUtil.scala new file mode 100644 index 000000000..12405c068 --- /dev/null +++ b/spark/src/test/spark-3.5/org/apache/spark/sql/classic/DatasetUtil.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** Compatibility wrapper for Spark 3.5 to provide classic package API */ +object DatasetUtil { + def fromLogicalPlan(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + Dataset.ofRows(sparkSession, logicalPlan) + } +} diff --git a/spark/src/test/spark-4.0/org/apache/spark/sql/classic/DatasetUtil.scala b/spark/src/test/spark-4.0/org/apache/spark/sql/classic/DatasetUtil.scala new file mode 100644 index 000000000..0f5a0b81e --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/spark/sql/classic/DatasetUtil.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +object DatasetUtil { + def fromLogicalPlan(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + Dataset.ofRows(sparkSession, logicalPlan) + } +}