diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index 6bc4d66..445744a 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -505,16 +505,16 @@ func ParseRequest(b []byte) (*RequestHeader, Request, error) { if err != nil { return nil, nil, fmt.Errorf("read produce records: %w", err) } + if flexible { + if err := reader.SkipTaggedFields(); err != nil { + return nil, nil, fmt.Errorf("skip partition tags: %w", err) + } + } partitions = append(partitions, ProducePartition{ Partition: index, Records: records, }) } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip partition tags: %w", err) - } - } if flexible { if err := reader.SkipTaggedFields(); err != nil { return nil, nil, fmt.Errorf("skip topic tags: %w", err) @@ -1931,14 +1931,12 @@ func EncodeProduceRequest(header *RequestHeader, req *ProduceRequest, version in } else { w.BytesWithLength(part.Records) } + if flexible { + w.WriteTaggedFields(0) + } } - // Match the parser: two tagged-field blocks after the partition array. - // The Kafka protocol spec places per-partition tags inside the partition - // loop, but our parser (ParseRequest) reads them outside. Since the - // broker is always KafScale (not vanilla Kafka), this is intentional. if flexible { w.WriteTaggedFields(0) - w.WriteTaggedFields(0) } } if flexible { diff --git a/pkg/protocol/request_test.go b/pkg/protocol/request_test.go index 1323741..d18afab 100644 --- a/pkg/protocol/request_test.go +++ b/pkg/protocol/request_test.go @@ -1033,3 +1033,113 @@ func TestEncodeFetchRequest_KmsgValidation(t *testing.T) { t.Fatalf("fetch offset: got %d, want 42", kmsgReq.Topics[0].Partitions[0].FetchOffset) } } + +// TestProduceMultiPartitionFranzCompat tests byte-level compatibility with +// franz-go for multi-partition produce requests in both directions: +// - franz-go encodes → KafScale parses +// - KafScale encodes → franz-go decodes +func TestProduceMultiPartitionFranzCompat(t *testing.T) { + t.Run("franz-encode-kafscale-parse", func(t *testing.T) { + req := kmsg.NewPtrProduceRequest() + req.Version = 9 + req.Acks = -1 + req.TimeoutMillis = 3000 + topic := kmsg.NewProduceRequestTopic() + topic.Topic = "orders" + for _, pi := range []int32{0, 1, 2} { + part := kmsg.NewProduceRequestTopicPartition() + part.Partition = pi + part.Records = []byte{byte(pi + 1), byte(pi + 2)} + topic.Partitions = append(topic.Partitions, part) + } + req.Topics = append(req.Topics, topic) + body := req.AppendTo(nil) + + w := newByteWriter(len(body) + 32) + w.Int16(APIKeyProduce) + w.Int16(9) + w.Int32(55) + clientID := "kgo" + w.NullableString(&clientID) + w.WriteTaggedFields(0) + w.write(body) + + _, parsed, err := ParseRequest(w.Bytes()) + if err != nil { + t.Fatalf("ParseRequest: %v", err) + } + got, ok := parsed.(*ProduceRequest) + if !ok { + t.Fatalf("expected *ProduceRequest, got %T", parsed) + } + if len(got.Topics) != 1 { + t.Fatalf("topic count: got %d want 1", len(got.Topics)) + } + if len(got.Topics[0].Partitions) != 3 { + t.Fatalf("partition count: got %d want 3", len(got.Topics[0].Partitions)) + } + for pi, part := range got.Topics[0].Partitions { + if part.Partition != int32(pi) { + t.Fatalf("part[%d] index: got %d want %d", pi, part.Partition, pi) + } + want := []byte{byte(pi + 1), byte(pi + 2)} + if string(part.Records) != string(want) { + t.Fatalf("part[%d] records: got %x want %x", pi, part.Records, want) + } + } + }) + + t.Run("kafscale-encode-franz-parse", func(t *testing.T) { + header := &RequestHeader{ + APIKey: APIKeyProduce, + APIVersion: 9, + CorrelationID: 66, + ClientID: strPtr("test"), + } + req := &ProduceRequest{ + Acks: -1, + TimeoutMs: 3000, + Topics: []ProduceTopic{ + { + Name: "orders", + Partitions: []ProducePartition{ + {Partition: 0, Records: []byte{1, 2}}, + {Partition: 1, Records: []byte{3, 4}}, + {Partition: 2, Records: []byte{5, 6}}, + }, + }, + }, + } + encoded, err := EncodeProduceRequest(header, req, 9) + if err != nil { + t.Fatalf("encode: %v", err) + } + + _, reader, err := ParseRequestHeader(encoded) + if err != nil { + t.Fatalf("ParseRequestHeader: %v", err) + } + bodyStart := len(encoded) - reader.remaining() + + kmsgReq := kmsg.NewPtrProduceRequest() + kmsgReq.Version = 9 + if err := kmsgReq.ReadFrom(encoded[bodyStart:]); err != nil { + t.Fatalf("kmsg.ReadFrom: %v", err) + } + if len(kmsgReq.Topics) != 1 { + t.Fatalf("topic count: got %d want 1", len(kmsgReq.Topics)) + } + if len(kmsgReq.Topics[0].Partitions) != 3 { + t.Fatalf("partition count: got %d want 3", len(kmsgReq.Topics[0].Partitions)) + } + for pi, part := range kmsgReq.Topics[0].Partitions { + if part.Partition != int32(pi) { + t.Fatalf("part[%d] index: got %d want %d", pi, part.Partition, pi) + } + want := []byte{byte(pi*2 + 1), byte(pi*2 + 2)} + if string(part.Records) != string(want) { + t.Fatalf("part[%d] records: got %x want %x", pi, part.Records, want) + } + } + }) +} diff --git a/pkg/protocol/response_test.go b/pkg/protocol/response_test.go index f0e340d..89a63ff 100644 --- a/pkg/protocol/response_test.go +++ b/pkg/protocol/response_test.go @@ -1518,12 +1518,6 @@ func TestParseProduceResponseRoundTrip(t *testing.T) { } func TestEncodeProduceRequestRoundTrip(t *testing.T) { - header := &RequestHeader{ - APIKey: APIKeyProduce, - APIVersion: 9, - CorrelationID: 77, - ClientID: strPtr("test-client"), - } req := &ProduceRequest{ Acks: -1, TimeoutMs: 5000, @@ -1531,64 +1525,58 @@ func TestEncodeProduceRequestRoundTrip(t *testing.T) { { Name: "orders", Partitions: []ProducePartition{ - {Partition: 0, Records: []byte{1, 2, 3, 4}}, - {Partition: 1, Records: []byte{5, 6}}, + {Partition: 0, Records: []byte{1, 2, 3}}, + {Partition: 1, Records: []byte{4, 5}}, + {Partition: 2, Records: []byte{6, 7, 8, 9}}, }, }, { Name: "events", Partitions: []ProducePartition{ - {Partition: 0, Records: []byte{7, 8, 9}}, + {Partition: 0, Records: []byte{10}}, + {Partition: 3, Records: []byte{11, 12}}, }, }, }, } for _, version := range []int16{3, 5, 7, 8, 9, 10} { t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) { - h := &RequestHeader{ + header := &RequestHeader{ APIKey: APIKeyProduce, APIVersion: version, - CorrelationID: header.CorrelationID, - ClientID: header.ClientID, + CorrelationID: 77, + ClientID: strPtr("test-client"), } - encoded, err := EncodeProduceRequest(h, req, version) + encoded, err := EncodeProduceRequest(header, req, version) if err != nil { t.Fatalf("encode: %v", err) } - parsedHeader, parsedReq, err := ParseRequest(encoded) + _, parsedReq, err := ParseRequest(encoded) if err != nil { t.Fatalf("parse: %v", err) } - if parsedHeader.CorrelationID != h.CorrelationID { - t.Fatalf("correlation id: got %d want %d", parsedHeader.CorrelationID, h.CorrelationID) - } - produceReq, ok := parsedReq.(*ProduceRequest) + got, ok := parsedReq.(*ProduceRequest) if !ok { t.Fatalf("expected *ProduceRequest, got %T", parsedReq) } - if produceReq.Acks != req.Acks { - t.Fatalf("acks: got %d want %d", produceReq.Acks, req.Acks) + if len(got.Topics) != len(req.Topics) { + t.Fatalf("topic count: got %d want %d", len(got.Topics), len(req.Topics)) } - if produceReq.TimeoutMs != req.TimeoutMs { - t.Fatalf("timeout: got %d want %d", produceReq.TimeoutMs, req.TimeoutMs) - } - if len(produceReq.Topics) != len(req.Topics) { - t.Fatalf("topic count: got %d want %d", len(produceReq.Topics), len(req.Topics)) - } - for ti, topic := range produceReq.Topics { - if topic.Name != req.Topics[ti].Name { - t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, req.Topics[ti].Name) + for ti, topic := range got.Topics { + want := req.Topics[ti] + if topic.Name != want.Name { + t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, want.Name) } - if len(topic.Partitions) != len(req.Topics[ti].Partitions) { - t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(req.Topics[ti].Partitions)) + if len(topic.Partitions) != len(want.Partitions) { + t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(want.Partitions)) } for pi, part := range topic.Partitions { - want := req.Topics[ti].Partitions[pi] - if part.Partition != want.Partition { - t.Fatalf("topic[%d].part[%d] index: got %d want %d", ti, pi, part.Partition, want.Partition) + wantPart := want.Partitions[pi] + if part.Partition != wantPart.Partition { + t.Fatalf("topic[%d].part[%d] index: got %d want %d", ti, pi, part.Partition, wantPart.Partition) } - if string(part.Records) != string(want.Records) { - t.Fatalf("topic[%d].part[%d] records: got %v want %v", ti, pi, part.Records, want.Records) + if string(part.Records) != string(wantPart.Records) { + t.Fatalf("topic[%d].part[%d] records: got %x want %x", ti, pi, part.Records, wantPart.Records) } } }