From c606045a10353dd7f30ddbc56636cd76da9efda3 Mon Sep 17 00:00:00 2001 From: Nico Duldhardt Date: Sun, 1 Mar 2026 13:15:01 +0100 Subject: [PATCH 1/2] feat: route fetch requests to partition-owning brokers Split fetch requests by owning broker, forward concurrently, and merge responses. Retry partitions rejected with NOT_LEADER_OR_FOLLOWER up to 3 times. For v12+ requests that use topic IDs instead of names, resolve IDs via a metadata-refreshed cache and use a collision-safe key to prevent unresolved topics from merging silently. Adds EncodeFetchRequest and ParseFetchResponse codecs with round-trip and kmsg validation tests. --- cmd/proxy/main.go | 352 +++++++++++++++++++++++++++++++++- cmd/proxy/main_test.go | 298 ++++++++++++++++++++++++++++ pkg/protocol/request.go | 94 +++++++++ pkg/protocol/request_test.go | 184 ++++++++++++++++++ pkg/protocol/response.go | 196 +++++++++++++++++++ pkg/protocol/response_test.go | 154 +++++++++++++++ 6 files changed, 1276 insertions(+), 2 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 1d55d57..34b88c7 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -58,6 +58,8 @@ type proxy struct { groupRouter *metadata.GroupRouter brokerAddrMu sync.RWMutex brokerAddrs map[string]string // brokerID -> "host:port" + topicNamesMu sync.RWMutex + topicNames map[[16]byte]string // topicID -> topic name backendRetries int backendBackoff time.Duration } @@ -112,6 +114,7 @@ func main() { cacheTTL: cacheTTL, apiVersions: generateProxyApiVersions(), brokerAddrs: make(map[string]string), + topicNames: make(map[[16]byte]string), backendRetries: backendRetries, backendBackoff: backendBackoff, } @@ -470,6 +473,18 @@ func (p *proxy) handleConnection(ctx context.Context, conn net.Conn) { return } continue + case protocol.APIKeyFetch: + resp, err := p.handleFetchRouting(ctx, header, frame.Payload, pool) + if err != nil { + p.logger.Warn("fetch routing failed", "error", err) + p.respondBackendError(conn, header, frame.Payload) + return + } + if err := protocol.WriteFrame(conn, resp); err != nil { + p.logger.Warn("write fetch response failed", "error", err) + return + } + continue case protocol.APIKeyJoinGroup, protocol.APIKeySyncGroup, protocol.APIKeyHeartbeat, @@ -1505,8 +1520,9 @@ func (p *proxy) updateBrokerAddrs(brokers []protocol.MetadataBroker) { } // refreshBrokerAddrs queries metadata solely to update the broker ID -> address -// mapping. Used when static backends are configured (so currentBackends returns -// early) but partition-aware routing still needs broker ID resolution. +// mapping and the topic ID -> name mapping. Used when static backends are +// configured (so currentBackends returns early) but partition-aware routing +// still needs broker ID resolution and topic ID resolution. func (p *proxy) refreshBrokerAddrs(ctx context.Context) { if p.store == nil { return @@ -1516,6 +1532,29 @@ func (p *proxy) refreshBrokerAddrs(ctx context.Context) { return } p.updateBrokerAddrs(meta.Brokers) + p.updateTopicNames(meta.Topics) +} + +// updateTopicNames rebuilds the topic ID -> name mapping from metadata. +func (p *proxy) updateTopicNames(topics []protocol.MetadataTopic) { + names := make(map[[16]byte]string, len(topics)) + var zeroID [16]byte + for _, topic := range topics { + if topic.TopicID != zeroID && topic.Name != "" { + names[topic.TopicID] = topic.Name + } + } + p.topicNamesMu.Lock() + p.topicNames = names + p.topicNamesMu.Unlock() +} + +// resolveTopicID returns the topic name for a given topic ID, or "" if unknown. +func (p *proxy) resolveTopicID(id [16]byte) string { + p.topicNamesMu.RLock() + name := p.topicNames[id] + p.topicNamesMu.RUnlock() + return name } func (p *proxy) forwardToBackend(ctx context.Context, conn net.Conn, backendAddr string, payload []byte) ([]byte, error) { @@ -1614,3 +1653,312 @@ func (p *proxy) extractGroupID(apiKey int16, payload []byte) string { return "" } } + +// handleFetchRouting routes fetch requests to the broker(s) that own the +// requested partitions. Like produce routing, the request is split by owning +// broker, forwarded concurrently, and responses are merged. On +// NOT_LEADER_OR_FOLLOWER, failed partitions are retried on a different broker. +func (p *proxy) handleFetchRouting(ctx context.Context, header *protocol.RequestHeader, payload []byte, pool *connPool) ([]byte, error) { + _, req, err := protocol.ParseRequest(payload) + if err != nil { + return p.forwardFetchRaw(ctx, payload, pool) + } + fetchReq, ok := req.(*protocol.FetchRequest) + if !ok || len(fetchReq.Topics) == 0 { + return p.forwardFetchRaw(ctx, payload, pool) + } + + // Resolve topic names for v12+ requests that use topic IDs. + p.resolveFetchTopicNames(fetchReq) + + groups := p.groupFetchPartitionsByBroker(fetchReq, nil) + return p.forwardFetch(ctx, header, fetchReq, payload, groups, pool) +} + +// forwardFetchRaw forwards an unparseable fetch payload to any backend. +func (p *proxy) forwardFetchRaw(ctx context.Context, payload []byte, pool *connPool) ([]byte, error) { + conn, addr, err := p.connectForAddr(ctx, "", nil, pool) + if err != nil { + return nil, err + } + resp, err := p.forwardToBackend(ctx, conn, addr, payload) + if err != nil { + conn.Close() + return nil, err + } + pool.Return(addr, conn) + return resp, nil +} + +// resolveFetchTopicNames fills in topic names from topic IDs for v12+ fetch +// requests. The partition router uses topic names, so we need to resolve IDs. +func (p *proxy) resolveFetchTopicNames(req *protocol.FetchRequest) { + var zeroID [16]byte + for i := range req.Topics { + if req.Topics[i].Name == "" && req.Topics[i].TopicID != zeroID { + req.Topics[i].Name = p.resolveTopicID(req.Topics[i].TopicID) + } + } +} + +// fetchTopicKey returns a deduplication key for a fetch topic. It uses the +// topic name when available, falling back to the hex-encoded topic ID. This +// prevents multiple unresolved topics (all with name "") from colliding. +func fetchTopicKey(name string, id [16]byte) string { + if name != "" { + return name + } + return fmt.Sprintf("id:%x", id) +} + +// groupFetchPartitionsByBroker groups topic-partitions by the owning broker's +// address. If include is non-nil, only partitions present in the include map +// are grouped (keyed by fetchTopicKey). Partitions with no known owner are +// grouped under "" for round-robin fallback. +func (p *proxy) groupFetchPartitionsByBroker(req *protocol.FetchRequest, include map[string]map[int32]bool) map[string]*protocol.FetchRequest { + groups := make(map[string]*protocol.FetchRequest) + topicIndices := make(map[string]map[string]int) // addr -> topicKey -> index in subReq.Topics + + for _, topic := range req.Topics { + topicName := topic.Name + key := fetchTopicKey(topicName, topic.TopicID) + var includeParts map[int32]bool + if include != nil { + includeParts = include[key] + if len(includeParts) == 0 { + continue + } + } + for _, part := range topic.Partitions { + if includeParts != nil && !includeParts[part.Partition] { + continue + } + addr := "" + if p.router != nil && topicName != "" { + if ownerID := p.router.LookupOwner(topicName, part.Partition); ownerID != "" { + addr = p.brokerIDToAddr(ownerID) + } + } + subReq, ok := groups[addr] + if !ok { + subReq = &protocol.FetchRequest{ + ReplicaID: req.ReplicaID, + MaxWaitMs: req.MaxWaitMs, + MinBytes: req.MinBytes, + MaxBytes: req.MaxBytes, + IsolationLevel: req.IsolationLevel, + SessionID: req.SessionID, + SessionEpoch: req.SessionEpoch, + } + groups[addr] = subReq + topicIndices[addr] = make(map[string]int) + } + idx, ok := topicIndices[addr][key] + if !ok { + idx = len(subReq.Topics) + subReq.Topics = append(subReq.Topics, protocol.FetchTopicRequest{ + Name: topic.Name, + TopicID: topic.TopicID, + }) + topicIndices[addr][key] = idx + } + subReq.Topics[idx].Partitions = append(subReq.Topics[idx].Partitions, part) + } + } + return groups +} + +type fetchFanOutResult struct { + subReq *protocol.FetchRequest + subResp *protocol.FetchResponse + conn net.Conn + target string + err error +} + +// forwardFetch splits a fetch request by broker, forwards each sub-request +// concurrently, and merges the responses. If any partitions are rejected with +// NOT_LEADER_OR_FOLLOWER, those partitions are retried on a different broker. +func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader, fullReq *protocol.FetchRequest, originalPayload []byte, groups map[string]*protocol.FetchRequest, pool *connPool) ([]byte, error) { + const maxRetries = 3 + + merged := &protocol.FetchResponse{ + CorrelationID: header.CorrelationID, + SessionID: fullReq.SessionID, + } + + // failedPartitions is keyed by fetchTopicKey (topic name or hex topic ID) + // to avoid collisions when multiple v12+ topics have unresolved names. + var failedPartitions map[string]map[int32]bool + for attempt := 0; attempt < maxRetries; attempt++ { + failedPartitions = nil + // Scope triedBackends per attempt so that retries can revisit brokers + // from earlier attempts. Without this, with N brokers all N get excluded + // after the first attempt and subsequent retries always fail to connect. + triedBackends := make(map[string]bool) + subResults := p.fanOutFetch(ctx, header, groups, originalPayload, triedBackends, pool) + + for _, r := range subResults { + if r.err != nil { + p.logger.Warn("fetch forward failed", "target", r.target, "error", r.err) + addFetchErrorForAllPartitions(merged, r.subReq, protocol.REQUEST_TIMED_OUT) + continue + } + if r.conn != nil { + pool.Return(r.target, r.conn) + } + if r.subResp.ErrorCode != 0 { + merged.ErrorCode = r.subResp.ErrorCode + } + for _, topic := range r.subResp.Topics { + for _, part := range topic.Partitions { + if part.ErrorCode == protocol.NOT_LEADER_OR_FOLLOWER { + topicName := topic.Name + if topicName == "" { + topicName = p.resolveTopicID(topic.TopicID) + } + key := fetchTopicKey(topicName, topic.TopicID) + if failedPartitions == nil { + failedPartitions = make(map[string]map[int32]bool) + } + if failedPartitions[key] == nil { + failedPartitions[key] = make(map[int32]bool) + } + failedPartitions[key][part.Partition] = true + if p.router != nil && topicName != "" { + p.router.Invalidate(topicName, part.Partition) + } + } else { + tr := findOrAddFetchTopicResponse(merged, topic.Name, topic.TopicID) + tr.Partitions = append(tr.Partitions, part) + } + } + } + if r.subResp.ThrottleMs > merged.ThrottleMs { + merged.ThrottleMs = r.subResp.ThrottleMs + } + } + + if len(failedPartitions) == 0 { + return protocol.EncodeFetchResponse(merged, header.APIVersion) + } + + groups = p.groupFetchPartitionsByBroker(fullReq, failedPartitions) + originalPayload = nil + if len(groups) == 0 { + break + } + p.logger.Debug("retrying NOT_LEADER fetch partitions", "attempt", attempt+1, "partitions", len(failedPartitions)) + } + + // Fill remaining failed partitions with errors. + for _, topic := range fullReq.Topics { + key := fetchTopicKey(topic.Name, topic.TopicID) + failedParts, ok := failedPartitions[key] + if !ok { + continue + } + tr := findOrAddFetchTopicResponse(merged, topic.Name, topic.TopicID) + for _, part := range topic.Partitions { + if failedParts[part.Partition] { + tr.Partitions = append(tr.Partitions, protocol.FetchPartitionResponse{ + Partition: part.Partition, + ErrorCode: protocol.NOT_LEADER_OR_FOLLOWER, + }) + } + } + } + return protocol.EncodeFetchResponse(merged, header.APIVersion) +} + +// fanOutFetch borrows connections and forwards fetch sub-requests concurrently. +func (p *proxy) fanOutFetch(ctx context.Context, header *protocol.RequestHeader, groups map[string]*protocol.FetchRequest, originalPayload []byte, triedBackends map[string]bool, pool *connPool) []fetchFanOutResult { + type workItem struct { + subReq *protocol.FetchRequest + conn net.Conn + target string + payload []byte + } + work := make([]workItem, 0, len(groups)) + var connectErrors []fetchFanOutResult + + canUseOriginal := originalPayload != nil && len(groups) == 1 + for addr, subReq := range groups { + conn, targetAddr, err := p.connectForAddr(ctx, addr, triedBackends, pool) + if err != nil { + connectErrors = append(connectErrors, fetchFanOutResult{subReq: subReq, target: addr, err: err}) + continue + } + triedBackends[targetAddr] = true + + var payload []byte + if canUseOriginal { + payload = originalPayload + } else { + encoded, encErr := protocol.EncodeFetchRequest(header, subReq, header.APIVersion) + if encErr != nil { + conn.Close() + connectErrors = append(connectErrors, fetchFanOutResult{subReq: subReq, target: targetAddr, err: encErr}) + continue + } + payload = encoded + } + work = append(work, workItem{subReq: subReq, conn: conn, target: targetAddr, payload: payload}) + } + + results := make([]fetchFanOutResult, len(work)) + var wg sync.WaitGroup + for i := range work { + i := i + w := work[i] + wg.Add(1) + go func() { + defer wg.Done() + respBytes, err := p.forwardToBackend(ctx, w.conn, w.target, w.payload) + if err != nil { + w.conn.Close() + results[i] = fetchFanOutResult{subReq: w.subReq, target: w.target, err: err} + return + } + subResp, parseErr := protocol.ParseFetchResponse(respBytes, header.APIVersion) + if parseErr != nil { + w.conn.Close() + results[i] = fetchFanOutResult{subReq: w.subReq, target: w.target, err: parseErr} + return + } + results[i] = fetchFanOutResult{subReq: w.subReq, subResp: subResp, conn: w.conn, target: w.target} + }() + } + wg.Wait() + + return append(connectErrors, results...) +} + +func findOrAddFetchTopicResponse(resp *protocol.FetchResponse, name string, topicID [16]byte) *protocol.FetchTopicResponse { + var zeroID [16]byte + for i := range resp.Topics { + if topicID != zeroID { + if resp.Topics[i].TopicID == topicID { + return &resp.Topics[i] + } + } else { + if resp.Topics[i].Name == name { + return &resp.Topics[i] + } + } + } + resp.Topics = append(resp.Topics, protocol.FetchTopicResponse{Name: name, TopicID: topicID}) + return &resp.Topics[len(resp.Topics)-1] +} + +func addFetchErrorForAllPartitions(resp *protocol.FetchResponse, req *protocol.FetchRequest, errorCode int16) { + for _, topic := range req.Topics { + tr := findOrAddFetchTopicResponse(resp, topic.Name, topic.TopicID) + for _, part := range topic.Partitions { + tr.Partitions = append(tr.Partitions, protocol.FetchPartitionResponse{ + Partition: part.Partition, + ErrorCode: errorCode, + }) + } + } +} diff --git a/cmd/proxy/main_test.go b/cmd/proxy/main_test.go index 6f8300c..97a1cd2 100644 --- a/cmd/proxy/main_test.go +++ b/cmd/proxy/main_test.go @@ -749,3 +749,301 @@ func TestExtractGroupID(t *testing.T) { }) } } + +// --- Fetch routing tests --- + +func makeFetchRequest(topics map[string][]int32) *protocol.FetchRequest { + req := &protocol.FetchRequest{ + ReplicaID: -1, + MaxWaitMs: 500, + MinBytes: 1, + MaxBytes: 1048576, + SessionID: 0, + SessionEpoch: -1, + } + for name, parts := range topics { + topic := protocol.FetchTopicRequest{Name: name} + for _, p := range parts { + topic.Partitions = append(topic.Partitions, protocol.FetchPartitionRequest{ + Partition: p, + FetchOffset: 0, + MaxBytes: 1048576, + }) + } + req.Topics = append(req.Topics, topic) + } + return req +} + +func countFetchPartitions(req *protocol.FetchRequest) int { + n := 0 + for _, t := range req.Topics { + n += len(t.Partitions) + } + return n +} + +func fetchMapKeys(m map[string]*protocol.FetchRequest) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +func TestGroupFetchPartitionsByBrokerNoRouter(t *testing.T) { + p := &proxy{} + req := makeFetchRequest(map[string][]int32{ + "orders": {0, 1, 2}, + "events": {0}, + }) + groups := p.groupFetchPartitionsByBroker(req, nil) + if len(groups) != 1 { + t.Fatalf("expected 1 group (all round-robin), got %d", len(groups)) + } + rr, ok := groups[""] + if !ok { + t.Fatalf("expected round-robin group (key=\"\"), got keys: %v", fetchMapKeys(groups)) + } + if countFetchPartitions(rr) != 4 { + t.Fatalf("expected 4 total partitions, got %d", countFetchPartitions(rr)) + } + if rr.MaxWaitMs != 500 || rr.MaxBytes != 1048576 { + t.Fatalf("sub-request should preserve settings: got maxWait=%d maxBytes=%d", rr.MaxWaitMs, rr.MaxBytes) + } +} + +func TestGroupFetchPartitionsByBrokerNoRouterMultipleTopics(t *testing.T) { + p := &proxy{} + req := makeFetchRequest(map[string][]int32{ + "orders": {0, 1}, + "events": {0, 1, 2}, + }) + groups := p.groupFetchPartitionsByBroker(req, nil) + if len(groups) != 1 { + t.Fatalf("expected 1 group, got %d", len(groups)) + } + rr := groups[""] + if rr == nil { + t.Fatalf("expected round-robin group") + } + if countFetchPartitions(rr) != 5 { + t.Fatalf("expected 5 partitions, got %d", countFetchPartitions(rr)) + } + topicNames := make(map[string]int) + for _, topic := range rr.Topics { + topicNames[topic.Name] = len(topic.Partitions) + } + if topicNames["orders"] != 2 || topicNames["events"] != 3 { + t.Fatalf("unexpected topic grouping: %v", topicNames) + } +} + +func TestGroupFetchPartitionsByBrokerFiltersCorrectly(t *testing.T) { + p := &proxy{} + req := makeFetchRequest(map[string][]int32{ + "orders": {0, 1, 2}, + "events": {0, 1}, + }) + include := map[string]map[int32]bool{ + "orders": {1: true}, + "events": {0: true}, + } + groups := p.groupFetchPartitionsByBroker(req, include) + if len(groups) != 1 { + t.Fatalf("expected 1 group (no router), got %d", len(groups)) + } + rr := groups[""] + if rr == nil { + t.Fatalf("missing round-robin group") + } + if countFetchPartitions(rr) != 2 { + t.Fatalf("expected 2 filtered partitions, got %d", countFetchPartitions(rr)) + } +} + +func TestFindOrAddFetchTopicResponse(t *testing.T) { + resp := &protocol.FetchResponse{} + topicID := [16]byte{1, 2, 3} + + tr := findOrAddFetchTopicResponse(resp, "orders", topicID) + tr.Partitions = append(tr.Partitions, protocol.FetchPartitionResponse{Partition: 0}) + + // Same topic should return the existing entry. + tr2 := findOrAddFetchTopicResponse(resp, "orders", topicID) + if len(tr2.Partitions) != 1 { + t.Fatalf("expected 1 partition in existing topic, got %d", len(tr2.Partitions)) + } + + // Different topic. + tr3 := findOrAddFetchTopicResponse(resp, "events", [16]byte{4, 5, 6}) + if len(tr3.Partitions) != 0 { + t.Fatalf("expected 0 partitions in new topic, got %d", len(tr3.Partitions)) + } + if len(resp.Topics) != 2 { + t.Fatalf("expected 2 topics in response, got %d", len(resp.Topics)) + } + + // v12+: same topicID but different name should match on topicID alone. + tr4 := findOrAddFetchTopicResponse(resp, "", topicID) + tr4.Partitions = append(tr4.Partitions, protocol.FetchPartitionResponse{Partition: 1}) + if len(resp.Topics) != 2 { + t.Fatalf("expected 2 topics after topicID-only match, got %d", len(resp.Topics)) + } + // The entry found by topicID should now have 2 partitions (0 from before, 1 just added). + tr5 := findOrAddFetchTopicResponse(resp, "orders", topicID) + if len(tr5.Partitions) != 2 { + t.Fatalf("expected 2 partitions after topicID-only merge, got %d", len(tr5.Partitions)) + } + + // Name-only match (zero topicID) should work for pre-v12 topics. + tr6 := findOrAddFetchTopicResponse(resp, "logs", [16]byte{}) + tr6.Partitions = append(tr6.Partitions, protocol.FetchPartitionResponse{Partition: 0}) + tr7 := findOrAddFetchTopicResponse(resp, "logs", [16]byte{}) + if len(tr7.Partitions) != 1 { + t.Fatalf("expected 1 partition for name-only topic, got %d", len(tr7.Partitions)) + } + if len(resp.Topics) != 3 { + t.Fatalf("expected 3 topics total, got %d", len(resp.Topics)) + } +} + +func TestAddFetchErrorForAllPartitions(t *testing.T) { + resp := &protocol.FetchResponse{} + req := makeFetchRequest(map[string][]int32{ + "orders": {0, 1}, + "events": {0}, + }) + addFetchErrorForAllPartitions(resp, req, protocol.REQUEST_TIMED_OUT) + + if len(resp.Topics) != 2 { + t.Fatalf("expected 2 topics, got %d", len(resp.Topics)) + } + total := 0 + for _, topic := range resp.Topics { + for _, part := range topic.Partitions { + if part.ErrorCode != protocol.REQUEST_TIMED_OUT { + t.Fatalf("expected error %d, got %d", protocol.REQUEST_TIMED_OUT, part.ErrorCode) + } + total++ + } + } + if total != 3 { + t.Fatalf("expected 3 partition errors, got %d", total) + } +} + +func TestUpdateTopicNames(t *testing.T) { + p := &proxy{topicNames: make(map[[16]byte]string)} + topicID1 := [16]byte{1, 2, 3} + topicID2 := [16]byte{4, 5, 6} + topics := []protocol.MetadataTopic{ + {Name: "orders", TopicID: topicID1}, + {Name: "events", TopicID: topicID2}, + {Name: "", TopicID: [16]byte{}}, // should be skipped + } + p.updateTopicNames(topics) + + if got := p.resolveTopicID(topicID1); got != "orders" { + t.Fatalf("resolveTopicID(1): got %q, want %q", got, "orders") + } + if got := p.resolveTopicID(topicID2); got != "events" { + t.Fatalf("resolveTopicID(2): got %q, want %q", got, "events") + } + if got := p.resolveTopicID([16]byte{9, 9, 9}); got != "" { + t.Fatalf("resolveTopicID(unknown): got %q, want %q", got, "") + } +} + +func TestGroupFetchPartitionsByBrokerUnresolvedTopicIDs(t *testing.T) { + // When multiple topics have unresolved names (empty string) but different + // topic IDs, they must not be merged into a single FetchTopicRequest. + idA := [16]byte{1, 2, 3} + idB := [16]byte{4, 5, 6} + p := &proxy{} + req := &protocol.FetchRequest{ + ReplicaID: -1, + MaxWaitMs: 500, + MinBytes: 1, + MaxBytes: 1048576, + SessionEpoch: -1, + Topics: []protocol.FetchTopicRequest{ + {TopicID: idA, Partitions: []protocol.FetchPartitionRequest{{Partition: 0, MaxBytes: 1048576}}}, + {TopicID: idB, Partitions: []protocol.FetchPartitionRequest{{Partition: 0, MaxBytes: 1048576}}}, + }, + } + groups := p.groupFetchPartitionsByBroker(req, nil) + rr := groups[""] + if rr == nil { + t.Fatal("expected round-robin group") + } + if len(rr.Topics) != 2 { + t.Fatalf("expected 2 topics (separate entries for different IDs), got %d", len(rr.Topics)) + } + if rr.Topics[0].TopicID != idA || rr.Topics[1].TopicID != idB { + t.Fatalf("topic IDs not preserved: got %x and %x", rr.Topics[0].TopicID, rr.Topics[1].TopicID) + } +} + +func TestGroupFetchPartitionsByBrokerUnresolvedFilter(t *testing.T) { + // Verify that the include filter works correctly with fetchTopicKey for + // unresolved topic IDs. + idA := [16]byte{1, 2, 3} + idB := [16]byte{4, 5, 6} + p := &proxy{} + req := &protocol.FetchRequest{ + ReplicaID: -1, + MaxWaitMs: 500, + MinBytes: 1, + MaxBytes: 1048576, + SessionEpoch: -1, + Topics: []protocol.FetchTopicRequest{ + {TopicID: idA, Partitions: []protocol.FetchPartitionRequest{ + {Partition: 0, MaxBytes: 1048576}, + {Partition: 1, MaxBytes: 1048576}, + }}, + {TopicID: idB, Partitions: []protocol.FetchPartitionRequest{ + {Partition: 0, MaxBytes: 1048576}, + }}, + }, + } + // Only retry partition 1 of topic A. + include := map[string]map[int32]bool{ + fetchTopicKey("", idA): {1: true}, + } + groups := p.groupFetchPartitionsByBroker(req, include) + rr := groups[""] + if rr == nil { + t.Fatal("expected round-robin group") + } + if len(rr.Topics) != 1 { + t.Fatalf("expected 1 topic after filter, got %d", len(rr.Topics)) + } + if rr.Topics[0].TopicID != idA { + t.Fatalf("expected topic A, got %x", rr.Topics[0].TopicID) + } + if len(rr.Topics[0].Partitions) != 1 || rr.Topics[0].Partitions[0].Partition != 1 { + t.Fatalf("expected only partition 1, got %v", rr.Topics[0].Partitions) + } +} + +func TestResolveFetchTopicNames(t *testing.T) { + topicID := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + p := &proxy{ + topicNames: map[[16]byte]string{topicID: "orders"}, + } + req := &protocol.FetchRequest{ + Topics: []protocol.FetchTopicRequest{ + {TopicID: topicID}, // name not set, should be resolved + {Name: "events"}, // already has name, should be left alone + }, + } + p.resolveFetchTopicNames(req) + + if req.Topics[0].Name != "orders" { + t.Fatalf("topic[0] name: got %q, want %q", req.Topics[0].Name, "orders") + } + if req.Topics[1].Name != "events" { + t.Fatalf("topic[1] name: got %q, want %q", req.Topics[1].Name, "events") + } +} diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index af8580a..8e51d4a 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -1790,6 +1790,100 @@ func ParseRequest(b []byte) (*RequestHeader, Request, error) { return header, req, nil } +// EncodeFetchRequest serializes a RequestHeader + FetchRequest into wire-format +// bytes suitable for WriteFrame. The encoding mirrors what ParseRequest expects. +func EncodeFetchRequest(header *RequestHeader, req *FetchRequest, version int16) ([]byte, error) { + w := newByteWriter(256) + flexible := isFlexibleRequest(APIKeyFetch, version) + + w.Int16(header.APIKey) + w.Int16(header.APIVersion) + w.Int32(header.CorrelationID) + w.NullableString(header.ClientID) + if flexible { + w.WriteTaggedFields(0) + } + + w.Int32(req.ReplicaID) + w.Int32(req.MaxWaitMs) + w.Int32(req.MinBytes) + if version >= 3 { + w.Int32(req.MaxBytes) + } + if version >= 4 { + w.Int8(req.IsolationLevel) + } + if version >= 7 { + w.Int32(req.SessionID) + w.Int32(req.SessionEpoch) + } + + if flexible { + w.CompactArrayLen(len(req.Topics)) + } else { + w.Int32(int32(len(req.Topics))) + } + for _, topic := range req.Topics { + if version >= 12 { + w.UUID(topic.TopicID) + } else { + if flexible { + w.CompactString(topic.Name) + } else { + w.String(topic.Name) + } + } + if flexible { + w.CompactArrayLen(len(topic.Partitions)) + } else { + w.Int32(int32(len(topic.Partitions))) + } + for _, part := range topic.Partitions { + w.Int32(part.Partition) + if version >= 9 { + w.Int32(-1) // leader epoch (unknown) + } + w.Int64(part.FetchOffset) + if version >= 12 { + w.Int32(-1) // last fetched epoch (unknown) + } + if version >= 5 { + w.Int64(-1) // log start offset + } + w.Int32(part.MaxBytes) + if flexible { + w.WriteTaggedFields(0) + } + } + if flexible { + w.WriteTaggedFields(0) + } + } + + // Forgotten topics (empty) + if version >= 7 { + if flexible { + w.CompactArrayLen(0) + } else { + w.Int32(0) + } + } + + // Rack ID (v11+) + if version >= 11 { + if flexible { + w.CompactString("") + } else { + w.String("") + } + } + + if flexible { + w.WriteTaggedFields(0) + } + return w.Bytes(), nil +} + // EncodeProduceRequest serializes a RequestHeader + ProduceRequest into wire-format // bytes suitable for WriteFrame. The encoding mirrors what ParseRequest expects. func EncodeProduceRequest(header *RequestHeader, req *ProduceRequest, version int16) ([]byte, error) { diff --git a/pkg/protocol/request_test.go b/pkg/protocol/request_test.go index 54efc23..1323741 100644 --- a/pkg/protocol/request_test.go +++ b/pkg/protocol/request_test.go @@ -849,3 +849,187 @@ func TestParseFetchRequest(t *testing.T) { t.Fatalf("unexpected fetch data: %#v", fetchReq.Topics) } } + +func TestEncodeFetchRequest_RoundTrip(t *testing.T) { + tests := []struct { + name string + version int16 + req *FetchRequest + topicID [16]byte + }{ + { + name: "v11 name-based", + version: 11, + req: &FetchRequest{ + ReplicaID: -1, + MaxWaitMs: 500, + MinBytes: 1, + MaxBytes: 1048576, + IsolationLevel: 0, + SessionID: 0, + SessionEpoch: -1, + Topics: []FetchTopicRequest{ + { + Name: "orders", + Partitions: []FetchPartitionRequest{ + {Partition: 0, FetchOffset: 10, MaxBytes: 1048576}, + {Partition: 1, FetchOffset: 20, MaxBytes: 1048576}, + }, + }, + { + Name: "events", + Partitions: []FetchPartitionRequest{ + {Partition: 0, FetchOffset: 0, MaxBytes: 524288}, + }, + }, + }, + }, + }, + { + name: "v13 topic-id-based", + version: 13, + topicID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + req: &FetchRequest{ + ReplicaID: -1, + MaxWaitMs: 500, + MinBytes: 1, + MaxBytes: 1048576, + IsolationLevel: 1, + SessionID: 42, + SessionEpoch: 3, + Topics: []FetchTopicRequest{ + { + TopicID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + Partitions: []FetchPartitionRequest{ + {Partition: 0, FetchOffset: 100, MaxBytes: 1048576}, + }, + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + header := &RequestHeader{ + APIKey: APIKeyFetch, + APIVersion: tc.version, + CorrelationID: 42, + ClientID: strPtr("test-client"), + } + encoded, err := EncodeFetchRequest(header, tc.req, tc.version) + if err != nil { + t.Fatalf("EncodeFetchRequest: %v", err) + } + + parsedHeader, parsedReq, err := ParseRequest(encoded) + if err != nil { + t.Fatalf("ParseRequest: %v", err) + } + if parsedHeader.APIKey != APIKeyFetch { + t.Fatalf("expected APIKeyFetch, got %d", parsedHeader.APIKey) + } + if parsedHeader.CorrelationID != 42 { + t.Fatalf("expected correlation 42, got %d", parsedHeader.CorrelationID) + } + + fetchReq, ok := parsedReq.(*FetchRequest) + if !ok { + t.Fatalf("expected *FetchRequest, got %T", parsedReq) + } + if fetchReq.MaxWaitMs != tc.req.MaxWaitMs { + t.Fatalf("MaxWaitMs: got %d, want %d", fetchReq.MaxWaitMs, tc.req.MaxWaitMs) + } + if fetchReq.SessionID != tc.req.SessionID { + t.Fatalf("SessionID: got %d, want %d", fetchReq.SessionID, tc.req.SessionID) + } + if len(fetchReq.Topics) != len(tc.req.Topics) { + t.Fatalf("topic count: got %d, want %d", len(fetchReq.Topics), len(tc.req.Topics)) + } + for ti, topic := range fetchReq.Topics { + wantTopic := tc.req.Topics[ti] + if tc.version >= 12 { + if topic.TopicID != wantTopic.TopicID { + t.Fatalf("topic[%d] ID mismatch", ti) + } + } else { + if topic.Name != wantTopic.Name { + t.Fatalf("topic[%d] name: got %q, want %q", ti, topic.Name, wantTopic.Name) + } + } + if len(topic.Partitions) != len(wantTopic.Partitions) { + t.Fatalf("topic[%d] partition count: got %d, want %d", ti, len(topic.Partitions), len(wantTopic.Partitions)) + } + for pi, part := range topic.Partitions { + wantPart := wantTopic.Partitions[pi] + if part.Partition != wantPart.Partition { + t.Fatalf("topic[%d] part[%d] id: got %d, want %d", ti, pi, part.Partition, wantPart.Partition) + } + if part.FetchOffset != wantPart.FetchOffset { + t.Fatalf("topic[%d] part[%d] offset: got %d, want %d", ti, pi, part.FetchOffset, wantPart.FetchOffset) + } + if part.MaxBytes != wantPart.MaxBytes { + t.Fatalf("topic[%d] part[%d] maxBytes: got %d, want %d", ti, pi, part.MaxBytes, wantPart.MaxBytes) + } + } + } + }) + } +} + +func TestEncodeFetchRequest_KmsgValidation(t *testing.T) { + // Encode a v13 request and validate it parses with franz-go's kmsg. + header := &RequestHeader{ + APIKey: APIKeyFetch, + APIVersion: 13, + CorrelationID: 99, + ClientID: strPtr("kmsg-test"), + } + topicID := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + req := &FetchRequest{ + ReplicaID: -1, + MaxWaitMs: 500, + MinBytes: 1, + MaxBytes: 1048576, + IsolationLevel: 0, + SessionID: 0, + SessionEpoch: -1, + Topics: []FetchTopicRequest{ + { + TopicID: topicID, + Partitions: []FetchPartitionRequest{ + {Partition: 0, FetchOffset: 42, MaxBytes: 1048576}, + }, + }, + }, + } + encoded, err := EncodeFetchRequest(header, req, 13) + if err != nil { + t.Fatalf("EncodeFetchRequest: %v", err) + } + + // Use ParseRequestHeader to find where the body starts (same as the real code). + _, reader, err := ParseRequestHeader(encoded) + if err != nil { + t.Fatalf("ParseRequestHeader: %v", err) + } + bodyStart := len(encoded) - reader.remaining() + + kmsgReq := kmsg.NewPtrFetchRequest() + kmsgReq.Version = 13 + if err := kmsgReq.ReadFrom(encoded[bodyStart:]); err != nil { + t.Fatalf("kmsg.ReadFrom: %v", err) + } + if len(kmsgReq.Topics) != 1 { + t.Fatalf("expected 1 topic, got %d", len(kmsgReq.Topics)) + } + if kmsgReq.Topics[0].TopicID != topicID { + t.Fatalf("topic ID mismatch") + } + if len(kmsgReq.Topics[0].Partitions) != 1 { + t.Fatalf("expected 1 partition, got %d", len(kmsgReq.Topics[0].Partitions)) + } + if kmsgReq.Topics[0].Partitions[0].FetchOffset != 42 { + t.Fatalf("fetch offset: got %d, want 42", kmsgReq.Topics[0].Partitions[0].FetchOffset) + } +} diff --git a/pkg/protocol/response.go b/pkg/protocol/response.go index fccbb44..93c2201 100644 --- a/pkg/protocol/response.go +++ b/pkg/protocol/response.go @@ -825,6 +825,202 @@ func EncodeFetchResponse(resp *FetchResponse, version int16) ([]byte, error) { return w.Bytes(), nil } +// ParseFetchResponse decodes a fetch response from wire-format bytes. +// This is the inverse of EncodeFetchResponse. +func ParseFetchResponse(payload []byte, version int16) (*FetchResponse, error) { + if version < 1 || version > 13 { + return nil, fmt.Errorf("fetch response version %d not supported", version) + } + r := newByteReader(payload) + flexible := version >= 12 + + corrID, err := r.Int32() + if err != nil { + return nil, fmt.Errorf("read correlation id: %w", err) + } + if flexible { + if err := r.SkipTaggedFields(); err != nil { + return nil, fmt.Errorf("skip response header tags: %w", err) + } + } + + throttleMs, err := r.Int32() + if err != nil { + return nil, fmt.Errorf("read throttle ms: %w", err) + } + + var errorCode int16 + var sessionID int32 + if version >= 7 { + errorCode, err = r.Int16() + if err != nil { + return nil, fmt.Errorf("read error code: %w", err) + } + sessionID, err = r.Int32() + if err != nil { + return nil, fmt.Errorf("read session id: %w", err) + } + } + + var topicCount int32 + if flexible { + topicCount, err = compactArrayLenNonNull(r) + } else { + topicCount, err = r.Int32() + } + if err != nil { + return nil, fmt.Errorf("read topic count: %w", err) + } + + topics := make([]FetchTopicResponse, 0, topicCount) + for i := int32(0); i < topicCount; i++ { + var ( + name string + topicID [16]byte + ) + if flexible { + topicID, err = r.UUID() + if err != nil { + return nil, fmt.Errorf("read topic id: %w", err) + } + } else { + name, err = r.String() + if err != nil { + return nil, fmt.Errorf("read topic name: %w", err) + } + } + + var partCount int32 + if flexible { + partCount, err = compactArrayLenNonNull(r) + } else { + partCount, err = r.Int32() + } + if err != nil { + return nil, fmt.Errorf("read partition count: %w", err) + } + + partitions := make([]FetchPartitionResponse, 0, partCount) + for j := int32(0); j < partCount; j++ { + partIdx, err := r.Int32() + if err != nil { + return nil, fmt.Errorf("read partition index: %w", err) + } + ec, err := r.Int16() + if err != nil { + return nil, fmt.Errorf("read partition error code: %w", err) + } + highWatermark, err := r.Int64() + if err != nil { + return nil, fmt.Errorf("read high watermark: %w", err) + } + var lastStableOffset, logStartOffset int64 + if version >= 4 { + lastStableOffset, err = r.Int64() + if err != nil { + return nil, fmt.Errorf("read last stable offset: %w", err) + } + } + if version >= 5 { + logStartOffset, err = r.Int64() + if err != nil { + return nil, fmt.Errorf("read log start offset: %w", err) + } + } + + var abortedTransactions []FetchAbortedTransaction + if version >= 4 { + var abortedCount int32 + // Nullable: brokers may return null (no aborted transactions). + if flexible { + abortedCount, err = r.CompactArrayLen() + } else { + abortedCount, err = r.Int32() + } + if err != nil { + return nil, fmt.Errorf("read aborted count: %w", err) + } + if abortedCount > 0 { + abortedTransactions = make([]FetchAbortedTransaction, 0, abortedCount) + for k := int32(0); k < abortedCount; k++ { + producerID, err := r.Int64() + if err != nil { + return nil, fmt.Errorf("read aborted producer id: %w", err) + } + firstOffset, err := r.Int64() + if err != nil { + return nil, fmt.Errorf("read aborted first offset: %w", err) + } + abortedTransactions = append(abortedTransactions, FetchAbortedTransaction{ + ProducerID: producerID, + FirstOffset: firstOffset, + }) + } + } + } + + var preferredReadReplica int32 + if version >= 11 { + preferredReadReplica, err = r.Int32() + if err != nil { + return nil, fmt.Errorf("read preferred read replica: %w", err) + } + } + + var recordSet []byte + if flexible { + recordSet, err = r.CompactBytes() + } else { + recordSet, err = r.Bytes() + } + if err != nil { + return nil, fmt.Errorf("read record set: %w", err) + } + + if flexible { + if err := r.SkipTaggedFields(); err != nil { + return nil, fmt.Errorf("skip partition tags: %w", err) + } + } + + partitions = append(partitions, FetchPartitionResponse{ + Partition: partIdx, + ErrorCode: ec, + HighWatermark: highWatermark, + LastStableOffset: lastStableOffset, + LogStartOffset: logStartOffset, + PreferredReadReplica: preferredReadReplica, + RecordSet: recordSet, + AbortedTransactions: abortedTransactions, + }) + } + + if flexible { + if err := r.SkipTaggedFields(); err != nil { + return nil, fmt.Errorf("skip topic tags: %w", err) + } + } + + topics = append(topics, FetchTopicResponse{ + Name: name, + TopicID: topicID, + Partitions: partitions, + }) + } + + if flexible { + _ = r.SkipTaggedFields() + } + + return &FetchResponse{ + CorrelationID: corrID, + ThrottleMs: throttleMs, + ErrorCode: errorCode, + SessionID: sessionID, + Topics: topics, + }, nil +} + func EncodeCreateTopicsResponse(resp *CreateTopicsResponse, version int16) ([]byte, error) { if version < 0 || version > 2 { return nil, fmt.Errorf("create topics response version %d not supported", version) diff --git a/pkg/protocol/response_test.go b/pkg/protocol/response_test.go index 476d171..f0e340d 100644 --- a/pkg/protocol/response_test.go +++ b/pkg/protocol/response_test.go @@ -1843,6 +1843,160 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { } } +func TestParseFetchResponse_RoundTrip(t *testing.T) { + tests := []struct { + name string + version int16 + resp *FetchResponse + }{ + { + name: "v11 name-based", + version: 11, + resp: &FetchResponse{ + CorrelationID: 7, + ThrottleMs: 0, + ErrorCode: NONE, + SessionID: 0, + Topics: []FetchTopicResponse{ + { + Name: "orders", + Partitions: []FetchPartitionResponse{ + { + Partition: 0, + ErrorCode: NONE, + HighWatermark: 100, + LastStableOffset: 100, + LogStartOffset: 0, + PreferredReadReplica: -1, + RecordSet: []byte("test-records"), + }, + { + Partition: 1, + ErrorCode: NOT_LEADER_OR_FOLLOWER, + HighWatermark: 0, + LastStableOffset: 0, + LogStartOffset: 0, + PreferredReadReplica: -1, + RecordSet: []byte{}, + }, + }, + }, + }, + }, + }, + { + name: "v13 topic-id-based", + version: 13, + resp: &FetchResponse{ + CorrelationID: 11, + ThrottleMs: 5, + ErrorCode: NONE, + SessionID: 42, + Topics: []FetchTopicResponse{ + { + TopicID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + Partitions: []FetchPartitionResponse{ + { + Partition: 0, + ErrorCode: NONE, + HighWatermark: 50, + LastStableOffset: 50, + LogStartOffset: 0, + PreferredReadReplica: -1, + RecordSet: []byte("hello"), + }, + }, + }, + }, + }, + }, + { + name: "v11 multiple topics", + version: 11, + resp: &FetchResponse{ + CorrelationID: 99, + ThrottleMs: 0, + ErrorCode: NONE, + SessionID: 0, + Topics: []FetchTopicResponse{ + { + Name: "orders", + Partitions: []FetchPartitionResponse{ + {Partition: 0, ErrorCode: NONE, HighWatermark: 10, LastStableOffset: 10, PreferredReadReplica: -1, RecordSet: []byte("a")}, + }, + }, + { + Name: "events", + Partitions: []FetchPartitionResponse{ + {Partition: 0, ErrorCode: NONE, HighWatermark: 20, LastStableOffset: 20, PreferredReadReplica: -1, RecordSet: []byte("b")}, + }, + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + encoded, err := EncodeFetchResponse(tc.resp, tc.version) + if err != nil { + t.Fatalf("EncodeFetchResponse: %v", err) + } + + parsed, err := ParseFetchResponse(encoded, tc.version) + if err != nil { + t.Fatalf("ParseFetchResponse: %v", err) + } + + if parsed.CorrelationID != tc.resp.CorrelationID { + t.Fatalf("CorrelationID: got %d, want %d", parsed.CorrelationID, tc.resp.CorrelationID) + } + if parsed.ThrottleMs != tc.resp.ThrottleMs { + t.Fatalf("ThrottleMs: got %d, want %d", parsed.ThrottleMs, tc.resp.ThrottleMs) + } + if parsed.ErrorCode != tc.resp.ErrorCode { + t.Fatalf("ErrorCode: got %d, want %d", parsed.ErrorCode, tc.resp.ErrorCode) + } + if parsed.SessionID != tc.resp.SessionID { + t.Fatalf("SessionID: got %d, want %d", parsed.SessionID, tc.resp.SessionID) + } + if len(parsed.Topics) != len(tc.resp.Topics) { + t.Fatalf("topic count: got %d, want %d", len(parsed.Topics), len(tc.resp.Topics)) + } + for ti, topic := range parsed.Topics { + wantTopic := tc.resp.Topics[ti] + if tc.version >= 12 { + if topic.TopicID != wantTopic.TopicID { + t.Fatalf("topic[%d] ID mismatch", ti) + } + } else { + if topic.Name != wantTopic.Name { + t.Fatalf("topic[%d] name: got %q, want %q", ti, topic.Name, wantTopic.Name) + } + } + if len(topic.Partitions) != len(wantTopic.Partitions) { + t.Fatalf("topic[%d] partition count: got %d, want %d", ti, len(topic.Partitions), len(wantTopic.Partitions)) + } + for pi, part := range topic.Partitions { + wantPart := wantTopic.Partitions[pi] + if part.Partition != wantPart.Partition { + t.Fatalf("topic[%d] part[%d]: got %d, want %d", ti, pi, part.Partition, wantPart.Partition) + } + if part.ErrorCode != wantPart.ErrorCode { + t.Fatalf("topic[%d] part[%d] error: got %d, want %d", ti, pi, part.ErrorCode, wantPart.ErrorCode) + } + if part.HighWatermark != wantPart.HighWatermark { + t.Fatalf("topic[%d] part[%d] HW: got %d, want %d", ti, pi, part.HighWatermark, wantPart.HighWatermark) + } + if string(part.RecordSet) != string(wantPart.RecordSet) { + t.Fatalf("topic[%d] part[%d] records: got %q, want %q", ti, pi, part.RecordSet, wantPart.RecordSet) + } + } + } + }) + } +} + func TestGroupResponseErrorCode_Truncated(t *testing.T) { // A truncated response should return ok=false. _, ok := GroupResponseErrorCode(APIKeyJoinGroup, 2, []byte{0, 0, 0, 1}) From fa6938f2ac9d8b1cc088abe650f7a78261fe5d45 Mon Sep 17 00:00:00 2001 From: Nico Duldhardt Date: Sun, 1 Mar 2026 13:55:52 +0100 Subject: [PATCH 2/2] fix: replace periodic metadata polling with on-demand cache refresh Remove the 3-second polling ticker and refresh broker/topic caches on demand when a lookup misses. Concurrent misses are coalesced via singleflight to avoid thundering herd metadata fetches. Readiness probe now checks cached state first (fast path), falling back to a live metadata fetch only when the cache TTL expires. Static backends are always ready. Clean up comments per AGENTS.md: remove low-value comments, condense verbose doc comments. --- cmd/proxy/main.go | 191 +++++++++++++++++---------------------- cmd/proxy/main_test.go | 32 ++++--- pkg/protocol/request.go | 3 +- pkg/protocol/response.go | 3 +- 4 files changed, 104 insertions(+), 125 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 34b88c7..a8decb2 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -33,6 +33,7 @@ import ( "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" + "golang.org/x/sync/singleflight" ) const ( @@ -60,6 +61,7 @@ type proxy struct { brokerAddrs map[string]string // brokerID -> "host:port" topicNamesMu sync.RWMutex topicNames map[[16]byte]string // topicID -> topic name + metaFlight singleflight.Group backendRetries int backendBackoff time.Duration } @@ -140,7 +142,7 @@ func main() { p.touchHealthy() p.setReady(true) } - p.startBackendRefresh(ctx, backendBackoff) + p.initMetadataCache(ctx) if healthAddr != "" { p.startHealthServer(ctx, healthAddr) } @@ -273,7 +275,7 @@ func (p *proxy) setReady(ready bool) { } func (p *proxy) isReady() bool { - return atomic.LoadUint32(&p.ready) == 1 && p.cacheFresh() + return atomic.LoadUint32(&p.ready) == 1 } func (p *proxy) setCachedBackends(backends []string) { @@ -311,56 +313,33 @@ func (p *proxy) cacheFresh() bool { return time.Since(time.Unix(0, last)) <= p.cacheTTL } -func (p *proxy) startBackendRefresh(ctx context.Context, backoff time.Duration) { - if p.store == nil { - return +// checkReady uses cached state when fresh, falling back to a live metadata +// fetch only when the cache TTL has expired (e.g. no traffic for >60s). +func (p *proxy) checkReady(ctx context.Context) bool { + if len(p.backends) > 0 { + return true } - if backoff <= 0 { - backoff = 500 * time.Millisecond + if p.cacheFresh() { + return true } - // Eagerly populate broker ID -> address mapping before accepting connections. - p.refreshBrokerAddrs(ctx) - ticker := time.NewTicker(3 * time.Second) - go func() { - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if len(p.backends) > 0 { - // Static backends: only refresh broker ID mapping. - p.refreshBrokerAddrs(ctx) - continue - } - _, err := p.refreshBackends(ctx) - if err != nil { - if !p.cacheFresh() { - p.setReady(false) - } - time.Sleep(backoff) - } - } - } - }() + if p.store == nil { + return false + } + backends, err := p.currentBackends(ctx) + return err == nil && len(backends) > 0 } -func (p *proxy) refreshBackends(ctx context.Context) ([]string, error) { - backends, err := p.currentBackends(ctx) - if err != nil { - return nil, err - } - if len(backends) > 0 { - p.touchHealthy() - p.setReady(true) +func (p *proxy) initMetadataCache(ctx context.Context) { + if p.store == nil { + return } - return backends, nil + p.refreshMetadataCache(ctx) } func (p *proxy) startHealthServer(ctx context.Context, addr string) { mux := http.NewServeMux() - mux.HandleFunc("/readyz", func(w http.ResponseWriter, _ *http.Request) { - if p.isReady() || (len(p.cachedBackendsSnapshot()) > 0 && p.cacheFresh()) { + mux.HandleFunc("/readyz", func(w http.ResponseWriter, r *http.Request) { + if p.checkReady(r.Context()) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ready\n")) return @@ -602,7 +581,7 @@ func (p *proxy) handleProduceRouting(ctx context.Context, header *protocol.Reque return nil, nil } - groups := p.groupPartitionsByBroker(produceReq, nil) + groups := p.groupPartitionsByBroker(ctx, produceReq, nil) return p.forwardProduce(ctx, header, produceReq, payload, groups, pool) } @@ -626,7 +605,7 @@ func (p *proxy) forwardProduceRaw(ctx context.Context, payload []byte, pool *con // response. Used for acks=0 produces where the Kafka protocol specifies no // server response. func (p *proxy) fireAndForgetProduce(ctx context.Context, header *protocol.RequestHeader, req *protocol.ProduceRequest, originalPayload []byte, pool *connPool) { - groups := p.groupPartitionsByBroker(req, nil) + groups := p.groupPartitionsByBroker(ctx, req, nil) for addr, subReq := range groups { var payload []byte @@ -658,7 +637,7 @@ func (p *proxy) fireAndForgetProduce(ctx context.Context, header *protocol.Reque // groupPartitionsByBroker groups topic-partitions by the owning broker's address. // If include is non-nil, only partitions present in the include map are grouped. // Partitions with no known owner are grouped under "" for round-robin fallback. -func (p *proxy) groupPartitionsByBroker(req *protocol.ProduceRequest, include map[string]map[int32]bool) map[string]*protocol.ProduceRequest { +func (p *proxy) groupPartitionsByBroker(ctx context.Context, req *protocol.ProduceRequest, include map[string]map[int32]bool) map[string]*protocol.ProduceRequest { groups := make(map[string]*protocol.ProduceRequest) topicIndices := make(map[string]map[string]int) // addr -> topic name -> index in subReq.Topics @@ -677,7 +656,7 @@ func (p *proxy) groupPartitionsByBroker(req *protocol.ProduceRequest, include ma addr := "" if p.router != nil { if ownerID := p.router.LookupOwner(topic.Name, part.Partition); ownerID != "" { - addr = p.brokerIDToAddr(ownerID) + addr = p.brokerIDToAddr(ctx, ownerID) } } subReq, ok := groups[addr] @@ -759,7 +738,7 @@ func (p *proxy) forwardProduce(ctx context.Context, header *protocol.RequestHead return protocol.EncodeProduceResponse(merged, header.APIVersion) } - groups = p.groupPartitionsByBroker(fullReq, failedPartitions) + groups = p.groupPartitionsByBroker(ctx, fullReq, failedPartitions) originalPayload = nil // force re-encoding on retry if len(groups) == 0 { break @@ -896,10 +875,19 @@ func addErrorForAllPartitions(resp *protocol.ProduceResponse, req *protocol.Prod } } -func (p *proxy) brokerIDToAddr(brokerID string) string { +// brokerIDToAddr resolves broker ID to address. Triggers a metadata fetch on +// cache miss. +func (p *proxy) brokerIDToAddr(ctx context.Context, brokerID string) string { p.brokerAddrMu.RLock() addr := p.brokerAddrs[brokerID] p.brokerAddrMu.RUnlock() + if addr != "" { + return addr + } + p.refreshMetadataCache(ctx) + p.brokerAddrMu.RLock() + addr = p.brokerAddrs[brokerID] + p.brokerAddrMu.RUnlock() return addr } @@ -1502,6 +1490,7 @@ func (p *proxy) currentBackends(ctx context.Context) ([]string, error) { p.setReady(true) } p.updateBrokerAddrs(meta.Brokers) + p.updateTopicNames(meta.Topics) return addrs, nil } @@ -1519,23 +1508,24 @@ func (p *proxy) updateBrokerAddrs(brokers []protocol.MetadataBroker) { p.brokerAddrMu.Unlock() } -// refreshBrokerAddrs queries metadata solely to update the broker ID -> address -// mapping and the topic ID -> name mapping. Used when static backends are -// configured (so currentBackends returns early) but partition-aware routing -// still needs broker ID resolution and topic ID resolution. -func (p *proxy) refreshBrokerAddrs(ctx context.Context) { +// refreshMetadataCache updates broker address and topic name caches from +// metadata. Concurrent calls are coalesced via singleflight. +func (p *proxy) refreshMetadataCache(ctx context.Context) { if p.store == nil { return } - meta, err := p.store.Metadata(ctx, nil) - if err != nil { - return - } - p.updateBrokerAddrs(meta.Brokers) - p.updateTopicNames(meta.Topics) + p.metaFlight.Do("refresh", func() (interface{}, error) { + meta, err := p.store.Metadata(ctx, nil) + if err != nil { + return nil, err + } + p.updateBrokerAddrs(meta.Brokers) + p.updateTopicNames(meta.Topics) + p.touchHealthy() + return nil, nil + }) } -// updateTopicNames rebuilds the topic ID -> name mapping from metadata. func (p *proxy) updateTopicNames(topics []protocol.MetadataTopic) { names := make(map[[16]byte]string, len(topics)) var zeroID [16]byte @@ -1549,11 +1539,18 @@ func (p *proxy) updateTopicNames(topics []protocol.MetadataTopic) { p.topicNamesMu.Unlock() } -// resolveTopicID returns the topic name for a given topic ID, or "" if unknown. -func (p *proxy) resolveTopicID(id [16]byte) string { +// resolveTopicID maps topic UUID to name. Triggers a metadata fetch on cache miss. +func (p *proxy) resolveTopicID(ctx context.Context, id [16]byte) string { p.topicNamesMu.RLock() name := p.topicNames[id] p.topicNamesMu.RUnlock() + if name != "" { + return name + } + p.refreshMetadataCache(ctx) + p.topicNamesMu.RLock() + name = p.topicNames[id] + p.topicNamesMu.RUnlock() return name } @@ -1568,18 +1565,12 @@ func (p *proxy) forwardToBackend(ctx context.Context, conn net.Conn, backendAddr return frame.Payload, nil } -// handleGroupRouting routes group-related requests to the broker that owns the -// group coordination lease. If no owner is cached, or the owner returns -// NOT_COORDINATOR, the request is retried on a different broker. -// -// DescribeGroups requests are forwarded once without retry since different -// groups may live on different brokers. The broker returns per-group -// NOT_COORDINATOR errors that the Kafka client handles natively. +// handleGroupRouting forwards group requests to the coordination lease owner, +// retrying on NOT_COORDINATOR. DescribeGroups is forwarded once without retry +// since it may span multiple groups on different brokers. func (p *proxy) handleGroupRouting(ctx context.Context, header *protocol.RequestHeader, payload []byte, pool *connPool) ([]byte, error) { groupID := p.extractGroupID(header.APIKey, payload) - // DescribeGroups with multiple groups cannot be reliably split/retried at - // the proxy level. Forward once and let the client handle per-group errors. maxAttempts := 3 if header.APIKey == protocol.APIKeyDescribeGroups { maxAttempts = 1 @@ -1591,7 +1582,7 @@ func (p *proxy) handleGroupRouting(ctx context.Context, header *protocol.Request targetAddr := "" if p.groupRouter != nil && groupID != "" { if ownerID := p.groupRouter.LookupOwner(groupID); ownerID != "" { - targetAddr = p.brokerIDToAddr(ownerID) + targetAddr = p.brokerIDToAddr(ctx, ownerID) } } @@ -1624,8 +1615,6 @@ func (p *proxy) handleGroupRouting(ctx context.Context, header *protocol.Request return nil, fmt.Errorf("group request for %q failed after %d attempts", groupID, maxAttempts) } -// extractGroupID parses the request payload to extract the group ID for routing. -// Returns "" if the group ID cannot be determined. func (p *proxy) extractGroupID(apiKey int16, payload []byte) string { _, req, err := protocol.ParseRequest(payload) if err != nil { @@ -1654,10 +1643,8 @@ func (p *proxy) extractGroupID(apiKey int16, payload []byte) string { } } -// handleFetchRouting routes fetch requests to the broker(s) that own the -// requested partitions. Like produce routing, the request is split by owning -// broker, forwarded concurrently, and responses are merged. On -// NOT_LEADER_OR_FOLLOWER, failed partitions are retried on a different broker. +// handleFetchRouting splits fetch requests by partition owner, fans out, and +// merges responses. Retries NOT_LEADER_OR_FOLLOWER on a different broker. func (p *proxy) handleFetchRouting(ctx context.Context, header *protocol.RequestHeader, payload []byte, pool *connPool) ([]byte, error) { _, req, err := protocol.ParseRequest(payload) if err != nil { @@ -1668,10 +1655,9 @@ func (p *proxy) handleFetchRouting(ctx context.Context, header *protocol.Request return p.forwardFetchRaw(ctx, payload, pool) } - // Resolve topic names for v12+ requests that use topic IDs. - p.resolveFetchTopicNames(fetchReq) + p.resolveFetchTopicNames(ctx, fetchReq) - groups := p.groupFetchPartitionsByBroker(fetchReq, nil) + groups := p.groupFetchPartitionsByBroker(ctx, fetchReq, nil) return p.forwardFetch(ctx, header, fetchReq, payload, groups, pool) } @@ -1690,20 +1676,19 @@ func (p *proxy) forwardFetchRaw(ctx context.Context, payload []byte, pool *connP return resp, nil } -// resolveFetchTopicNames fills in topic names from topic IDs for v12+ fetch -// requests. The partition router uses topic names, so we need to resolve IDs. -func (p *proxy) resolveFetchTopicNames(req *protocol.FetchRequest) { +// resolveFetchTopicNames resolves topic IDs to names so the partition router +// (which is keyed by name) can look up owners for v12+ requests. +func (p *proxy) resolveFetchTopicNames(ctx context.Context, req *protocol.FetchRequest) { var zeroID [16]byte for i := range req.Topics { if req.Topics[i].Name == "" && req.Topics[i].TopicID != zeroID { - req.Topics[i].Name = p.resolveTopicID(req.Topics[i].TopicID) + req.Topics[i].Name = p.resolveTopicID(ctx, req.Topics[i].TopicID) } } } -// fetchTopicKey returns a deduplication key for a fetch topic. It uses the -// topic name when available, falling back to the hex-encoded topic ID. This -// prevents multiple unresolved topics (all with name "") from colliding. +// fetchTopicKey returns name when available, or hex topic ID as fallback. +// Prevents unresolved v12+ topics (all name="") from colliding in maps. func fetchTopicKey(name string, id [16]byte) string { if name != "" { return name @@ -1711,11 +1696,10 @@ func fetchTopicKey(name string, id [16]byte) string { return fmt.Sprintf("id:%x", id) } -// groupFetchPartitionsByBroker groups topic-partitions by the owning broker's -// address. If include is non-nil, only partitions present in the include map -// are grouped (keyed by fetchTopicKey). Partitions with no known owner are -// grouped under "" for round-robin fallback. -func (p *proxy) groupFetchPartitionsByBroker(req *protocol.FetchRequest, include map[string]map[int32]bool) map[string]*protocol.FetchRequest { +// groupFetchPartitionsByBroker groups partitions by owning broker. If include +// is non-nil, only listed partitions are grouped. Unknown owners go under "" +// for round-robin. +func (p *proxy) groupFetchPartitionsByBroker(ctx context.Context, req *protocol.FetchRequest, include map[string]map[int32]bool) map[string]*protocol.FetchRequest { groups := make(map[string]*protocol.FetchRequest) topicIndices := make(map[string]map[string]int) // addr -> topicKey -> index in subReq.Topics @@ -1736,7 +1720,7 @@ func (p *proxy) groupFetchPartitionsByBroker(req *protocol.FetchRequest, include addr := "" if p.router != nil && topicName != "" { if ownerID := p.router.LookupOwner(topicName, part.Partition); ownerID != "" { - addr = p.brokerIDToAddr(ownerID) + addr = p.brokerIDToAddr(ctx, ownerID) } } subReq, ok := groups[addr] @@ -1776,9 +1760,8 @@ type fetchFanOutResult struct { err error } -// forwardFetch splits a fetch request by broker, forwards each sub-request -// concurrently, and merges the responses. If any partitions are rejected with -// NOT_LEADER_OR_FOLLOWER, those partitions are retried on a different broker. +// forwardFetch fans out sub-requests, merges responses, and retries +// NOT_LEADER_OR_FOLLOWER partitions on a different broker. func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader, fullReq *protocol.FetchRequest, originalPayload []byte, groups map[string]*protocol.FetchRequest, pool *connPool) ([]byte, error) { const maxRetries = 3 @@ -1787,14 +1770,11 @@ func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader SessionID: fullReq.SessionID, } - // failedPartitions is keyed by fetchTopicKey (topic name or hex topic ID) - // to avoid collisions when multiple v12+ topics have unresolved names. + // Keyed by fetchTopicKey to avoid collisions among unresolved v12+ topics. var failedPartitions map[string]map[int32]bool for attempt := 0; attempt < maxRetries; attempt++ { failedPartitions = nil - // Scope triedBackends per attempt so that retries can revisit brokers - // from earlier attempts. Without this, with N brokers all N get excluded - // after the first attempt and subsequent retries always fail to connect. + // Reset per attempt so retries can revisit brokers from earlier attempts. triedBackends := make(map[string]bool) subResults := p.fanOutFetch(ctx, header, groups, originalPayload, triedBackends, pool) @@ -1815,7 +1795,7 @@ func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader if part.ErrorCode == protocol.NOT_LEADER_OR_FOLLOWER { topicName := topic.Name if topicName == "" { - topicName = p.resolveTopicID(topic.TopicID) + topicName = p.resolveTopicID(ctx, topic.TopicID) } key := fetchTopicKey(topicName, topic.TopicID) if failedPartitions == nil { @@ -1843,7 +1823,7 @@ func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader return protocol.EncodeFetchResponse(merged, header.APIVersion) } - groups = p.groupFetchPartitionsByBroker(fullReq, failedPartitions) + groups = p.groupFetchPartitionsByBroker(ctx, fullReq, failedPartitions) originalPayload = nil if len(groups) == 0 { break @@ -1851,7 +1831,6 @@ func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader p.logger.Debug("retrying NOT_LEADER fetch partitions", "attempt", attempt+1, "partitions", len(failedPartitions)) } - // Fill remaining failed partitions with errors. for _, topic := range fullReq.Topics { key := fetchTopicKey(topic.Name, topic.TopicID) failedParts, ok := failedPartitions[key] diff --git a/cmd/proxy/main_test.go b/cmd/proxy/main_test.go index 97a1cd2..5039f1f 100644 --- a/cmd/proxy/main_test.go +++ b/cmd/proxy/main_test.go @@ -391,7 +391,7 @@ func TestGroupPartitionsByBrokerNoRouter(t *testing.T) { "orders": {0, 1, 2}, "events": {0}, }) - groups := p.groupPartitionsByBroker(req, nil) + groups := p.groupPartitionsByBroker(context.Background(), req, nil) if len(groups) != 1 { t.Fatalf("expected 1 group (all round-robin), got %d", len(groups)) } @@ -417,7 +417,7 @@ func TestGroupPartitionsByBrokerNoRouterMultipleTopics(t *testing.T) { "orders": {0, 1}, "events": {0, 1, 2}, }) - groups := p.groupPartitionsByBroker(req, nil) + groups := p.groupPartitionsByBroker(context.Background(), req, nil) if len(groups) != 1 { t.Fatalf("expected 1 group, got %d", len(groups)) } @@ -447,7 +447,7 @@ func TestGroupPartitionsByBrokerFiltersCorrectly(t *testing.T) { "orders": {1: true}, "events": {0: true}, } - groups := p.groupPartitionsByBroker(req, include) + groups := p.groupPartitionsByBroker(context.Background(), req, include) if len(groups) != 1 { t.Fatalf("expected 1 group (no router), got %d", len(groups)) } @@ -519,13 +519,14 @@ func TestUpdateBrokerAddrs(t *testing.T) { } p.updateBrokerAddrs(brokers) - if got := p.brokerIDToAddr("1"); got != "broker1:9092" { + ctx := context.Background() + if got := p.brokerIDToAddr(ctx, "1"); got != "broker1:9092" { t.Fatalf("broker 1: got %q, want %q", got, "broker1:9092") } - if got := p.brokerIDToAddr("2"); got != "broker2:9093" { + if got := p.brokerIDToAddr(ctx, "2"); got != "broker2:9093" { t.Fatalf("broker 2: got %q, want %q", got, "broker2:9093") } - if got := p.brokerIDToAddr("3"); got != "" { + if got := p.brokerIDToAddr(ctx, "3"); got != "" { t.Fatalf("broker 3 (empty host): got %q, want %q", got, "") } } @@ -797,7 +798,7 @@ func TestGroupFetchPartitionsByBrokerNoRouter(t *testing.T) { "orders": {0, 1, 2}, "events": {0}, }) - groups := p.groupFetchPartitionsByBroker(req, nil) + groups := p.groupFetchPartitionsByBroker(context.Background(), req, nil) if len(groups) != 1 { t.Fatalf("expected 1 group (all round-robin), got %d", len(groups)) } @@ -819,7 +820,7 @@ func TestGroupFetchPartitionsByBrokerNoRouterMultipleTopics(t *testing.T) { "orders": {0, 1}, "events": {0, 1, 2}, }) - groups := p.groupFetchPartitionsByBroker(req, nil) + groups := p.groupFetchPartitionsByBroker(context.Background(), req, nil) if len(groups) != 1 { t.Fatalf("expected 1 group, got %d", len(groups)) } @@ -849,7 +850,7 @@ func TestGroupFetchPartitionsByBrokerFiltersCorrectly(t *testing.T) { "orders": {1: true}, "events": {0: true}, } - groups := p.groupFetchPartitionsByBroker(req, include) + groups := p.groupFetchPartitionsByBroker(context.Background(), req, include) if len(groups) != 1 { t.Fatalf("expected 1 group (no router), got %d", len(groups)) } @@ -944,13 +945,14 @@ func TestUpdateTopicNames(t *testing.T) { } p.updateTopicNames(topics) - if got := p.resolveTopicID(topicID1); got != "orders" { + ctx := context.Background() + if got := p.resolveTopicID(ctx, topicID1); got != "orders" { t.Fatalf("resolveTopicID(1): got %q, want %q", got, "orders") } - if got := p.resolveTopicID(topicID2); got != "events" { + if got := p.resolveTopicID(ctx, topicID2); got != "events" { t.Fatalf("resolveTopicID(2): got %q, want %q", got, "events") } - if got := p.resolveTopicID([16]byte{9, 9, 9}); got != "" { + if got := p.resolveTopicID(ctx, [16]byte{9, 9, 9}); got != "" { t.Fatalf("resolveTopicID(unknown): got %q, want %q", got, "") } } @@ -972,7 +974,7 @@ func TestGroupFetchPartitionsByBrokerUnresolvedTopicIDs(t *testing.T) { {TopicID: idB, Partitions: []protocol.FetchPartitionRequest{{Partition: 0, MaxBytes: 1048576}}}, }, } - groups := p.groupFetchPartitionsByBroker(req, nil) + groups := p.groupFetchPartitionsByBroker(context.Background(), req, nil) rr := groups[""] if rr == nil { t.Fatal("expected round-robin group") @@ -1011,7 +1013,7 @@ func TestGroupFetchPartitionsByBrokerUnresolvedFilter(t *testing.T) { include := map[string]map[int32]bool{ fetchTopicKey("", idA): {1: true}, } - groups := p.groupFetchPartitionsByBroker(req, include) + groups := p.groupFetchPartitionsByBroker(context.Background(), req, include) rr := groups[""] if rr == nil { t.Fatal("expected round-robin group") @@ -1038,7 +1040,7 @@ func TestResolveFetchTopicNames(t *testing.T) { {Name: "events"}, // already has name, should be left alone }, } - p.resolveFetchTopicNames(req) + p.resolveFetchTopicNames(context.Background(), req) if req.Topics[0].Name != "orders" { t.Fatalf("topic[0] name: got %q, want %q", req.Topics[0].Name, "orders") diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index 8e51d4a..6bc4d66 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -1790,8 +1790,7 @@ func ParseRequest(b []byte) (*RequestHeader, Request, error) { return header, req, nil } -// EncodeFetchRequest serializes a RequestHeader + FetchRequest into wire-format -// bytes suitable for WriteFrame. The encoding mirrors what ParseRequest expects. +// EncodeFetchRequest encodes a fetch request. Mirrors ParseRequest's fetch case. func EncodeFetchRequest(header *RequestHeader, req *FetchRequest, version int16) ([]byte, error) { w := newByteWriter(256) flexible := isFlexibleRequest(APIKeyFetch, version) diff --git a/pkg/protocol/response.go b/pkg/protocol/response.go index 93c2201..8f3c27e 100644 --- a/pkg/protocol/response.go +++ b/pkg/protocol/response.go @@ -825,8 +825,7 @@ func EncodeFetchResponse(resp *FetchResponse, version int16) ([]byte, error) { return w.Bytes(), nil } -// ParseFetchResponse decodes a fetch response from wire-format bytes. -// This is the inverse of EncodeFetchResponse. +// ParseFetchResponse decodes a fetch response. Inverse of EncodeFetchResponse. func ParseFetchResponse(payload []byte, version int16) (*FetchResponse, error) { if version < 1 || version > 13 { return nil, fmt.Errorf("fetch response version %d not supported", version)