diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4e2707a488e38..44c06a53d1749 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2651,6 +2651,8 @@ class SparkConnectPlanner( Some(transformWriteOperation(command.getWriteOperation)) case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 => Some(transformWriteOperationV2(command.getWriteOperationV2)) + case proto.Command.CommandTypeCase.MERGE_INTO_TABLE_COMMAND => + Some(transformMergeIntoTableCommand(command.getMergeIntoTableCommand)) case _ => None } @@ -2700,8 +2702,6 @@ class SparkConnectPlanner( handleCheckpointCommand(command.getCheckpointCommand, responseObserver) case proto.Command.CommandTypeCase.REMOVE_CACHED_REMOTE_RELATION_COMMAND => handleRemoveCachedRemoteRelationCommand(command.getRemoveCachedRemoteRelationCommand) - case proto.Command.CommandTypeCase.MERGE_INTO_TABLE_COMMAND => - handleMergeIntoTableCommand(command.getMergeIntoTableCommand) case proto.Command.CommandTypeCase.ML_COMMAND => handleMlCommand(command.getMlCommand, responseObserver) case proto.Command.CommandTypeCase.PIPELINE_COMMAND => @@ -3759,7 +3759,8 @@ class SparkConnectPlanner( executeHolder.eventsManager.postFinished() } - private def handleMergeIntoTableCommand(cmd: proto.MergeIntoTableCommand): Unit = { + private def transformMergeIntoTableCommand(cmd: proto.MergeIntoTableCommand)( + tracker: QueryPlanningTracker): LogicalPlan = { def transformActions(actions: java.util.List[proto.Expression]): Seq[MergeAction] = actions.asScala.map(transformExpression).map(_.asInstanceOf[MergeAction]).toSeq @@ -3767,7 +3768,7 @@ class SparkConnectPlanner( val notMatchedActions = transformActions(cmd.getNotMatchedActionsList) val notMatchedBySourceActions = transformActions(cmd.getNotMatchedBySourceActionsList) - val sourceDs = Dataset.ofRows(session, transformRelation(cmd.getSourceTablePlan)) + val sourceDs = Dataset.ofRows(session, transformRelation(cmd.getSourceTablePlan), tracker) val mergeInto = sourceDs .mergeInto(cmd.getTargetTableName, Column(transformExpression(cmd.getMergeCondition))) .asInstanceOf[MergeIntoWriter[Row]] @@ -3777,8 +3778,7 @@ class SparkConnectPlanner( if (cmd.getWithSchemaEvolution) { mergeInto.withSchemaEvolution() } - mergeInto.merge() - executeHolder.eventsManager.postFinished() + mergeInto.mergeCommand() } private val emptyLocalRelation = LocalRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala index 0269b15061c97..e3c872658c86a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/MergeIntoWriter.scala @@ -57,13 +57,18 @@ class MergeIntoWriter[T] private[sql](table: String, ds: Dataset[T], on: Column) /** @inheritdoc */ def merge(): Unit = { + val qe = sparkSession.sessionState.executePlan(mergeCommand()) + qe.assertCommandExecuted() + } + + private[sql] def mergeCommand(): LogicalPlan = { if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) { throw new SparkRuntimeException( errorClass = "NO_MERGE_ACTION_SPECIFIED", messageParameters = Map.empty) } - val merge = MergeIntoTable( + MergeIntoTable( UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges( matchedActions, notMatchedActions, notMatchedBySourceActions)), logicalPlan, @@ -72,8 +77,6 @@ class MergeIntoWriter[T] private[sql](table: String, ds: Dataset[T], on: Column) notMatchedActions.toSeq, notMatchedBySourceActions.toSeq, schemaEvolutionEnabled) - val qe = sparkSession.sessionState.executePlan(merge) - qe.assertCommandExecuted() } override protected[sql] def insertAll(condition: Option[Column]): this.type = {