diff --git a/cmd/kaf/consume.go b/cmd/kaf/consume.go index 3a37b329..5934e5d3 100644 --- a/cmd/kaf/consume.go +++ b/cmd/kaf/consume.go @@ -13,8 +13,7 @@ import ( "github.com/Shopify/sarama" "github.com/birdayz/kaf/pkg/avro" - "github.com/birdayz/kaf/pkg/proto" - "github.com/golang/protobuf/jsonpb" + "github.com/birdayz/kaf/pkg/codec" prettyjson "github.com/hokaccha/go-prettyjson" "github.com/spf13/cobra" "github.com/vmihailenco/msgpack/v5" @@ -37,7 +36,11 @@ var ( limitMessagesFlag int64 - reg *proto.DescriptorRegistry + reg *codec.DescriptorRegistry + + protoDecoder *codec.ProtoCodec + protoKeyDecoder *codec.ProtoCodec + avroDecoder *codec.AvroCodec ) func init() { @@ -95,6 +98,11 @@ var consumeCmd = &cobra.Command{ topic := args[0] client := getClientFromConfig(cfg) + schemaCache = getSchemaCache() + avroDecoder = codec.NewAvroCodec(-1, false, schemaCache) + protoDecoder = codec.NewProtoCodec(protoType, reg) + protoKeyDecoder = codec.NewProtoCodec(keyProtoType, reg) + switch offsetFlag { case "oldest": offset = sarama.OffsetOldest @@ -169,8 +177,6 @@ func withoutConsumerGroup(ctx context.Context, client sarama.Client, topic strin partitions = flagPartitions } - schemaCache = getSchemaCache() - wg := sync.WaitGroup{} mu := sync.Mutex{} // Synchronizes stderr and stdout. for _, partition := range partitions { @@ -230,24 +236,24 @@ func handleMessage(msg *sarama.ConsumerMessage, mu *sync.Mutex) { var err error if protoType != "" { - dataToDisplay, err = protoDecode(reg, msg.Value, protoType) + dataToDisplay, err = protoDecoder.Decode(msg.Value) if err != nil { - fmt.Fprintf(&stderr, "failed to decode proto. falling back to binary outputla. Error: %v\n", err) + fmt.Fprintf(&stderr, "failed to decode proto. falling back to binary output. Error: %v\n", err) } } else { - dataToDisplay, err = avroDecode(msg.Value) + dataToDisplay, err = avroDecoder.Decode(msg.Value) if err != nil { fmt.Fprintf(&stderr, "could not decode Avro data: %v\n", err) } } if keyProtoType != "" { - keyToDisplay, err = protoDecode(reg, msg.Key, keyProtoType) + keyToDisplay, err = protoKeyDecoder.Decode(msg.Key) if err != nil { - fmt.Fprintf(&stderr, "failed to decode proto key. falling back to binary outputla. Error: %v\n", err) + fmt.Fprintf(&stderr, "failed to decode proto key. falling back to binary output. Error: %v\n", err) } } else { - keyToDisplay, err = avroDecode(msg.Key) + keyToDisplay, err = avroDecoder.Decode(msg.Key) if err != nil { fmt.Fprintf(&stderr, "could not decode Avro data: %v\n", err) } @@ -314,36 +320,6 @@ func handleMessage(msg *sarama.ConsumerMessage, mu *sync.Mutex) { } -// proto to JSON -func protoDecode(reg *proto.DescriptorRegistry, b []byte, _type string) ([]byte, error) { - dynamicMessage := reg.MessageForType(_type) - if dynamicMessage == nil { - return b, nil - } - - err := dynamicMessage.Unmarshal(b) - if err != nil { - return nil, err - } - - var m jsonpb.Marshaler - var w bytes.Buffer - - err = m.Marshal(&w, dynamicMessage) - if err != nil { - return nil, err - } - return w.Bytes(), nil - -} - -func avroDecode(b []byte) ([]byte, error) { - if schemaCache != nil { - return schemaCache.DecodeMessage(b) - } - return b, nil -} - func formatKey(key []byte) []byte { if b, err := keyfmt.Format(key); err == nil { return b diff --git a/cmd/kaf/kaf.go b/cmd/kaf/kaf.go index ec43e576..9e9a2a79 100644 --- a/cmd/kaf/kaf.go +++ b/cmd/kaf/kaf.go @@ -15,8 +15,8 @@ import ( "github.com/spf13/cobra" "github.com/birdayz/kaf/pkg/avro" + "github.com/birdayz/kaf/pkg/codec" "github.com/birdayz/kaf/pkg/config" - "github.com/birdayz/kaf/pkg/proto" ) var cfgFile string @@ -174,7 +174,7 @@ func init() { var setupProtoDescriptorRegistry = func(cmd *cobra.Command, args []string) { if protoType != "" { - r, err := proto.NewDescriptorRegistry(protoFiles, protoExclude) + r, err := codec.NewDescriptorRegistry(protoFiles, protoExclude) if err != nil { errorExit("Failed to load protobuf files: %v\n", err) } diff --git a/cmd/kaf/produce.go b/cmd/kaf/produce.go index 3bec3587..66ee2bbe 100644 --- a/cmd/kaf/produce.go +++ b/cmd/kaf/produce.go @@ -15,8 +15,8 @@ import ( "github.com/Masterminds/sprig" "github.com/Shopify/sarama" + "github.com/birdayz/kaf/pkg/codec" "github.com/birdayz/kaf/pkg/partitioner" - pb "github.com/golang/protobuf/proto" "github.com/spf13/cobra" ) @@ -32,6 +32,7 @@ var ( inputModeFlag string avroSchemaID int avroKeySchemaID int + avroStrictFlag bool templateFlag bool ) @@ -54,6 +55,7 @@ func init() { produceCmd.Flags().IntVarP(&avroSchemaID, "avro-schema-id", "", -1, "Value schema id for avro messsage encoding") produceCmd.Flags().IntVarP(&avroKeySchemaID, "avro-key-schema-id", "", -1, "Key schema id for avro messsage encoding") + produceCmd.Flags().BoolVar(&avroStrictFlag, "avro-strict", false, "Uses strict version of the input json to parse unions") produceCmd.Flags().StringVarP(&inputModeFlag, "input-mode", "", "line", "Scanning input mode: [line|full]") produceCmd.Flags().IntVarP(&bufferSizeFlag, "line-length-limit", "", 0, "line length limit in line input mode") @@ -87,6 +89,26 @@ func readFull(reader io.Reader, out chan []byte) { close(out) } +func valueEncoder() codec.Encoder { + if protoType != "" { + return codec.NewProtoCodec(protoType, reg) + } else if avroSchemaID != -1 { + return codec.NewAvroCodec(avroSchemaID, avroStrictFlag, schemaCache) + } else { + return &codec.BypassCodec{} + } +} + +func keyEncoder() codec.Encoder { + if keyProtoType != "" { + return codec.NewProtoCodec(keyProtoType, reg) + } else if avroKeySchemaID != -1 { + return codec.NewAvroCodec(avroKeySchemaID, avroStrictFlag, schemaCache) + } else { + return &codec.BypassCodec{} + } +} + var produceCmd = &cobra.Command{ Use: "produce TOPIC", Short: "Produce record. Reads data from stdin.", @@ -128,6 +150,9 @@ var produceCmd = &cobra.Command{ go readLines(inReader, out) } + valueEncoder := valueEncoder() + keyEncoder := keyEncoder() + var key sarama.Encoder if rawKeyFlag { keyBytes, err := base64.RawStdEncoding.DecodeString(keyFlag) @@ -136,31 +161,11 @@ var produceCmd = &cobra.Command{ } key = sarama.ByteEncoder(keyBytes) } else { - key = sarama.StringEncoder(keyFlag) - } - if keyProtoType != "" { - if dynamicMessage := reg.MessageForType(keyProtoType); dynamicMessage != nil { - err = dynamicMessage.UnmarshalJSON([]byte(keyFlag)) - if err != nil { - errorExit("Failed to parse input JSON as proto type %v: %v", protoType, err) - } - - pb, err := pb.Marshal(dynamicMessage) - if err != nil { - errorExit("Failed to marshal proto: %v", err) - } - - key = sarama.ByteEncoder(pb) - } else { - errorExit("Failed to load key proto type") - } - - } else if avroKeySchemaID != -1 { - avroKey, err := schemaCache.EncodeMessage(avroKeySchemaID, []byte(keyFlag)) + encodedKey, err := keyEncoder.Encode([]byte(keyFlag)) if err != nil { - errorExit("Failed to encode avro key", err) + errorExit("Error encoding key: %v", err) } - key = sarama.ByteEncoder(avroKey) + key = sarama.ByteEncoder(encodedKey) } var headers []sarama.RecordHeader @@ -175,30 +180,6 @@ var produceCmd = &cobra.Command{ } for data := range out { - if protoType != "" { - if dynamicMessage := reg.MessageForType(protoType); dynamicMessage != nil { - err = dynamicMessage.UnmarshalJSON(data) - if err != nil { - errorExit("Failed to parse input JSON as proto type %v: %v", protoType, err) - } - - pb, err := pb.Marshal(dynamicMessage) - if err != nil { - errorExit("Failed to marshal proto: %v", err) - } - - data = pb - } else { - errorExit("Failed to load payload proto type") - } - } else if avroSchemaID != -1 { - avro, err := schemaCache.EncodeMessage(avroSchemaID, data) - if err != nil { - errorExit("Failed to encode avro value", err) - } - data = avro - } - var ts time.Time t, err := time.Parse(time.RFC3339, timestampFlag) if err != nil { @@ -230,6 +211,11 @@ var produceCmd = &cobra.Command{ input = buf.Bytes() } + input, err = valueEncoder.Encode(input) + if err != nil { + errorExit("Error encoding value: %v", err) + } + msg := &sarama.ProducerMessage{ Topic: args[0], Key: key, diff --git a/cmd/kaf/query.go b/cmd/kaf/query.go index d907095d..2ef170e3 100644 --- a/cmd/kaf/query.go +++ b/cmd/kaf/query.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/Shopify/sarama" + "github.com/birdayz/kaf/pkg/codec" "github.com/spf13/cobra" ) @@ -46,6 +47,8 @@ var queryCmd = &cobra.Command{ } schemaCache = getSchemaCache() + protoDecoder := codec.NewProtoCodec(protoType, reg) + protoKeyDecoder := codec.NewProtoCodec(keyProtoType, reg) wg := sync.WaitGroup{} @@ -72,7 +75,7 @@ var queryCmd = &cobra.Command{ var keyTextRaw string var valueTextRaw string if protoType != "" { - d, err := protoDecode(reg, msg.Value, protoType) + d, err := protoDecoder.Decode(msg.Value) if err != nil { fmt.Println("Failed proto decode") } @@ -82,7 +85,7 @@ var queryCmd = &cobra.Command{ } if keyProtoType != "" { - d, err := protoDecode(reg, msg.Key, keyProtoType) + d, err := protoKeyDecoder.Decode(msg.Key) if err != nil { fmt.Println("Failed proto decode") } diff --git a/pkg/avro/schema.go b/pkg/avro/schema.go index 7b02214e..00e93f32 100644 --- a/pkg/avro/schema.go +++ b/pkg/avro/schema.go @@ -1,7 +1,6 @@ package avro import ( - "encoding/binary" "sync" schemaregistry "github.com/Landoop/schema-registry" @@ -38,7 +37,7 @@ func NewSchemaCache(url string) (*SchemaCache, error) { } // getCodecForSchemaID returns a goavro codec for transforming data. -func (c *SchemaCache) getCodecForSchemaID(schemaID int) (codec *goavro.Codec, err error) { +func (c *SchemaCache) GetCodecForSchemaID(schemaID int, strict bool) (codec *goavro.Codec, err error) { c.mu.RLock() cc, ok := c.codecsBySchemaID[schemaID] c.mu.RUnlock() @@ -74,67 +73,14 @@ func (c *SchemaCache) getCodecForSchemaID(schemaID int) (codec *goavro.Codec, er return nil, err } - codec, err = goavro.NewCodec(schema) - if err != nil { - return nil, err - } - - return codec, nil -} - -// DecodeMessage returns a text representation of an Avro-encoded message. -func (c *SchemaCache) DecodeMessage(b []byte) (message []byte, err error) { - // Ensure avro header is present with the magic start-byte. - if len(b) < 5 || b[0] != 0x00 { - // The message does not contain Avro-encoded data - return b, nil + if strict { + codec, err = goavro.NewCodec(schema) + } else { + codec, err = goavro.NewCodecForStandardJSON(schema) } - - // Schema ID is stored in the 4 bytes following the magic byte. - schemaID := binary.BigEndian.Uint32(b[1:5]) - codec, err := c.getCodecForSchemaID(int(schemaID)) - if err != nil { - return b, err - } - - // Convert binary Avro data back to native Go form - native, _, err := codec.NativeFromBinary(b[5:]) - if err != nil { - return b, err - } - - // Convert native Go form to textual Avro data - message, err = codec.TextualFromNative(nil, native) - if err != nil { - return b, err - } - - return message, nil -} - -// EncodeMessage returns a binary representation of an Avro-encoded message. -func (c *SchemaCache) EncodeMessage(schemaID int, json []byte) (message []byte, err error) { - codec, err := c.getCodecForSchemaID(schemaID) - if err != nil { - return nil, err - } - - // Creates a header with an initial zero byte and - // the schema id encoded as a big endian uint32 - buf := make([]byte, 5) - binary.BigEndian.PutUint32(buf[1:5], uint32(schemaID)) - - // Convert textual json data to native Go form - native, _, err := codec.NativeFromTextual(json) if err != nil { return nil, err } - // Convert native Go form to binary Avro data - message, err = codec.BinaryFromNative(buf, native) - if err != nil { - return nil, err - } - - return message, nil + return codec, nil } diff --git a/pkg/codec/avro.go b/pkg/codec/avro.go new file mode 100644 index 00000000..f2ddfde6 --- /dev/null +++ b/pkg/codec/avro.go @@ -0,0 +1,77 @@ +package codec + +import ( + "encoding/binary" + "encoding/json" + + "github.com/birdayz/kaf/pkg/avro" +) + +// AvroCodec implements the Encoder/Decoder interfaces for +// avro formats +type AvroCodec struct { + encodeSchemaID int + strict bool + schemaCache *avro.SchemaCache +} + +func NewAvroCodec(schemaID int, strict bool, cache *avro.SchemaCache) *AvroCodec { + return &AvroCodec{schemaID, strict, cache} +} + +// Encode returns a binary representation of an Avro-encoded message. +func (a *AvroCodec) Encode(in json.RawMessage) ([]byte, error) { + codec, err := a.schemaCache.GetCodecForSchemaID(a.encodeSchemaID, a.strict) + if err != nil { + return nil, err + } + + // Creates a header with an initial zero byte and + // the schema id encoded as a big endian uint32 + buf := make([]byte, 5) + binary.BigEndian.PutUint32(buf[1:5], uint32(a.encodeSchemaID)) + + // Convert textual json data to native Go form + native, _, err := codec.NativeFromTextual(in) + if err != nil { + return nil, err + } + + // Convert native Go form to binary Avro data + message, err := codec.BinaryFromNative(buf, native) + if err != nil { + return nil, err + } + + return message, nil +} + +// Decode returns a text representation of an Avro-encoded message. +func (a *AvroCodec) Decode(in []byte) (json.RawMessage, error) { + // Ensure avro header is present with the magic start-byte. + if len(in) < 5 || in[0] != 0x00 { + // The message does not contain Avro-encoded data + return in, nil + } + + // Schema ID is stored in the 4 bytes following the magic byte. + schemaID := binary.BigEndian.Uint32(in[1:5]) + codec, err := a.schemaCache.GetCodecForSchemaID(int(schemaID), a.strict) + if err != nil { + return in, err + } + + // Convert binary Avro data back to native Go form + native, _, err := codec.NativeFromBinary(in[5:]) + if err != nil { + return in, err + } + + // Convert native Go form to textual Avro data + message, err := codec.TextualFromNative(nil, native) + if err != nil { + return in, err + } + + return message, nil +} diff --git a/pkg/codec/codec.go b/pkg/codec/codec.go new file mode 100644 index 00000000..9e0c8c77 --- /dev/null +++ b/pkg/codec/codec.go @@ -0,0 +1,27 @@ +package codec + +import "encoding/json" + +// Encoder converts from json representation +// to bytes in the specified format +type Encoder interface { + // Encode json to binary format + Encode(in json.RawMessage) ([]byte, error) +} + +// Decoder converts from binary representation to json +type Decoder interface { + // Decode binary to json form + Decode(in []byte) (json.RawMessage, error) +} + +// BypassCodec is a no-op implementation of Encoder and Decoder +type BypassCodec struct{} + +func (BypassCodec) Encode(in json.RawMessage) ([]byte, error) { + return in, nil +} + +func (BypassCodec) Decode(in json.RawMessage) ([]byte, error) { + return in, nil +} diff --git a/pkg/proto/proto.go b/pkg/codec/proto.go similarity index 53% rename from pkg/proto/proto.go rename to pkg/codec/proto.go index ca157401..bab9b89f 100644 --- a/pkg/proto/proto.go +++ b/pkg/codec/proto.go @@ -1,11 +1,15 @@ -package proto +package codec import ( + "bytes" + "encoding/json" + "fmt" "os" "path/filepath" - "strings" + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/desc/protoparse" "github.com/jhump/protoreflect/dynamic" @@ -73,3 +77,53 @@ func (d *DescriptorRegistry) MessageForType(_type string) *dynamic.Message { } return nil } + +// ProtoCodec implements the Encoder/Decoder interfaces +// for protobuf messages +type ProtoCodec struct { + registry *DescriptorRegistry + protoType string +} + +func NewProtoCodec(protoType string, registry *DescriptorRegistry) *ProtoCodec { + return &ProtoCodec{registry, protoType} +} + +func (p *ProtoCodec) Encode(in json.RawMessage) ([]byte, error) { + if dynamicMessage := p.registry.MessageForType(p.protoType); dynamicMessage != nil { + err := dynamicMessage.UnmarshalJSON(in) + if err != nil { + return nil, fmt.Errorf("failed to parse input JSON as proto type %v: %v", p.protoType, err) + } + + pb, err := proto.Marshal(dynamicMessage) + if err != nil { + return nil, fmt.Errorf("failed to marshal proto: %v", err) + } + + return pb, nil + } else { + return nil, fmt.Errorf("failed to load payload proto type: %v", p.protoType) + } +} + +func (p *ProtoCodec) Decode(in []byte) (json.RawMessage, error) { + dynamicMessage := p.registry.MessageForType(p.protoType) + if dynamicMessage == nil { + return in, nil + } + + err := dynamicMessage.Unmarshal(in) + if err != nil { + return nil, err + } + + var m jsonpb.Marshaler + var w bytes.Buffer + + err = m.Marshal(&w, dynamicMessage) + if err != nil { + return nil, err + } + return w.Bytes(), nil +} diff --git a/pkg/codec/user.avro b/pkg/codec/user.avro new file mode 100644 index 00000000..67505e5c --- /dev/null +++ b/pkg/codec/user.avro @@ -0,0 +1,8 @@ +{ + "name": "user", + "type": "record", + "fields": [ + { "name": "id", "type": "int" }, + { "name": "name", "type": "string" } + ] +} diff --git a/pkg/codec/user.proto b/pkg/codec/user.proto new file mode 100644 index 00000000..28e6d491 --- /dev/null +++ b/pkg/codec/user.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +message Id { + int32 id = 1; +} + +message User { + int32 id = 1; + string name = 2; +}