diff --git a/client/allocrunner/alloc_runner_hooks.go b/client/allocrunner/alloc_runner_hooks.go index abc2898b08b..c04d8f4b9c5 100644 --- a/client/allocrunner/alloc_runner_hooks.go +++ b/client/allocrunner/alloc_runner_hooks.go @@ -131,7 +131,6 @@ func (ar *allocRunner) initRunnerHooks(config *clientconfig.Config) error { providerNamespace: alloc.ServiceProviderNamespace(), serviceRegWrapper: ar.serviceRegWrapper, hookResources: ar.hookResources, - restarter: ar, networkStatus: ar, logger: hookLogger, shutdownDelayCtx: ar.shutdownDelayCtx, diff --git a/client/allocrunner/group_service_hook.go b/client/allocrunner/group_service_hook.go index 89b8e04ec93..8e281002d46 100644 --- a/client/allocrunner/group_service_hook.go +++ b/client/allocrunner/group_service_hook.go @@ -31,7 +31,6 @@ type groupServiceHook struct { group string tg *structs.TaskGroup namespace string - restarter serviceregistration.WorkloadRestarter prerun bool deregistered bool networkStatus structs.NetworkStatus @@ -63,7 +62,6 @@ type groupServiceHook struct { type groupServiceHookConfig struct { alloc *structs.Allocation - restarter serviceregistration.WorkloadRestarter networkStatus structs.NetworkStatus shutdownDelayCtx context.Context logger hclog.Logger @@ -92,7 +90,6 @@ func newGroupServiceHook(cfg groupServiceHookConfig) *groupServiceHook { jobID: cfg.alloc.JobID, group: cfg.alloc.TaskGroup, namespace: cfg.alloc.Namespace, - restarter: cfg.restarter, providerNamespace: cfg.providerNamespace, delay: shutdownDelay, networkStatus: cfg.networkStatus, @@ -311,7 +308,6 @@ func (h *groupServiceHook) getWorkloadServicesLocked() *serviceregistration.Work return &serviceregistration.WorkloadServices{ AllocInfo: info, ProviderNamespace: h.providerNamespace, - Restarter: h.restarter, Services: h.services, Networks: h.networks, NetworkStatus: netStatus, diff --git a/client/allocrunner/group_service_hook_test.go b/client/allocrunner/group_service_hook_test.go index 630bdf75a38..02a9373d6c6 100644 --- a/client/allocrunner/group_service_hook_test.go +++ b/client/allocrunner/group_service_hook_test.go @@ -14,7 +14,6 @@ import ( "github.com/hashicorp/nomad/client/serviceregistration/wrapper" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/client/taskenv" - agentconsul "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" @@ -46,7 +45,6 @@ func TestGroupServiceHook_NoGroupServices(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -88,7 +86,6 @@ func TestGroupServiceHook_ShutdownDelayUpdate(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -130,7 +127,6 @@ func TestGroupServiceHook_GroupServices(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -199,7 +195,6 @@ func TestGroupServiceHook_GroupServicesCheckUpdates(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: resources, }) @@ -254,7 +249,6 @@ func TestGroupServiceHook_GroupServices_Nomad(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -309,7 +303,6 @@ func TestGroupServiceHook_NoNetwork(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -357,7 +350,6 @@ func TestGroupServiceHook_getWorkloadServices(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -400,7 +392,6 @@ func TestGroupServiceHook_PreKill(t *testing.T) { alloc: alloc, serviceRegWrapper: regWrapper, shutdownDelayCtx: shutDownCtx, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -448,7 +439,6 @@ func TestGroupServiceHook_PreKill(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) @@ -500,7 +490,6 @@ func TestGroupServiceHook_PreKill(t *testing.T) { h := newGroupServiceHook(groupServiceHookConfig{ alloc: alloc, serviceRegWrapper: regWrapper, - restarter: agentconsul.NoopRestarter(), logger: logger, hookResources: cstructs.NewAllocHookResources(), }) diff --git a/client/allocrunner/taskrunner/check_restart_hook.go b/client/allocrunner/taskrunner/check_restart_hook.go new file mode 100644 index 00000000000..c6f64e18437 --- /dev/null +++ b/client/allocrunner/taskrunner/check_restart_hook.go @@ -0,0 +1,135 @@ +// Copyright IBM Corp. 2015, 2025 +// SPDX-License-Identifier: BUSL-1.1 + +package taskrunner + +import ( + "context" + "fmt" + + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/serviceregistration" + "github.com/hashicorp/nomad/client/serviceregistration/wrapper" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/nomad/structs" +) + +type hookCheck struct { + // providerType defines the backend service that is checking + // the service. This is currently either Nomad or Consul. + providerType string + + // providerNS is the providers namespace that the service is + // registered. When a provider implements namespaces (i.e. Consul), + // Nomad runs a single check watcher per namespace. + providerNS string + + // checkID is the ID of the check used to register a check_restart watch + checkID string + + // check is the actual Nomad service check configuration + check *structs.ServiceCheck +} + +// The checkRestartHook is responsible for registering/deregistering _both_ group and task +// check_start blocks with the appropriate CheckWatcher. This is a standalone hook and not part +// of the service hook because restarting checks is task specific, even though check_restart +// can be defined at the group level. Therefore this task will look at both TG and task services. +type checkRestartHook struct { + checks []*hookCheck + handler *wrapper.HandlerWrapper + wr serviceregistration.WorkloadRestarter + allocID string + taskName string + taskEnv *taskenv.TaskEnv + + tgName string + tgServices []*structs.Service + taskServices []*structs.Service +} + +func newCheckRestartHook(alloc *structs.Allocation, task *structs.Task, handler *wrapper.HandlerWrapper, restarter serviceregistration.WorkloadRestarter) *checkRestartHook { + tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) + return &checkRestartHook{ + handler: handler, + allocID: alloc.ID, + taskName: task.Name, + tgName: tg.Name, + tgServices: tg.Services, + taskServices: task.Services, + wr: restarter, + } +} + +func (h *checkRestartHook) Name() string { + return "check_restart" +} + +func (h *checkRestartHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, _ *interfaces.TaskPrestartResponse) error { + var checks []*hookCheck + for _, s := range taskenv.InterpolateServices(req.TaskEnv, h.tgServices) { + for _, c := range s.Checks { + if c.TriggersRestarts() && c.TaskName == h.taskName { + checks = append(checks, &hookCheck{ + providerType: s.Provider, + providerNS: s.Cluster, + check: c, + // TODO: does this work for Nomad? + checkID: checkID(h.allocID, h.tgName, fmt.Sprintf("group-%s", h.tgName), s.Provider, c, s), + }) + } + } + } + + for _, s := range taskenv.InterpolateServices(req.TaskEnv, h.taskServices) { + for _, c := range s.Checks { + if c.TriggersRestarts() { + checks = append(checks, &hookCheck{ + providerType: s.Provider, + providerNS: s.Cluster, + check: c, + // TODO: does this work for Nomad? + checkID: checkID(h.allocID, h.tgName, h.taskName, s.Provider, c, s), + }) + } + } + } + h.checks = checks + + for _, c := range h.checks { + watcher := h.handler.CheckWatcher(c.providerType, c.providerNS) + watcher.Watch(c.checkID, c.check, h.wr) + } + return nil +} + +func (h *checkRestartHook) Exited(ctx context.Context, req *interfaces.TaskExitedRequest, resp *interfaces.TaskExitedResponse) error { + for _, c := range h.checks { + watcher := h.handler.CheckWatcher(c.providerType, c.providerNS) + watcher.Unwatch(c.checkID) + } + return nil +} + +func (h *checkRestartHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error { + for _, c := range h.checks { + watcher := h.handler.CheckWatcher(c.providerType, c.providerNS) + watcher.Unwatch(c.checkID) + } + return nil +} + +// checkID returns a provider specific checkID for the workload. Unfortunately nomad and consul use different +// methods for creating checkID's so these are quite different. Consul distinguishes between group and task +// checkID's, but Nomad seems to just always use the task group name? +func checkID(allocID, tg, task, ptype string, check *structs.ServiceCheck, service *structs.Service) string { + switch ptype { + case "nomad": + return string(structs.NomadCheckID(allocID, tg, check)) + case "consul": + return consul.MakeCheckID(serviceregistration.MakeAllocServiceID(allocID, task, service), check) + default: + return "" + } +} diff --git a/client/allocrunner/taskrunner/check_restart_hook_test.go b/client/allocrunner/taskrunner/check_restart_hook_test.go new file mode 100644 index 00000000000..04d71e36f59 --- /dev/null +++ b/client/allocrunner/taskrunner/check_restart_hook_test.go @@ -0,0 +1,123 @@ +// Copyright IBM Corp. 2015, 2025 +// SPDX-License-Identifier: BUSL-1.1 + +package taskrunner + +import ( + "testing" + + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + regMock "github.com/hashicorp/nomad/client/serviceregistration/mock" + "github.com/hashicorp/nomad/client/serviceregistration/wrapper" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" + tmock "github.com/stretchr/testify/mock" +) + +func TestCheckRestartHook_Prestart(t *testing.T) { + logger := testlog.HCLogger(t) + + alloc := mock.Alloc() + alloc.Job.Canonicalize() + + handler := regMock.NewServiceRegistrationHandler(logger) + + mockWatcher := ®Mock.MockUniversalWatcher{} + mockWatcher.On("Watch", tmock.Anything, tmock.Anything, tmock.Anything) + handler.UniversalWatcher = mockWatcher + + regWrap := wrapper.NewHandlerWrapper(logger, handler, handler) + + service := &structs.Service{ + Provider: "nomad", + Checks: []*structs.ServiceCheck{ + { + TaskName: "web", + CheckRestart: &structs.CheckRestart{ + Limit: 1, + }, + }, + }, + } + + // group level service + alloc.Job.LookupTaskGroup("web").Services = []*structs.Service{service} + + // task level service + task := alloc.LookupTask("web") + task.Services = []*structs.Service{service} + + testHook := newCheckRestartHook(alloc, task, regWrap, consul.NoopRestarter()) + + err := testHook.Prestart( + t.Context(), + &interfaces.TaskPrestartRequest{TaskEnv: taskenv.NewEmptyTaskEnv()}, + &interfaces.TaskPrestartResponse{}, + ) + must.NoError(t, err) + must.True(t, mockWatcher.AssertNumberOfCalls(t, "Watch", 2)) +} + +func TestCheckRestartHook_Exited(t *testing.T) { + logger := testlog.HCLogger(t) + + alloc := mock.Alloc() + alloc.Job.Canonicalize() + + mockWatcher := ®Mock.MockUniversalWatcher{} + mockWatcher.On("Unwatch", tmock.Anything) + + handler := regMock.NewServiceRegistrationHandler(logger) + handler.UniversalWatcher = mockWatcher + + regWrap := wrapper.NewHandlerWrapper(logger, handler, handler) + + testHook := newCheckRestartHook(alloc, alloc.LookupTask("web"), regWrap, consul.NoopRestarter()) + testHook.checks = []*hookCheck{ + { + providerType: "nomad", + }, + } + + err := testHook.Exited( + t.Context(), + &interfaces.TaskExitedRequest{}, + &interfaces.TaskExitedResponse{}, + ) + must.NoError(t, err) + must.True(t, mockWatcher.AssertNumberOfCalls(t, "Unwatch", 1)) +} + +func TestCheckRestartHook_Stop(t *testing.T) { + logger := testlog.HCLogger(t) + + alloc := mock.Alloc() + alloc.Job.Canonicalize() + + mockWatcher := ®Mock.MockUniversalWatcher{} + mockWatcher.On("Unwatch", tmock.Anything) + + handler := regMock.NewServiceRegistrationHandler(logger) + handler.UniversalWatcher = mockWatcher + + regWrap := wrapper.NewHandlerWrapper(logger, handler, handler) + + testHook := newCheckRestartHook(alloc, alloc.LookupTask("web"), regWrap, consul.NoopRestarter()) + testHook.checks = []*hookCheck{ + { + providerType: "nomad", + }, + } + + err := testHook.Stop( + t.Context(), + &interfaces.TaskStopRequest{}, + &interfaces.TaskStopResponse{}, + ) + must.NoError(t, err) + must.True(t, mockWatcher.AssertNumberOfCalls(t, "Unwatch", 1)) +} diff --git a/client/allocrunner/taskrunner/connect_native_hook_test.go b/client/allocrunner/taskrunner/connect_native_hook_test.go index 2020b5d2db1..45b23c3e304 100644 --- a/client/allocrunner/taskrunner/connect_native_hook_test.go +++ b/client/allocrunner/taskrunner/connect_native_hook_test.go @@ -332,7 +332,7 @@ func TestTaskRunner_ConnectNativeHook_Ok(t *testing.T) { consulClient := agentconsul.NewServiceClient(consulAPIClient.Agent(), namespacesClient, logger, true) go consulClient.Run() defer consulClient.Shutdown() - require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc, agentconsul.NoopRestarter()))) + require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc))) // Run Connect Native hook h := newConnectNativeHook(newConnectNativeHookConfig(alloc, &config.ConsulConfig{ @@ -394,7 +394,7 @@ func TestTaskRunner_ConnectNativeHook_with_SI_token(t *testing.T) { consulClient := agentconsul.NewServiceClient(consulAPIClient.Agent(), namespacesClient, logger, true) go consulClient.Run() defer consulClient.Shutdown() - require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc, agentconsul.NoopRestarter()))) + require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc))) // Run Connect Native hook h := newConnectNativeHook(newConnectNativeHookConfig(alloc, &config.ConsulConfig{ @@ -467,7 +467,7 @@ func TestTaskRunner_ConnectNativeHook_shareTLS(t *testing.T) { consulClient := agentconsul.NewServiceClient(consulAPIClient.Agent(), namespacesClient, logger, true) go consulClient.Run() defer consulClient.Shutdown() - require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc, agentconsul.NoopRestarter()))) + require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc))) // Run Connect Native hook h := newConnectNativeHook(newConnectNativeHookConfig(alloc, &config.ConsulConfig{ @@ -583,7 +583,7 @@ func TestTaskRunner_ConnectNativeHook_shareTLS_override(t *testing.T) { consulClient := agentconsul.NewServiceClient(consulAPIClient.Agent(), namespacesClient, logger, true) go consulClient.Run() defer consulClient.Shutdown() - require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc, agentconsul.NoopRestarter()))) + require.NoError(t, consulClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc))) // Run Connect Native hook h := newConnectNativeHook(newConnectNativeHookConfig(alloc, &config.ConsulConfig{ diff --git a/client/allocrunner/taskrunner/envoy_bootstrap_hook_test.go b/client/allocrunner/taskrunner/envoy_bootstrap_hook_test.go index c0242216d4c..621189ad0e5 100644 --- a/client/allocrunner/taskrunner/envoy_bootstrap_hook_test.go +++ b/client/allocrunner/taskrunner/envoy_bootstrap_hook_test.go @@ -351,7 +351,7 @@ func TestEnvoyBootstrapHook_with_SI_token(t *testing.T) { serviceClient := agentconsul.NewServiceClient(consulAPIClient.Agent(), namespacesClient, logger, true) go serviceClient.Run() defer serviceClient.Shutdown() - must.NoError(t, serviceClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc, agentconsul.NoopRestarter()))) + must.NoError(t, serviceClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc))) // Run Connect bootstrap Hook h := newEnvoyBootstrapHook(newEnvoyBootstrapHookConfig(alloc, &config.ConsulConfig{ @@ -449,7 +449,7 @@ func TestEnvoyBootstrapHook_sidecar_ok(t *testing.T) { serviceClient := agentconsul.NewServiceClient(consulAPIClient.Agent(), namespacesClient, logger, true) go serviceClient.Run() defer serviceClient.Shutdown() - require.NoError(t, serviceClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc, agentconsul.NoopRestarter()))) + require.NoError(t, serviceClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc))) // Run Connect bootstrap Hook h := newEnvoyBootstrapHook(newEnvoyBootstrapHookConfig(alloc, &config.ConsulConfig{ @@ -512,7 +512,7 @@ func TestEnvoyBootstrapHook_gateway_ok(t *testing.T) { serviceClient := agentconsul.NewServiceClient(consulAPIClient.Agent(), namespacesClient, logger, true) go serviceClient.Run() defer serviceClient.Shutdown() - require.NoError(t, serviceClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc, agentconsul.NoopRestarter()))) + require.NoError(t, serviceClient.RegisterWorkload(agentconsul.BuildAllocServices(mock.Node(), alloc))) // Register Configuration Entry ceClient := consulAPIClient.ConfigEntries() diff --git a/client/allocrunner/taskrunner/lifecycle.go b/client/allocrunner/taskrunner/lifecycle.go index 100e6da81bc..7eeeb03b659 100644 --- a/client/allocrunner/taskrunner/lifecycle.go +++ b/client/allocrunner/taskrunner/lifecycle.go @@ -79,12 +79,13 @@ func (tr *TaskRunner) restartImpl(ctx context.Context, event *structs.TaskEvent, return ErrTaskNotRunning } + if err := tr.restartTracker.SetRestartTriggered(failure); err != nil { + return err + } + // Emit the event since it may take a long time to kill tr.EmitEvent(event) - // Tell the restart tracker that a restart triggered the exit - tr.restartTracker.SetRestartTriggered(failure) - // Signal a restart to unblock tasks that are in the "dead" state, but // don't block since the channel is buffered. Only one signal is enough to // notify the tr.Run() loop. diff --git a/client/allocrunner/taskrunner/restarts/restarts.go b/client/allocrunner/taskrunner/restarts/restarts.go index 2a8234aa697..1665ed95b3a 100644 --- a/client/allocrunner/taskrunner/restarts/restarts.go +++ b/client/allocrunner/taskrunner/restarts/restarts.go @@ -4,6 +4,7 @@ package restarts import ( + "errors" "fmt" "math/rand" "sync" @@ -106,15 +107,22 @@ func (r *RestartTracker) SetExitResult(res *drivers.ExitResult) *RestartTracker // restarted. Setting the failure to true restarts according to the restart // policy. When failure is false the task is restarted without considering the // restart policy. -func (r *RestartTracker) SetRestartTriggered(failure bool) *RestartTracker { +// Returns an error if the task has already been set to restart due to failure. +func (r *RestartTracker) SetRestartTriggered(failure bool) error { r.lock.Lock() defer r.lock.Unlock() + + // err if the task was already marked as restarting for failure + if r.failure && failure { + return errors.New("task failure restart already triggered") + } + if failure { r.failure = true } else { r.restartTriggered = true } - return r + return nil } // SetKilled is used to mark that the task has been killed. diff --git a/client/allocrunner/taskrunner/restarts/restarts_test.go b/client/allocrunner/taskrunner/restarts/restarts_test.go index b11c6d1f481..43cdfab6e92 100644 --- a/client/allocrunner/taskrunner/restarts/restarts_test.go +++ b/client/allocrunner/taskrunner/restarts/restarts_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/drivers" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -139,7 +140,8 @@ func TestClient_RestartTracker_RestartTriggered(t *testing.T) { p := testPolicy(true, structs.RestartPolicyModeFail) p.Attempts = 0 rt := NewRestartTracker(p, structs.JobTypeService, nil) - if state, when := rt.SetRestartTriggered(false).GetState(); state != structs.TaskRestarting && when != 0 { + must.NoError(t, rt.SetRestartTriggered(false)) + if state, when := rt.GetState(); state != structs.TaskRestarting && when != 0 { t.Fatalf("expect restart immediately, got %v %v", state, when) } } @@ -149,10 +151,12 @@ func TestClient_RestartTracker_RestartTriggered_Failure(t *testing.T) { p := testPolicy(true, structs.RestartPolicyModeFail) p.Attempts = 1 rt := NewRestartTracker(p, structs.JobTypeService, nil) - if state, when := rt.SetRestartTriggered(true).GetState(); state != structs.TaskRestarting || when == 0 { + must.NoError(t, rt.SetRestartTriggered(true)) + if state, when := rt.GetState(); state != structs.TaskRestarting || when == 0 { t.Fatalf("expect restart got %v %v", state, when) } - if state, when := rt.SetRestartTriggered(true).GetState(); state != structs.TaskNotRestarting || when != 0 { + must.NoError(t, rt.SetRestartTriggered(true)) + if state, when := rt.GetState(); state != structs.TaskNotRestarting || when != 0 { t.Fatalf("expect failed got %v %v", state, when) } } diff --git a/client/allocrunner/taskrunner/service_hook.go b/client/allocrunner/taskrunner/service_hook.go index 7d21119fe28..1b63fbc5cf4 100644 --- a/client/allocrunner/taskrunner/service_hook.go +++ b/client/allocrunner/taskrunner/service_hook.go @@ -253,7 +253,6 @@ func (h *serviceHook) getWorkloadServices() *serviceregistration.WorkloadService return &serviceregistration.WorkloadServices{ AllocInfo: info, ProviderNamespace: h.providerNamespace, - Restarter: h.restarter, Services: interpolatedServices, DriverExec: h.driverExec, DriverNetwork: h.driverNet, diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index 9f1f8030691..9a9d2504ed1 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -699,7 +699,7 @@ MAIN: select { case <-tr.killCtx.Done(): // We can go through the normal should restart check since - // the restart tracker knowns it is killed + // the restart tracker knows it is killed result = tr.handleKill(resultCh) case <-tr.shutdownCtx.Done(): // TaskRunner was told to exit immediately diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 34deef8bffe..aa8ccc1eb51 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -52,6 +52,7 @@ func (h *hookResources) getMounts() []*drivers.MountConfig { func (tr *TaskRunner) initHooks() { hookLogger := tr.logger.Named("task_hook") task := tr.Task() + alloc := tr.Alloc() tr.logmonHookConfig = newLogMonHookConfig(task.Name, task.LogConfig, tr.taskDir.LogDir) @@ -93,7 +94,6 @@ func (tr *TaskRunner) initHooks() { }, task.Secrets)) } - alloc := tr.Alloc() tr.runnerHooks = append(tr.runnerHooks, []interfaces.TaskHook{ newLogMonHook(tr, hookLogger), newDispatchHook(alloc, hookLogger), @@ -149,6 +149,8 @@ func (tr *TaskRunner) initHooks() { logger: hookLogger, })) + tr.runnerHooks = append(tr.runnerHooks, newCheckRestartHook(alloc, task, tr.serviceRegWrapper, tr)) + // If this is a Connect sidecar proxy (or a Connect Native) service, // add the sidsHook for requesting a Service Identity token (if ACLs). if task.UsesConnect() { diff --git a/client/allocrunner/taskrunner/task_runner_linux_test.go b/client/allocrunner/taskrunner/task_runner_linux_test.go index 1343e81d4de..5a93fe33f46 100644 --- a/client/allocrunner/taskrunner/task_runner_linux_test.go +++ b/client/allocrunner/taskrunner/task_runner_linux_test.go @@ -45,6 +45,7 @@ import ( ctestutil "github.com/hashicorp/nomad/client/testutil" "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/client/widmgr" + "github.com/hashicorp/nomad/command/agent/consul" agentconsul "github.com/hashicorp/nomad/command/agent/consul" mockdriver "github.com/hashicorp/nomad/drivers/mock" "github.com/hashicorp/nomad/drivers/rawexec" @@ -1360,9 +1361,11 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) { consulServices := agentconsul.NewServiceClient(consulAgent, namespacesClient, conf.Logger, true) go consulServices.Run() defer consulServices.Shutdown() + sc := consul.NewServiceClientWrapper() + sc.AddClient("default", consulServices) - conf.ConsulServices = consulServices - conf.ServiceRegWrapper = wrapper.NewHandlerWrapper(conf.Logger, consulServices, nil) + conf.ConsulServices = sc + conf.ServiceRegWrapper = wrapper.NewHandlerWrapper(conf.Logger, sc, nil) tr, err := NewTaskRunner(conf) require.NoError(t, err) @@ -1967,8 +1970,10 @@ func TestTaskRunner_DriverNetwork(t *testing.T) { defer consulServices.Shutdown() go consulServices.Run() - conf.ConsulServices = consulServices - conf.ServiceRegWrapper = wrapper.NewHandlerWrapper(conf.Logger, consulServices, nil) + sc := consul.NewServiceClientWrapper() + sc.AddClient("default", consulServices) + conf.ConsulServices = sc + conf.ServiceRegWrapper = wrapper.NewHandlerWrapper(conf.Logger, sc, nil) tr, err := NewTaskRunner(conf) require.NoError(t, err) diff --git a/client/serviceregistration/mock/mock.go b/client/serviceregistration/mock/mock.go index 4c669cc7b6c..85db44a0ec2 100644 --- a/client/serviceregistration/mock/mock.go +++ b/client/serviceregistration/mock/mock.go @@ -30,6 +30,8 @@ type ServiceRegistrationHandler struct { // AllocRegistrationsFn allows injecting return values for the // AllocRegistrations function. AllocRegistrationsFn func(allocID string) (*serviceregistration.AllocRegistration, error) + + UniversalWatcher serviceregistration.CheckWatcher } // NewServiceRegistrationHandler returns a ready to use @@ -41,6 +43,10 @@ func NewServiceRegistrationHandler(log hclog.Logger) *ServiceRegistrationHandler } } +func (h *ServiceRegistrationHandler) CheckWatcher(_ string) serviceregistration.CheckWatcher { + return h.UniversalWatcher +} + func (h *ServiceRegistrationHandler) RegisterWorkload(services *serviceregistration.WorkloadServices) error { h.mu.Lock() defer h.mu.Unlock() diff --git a/client/serviceregistration/mock/universal_watcher.go b/client/serviceregistration/mock/universal_watcher.go new file mode 100644 index 00000000000..770e812932d --- /dev/null +++ b/client/serviceregistration/mock/universal_watcher.go @@ -0,0 +1,32 @@ +// Copyright IBM Corp. 2015, 2025 +// SPDX-License-Identifier: BUSL-1.1 + +package mock + +import ( + "context" + + "github.com/hashicorp/nomad/client/serviceregistration" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/mock" +) + +type MockUniversalWatcher struct { + mock.Mock +} + +func (m *MockUniversalWatcher) Run(ctx context.Context) { + m.Called() +} + +// Watch the given check. If the check status enters a failing state, the +// task associated with the check will be restarted according to its check_restart +// policy via wr. +func (m *MockUniversalWatcher) Watch(checkID string, check *structs.ServiceCheck, wr serviceregistration.WorkloadRestarter) { + m.Called(checkID, check, wr) +} + +// Unwatch will cause the CheckWatcher to no longer monitor the check of given checkID. +func (m *MockUniversalWatcher) Unwatch(checkID string) { + m.Called(checkID) +} diff --git a/client/serviceregistration/nsd/nsd.go b/client/serviceregistration/nsd/nsd.go index a16aa2cdaa1..b5a91c2a504 100644 --- a/client/serviceregistration/nsd/nsd.go +++ b/client/serviceregistration/nsd/nsd.go @@ -113,6 +113,10 @@ func NewServiceRegistrationHandler(log hclog.Logger, cfg *ServiceRegistrationHan // renewed. func (s *ServiceRegistrationHandler) SetNodeIdentityToken(token string) { s.nodeAuthToken.Store(token) } +func (s *ServiceRegistrationHandler) CheckWatcher(_ string) serviceregistration.CheckWatcher { + return s.checkWatcher +} + func (s *ServiceRegistrationHandler) RegisterWorkload(workload *serviceregistration.WorkloadServices) error { // Check whether we are enabled or not first. Hitting this likely means // there is a bug within the implicit constraint, or process using it, as @@ -143,18 +147,6 @@ func (s *ServiceRegistrationHandler) RegisterWorkload(workload *serviceregistrat return err } - // Service registrations look ok; startup check watchers as specified. The - // astute observer may notice the services are not actually registered yet - - // this is the same as the Consul flow so hopefully things just work out. - for _, service := range workload.Services { - for _, check := range service.Checks { - if check.TriggersRestarts() { - checkID := string(structs.NomadCheckID(workload.AllocInfo.AllocID, workload.AllocInfo.Group, check)) - s.checkWatcher.Watch(workload.AllocInfo.AllocID, workload.Name(), checkID, check, workload.Restarter) - } - } - } - args := structs.ServiceRegistrationUpsertRequest{ Services: registrations, WriteRequest: structs.WriteRequest{ @@ -194,16 +186,6 @@ func (s *ServiceRegistrationHandler) removeWorkload( // unblock wait group when we are done defer wg.Done() - // Stop check watcher - // - // todo(shoenig) - shouldn't we only unwatch checks for the given serviceSpec ? - for _, service := range workload.Services { - for _, check := range service.Checks { - checkID := string(structs.NomadCheckID(workload.AllocInfo.AllocID, workload.AllocInfo.Group, check)) - s.checkWatcher.Unwatch(checkID) - } - } - // Generate the consistent ID for this service, so we know what to remove. id := serviceregistration.MakeAllocServiceID(workload.AllocInfo.AllocID, workload.Name(), serviceSpec) diff --git a/client/serviceregistration/nsd/nsd_test.go b/client/serviceregistration/nsd/nsd_test.go index a89771d18c5..218ad0b5152 100644 --- a/client/serviceregistration/nsd/nsd_test.go +++ b/client/serviceregistration/nsd/nsd_test.go @@ -15,7 +15,6 @@ import ( "github.com/hashicorp/nomad/client/serviceregistration" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/structs" - "github.com/shoenig/test" "github.com/shoenig/test/must" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,7 +31,7 @@ func (cw *mockCheckWatcher) Run(_ context.Context) { // Run runs async; just assume it ran } -func (cw *mockCheckWatcher) Watch(_, _, _ string, _ *structs.ServiceCheck, _ serviceregistration.WorkloadRestarter) { +func (cw *mockCheckWatcher) Watch(_ string, _ *structs.ServiceCheck, _ serviceregistration.WorkloadRestarter) { cw.lock.Lock() defer cw.lock.Unlock() cw.watchCalls++ @@ -44,45 +43,33 @@ func (cw *mockCheckWatcher) Unwatch(_ string) { cw.unWatchCalls++ } -func (cw *mockCheckWatcher) assert(t *testing.T, watchCalls, unWatchCalls int) { - cw.lock.Lock() - defer cw.lock.Unlock() - test.Eq(t, watchCalls, cw.watchCalls, test.Sprintf("expected %d Watch() calls but got %d", watchCalls, cw.watchCalls)) - test.Eq(t, unWatchCalls, cw.unWatchCalls, test.Sprintf("expected %d Unwatch() calls but got %d", unWatchCalls, cw.unWatchCalls)) -} - func TestServiceRegistrationHandler_RegisterWorkload(t *testing.T) { testCases := []struct { - name string - inputCfg *ServiceRegistrationHandlerCfg - inputWorkload *serviceregistration.WorkloadServices - expectedRPCs map[string]int - expectedError error - expWatch, expUnWatch int + name string + inputCfg *ServiceRegistrationHandlerCfg + inputWorkload *serviceregistration.WorkloadServices + expectedRPCs map[string]int + expectedError error }{ { name: "registration disabled", inputCfg: &ServiceRegistrationHandlerCfg{ Enabled: false, - CheckWatcher: new(mockCheckWatcher), + CheckWatcher: &mockCheckWatcher{}, }, inputWorkload: mockWorkload(), expectedRPCs: map[string]int{}, expectedError: errors.New(`service registration provider "nomad" not enabled`), - expWatch: 0, - expUnWatch: 0, }, { name: "registration enabled", inputCfg: &ServiceRegistrationHandlerCfg{ Enabled: true, - CheckWatcher: new(mockCheckWatcher), + CheckWatcher: &mockCheckWatcher{}, }, inputWorkload: mockWorkload(), expectedRPCs: map[string]int{structs.ServiceRegistrationUpsertRPCMethod: 1}, expectedError: nil, - expWatch: 1, - expUnWatch: 0, }, } @@ -102,20 +89,18 @@ func TestServiceRegistrationHandler_RegisterWorkload(t *testing.T) { actualErr := h.RegisterWorkload(tc.inputWorkload) require.Equal(t, tc.expectedError, actualErr) require.Equal(t, tc.expectedRPCs, mockRPC.calls()) - tc.inputCfg.CheckWatcher.(*mockCheckWatcher).assert(t, tc.expWatch, tc.expUnWatch) }) } } func TestServiceRegistrationHandler_RemoveWorkload(t *testing.T) { testCases := []struct { - name string - inputCfg *ServiceRegistrationHandlerCfg - inputWorkload *serviceregistration.WorkloadServices - returnedDeleteErr error - expectedRPCs map[string]int - expectedError error - expWatch, expUnWatch int + name string + inputCfg *ServiceRegistrationHandlerCfg + inputWorkload *serviceregistration.WorkloadServices + returnedDeleteErr error + expectedRPCs map[string]int + expectedError error }{ { name: "registration disabled multiple services", @@ -126,8 +111,6 @@ func TestServiceRegistrationHandler_RemoveWorkload(t *testing.T) { inputWorkload: mockWorkload(), expectedRPCs: map[string]int{structs.ServiceRegistrationDeleteByIDRPCMethod: 2}, expectedError: nil, - expWatch: 0, - expUnWatch: 2, // RemoveWorkload works regardless if provider is enabled }, { name: "registration enabled multiple services", @@ -138,8 +121,6 @@ func TestServiceRegistrationHandler_RemoveWorkload(t *testing.T) { inputWorkload: mockWorkload(), expectedRPCs: map[string]int{structs.ServiceRegistrationDeleteByIDRPCMethod: 2}, expectedError: nil, - expWatch: 0, - expUnWatch: 2, }, { name: "failed deregister", @@ -153,8 +134,6 @@ func TestServiceRegistrationHandler_RemoveWorkload(t *testing.T) { returnedDeleteErr: errors.New("unrecoverable error"), expectedRPCs: map[string]int{structs.ServiceRegistrationDeleteByIDRPCMethod: 4}, expectedError: nil, - expWatch: 0, - expUnWatch: 2, }, } @@ -174,7 +153,6 @@ func TestServiceRegistrationHandler_RemoveWorkload(t *testing.T) { h.RemoveWorkload(tc.inputWorkload) must.Eq(t, tc.expectedRPCs, mockRPC.calls()) - tc.inputCfg.CheckWatcher.(*mockCheckWatcher).assert(t, tc.expWatch, tc.expUnWatch) }) } } @@ -335,7 +313,6 @@ func TestServiceRegistrationHandler_UpdateWorkload(t *testing.T) { require.Eventually(t, func() bool { return assert.Equal(t, tc.expectedRPCs, mockRPC.calls()) }, 100*time.Millisecond, 10*time.Millisecond) - tc.inputCfg.CheckWatcher.(*mockCheckWatcher).assert(t, tc.expWatch, tc.expUnWatch) }) } @@ -677,7 +654,7 @@ func (mr *mockRPC) calls() map[string]int { } // RPC mocks the server RPCs, acting as though any request succeeds. -func (mr *mockRPC) RPC(method string, _, _ interface{}) error { +func (mr *mockRPC) RPC(method string, _, _ any) error { mr.l.Lock() defer mr.l.Unlock() diff --git a/client/serviceregistration/service_registration.go b/client/serviceregistration/service_registration.go index f2c91d0aad2..be1ac0cf57d 100644 --- a/client/serviceregistration/service_registration.go +++ b/client/serviceregistration/service_registration.go @@ -46,6 +46,13 @@ type Handler interface { // UpdateTTL is used to update the TTL of an individual service // registration check. UpdateTTL(id, namespace, output, status string) error + + // CheckWatcher returns the CheckWatcher for the service provider key + // + // Note: (mismith) this is awkward but removing the CheckWatcher from + // serviceReg is a decent sized lift due to Consul configuration complexity. + // Leaving this for a followup PR. + CheckWatcher(key string) CheckWatcher } type HandlerFunc func(string) Handler diff --git a/client/serviceregistration/watcher.go b/client/serviceregistration/watcher.go index a8bbb836acb..f0b2bcd7638 100644 --- a/client/serviceregistration/watcher.go +++ b/client/serviceregistration/watcher.go @@ -10,19 +10,12 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-set/v3" - "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/nomad/structs" ) -// composite of allocID + taskName for uniqueness -type key string - type restarter struct { - allocID string - taskName string checkID string checkName string - taskKey key logger hclog.Logger task WorkloadRestarter @@ -142,7 +135,7 @@ type CheckWatcher interface { // Watch the given check. If the check status enters a failing state, the // task associated with the check will be restarted according to its check_restart // policy via wr. - Watch(allocID, taskName, checkID string, check *structs.ServiceCheck, wr WorkloadRestarter) + Watch(checkID string, check *structs.ServiceCheck, wr WorkloadRestarter) // Unwatch will cause the CheckWatcher to no longer monitor the check of given checkID. Unwatch(checkID string) @@ -179,24 +172,21 @@ func NewCheckWatcher(logger hclog.Logger, getter CheckStatusGetter) *UniversalCh } // Watch a check and restart its task if unhealthy. -func (w *UniversalCheckWatcher) Watch(allocID, taskName, checkID string, check *structs.ServiceCheck, wr WorkloadRestarter) { +func (w *UniversalCheckWatcher) Watch(checkID string, check *structs.ServiceCheck, wr WorkloadRestarter) { if !check.TriggersRestarts() { return // check_restart not set; no-op } c := &restarter{ - allocID: allocID, - taskName: taskName, checkID: checkID, checkName: check.Name, - taskKey: key(allocID + taskName), task: wr, interval: check.Interval, grace: check.CheckRestart.Grace, graceUntil: time.Now().Add(check.CheckRestart.Grace), timeLimit: check.Interval * time.Duration(check.CheckRestart.Limit-1), ignoreWarnings: check.CheckRestart.IgnoreWarnings, - logger: w.logger.With("alloc_id", allocID, "task", taskName, "check", check.Name), + logger: w.logger.With("check", check.Name), } select { @@ -225,26 +215,7 @@ func (w *UniversalCheckWatcher) Run(ctx context.Context) { // map of checkID to their restarter handle (contains only checks we are watching) watched := make(map[string]*restarter) - checkTimer, cleanupCheckTimer := helper.NewSafeTimer(0) - defer cleanupCheckTimer() - - stopCheckTimer := func() { // todo: refactor using that other pattern - checkTimer.Stop() - select { - case <-checkTimer.C: - default: - } - } - - // initialize with checkTimer disabled - stopCheckTimer() - for { - // disable polling if there are no checks - if len(watched) == 0 { - stopCheckTimer() - } - select { // caller cancelled us; goodbye case <-ctx.Done(): @@ -258,21 +229,11 @@ func (w *UniversalCheckWatcher) Run(ctx context.Context) { } watched[update.checkID] = update.restart - allocID := update.restart.allocID - taskName := update.restart.taskName checkName := update.restart.checkName - w.logger.Trace("now watching check", "alloc_i", allocID, "task", taskName, "check", checkName) - - // turn on the timer if we are now active - if len(watched) == 1 { - stopCheckTimer() - checkTimer.Reset(w.pollFrequency) - } + w.logger.Trace("now watching check", "check", checkName) - // poll time; refresh check statuses - case now := <-checkTimer.C: - w.interval(ctx, now, watched) - checkTimer.Reset(w.pollFrequency) + case <-time.After(w.pollFrequency): + w.interval(ctx, time.Now(), watched) } } } @@ -287,7 +248,7 @@ func (w *UniversalCheckWatcher) interval(ctx context.Context, now time.Time, wat w.failedPreviousInterval = false // keep track of tasks restarted this interval - restarts := set.New[key](len(statuses)) + restarts := set.New[string](len(statuses)) // iterate over status of all checks, and update the status of checks // we care about watching @@ -296,7 +257,7 @@ func (w *UniversalCheckWatcher) interval(ctx context.Context, now time.Time, wat return // short circuit; caller cancelled us } - if restarts.Contains(checkRestarter.taskKey) { + if restarts.Contains(checkID) { // skip; task is already being restarted delete(watched, checkID) continue @@ -314,16 +275,7 @@ func (w *UniversalCheckWatcher) interval(ctx context.Context, now time.Time, wat if checkRestarter.apply(ctx, now, status) { // check will be re-registered & re-watched on startup delete(watched, checkID) - restarts.Insert(checkRestarter.taskKey) - } - } - - // purge passing checks of tasks that are being restarted - if restarts.Size() > 0 { - for checkID, checkRestarter := range watched { - if restarts.Contains(checkRestarter.taskKey) { - delete(watched, checkID) - } + restarts.Insert(checkID) } } } diff --git a/client/serviceregistration/watcher_test.go b/client/serviceregistration/watcher_test.go index 53ddc502058..7b331bf94a9 100644 --- a/client/serviceregistration/watcher_test.go +++ b/client/serviceregistration/watcher_test.go @@ -36,20 +36,16 @@ type fakeWorkloadRestarter struct { // check to re-Watch on restarts check *structs.ServiceCheck - allocID string - taskName string checkName string lock sync.Mutex } // newFakeCheckRestart creates a new mock WorkloadRestarter. -func newFakeWorkloadRestarter(w *UniversalCheckWatcher, allocID, taskName, checkName string, c *structs.ServiceCheck) *fakeWorkloadRestarter { +func newFakeWorkloadRestarter(w *UniversalCheckWatcher, checkName string, c *structs.ServiceCheck) *fakeWorkloadRestarter { return &fakeWorkloadRestarter{ watcher: w, check: c, - allocID: allocID, - taskName: taskName, checkName: checkName, } } @@ -71,7 +67,7 @@ func (c *fakeWorkloadRestarter) Restart(_ context.Context, event *structs.TaskEv c.restarts = append(c.restarts, restart) // Re-Watch the check just like TaskRunner - c.watcher.Watch(c.allocID, c.taskName, c.checkName, c.check, c) + c.watcher.Watch(c.checkName, c.check, c) return nil } @@ -80,9 +76,10 @@ func (c *fakeWorkloadRestarter) String() string { c.lock.Lock() defer c.lock.Unlock() - s := fmt.Sprintf("%s %s %s restarts:\n", c.allocID, c.taskName, c.checkName) + s := fmt.Sprintf("%s restarts:\n", c.checkName) for _, r := range c.restarts { - s += fmt.Sprintf("%s - %s: %s (failure: %t)\n", r.timestamp, r.source, r.reason, r.failure) + restart := fmt.Sprintf("%s - %s: %s (failure: %t)\n", r.timestamp, r.source, r.reason, r.failure) + s = fmt.Sprintf("%s%s", s, restart) } return s } @@ -177,8 +174,8 @@ func TestCheckWatcher_SkipUnwatched(t *testing.T) { getter := new(fakeCheckStatusGetter) cw := NewCheckWatcher(logger, getter) - restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check) - cw.Watch("testalloc1", "testtask1", "testcheck1", check, restarter1) + restarter1 := newFakeWorkloadRestarter(cw, "testcheck1", check) + cw.Watch("testcheck1", check, restarter1) // Check should have been dropped as it's not watched enqueued := len(cw.checkUpdateCh) @@ -197,14 +194,14 @@ func TestCheckWatcher_Healthy(t *testing.T) { getter.add("testcheck2", "passing", now) check1 := testCheck() - restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) - cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + restarter1 := newFakeWorkloadRestarter(cw, "testcheck1", check1) + cw.Watch("testcheck1", check1, restarter1) check2 := testCheck() check2.CheckRestart.Limit = 1 check2.CheckRestart.Grace = 0 - restarter2 := newFakeWorkloadRestarter(cw, "testalloc2", "testtask2", "testcheck2", check2) - cw.Watch("testalloc2", "testtask2", "testcheck2", check2, restarter2) + restarter2 := newFakeWorkloadRestarter(cw, "testcheck2", check2) + cw.Watch("testcheck2", check2, restarter2) // Run ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) @@ -227,8 +224,8 @@ func TestCheckWatcher_Unhealthy(t *testing.T) { getter.add("testcheck1", "critical", now) check1 := testCheck() - restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) - cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + restarter1 := newFakeWorkloadRestarter(cw, "testcheck1", check1) + cw.Watch("testcheck1", check1, restarter1) // Run ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) @@ -254,8 +251,8 @@ func TestCheckWatcher_HealthyWarning(t *testing.T) { check1.CheckRestart.Limit = 1 check1.CheckRestart.Grace = 0 check1.CheckRestart.IgnoreWarnings = true - restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) - cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + restarter1 := newFakeWorkloadRestarter(cw, "testcheck1", check1) + cw.Watch("testcheck1", check1, restarter1) // Run ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) @@ -275,8 +272,8 @@ func TestCheckWatcher_Flapping(t *testing.T) { check1 := testCheck() check1.CheckRestart.Grace = 0 - restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) - cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + restarter1 := newFakeWorkloadRestarter(cw, "testcheck1", check1) + cw.Watch("testcheck1", check1, restarter1) // Check flaps and is never failing for the full 200ms needed to restart now := time.Now() @@ -308,8 +305,8 @@ func TestCheckWatcher_Unwatch(t *testing.T) { check1 := testCheck() check1.CheckRestart.Limit = 1 check1.CheckRestart.Grace = 100 * time.Millisecond - restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) - cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + restarter1 := newFakeWorkloadRestarter(cw, "testcheck1", check1) + cw.Watch("testcheck1", check1, restarter1) cw.Unwatch("testcheck1") ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) @@ -339,20 +336,20 @@ func TestCheckWatcher_MultipleChecks(t *testing.T) { check1 := testCheck() check1.Name = "testcheck1" check1.CheckRestart.Limit = 1 - restarter1 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck1", check1) - cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + restarter1 := newFakeWorkloadRestarter(cw, "testcheck1", check1) + cw.Watch("testcheck1", check1, restarter1) check2 := testCheck() check2.Name = "testcheck2" check2.CheckRestart.Limit = 1 - restarter2 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck2", check2) - cw.Watch("testalloc1", "testtask1", "testcheck2", check2, restarter2) + restarter2 := newFakeWorkloadRestarter(cw, "testcheck2", check2) + cw.Watch("testcheck2", check2, restarter2) check3 := testCheck() check3.Name = "testcheck3" check3.CheckRestart.Limit = 1 - restarter3 := newFakeWorkloadRestarter(cw, "testalloc1", "testtask1", "testcheck3", check3) - cw.Watch("testalloc1", "testtask1", "testcheck3", check3, restarter3) + restarter3 := newFakeWorkloadRestarter(cw, "testcheck3", check3) + cw.Watch("testcheck3", check3, restarter3) // Run ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) @@ -362,13 +359,12 @@ func TestCheckWatcher_MultipleChecks(t *testing.T) { // Ensure that restart was only called once on check 1 or 2. Since // checks are in a map it's random which check triggers the restart // first. - if n := len(restarter1.restarts) + len(restarter2.restarts); n != 1 { - t.Errorf("expected check 1 & 2 to be restarted 1 time but found %d\ncheck 1:\n%s\ncheck 2:%s", - n, restarter1, restarter2) + if n := len(restarter1.restarts) + len(restarter2.restarts); n != 2 { + t.Errorf("expected check 1 & 2 to be restarted 2 times but found %d", n) } if n := len(restarter3.restarts); n != 0 { - t.Errorf("expected check 3 to not be restarted but found %d:\n%s", n, restarter3) + t.Errorf("expected check 3 to not be restarted but found %d", n) } } @@ -386,11 +382,9 @@ func TestCheckWatcher_Deadlock(t *testing.T) { n := cap(cw.checkUpdateCh) + 1 checks := make([]*structs.ServiceCheck, n) restarters := make([]*fakeWorkloadRestarter, n) - for i := 0; i < n; i++ { + for i := range n { c := testCheck() r := newFakeWorkloadRestarter(cw, - fmt.Sprintf("alloc%d", i), - fmt.Sprintf("task%d", i), fmt.Sprintf("check%d", i), c, ) @@ -399,13 +393,12 @@ func TestCheckWatcher_Deadlock(t *testing.T) { } // Run - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() go cw.Run(ctx) // Watch for _, r := range restarters { - cw.Watch(r.allocID, r.taskName, r.checkName, r.check, r) + cw.Watch(r.checkName, r.check, r) } // Make them all fail diff --git a/client/serviceregistration/workload.go b/client/serviceregistration/workload.go index bd87f367346..a7f1aea95f9 100644 --- a/client/serviceregistration/workload.go +++ b/client/serviceregistration/workload.go @@ -22,10 +22,6 @@ type WorkloadServices struct { // registered, if the provider supports this functionality. ProviderNamespace string - // Restarter allows restarting the task or task group depending on the - // check_restart blocks. - Restarter WorkloadRestarter - // Services and checks to register for the task. Services []*structs.Service diff --git a/client/serviceregistration/wrapper/wrapper.go b/client/serviceregistration/wrapper/wrapper.go index b15be251fe6..0eb015d2d60 100644 --- a/client/serviceregistration/wrapper/wrapper.go +++ b/client/serviceregistration/wrapper/wrapper.go @@ -41,6 +41,17 @@ func NewHandlerWrapper( } } +func (h *HandlerWrapper) CheckWatcher(provider, key string) serviceregistration.CheckWatcher { + switch provider { + case structs.ServiceProviderNomad: + return h.nomadServiceProvider.CheckWatcher("") + case structs.ServiceProviderConsul: + return h.consulServiceProvider.CheckWatcher(key) + default: + return nil + } +} + // RegisterWorkload wraps the serviceregistration.Handler RegisterWorkload // function. It determines which backend provider to call and passes the // workload unless the provider is unknown, in which case an error will be diff --git a/command/agent/agent.go b/command/agent/agent.go index e53f3ffe739..1fe5d81536b 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -100,10 +100,6 @@ type Agent struct { // config entries for Connect gateways consulConfigEntriesFunc consul.ConfigAPIFunc - // consulACLs is Nomad's subset of Consul's ACL API Nomad uses. Used by - // server for legacy token workflow only, so only needs default Consul. - consulACLs consul.ACLsAPI - // client is the launched Nomad Client. Can be nil if the agent isn't // configured to run a client. client *client.Client @@ -1759,10 +1755,6 @@ func (a *Agent) setupConsuls(cfgs []*config.ConsulConfig) error { consulConfigEntries[cluster] = consulClient.ConfigEntries() if cluster == structs.ConsulDefaultCluster { - // Create Consul ACL client for managing tokens in the legacy - // workflow on the server - a.consulACLs = consulClient.ACL() - // Create Consul Catalog client for self service discovery. a.consulCatalog = consulClient.Catalog() } diff --git a/command/agent/consul/group_test.go b/command/agent/consul/group_test.go index 4b5b0d752a2..f54eb879b69 100644 --- a/command/agent/consul/group_test.go +++ b/command/agent/consul/group_test.go @@ -85,7 +85,7 @@ func TestConsul_Connect(t *testing.T) { }, } - require.NoError(t, serviceClient.RegisterWorkload(BuildAllocServices(mock.Node(), alloc, NoopRestarter()))) + require.NoError(t, serviceClient.RegisterWorkload(BuildAllocServices(mock.Node(), alloc))) require.Eventually(t, func() bool { services, err := consulClient.Agent().Services() diff --git a/command/agent/consul/int_test.go b/command/agent/consul/int_test.go index 3e9700b4ad3..3a3d1c27a32 100644 --- a/command/agent/consul/int_test.go +++ b/command/agent/consul/int_test.go @@ -153,6 +153,8 @@ func TestConsul_Integration(t *testing.T) { close(consulRan) }() + sc := consul.NewServiceClientWrapper() + sc.AddClient("default", serviceClient) // Create a closed channel to mock TaskCoordinator.startConditionForTask. // Closed channel indicates this task is not blocked on prestart hooks. closedCh := make(chan struct{}) @@ -162,7 +164,7 @@ func TestConsul_Integration(t *testing.T) { config := &taskrunner.Config{ Alloc: alloc, ClientConfig: conf, - ConsulServices: serviceClient, + ConsulServices: sc, Task: task, TaskDir: taskDir, Logger: logger, @@ -172,7 +174,7 @@ func TestConsul_Integration(t *testing.T) { DeviceManager: devicemanager.NoopMockManager(), DriverManager: drivermanager.TestDriverManager(t), StartConditionMetCh: closedCh, - ServiceRegWrapper: wrapper.NewHandlerWrapper(logger, serviceClient, regMock.NewServiceRegistrationHandler(logger)), + ServiceRegWrapper: wrapper.NewHandlerWrapper(logger, sc, regMock.NewServiceRegistrationHandler(logger)), Wranglers: proclib.MockWranglers(t), AllocHookResources: cstructs.NewAllocHookResources(), } diff --git a/command/agent/consul/service_client.go b/command/agent/consul/service_client.go index 6002b26e8f8..b549d7fcbdd 100644 --- a/command/agent/consul/service_client.go +++ b/command/agent/consul/service_client.go @@ -533,6 +533,10 @@ func (scw *ServiceClientWrapper) RegisterWorkload(workload *serviceregistration. return nil } +func (scw *ServiceClientWrapper) CheckWatcher(cluster string) serviceregistration.CheckWatcher { + return scw.serviceClients[cluster].checkWatcher +} + func (scw *ServiceClientWrapper) RemoveWorkload(workload *serviceregistration.WorkloadServices) { scw.lock.RLock() defer scw.lock.RUnlock() @@ -907,10 +911,8 @@ INIT: default: } } - backoff := c.retryInterval * time.Duration(failures) - if backoff > c.maxRetryInterval { - backoff = c.maxRetryInterval - } + + backoff := min(c.retryInterval*time.Duration(failures), c.maxRetryInterval) retryTimer.Reset(backoff) } else { if failures > 0 { @@ -1561,17 +1563,6 @@ func (c *ServiceClient) RegisterWorkload(workload *serviceregistration.WorkloadS c.commit(ops) - // Start watching checks. Done after service registrations are built - // since an error building them could leak watches. - for _, service := range workload.Services { - serviceID := serviceregistration.MakeAllocServiceID(workload.AllocInfo.AllocID, workload.Name(), service) - for _, check := range service.Checks { - if check.TriggersRestarts() { - checkID := MakeCheckID(serviceID, check) - c.checkWatcher.Watch(workload.AllocInfo.AllocID, workload.Name(), checkID, check, workload.Restarter) - } - } - } return nil } @@ -1602,11 +1593,6 @@ func (c *ServiceClient) UpdateWorkload(old, newWorkload *serviceregistration.Wor for _, check := range existingSvc.Checks { cid := MakeCheckID(existingID, check) ops.deregChecks = append(ops.deregChecks, cid) - - // Unwatch watched checks - if check.TriggersRestarts() { - c.checkWatcher.Unwatch(cid) - } } continue } @@ -1659,21 +1645,11 @@ func (c *ServiceClient) UpdateWorkload(old, newWorkload *serviceregistration.Wor sreg.CheckOnUpdate[registration.ID] = check.OnUpdate ops.regChecks = append(ops.regChecks, registration) } - - // Update all watched checks as CheckRestart fields aren't part of ID - if check.TriggersRestarts() { - c.checkWatcher.Watch(newWorkload.AllocInfo.AllocID, newWorkload.Name(), checkID, check, newWorkload.Restarter) - } } // Remove existing checks not in updated service - for cid, check := range existingChecks { + for cid := range existingChecks { ops.deregChecks = append(ops.deregChecks, cid) - - // Unwatch checks - if check.TriggersRestarts() { - c.checkWatcher.Unwatch(cid) - } } } @@ -1698,18 +1674,6 @@ func (c *ServiceClient) UpdateWorkload(old, newWorkload *serviceregistration.Wor c.addRegistrations(newWorkload.AllocInfo.AllocID, newWorkload.Name(), regs) c.commit(ops) - - // Start watching checks. Done after service registrations are built - // since an error building them could leak watches. - for serviceID, service := range newIDs { - for _, check := range service.Checks { - if check.TriggersRestarts() { - checkID := MakeCheckID(serviceID, check) - c.checkWatcher.Watch(newWorkload.AllocInfo.AllocID, newWorkload.Name(), checkID, check, newWorkload.Restarter) - } - } - } - return nil } @@ -1726,10 +1690,6 @@ func (c *ServiceClient) RemoveWorkload(workload *serviceregistration.WorkloadSer for _, check := range service.Checks { cid := MakeCheckID(id, check) ops.deregChecks = append(ops.deregChecks, cid) - - if check.TriggersRestarts() { - c.checkWatcher.Unwatch(cid) - } } } diff --git a/command/agent/consul/service_client_test.go b/command/agent/consul/service_client_test.go index c46398b1a86..5809b03776f 100644 --- a/command/agent/consul/service_client_test.go +++ b/command/agent/consul/service_client_test.go @@ -524,7 +524,6 @@ func TestServiceRegistration_CheckOnUpdate(t *testing.T) { AllocID: allocID, Task: "taskname", }, - Restarter: &restartRecorder{}, Services: []*structs.Service{ { Name: "taskname-service", diff --git a/command/agent/consul/structs.go b/command/agent/consul/structs.go index 19fb3765578..98376291f5c 100644 --- a/command/agent/consul/structs.go +++ b/command/agent/consul/structs.go @@ -12,7 +12,7 @@ import ( ) func BuildAllocServices( - node *structs.Node, alloc *structs.Allocation, restarter serviceregistration.WorkloadRestarter) *serviceregistration.WorkloadServices { + node *structs.Node, alloc *structs.Allocation) *serviceregistration.WorkloadServices { //TODO(schmichael) only support one network for now net := alloc.AllocatedResources.Shared.Networks[0] @@ -34,8 +34,6 @@ func BuildAllocServices( // Copy PortLabels from group network PortMap: net.PortLabels(), }, - - Restarter: restarter, DriverExec: nil, } diff --git a/command/agent/consul/unit_test.go b/command/agent/consul/unit_test.go index a3fdedc19a4..33b96941bca 100644 --- a/command/agent/consul/unit_test.go +++ b/command/agent/consul/unit_test.go @@ -36,7 +36,6 @@ func testWorkload() *serviceregistration.WorkloadServices { AllocID: uuid.Generate(), Task: "taskname", }, - Restarter: &restartRecorder{}, Services: []*structs.Service{ { Name: "taskname-service", diff --git a/nomad/structs/services.go b/nomad/structs/services.go index f74872ea6f0..a647d9714c3 100644 --- a/nomad/structs/services.go +++ b/nomad/structs/services.go @@ -241,7 +241,7 @@ func (sc *ServiceCheck) Canonicalize(serviceName, taskName string) { } // Set task name if not already set - if sc.TaskName == "" && taskName != "group" { + if sc.TaskName == "" { sc.TaskName = taskName } @@ -704,6 +704,7 @@ func (s *Service) Canonicalize(job, taskGroup, task, jobNamespace string) { s.TaggedAddresses = nil } + // TODO: mismithhisler: a task named group (yes, weird) will break this // Set the task name if not already set if s.TaskName == "" && task != "group" { s.TaskName = task