diff --git a/shortcuts/mail/mail_watch.go b/shortcuts/mail/mail_watch.go index c5699427..bf368691 100644 --- a/shortcuts/mail/mail_watch.go +++ b/shortcuts/mail/mail_watch.go @@ -424,6 +424,9 @@ var MailWatch = common.Shortcut{ larkws.WithLogger(sdkLogger), ) + watchCtx, cancelWatch := context.WithCancel(ctx) + defer cancelWatch() + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { @@ -433,6 +436,7 @@ var MailWatch = common.Shortcut{ } }() <-sigCh + signal.Stop(sigCh) info(fmt.Sprintf("\nShutting down... (received %d events)", eventCount)) info("Unsubscribing mailbox events...") if unsubErr := unsubscribe(); unsubErr != nil { @@ -440,19 +444,26 @@ var MailWatch = common.Shortcut{ } else { info("Mailbox unsubscribed.") } - signal.Stop(sigCh) - os.Exit(0) + cancelWatch() }() info("Connected. Waiting for mail events... (Ctrl+C to stop)") - if err := cli.Start(ctx); err != nil { - unsubscribe() //nolint:errcheck // best-effort cleanup - return output.ErrNetwork("WebSocket connection failed: %v", err) + if err := cli.Start(watchCtx); err != nil { + return handleMailWatchStartError(err, watchCtx, unsubscribe) } return nil }, } +func handleMailWatchStartError(err error, watchCtx context.Context, unsubscribe func() error) error { + if watchCtx.Err() != nil { + // Graceful shutdown via signal cancellation; not an error. + return nil + } + unsubscribe() //nolint:errcheck // best-effort cleanup + return output.ErrNetwork("WebSocket connection failed: %v", err) +} + // extractMailEventBody extracts the event body from the Lark event envelope. func extractMailEventBody(data map[string]interface{}) map[string]interface{} { // V2 envelope: { header: {...}, event: { mail_address, message_id, ... } } diff --git a/shortcuts/mail/mail_watch_test.go b/shortcuts/mail/mail_watch_test.go index 02476fbd..e27c9d14 100644 --- a/shortcuts/mail/mail_watch_test.go +++ b/shortcuts/mail/mail_watch_test.go @@ -539,6 +539,40 @@ func TestWrapWatchSubscribeErrorExitError(t *testing.T) { } } +func TestHandleMailWatchStartErrorGracefulShutdownSkipsCleanup(t *testing.T) { + watchCtx, cancel := context.WithCancel(context.Background()) + cancel() + + unsubscribeCalled := false + err := handleMailWatchStartError(assertErr("context canceled"), watchCtx, func() error { + unsubscribeCalled = true + return nil + }) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if unsubscribeCalled { + t.Fatal("unsubscribe should not be called for graceful shutdown") + } +} + +func TestHandleMailWatchStartErrorNetworkFailureCleansUp(t *testing.T) { + unsubscribeCalled := false + err := handleMailWatchStartError(assertErr("boom"), context.Background(), func() error { + unsubscribeCalled = true + return nil + }) + if err == nil { + t.Fatal("expected error") + } + if !unsubscribeCalled { + t.Fatal("expected unsubscribe to be called for startup failure") + } + if !strings.Contains(err.Error(), "WebSocket connection failed: boom") { + t.Fatalf("unexpected error: %v", err) + } +} + // --- watchFetchFormat --- func TestWatchFetchFormat(t *testing.T) {