diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index fd89be2368af..fdf399d589ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -42,6 +42,8 @@ case class CollectMetricsExec( acc } + private var collectedMetricsRow: Row = _ + private[sql] def accumulatorId: Long = { accumulator.id } @@ -56,7 +58,14 @@ case class CollectMetricsExec( .asInstanceOf[InternalRow => Row] } - def collectedMetrics: Row = toRowConverter(accumulator.value) + def collectedMetrics: Row = { + accumulator.synchronized { + if (collectedMetricsRow == null) { + collectedMetricsRow = toRowConverter(accumulator.value) + } + collectedMetricsRow + } + } override def output: Seq[Attribute] = child.output @@ -65,7 +74,10 @@ case class CollectMetricsExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def resetMetrics(): Unit = { - accumulator.reset() + accumulator.synchronized { + accumulator.reset() + collectedMetricsRow = null + } super.resetMetrics() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 27d6eec46b69..9e96ba4948aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.io.{BufferedWriter, OutputStreamWriter} import java.util.UUID import java.util.concurrent.atomic.AtomicLong -import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal @@ -298,13 +297,8 @@ class QueryExecution( */ def toRdd: RDD[InternalRow] = lazyToRdd.get - private val observedMetricsLock = new Object - /** Get the metrics observed during the execution of the query plan. */ - @GuardedBy("observedMetricsLock") - def observedMetrics: Map[String, Row] = observedMetricsLock.synchronized { - CollectMetricsExec.collect(executedPlan) - } + def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan) protected def preparations: Seq[Rule[SparkPlan]] = { QueryExecution.preparations(sparkSession, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 941fd2205424..abdf07fd3da8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import java.util.concurrent.Executors import scala.collection.immutable.HashSet import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -44,7 +47,7 @@ import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, Expr import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} +import org.apache.spark.sql.execution.{CollectMetricsExec, LogicalRDD, RDDScanExec, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream @@ -55,6 +58,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.util.SparkThreadUtils case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -1075,6 +1079,28 @@ class DatasetSuite extends QueryTest assert(namedObservation2.get === expected2) } + test("SPARK-54353: concurrent CollectMetricsExec.collect()") { + val df = spark + .range(10) + .observe(Observation("result"), map(lit("count"), count(lit(1)))) + df.collect() + val threadPool = Executors.newFixedThreadPool(2) + val executionContext = ExecutionContext.fromExecutorService(threadPool) + try { + Seq( + Future { CollectMetricsExec.collect(df.queryExecution.executedPlan) }(executionContext), + Future { CollectMetricsExec.collect(df.queryExecution.executedPlan) }(executionContext) + ).foreach { future => + val result = SparkThreadUtils.awaitResult(future, Duration.Inf) + assert(result.size === 1) + assert(result.get("result") === Some(Row(Map("count" -> 10)))) + } + } finally { + executionContext.shutdown() + threadPool.shutdown() + } + } + test("sample with replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS()