diff --git a/integration_test.go b/integration_test.go index e61f206..9c1a938 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1234,3 +1234,112 @@ func TestClient_502CommandFallsBackToBackup(t *testing.T) { } } +// --- Keepalive tests --- + +// TestKeepalive_KeepsConnectionAlive verifies that the keepalive probe is sent +// and, when the server responds correctly, the connection remains alive and can +// serve subsequent real requests. +func TestKeepalive_KeepsConnectionAlive(t *testing.T) { + keepaliveSeen := make(chan struct{}, 1) + + conn := mockServer(t, func(s net.Conn) { + _, _ = s.Write([]byte("200 server ready\r\n")) + + buf := make([]byte, 256) + for { + n, err := s.Read(buf) + if err != nil { + return + } + cmd := string(buf[:n]) + switch { + case cmd == "DATE\r\n": + select { + case keepaliveSeen <- struct{}{}: + default: + } + _, _ = s.Write([]byte("111 20060102150405\r\n")) + case len(cmd) > 4 && cmd[:4] == "STAT": + _, _ = s.Write([]byte("223 1 exists\r\n")) + } + } + }) + + reqCh := make(chan *Request, 1) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + if err != nil { + t.Fatalf("newNNTPConnectionFromConn() error = %v", err) + } + nc.keepaliveInterval = 20 * time.Millisecond + nc.keepaliveCommand = "DATE" + + go nc.Run() + t.Cleanup(func() { _ = nc.Close() }) + + // Wait for at least one keepalive probe. + select { + case <-keepaliveSeen: + case <-time.After(2 * time.Second): + t.Fatal("timeout: keepalive probe not sent") + } + + // Verify the connection is still alive by sending a real request. + respCh := make(chan Response, 1) + reqCh <- &Request{ + Ctx: context.Background(), + Payload: []byte("STAT \r\n"), + RespCh: respCh, + } + select { + case resp := <-respCh: + if resp.Err != nil { + t.Fatalf("real request after keepalive: error = %v", resp.Err) + } + if resp.StatusCode != 223 { + t.Errorf("real request: StatusCode = %d, want 223", resp.StatusCode) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout: real request after keepalive timed out") + } +} + +// TestKeepalive_DeadConnection verifies that when the server drops the connection +// in response to a keepalive probe, Run() returns (allowing runConnSlot to reconnect). +func TestKeepalive_DeadConnection(t *testing.T) { + conn := mockServer(t, func(s net.Conn) { + _, _ = s.Write([]byte("200 server ready\r\n")) + + buf := make([]byte, 256) + // Wait for the keepalive command, then close without responding. + for { + n, err := s.Read(buf) + if err != nil { + return + } + if string(buf[:n]) == "DATE\r\n" { + // Drop connection without responding. + _ = s.Close() + return + } + } + }) + + reqCh := make(chan *Request) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + if err != nil { + t.Fatalf("newNNTPConnectionFromConn() error = %v", err) + } + nc.keepaliveInterval = 20 * time.Millisecond + nc.keepaliveCommand = "DATE" + + go nc.Run() + + // Run() should return once the keepalive detects the dead connection. + select { + case <-nc.Done(): + // Good: connection was detected dead and Run() returned. + case <-time.After(2 * time.Second): + t.Fatal("timeout: Run() should have returned after keepalive detected dead connection") + } +} + diff --git a/nntp.go b/nntp.go index ffd5ec9..4581cb3 100644 --- a/nntp.go +++ b/nntp.go @@ -125,9 +125,11 @@ type NNTPConnection struct { Greeting NNTPResponse - firstReq *Request // bootstrap request from connection slot - idleTimeout time.Duration // 0 = no idle timeout - providerName string // set by runConnSlot; used for error context + firstReq *Request // bootstrap request from connection slot + idleTimeout time.Duration // 0 = no idle timeout + keepaliveInterval time.Duration // 0 = no keepalive + keepaliveCommand string // NNTP command for keepalive probe (e.g. "DATE") + providerName string // set by runConnSlot; used for error context stats *providerStats // nil for standalone connections @@ -272,6 +274,19 @@ func safeClose[T any](ch chan T) { close(ch) } +// keepaliveExpectedCode returns the expected NNTP status code for the given +// keepalive command: DATE→111, HELP→100, CAPABILITIES→101, default→111. +func keepaliveExpectedCode(cmd string) int { + switch cmd { + case "HELP": + return 100 + case "CAPABILITIES": + return 101 + default: + return 111 + } +} + func failRequest(ch chan Response, err error) { defer func() { _ = recover() }() select { @@ -458,7 +473,7 @@ func (g *connGate) snapshot() (maxSlots, running int) { // runConnSlot is the slot goroutine that manages the lifecycle of a single // connection: IDLE → CONNECTING → ACTIVE → (death/idle) → IDLE. -func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Request, hotReqCh <-chan *Request, hotPrioCh <-chan *Request, factory ConnFactory, inflight int, auth Auth, idleTimeout time.Duration, gate *connGate, stats *providerStats, providerName string, wg *sync.WaitGroup) { +func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Request, hotReqCh <-chan *Request, hotPrioCh <-chan *Request, factory ConnFactory, inflight int, auth Auth, idleTimeout time.Duration, keepaliveInterval time.Duration, keepaliveCommand string, gate *connGate, stats *providerStats, providerName string, wg *sync.WaitGroup) { defer wg.Done() // Shared read buffer persists across reconnections to avoid re-growing. @@ -549,6 +564,8 @@ func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Requ nc.providerName = providerName nc.hotReqCh = hotReqCh nc.hotPrioCh = hotPrioCh + nc.keepaliveInterval = keepaliveInterval + nc.keepaliveCommand = keepaliveCommand gate.markRunning() nc.Run() // blocks until death or idle timeout gate.markNotRunning() @@ -713,6 +730,12 @@ mainLoop: defer idleTimer.Stop() } + // Set up keepalive timer (nil if no keepalive configured). + var keepaliveCh <-chan time.Time + if c.keepaliveInterval > 0 { + keepaliveCh = time.After(c.keepaliveInterval) + } + for { select { case <-c.ctx.Done(): @@ -749,6 +772,7 @@ mainLoop: // from nil channels block forever in select and are excluded. var req *Request var ok bool + var didKeepalive bool if c.prioCh != nil { // Try hot priority (non-blocking). select { @@ -771,6 +795,8 @@ mainLoop: <-c.inflightSem c.waitForInflightDrain() return + case <-keepaliveCh: + didKeepalive = true } } } @@ -784,8 +810,45 @@ mainLoop: <-c.inflightSem c.waitForInflightDrain() return + case <-keepaliveCh: + didKeepalive = true } } + + // Keepalive probe: send a lightweight command through the normal pipeline + // so readerLoop can match the response in FIFO order. + // inflightSem is already held; readerLoop releases it at line 1008. + if didKeepalive { + keepaliveCh = time.After(c.keepaliveInterval) // reset regardless of outcome + kaCh := make(chan Response, 1) + kaReq := &Request{ + Payload: []byte(c.keepaliveCommand + "\r\n"), + RespCh: kaCh, + Ctx: context.Background(), + } + if _, err := bw.Write(kaReq.Payload); err != nil { + _ = c.conn.Close() + c.failOutstanding() + return + } + if err := bw.Flush(); err != nil { + _ = c.conn.Close() + c.failOutstanding() + return + } + c.pending <- kaReq + select { + case resp := <-kaCh: + if resp.Err != nil || resp.StatusCode != keepaliveExpectedCode(c.keepaliveCommand) { + _ = c.conn.Close() + c.failOutstanding() + return + } + case <-c.ctx.Done(): + return + } + continue + } if !ok { <-c.inflightSem return @@ -1083,6 +1146,18 @@ type Provider struct { ThrottleRestore time.Duration // 0 defaults to 30s KeepAlive time.Duration // TCP keep-alive interval; 0 defaults to 30s; negative disables ReconnectDelay time.Duration // 0 disables auto-reconnect after 502; when set, re-adds provider after this delay + + // KeepaliveInterval, if non-zero, sends a lightweight NNTP command + // periodically when the connection is idle, to detect zombie connections + // before a real request arrives. Recommended: 30s–60s. + // Disabled when SkipPing is true and KeepaliveCommand is empty. + KeepaliveInterval time.Duration + + // KeepaliveCommand is the NNTP command sent as a keepalive probe. + // Defaults to "DATE" (response 111). Use "HELP" (response 100) or + // "CAPABILITIES" (response 101) for providers that do not support DATE. + // Ignored when KeepaliveInterval is 0. + KeepaliveCommand string } type providerGroup struct { @@ -1254,9 +1329,23 @@ func (c *Client) startProviderGroup(p Provider, index int) *providerGroup { pingCancel() } + // Resolve keepalive settings. If SkipPing is true and no explicit command + // is set, keepalive is disabled (we don't know which command the server supports). + kaInterval := p.KeepaliveInterval + kaCmd := p.KeepaliveCommand + if kaInterval > 0 { + if kaCmd == "" { + if p.SkipPing { + kaInterval = 0 // disable: no safe probe command known + } else { + kaCmd = "DATE" + } + } + } + for range p.Connections { c.wg.Add(1) - go runConnSlot(gctx, g.reqCh, g.prioCh, g.hotReqCh, g.hotPrioCh, factory, inflight, p.Auth, p.IdleTimeout, gate, &g.stats, name, &c.wg) + go runConnSlot(gctx, g.reqCh, g.prioCh, g.hotReqCh, g.hotPrioCh, factory, inflight, p.Auth, p.IdleTimeout, kaInterval, kaCmd, gate, &g.stats, name, &c.wg) } return g