diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/CatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/CatalogPlugin.java index 23f3acc7230f..e5c0fa34ded4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/CatalogPlugin.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/CatalogPlugin.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.catalog; +import java.io.Closeable; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -42,7 +44,7 @@ * @since 3.0.0 */ @Evolving -public interface CatalogPlugin { +public interface CatalogPlugin extends Closeable { /** * Called to initialize configuration. *

@@ -74,4 +76,7 @@ public interface CatalogPlugin { default String[] defaultNamespace() { return new String[0]; } + + @Override + default void close() {} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index b52091afc133..cdff146b33ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -265,6 +265,7 @@ object ResolvedIdentifier { object FakeSystemCatalog extends CatalogPlugin { override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} override def name(): String = "system" + override def close(): Unit = {} } /** @@ -273,4 +274,5 @@ object FakeSystemCatalog extends CatalogPlugin { object FakeLocalCatalog extends CatalogPlugin { override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} override def name(): String = "local" + override def close(): Unit = {} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 9b8584604d32..1ee1421c36a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.catalog +import java.io.Closeable + import scala.collection.mutable import org.apache.spark.internal.Logging @@ -40,7 +42,7 @@ import org.apache.spark.sql.internal.SQLConf private[sql] class CatalogManager( defaultSessionCatalog: CatalogPlugin, - val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging { + val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging with Closeable { import CatalogManager.SESSION_CATALOG_NAME import CatalogV2Util._ @@ -57,6 +59,21 @@ class CatalogManager( } } + override def close(): Unit = synchronized { + val allCatalogs = (catalogs.values.toSet + defaultSessionCatalog).toSeq + allCatalogs.foreach { + case c: Closeable => + try { + c.close() + } catch { + case e: Throwable => + logWarning(s"Failed to close catalog of class ${c.getClass.getName}", e) + } + case _ => + } + catalogs.clear() + } + def isCatalogRegistered(name: String): Boolean = { try { catalog(name) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala index fc78eef0ff1b..47566a7e7534 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala @@ -127,6 +127,18 @@ class CatalogManagerSuite extends SparkFunSuite with SQLHelper { } } } + + test("CatalogManager.close should close all closeable catalogs") { + val catalogManager = new CatalogManager(FakeV2SessionCatalog, createSessionCatalog()) + withSQLConf("spark.sql.catalog.dummy" -> classOf[DummyCatalog].getName, + "spark.sql.catalog.closeable" -> classOf[CloseableCatalog].getName) { + catalogManager.setCurrentCatalog("dummy") + val closeable = catalogManager.catalog("closeable").asInstanceOf[CloseableCatalog] + assert(!closeable.isClosed) + catalogManager.close() + assert(closeable.isClosed) + } + } } class DummyCatalog extends CatalogPlugin { @@ -136,4 +148,13 @@ class DummyCatalog extends CatalogPlugin { private var _name: String = null override def name(): String = _name override def defaultNamespace(): Array[String] = Array("a", "b") + override def close(): Unit = {} +} + +class CloseableCatalog extends DummyCatalog with java.io.Closeable { + private var closed = false + override def close(): Unit = { + closed = true + } + def isClosed: Boolean = closed } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala index 49e119b56bc8..66698c620295 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap private case class DummyCatalogPlugin(override val name: String) extends CatalogPlugin { override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def close(): Unit = {} } class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index f3128ce50840..a347ad6dc936 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -380,6 +380,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio mlCache.clear() + session.sessionState.close() session.cleanupPythonWorkerLogs() eventManager.postClosed() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala index 91fe395f520d..141b85bda821 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -137,13 +137,16 @@ private[sql] class SparkConnectListenerBusListener( } def sendResultComplete(): Unit = { - responseObserver - .asInstanceOf[ExecuteResponseObserver[ExecutePlanResponse]] - .onNextComplete( - ExecutePlanResponse - .newBuilder() - .setResultComplete(ExecutePlanResponse.ResultComplete.newBuilder().build()) - .build()) + responseObserver match { + case obs: ExecuteResponseObserver[ExecutePlanResponse] => + obs.onNextComplete( + ExecutePlanResponse + .newBuilder() + .setResultComplete(ExecutePlanResponse.ResultComplete.newBuilder().build()) + .build()) + case _ => + responseObserver.onCompleted() + } } // QueryStartedEvent is sent to client along with WriteStreamOperationStartResult diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala index d06c93cc1cad..de8163978860 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectTestUtils.scala @@ -31,6 +31,9 @@ object SparkConnectTestUtils { sessionId = UUID.randomUUID().toString, session = session) SparkConnectService.sessionManager.putSessionForTesting(ret) + if (session != null) { + ret.initializeSession() + } ret } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 1b747705e9ad..9b5d8dee8963 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -40,10 +40,24 @@ import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCl import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl} import org.apache.spark.sql.pipelines.logging.PipelineEvent import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ class SparkConnectSessionHolderSuite extends SharedSparkSession { + test("SessionHolder.close should close catalogs") { + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + val catalogName = "my_closeable_catalog" + sessionHolder.session.conf + .set(s"spark.sql.catalog.$catalogName", classOf[CloseableCatalog].getName) + + val catalog = sessionHolder.session.sessionState.catalogManager.catalog(catalogName) + val closeableCatalog = catalog.asInstanceOf[CloseableCatalog] + assert(!closeableCatalog.isClosed) + sessionHolder.close() + assert(closeableCatalog.isClosed) + } + test("DataFrame cache: Successful put and get") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) import sessionHolder.session.implicits._ @@ -484,3 +498,21 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { assertPlanCache(sessionHolder, Some(expected)) } } + +class CloseableCatalog + extends org.apache.spark.sql.connector.catalog.CatalogPlugin + with java.io.Closeable { + private var _name: String = _ + private var closed = false + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + _name = name + } + + override def name(): String = _name + override def close(): Unit = { + closed = true + } + + def isClosed: Boolean = closed +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index f11ddbc51d33..5ab74aebe89f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -55,6 +55,8 @@ class V2SessionCatalog(catalog: SessionCatalog) // This class is instantiated by Spark, so `initialize` method will not be called. override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + override def close(): Unit = {} + override def capabilities(): util.Set[TableCatalogCapability] = { Set( TableCatalogCapability.SUPPORT_COLUMN_DEFAULT_VALUE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 440148989ffb..08dfa64409d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import java.io.File +import java.io.{Closeable, File} import java.net.URI import org.apache.hadoop.conf.Configuration @@ -91,7 +91,7 @@ private[sql] class SessionState( val columnarRules: Seq[ColumnarRule], val adaptiveRulesHolder: AdaptiveRulesHolder, val planNormalizationRules: Seq[Rule[LogicalPlan]], - val artifactManagerBuilder: () => ArtifactManager) { + val artifactManagerBuilder: () => ArtifactManager) extends Closeable { // The following fields are lazy to avoid creating the Hive client when creating SessionState. lazy val catalog: SessionCatalog = catalogBuilder() @@ -110,6 +110,10 @@ private[sql] class SessionState( def catalogManager: CatalogManager = analyzer.catalogManager + override def close(): Unit = { + catalogManager.close() + } + def newHadoopConf(): Configuration = SessionState.newHadoopConf( sharedState.sparkContext.hadoopConfiguration, conf)