Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ case class CollectMetricsExec(
acc
}

private var collectedMetricsRow: Row = _

private[sql] def accumulatorId: Long = {
accumulator.id
}
Expand All @@ -56,7 +58,14 @@ case class CollectMetricsExec(
.asInstanceOf[InternalRow => Row]
}

def collectedMetrics: Row = toRowConverter(accumulator.value)
def collectedMetrics: Row = {
accumulator.synchronized {
if (collectedMetricsRow == null) {
collectedMetricsRow = toRowConverter(accumulator.value)
}
collectedMetricsRow
}
}

override def output: Seq[Attribute] = child.output

Expand All @@ -65,7 +74,10 @@ case class CollectMetricsExec(
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def resetMetrics(): Unit = {
accumulator.reset()
accumulator.synchronized {
accumulator.reset()
collectedMetricsRow = null
}
super.resetMetrics()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
import java.io.{BufferedWriter, OutputStreamWriter}
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import javax.annotation.concurrent.GuardedBy

import scala.util.control.NonFatal

Expand Down Expand Up @@ -298,13 +297,8 @@ class QueryExecution(
*/
def toRdd: RDD[InternalRow] = lazyToRdd.get

private val observedMetricsLock = new Object

/** Get the metrics observed during the execution of the query plan. */
@GuardedBy("observedMetricsLock")
def observedMetrics: Map[String, Row] = observedMetricsLock.synchronized {
CollectMetricsExec.collect(executedPlan)
}
def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan)

protected def preparations: Seq[Rule[SparkPlan]] = {
QueryExecution.preparations(sparkSession,
Expand Down
28 changes: 27 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package org.apache.spark.sql

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
import java.util.concurrent.Executors

import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.Random
Expand All @@ -44,7 +47,7 @@ import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, Expr
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.{CollectMetricsExec, LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
Expand All @@ -55,6 +58,7 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkThreadUtils

case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
case class TestDataPoint2(x: Int, s: String)
Expand Down Expand Up @@ -1075,6 +1079,28 @@ class DatasetSuite extends QueryTest
assert(namedObservation2.get === expected2)
}

test("SPARK-54353: concurrent CollectMetricsExec.collect()") {
val df = spark
.range(10)
.observe(Observation("result"), map(lit("count"), count(lit(1))))
df.collect()
val threadPool = Executors.newFixedThreadPool(2)
val executionContext = ExecutionContext.fromExecutorService(threadPool)
try {
Seq(
Future { CollectMetricsExec.collect(df.queryExecution.executedPlan) }(executionContext),
Future { CollectMetricsExec.collect(df.queryExecution.executedPlan) }(executionContext)
).foreach { future =>
val result = SparkThreadUtils.awaitResult(future, Duration.Inf)
assert(result.size === 1)
assert(result.get("result") === Some(Row(Map("count" -> 10))))
}
} finally {
executionContext.shutdown()
threadPool.shutdown()
}
}

test("sample with replacement") {
val n = 100
val data = sparkContext.parallelize(1 to n, 2).toDS()
Expand Down