diff --git a/mapstructure.go b/mapstructure.go index 60311d6b..52c58575 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -218,6 +218,11 @@ type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface // values. type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (interface{}, error) +// ValidateHookFunc is the callback function that can be used for +// post-decoding transformations. See "ValidateHook" in the DecoderConfig +// struct. +type ValidateHookFunc func(reflect.Value) error + // DecoderConfig is the configuration that is used to create a new decoder // and allows customization of various aspects of decoding. type DecoderConfig struct { @@ -232,6 +237,16 @@ type DecoderConfig struct { // If an error is returned, the entire decode will fail with that error. DecodeHook DecodeHookFunc + // ValidateHook, if set, will be called after decoding is complete. + // This is useful for types that need a finalization step or validation. + // The ValidateHook is called for every map and value in the input. This + // means that if a struct has embedded fields with squash tags the post + // decode hook is called only once with all of the input data, not once + // for each embedded struct. + // + // If an error is returned, the entire decode will fail with that error. + ValidateHook ValidateHookFunc + // If ErrorUnused is true, then it is an error for there to exist // keys in the original map that were unused in the decoding process // (extra keys). @@ -572,6 +587,13 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) } + // If we have a post-decode hook, then we call it now. + if d.config.ValidateHook != nil { + if err := d.config.ValidateHook(outVal); err != nil { + return fmt.Errorf("error validating '%s': %w", name, err) + } + } + return err } diff --git a/mapstructure_test.go b/mapstructure_test.go index 248afaa8..1cf9f3e2 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -1334,6 +1334,51 @@ func TestDecode_DecodeHookType(t *testing.T) { } } +func TestDecode_ValidateHook(t *testing.T) { + t.Parallel() + + input1 := map[string]interface{}{ + "Vint": 42, + } + input2 := map[string]interface{}{ + "Vint": 43, + } + + validateHook := func(val reflect.Value) error { + iface, ok := val.Interface().(Basic) + if !ok { + return nil + } + + if iface.Vint != 42 { + return errors.New("vint should not be 42") + } + + return nil + } + + var result Basic + config := &DecoderConfig{ + ValidateHook: validateHook, + Result: &result, + } + + decoder, err := NewDecoder(config) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = decoder.Decode(input1) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + err = decoder.Decode(input2) + if err == nil { + t.Fatal("expected an error") + } +} + func TestDecode_Nil(t *testing.T) { t.Parallel()