Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestNNTPConnection_Greeting(t *testing.T) {
})

reqCh := make(chan *Request)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("newNNTPConnectionFromConn() error = %v", err)
}
Expand All @@ -39,7 +39,7 @@ func TestNNTPConnection_GreetingReject(t *testing.T) {
})

reqCh := make(chan *Request)
_, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
_, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err == nil {
t.Fatal("expected error for 502 greeting")
}
Expand Down Expand Up @@ -75,7 +75,7 @@ func TestNNTPConnection_Auth(t *testing.T) {
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{
Username: "testuser",
Password: "testpass",
}, nil, nil)
}, "", nil, nil)
if err != nil {
t.Fatalf("auth error = %v", err)
}
Expand All @@ -100,7 +100,7 @@ func TestNNTPConnection_AuthReject(t *testing.T) {
_, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{
Username: "testuser",
Password: "wrongpass",
}, nil, nil)
}, "", nil, nil)
if err == nil {
t.Fatal("expected auth rejection error")
}
Expand All @@ -120,7 +120,7 @@ func TestNNTPConnection_RunSingleRequest(t *testing.T) {
})

reqCh := make(chan *Request, 1)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestNNTPConnection_RunBodyRequest(t *testing.T) {
})

reqCh := make(chan *Request, 1)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down Expand Up @@ -224,7 +224,7 @@ func TestNNTPConnection_RunPipelined(t *testing.T) {
})

reqCh := make(chan *Request, 3)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 3, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 3, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down Expand Up @@ -269,7 +269,7 @@ func TestNNTPConnection_CancelledRequest(t *testing.T) {
})

reqCh := make(chan *Request, 1)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down Expand Up @@ -311,7 +311,7 @@ func TestNNTPConnection_IdleTimeout(t *testing.T) {
})

reqCh := make(chan *Request)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down Expand Up @@ -871,7 +871,7 @@ func TestReadOneResponse(t *testing.T) {
})

reqCh := make(chan *Request)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down Expand Up @@ -1266,7 +1266,7 @@ func TestKeepalive_KeepsConnectionAlive(t *testing.T) {
})

reqCh := make(chan *Request, 1)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("newNNTPConnectionFromConn() error = %v", err)
}
Expand Down Expand Up @@ -1325,7 +1325,7 @@ func TestKeepalive_DeadConnection(t *testing.T) {
})

reqCh := make(chan *Request)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("newNNTPConnectionFromConn() error = %v", err)
}
Expand Down
37 changes: 21 additions & 16 deletions nntp.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ type NNTPConnection struct {
keepaliveInterval time.Duration // 0 = no keepalive
keepaliveCommand string // NNTP command for keepalive probe (e.g. "DATE")
providerName string // set by runConnSlot; used for error context
userAgent string

stats *providerStats // nil for standalone connections

Expand Down Expand Up @@ -165,7 +166,7 @@ func newNetConn(ctx context.Context, addr string, tlsConfig *tls.Config, keepAli
return conn, nil
}

func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit int, reqCh <-chan *Request, prioCh <-chan *Request, auth Auth, sharedBuf *readBuffer, stats *providerStats) (*NNTPConnection, error) {
func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit int, reqCh <-chan *Request, prioCh <-chan *Request, auth Auth, userAgent string, sharedBuf *readBuffer, stats *providerStats) (*NNTPConnection, error) {
if ctx == nil {
ctx = context.Background()
}
Expand All @@ -180,16 +181,17 @@ func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit
}

c := &NNTPConnection{
conn: conn,
ctx: cctx,
cancel: cancel,
reqCh: reqCh,
prioCh: prioCh,
pending: make(chan *Request, inflightLimit),
inflightSem: make(chan struct{}, inflightLimit),
rb: rb,
stats: stats,
done: make(chan struct{}),
conn: conn,
ctx: cctx,
cancel: cancel,
reqCh: reqCh,
prioCh: prioCh,
pending: make(chan *Request, inflightLimit),
inflightSem: make(chan struct{}, inflightLimit),
rb: rb,
stats: stats,
done: make(chan struct{}),
userAgent: userAgent,
}

// Server greeting is sent immediately upon connect.
Expand All @@ -216,13 +218,13 @@ func newNNTPConnectionFromConn(ctx context.Context, conn net.Conn, inflightLimit
return c, nil
}

func NewNNTPConnection(ctx context.Context, addr string, tlsConfig *tls.Config, inflightLimit int, reqCh <-chan *Request, auth Auth) (*NNTPConnection, error) {
func NewNNTPConnection(ctx context.Context, addr string, tlsConfig *tls.Config, inflightLimit int, reqCh <-chan *Request, auth Auth, userAgent string) (*NNTPConnection, error) {
conn, err := newNetConn(ctx, addr, tlsConfig, 0)
if err != nil {
return nil, err
}

c, err := newNNTPConnectionFromConn(ctx, conn, inflightLimit, reqCh, nil, auth, nil, nil)
c, err := newNNTPConnectionFromConn(ctx, conn, inflightLimit, reqCh, nil, auth, userAgent, nil, nil)
if err != nil {
_ = conn.Close()
return nil, err
Expand Down Expand Up @@ -473,7 +475,7 @@ func (g *connGate) snapshot() (maxSlots, running int) {

// runConnSlot is the slot goroutine that manages the lifecycle of a single
// connection: IDLE → CONNECTING → ACTIVE → (death/idle) → IDLE.
func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Request, hotReqCh <-chan *Request, hotPrioCh <-chan *Request, factory ConnFactory, inflight int, auth Auth, idleTimeout time.Duration, keepaliveInterval time.Duration, keepaliveCommand string, gate *connGate, stats *providerStats, providerName string, wg *sync.WaitGroup) {
func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Request, hotReqCh <-chan *Request, hotPrioCh <-chan *Request, factory ConnFactory, inflight int, auth Auth, userAgent string, idleTimeout time.Duration, keepaliveInterval time.Duration, keepaliveCommand string, gate *connGate, stats *providerStats, providerName string, wg *sync.WaitGroup) {
defer wg.Done()

// Shared read buffer persists across reconnections to avoid re-growing.
Expand Down Expand Up @@ -533,7 +535,7 @@ func runConnSlot(ctx context.Context, reqCh <-chan *Request, prioCh <-chan *Requ
continue
}

nc, err := newNNTPConnectionFromConn(ctx, conn, inflight, reqCh, prioCh, auth, &sharedBuf, stats)
nc, err := newNNTPConnectionFromConn(ctx, conn, inflight, reqCh, prioCh, auth, userAgent, &sharedBuf, stats)
if err != nil {
_ = conn.Close()
failRequest(firstReq.RespCh, fmt.Errorf("%s: %w", providerName, err))
Expand Down Expand Up @@ -1158,6 +1160,9 @@ type Provider struct {
// "CAPABILITIES" (response 101) for providers that do not support DATE.
// Ignored when KeepaliveInterval is 0.
KeepaliveCommand string

// UserAgent identifies this client to the NNTP server. Empty string disables it.
UserAgent string
}

type providerGroup struct {
Expand Down Expand Up @@ -1345,7 +1350,7 @@ func (c *Client) startProviderGroup(p Provider, index int) *providerGroup {

for range p.Connections {
c.wg.Add(1)
go runConnSlot(gctx, g.reqCh, g.prioCh, g.hotReqCh, g.hotPrioCh, factory, inflight, p.Auth, p.IdleTimeout, kaInterval, kaCmd, gate, &g.stats, name, &c.wg)
go runConnSlot(gctx, g.reqCh, g.prioCh, g.hotReqCh, g.hotPrioCh, factory, inflight, p.Auth, p.UserAgent, p.IdleTimeout, kaInterval, kaCmd, gate, &g.stats, name, &c.wg)
}

return g
Expand Down
51 changes: 51 additions & 0 deletions nntp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,54 @@ func TestDynamicWeights_LargeN(t *testing.T) {
t.Errorf("provider 0 percentage = %.2f%%, want ~83.33%%", pct0)
}
}

func TestUserAgent_StoredOnConnection(t *testing.T) {
srv, cli := net.Pipe()
defer func() { _ = srv.Close() }()
defer func() { _ = cli.Close() }()

go func() {
_, _ = srv.Write([]byte("200 server ready\r\n"))
buf := make([]byte, 256)
for {
if _, err := srv.Read(buf); err != nil {
return
}
}
}()

reqCh := make(chan *Request)
nc, err := newNNTPConnectionFromConn(context.Background(), cli, 1, reqCh, nil, Auth{}, "TestAgent/1.0", nil, nil)
if err != nil {
t.Fatalf("newNNTPConnectionFromConn() error = %v", err)
}

if nc.userAgent != "TestAgent/1.0" {
t.Errorf("userAgent = %q, want %q", nc.userAgent, "TestAgent/1.0")
}
}

func TestUserAgent_EmptyIsAccepted(t *testing.T) {
srv, cli := net.Pipe()
defer func() { _ = srv.Close() }()
defer func() { _ = cli.Close() }()

go func() {
_, _ = srv.Write([]byte("200 server ready\r\n"))
buf := make([]byte, 256)
for {
if _, err := srv.Read(buf); err != nil {
return
}
}
}()

reqCh := make(chan *Request)
nc, err := newNNTPConnectionFromConn(context.Background(), cli, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("newNNTPConnectionFromConn() error = %v", err)
}
if nc.userAgent != "" {
t.Errorf("userAgent = %q, want empty", nc.userAgent)
}
}
4 changes: 2 additions & 2 deletions post_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func TestNNTPConnection_PostTwoPhase(t *testing.T) {
})

reqCh := make(chan *Request, 1)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down Expand Up @@ -521,7 +521,7 @@ func TestNNTPConnection_PostRejected(t *testing.T) {
})

reqCh := make(chan *Request, 1)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, nil, nil)
nc, err := newNNTPConnectionFromConn(context.Background(), conn, 1, reqCh, nil, Auth{}, "", nil, nil)
if err != nil {
t.Fatalf("connection error = %v", err)
}
Expand Down
Loading