Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 0 additions & 65 deletions pkg/transport/proxy/manager.go

This file was deleted.

119 changes: 50 additions & 69 deletions pkg/workloads/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/stacklok/toolhive/pkg/secrets"
"github.com/stacklok/toolhive/pkg/state"
"github.com/stacklok/toolhive/pkg/transport"
"github.com/stacklok/toolhive/pkg/transport/proxy"
"github.com/stacklok/toolhive/pkg/workloads/statuses"
"github.com/stacklok/toolhive/pkg/workloads/types"
)
Expand Down Expand Up @@ -575,16 +574,39 @@ func (d *defaultManager) getWorkloadContainer(ctx context.Context, name string)
return &container, nil
}

// stopProcess stops the proxy process associated with the container
func (d *defaultManager) stopProcess(ctx context.Context, name string) {
if name == "" {
logger.Warnf("Warning: Could not find base container name in labels")
return
}

// Try to read the PID and kill the process
pid, err := d.statuses.GetWorkloadPID(ctx, name)
if err != nil {
logger.Errorf("No PID file found for %s, proxy may not be running in detached mode", name)
return
}

// PID file found, try to kill the process
logger.Infof("Stopping proxy process (PID: %d)...", pid)
if err := process.KillProcess(pid); err != nil {
logger.Warnf("Warning: Failed to kill proxy process: %v", err)
} else {
logger.Info("Proxy process stopped")
}

// Clean up PID file after successful kill
if err := process.RemovePIDFile(name); err != nil {
logger.Warnf("Warning: Failed to remove PID file: %v", err)
}
}

// stopProxyIfNeeded stops the proxy process if the workload has a base name
func (d *defaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) {
logger.Infof("Removing proxy process for %s...", name)
if baseName != "" {
proxy.StopProcess(baseName)
// TODO: refactor the StopProcess function to stop dealing explicitly with PID files.
// Note that this is not a blocker for k8s since this code path is not called there.
if err := d.statuses.ResetWorkloadPID(ctx, baseName); err != nil {
logger.Warnf("Warning: Failed to reset workload %s PID: %v", name, err)
}
d.stopProcess(ctx, baseName)
}
}

Expand Down Expand Up @@ -798,32 +820,40 @@ func (d *defaultManager) restartContainerWorkload(ctx context.Context, name stri
workloadName = name
}

// Get workload state information using the original name
workloadState, err := d.getWorkloadState(ctx, name)
if err != nil {
// Get workload status using the status manager
workload, err := d.statuses.GetWorkload(ctx, name)
if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) {
return err
}

// Check if already running - use container name for this check
if d.isWorkloadAlreadyRunning(containerName, workloadState) {
// Check if already running - compare status to WorkloadStatusRunning
if err == nil && workload.Status == rt.WorkloadStatusRunning {
logger.Infof("Container %s is already running", containerName)
return nil
}

// Load runner configuration from state
mcpRunner, err := d.loadRunnerFromState(ctx, workloadState.BaseName)
mcpRunner, err := d.loadRunnerFromState(ctx, workloadName)
if err != nil {
return fmt.Errorf("failed to load state for %s: %v", workloadState.BaseName, err)
return fmt.Errorf("failed to load state for %s: %v", workloadName, err)
}

// Set workload status to starting - use the workload name for status operations
if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStarting, ""); err != nil {
logger.Warnf("Failed to set workload %s status to starting: %v", workloadName, err)
}
logger.Infof("Loaded configuration from state for %s", workloadState.BaseName)
logger.Infof("Loaded configuration from state for %s", workloadName)

// Stop container if running but proxy is not - use the container name for runtime operations
if err := d.stopContainerIfNeeded(ctx, containerName, workloadName, workloadState); err != nil {
return err
// Stop container if needed - since workload is not in running status, check if container needs stopping
if container.IsRunning() {
logger.Infof("Container %s is running but workload is not in running state. Stopping container...", containerName)
if err := d.runtime.StopWorkload(ctx, containerName); err != nil {
if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, ""); statusErr != nil {
logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr)
}
return fmt.Errorf("failed to stop container %s: %v", containerName, err)
}
logger.Infof("Container %s stopped", containerName)
}

// Start the workload with background context to avoid timeout cancellation
Expand All @@ -840,34 +870,6 @@ type workloadState struct {
ProxyRunning bool
}

// getWorkloadState retrieves the current state of a workload
func (d *defaultManager) getWorkloadState(ctx context.Context, name string) (*workloadState, error) {
workloadSt := &workloadState{}

// Try to find the container
container, err := d.runtime.GetWorkloadInfo(ctx, name)
if err != nil {
if errors.Is(err, rt.ErrWorkloadNotFound) {
logger.Warnf("Warning: Failed to find container: %v", err)
logger.Warnf("Trying to find state with name %s directly...", name)
// Try to use the provided name as the base name
workloadSt.BaseName = name
workloadSt.Running = false
} else {
return nil, fmt.Errorf("failed to find workload %s: %v", name, err)
}
} else {
// Container found, check if it's running and get the base name
workloadSt.Running = container.IsRunning()
workloadSt.BaseName = labels.GetContainerBaseName(container.Labels)
}

// Check if the proxy process is running
workloadSt.ProxyRunning = proxy.IsRunning(workloadSt.BaseName)

return workloadSt, nil
}

// getRemoteWorkloadState retrieves the current state of a remote workload
func (d *defaultManager) getRemoteWorkloadState(ctx context.Context, name, baseName string) *workloadState {
workloadSt := &workloadState{
Expand All @@ -884,9 +886,6 @@ func (d *defaultManager) getRemoteWorkloadState(ctx context.Context, name, baseN
workloadSt.Running = workload.Status == rt.WorkloadStatusRunning
}

// Check if the detached process is actually running
workloadSt.ProxyRunning = proxy.IsRunning(baseName)

return workloadSt
}

Expand All @@ -899,25 +898,6 @@ func (*defaultManager) isWorkloadAlreadyRunning(name string, workloadSt *workloa
return false
}

// stopContainerIfNeeded stops the container if it's running but proxy is not
func (d *defaultManager) stopContainerIfNeeded(
ctx context.Context, containerName, workloadName string, workloadSt *workloadState,
) error {
if !workloadSt.Running {
return nil
}

logger.Infof("Container %s is running but proxy is not. Stopping container...", containerName)
if err := d.runtime.StopWorkload(ctx, containerName); err != nil {
if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, ""); statusErr != nil {
logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr)
}
return fmt.Errorf("failed to stop container %s: %v", containerName, err)
}
logger.Infof("Container %s stopped", containerName)
return nil
}

// startWorkload starts the workload in either foreground or background mode
func (d *defaultManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error {
logger.Infof("Starting tooling server %s...", name)
Expand Down Expand Up @@ -1024,8 +1004,9 @@ func (d *defaultManager) stopSingleContainerWorkload(ctx context.Context, worklo
if labels.IsAuxiliaryWorkload(workload.Labels) {
logger.Debugf("Skipping proxy stop for auxiliary workload %s", name)
} else {
proxy.StopProcess(name)
d.stopProcess(ctx, name)
}

// TODO: refactor the StopProcess function to stop dealing explicitly with PID files.
// Note that this is not a blocker for k8s since this code path is not called there.
if err := d.statuses.ResetWorkloadPID(ctx, name); err != nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/workloads/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,8 @@ func TestDefaultManager_updateSingleWorkload(t *testing.T) {
State: "running",
Labels: map[string]string{"toolhive-basename": "test-workload"},
}, nil)
// Mock GetWorkloadPID call from stopProcess
sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil)
rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil)
sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil)

Expand Down Expand Up @@ -1288,6 +1290,8 @@ func TestDefaultManager_updateSingleWorkload(t *testing.T) {
State: "running",
Labels: map[string]string{"toolhive-basename": "test-workload"},
}, nil)
// Mock GetWorkloadPID call from stopProcess
sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil)
rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil)
sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil)

Expand Down
Loading
Loading