Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions pkg/protocol/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,16 @@ func ParseRequest(b []byte) (*RequestHeader, Request, error) {
if err != nil {
return nil, nil, fmt.Errorf("read produce records: %w", err)
}
if flexible {
if err := reader.SkipTaggedFields(); err != nil {
return nil, nil, fmt.Errorf("skip partition tags: %w", err)
}
}
partitions = append(partitions, ProducePartition{
Partition: index,
Records: records,
})
}
if flexible {
if err := reader.SkipTaggedFields(); err != nil {
return nil, nil, fmt.Errorf("skip partition tags: %w", err)
}
}
if flexible {
if err := reader.SkipTaggedFields(); err != nil {
return nil, nil, fmt.Errorf("skip topic tags: %w", err)
Expand Down Expand Up @@ -1931,14 +1931,12 @@ func EncodeProduceRequest(header *RequestHeader, req *ProduceRequest, version in
} else {
w.BytesWithLength(part.Records)
}
if flexible {
w.WriteTaggedFields(0)
}
}
// Match the parser: two tagged-field blocks after the partition array.
// The Kafka protocol spec places per-partition tags inside the partition
// loop, but our parser (ParseRequest) reads them outside. Since the
// broker is always KafScale (not vanilla Kafka), this is intentional.
if flexible {
w.WriteTaggedFields(0)
w.WriteTaggedFields(0)
}
}
if flexible {
Expand Down
110 changes: 110 additions & 0 deletions pkg/protocol/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1033,3 +1033,113 @@ func TestEncodeFetchRequest_KmsgValidation(t *testing.T) {
t.Fatalf("fetch offset: got %d, want 42", kmsgReq.Topics[0].Partitions[0].FetchOffset)
}
}

// TestProduceMultiPartitionFranzCompat tests byte-level compatibility with
// franz-go for multi-partition produce requests in both directions:
// - franz-go encodes → KafScale parses
// - KafScale encodes → franz-go decodes
func TestProduceMultiPartitionFranzCompat(t *testing.T) {
t.Run("franz-encode-kafscale-parse", func(t *testing.T) {
req := kmsg.NewPtrProduceRequest()
req.Version = 9
req.Acks = -1
req.TimeoutMillis = 3000
topic := kmsg.NewProduceRequestTopic()
topic.Topic = "orders"
for _, pi := range []int32{0, 1, 2} {
part := kmsg.NewProduceRequestTopicPartition()
part.Partition = pi
part.Records = []byte{byte(pi + 1), byte(pi + 2)}
topic.Partitions = append(topic.Partitions, part)
}
req.Topics = append(req.Topics, topic)
body := req.AppendTo(nil)

w := newByteWriter(len(body) + 32)
w.Int16(APIKeyProduce)
w.Int16(9)
w.Int32(55)
clientID := "kgo"
w.NullableString(&clientID)
w.WriteTaggedFields(0)
w.write(body)

_, parsed, err := ParseRequest(w.Bytes())
if err != nil {
t.Fatalf("ParseRequest: %v", err)
}
got, ok := parsed.(*ProduceRequest)
if !ok {
t.Fatalf("expected *ProduceRequest, got %T", parsed)
}
if len(got.Topics) != 1 {
t.Fatalf("topic count: got %d want 1", len(got.Topics))
}
if len(got.Topics[0].Partitions) != 3 {
t.Fatalf("partition count: got %d want 3", len(got.Topics[0].Partitions))
}
for pi, part := range got.Topics[0].Partitions {
if part.Partition != int32(pi) {
t.Fatalf("part[%d] index: got %d want %d", pi, part.Partition, pi)
}
want := []byte{byte(pi + 1), byte(pi + 2)}
if string(part.Records) != string(want) {
t.Fatalf("part[%d] records: got %x want %x", pi, part.Records, want)
}
}
})

t.Run("kafscale-encode-franz-parse", func(t *testing.T) {
header := &RequestHeader{
APIKey: APIKeyProduce,
APIVersion: 9,
CorrelationID: 66,
ClientID: strPtr("test"),
}
req := &ProduceRequest{
Acks: -1,
TimeoutMs: 3000,
Topics: []ProduceTopic{
{
Name: "orders",
Partitions: []ProducePartition{
{Partition: 0, Records: []byte{1, 2}},
{Partition: 1, Records: []byte{3, 4}},
{Partition: 2, Records: []byte{5, 6}},
},
},
},
}
encoded, err := EncodeProduceRequest(header, req, 9)
if err != nil {
t.Fatalf("encode: %v", err)
}

_, reader, err := ParseRequestHeader(encoded)
if err != nil {
t.Fatalf("ParseRequestHeader: %v", err)
}
bodyStart := len(encoded) - reader.remaining()

kmsgReq := kmsg.NewPtrProduceRequest()
kmsgReq.Version = 9
if err := kmsgReq.ReadFrom(encoded[bodyStart:]); err != nil {
t.Fatalf("kmsg.ReadFrom: %v", err)
}
if len(kmsgReq.Topics) != 1 {
t.Fatalf("topic count: got %d want 1", len(kmsgReq.Topics))
}
if len(kmsgReq.Topics[0].Partitions) != 3 {
t.Fatalf("partition count: got %d want 3", len(kmsgReq.Topics[0].Partitions))
}
for pi, part := range kmsgReq.Topics[0].Partitions {
if part.Partition != int32(pi) {
t.Fatalf("part[%d] index: got %d want %d", pi, part.Partition, pi)
}
want := []byte{byte(pi*2 + 1), byte(pi*2 + 2)}
if string(part.Records) != string(want) {
t.Fatalf("part[%d] records: got %x want %x", pi, part.Records, want)
}
}
})
}
60 changes: 24 additions & 36 deletions pkg/protocol/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1518,77 +1518,65 @@ func TestParseProduceResponseRoundTrip(t *testing.T) {
}

func TestEncodeProduceRequestRoundTrip(t *testing.T) {
header := &RequestHeader{
APIKey: APIKeyProduce,
APIVersion: 9,
CorrelationID: 77,
ClientID: strPtr("test-client"),
}
req := &ProduceRequest{
Acks: -1,
TimeoutMs: 5000,
Topics: []ProduceTopic{
{
Name: "orders",
Partitions: []ProducePartition{
{Partition: 0, Records: []byte{1, 2, 3, 4}},
{Partition: 1, Records: []byte{5, 6}},
{Partition: 0, Records: []byte{1, 2, 3}},
{Partition: 1, Records: []byte{4, 5}},
{Partition: 2, Records: []byte{6, 7, 8, 9}},
},
},
{
Name: "events",
Partitions: []ProducePartition{
{Partition: 0, Records: []byte{7, 8, 9}},
{Partition: 0, Records: []byte{10}},
{Partition: 3, Records: []byte{11, 12}},
},
},
},
}
for _, version := range []int16{3, 5, 7, 8, 9, 10} {
t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) {
h := &RequestHeader{
header := &RequestHeader{
APIKey: APIKeyProduce,
APIVersion: version,
CorrelationID: header.CorrelationID,
ClientID: header.ClientID,
CorrelationID: 77,
ClientID: strPtr("test-client"),
}
encoded, err := EncodeProduceRequest(h, req, version)
encoded, err := EncodeProduceRequest(header, req, version)
if err != nil {
t.Fatalf("encode: %v", err)
}
parsedHeader, parsedReq, err := ParseRequest(encoded)
_, parsedReq, err := ParseRequest(encoded)
if err != nil {
t.Fatalf("parse: %v", err)
}
if parsedHeader.CorrelationID != h.CorrelationID {
t.Fatalf("correlation id: got %d want %d", parsedHeader.CorrelationID, h.CorrelationID)
}
produceReq, ok := parsedReq.(*ProduceRequest)
got, ok := parsedReq.(*ProduceRequest)
if !ok {
t.Fatalf("expected *ProduceRequest, got %T", parsedReq)
}
if produceReq.Acks != req.Acks {
t.Fatalf("acks: got %d want %d", produceReq.Acks, req.Acks)
if len(got.Topics) != len(req.Topics) {
t.Fatalf("topic count: got %d want %d", len(got.Topics), len(req.Topics))
}
if produceReq.TimeoutMs != req.TimeoutMs {
t.Fatalf("timeout: got %d want %d", produceReq.TimeoutMs, req.TimeoutMs)
}
if len(produceReq.Topics) != len(req.Topics) {
t.Fatalf("topic count: got %d want %d", len(produceReq.Topics), len(req.Topics))
}
for ti, topic := range produceReq.Topics {
if topic.Name != req.Topics[ti].Name {
t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, req.Topics[ti].Name)
for ti, topic := range got.Topics {
want := req.Topics[ti]
if topic.Name != want.Name {
t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, want.Name)
}
if len(topic.Partitions) != len(req.Topics[ti].Partitions) {
t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(req.Topics[ti].Partitions))
if len(topic.Partitions) != len(want.Partitions) {
t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(want.Partitions))
}
for pi, part := range topic.Partitions {
want := req.Topics[ti].Partitions[pi]
if part.Partition != want.Partition {
t.Fatalf("topic[%d].part[%d] index: got %d want %d", ti, pi, part.Partition, want.Partition)
wantPart := want.Partitions[pi]
if part.Partition != wantPart.Partition {
t.Fatalf("topic[%d].part[%d] index: got %d want %d", ti, pi, part.Partition, wantPart.Partition)
}
if string(part.Records) != string(want.Records) {
t.Fatalf("topic[%d].part[%d] records: got %v want %v", ti, pi, part.Records, want.Records)
if string(part.Records) != string(wantPart.Records) {
t.Fatalf("topic[%d].part[%d] records: got %x want %x", ti, pi, part.Records, wantPart.Records)
}
}
}
Expand Down