diff --git a/integration_test.go b/integration_test.go index 9c1a938..f0887bd 100644 --- a/integration_test.go +++ b/integration_test.go @@ -24,7 +24,7 @@ func TestNNTPConnection_Greeting(t *testing.T) { }) reqCh := make(chan *Request) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("newNNTPConnectionFromConn() error = %v", err) } @@ -39,7 +39,7 @@ func TestNNTPConnection_GreetingReject(t *testing.T) { }) reqCh := make(chan *Request) - _, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + _, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err == nil { t.Fatal("expected error for 502 greeting") } @@ -75,7 +75,7 @@ func TestNNTPConnection_Auth(t *testing.T) { nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{ Username: "testuser", Password: "testpass", - }, nil, nil) + }, "", nil, nil) if err != nil { t.Fatalf("auth error = %v", err) } @@ -100,7 +100,7 @@ func TestNNTPConnection_AuthReject(t *testing.T) { _, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{ Username: "testuser", Password: "wrongpass", - }, nil, nil) + }, "", nil, nil) if err == nil { t.Fatal("expected auth rejection error") } @@ -120,7 +120,7 @@ func TestNNTPConnection_RunSingleRequest(t *testing.T) { }) reqCh := make(chan *Request, 1) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) } @@ -162,7 +162,7 @@ func TestNNTPConnection_RunBodyRequest(t *testing.T) { }) reqCh := make(chan *Request, 1) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) } @@ -224,7 +224,7 @@ func TestNNTPConnection_RunPipelined(t *testing.T) { }) reqCh := make(chan *Request, 3) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 3, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 3, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) } @@ -269,7 +269,7 @@ func TestNNTPConnection_CancelledRequest(t *testing.T) { }) reqCh := make(chan *Request, 1) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) } @@ -311,7 +311,7 @@ func TestNNTPConnection_IdleTimeout(t *testing.T) { }) reqCh := make(chan *Request) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) } @@ -871,7 +871,7 @@ func TestReadOneResponse(t *testing.T) { }) reqCh := make(chan *Request) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) } @@ -1266,7 +1266,7 @@ func TestKeepalive_KeepsConnectionAlive(t *testing.T) { }) reqCh := make(chan *Request, 1) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("newNNTPConnectionFromConn() error = %v", err) } @@ -1325,7 +1325,7 @@ func TestKeepalive_DeadConnection(t *testing.T) { }) reqCh := make(chan *Request) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("newNNTPConnectionFromConn() error = %v", err) } diff --git a/nntp.go b/nntp.go index 4581cb3..304c10c 100644 --- a/nntp.go +++ b/nntp.go @@ -130,6 +130,7 @@ type NNTPConnection struct { keepaliveInterval time.Duration // 0 = no keepalive keepaliveCommand string // NNTP command for keepalive probe (e.g. "DATE") providerName string // set by runConnSlot; used for error context + userAgent string stats *providerStats // nil for standalone connections @@ -165,7 +166,7 @@ func newNetConn(ctx context.Context, addr string, tlsConfig *tls.Config, keepAli return conn, nil } -func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit int, reqCh <-chan *Request, prioCh <-chan *Request, auth Auth, sharedBuf *readBuffer, stats *providerStats) (*NNTPConnection, error) { +func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit int, reqCh <-chan *Request, prioCh <-chan *Request, auth Auth, userAgent string, sharedBuf *readBuffer, stats *providerStats) (*NNTPConnection, error) { if ctx == nil { ctx = context.Background() } @@ -180,16 +181,17 @@ func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit } c := &NNTPConnection{ - conn: conn, - ctx: cctx, - cancel: cancel, - reqCh: reqCh, - prioCh: prioCh, - pending: make(chan *Request, inflightLimit), - inflightSem: make(chan struct{}, inflightLimit), - rb: rb, - stats: stats, - done: make(chan struct{}), + conn: conn, + ctx: cctx, + cancel: cancel, + reqCh: reqCh, + prioCh: prioCh, + pending: make(chan *Request, inflightLimit), + inflightSem: make(chan struct{}, inflightLimit), + rb: rb, + stats: stats, + done: make(chan struct{}), + userAgent: userAgent, } // Server greeting is sent immediately upon connect. @@ -216,13 +218,13 @@ func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit return c, nil } -func NewNNTPConnection(ctx context.Context, addr string, tlsConfig *tls.Config, inflightLimit int, reqCh <-chan *Request, auth Auth) (*NNTPConnection, error) { +func NewNNTPConnection(ctx context.Context, addr string, tlsConfig *tls.Config, inflightLimit int, reqCh <-chan *Request, auth Auth, userAgent string) (*NNTPConnection, error) { conn, err := newNetConn(ctx, addr, tlsConfig, 0) if err != nil { return nil, err } - c, err := newNNTPConnectionFromConn(ctx, conn, inflightLimit, reqCh, nil, auth, nil, nil) + c, err := newNNTPConnectionFromConn(ctx, conn, inflightLimit, reqCh, nil, auth, userAgent, nil, nil) if err != nil { _ = conn.Close() return nil, err @@ -473,7 +475,7 @@ func (g *connGate) snapshot() (maxSlots, running int) { // runConnSlot is the slot goroutine that manages the lifecycle of a single // connection: IDLE → CONNECTING → ACTIVE → (death/idle) → IDLE. -func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Request, hotReqCh <-chan *Request, hotPrioCh <-chan *Request, factory ConnFactory, inflight int, auth Auth, idleTimeout time.Duration, keepaliveInterval time.Duration, keepaliveCommand string, gate *connGate, stats *providerStats, providerName string, wg *sync.WaitGroup) { +func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Request, hotReqCh <-chan *Request, hotPrioCh <-chan *Request, factory ConnFactory, inflight int, auth Auth, userAgent string, idleTimeout time.Duration, keepaliveInterval time.Duration, keepaliveCommand string, gate *connGate, stats *providerStats, providerName string, wg *sync.WaitGroup) { defer wg.Done() // Shared read buffer persists across reconnections to avoid re-growing. @@ -533,7 +535,7 @@ func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Requ continue } - nc, err := newNNTPConnectionFromConn(ctx, conn, inflight, reqCh, prioCh, auth, &sharedBuf, stats) + nc, err := newNNTPConnectionFromConn(ctx, conn, inflight, reqCh, prioCh, auth, userAgent, &sharedBuf, stats) if err != nil { _ = conn.Close() failRequest(firstReq.RespCh, fmt.Errorf("%s: %w", providerName, err)) @@ -1158,6 +1160,9 @@ type Provider struct { // "CAPABILITIES" (response 101) for providers that do not support DATE. // Ignored when KeepaliveInterval is 0. KeepaliveCommand string + + // UserAgent identifies this client to the NNTP server. Empty string disables it. + UserAgent string } type providerGroup struct { @@ -1345,7 +1350,7 @@ func (c *Client) startProviderGroup(p Provider, index int) *providerGroup { for range p.Connections { c.wg.Add(1) - go runConnSlot(gctx, g.reqCh, g.prioCh, g.hotReqCh, g.hotPrioCh, factory, inflight, p.Auth, p.IdleTimeout, kaInterval, kaCmd, gate, &g.stats, name, &c.wg) + go runConnSlot(gctx, g.reqCh, g.prioCh, g.hotReqCh, g.hotPrioCh, factory, inflight, p.Auth, p.UserAgent, p.IdleTimeout, kaInterval, kaCmd, gate, &g.stats, name, &c.wg) } return g diff --git a/nntp_test.go b/nntp_test.go index 52e5568..2a96bcc 100644 --- a/nntp_test.go +++ b/nntp_test.go @@ -1345,3 +1345,54 @@ func TestDynamicWeights_LargeN(t *testing.T) { t.Errorf("provider 0 percentage = %.2f%%, want ~83.33%%", pct0) } } + +func TestUserAgent_StoredOnConnection(t *testing.T) { + srv, cli := net.Pipe() + defer func() { _ = srv.Close() }() + defer func() { _ = cli.Close() }() + + go func() { + _, _ = srv.Write([]byte("200 server ready\r\n")) + buf := make([]byte, 256) + for { + if _, err := srv.Read(buf); err != nil { + return + } + } + }() + + reqCh := make(chan *Request) + nc, err := newNNTPConnectionFromConn(context.Background(), cli, 1, reqCh, nil, Auth{}, "TestAgent/1.0", nil, nil) + if err != nil { + t.Fatalf("newNNTPConnectionFromConn() error = %v", err) + } + + if nc.userAgent != "TestAgent/1.0" { + t.Errorf("userAgent = %q, want %q", nc.userAgent, "TestAgent/1.0") + } +} + +func TestUserAgent_EmptyIsAccepted(t *testing.T) { + srv, cli := net.Pipe() + defer func() { _ = srv.Close() }() + defer func() { _ = cli.Close() }() + + go func() { + _, _ = srv.Write([]byte("200 server ready\r\n")) + buf := make([]byte, 256) + for { + if _, err := srv.Read(buf); err != nil { + return + } + } + }() + + reqCh := make(chan *Request) + nc, err := newNNTPConnectionFromConn(context.Background(), cli, 1, reqCh, nil, Auth{}, "", nil, nil) + if err != nil { + t.Fatalf("newNNTPConnectionFromConn() error = %v", err) + } + if nc.userAgent != "" { + t.Errorf("userAgent = %q, want empty", nc.userAgent) + } +} diff --git a/post_test.go b/post_test.go index bf2664e..feec605 100644 --- a/post_test.go +++ b/post_test.go @@ -322,7 +322,7 @@ func TestNNTPConnection_PostTwoPhase(t *testing.T) { }) reqCh := make(chan *Request, 1) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) } @@ -521,7 +521,7 @@ func TestNNTPConnection_PostRejected(t *testing.T) { }) reqCh := make(chan *Request, 1) - nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil) + nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil) if err != nil { t.Fatalf("connection error = %v", err) }