From e1832313750e244b6973e08935d3b87b131822ad Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Sat, 12 Jul 2025 08:36:50 +0200 Subject: [PATCH 1/3] feat: add generic string hook Signed-off-by: Mark Sagi-Kazar refactor: improve parse func type safety Signed-off-by: Mark Sagi-Kazar feat: add with parse func string to conversion Signed-off-by: Mark Sagi-Kazar refactor: move functions to the end of the file Signed-off-by: Mark Sagi-Kazar refactor: rename some code Signed-off-by: Mark Sagi-Kazar refactor: move some code Signed-off-by: Mark Sagi-Kazar refactor: move some code Signed-off-by: Mark Sagi-Kazar feat: add parse functions Signed-off-by: Mark Sagi-Kazar refactor: unexport parsers Signed-off-by: Mark Sagi-Kazar refactor: stricter type constraint for string hook Signed-off-by: Mark Sagi-Kazar refactor: changes Signed-off-by: Mark Sagi-Kazar test: remove tests Signed-off-by: Mark Sagi-Kazar refactor: use StringParserHookFunc in StringTo functions Signed-off-by: Mark Sagi-Kazar doc: StringTo doc comment Signed-off-by: Mark Sagi-Kazar --- decode_hooks.go | 282 +++--------------------------------- decode_hooks_string.go | 255 ++++++++++++++++++++++++++++++++ decode_hooks_string_test.go | 1 + 3 files changed, 278 insertions(+), 260 deletions(-) create mode 100644 decode_hooks_string.go create mode 100644 decode_hooks_string_test.go diff --git a/decode_hooks.go b/decode_hooks.go index a852a0a0..a6691fb6 100644 --- a/decode_hooks.go +++ b/decode_hooks.go @@ -3,10 +3,6 @@ package mapstructure import ( "encoding" "errors" - "fmt" - "net" - "net/netip" - "net/url" "reflect" "strconv" "strings" @@ -187,23 +183,7 @@ func StringToWeakSliceHookFunc(sep string) DecodeHookFunc { // StringToTimeDurationHookFunc returns a DecodeHookFunc that converts // strings to time.Duration. func StringToTimeDurationHookFunc() DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data any, - ) (any, error) { - if f.Kind() != reflect.String { - return data, nil - } - if t != reflect.TypeOf(time.Duration(5)) { - return data, nil - } - - // Convert it by parsing - d, err := time.ParseDuration(data.(string)) - - return d, wrapTimeParseDurationError(err) - } + return StringParserHookFunc(parseDuration) } // StringToTimeLocationHookFunc returns a DecodeHookFunc that converts @@ -229,69 +209,19 @@ func StringToTimeLocationHookFunc() DecodeHookFunc { // StringToURLHookFunc returns a DecodeHookFunc that converts // strings to *url.URL. func StringToURLHookFunc() DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data any, - ) (any, error) { - if f.Kind() != reflect.String { - return data, nil - } - if t != reflect.TypeOf(&url.URL{}) { - return data, nil - } - - // Convert it by parsing - u, err := url.Parse(data.(string)) - - return u, wrapUrlError(err) - } + return StringParserHookFunc(parseURL) } // StringToIPHookFunc returns a DecodeHookFunc that converts // strings to net.IP func StringToIPHookFunc() DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data any, - ) (any, error) { - if f.Kind() != reflect.String { - return data, nil - } - if t != reflect.TypeOf(net.IP{}) { - return data, nil - } - - // Convert it by parsing - ip := net.ParseIP(data.(string)) - if ip == nil { - return net.IP{}, fmt.Errorf("failed parsing ip") - } - - return ip, nil - } + return StringParserHookFunc(parseIP) } // StringToIPNetHookFunc returns a DecodeHookFunc that converts // strings to net.IPNet func StringToIPNetHookFunc() DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data any, - ) (any, error) { - if f.Kind() != reflect.String { - return data, nil - } - if t != reflect.TypeOf(net.IPNet{}) { - return data, nil - } - - // Convert it by parsing - _, net, err := net.ParseCIDR(data.(string)) - return net, wrapNetParseError(err) - } + return StringParserHookFunc(parseIPNet) } // StringToTimeHookFunc returns a DecodeHookFunc that converts @@ -402,67 +332,19 @@ func TextUnmarshallerHookFunc() DecodeHookFuncType { // StringToNetIPAddrHookFunc returns a DecodeHookFunc that converts // strings to netip.Addr. func StringToNetIPAddrHookFunc() DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data any, - ) (any, error) { - if f.Kind() != reflect.String { - return data, nil - } - if t != reflect.TypeOf(netip.Addr{}) { - return data, nil - } - - // Convert it by parsing - addr, err := netip.ParseAddr(data.(string)) - - return addr, wrapNetIPParseAddrError(err) - } + return StringParserHookFunc(parseNetipAddr) } // StringToNetIPAddrPortHookFunc returns a DecodeHookFunc that converts // strings to netip.AddrPort. func StringToNetIPAddrPortHookFunc() DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data any, - ) (any, error) { - if f.Kind() != reflect.String { - return data, nil - } - if t != reflect.TypeOf(netip.AddrPort{}) { - return data, nil - } - - // Convert it by parsing - addrPort, err := netip.ParseAddrPort(data.(string)) - - return addrPort, wrapNetIPParseAddrPortError(err) - } + return StringParserHookFunc(parseNetipAddrPort) } // StringToNetIPPrefixHookFunc returns a DecodeHookFunc that converts // strings to netip.Prefix. func StringToNetIPPrefixHookFunc() DecodeHookFunc { - return func( - f reflect.Type, - t reflect.Type, - data any, - ) (any, error) { - if f.Kind() != reflect.String { - return data, nil - } - if t != reflect.TypeOf(netip.Prefix{}) { - return data, nil - } - - // Convert it by parsing - prefix, err := netip.ParsePrefix(data.(string)) - - return prefix, wrapNetIPParsePrefixError(err) - } + return StringParserHookFunc(parseNetipPrefix) } // StringToBasicTypeHookFunc returns a DecodeHookFunc that converts @@ -494,183 +376,79 @@ func StringToBasicTypeHookFunc() DecodeHookFunc { // StringToInt8HookFunc returns a DecodeHookFunc that converts // strings to int8. func StringToInt8HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Int8 { - return data, nil - } - - // Convert it by parsing - i64, err := strconv.ParseInt(data.(string), 0, 8) - return int8(i64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseInt8) } // StringToUint8HookFunc returns a DecodeHookFunc that converts // strings to uint8. func StringToUint8HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Uint8 { - return data, nil - } - - // Convert it by parsing - u64, err := strconv.ParseUint(data.(string), 0, 8) - return uint8(u64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseUint8) } // StringToInt16HookFunc returns a DecodeHookFunc that converts // strings to int16. func StringToInt16HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Int16 { - return data, nil - } - - // Convert it by parsing - i64, err := strconv.ParseInt(data.(string), 0, 16) - return int16(i64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseInt16) } // StringToUint16HookFunc returns a DecodeHookFunc that converts // strings to uint16. func StringToUint16HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Uint16 { - return data, nil - } - - // Convert it by parsing - u64, err := strconv.ParseUint(data.(string), 0, 16) - return uint16(u64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseUint16) } // StringToInt32HookFunc returns a DecodeHookFunc that converts // strings to int32. func StringToInt32HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Int32 { - return data, nil - } - - // Convert it by parsing - i64, err := strconv.ParseInt(data.(string), 0, 32) - return int32(i64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseInt32) } // StringToUint32HookFunc returns a DecodeHookFunc that converts // strings to uint32. func StringToUint32HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Uint32 { - return data, nil - } - - // Convert it by parsing - u64, err := strconv.ParseUint(data.(string), 0, 32) - return uint32(u64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseUint32) } // StringToInt64HookFunc returns a DecodeHookFunc that converts // strings to int64. func StringToInt64HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Int64 { - return data, nil - } - - // Convert it by parsing - i64, err := strconv.ParseInt(data.(string), 0, 64) - return int64(i64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseInt64) } // StringToUint64HookFunc returns a DecodeHookFunc that converts // strings to uint64. func StringToUint64HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Uint64 { - return data, nil - } - - // Convert it by parsing - u64, err := strconv.ParseUint(data.(string), 0, 64) - return uint64(u64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseUint64) } // StringToIntHookFunc returns a DecodeHookFunc that converts // strings to int. func StringToIntHookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Int { - return data, nil - } - - // Convert it by parsing - i64, err := strconv.ParseInt(data.(string), 0, 0) - return int(i64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseInt) } // StringToUintHookFunc returns a DecodeHookFunc that converts // strings to uint. func StringToUintHookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Uint { - return data, nil - } - - // Convert it by parsing - u64, err := strconv.ParseUint(data.(string), 0, 0) - return uint(u64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseUint) } // StringToFloat32HookFunc returns a DecodeHookFunc that converts // strings to float32. func StringToFloat32HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Float32 { - return data, nil - } - - // Convert it by parsing - f64, err := strconv.ParseFloat(data.(string), 32) - return float32(f64), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseFloat32) } // StringToFloat64HookFunc returns a DecodeHookFunc that converts // strings to float64. func StringToFloat64HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Float64 { - return data, nil - } - - // Convert it by parsing - f64, err := strconv.ParseFloat(data.(string), 64) - return f64, wrapStrconvNumError(err) - } + return StringParserHookFunc(parseFloat64) } // StringToBoolHookFunc returns a DecodeHookFunc that converts // strings to bool. func StringToBoolHookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Bool { - return data, nil - } - - // Convert it by parsing - b, err := strconv.ParseBool(data.(string)) - return b, wrapStrconvNumError(err) - } + return StringParserHookFunc(parseBool) } // StringToByteHookFunc returns a DecodeHookFunc that converts @@ -688,27 +466,11 @@ func StringToRuneHookFunc() DecodeHookFunc { // StringToComplex64HookFunc returns a DecodeHookFunc that converts // strings to complex64. func StringToComplex64HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Complex64 { - return data, nil - } - - // Convert it by parsing - c128, err := strconv.ParseComplex(data.(string), 64) - return complex64(c128), wrapStrconvNumError(err) - } + return StringParserHookFunc(parseComplex64) } // StringToComplex128HookFunc returns a DecodeHookFunc that converts // strings to complex128. func StringToComplex128HookFunc() DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t.Kind() != reflect.Complex128 { - return data, nil - } - - // Convert it by parsing - c128, err := strconv.ParseComplex(data.(string), 128) - return c128, wrapStrconvNumError(err) - } + return StringParserHookFunc(parseComplex128) } diff --git a/decode_hooks_string.go b/decode_hooks_string.go new file mode 100644 index 00000000..5dffb79e --- /dev/null +++ b/decode_hooks_string.go @@ -0,0 +1,255 @@ +package mapstructure + +import ( + "fmt" + "net" + "net/netip" + "net/url" + "reflect" + "strconv" + "time" +) + +// PrimitiveStringConvertible defines the constraint for primitive types that can be converted from strings. +type PrimitiveStringConvertible interface { + ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | + ~int | ~uint | ~float32 | ~float64 | ~bool | ~complex64 | ~complex128 +} + +// ComplexStringConvertible defines the constraint for complex types that can be converted from strings. +type ComplexStringConvertible interface { + time.Duration | *url.URL | net.IP | *net.IPNet | netip.Addr | netip.AddrPort | netip.Prefix +} + +// StringConvertible defines the constraint for all types that can be converted from strings. +type StringConvertible interface { + PrimitiveStringConvertible | ComplexStringConvertible +} + +// StringToHookFuncWithParser creates a DecodeHookFunc that converts strings to type T +// using the provided parseFunc allowing for custom parsing logic. +// +// Unlike [StringToHookFunc], this function supports tilde types (~int8, ~uint8, etc.) +// which allows it to work with custom type aliases at compile time: +// +// type MyInt int32 +// customParser := func(s string) (MyInt, error) { +// val, err := strconv.ParseInt(s, 0, 32) +// return MyInt(val), err +// } +// hook := StringParserHookFunc(customParser) +func StringParserHookFunc[T StringConvertible](parseFunc func(string) (T, error)) DecodeHookFunc { + var zero T + expectedType := reflect.TypeOf(zero) + + return func(f reflect.Type, t reflect.Type, data any) (any, error) { + if f.Kind() != reflect.String { + return data, nil + } + + // Type checking with special case for net.IPNet + if expectedType == reflect.TypeOf((*net.IPNet)(nil)) { + expectedType = reflect.TypeOf(net.IPNet{}) + } + + if t != expectedType { + return data, nil + } + + return parseFunc(data.(string)) + } +} + +// ExactPrimitiveStringConvertible defines the constraint for primitive types that can be converted from strings. +type ExactPrimitiveStringConvertible interface { + int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | + int | uint | float32 | float64 | bool | complex64 | complex128 +} + +// ExactStringConvertible defines the constraint for exact types (no tilde) that can be converted from strings. +// This is used by StringToHookFunc to prevent type alias compilation issues. +type ExactStringConvertible interface { + ExactPrimitiveStringConvertible | ComplexStringConvertible +} + +// StringToHookFunc is a generic decode hook for converting strings. +func StringToHookFunc[T ExactStringConvertible]() DecodeHookFunc { + return StringParserHookFunc(getParseFunc[T]()) +} + +// getParseFunc returns the appropriate parsing function for the given type T. +// This function encapsulates the type switch logic that determines which parser to use. +func getParseFunc[T ExactStringConvertible]() func(string) (T, error) { + var zero T + + switch any(zero).(type) { + case int8: + return genericParseWrapper[T](parseInt8) + case uint8: + return genericParseWrapper[T](parseUint8) + case int16: + return genericParseWrapper[T](parseInt16) + case uint16: + return genericParseWrapper[T](parseUint16) + case int32: + return genericParseWrapper[T](parseInt32) + case uint32: + return genericParseWrapper[T](parseUint32) + case int64: + return genericParseWrapper[T](parseInt64) + case uint64: + return genericParseWrapper[T](parseUint64) + case int: + return genericParseWrapper[T](parseInt) + case uint: + return genericParseWrapper[T](parseUint) + case float32: + return genericParseWrapper[T](parseFloat32) + case float64: + return genericParseWrapper[T](parseFloat64) + case bool: + return genericParseWrapper[T](parseBool) + case complex64: + return genericParseWrapper[T](parseComplex64) + case complex128: + return genericParseWrapper[T](parseComplex128) + case time.Duration: + return genericParseWrapper[T](parseDuration) + case *url.URL: + return genericParseWrapper[T](parseURL) + case net.IP: + return genericParseWrapper[T](parseIP) + case *net.IPNet: + return genericParseWrapper[T](parseIPNet) + case netip.Addr: + return genericParseWrapper[T](parseNetipAddr) + case netip.AddrPort: + return genericParseWrapper[T](parseNetipAddrPort) + case netip.Prefix: + return genericParseWrapper[T](parseNetipPrefix) + default: + // This should never happen due to the type constraint + panic("unsupported type for string conversion") + } +} + +// genericParseWrapper creates a generic wrapper for the specific parse functions +func genericParseWrapper[T StringConvertible, U any](parseFunc func(string) (U, error)) func(string) (T, error) { + return func(str string) (T, error) { + val, err := parseFunc(str) + return any(val).(T), err + } +} + +func parseInt8(str string) (int8, error) { + v, err := strconv.ParseInt(str, 0, 8) + return int8(v), wrapStrconvNumError(err) +} + +func parseUint8(str string) (uint8, error) { + v, err := strconv.ParseUint(str, 0, 8) + return uint8(v), wrapStrconvNumError(err) +} + +func parseInt16(str string) (int16, error) { + v, err := strconv.ParseInt(str, 0, 16) + return int16(v), wrapStrconvNumError(err) +} + +func parseUint16(str string) (uint16, error) { + v, err := strconv.ParseUint(str, 0, 16) + return uint16(v), wrapStrconvNumError(err) +} + +func parseInt32(str string) (int32, error) { + v, err := strconv.ParseInt(str, 0, 32) + return int32(v), wrapStrconvNumError(err) +} + +func parseUint32(str string) (uint32, error) { + v, err := strconv.ParseUint(str, 0, 32) + return uint32(v), wrapStrconvNumError(err) +} + +func parseInt64(str string) (int64, error) { + v, err := strconv.ParseInt(str, 0, 64) + return int64(v), wrapStrconvNumError(err) +} + +func parseUint64(str string) (uint64, error) { + v, err := strconv.ParseUint(str, 0, 64) + return uint64(v), wrapStrconvNumError(err) +} + +func parseInt(str string) (int, error) { + v, err := strconv.ParseInt(str, 0, 0) + return int(v), wrapStrconvNumError(err) +} + +func parseUint(str string) (uint, error) { + v, err := strconv.ParseUint(str, 0, 0) + return uint(v), wrapStrconvNumError(err) +} + +func parseFloat32(str string) (float32, error) { + v, err := strconv.ParseFloat(str, 32) + return float32(v), wrapStrconvNumError(err) +} + +func parseFloat64(str string) (float64, error) { + v, err := strconv.ParseFloat(str, 64) + return v, wrapStrconvNumError(err) +} + +func parseBool(str string) (bool, error) { + v, err := strconv.ParseBool(str) + return v, wrapStrconvNumError(err) +} + +func parseComplex64(str string) (complex64, error) { + v, err := strconv.ParseComplex(str, 64) + return complex64(v), wrapStrconvNumError(err) +} + +func parseComplex128(str string) (complex128, error) { + v, err := strconv.ParseComplex(str, 128) + return v, wrapStrconvNumError(err) +} + +func parseDuration(str string) (time.Duration, error) { + v, err := time.ParseDuration(str) + return v, wrapTimeParseDurationError(err) +} + +func parseURL(str string) (*url.URL, error) { + v, err := url.Parse(str) + return v, wrapUrlError(err) +} + +func parseIP(str string) (net.IP, error) { + v := net.ParseIP(str) + if v == nil { + return net.IP{}, fmt.Errorf("failed parsing ip") + } + return v, nil +} + +func parseIPNet(str string) (*net.IPNet, error) { + _, v, err := net.ParseCIDR(str) + return v, wrapNetParseError(err) +} + +func parseNetipAddr(str string) (netip.Addr, error) { + v, err := netip.ParseAddr(str) + return v, wrapNetIPParseAddrError(err) +} + +func parseNetipAddrPort(str string) (netip.AddrPort, error) { + v, err := netip.ParseAddrPort(str) + return v, wrapNetIPParseAddrPortError(err) +} + +func parseNetipPrefix(str string) (netip.Prefix, error) { + v, err := netip.ParsePrefix(str) + return v, wrapNetIPParsePrefixError(err) +} diff --git a/decode_hooks_string_test.go b/decode_hooks_string_test.go new file mode 100644 index 00000000..0d28ec1f --- /dev/null +++ b/decode_hooks_string_test.go @@ -0,0 +1 @@ +package mapstructure From 6968bf818083085c1efcd6fa81a22603b4ffb233 Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Sat, 12 Jul 2025 11:49:39 +0200 Subject: [PATCH 2/3] test: add test to StringTo and StringParser hooks Signed-off-by: Mark Sagi-Kazar --- decode_hooks_string_test.go | 823 ++++++++++++++++++++++++++++++++++++ 1 file changed, 823 insertions(+) diff --git a/decode_hooks_string_test.go b/decode_hooks_string_test.go index 0d28ec1f..36d259da 100644 --- a/decode_hooks_string_test.go +++ b/decode_hooks_string_test.go @@ -1 +1,824 @@ package mapstructure + +import ( + "net" + "net/netip" + "net/url" + "reflect" + "strconv" + "strings" + "testing" + "time" +) + +func TestStringParserHookFunc(t *testing.T) { + t.Run("CustomInt32Parser", func(t *testing.T) { + customParser := func(s string) (int32, error) { + // Custom parser that multiplies by 2 + val, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return 0, err + } + return int32(val * 2), nil + } + + hook := StringParserHookFunc(customParser) + + strValue := reflect.ValueOf("21") + int32Value := reflect.ValueOf(int32(0)) + + result, err := DecodeHookExec(hook, strValue, int32Value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := int32(42) + if result != expected { + t.Fatalf("expected %v, got %v", expected, result) + } + }) + + t.Run("CustomStringToURL", func(t *testing.T) { + customParser := func(s string) (*url.URL, error) { + // Add https:// prefix if not present + if !strings.HasPrefix(s, "http://") && !strings.HasPrefix(s, "https://") { + s = "https://" + s + } + return url.Parse(s) + } + + hook := StringParserHookFunc(customParser) + + strValue := reflect.ValueOf("example.com") + urlValue := reflect.ValueOf(&url.URL{}) + + result, err := DecodeHookExec(hook, strValue, urlValue) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := &url.URL{Scheme: "https", Host: "example.com"} + if !reflect.DeepEqual(result, expected) { + t.Fatalf("expected %v, got %v", expected, result) + } + }) + + t.Run("NonStringSource", func(t *testing.T) { + hook := StringParserHookFunc(func(s string) (int32, error) { + val, err := strconv.ParseInt(s, 10, 32) + return int32(val), err + }) + + intValue := reflect.ValueOf(42) + int32Value := reflect.ValueOf(int32(0)) + + result, err := DecodeHookExec(hook, intValue, int32Value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should return original data unchanged + if result != 42 { + t.Fatalf("expected %v, got %v", 42, result) + } + }) + + t.Run("WrongTargetType", func(t *testing.T) { + hook := StringParserHookFunc(func(s string) (int32, error) { + val, err := strconv.ParseInt(s, 10, 32) + return int32(val), err + }) + + strValue := reflect.ValueOf("42") + int64Value := reflect.ValueOf(int64(0)) + + result, err := DecodeHookExec(hook, strValue, int64Value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should return original data unchanged + if result != "42" { + t.Fatalf("expected %v, got %v", "42", result) + } + }) + + t.Run("ParseError", func(t *testing.T) { + hook := StringParserHookFunc(func(s string) (int32, error) { + val, err := strconv.ParseInt(s, 10, 32) + return int32(val), err + }) + + strValue := reflect.ValueOf("not-a-number") + int32Value := reflect.ValueOf(int32(0)) + + _, err := DecodeHookExec(hook, strValue, int32Value) + if err == nil { + t.Fatal("expected error but got none") + } + }) + + t.Run("IPNetSpecialCase", func(t *testing.T) { + hook := StringParserHookFunc(func(s string) (*net.IPNet, error) { + _, ipnet, err := net.ParseCIDR(s) + return ipnet, err + }) + + strValue := reflect.ValueOf("192.168.1.0/24") + ipnetValue := reflect.ValueOf(net.IPNet{}) + + result, err := DecodeHookExec(hook, strValue, ipnetValue) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedIPNet := &net.IPNet{ + IP: net.IPv4(192, 168, 1, 0), + Mask: net.CIDRMask(24, 32), + } + + resultIPNet, ok := result.(*net.IPNet) + if !ok { + t.Fatalf("expected *net.IPNet, got %T", result) + } + + if !resultIPNet.IP.Equal(expectedIPNet.IP) || !reflect.DeepEqual(resultIPNet.Mask, expectedIPNet.Mask) { + t.Fatalf("expected %v, got %v", expectedIPNet, resultIPNet) + } + }) +} + +func TestStringToHookFunc(t *testing.T) { + t.Run("Int32", func(t *testing.T) { + hook := StringToHookFunc[int32]() + + int32Value := reflect.ValueOf(int32(0)) + + cases := []struct { + input string + expected int32 + hasError bool + }{ + {"42", 42, false}, + {"-42", -42, false}, + {"0", 0, false}, + {"0x2a", 42, false}, + {"052", 42, false}, + {"0b101010", 42, false}, + {"2147483647", 2147483647, false}, + {"-2147483648", -2147483648, false}, + {"2147483648", 0, true}, // overflow + {"-2147483649", 0, true}, // underflow + {"42.5", 0, true}, // float + {"not-a-number", 0, true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, int32Value) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if result != tc.expected { + t.Fatalf("case %d: expected %v, got %v", i, tc.expected, result) + } + } + }) + + t.Run("Float64", func(t *testing.T) { + hook := StringToHookFunc[float64]() + + float64Value := reflect.ValueOf(float64(0)) + + cases := []struct { + input string + expected float64 + hasError bool + }{ + {"42.5", 42.5, false}, + {"-42.5", -42.5, false}, + {"0", 0, false}, + {"0.0", 0.0, false}, + {"3.14159", 3.14159, false}, + {"1e10", 1e10, false}, + {"1.5e-10", 1.5e-10, false}, + {"not-a-number", 0, true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, float64Value) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if result != tc.expected { + t.Fatalf("case %d: expected %v, got %v", i, tc.expected, result) + } + } + }) + + t.Run("Bool", func(t *testing.T) { + hook := StringToHookFunc[bool]() + + boolValue := reflect.ValueOf(false) + + cases := []struct { + input string + expected bool + hasError bool + }{ + {"true", true, false}, + {"false", false, false}, + {"1", true, false}, + {"0", false, false}, + {"t", true, false}, + {"f", false, false}, + {"T", true, false}, + {"F", false, false}, + {"TRUE", true, false}, + {"FALSE", false, false}, + {"True", true, false}, + {"False", false, false}, + {"yes", false, true}, + {"no", false, true}, + {"invalid", false, true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, boolValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if result != tc.expected { + t.Fatalf("case %d: expected %v, got %v", i, tc.expected, result) + } + } + }) + + t.Run("Duration", func(t *testing.T) { + hook := StringToHookFunc[time.Duration]() + + durationValue := reflect.ValueOf(time.Duration(0)) + + cases := []struct { + input string + expected time.Duration + hasError bool + }{ + {"1h", time.Hour, false}, + {"30m", 30 * time.Minute, false}, + {"45s", 45 * time.Second, false}, + {"1h30m45s", time.Hour + 30*time.Minute + 45*time.Second, false}, + {"1000ms", time.Second, false}, + {"1000000us", time.Second, false}, + {"1000000000ns", time.Second, false}, + {"0", 0, false}, + {"invalid", 0, true}, + {"1", 0, true}, // missing unit + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, durationValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if result != tc.expected { + t.Fatalf("case %d: expected %v, got %v", i, tc.expected, result) + } + } + }) + + t.Run("URL", func(t *testing.T) { + hook := StringToHookFunc[*url.URL]() + + urlValue := reflect.ValueOf(&url.URL{}) + + cases := []struct { + input string + expected *url.URL + hasError bool + }{ + { + "https://example.com", + &url.URL{Scheme: "https", Host: "example.com"}, + false, + }, + { + "http://example.com:8080/path?query=value", + &url.URL{ + Scheme: "http", + Host: "example.com:8080", + Path: "/path", + RawQuery: "query=value", + }, + false, + }, + { + "ftp://user:pass@example.com/file.txt", + &url.URL{ + Scheme: "ftp", + User: url.UserPassword("user", "pass"), + Host: "example.com", + Path: "/file.txt", + }, + false, + }, + { + "example.com", // relative URL + &url.URL{Path: "example.com"}, + false, + }, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, urlValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if !reflect.DeepEqual(result, tc.expected) { + t.Fatalf("case %d: expected %v, got %v", i, tc.expected, result) + } + } + }) + + t.Run("NetIP", func(t *testing.T) { + hook := StringToHookFunc[net.IP]() + + ipValue := reflect.ValueOf(net.IP{}) + + cases := []struct { + input string + expected net.IP + hasError bool + }{ + {"192.168.1.1", net.IPv4(192, 168, 1, 1), false}, + {"::1", net.IPv6loopback, false}, + {"2001:db8::1", net.ParseIP("2001:db8::1"), false}, + {"invalid-ip", net.IP{}, true}, + {"", net.IP{}, true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, ipValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if !reflect.DeepEqual(result, tc.expected) { + t.Fatalf("case %d: expected %v, got %v", i, tc.expected, result) + } + } + }) + + t.Run("NetIPNet", func(t *testing.T) { + hook := StringToHookFunc[*net.IPNet]() + + ipnetValue := reflect.ValueOf(net.IPNet{}) + + cases := []struct { + input string + hasError bool + }{ + {"192.168.1.0/24", false}, + {"10.0.0.0/8", false}, + {"2001:db8::/32", false}, + {"192.168.1.1", true}, // single IP, not CIDR + {"192.168.1.0/33", true}, // invalid mask + {"invalid", true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, ipnetValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + // Verify it's a valid IPNet + if result == nil { + t.Fatalf("case %d: expected non-nil result", i) + } + + ipnet, ok := result.(*net.IPNet) + if !ok { + t.Fatalf("case %d: expected *net.IPNet, got %T", i, result) + } + + if ipnet.IP == nil || ipnet.Mask == nil { + t.Fatalf("case %d: invalid IPNet: %v", i, ipnet) + } + } + }) + + t.Run("NetipAddr", func(t *testing.T) { + hook := StringToHookFunc[netip.Addr]() + + addrValue := reflect.ValueOf(netip.Addr{}) + + cases := []struct { + input string + hasError bool + }{ + {"192.168.1.1", false}, + {"::1", false}, + {"2001:db8::1", false}, + {"invalid-ip", true}, + {"", true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, addrValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + // Verify it's a valid netip.Addr + addr, ok := result.(netip.Addr) + if !ok { + t.Fatalf("case %d: expected netip.Addr, got %T", i, result) + } + + if !addr.IsValid() { + t.Fatalf("case %d: invalid netip.Addr: %v", i, addr) + } + } + }) + + t.Run("NetipAddrPort", func(t *testing.T) { + hook := StringToHookFunc[netip.AddrPort]() + + addrPortValue := reflect.ValueOf(netip.AddrPort{}) + + cases := []struct { + input string + hasError bool + }{ + {"192.168.1.1:8080", false}, + {"[::1]:8080", false}, + {"[2001:db8::1]:443", false}, + {"192.168.1.1", true}, // missing port + {"192.168.1.1:99999", true}, // invalid port + {"invalid", true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, addrPortValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + // Verify it's a valid netip.AddrPort + addrPort, ok := result.(netip.AddrPort) + if !ok { + t.Fatalf("case %d: expected netip.AddrPort, got %T", i, result) + } + + if !addrPort.IsValid() { + t.Fatalf("case %d: invalid netip.AddrPort: %v", i, addrPort) + } + } + }) + + t.Run("NetipPrefix", func(t *testing.T) { + hook := StringToHookFunc[netip.Prefix]() + + prefixValue := reflect.ValueOf(netip.Prefix{}) + + cases := []struct { + input string + hasError bool + }{ + {"192.168.1.0/24", false}, + {"10.0.0.0/8", false}, + {"2001:db8::/32", false}, + {"192.168.1.1/32", false}, + {"192.168.1.0/33", true}, // invalid mask for IPv4 + {"2001:db8::/129", true}, // invalid mask for IPv6 + {"192.168.1.1", true}, // missing prefix length + {"invalid", true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, prefixValue) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + // Verify it's a valid netip.Prefix + prefix, ok := result.(netip.Prefix) + if !ok { + t.Fatalf("case %d: expected netip.Prefix, got %T", i, result) + } + + if !prefix.IsValid() { + t.Fatalf("case %d: invalid netip.Prefix: %v", i, prefix) + } + } + }) + + t.Run("Complex64", func(t *testing.T) { + hook := StringToHookFunc[complex64]() + + complex64Value := reflect.ValueOf(complex64(0)) + + cases := []struct { + input string + expected complex64 + hasError bool + }{ + {"1+2i", complex64(1 + 2i), false}, + {"3-4i", complex64(3 - 4i), false}, + {"5", complex64(5 + 0i), false}, + {"0", complex64(0 + 0i), false}, + {"-1", complex64(-1 + 0i), false}, + {"0+1i", complex64(0 + 1i), false}, + {"0-1i", complex64(0 - 1i), false}, + {"invalid", complex64(0), true}, + {"1+", complex64(0), true}, + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(hook, inputValue, complex64Value) + + if tc.hasError { + if err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + continue + } + + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + if result != tc.expected { + t.Fatalf("case %d: expected %v, got %v", i, tc.expected, result) + } + } + }) + + t.Run("NonStringSource", func(t *testing.T) { + hook := StringToHookFunc[int32]() + + intValue := reflect.ValueOf(42) + int32Value := reflect.ValueOf(int32(0)) + + result, err := DecodeHookExec(hook, intValue, int32Value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should return original data unchanged + if result != 42 { + t.Fatalf("expected %v, got %v", 42, result) + } + }) + + t.Run("WrongTargetType", func(t *testing.T) { + hook := StringToHookFunc[int32]() + + strValue := reflect.ValueOf("42") + int64Value := reflect.ValueOf(int64(0)) + + result, err := DecodeHookExec(hook, strValue, int64Value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should return original data unchanged + if result != "42" { + t.Fatalf("expected %v, got %v", "42", result) + } + }) +} + +func TestStringParserHookFuncWithTypeAlias(t *testing.T) { + // Test with type alias to ensure tilde types work correctly + type MyInt int32 + + customParser := func(s string) (MyInt, error) { + val, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return 0, err + } + return MyInt(val), nil + } + + hook := StringParserHookFunc(customParser) + + strValue := reflect.ValueOf("42") + myIntValue := reflect.ValueOf(MyInt(0)) + + result, err := DecodeHookExec(hook, strValue, myIntValue) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := MyInt(42) + if result != expected { + t.Fatalf("expected %v, got %v", expected, result) + } +} + +func TestStringToHookFuncEdgeCases(t *testing.T) { + t.Run("UintOverflow", func(t *testing.T) { + hook := StringToHookFunc[uint8]() + uintValue := reflect.ValueOf(uint8(0)) + + cases := []struct { + input string + hasError bool + }{ + {"0", false}, + {"255", false}, + {"256", true}, // overflow + {"-1", true}, // negative + } + + for i, tc := range cases { + inputValue := reflect.ValueOf(tc.input) + _, err := DecodeHookExec(hook, inputValue, uintValue) + + if tc.hasError && err == nil { + t.Fatalf("case %d: expected error but got none", i) + } + if !tc.hasError && err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + } + }) + + t.Run("EmptyStringHandling", func(t *testing.T) { + t.Run("Int", func(t *testing.T) { + hook := StringToHookFunc[int]() + intValue := reflect.ValueOf(int(0)) + + inputValue := reflect.ValueOf("") + _, err := DecodeHookExec(hook, inputValue, intValue) + if err == nil { + t.Fatal("expected error for empty string") + } + }) + + t.Run("Bool", func(t *testing.T) { + hook := StringToHookFunc[bool]() + boolValue := reflect.ValueOf(false) + + inputValue := reflect.ValueOf("") + _, err := DecodeHookExec(hook, inputValue, boolValue) + if err == nil { + t.Fatal("expected error for empty string") + } + }) + + t.Run("URL", func(t *testing.T) { + hook := StringToHookFunc[*url.URL]() + urlValue := reflect.ValueOf(&url.URL{}) + + inputValue := reflect.ValueOf("") + result, err := DecodeHookExec(hook, inputValue, urlValue) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Empty string should parse to empty URL + expected := &url.URL{} + if !reflect.DeepEqual(result, expected) { + t.Fatalf("expected %v, got %v", expected, result) + } + }) + }) + + t.Run("AllNumericTypes", func(t *testing.T) { + // Test all supported numeric types work correctly + testCases := []struct { + name string + hookFunc DecodeHookFunc + target reflect.Value + input string + expected interface{} + }{ + {"int8", StringToHookFunc[int8](), reflect.ValueOf(int8(0)), "42", int8(42)}, + {"uint8", StringToHookFunc[uint8](), reflect.ValueOf(uint8(0)), "42", uint8(42)}, + {"int16", StringToHookFunc[int16](), reflect.ValueOf(int16(0)), "42", int16(42)}, + {"uint16", StringToHookFunc[uint16](), reflect.ValueOf(uint16(0)), "42", uint16(42)}, + {"int32", StringToHookFunc[int32](), reflect.ValueOf(int32(0)), "42", int32(42)}, + {"uint32", StringToHookFunc[uint32](), reflect.ValueOf(uint32(0)), "42", uint32(42)}, + {"int64", StringToHookFunc[int64](), reflect.ValueOf(int64(0)), "42", int64(42)}, + {"uint64", StringToHookFunc[uint64](), reflect.ValueOf(uint64(0)), "42", uint64(42)}, + {"int", StringToHookFunc[int](), reflect.ValueOf(int(0)), "42", int(42)}, + {"uint", StringToHookFunc[uint](), reflect.ValueOf(uint(0)), "42", uint(42)}, + {"float32", StringToHookFunc[float32](), reflect.ValueOf(float32(0)), "42.5", float32(42.5)}, + {"float64", StringToHookFunc[float64](), reflect.ValueOf(float64(0)), "42.5", float64(42.5)}, + {"complex64", StringToHookFunc[complex64](), reflect.ValueOf(complex64(0)), "1+2i", complex64(1 + 2i)}, + {"complex128", StringToHookFunc[complex128](), reflect.ValueOf(complex128(0)), "1+2i", complex128(1 + 2i)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + inputValue := reflect.ValueOf(tc.input) + result, err := DecodeHookExec(tc.hookFunc, inputValue, tc.target) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result != tc.expected { + t.Fatalf("expected %v (%T), got %v (%T)", tc.expected, tc.expected, result, result) + } + }) + } + }) +} From 1ad4e4e1763f3d85b30ef1a997a907c3ff327508 Mon Sep 17 00:00:00 2001 From: Mark Sagi-Kazar Date: Tue, 9 Dec 2025 18:35:29 +0100 Subject: [PATCH 3/3] chore: fix lint violation Signed-off-by: Mark Sagi-Kazar --- decode_hooks_string_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decode_hooks_string_test.go b/decode_hooks_string_test.go index 36d259da..198f322d 100644 --- a/decode_hooks_string_test.go +++ b/decode_hooks_string_test.go @@ -789,7 +789,7 @@ func TestStringToHookFuncEdgeCases(t *testing.T) { hookFunc DecodeHookFunc target reflect.Value input string - expected interface{} + expected any }{ {"int8", StringToHookFunc[int8](), reflect.ValueOf(int8(0)), "42", int8(42)}, {"uint8", StringToHookFunc[uint8](), reflect.ValueOf(uint8(0)), "42", uint8(42)},