diff --git a/cmd/dev.go b/cmd/dev.go index 61027ab9..fe43a8b3 100644 --- a/cmd/dev.go +++ b/cmd/dev.go @@ -16,7 +16,6 @@ import ( "github.com/agentuity/cli/internal/util" "github.com/agentuity/go-common/env" cstr "github.com/agentuity/go-common/string" - csys "github.com/agentuity/go-common/sys" "github.com/agentuity/go-common/tui" "github.com/bep/debounce" "github.com/spf13/cobra" @@ -103,7 +102,10 @@ Examples: } defer websocketConn.Close() - projectServerCmd, err := dev.CreateRunProjectCmd(ctx, log, theproject, websocketConn, dir, orgId, port) + processCtx := context.Background() + var pid int + + projectServerCmd, err := dev.CreateRunProjectCmd(processCtx, log, theproject, websocketConn, dir, orgId, port) if err != nil { errsystem.New(errsystem.ErrInvalidConfiguration, err, errsystem.WithContextMessage("Failed to run project")).ShowErrorAndExit() } @@ -139,7 +141,7 @@ Examples: build(false) isDeliberateRestart = true log.Debug("killing project server") - dev.KillProjectServer(projectServerCmd) + dev.KillProjectServer(log, projectServerCmd, pid) log.Debug("killing project server done") } @@ -160,6 +162,8 @@ Examples: errsystem.New(errsystem.ErrInvalidConfiguration, err, errsystem.WithContextMessage(fmt.Sprintf("Failed to start project: %s", err))).ShowErrorAndExit() } + pid = projectServerCmd.Process.Pid + websocketConn.StartReadingMessages(ctx, log, port) devUrl := websocketConn.WebURL(appUrl) @@ -168,16 +172,16 @@ Examples: go func() { for { - log.Trace("waiting for project server to exit") + log.Trace("waiting for project server to exit (pid: %d)", pid) if err := projectServerCmd.Wait(); err != nil { - log.Error("project server exited with error: %s", err) + log.Error("project server (pid: %d) exited with error: %s", pid, err) } if projectServerCmd.ProcessState != nil { - log.Debug("project server exited with code %d", projectServerCmd.ProcessState.ExitCode()) + log.Debug("project server (pid: %d) exited with code %d", pid, projectServerCmd.ProcessState.ExitCode()) } else { - log.Debug("project server exited") + log.Debug("project server (pid: %d) exited", pid) } - log.Debug("isDeliberateRestart: %t", isDeliberateRestart) + log.Debug("isDeliberateRestart: %t, pid: %d", isDeliberateRestart, pid) if !isDeliberateRestart { return } @@ -186,31 +190,32 @@ Examples: if isDeliberateRestart { isDeliberateRestart = false log.Trace("restarting project server") - projectServerCmd, err = dev.CreateRunProjectCmd(ctx, log, theproject, websocketConn, dir, orgId, port) + projectServerCmd, err = dev.CreateRunProjectCmd(processCtx, log, theproject, websocketConn, dir, orgId, port) if err != nil { errsystem.New(errsystem.ErrInvalidConfiguration, err, errsystem.WithContextMessage("Failed to run project")).ShowErrorAndExit() } if err := projectServerCmd.Start(); err != nil { errsystem.New(errsystem.ErrInvalidConfiguration, err, errsystem.WithContextMessage(fmt.Sprintf("Failed to start project: %s", err))).ShowErrorAndExit() } + pid = projectServerCmd.Process.Pid + log.Trace("restarted project server (pid: %d)", pid) } } }() + teardown := func() { + watcher.Close(log) + websocketConn.Close() + dev.KillProjectServer(log, projectServerCmd, pid) + } + select { case <-websocketConn.Done(): log.Info("live dev connection closed, shutting down") - dev.KillProjectServer(projectServerCmd) - watcher.Close(log) + teardown() case <-ctx.Done(): log.Info("context done, shutting down") - websocketConn.Close() - watcher.Close(log) - case <-csys.CreateShutdownChannel(): - log.Info("shutdown signal received, shutting down") - dev.KillProjectServer(projectServerCmd) - websocketConn.Close() - watcher.Close(log) + teardown() } }, } diff --git a/go.mod b/go.mod index 8bfa8338..db91bd94 100644 --- a/go.mod +++ b/go.mod @@ -133,7 +133,7 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/net v0.38.0 // indirect golang.org/x/sync v0.12.0 // indirect - golang.org/x/sys v0.31.0 // indirect + golang.org/x/sys v0.32.0 golang.org/x/text v0.23.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250303144028-a0af3efb3deb // indirect diff --git a/go.sum b/go.sum index 7f09d02a..62ad6e40 100644 --- a/go.sum +++ b/go.sum @@ -333,8 +333,8 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= diff --git a/internal/dev/dev.go b/internal/dev/dev.go index 92acd9d5..4bdc3879 100644 --- a/internal/dev/dev.go +++ b/internal/dev/dev.go @@ -7,7 +7,6 @@ import ( "os" "os/exec" "strconv" - "syscall" "time" "github.com/agentuity/cli/internal/project" @@ -15,8 +14,20 @@ import ( "github.com/agentuity/go-common/logger" ) -func KillProjectServer(projectServerCmd *exec.Cmd) { +func KillProjectServer(logger logger.Logger, projectServerCmd *exec.Cmd, pid int) { + if pid > 0 { + processes, err := getProcessTree(logger, pid) + if err != nil { + logger.Error("error getting process tree for parent (pid: %d): %s", pid, err) + } + for _, childPid := range processes { + logger.Debug("killing child process (pid: %d)", childPid) + kill(logger, childPid) + } + } if projectServerCmd == nil || projectServerCmd.ProcessState == nil || projectServerCmd.ProcessState.Exited() { + logger.Debug("project server already exited (pid: %d)", pid) + kill(logger, pid) return } ch := make(chan struct{}, 1) @@ -26,8 +37,10 @@ func KillProjectServer(projectServerCmd *exec.Cmd) { }() if projectServerCmd.Process != nil { - // Try SIGINT first (Ctrl+C equivalent) - projectServerCmd.Process.Signal(syscall.SIGINT) + logger.Debug("killing parent process %d", pid) + if err := terminateProcess(logger, projectServerCmd); err != nil { + logger.Error("error terminating project server: %s", err) + } } // Wait a bit longer for SIGTERM to take effect diff --git a/internal/dev/dev_unix.go b/internal/dev/dev_unix.go new file mode 100644 index 00000000..65e3faae --- /dev/null +++ b/internal/dev/dev_unix.go @@ -0,0 +1,100 @@ +//go:build !windows +// +build !windows + +package dev + +import ( + "bytes" + "fmt" + "os/exec" + "strconv" + "strings" + "syscall" + "time" + + "github.com/agentuity/go-common/logger" +) + +func terminateProcess(logger logger.Logger, cmd *exec.Cmd) error { + logger.Debug("terminateProcess: %s", cmd) + if cmd.Process != nil { + // Get the process group ID (negative PID) + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err != nil { + // If we can't get the process group, just kill the process directly + cmd.Process.Signal(syscall.SIGINT) + } else { + // Kill the entire process group + syscall.Kill(-pgid, syscall.SIGINT) + } + + // Wait a short time for graceful shutdown + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case <-time.After(5 * time.Second): + // If process hasn't terminated, use SIGKILL on the process group + if err == nil { + // Kill the entire process group with SIGKILL + syscall.Kill(-pgid, syscall.SIGKILL) + } else { + // Fallback to just killing the process + cmd.Process.Signal(syscall.SIGKILL) + } + case <-done: + // Process terminated gracefully + } + } + return nil +} + +// getProcessTree returns a list of all descendant PIDs of the given parent PID. +func getProcessTree(logger logger.Logger, parentPID int) ([]int, error) { + logger.Debug("getting process tree for parent (pid: %d)", parentPID) + cmd := exec.Command("ps", "-eo", "pid,ppid") // works on both macOS and Linux + var out bytes.Buffer + cmd.Stdout = &out + if err := cmd.Run(); err != nil { + logger.Debug("failed to run ps: %s", err) + return nil, fmt.Errorf("failed to run ps: %w", err) + } + + lines := strings.Split(out.String(), "\n") + pidMap := make(map[int][]int) // PPID -> []PID + + for _, line := range lines[1:] { // skip header + fields := strings.Fields(line) + if len(fields) != 2 { + continue + } + + pid, err1 := strconv.Atoi(fields[0]) + ppid, err2 := strconv.Atoi(fields[1]) + if err1 != nil || err2 != nil { + continue + } + + pidMap[ppid] = append(pidMap[ppid], pid) + } + + // Recursively collect descendants + var collect func(int) + descendants := []int{} + collect = func(ppid int) { + for _, child := range pidMap[ppid] { + descendants = append(descendants, child) + collect(child) + } + } + collect(parentPID) + + return descendants, nil +} + +func kill(logger logger.Logger, pid int) error { + logger.Debug("killing process (pid: %d)", pid) + return syscall.Kill(pid, syscall.SIGTERM) +} diff --git a/internal/dev/dev_windows.go b/internal/dev/dev_windows.go new file mode 100644 index 00000000..5001cb61 --- /dev/null +++ b/internal/dev/dev_windows.go @@ -0,0 +1,157 @@ +//go:build windows +// +build windows + +package dev + +import ( + "fmt" + "os/exec" + "unsafe" + + "github.com/agentuity/go-common/logger" + "golang.org/x/sys/windows" +) + +// getProcessTree returns all descendant PIDs for a given parent PID +func getProcessTree(logger logger.Logger, pid int) ([]int, error) { + var pids []int + + // Create a snapshot of all processes + snapshot, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if err != nil { + return nil, fmt.Errorf("failed to create process snapshot: %v", err) + } + defer windows.CloseHandle(snapshot) + + // Initialize process entry + var processEntry windows.ProcessEntry32 + processEntry.Size = uint32(unsafe.Sizeof(processEntry)) + + // Get first process + err = windows.Process32First(snapshot, &processEntry) + if err != nil { + return nil, fmt.Errorf("failed to get first process: %v", err) + } + + // Create a map to track all processes and their parent-child relationships + processMap := make(map[uint32][]uint32) + processNames := make(map[uint32]string) + + // First pass: build the process tree + for { + parentID := processEntry.ParentProcessID + processID := processEntry.ProcessID + + // Convert the process name from UTF-16 to string + name := windows.UTF16ToString(processEntry.ExeFile[:]) + + // Skip the System process (PID 4) as it's a special case + if processID != 4 { + processMap[parentID] = append(processMap[parentID], processID) + processNames[processID] = name + } + + err = windows.Process32Next(snapshot, &processEntry) + if err != nil { + break + } + } + + // Function to recursively get all descendants + var getDescendants func(parentID uint32, depth int) + getDescendants = func(parentID uint32, depth int) { + children, exists := processMap[parentID] + if !exists { + return + } + + for _, childID := range children { + // Skip the System process (PID 4) and the parent process itself + if childID != 4 && childID != uint32(pid) { + pids = append(pids, int(childID)) + // Log the process tree structure + logger.Debug("Found process: %s (pid: %d, parent: %d, depth: %d)", + processNames[childID], childID, parentID, depth) + getDescendants(childID, depth+1) + } + } + } + + // Start the recursive search from the given PID + logger.Debug("Starting process tree search from PID: %d", pid) + getDescendants(uint32(pid), 0) + + return pids, nil +} + +// kill terminates a process by PID +func kill(logger logger.Logger, pid int) error { + // Open the process with terminate access + handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE|windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) + if err != nil { + return fmt.Errorf("failed to open process: %v", err) + } + defer windows.CloseHandle(handle) + + // Get process name for logging + var name [windows.MAX_PATH]uint16 + var size uint32 = windows.MAX_PATH + err = windows.QueryFullProcessImageName(handle, 0, &name[0], &size) + processName := "unknown" + if err == nil { + processName = windows.UTF16ToString(name[:size]) + } + + logger.Debug("Killing process: %s (pid: %d)", processName, pid) + + // Terminate the process + err = windows.TerminateProcess(handle, 1) + if err != nil { + return fmt.Errorf("failed to terminate process: %v", err) + } + + return nil +} + +func terminateProcess(logger logger.Logger, cmd *exec.Cmd) error { + logger.Debug("terminateProcess: %s", cmd) + if cmd.Process != nil { + // Create a job object + job, err := windows.CreateJobObject(nil, nil) + if err != nil { + return fmt.Errorf("failed to create job object: %v", err) + } + defer windows.CloseHandle(job) + + // Configure the job object to terminate all processes when the job is terminated + info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ + BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ + LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + } + + // Set the job object information + _, err = windows.SetInformationJobObject( + job, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + ) + if err != nil { + return fmt.Errorf("failed to set job object information: %v", err) + } + + // Assign the process to the job object + err = windows.AssignProcessToJobObject(job, windows.Handle(cmd.Process.Pid)) + if err != nil { + return fmt.Errorf("failed to assign process to job object: %v", err) + } + + // Terminate the job object, which will kill all processes in the job + err = windows.TerminateJobObject(job, 1) + if err != nil { + return fmt.Errorf("failed to terminate job object: %v", err) + } + } + return nil +} diff --git a/internal/dev/watcher.go b/internal/dev/watcher.go index 105e7b7a..26b7c481 100644 --- a/internal/dev/watcher.go +++ b/internal/dev/watcher.go @@ -16,6 +16,13 @@ type FileWatcher struct { dir string } +var ignorePatterns = []string{ + "__pycache__", + "__test__", + "node_modules", + ".pyc", +} + func NewWatcher(logger logger.Logger, dir string, patterns []string, callback func(string)) (*FileWatcher, error) { watcher, err := fsnotify.NewWatcher() if err != nil { @@ -33,6 +40,11 @@ func NewWatcher(logger logger.Logger, dir string, patterns []string, callback fu if err != nil { return err } + for _, ignorePattern := range ignorePatterns { + if strings.Contains(path, ignorePattern) { + return nil + } + } if fw.matchesPattern(logger, path) { logger.Trace("Adding path to watcher: %s", path) return watcher.Add(path)