diff --git a/errors.go b/errors.go index d2ed6bb..8e6f03c 100644 --- a/errors.go +++ b/errors.go @@ -43,6 +43,7 @@ var ( ErrServiceUnavailable = &Error{Code: 502, Message: "service unavailable"} ErrCRCMismatch = errors.New("nntp: yEnc CRC mismatch") ErrProtocolDesync = errors.New("nntp: protocol desync: expected status line, got binary data") + ErrQuotaExceeded = errors.New("nntp: download quota exceeded") ) // toError maps an NNTP status code to a sentinel error, or returns nil for success codes. diff --git a/metrics.go b/metrics.go index 7c1e53f..3396963 100644 --- a/metrics.go +++ b/metrics.go @@ -19,6 +19,11 @@ type providerStats struct { Missing atomic.Int64 // 430/423 responses Errors atomic.Int64 // network errors, bad status codes Ping PingResult // result of initial DATE ping + + // Quota tracking. quotaBytes is set once at group init (0 = unlimited). + quotaBytes int64 + quotaUsed atomic.Int64 // bytes consumed in the current quota period + quotaExceeded atomic.Bool // cached flag: set when quotaUsed >= quotaBytes; cleared on period reset } // ProviderStats is a public snapshot of one provider's metrics. @@ -31,6 +36,12 @@ type ProviderStats struct { ActiveConnections int // currently running connections MaxConnections int // configured connection slots Ping PingResult + + // Quota fields. QuotaBytes is 0 when no quota is configured. + QuotaBytes int64 // configured limit per period (0 = unlimited) + QuotaUsed int64 // bytes consumed in the current period + QuotaResetAt time.Time // when the quota period resets; zero if no period + QuotaExceeded bool // true when QuotaUsed >= QuotaBytes > 0 } // ClientStats is an aggregate snapshot of all provider metrics. diff --git a/nntp.go b/nntp.go index c1654ab..8e0b6ec 100644 --- a/nntp.go +++ b/nntp.go @@ -1050,7 +1050,13 @@ func (c *NNTPConnection) readerLoop() { } if c.stats != nil { - c.stats.BytesConsumed.Add(int64(decoder.BytesConsumed)) + n := int64(decoder.BytesConsumed) + c.stats.BytesConsumed.Add(n) + if c.stats.quotaBytes > 0 { + if c.stats.quotaUsed.Add(n) >= c.stats.quotaBytes { + c.stats.quotaExceeded.Store(true) + } + } if resp.Err != nil { c.stats.Errors.Add(1) } else if decoder.StatusCode == 430 || decoder.StatusCode == 423 { @@ -1163,6 +1169,15 @@ type Provider struct { // UserAgent identifies this client to the NNTP server. Empty string disables it. UserAgent string + + // QuotaBytes is the maximum number of bytes that may be downloaded from this + // provider per QuotaPeriod. 0 means unlimited. + QuotaBytes int64 + + // QuotaPeriod is the rolling window after which the quota counter resets. + // 0 means the quota never resets (lifetime cap). + // Typical value: 30 * 24 * time.Hour (≈ monthly) + QuotaPeriod time.Duration } type providerGroup struct { @@ -1178,6 +1193,37 @@ type providerGroup struct { stats providerStats cancel context.CancelFunc // cancels this group's slot goroutines p Provider // original config; used for auto-reconnect + + // Quota period configuration. quotaBytes/quotaUsed/quotaExceeded live in + // stats so that NNTPConnection can update them via its *providerStats pointer. + quotaPeriod time.Duration // 0 = no auto-reset + quotaResetAt atomic.Int64 // Unix nanoseconds of next reset; 0 = never +} + +// isQuotaExceeded reports whether this provider has consumed its download quota +// for the current period. +// +// Fast path (quota not exceeded): single atomic.Bool load (~1 ns). +// Slow path (flag set, period elapsed): resets counters and returns false. +// The time.Now() call is deferred until the cached flag is actually set. +func (g *providerGroup) isQuotaExceeded() bool { + if g.stats.quotaBytes <= 0 { + return false // unlimited + } + if !g.stats.quotaExceeded.Load() { + return false // fast path: quota not yet hit + } + // Flag is set. If a reset period is configured, check whether it has elapsed. + if g.quotaPeriod > 0 { + resetAt := g.quotaResetAt.Load() + if resetAt > 0 && time.Now().UnixNano() >= resetAt { + g.stats.quotaUsed.Store(0) + g.stats.quotaExceeded.Store(false) + g.quotaResetAt.Store(time.Now().Add(g.quotaPeriod).UnixNano()) + return false + } + } + return true } type Client struct { @@ -1314,17 +1360,22 @@ func (c *Client) startProviderGroup(p Provider, index int) *providerGroup { gctx, gcancel := context.WithCancel(c.ctx) g := &providerGroup{ - name: name, - host: p.Host, - maxConns: p.Connections, - ctx: gctx, - reqCh: make(chan *Request, p.Connections), - prioCh: make(chan *Request, p.Connections), - hotReqCh: make(chan *Request), - hotPrioCh: make(chan *Request), - gate: gate, - cancel: gcancel, - p: p, + name: name, + host: p.Host, + maxConns: p.Connections, + ctx: gctx, + reqCh: make(chan *Request, p.Connections), + prioCh: make(chan *Request, p.Connections), + hotReqCh: make(chan *Request), + hotPrioCh: make(chan *Request), + gate: gate, + cancel: gcancel, + p: p, + quotaPeriod: p.QuotaPeriod, + } + g.stats.quotaBytes = p.QuotaBytes + if p.QuotaBytes > 0 && p.QuotaPeriod > 0 { + g.quotaResetAt.Store(time.Now().Add(p.QuotaPeriod).UnixNano()) } // Ping with a short timeout so we don't block forever. @@ -1562,26 +1613,36 @@ func (c *Client) doSendWithRetry(ctx context.Context, payload []byte, bodyWriter var start int switch c.dispatch { case DispatchFIFO: - // Priority order: first provider with available capacity, - // falling back to provider 0 if all are saturated. + // Priority order: first provider with available capacity and within quota, + // falling back to provider 0 if all are saturated or exceeded. for i, g := range mains { - if g.gate.available.Load() > 0 { + if g.gate.available.Load() > 0 && !g.isQuotaExceeded() { start = i break } } default: // DispatchRoundRobin // Dynamic weighted round-robin: each provider's weight equals - // its available capacity (allowed - held). + // its available capacity (allowed - held). Quota-exceeded providers + // get weight 0 so they are never selected during normal dispatch. cumWeights := make([]int, n) totalW := 0 for i, g := range mains { - avail := max(1, int(g.gate.available.Load())) + avail := 0 + if !g.isQuotaExceeded() { + avail = max(1, int(g.gate.available.Load())) + } totalW += avail cumWeights[i] = totalW } - slot := int(c.nextIdx.Add(1) % uint64(totalW)) - start = sort.SearchInts(cumWeights, slot+1) + if totalW == 0 { + // All providers are quota-exceeded; start at 0 and let the main + // loop below return ErrQuotaExceeded for each. + start = 0 + } else { + slot := int(c.nextIdx.Add(1) % uint64(totalW)) + start = sort.SearchInts(cumWeights, slot+1) + } } for attempt := range n { @@ -1590,6 +1651,10 @@ func (c *Client) doSendWithRetry(ctx context.Context, payload []byte, bodyWriter if hostSkipped(g.host, &skipHosts, skipCount) { continue } + if g.isQuotaExceeded() { + lastErr = fmt.Errorf("%s: %w", g.name, ErrQuotaExceeded) + continue + } resp, ok, cancelled := tryGroup(g) if cancelled { err := ctx.Err() @@ -1639,6 +1704,10 @@ func (c *Client) doSendWithRetry(ctx context.Context, payload []byte, bodyWriter if hostSkipped(g.host, &skipHosts, skipCount) { continue } + if g.isQuotaExceeded() { + lastErr = fmt.Errorf("%s: %w", g.name, ErrQuotaExceeded) + continue + } resp, ok, cancelled := tryGroup(g) if cancelled { err := ctx.Err() @@ -1695,6 +1764,7 @@ func (c *Client) Stats() ClientStats { consumed := g.stats.BytesConsumed.Load() totalBytes += consumed maxSlots, running := g.gate.snapshot() + quotaUsed := g.stats.quotaUsed.Load() ps := ProviderStats{ Name: g.name, BytesConsumed: consumed, @@ -1703,6 +1773,15 @@ func (c *Client) Stats() ClientStats { ActiveConnections: running, MaxConnections: maxSlots, Ping: g.stats.Ping, + QuotaBytes: g.stats.quotaBytes, + QuotaUsed: quotaUsed, + QuotaExceeded: g.stats.quotaBytes > 0 && quotaUsed >= g.stats.quotaBytes, + } + if g.stats.quotaBytes > 0 && g.quotaPeriod > 0 { + resetAt := g.quotaResetAt.Load() + if resetAt > 0 { + ps.QuotaResetAt = time.Unix(0, resetAt) + } } if secs > 0 { ps.AvgSpeed = float64(consumed) / secs diff --git a/nntp_test.go b/nntp_test.go index 2a96bcc..429b3f1 100644 --- a/nntp_test.go +++ b/nntp_test.go @@ -1396,3 +1396,243 @@ func TestUserAgent_EmptyIsAccepted(t *testing.T) { t.Errorf("userAgent = %q, want empty", nc.userAgent) } } + +// --- Quota tests --- + +func TestProviderGroup_isQuotaExceeded_Unlimited(t *testing.T) { + g := &providerGroup{} + // quotaBytes == 0 → always false regardless of usage + g.stats.quotaUsed.Store(1_000_000) + g.stats.quotaExceeded.Store(true) + if g.isQuotaExceeded() { + t.Error("isQuotaExceeded() = true with no quota configured, want false") + } +} + +func TestProviderGroup_isQuotaExceeded_NotYetHit(t *testing.T) { + g := &providerGroup{} + g.stats.quotaBytes = 100 + // Flag not set → fast path returns false without checking time. + if g.isQuotaExceeded() { + t.Error("isQuotaExceeded() = true before quota hit, want false") + } +} + +func TestProviderGroup_isQuotaExceeded_FlagSet_NoReset(t *testing.T) { + g := &providerGroup{} // quotaPeriod == 0 → no auto-reset + g.stats.quotaBytes = 100 + g.stats.quotaUsed.Store(100) + g.stats.quotaExceeded.Store(true) + + if !g.isQuotaExceeded() { + t.Error("isQuotaExceeded() = false after quota hit with no period, want true") + } +} + +func TestProviderGroup_isQuotaExceeded_PeriodElapsed_Resets(t *testing.T) { + g := &providerGroup{quotaPeriod: time.Hour} + g.stats.quotaBytes = 50 + g.stats.quotaUsed.Store(50) + g.stats.quotaExceeded.Store(true) + // Set reset deadline in the past so the period is considered elapsed. + g.quotaResetAt.Store(time.Now().Add(-time.Second).UnixNano()) + + if g.isQuotaExceeded() { + t.Error("isQuotaExceeded() = true after period elapsed, want false (reset should have fired)") + } + if g.stats.quotaUsed.Load() != 0 { + t.Errorf("quotaUsed after reset = %d, want 0", g.stats.quotaUsed.Load()) + } + if g.stats.quotaExceeded.Load() { + t.Error("quotaExceeded flag should be cleared after reset") + } + if g.quotaResetAt.Load() <= time.Now().UnixNano() { + t.Error("quotaResetAt should be scheduled in the future after reset") + } +} + +func TestProviderGroup_isQuotaExceeded_PeriodNotYetElapsed(t *testing.T) { + g := &providerGroup{quotaPeriod: time.Hour} + g.stats.quotaBytes = 50 + g.stats.quotaUsed.Store(50) + g.stats.quotaExceeded.Store(true) + // Reset deadline is in the future — should stay exceeded. + g.quotaResetAt.Store(time.Now().Add(time.Hour).UnixNano()) + + if !g.isQuotaExceeded() { + t.Error("isQuotaExceeded() = false before period elapsed, want true") + } +} + +func TestProviderStats_QuotaAccounting(t *testing.T) { + var stats providerStats + stats.quotaBytes = 100 + + // Add 60 bytes — not yet exceeded. + stats.quotaUsed.Add(60) + if stats.quotaExceeded.Load() { + t.Error("quotaExceeded should be false before threshold") + } + + // Simulate readLoop: add 40 more bytes → crosses threshold. + if stats.quotaUsed.Add(40) >= stats.quotaBytes { + stats.quotaExceeded.Store(true) + } + if !stats.quotaExceeded.Load() { + t.Error("quotaExceeded should be true after crossing threshold") + } + if stats.quotaUsed.Load() != 100 { + t.Errorf("quotaUsed = %d, want 100", stats.quotaUsed.Load()) + } +} + +func TestClient_Stats_QuotaFields(t *testing.T) { + factory := func(ctx context.Context) (net.Conn, error) { + client, server := net.Pipe() + go func() { + defer func() { _ = server.Close() }() + _, _ = server.Write([]byte("200 server ready\r\n")) + buf := make([]byte, 4096) + for { + n, err := server.Read(buf) + if err != nil { + return + } + cmd := string(buf[:n]) + if len(cmd) >= 4 && cmd[:4] == "DATE" { + _, _ = server.Write([]byte("111 20240315120000\r\n")) + } + } + }() + return client, nil + } + + c, err := NewClient(context.Background(), []Provider{ + { + Factory: factory, + Connections: 1, + QuotaBytes: 1_000_000, + QuotaPeriod: 30 * 24 * time.Hour, + SkipPing: true, + }, + }) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + defer func() { _ = c.Close() }() + + stats := c.Stats() + if len(stats.Providers) == 0 { + t.Fatal("no providers in stats") + } + ps := stats.Providers[0] + if ps.QuotaBytes != 1_000_000 { + t.Errorf("QuotaBytes = %d, want 1_000_000", ps.QuotaBytes) + } + if ps.QuotaUsed != 0 { + t.Errorf("QuotaUsed = %d, want 0", ps.QuotaUsed) + } + if ps.QuotaExceeded { + t.Error("QuotaExceeded should be false initially") + } + if ps.QuotaResetAt.IsZero() { + t.Error("QuotaResetAt should be set when QuotaPeriod > 0") + } + if ps.QuotaResetAt.Before(time.Now()) { + t.Error("QuotaResetAt should be in the future") + } +} + +func TestClient_QuotaExceeded_FallsThrough(t *testing.T) { + // Two providers: first has quota exceeded, second responds normally. + // The request must be served by the second provider. + makeFactory := func(statusLine string) ConnFactory { + return func(ctx context.Context) (net.Conn, error) { + client, server := net.Pipe() + go func() { + defer func() { _ = server.Close() }() + _, _ = server.Write([]byte("200 server ready\r\n")) + buf := make([]byte, 4096) + for { + n, err := server.Read(buf) + if err != nil { + return + } + cmd := string(buf[:n]) + if len(cmd) >= 4 && cmd[:4] == "DATE" { + _, _ = server.Write([]byte("111 20240315120000\r\n")) + } else { + _, _ = server.Write([]byte(statusLine + "\r\n")) + } + } + }() + return client, nil + } + } + + c, err := NewClient(context.Background(), []Provider{ + {Factory: makeFactory("223 "), Connections: 1, SkipPing: true, QuotaBytes: 1}, + {Factory: makeFactory("223 "), Connections: 1, SkipPing: true}, + }) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + defer func() { _ = c.Close() }() + + // Manually mark the first provider as quota-exceeded. + mains := *c.mainGroups.Load() + mains[0].stats.quotaExceeded.Store(true) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + resp := <-c.Send(ctx, []byte("STAT \r\n"), nil) + if resp.Err != nil { + t.Fatalf("Send() error = %v, want nil (second provider should serve)", resp.Err) + } + if resp.StatusCode != 223 { + t.Errorf("StatusCode = %d, want 223", resp.StatusCode) + } +} + +func TestClient_AllQuotaExceeded_ReturnsError(t *testing.T) { + makeFactory := func() ConnFactory { + return func(ctx context.Context) (net.Conn, error) { + client, server := net.Pipe() + go func() { + defer func() { _ = server.Close() }() + _, _ = server.Write([]byte("200 server ready\r\n")) + buf := make([]byte, 4096) + for { + if _, err := server.Read(buf); err != nil { + return + } + } + }() + return client, nil + } + } + + c, err := NewClient(context.Background(), []Provider{ + {Factory: makeFactory(), Connections: 1, SkipPing: true, QuotaBytes: 1}, + }) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + defer func() { _ = c.Close() }() + + // Mark the only provider as quota-exceeded. + mains := *c.mainGroups.Load() + mains[0].stats.quotaExceeded.Store(true) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + resp := <-c.Send(ctx, []byte("STAT \r\n"), nil) + if resp.Err == nil { + t.Fatal("Send() should return an error when all providers are quota-exceeded") + } + if !errors.Is(resp.Err, ErrQuotaExceeded) { + t.Errorf("error = %v, want errors.Is(ErrQuotaExceeded)", resp.Err) + } +}