diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 98da33a429ec..93fb64f485f6 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -211,6 +211,8 @@ private[deploy] class Worker( private var registerMasterFutures: Array[JFuture[_]] = null private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + private var heartbeatTask: Option[JScheduledFuture[_]] = None + private var workDirCleanupTask: Option[JScheduledFuture[_]] = None // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same @@ -492,16 +494,25 @@ private[deploy] class Worker( logInfo(log"Successfully registered with master ${MDC(MASTER_URL, preferredMasterAddress)}") registered = true changeMaster(masterRef, masterWebUiUrl, masterAddress) - forwardMessageScheduler.scheduleAtFixedRate( - () => Utils.tryLogNonFatalError { self.send(SendHeartbeat) }, - 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) - if (CLEANUP_ENABLED) { + + // Only schedule heartbeat task if not already scheduled. The existing task will + // continue running through reconnections, and the SendHeartbeat handler already + // checks the 'connected' flag before sending heartbeats to master. + if (heartbeatTask.isEmpty) { + heartbeatTask = Some(forwardMessageScheduler.scheduleAtFixedRate( + () => Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + }, + 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)) + } + // Only schedule work directory cleanup task if not already scheduled + if (CLEANUP_ENABLED && workDirCleanupTask.isEmpty) { logInfo( log"Worker cleanup enabled; old application directories will be deleted in: " + log"${MDC(PATH, workDir)}") - forwardMessageScheduler.scheduleAtFixedRate( + workDirCleanupTask = Some(forwardMessageScheduler.scheduleAtFixedRate( () => Utils.tryLogNonFatalError { self.send(WorkDirCleanup) }, - CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) + CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)) } val execs = executors.values.map { e => @@ -852,6 +863,10 @@ private[deploy] class Worker( cleanupThreadExecutor.shutdownNow() metricsSystem.report() cancelLastRegistrationRetry() + heartbeatTask.foreach(_.cancel(true)) + heartbeatTask = None + workDirCleanupTask.foreach(_.cancel(true)) + workDirCleanupTask = None forwardMessageScheduler.shutdownNow() registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index ff5d314d1688..f9a0efce8870 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.worker import java.io.{File, IOException} +import java.util.concurrent.{ScheduledFuture => JScheduledFuture} import java.util.concurrent.atomic.AtomicBoolean import java.util.function.Supplier @@ -37,7 +38,7 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.TestUtils.{createTempJsonFile, createTempScriptWithExpectedOutput} import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService} -import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged, WorkDirCleanup} +import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged, RegisteredWorker, WorkDirCleanup} import org.apache.spark.deploy.master.DriverState import org.apache.spark.internal.config import org.apache.spark.internal.config.SHUFFLE_SERVICE_DB_BACKEND @@ -46,7 +47,7 @@ import org.apache.spark.network.shuffledb.DBBackend import org.apache.spark.resource.{ResourceAllocation, ResourceInformation} import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.resource.TestResourceIDs.{WORKER_FPGA_ID, WORKER_GPU_ID} -import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} import org.apache.spark.util.Utils class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter with PrivateMethodTester { @@ -405,4 +406,41 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter with P }.getMessage assert(m.contains("Whitespace is not allowed")) } + + test("SPARK-54312: heartbeat task and workdir cleanup task should only be scheduled once " + + "across multiple registrations") { + val worker = spy(makeWorker()) + val masterWebUiUrl = "https://1.2.3.4:8080" + val masterAddress = RpcAddress("1.2.3.4", 1234) + val masterRef = mock(classOf[RpcEndpointRef]) + when(masterRef.address).thenReturn(masterAddress) + + def getHeartbeatTask(worker: Worker): Option[JScheduledFuture[_]] = { + val _heartbeatTask = + PrivateMethod[Option[JScheduledFuture[_]]](Symbol("heartbeatTask")) + worker.invokePrivate(_heartbeatTask()) + } + + def getWorkDirCleanupTask(worker: Worker): Option[JScheduledFuture[_]] = { + val _workDirCleanupTask = + PrivateMethod[Option[JScheduledFuture[_]]](Symbol("workDirCleanupTask")) + worker.invokePrivate(_workDirCleanupTask()) + } + + // Tasks should not be scheduled yet before registration + assert(getHeartbeatTask(worker).isEmpty && getWorkDirCleanupTask(worker).isEmpty) + + val msg = RegisteredWorker(masterRef, masterWebUiUrl, masterAddress, duplicate = false) + // Simulate first registration - this should schedule both tasks + worker.receive(msg) + val heartbeatTask = getHeartbeatTask(worker) + val workDirCleanupTask = getWorkDirCleanupTask(worker) + assert(heartbeatTask.isDefined && workDirCleanupTask.isDefined) + + // Simulate disconnection and re-registration + worker.receive(msg) + // After re-registration, the task references should be the same (not rescheduled) + assert(getHeartbeatTask(worker) == heartbeatTask) + assert(getWorkDirCleanupTask(worker) == workDirCleanupTask) + } }