Skip to content

Commit 029dfff

Browse files
committed
[SPARK-54353][SQL] Make CollectMetricsExec.collectedMetrics thread-safe
1 parent fc49dbd commit 029dfff

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.spark.sql.execution
1818

19+
import javax.annotation.concurrent.GuardedBy
20+
1921
import org.apache.spark.{InternalAccumulator, TaskContext}
2022
import org.apache.spark.rdd.RDD
2123
import org.apache.spark.sql.Row
@@ -42,6 +44,11 @@ case class CollectMetricsExec(
4244
acc
4345
}
4446

47+
private val collectedMetricsLock = new Object()
48+
49+
@GuardedBy("collectedMetricsLock")
50+
private var collectedMetricsRow: Row = _
51+
4552
private[sql] def accumulatorId: Long = {
4653
accumulator.id
4754
}
@@ -56,7 +63,14 @@ case class CollectMetricsExec(
5663
.asInstanceOf[InternalRow => Row]
5764
}
5865

59-
def collectedMetrics: Row = toRowConverter(accumulator.value)
66+
def collectedMetrics: Row = {
67+
collectedMetricsLock.synchronized {
68+
if (collectedMetricsRow == null) {
69+
collectedMetricsRow = toRowConverter(accumulator.value)
70+
}
71+
collectedMetricsRow
72+
}
73+
}
6074

6175
override def output: Seq[Attribute] = child.output
6276

@@ -65,7 +79,10 @@ case class CollectMetricsExec(
6579
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
6680

6781
override def resetMetrics(): Unit = {
68-
accumulator.reset()
82+
collectedMetricsLock.synchronized {
83+
accumulator.reset()
84+
collectedMetricsRow = null
85+
}
6986
super.resetMetrics()
7087
}
7188

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
2020
import java.io.{BufferedWriter, OutputStreamWriter}
2121
import java.util.UUID
2222
import java.util.concurrent.atomic.AtomicLong
23-
import javax.annotation.concurrent.GuardedBy
2423

2524
import scala.util.control.NonFatal
2625

@@ -298,13 +297,8 @@ class QueryExecution(
298297
*/
299298
def toRdd: RDD[InternalRow] = lazyToRdd.get
300299

301-
private val observedMetricsLock = new Object
302-
303300
/** Get the metrics observed during the execution of the query plan. */
304-
@GuardedBy("observedMetricsLock")
305-
def observedMetrics: Map[String, Row] = observedMetricsLock.synchronized {
306-
CollectMetricsExec.collect(executedPlan)
307-
}
301+
def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan)
308302

309303
protected def preparations: Seq[Rule[SparkPlan]] = {
310304
QueryExecution.preparations(sparkSession,

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ package org.apache.spark.sql
1919

2020
import java.io.{Externalizable, ObjectInput, ObjectOutput}
2121
import java.sql.{Date, Timestamp}
22+
import java.util.concurrent.Executors
2223

2324
import scala.collection.immutable.HashSet
2425
import scala.collection.mutable
26+
import scala.concurrent.{ExecutionContext, Future}
27+
import scala.concurrent.duration.Duration
2528
import scala.jdk.CollectionConverters._
2629
import scala.reflect.ClassTag
2730
import scala.util.Random
@@ -44,7 +47,7 @@ import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, Expr
4447
import org.apache.spark.sql.catalyst.plans.JoinType
4548
import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext
4649
import org.apache.spark.sql.catalyst.util.sideBySide
47-
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
50+
import org.apache.spark.sql.execution.{CollectMetricsExec, LogicalRDD, RDDScanExec, SQLExecution}
4851
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
4952
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
5053
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
@@ -55,6 +58,7 @@ import org.apache.spark.sql.test.SharedSparkSession
5558
import org.apache.spark.sql.types._
5659
import org.apache.spark.storage.StorageLevel
5760
import org.apache.spark.util.ArrayImplicits._
61+
import org.apache.spark.util.SparkThreadUtils
5862

5963
case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
6064
case class TestDataPoint2(x: Int, s: String)
@@ -1075,6 +1079,28 @@ class DatasetSuite extends QueryTest
10751079
assert(namedObservation2.get === expected2)
10761080
}
10771081

1082+
test("SPARK-54353: concurrent CollectMetricsExec.collect()") {
1083+
val df = spark
1084+
.range(10)
1085+
.observe(Observation("result"), map(lit("count"), count(lit(1))))
1086+
df.collect()
1087+
val threadPool = Executors.newFixedThreadPool(2)
1088+
val executionContext = ExecutionContext.fromExecutorService(threadPool)
1089+
try {
1090+
Seq(
1091+
Future { CollectMetricsExec.collect(df.queryExecution.executedPlan) }(executionContext),
1092+
Future { CollectMetricsExec.collect(df.queryExecution.executedPlan) }(executionContext)
1093+
).foreach { future =>
1094+
val result = SparkThreadUtils.awaitResult(future, Duration.Inf)
1095+
assert(result.size === 1)
1096+
assert(result.get("result") === Some(Row(Map("count" -> 10))))
1097+
}
1098+
} finally {
1099+
executionContext.shutdown()
1100+
threadPool.shutdown()
1101+
}
1102+
}
1103+
10781104
test("sample with replacement") {
10791105
val n = 100
10801106
val data = sparkContext.parallelize(1 to n, 2).toDS()

0 commit comments

Comments
 (0)