Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion client/allocrunner/alloc_runner_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ func (ar *allocRunner) initRunnerHooks(config *clientconfig.Config) error {
providerNamespace: alloc.ServiceProviderNamespace(),
serviceRegWrapper: ar.serviceRegWrapper,
hookResources: ar.hookResources,
restarter: ar,
networkStatus: ar,
logger: hookLogger,
shutdownDelayCtx: ar.shutdownDelayCtx,
Expand Down
4 changes: 0 additions & 4 deletions client/allocrunner/group_service_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ type groupServiceHook struct {
group string
tg *structs.TaskGroup
namespace string
restarter serviceregistration.WorkloadRestarter
prerun bool
deregistered bool
networkStatus structs.NetworkStatus
Expand Down Expand Up @@ -63,7 +62,6 @@ type groupServiceHook struct {

type groupServiceHookConfig struct {
alloc *structs.Allocation
restarter serviceregistration.WorkloadRestarter
networkStatus structs.NetworkStatus
shutdownDelayCtx context.Context
logger hclog.Logger
Expand Down Expand Up @@ -92,7 +90,6 @@ func newGroupServiceHook(cfg groupServiceHookConfig) *groupServiceHook {
jobID: cfg.alloc.JobID,
group: cfg.alloc.TaskGroup,
namespace: cfg.alloc.Namespace,
restarter: cfg.restarter,
providerNamespace: cfg.providerNamespace,
delay: shutdownDelay,
networkStatus: cfg.networkStatus,
Expand Down Expand Up @@ -311,7 +308,6 @@ func (h *groupServiceHook) getWorkloadServicesLocked() *serviceregistration.Work
return &serviceregistration.WorkloadServices{
AllocInfo: info,
ProviderNamespace: h.providerNamespace,
Restarter: h.restarter,
Services: h.services,
Networks: h.networks,
NetworkStatus: netStatus,
Expand Down
11 changes: 0 additions & 11 deletions client/allocrunner/group_service_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/hashicorp/nomad/client/serviceregistration/wrapper"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/client/taskenv"
agentconsul "github.com/hashicorp/nomad/command/agent/consul"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
Expand Down Expand Up @@ -46,7 +45,6 @@ func TestGroupServiceHook_NoGroupServices(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -88,7 +86,6 @@ func TestGroupServiceHook_ShutdownDelayUpdate(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -130,7 +127,6 @@ func TestGroupServiceHook_GroupServices(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -199,7 +195,6 @@ func TestGroupServiceHook_GroupServicesCheckUpdates(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: resources,
})
Expand Down Expand Up @@ -254,7 +249,6 @@ func TestGroupServiceHook_GroupServices_Nomad(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -309,7 +303,6 @@ func TestGroupServiceHook_NoNetwork(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -357,7 +350,6 @@ func TestGroupServiceHook_getWorkloadServices(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -400,7 +392,6 @@ func TestGroupServiceHook_PreKill(t *testing.T) {
alloc: alloc,
serviceRegWrapper: regWrapper,
shutdownDelayCtx: shutDownCtx,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -448,7 +439,6 @@ func TestGroupServiceHook_PreKill(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down Expand Up @@ -500,7 +490,6 @@ func TestGroupServiceHook_PreKill(t *testing.T) {
h := newGroupServiceHook(groupServiceHookConfig{
alloc: alloc,
serviceRegWrapper: regWrapper,
restarter: agentconsul.NoopRestarter(),
logger: logger,
hookResources: cstructs.NewAllocHookResources(),
})
Expand Down
135 changes: 135 additions & 0 deletions client/allocrunner/taskrunner/check_restart_hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright IBM Corp. 2015, 2025
// SPDX-License-Identifier: BUSL-1.1

package taskrunner

import (
"context"
"fmt"

"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/serviceregistration"
"github.com/hashicorp/nomad/client/serviceregistration/wrapper"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/command/agent/consul"
"github.com/hashicorp/nomad/nomad/structs"
)

type hookCheck struct {
// providerType defines the backend service that is checking
// the service. This is currently either Nomad or Consul.
providerType string

// providerNS is the providers namespace that the service is
// registered. When a provider implements namespaces (i.e. Consul),
// Nomad runs a single check watcher per namespace.
providerNS string

// checkID is the ID of the check used to register a check_restart watch
checkID string

// check is the actual Nomad service check configuration
check *structs.ServiceCheck
}

// The checkRestartHook is responsible for registering/deregistering _both_ group and task
// check_start blocks with the appropriate CheckWatcher. This is a standalone hook and not part
// of the service hook because restarting checks is task specific, even though check_restart
// can be defined at the group level. Therefore this task will look at both TG and task services.
type checkRestartHook struct {
checks []*hookCheck
handler *wrapper.HandlerWrapper
wr serviceregistration.WorkloadRestarter
allocID string
taskName string
taskEnv *taskenv.TaskEnv

tgName string
tgServices []*structs.Service
taskServices []*structs.Service
}

func newCheckRestartHook(alloc *structs.Allocation, task *structs.Task, handler *wrapper.HandlerWrapper, restarter serviceregistration.WorkloadRestarter) *checkRestartHook {
tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup)
return &checkRestartHook{
handler: handler,
allocID: alloc.ID,
taskName: task.Name,
tgName: tg.Name,
tgServices: tg.Services,
taskServices: task.Services,
wr: restarter,
}
}

func (h *checkRestartHook) Name() string {
return "check_restart"
}

func (h *checkRestartHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, _ *interfaces.TaskPrestartResponse) error {
var checks []*hookCheck
for _, s := range taskenv.InterpolateServices(req.TaskEnv, h.tgServices) {
for _, c := range s.Checks {
if c.TriggersRestarts() && c.TaskName == h.taskName {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.TaskName == h.taskName means we would only register that for group level checks, only the task specified in the service/check will be restarted. Also if the task is not specified the check won't even be registered.

checks = append(checks, &hookCheck{
providerType: s.Provider,
providerNS: s.Cluster,
check: c,
// TODO: does this work for Nomad?
checkID: checkID(h.allocID, h.tgName, fmt.Sprintf("group-%s", h.tgName), s.Provider, c, s),
})
}
}
}

for _, s := range taskenv.InterpolateServices(req.TaskEnv, h.taskServices) {
for _, c := range s.Checks {
if c.TriggersRestarts() {
checks = append(checks, &hookCheck{
providerType: s.Provider,
providerNS: s.Cluster,
check: c,
// TODO: does this work for Nomad?
checkID: checkID(h.allocID, h.tgName, h.taskName, s.Provider, c, s),
})
}
}
}
h.checks = checks

for _, c := range h.checks {
watcher := h.handler.CheckWatcher(c.providerType, c.providerNS)
watcher.Watch(c.checkID, c.check, h.wr)
}
return nil
}

func (h *checkRestartHook) Exited(ctx context.Context, req *interfaces.TaskExitedRequest, resp *interfaces.TaskExitedResponse) error {
for _, c := range h.checks {
watcher := h.handler.CheckWatcher(c.providerType, c.providerNS)
watcher.Unwatch(c.checkID)
}
return nil
}

func (h *checkRestartHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
for _, c := range h.checks {
watcher := h.handler.CheckWatcher(c.providerType, c.providerNS)
watcher.Unwatch(c.checkID)
}
return nil
}

// checkID returns a provider specific checkID for the workload. Unfortunately nomad and consul use different
// methods for creating checkID's so these are quite different. Consul distinguishes between group and task
// checkID's, but Nomad seems to just always use the task group name?
func checkID(allocID, tg, task, ptype string, check *structs.ServiceCheck, service *structs.Service) string {
switch ptype {
case "nomad":
return string(structs.NomadCheckID(allocID, tg, check))
case "consul":
return consul.MakeCheckID(serviceregistration.MakeAllocServiceID(allocID, task, service), check)
default:
return ""
}
}
123 changes: 123 additions & 0 deletions client/allocrunner/taskrunner/check_restart_hook_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright IBM Corp. 2015, 2025
// SPDX-License-Identifier: BUSL-1.1

package taskrunner

import (
"testing"

"github.com/hashicorp/nomad/client/allocrunner/interfaces"
regMock "github.com/hashicorp/nomad/client/serviceregistration/mock"
"github.com/hashicorp/nomad/client/serviceregistration/wrapper"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/command/agent/consul"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/shoenig/test/must"
tmock "github.com/stretchr/testify/mock"
)

func TestCheckRestartHook_Prestart(t *testing.T) {
logger := testlog.HCLogger(t)

alloc := mock.Alloc()
alloc.Job.Canonicalize()

handler := regMock.NewServiceRegistrationHandler(logger)

mockWatcher := &regMock.MockUniversalWatcher{}
mockWatcher.On("Watch", tmock.Anything, tmock.Anything, tmock.Anything)
handler.UniversalWatcher = mockWatcher

regWrap := wrapper.NewHandlerWrapper(logger, handler, handler)

service := &structs.Service{
Provider: "nomad",
Checks: []*structs.ServiceCheck{
{
TaskName: "web",
CheckRestart: &structs.CheckRestart{
Limit: 1,
},
},
},
}

// group level service
alloc.Job.LookupTaskGroup("web").Services = []*structs.Service{service}

// task level service
task := alloc.LookupTask("web")
task.Services = []*structs.Service{service}

testHook := newCheckRestartHook(alloc, task, regWrap, consul.NoopRestarter())

err := testHook.Prestart(
t.Context(),
&interfaces.TaskPrestartRequest{TaskEnv: taskenv.NewEmptyTaskEnv()},
&interfaces.TaskPrestartResponse{},
)
must.NoError(t, err)
must.True(t, mockWatcher.AssertNumberOfCalls(t, "Watch", 2))
}

func TestCheckRestartHook_Exited(t *testing.T) {
logger := testlog.HCLogger(t)

alloc := mock.Alloc()
alloc.Job.Canonicalize()

mockWatcher := &regMock.MockUniversalWatcher{}
mockWatcher.On("Unwatch", tmock.Anything)

handler := regMock.NewServiceRegistrationHandler(logger)
handler.UniversalWatcher = mockWatcher

regWrap := wrapper.NewHandlerWrapper(logger, handler, handler)

testHook := newCheckRestartHook(alloc, alloc.LookupTask("web"), regWrap, consul.NoopRestarter())
testHook.checks = []*hookCheck{
{
providerType: "nomad",
},
}

err := testHook.Exited(
t.Context(),
&interfaces.TaskExitedRequest{},
&interfaces.TaskExitedResponse{},
)
must.NoError(t, err)
must.True(t, mockWatcher.AssertNumberOfCalls(t, "Unwatch", 1))
}

func TestCheckRestartHook_Stop(t *testing.T) {
logger := testlog.HCLogger(t)

alloc := mock.Alloc()
alloc.Job.Canonicalize()

mockWatcher := &regMock.MockUniversalWatcher{}
mockWatcher.On("Unwatch", tmock.Anything)

handler := regMock.NewServiceRegistrationHandler(logger)
handler.UniversalWatcher = mockWatcher

regWrap := wrapper.NewHandlerWrapper(logger, handler, handler)

testHook := newCheckRestartHook(alloc, alloc.LookupTask("web"), regWrap, consul.NoopRestarter())
testHook.checks = []*hookCheck{
{
providerType: "nomad",
},
}

err := testHook.Stop(
t.Context(),
&interfaces.TaskStopRequest{},
&interfaces.TaskStopResponse{},
)
must.NoError(t, err)
must.True(t, mockWatcher.AssertNumberOfCalls(t, "Unwatch", 1))
}
Loading
Loading