diff --git a/pkg/avro/schema.go b/pkg/avro/schema.go index b1ac8552..76040314 100644 --- a/pkg/avro/schema.go +++ b/pkg/avro/schema.go @@ -21,10 +21,14 @@ type cachedCodec struct { type SchemaCache struct { client *schemaregistry.Client - mu sync.RWMutex - codecsBySchemaID map[int]*cachedCodec + mu sync.RWMutex + codecsBySchemaID map[int]*cachedCodec } +// avroMagicByte is the first byte of an Avro-encoded message, used to identify +// Avro data format in the schema registry protocol. +const avroMagicByte = 0x00 + type transport struct { underlyingTransport http.RoundTripper encodedCredentials string @@ -110,7 +114,7 @@ func (c *SchemaCache) getCodecForSchemaID(schemaID int) (codec *goavro.Codec, er // DecodeMessage returns a text representation of an Avro-encoded message. func (c *SchemaCache) DecodeMessage(b []byte) (message []byte, err error) { // Ensure avro header is present with the magic start-byte. - if len(b) < 5 || b[0] != 0x00 { + if len(b) < 5 || b[0] != avroMagicByte { // The message does not contain Avro-encoded data return b, nil } diff --git a/pkg/config/confluent_cloud.go b/pkg/config/confluent_cloud.go index 0514f5c8..531f376b 100644 --- a/pkg/config/confluent_cloud.go +++ b/pkg/config/confluent_cloud.go @@ -15,17 +15,15 @@ import ( var defaultCcloudSubpath = filepath.Join(".ccloud", "config") func TryFindCcloudConfigFile() (string, error) { - homedir, err := homedir.Dir() + home, err := homedir.Dir() if err != nil { - return "", err } - absoluteDefaultPath := filepath.Join(homedir, defaultCcloudSubpath) + absoluteDefaultPath := filepath.Join(home, defaultCcloudSubpath) _, err = os.Stat(absoluteDefaultPath) if err == nil { - return absoluteDefaultPath, nil } return "", os.ErrNotExist @@ -66,7 +64,7 @@ func ParseConfluentCloudConfig(path string) (username, password, broker string, } if !jaasOk { - return "", "", "", errors.New("Could not parse sasl.jaas.config from ccloud") + return "", "", "", errors.New("could not parse sasl.jaas.config from ccloud") } broker = p["bootstrap.servers"] diff --git a/pkg/proto/proto.go b/pkg/proto/proto.go index ca157401..5ce72cab 100644 --- a/pkg/proto/proto.go +++ b/pkg/proto/proto.go @@ -16,7 +16,7 @@ type DescriptorRegistry struct { } func NewDescriptorRegistry(importPaths []string, exclusions []string) (*DescriptorRegistry, error) { - p := &protoparse.Parser{ + parser := &protoparse.Parser{ ImportPaths: importPaths, } @@ -41,23 +41,24 @@ func NewDescriptorRegistry(importPaths []string, exclusions []string) (*Descript return nil, err } + exclusionSet := make(map[string]struct{}, len(exclusions)) + for _, exclusion := range exclusions { + exclusionSet[exclusion] = struct{}{} + } + var deduped []string + seen := make(map[string]struct{}) for _, i := range resolved { - - var exclusionFound bool - for _, exclusion := range exclusions { - if strings.HasPrefix(i, exclusion) { - exclusionFound = true - break - } + if _, excluded := exclusionSet[i]; excluded { + continue } - - if !exclusionFound { + if _, ok := seen[i]; !ok { + seen[i] = struct{}{} deduped = append(deduped, i) } } - descs, err := p.ParseFiles(deduped...) + descs, err := parser.ParseFiles(deduped...) if err != nil { return nil, err } diff --git a/pkg/streams/decoder.go b/pkg/streams/decoder.go index 8a86f356..3d465a3a 100644 --- a/pkg/streams/decoder.go +++ b/pkg/streams/decoder.go @@ -3,16 +3,11 @@ package streams import ( "encoding/binary" "errors" - "math" ) var errInvalidArrayLength = errors.New("invalid array length") var errInvalidByteSliceLength = errors.New("invalid byteslice length") - -//var errInvalidByteSliceLengthType = errors.New("invalid byteslice length type") var errInvalidStringLength = errors.New("invalid string length") - -//var errInvalidSubsetSize = errors.New("invalid subset size") var errVarintOverflow = errors.New("varint overflow") var errInvalidBool = errors.New("invalid bool") @@ -122,11 +117,12 @@ func (rd *realDecoder) getArrayLength() (int, error) { } tmp := int(int32(binary.BigEndian.Uint32(rd.raw[rd.off:]))) rd.off += 4 + if tmp < 0 { + return -1, errInvalidArrayLength + } if tmp > rd.remaining() { rd.off = len(rd.raw) return -1, ErrInsufficientData - } else if tmp > 2*math.MaxUint16 { - return -1, errInvalidArrayLength } return tmp, nil } diff --git a/pkg/streams/subscription_info.go b/pkg/streams/subscription_info.go index 0ed50b98..6bc2a797 100644 --- a/pkg/streams/subscription_info.go +++ b/pkg/streams/subscription_info.go @@ -25,6 +25,9 @@ func (s *SubscriptionInfo) Decode(pd PacketDecoder) (err error) { if err != nil { return err } + if numPrevs < 0 { + return errInvalidArrayLength + } for i := 0; i < int(numPrevs); i++ { t := TaskID{} @@ -46,6 +49,9 @@ func (s *SubscriptionInfo) Decode(pd PacketDecoder) (err error) { if err != nil { return err } + if numStandby < 0 { + return errInvalidArrayLength + } for i := 0; i < int(numStandby); i++ { t := TaskID{}