diff --git a/README.md b/README.md index 377f6f6..f79e58c 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ choice は多人数ビデオ会議を実現する WebRTC SFU サーバーです ### 主な機能 -- Simulcast 対応(low/mid/high の3レイヤー) -- レイヤー選択によるクライアントへの適応的配信 +- Simulcast 対応(low/mid/high の 3 レイヤー) +- 帯域幅ベースの自動レイヤー切り替え - キーフレームベースのレイヤー切り替え - JSON-RPC 2.0 シグナリング @@ -82,6 +82,8 @@ pkg/sfu/ ├── layer.go # Layer - 品質レイヤー(low/mid/high) ├── receiver.go # LayerReceiver - RTP パケットの受信 ├── downtrack.go # DownTrack - レイヤー選択と RTP 送信 +├── bandwidth.go # BandwidthController - 帯域幅ベースのレイヤー自動選択 +├── twcc.go # BandwidthEstimator - 帯域幅推定 ├── rtp.go # RTP ユーティリティ(キーフレーム検出など) ├── signaling.go # JSON-RPC シグナリングハンドラー └── transport.go # WebSocket 接続ラッパー(スレッドセーフ) @@ -101,8 +103,10 @@ pkg/sfu/ | **TrackReceiver** | 1つのトラックの複数レイヤー(low/mid/high)を管理 | | **Layer** | 品質レイヤーを表す。LayerReceiver を保持 | | **LayerReceiver** | リモートトラックから RTP パケットを受信 | -| **DownTrack** | サブスクライバーに RTP パケットを送信。レイヤー選択を担当 | -| **LayerSelector** | 現在のレイヤーと目標レイヤーを管理。キーフレームでレイヤー切り替え | +| **DownTrack** | サブスクライバーに RTP パケットを送信。レイヤー選択を担当 | +| **LayerSelector** | 現在のレイヤーと目標レイヤーを管理。キーフレームでレイヤー切り替え | +| **BandwidthController** | 帯域幅に基づいて各トラックのレイヤーを自動選択 | +| **BandwidthEstimator** | 送信バイト数とパケットロス率から帯域幅を推定 | ## Simulcast とレイヤー選択 @@ -123,6 +127,112 @@ choice は Simulcast に対応しており、パブリッシャーから複数 | mid | 2 | 中品質 | | low | 1 | 低品質 | +## 帯域幅ベースの自動レイヤー切り替え + +Subscriber ごとに BandwidthController が動作し、帯域幅に応じて自動的にレイヤーを切り替えます。 + +### 動作フロー + +```text +DownTrack ─── 送信バイト数を記録 + │ + ▼ +Subscriber.statsLoop (1秒ごと) + │ + ▼ +BandwidthController.UpdateBitrate() ─── 帯域幅を更新 + │ + ▼ +BandwidthController.recalculateAllocations() (500msごと) + │ + ▼ +onLayerChange コールバック + │ + ▼ +Subscriber.SetLayer() → DownTrack.SetTargetLayer() +``` + +### レイヤー選択の閾値 + +| 帯域幅予算 | 選択レイヤー | +| ------------ | ------------ | +| ≥ 2.5 Mbps | high | +| ≥ 500 Kbps | mid | +| < 500 Kbps | low | + +### 帯域幅調整(パケットロス率に基づく) + +| ロス率 | 調整 | +| -------- | -------------- | +| > 10% | 50% に削減 | +| > 2% | 85% に削減 | +| < 1% | 5% 増加 | + +## 高度な輻輳制御 (GCC アルゴリズム) + +choice は Google Congestion Control (GCC) アルゴリズムを実装した高度な輻輳制御機能を備えています。 + +### アーキテクチャ + +```text +クライアント (RTCP TWCC フィードバック) + │ + ▼ +DownTrack.readRTCP() + │ + ├─ ReceiverReport → パケットロス率 + │ + └─ TransportLayerCC → TWCCReceiver.ProcessTWCCFeedback() + │ + ▼ + DelayBasedDetector (GCC) + │ + ├─ Trendline Filter (遅延勾配検出) + ├─ Adaptive Threshold (適応的閾値) + └─ Hysteresis (ヒステリシス制御) + │ + ▼ + 帯域幅推定値 (delay-based) + │ + ▼ + BandwidthEstimator + │ + ├─ Loss-based estimate (ロスベース) + └─ Delay-based estimate (遅延ベース) + │ + ▼ + min(loss-based, delay-based) + │ + ▼ + BandwidthController + │ + ▼ + LayerSelector → レイヤー自動選択 +``` + +### GCC アルゴリズムの概要 + +1. **Trendline Filter**: パケット間到着時間の変動を指数移動平均でフィルタリング +2. **Adaptive Threshold**: ノイズ分散に基づいて閾値を動的に調整(誤検出防止) +3. **Hysteresis**: 状態変化に複数サンプルを要求(急激な変化を抑制) + +### 輻輳状態 + +| 状態 | 条件 | アクション | +| --- | --- | --- | +| Overusing | 遅延勾配 > 閾値 | 帯域幅を 85% に削減 | +| Normal | -閾値 < 遅延勾配 < 閾値 | 現状維持 | +| Underusing | 遅延勾配 < -閾値 | 帯域幅を 5% 増加 | + +### コンポーネント + +| コンポーネント | 説明 | +| --- | --- | +| **TWCCReceiver** | TWCC フィードバックを受信し、遅延ベースの帯域幅を推定 | +| **DelayBasedDetector** | GCC の遅延検出アルゴリズムを実装 | +| **BandwidthEstimator** | ロスベースと遅延ベースの推定を統合 | +| **BandwidthController** | 帯域幅に基づいてレイヤーを自動選択 | + ## シグナリングプロトコル WebSocket 上で JSON-RPC 2.0 を使用。 diff --git a/pkg/sfu/bandwidth.go b/pkg/sfu/bandwidth.go index dae2402..17c1d75 100644 --- a/pkg/sfu/bandwidth.go +++ b/pkg/sfu/bandwidth.go @@ -8,11 +8,12 @@ import ( // LayerAllocation represents the target layer allocation for a subscriber type LayerAllocation struct { - TrackID string - TargetLayer string - CurrentLayer string - MaxLayer string - Paused bool + TrackID string + TargetLayer string + CurrentLayer string + MaxLayer string + Paused bool + ManualOverrideUntil time.Time // Time until manual override is active (auto control disabled) } // BandwidthController manages bandwidth allocation across subscribers @@ -107,7 +108,11 @@ func (bc *BandwidthController) SetMaxLayer(trackID, maxLayer string) { } } +// ManualOverrideDuration is how long manual layer selection disables auto control +const ManualOverrideDuration = 5 * time.Second + // RequestLayer requests a specific layer (manual override) +// This disables auto control for ManualOverrideDuration func (bc *BandwidthController) RequestLayer(trackID, layer string) { bc.mu.Lock() defer bc.mu.Unlock() @@ -118,6 +123,7 @@ func (bc *BandwidthController) RequestLayer(trackID, layer string) { layer = alloc.MaxLayer } alloc.TargetLayer = layer + alloc.ManualOverrideUntil = time.Now().Add(ManualOverrideDuration) } } @@ -132,11 +138,33 @@ func (bc *BandwidthController) GetTargetLayer(trackID string) string { return LayerHigh } -// UpdateBitrate updates the bandwidth estimate +// UpdateBitrate updates the bandwidth estimate (for backwards compatibility) func (bc *BandwidthController) UpdateBitrate(receivedBytes uint64, duration time.Duration, lossRate float64) { bc.estimator.Update(receivedBytes, duration, lossRate) } +// UpdateBitrateWithDelay updates the bandwidth estimate with delay-based estimation +func (bc *BandwidthController) UpdateBitrateWithDelay(receivedBytes uint64, duration time.Duration, lossRate float64, delayEstimate uint64) { + oldEstimate := bc.estimator.GetEstimate() + + // Update loss-based estimate + bc.estimator.Update(receivedBytes, duration, lossRate) + + // Set delay-based estimate from TWCCReceiver + if delayEstimate > 0 { + bc.estimator.SetDelayBasedEstimate(delayEstimate) + } + + newEstimate := bc.estimator.GetEstimate() + if newEstimate != oldEstimate { + slog.Debug("[BandwidthController] Bitrate updated", + slog.Uint64("from", oldEstimate), + slog.Uint64("to", newEstimate), + slog.Float64("lossRate", lossRate), + slog.Uint64("delayEstimate", delayEstimate)) + } +} + // onBitrateUpdate handles bitrate updates from the estimator func (bc *BandwidthController) onBitrateUpdate(bitrate uint64) { bc.mu.Lock() @@ -163,12 +191,19 @@ func (bc *BandwidthController) recalculateAllocations() { // Calculate per-track budget perTrackBudget := bc.availableBitrate / uint64(numTracks) + now := time.Now() + // Allocate layers based on budget for trackID, alloc := range bc.allocations { if alloc.Paused { continue } + // Skip auto control if manual override is active + if now.Before(alloc.ManualOverrideUntil) { + continue + } + newLayer := bc.selectLayerForBudget(perTrackBudget, alloc.MaxLayer) if newLayer != alloc.TargetLayer { @@ -236,14 +271,16 @@ func (bc *BandwidthController) Close() { // LayerSelector handles layer selection for a single subscriber type LayerSelector struct { - trackID string - currentLayer string - targetLayer string - pendingSwitch bool - lastSwitchTime time.Time - switchCooldown time.Duration - onSwitch func(layer string) - mu sync.RWMutex + trackID string + currentLayer string + targetLayer string + pendingSwitch bool + lastSwitchTime time.Time + switchCooldown time.Duration + onSwitch func(layer string) + lastKeyframeReqest time.Time + keyframeInterval time.Duration + mu sync.RWMutex } // NewLayerSelector creates a new layer selector @@ -252,10 +289,11 @@ func NewLayerSelector(trackID string, initialLayer string) *LayerSelector { initialLayer = LayerHigh } return &LayerSelector{ - trackID: trackID, - currentLayer: initialLayer, - targetLayer: initialLayer, - switchCooldown: 2 * time.Second, // Minimum time between switches + trackID: trackID, + currentLayer: initialLayer, + targetLayer: initialLayer, + switchCooldown: 2 * time.Second, // Minimum time between switches + keyframeInterval: 500 * time.Millisecond, // Retry keyframe request interval } } @@ -351,3 +389,23 @@ func (ls *LayerSelector) ForceSwitch(layer string) { slog.String("to", ls.currentLayer), ) } + +// NeedsKeyframeRequest returns true if a keyframe request should be sent. +// This is used for retrying keyframe requests when switching layers. +func (ls *LayerSelector) NeedsKeyframeRequest() bool { + ls.mu.RLock() + defer ls.mu.RUnlock() + + if !ls.pendingSwitch || ls.currentLayer == ls.targetLayer { + return false + } + + return time.Since(ls.lastKeyframeReqest) >= ls.keyframeInterval +} + +// MarkKeyframeRequested records that a keyframe request was sent. +func (ls *LayerSelector) MarkKeyframeRequested() { + ls.mu.Lock() + defer ls.mu.Unlock() + ls.lastKeyframeReqest = time.Now() +} diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index e79e63f..2569a08 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -6,6 +6,7 @@ import ( "sync/atomic" "time" + "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/webrtc/v4" ) @@ -21,6 +22,16 @@ type DownTrack struct { codec string closed atomic.Bool mu sync.RWMutex + + // Stats for bandwidth estimation + bytesSent uint64 + packetsSent uint64 + lastFractionLost uint8 + lastStatsTime time.Time + statsReportMu sync.Mutex + + // Advanced congestion control + twccReceiver *TWCCReceiver } // NewDownTrack creates a new downtrack. @@ -56,6 +67,8 @@ func NewDownTrack(subscriber *Subscriber, trackReceiver *TrackReceiver, codec we sequencer: newRTPSequencer(), selector: NewLayerSelector(trackReceiver.TrackID(), initialLayer), codec: codec.MimeType, + lastStatsTime: time.Now(), + twccReceiver: NewTWCCReceiver(DefaultTWCCConfig()), } // Set up layer switch callback @@ -69,18 +82,54 @@ func NewDownTrack(subscriber *Subscriber, trackReceiver *TrackReceiver, codec we return dt, nil } -// readRTCP reads RTCP packets from the sender. +// readRTCP reads RTCP packets from the sender and extracts loss information. func (d *DownTrack) readRTCP() { for { if d.closed.Load() { return } - if _, _, err := d.sender.ReadRTCP(); err != nil { + + packets, _, err := d.sender.ReadRTCP() + if err != nil { return } + + for _, pkt := range packets { + switch p := pkt.(type) { + case *rtcp.ReceiverReport: + d.handleReceiverReport(p) + case *rtcp.TransportLayerCC: + d.handleTWCCFeedback(p) + } + } } } +// handleReceiverReport processes RTCP Receiver Report to extract loss information. +func (d *DownTrack) handleReceiverReport(rr *rtcp.ReceiverReport) { + for _, report := range rr.Reports { + d.statsReportMu.Lock() + d.lastFractionLost = report.FractionLost + d.statsReportMu.Unlock() + + if report.FractionLost > 0 { + lossPercent := float64(report.FractionLost) / 256.0 * 100.0 + slog.Debug("[DownTrack] Receiver report", + slog.String("trackID", d.trackReceiver.TrackID()), + slog.Float64("lossPercent", lossPercent), + slog.Uint64("totalLost", uint64(report.TotalLost)), + ) + } + } +} + +// handleTWCCFeedback processes Transport Wide Congestion Control feedback. +func (d *DownTrack) handleTWCCFeedback(twcc *rtcp.TransportLayerCC) { + // Use TWCCReceiver for advanced congestion control + // This processes the feedback and updates delay-based bandwidth estimate + d.twccReceiver.ProcessTWCCFeedback(twcc) +} + // requestInitialKeyframe requests keyframes with retry. func (d *DownTrack) requestInitialKeyframe() { time.Sleep(100 * time.Millisecond) @@ -128,6 +177,9 @@ func (d *DownTrack) WriteRTP(packet *rtp.Packet, fromLayer string) error { currentLayer := d.tryLayerSwitch(packet, fromLayer) + // Retry keyframe request if needed + d.retryKeyframeRequestIfNeeded() + if !d.shouldForwardPacket(packet, fromLayer, currentLayer) { return nil } @@ -135,7 +187,40 @@ func (d *DownTrack) WriteRTP(packet *rtp.Packet, fromLayer string) error { ssrc := uint32(d.sender.GetParameters().Encodings[0].SSRC) rewritten := d.sequencer.Rewrite(packet, ssrc) - return d.track.WriteRTP(rewritten) + if err := d.track.WriteRTP(rewritten); err != nil { + return err + } + + // Track bytes and packets sent for bandwidth estimation + d.statsReportMu.Lock() + d.bytesSent += uint64(len(packet.Payload) + 12) // payload + RTP header + d.packetsSent++ + d.statsReportMu.Unlock() + + return nil +} + +// retryKeyframeRequestIfNeeded sends a keyframe request if needed during layer switch. +func (d *DownTrack) retryKeyframeRequestIfNeeded() { + if !d.selector.NeedsKeyframeRequest() { + return + } + + targetLayer := d.selector.GetTargetLayer() + d.selector.MarkKeyframeRequested() + + // Request keyframe asynchronously to avoid blocking + go func() { + layer, ok := d.trackReceiver.GetLayer(targetLayer) + if !ok { + return + } + slog.Debug("[DownTrack] Retrying keyframe request", + slog.String("layer", targetLayer), + slog.String("trackID", d.trackReceiver.TrackID()), + ) + layer.Receiver().SendPLI() + }() } // tryLayerSwitch attempts to switch layers if conditions are met. @@ -175,7 +260,15 @@ func (d *DownTrack) tryLayerSwitch(packet *rtp.Packet, fromLayer string) string // shouldForwardPacket determines if the packet should be forwarded. // Also handles fallback layer switching when current layer is unavailable. func (d *DownTrack) shouldForwardPacket(packet *rtp.Packet, fromLayer, currentLayer string) bool { + // If current layer is active, forward packets from current layer if d.isCurrentLayerActive(currentLayer) { + // During layer switch, also accept packets from current layer + // to avoid black screen while waiting for keyframe from target layer + if d.selector.NeedsSwitch() { + // Accept both current and target layer packets during transition + targetLayer := d.selector.GetTargetLayer() + return fromLayer == currentLayer || fromLayer == targetLayer + } return fromLayer == currentLayer } @@ -230,6 +323,37 @@ func (d *DownTrack) TrackReceiver() *TrackReceiver { return d.trackReceiver } +// TrackID returns the track ID. +func (d *DownTrack) TrackID() string { + return d.trackReceiver.TrackID() +} + +// GetStats returns stats since last call and resets counters. +// Returns bytes sent, duration, loss rate, and delay-based bitrate estimate. +func (d *DownTrack) GetStats() (bytesSent uint64, duration time.Duration, lossRate float64, delayEstimate uint64) { + d.statsReportMu.Lock() + defer d.statsReportMu.Unlock() + + now := time.Now() + duration = now.Sub(d.lastStatsTime) + bytesSent = d.bytesSent + + // Calculate loss rate from RTCP Receiver Report (FractionLost is 0-255) + // FractionLost represents the fraction of packets lost since last report + lossRate = float64(d.lastFractionLost) / 256.0 + + // Get delay-based estimate from TWCCReceiver + delayEstimate = d.twccReceiver.GetDelayEstimate() + + // Reset counters + d.bytesSent = 0 + d.packetsSent = 0 + d.lastFractionLost = 0 + d.lastStatsTime = now + + return bytesSent, duration, lossRate, delayEstimate +} + // Close closes the downtrack. func (d *DownTrack) Close() error { if d.closed.Swap(true) { @@ -239,5 +363,10 @@ func (d *DownTrack) Close() error { d.mu.Lock() defer d.mu.Unlock() + // Close TWCCReceiver + if d.twccReceiver != nil { + d.twccReceiver.Close() + } + return nil } diff --git a/pkg/sfu/subscriber.go b/pkg/sfu/subscriber.go index 945eea3..2431d1c 100644 --- a/pkg/sfu/subscriber.go +++ b/pkg/sfu/subscriber.go @@ -3,19 +3,21 @@ package sfu import ( "log/slog" "sync" + "time" "github.com/pion/webrtc/v4" ) // Subscriber handles the subscribing (downstream) connection to a client. type Subscriber struct { - peer *Peer - pc *webrtc.PeerConnection - downTracks map[string]*DownTrack - routers map[*Router]struct{} - dataChannel *webrtc.DataChannel - mu sync.RWMutex - closed bool + peer *Peer + pc *webrtc.PeerConnection + downTracks map[string]*DownTrack + routers map[*Router]struct{} + dataChannel *webrtc.DataChannel + bandwidthController *BandwidthController + mu sync.RWMutex + closed bool // Negotiation state negotiating bool @@ -29,13 +31,21 @@ func newSubscriber(peer *Peer) (*Subscriber, error) { return nil, err } + bc := NewBandwidthController(DefaultTWCCConfig()) + s := &Subscriber{ - peer: peer, - pc: pc, - downTracks: make(map[string]*DownTrack), - routers: make(map[*Router]struct{}), + peer: peer, + pc: pc, + downTracks: make(map[string]*DownTrack), + routers: make(map[*Router]struct{}), + bandwidthController: bc, } + // Set up bandwidth controller callback to automatically adjust layers + bc.OnLayerChange(func(trackID, layer string) { + s.SetLayer(trackID, layer) + }) + // Create data channel for sending messages to subscriber dc, err := pc.CreateDataChannel("data", nil) if err != nil { @@ -64,9 +74,76 @@ func newSubscriber(peer *Peer) (*Subscriber, error) { } }) + // Start bandwidth controller + bc.Start() + + // Start stats collection loop for bandwidth estimation + go s.statsLoop() + return s, nil } +// statsLoop periodically collects stats from downtracks and updates bandwidth controller. +func (s *Subscriber) statsLoop() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.collectStats() + default: + s.mu.RLock() + closed := s.closed + s.mu.RUnlock() + if closed { + return + } + time.Sleep(100 * time.Millisecond) + } + } +} + +// collectStats collects stats from all downtracks and updates bandwidth controller. +func (s *Subscriber) collectStats() { + s.mu.RLock() + if s.closed { + s.mu.RUnlock() + return + } + + downTracks := make([]*DownTrack, 0, len(s.downTracks)) + for _, dt := range s.downTracks { + downTracks = append(downTracks, dt) + } + s.mu.RUnlock() + + var totalBytes uint64 + var totalDuration time.Duration + var maxLossRate float64 + var minDelayEstimate uint64 // Use minimum for conservative estimation + + for _, dt := range downTracks { + bytes, duration, lossRate, delayEstimate := dt.GetStats() + totalBytes += bytes + if duration > totalDuration { + totalDuration = duration + } + // Use maximum loss rate among all tracks + if lossRate > maxLossRate { + maxLossRate = lossRate + } + // Use minimum delay estimate (most conservative) + if delayEstimate > 0 && (minDelayEstimate == 0 || delayEstimate < minDelayEstimate) { + minDelayEstimate = delayEstimate + } + } + + if totalDuration > 0 && s.bandwidthController != nil { + s.bandwidthController.UpdateBitrateWithDelay(totalBytes, totalDuration, maxLossRate, minDelayEstimate) + } +} + // PeerConnection returns the underlying WebRTC peer connection. func (s *Subscriber) PeerConnection() *webrtc.PeerConnection { return s.pc @@ -118,6 +195,12 @@ func (s *Subscriber) AddDownTrack(track *TrackReceiver) error { s.downTracks[trackID] = dt + // Register track with bandwidth controller for automatic layer selection + // Only for video tracks (audio doesn't have multiple layers) + if track.Kind() == webrtc.RTPCodecTypeVideo { + s.bandwidthController.AddTrack(trackID, dt.GetCurrentLayer()) + } + slog.Info("[Subscriber] Added downtrack", "trackID", trackID) return nil } @@ -140,6 +223,12 @@ func (s *Subscriber) SetLayer(trackID, layer string) { } dt.SetTargetLayer(layer) + + // Update BandwidthController allocation to keep state in sync + // This allows automatic adjustment to continue from this layer + if s.bandwidthController != nil { + s.bandwidthController.RequestLayer(trackID, layer) + } } // GetLayer returns the current and target layer for a track. @@ -284,5 +373,9 @@ func (s *Subscriber) Close() error { } } + if s.bandwidthController != nil { + s.bandwidthController.Close() + } + return s.pc.Close() } diff --git a/pkg/sfu/twcc.go b/pkg/sfu/twcc.go index e8d2c03..9acae98 100644 --- a/pkg/sfu/twcc.go +++ b/pkg/sfu/twcc.go @@ -1,10 +1,11 @@ package sfu import ( + "log/slog" + "math" "sync" "time" - "github.com/pion/interceptor/pkg/cc" "github.com/pion/rtcp" ) @@ -23,78 +24,42 @@ type TWCCConfig struct { // DefaultTWCCConfig returns the default TWCC configuration func DefaultTWCCConfig() TWCCConfig { return TWCCConfig{ - InitialBitrate: 1_000_000, // 1 Mbps - MinBitrate: 100_000, // 100 Kbps + InitialBitrate: 3_000_000, // 3 Mbps (start high, let it adapt down if needed) + MinBitrate: 150_000, // 150 Kbps (enough for low layer) MaxBitrate: 5_000_000, // 5 Mbps FeedbackInterval: 100 * time.Millisecond, } } -// PacketInfo contains information about a received packet for TWCC -type PacketInfo struct { - SequenceNumber uint16 - ArrivalTime time.Time - Size int -} +// DelayGradientState represents the state of the delay gradient detector +type DelayGradientState int + +const ( + DelayStateNormal DelayGradientState = iota + DelayStateOverusing + DelayStateUnderusing +) // TWCCReceiver receives TWCC feedback and estimates bandwidth type TWCCReceiver struct { config TWCCConfig - packets map[uint16]*PacketInfo estimatedBitrate uint64 lossRate float64 - rtt time.Duration - onBitrateChange func(bitrate uint64) mu sync.RWMutex closed bool - closeCh chan struct{} + + // Delay-based estimation (GCC algorithm) + delayDetector *DelayBasedDetector + lastDelayEstimate uint64 } // NewTWCCReceiver creates a new TWCC receiver func NewTWCCReceiver(config TWCCConfig) *TWCCReceiver { return &TWCCReceiver{ - config: config, - packets: make(map[uint16]*PacketInfo), - estimatedBitrate: config.InitialBitrate, - closeCh: make(chan struct{}), - } -} - -// OnBitrateChange sets the callback for bitrate changes -func (t *TWCCReceiver) OnBitrateChange(cb func(bitrate uint64)) { - t.mu.Lock() - defer t.mu.Unlock() - t.onBitrateChange = cb -} - -// RecordPacket records a received packet -func (t *TWCCReceiver) RecordPacket(seqNum uint16, size int) { - t.mu.Lock() - defer t.mu.Unlock() - - if t.closed { - return - } - - t.packets[seqNum] = &PacketInfo{ - SequenceNumber: seqNum, - ArrivalTime: time.Now(), - Size: size, - } - - // Clean up old packets (keep last 1000) - if len(t.packets) > 1000 { - t.cleanupOldPackets() - } -} - -// cleanupOldPackets removes old packet records -func (t *TWCCReceiver) cleanupOldPackets() { - threshold := time.Now().Add(-5 * time.Second) - for seq, pkt := range t.packets { - if pkt.ArrivalTime.Before(threshold) { - delete(t.packets, seq) - } + config: config, + estimatedBitrate: config.InitialBitrate, + delayDetector: NewDelayBasedDetector(config), + lastDelayEstimate: config.InitialBitrate, } } @@ -105,163 +70,136 @@ func (t *TWCCReceiver) GetEstimatedBitrate() uint64 { return t.estimatedBitrate } -// GetLossRate returns the current packet loss rate -func (t *TWCCReceiver) GetLossRate() float64 { - t.mu.RLock() - defer t.mu.RUnlock() - return t.lossRate -} - -// GetRTT returns the current RTT estimate -func (t *TWCCReceiver) GetRTT() time.Duration { - t.mu.RLock() - defer t.mu.RUnlock() - return t.rtt -} - -// Close closes the TWCC receiver -func (t *TWCCReceiver) Close() { +// ProcessTWCCFeedback processes TWCC feedback and updates bandwidth estimate +func (t *TWCCReceiver) ProcessTWCCFeedback(twcc *rtcp.TransportLayerCC) { t.mu.Lock() defer t.mu.Unlock() if t.closed { return } - t.closed = true - close(t.closeCh) -} - -// TWCCSender sends TWCC feedback -type TWCCSender struct { - config TWCCConfig - referenceTime time.Time - packets []*PacketInfo - feedbackCount uint8 - onFeedback func([]rtcp.Packet) - mu sync.Mutex - closed bool - closeCh chan struct{} -} - -// NewTWCCSender creates a new TWCC sender -func NewTWCCSender(config TWCCConfig) *TWCCSender { - return &TWCCSender{ - config: config, - referenceTime: time.Now(), - packets: make([]*PacketInfo, 0, 256), - closeCh: make(chan struct{}), - } -} - -// OnFeedback sets the callback for sending feedback -func (t *TWCCSender) OnFeedback(cb func([]rtcp.Packet)) { - t.mu.Lock() - defer t.mu.Unlock() - t.onFeedback = cb -} - -// RecordPacket records a sent packet -func (t *TWCCSender) RecordPacket(seqNum uint16, size int) { - t.mu.Lock() - defer t.mu.Unlock() - if t.closed { + // Extract arrival times and calculate inter-arrival deltas + arrivalDeltas := t.extractArrivalDeltas(twcc) + if len(arrivalDeltas) == 0 { return } - t.packets = append(t.packets, &PacketInfo{ - SequenceNumber: seqNum, - ArrivalTime: time.Now(), - Size: size, - }) -} - -// Start starts the feedback loop -func (t *TWCCSender) Start() { - go t.feedbackLoop() -} - -// feedbackLoop periodically sends TWCC feedback -func (t *TWCCSender) feedbackLoop() { - ticker := time.NewTicker(t.config.FeedbackInterval) - defer ticker.Stop() - - for { - select { - case <-t.closeCh: - return - case <-ticker.C: - t.sendFeedback() - } + // Calculate packet loss from feedback + received, lost := t.countPacketStatus(twcc) + if received+lost > 0 { + t.lossRate = float64(lost) / float64(received+lost) } -} -// sendFeedback generates and sends TWCC feedback -func (t *TWCCSender) sendFeedback() { - t.mu.Lock() - if t.closed || len(t.packets) == 0 { - t.mu.Unlock() - return + // Update delay-based detector with arrival deltas + state := t.delayDetector.Update(arrivalDeltas) + + // Log current state for debugging + slog.Debug("[TWCCReceiver] TWCC state", + slog.Int("state", int(state)), + slog.Float64("trendlineSlope", t.delayDetector.GetTrendlineSlope()), + slog.Float64("threshold", t.delayDetector.GetThreshold()), + slog.Float64("lossRate", t.lossRate), + slog.Uint64("currentBitrate", t.estimatedBitrate)) + + // Adjust bitrate based on delay state + oldBitrate := t.estimatedBitrate + switch state { + case DelayStateOverusing: + // Reduce bitrate when detecting overuse + t.estimatedBitrate = uint64(float64(t.estimatedBitrate) * 0.85) + slog.Info("[TWCCReceiver] Overuse detected, reducing bitrate", + slog.Uint64("from", oldBitrate), + slog.Uint64("to", t.estimatedBitrate), + slog.Float64("slope", t.delayDetector.GetTrendlineSlope())) + case DelayStateUnderusing: + // Can increase bitrate + t.estimatedBitrate = uint64(float64(t.estimatedBitrate) * 1.05) + slog.Debug("[TWCCReceiver] Underuse detected, increasing bitrate", + slog.Uint64("from", oldBitrate), + slog.Uint64("to", t.estimatedBitrate)) } - packets := t.packets - t.packets = make([]*PacketInfo, 0, 256) - callback := t.onFeedback - t.feedbackCount++ - t.mu.Unlock() - - if callback == nil { - return + // Apply loss-based adjustment on top (only if significant loss) + if t.lossRate > 0.1 { + t.estimatedBitrate = uint64(float64(t.estimatedBitrate) * 0.5) + slog.Info("[TWCCReceiver] High loss rate, reducing bitrate", + slog.Float64("lossRate", t.lossRate), + slog.Uint64("bitrate", t.estimatedBitrate)) + } else if t.lossRate > 0.02 { + t.estimatedBitrate = uint64(float64(t.estimatedBitrate) * 0.85) + slog.Debug("[TWCCReceiver] Moderate loss rate, reducing bitrate", + slog.Float64("lossRate", t.lossRate), + slog.Uint64("bitrate", t.estimatedBitrate)) } - // Build TWCC feedback packet - feedback := t.buildFeedback(packets) - if feedback != nil { - callback([]rtcp.Packet{feedback}) - } -} + // Clamp to configured bounds + t.estimatedBitrate = clampBitrate(t.estimatedBitrate, t.config.MinBitrate, t.config.MaxBitrate) + t.lastDelayEstimate = t.estimatedBitrate +} + +// extractArrivalDeltas extracts inter-arrival time deltas from TWCC feedback +func (t *TWCCReceiver) extractArrivalDeltas(twcc *rtcp.TransportLayerCC) []time.Duration { + deltas := make([]time.Duration, 0, len(twcc.RecvDeltas)) + + // RecvDeltas contain the receive delta for each packet + // We calculate the inter-arrival jitter by looking at variations + for _, recvDelta := range twcc.RecvDeltas { + // RecvDelta.Delta is the inter-packet arrival time + // For SmallDelta: in 250us units (max ~63.75ms) + // For LargeDelta: in 250us units but can be negative + var deltaUs int64 + switch recvDelta.Type { + case rtcp.TypeTCCPacketReceivedSmallDelta: + deltaUs = recvDelta.Delta // Already in 250us units + case rtcp.TypeTCCPacketReceivedLargeDelta: + deltaUs = recvDelta.Delta // Can be negative + default: + continue // Packet not received + } -// buildFeedback creates a TWCC feedback packet -func (t *TWCCSender) buildFeedback(packets []*PacketInfo) rtcp.Packet { - if len(packets) == 0 { - return nil + // Convert to microseconds (Delta is in 250us units) + arrivalDelta := time.Duration(deltaUs*250) * time.Microsecond + deltas = append(deltas, arrivalDelta) } - // Find base sequence number - baseSeq := packets[0].SequenceNumber - for _, p := range packets { - if p.SequenceNumber < baseSeq { - baseSeq = p.SequenceNumber + return deltas +} + +// countPacketStatus counts received and lost packets from TWCC feedback +func (t *TWCCReceiver) countPacketStatus(twcc *rtcp.TransportLayerCC) (received, lost uint64) { + for _, chunk := range twcc.PacketChunks { + switch c := chunk.(type) { + case *rtcp.RunLengthChunk: + if c.PacketStatusSymbol == rtcp.TypeTCCPacketReceivedSmallDelta || + c.PacketStatusSymbol == rtcp.TypeTCCPacketReceivedLargeDelta { + received += uint64(c.RunLength) + } else if c.PacketStatusSymbol == rtcp.TypeTCCPacketNotReceived { + lost += uint64(c.RunLength) + } + case *rtcp.StatusVectorChunk: + for _, symbol := range c.SymbolList { + if symbol == rtcp.TypeTCCPacketReceivedSmallDelta || + symbol == rtcp.TypeTCCPacketReceivedLargeDelta { + received++ + } else if symbol == rtcp.TypeTCCPacketNotReceived { + lost++ + } + } } } + return received, lost +} - // Build packet status chunks - recvDeltas := make([]*rtcp.RecvDelta, 0, len(packets)) - for _, p := range packets { - delta := p.ArrivalTime.Sub(t.referenceTime) - recvDeltas = append(recvDeltas, &rtcp.RecvDelta{ - Type: rtcp.TypeTCCPacketReceivedSmallDelta, - Delta: delta.Microseconds() * 250, // 250us units - }) - } - - return &rtcp.TransportLayerCC{ - Header: rtcp.Header{ - Count: rtcp.FormatTCC, - Type: rtcp.TypeTransportSpecificFeedback, - Length: 0, // Will be calculated - }, - MediaSSRC: 0, // Set by caller - BaseSequenceNumber: baseSeq, - PacketStatusCount: uint16(len(packets)), - ReferenceTime: uint32(t.referenceTime.UnixNano() / 64000), // 64ms units - FbPktCount: t.feedbackCount, - RecvDeltas: recvDeltas, - } +// GetDelayEstimate returns the delay-based bandwidth estimate +func (t *TWCCReceiver) GetDelayEstimate() uint64 { + t.mu.RLock() + defer t.mu.RUnlock() + return t.lastDelayEstimate } -// Close closes the TWCC sender -func (t *TWCCSender) Close() { +// Close closes the TWCC receiver +func (t *TWCCReceiver) Close() { t.mu.Lock() defer t.mu.Unlock() @@ -269,7 +207,6 @@ func (t *TWCCSender) Close() { return } t.closed = true - close(t.closeCh) } // BandwidthEstimator estimates bandwidth using various signals @@ -315,12 +252,13 @@ func (b *BandwidthEstimator) Update(receivedBytes uint64, duration time.Duration now := time.Now() - // Calculate instantaneous bitrate + // Calculate instantaneous bitrate (for reference only) if duration > 0 { - instantBitrate := uint64(float64(receivedBytes*8) / duration.Seconds()) + // Don't use instantaneous bitrate directly - it reflects what we're sending, + // not what the network can handle. Instead, adjust based on loss rate only. - // Apply loss-based adjustment - b.lossBasedEstimate = b.calculateLossBasedEstimate(instantBitrate, lossRate) + // Apply loss-based adjustment to current estimate + b.lossBasedEstimate = b.calculateLossBasedEstimate(b.estimatedBitrate, lossRate) // Combine estimates using weighted average b.estimatedBitrate = b.combineEstimates() @@ -363,18 +301,20 @@ func (b *BandwidthEstimator) calculateLossBasedEstimate(currentBitrate uint64, l // combineEstimates combines delay-based and loss-based estimates func (b *BandwidthEstimator) combineEstimates() uint64 { - // Use the minimum of the two estimates for safety + // If we have both estimates, use a weighted combination + // Prefer the loss-based estimate as it's more reliable initially if b.delayBasedEstimate > 0 && b.lossBasedEstimate > 0 { - if b.delayBasedEstimate < b.lossBasedEstimate { - return b.delayBasedEstimate - } - return b.lossBasedEstimate + // Use 70% loss-based, 30% delay-based for stability + // This prevents the delay-based estimate from causing rapid drops + combined := uint64(float64(b.lossBasedEstimate)*0.7 + float64(b.delayBasedEstimate)*0.3) + return combined } if b.lossBasedEstimate > 0 { return b.lossBasedEstimate } + // No estimates yet, keep current return b.estimatedBitrate } @@ -403,35 +343,133 @@ func clampBitrate(bitrate, min, max uint64) uint64 { return bitrate } -// CongestionController implements congestion control using pion's cc package -type CongestionController struct { - estimator cc.BandwidthEstimator - config TWCCConfig - mu sync.RWMutex +// DelayBasedDetector implements the delay-based congestion detection (GCC algorithm) +// It uses the Trendline filter to detect network congestion based on inter-arrival delays +type DelayBasedDetector struct { + // Trendline filter state (exponential moving average) + trendlineSlope float64 + smoothedDelay float64 + varNoise float64 // Noise variance estimate + threshold float64 // Adaptive threshold + lastState DelayGradientState + overuseCounter int + underuseCounter int + stateHoldDuration int // Number of samples to hold before state change + + mu sync.Mutex +} + +// NewDelayBasedDetector creates a new delay-based detector +func NewDelayBasedDetector(_ TWCCConfig) *DelayBasedDetector { + return &DelayBasedDetector{ + threshold: 25.0, // Initial threshold in ms + varNoise: 100.0, // Initial noise variance + stateHoldDuration: 5, + } } -// NewCongestionController creates a new congestion controller -func NewCongestionController(config TWCCConfig) *CongestionController { - return &CongestionController{ - config: config, +// Update processes new arrival deltas and returns the current congestion state +func (d *DelayBasedDetector) Update(deltas []time.Duration) DelayGradientState { + d.mu.Lock() + defer d.mu.Unlock() + + if len(deltas) < 2 { + return d.lastState + } + + // Calculate delay gradient as the variation in inter-arrival times + // Positive gradient = packets arriving later than expected = congestion + // We look at the difference between consecutive deltas + var totalGradient float64 + var count int + for i := 1; i < len(deltas); i++ { + // Gradient = (current delta - previous delta) + // Positive means increasing delay (congestion) + gradient := float64(deltas[i].Microseconds()-deltas[i-1].Microseconds()) / 1000.0 // Convert to ms + totalGradient += gradient + count++ + } + + if count == 0 { + return d.lastState } + + avgGradient := totalGradient / float64(count) + + // Update smoothed delay using exponential moving average + // Use a slower alpha to be less reactive + alpha := 0.95 // Smoothing factor (more smoothing) + d.smoothedDelay = alpha*d.smoothedDelay + (1-alpha)*avgGradient + + // Update trendline slope (simplified linear regression) + // Use slower adaptation + d.trendlineSlope = 0.95*d.trendlineSlope + 0.05*avgGradient + + // Update noise variance estimate (Kalman-like update) + residual := math.Abs(avgGradient - d.smoothedDelay) + d.varNoise = 0.98*d.varNoise + 0.02*residual*residual + + // Update adaptive threshold based on noise variance + // Higher noise -> higher threshold to avoid false positives + // Minimum threshold of 25ms, maximum of 600ms + d.threshold = math.Max(25.0, math.Min(600.0, 25.0+math.Sqrt(d.varNoise)*3)) + + // Detect state based on trendline slope vs threshold + newState := d.detectState() + + // Apply hysteresis to avoid rapid state changes + d.lastState = d.applyHysteresis(newState) + + return d.lastState } -// SetEstimator sets the bandwidth estimator from pion interceptor -func (c *CongestionController) SetEstimator(estimator cc.BandwidthEstimator) { - c.mu.Lock() - defer c.mu.Unlock() - c.estimator = estimator +// detectState determines the current state based on trendline slope +func (d *DelayBasedDetector) detectState() DelayGradientState { + if d.trendlineSlope > d.threshold { + return DelayStateOverusing + } else if d.trendlineSlope < -d.threshold { + return DelayStateUnderusing + } + return DelayStateNormal } -// GetTargetBitrate returns the target bitrate from the congestion controller -func (c *CongestionController) GetTargetBitrate() int { - c.mu.RLock() - defer c.mu.RUnlock() +// applyHysteresis applies hysteresis to state transitions +func (d *DelayBasedDetector) applyHysteresis(newState DelayGradientState) DelayGradientState { + switch newState { + case DelayStateOverusing: + d.overuseCounter++ + d.underuseCounter = 0 + if d.overuseCounter >= d.stateHoldDuration { + return DelayStateOverusing + } + case DelayStateUnderusing: + d.underuseCounter++ + d.overuseCounter = 0 + if d.underuseCounter >= d.stateHoldDuration { + return DelayStateUnderusing + } + default: + d.overuseCounter = 0 + d.underuseCounter = 0 + } - if c.estimator == nil { - return int(c.config.InitialBitrate) + // Return previous state if not enough samples to confirm change + if d.lastState == DelayStateNormal { + return DelayStateNormal } + return d.lastState +} + +// GetThreshold returns the current adaptive threshold +func (d *DelayBasedDetector) GetThreshold() float64 { + d.mu.Lock() + defer d.mu.Unlock() + return d.threshold +} - return c.estimator.GetTargetBitrate() +// GetTrendlineSlope returns the current trendline slope +func (d *DelayBasedDetector) GetTrendlineSlope() float64 { + d.mu.Lock() + defer d.mu.Unlock() + return d.trendlineSlope }