diff --git a/deep.go b/copy.go similarity index 73% rename from deep.go rename to copy.go index 82ffe18..672bcf4 100644 --- a/deep.go +++ b/copy.go @@ -5,24 +5,31 @@ import ( "reflect" ) +// Copier is an interface that types can implement to provide their own +// custom deep copy logic. The type T in Copy() (T, error) must be the +// same concrete type as the receiver that implements this interface. +type Copier[T any] interface { + Copy() (T, error) +} + // Copy creates a deep copy of src. It returns the copy and a nil error in case // of success and the zero value for the type and a non-nil error on failure. func Copy[T any](src T) (T, error) { - return copy(src, false) + return copyInternal(src, false) } // CopySkipUnsupported creates a deep copy of src. It returns the copy and a nil -// errorin case of success and the zero value for the type and a non-nil error +// error in case of success and the zero value for the type and a non-nil error // on failure. Unsupported types are skipped (the copy will have the zero value // for the type) instead of returning an error. func CopySkipUnsupported[T any](src T) (T, error) { - return copy(src, true) + return copyInternal(src, true) } // MustCopy creates a deep copy of src. It returns the copy on success or panics // in case of any failure. func MustCopy[T any](src T) T { - dst, err := copy(src, false) + dst, err := copyInternal(src, false) if err != nil { panic(err) } @@ -36,17 +43,58 @@ type pointersMapKey struct { } type pointersMap map[pointersMapKey]reflect.Value -func copy[T any](src T, skipUnsupported bool) (T, error) { +func copyInternal[T any](src T, skipUnsupported bool) (T, error) { v := reflect.ValueOf(src) - // We might have a zero value, so we check for this here otherwise - // calling interface below will panic. - if v.Kind() == reflect.Invalid { + // If src is the zero value for its type (e.g. an uninitialized interface, + // or if T is 'any' and src is its zero value), v will be invalid. + if !v.IsValid() { // This amounts to returning the zero value for T. var t T return t, nil } + // Attempt to use Copier interface if src is suitable: + // - A value type (struct, int, etc.) + // - A non-nil pointer type + // - A non-nil interface type + // This logic avoids trying to call Copy() on a nil receiver if T itself + // is a pointer or interface type that is nil. + attemptCopier := false + srcKind := v.Kind() + if srcKind != reflect.Interface && srcKind != reflect.Ptr { + attemptCopier = true + } else { + // Pointers or interface types are candidates only if they are not nil + if !v.IsNil() { + attemptCopier = true + } + } + + if attemptCopier { + srcType := v.Type() + + // If T is an interface or pointer type, converting src to 'any' is generally + // non-allocating for src's underlying data. + if srcKind == reflect.Interface || srcKind == reflect.Ptr { + if copier, ok := any(src).(Copier[T]); ok { + return copier.Copy() + } + } else { + // T is a value type (e.g. struct, array, basic type). + // The any(src) conversion might allocate. + // Check Implements first to avoid this allocation if T doesn't implement Copier[T]. + copierInterfaceType := reflect.TypeOf((*Copier[T])(nil)).Elem() + if srcType.Implements(copierInterfaceType) { + // T implements Copier[T]. Now the type assertion (and potential allocation) + // is justified as we expect to call the custom method. + if copier, ok := any(src).(Copier[T]); ok { + return copier.Copy() + } + } + } + } + dst, err := recursiveCopy(v, make(pointersMap), skipUnsupported) if err != nil { diff --git a/deep_test.go b/copy_test.go similarity index 65% rename from deep_test.go rename to copy_test.go index 64d553d..0990621 100644 --- a/deep_test.go +++ b/copy_test.go @@ -1,6 +1,7 @@ package deep import ( + "fmt" "reflect" "testing" "unsafe" @@ -355,3 +356,119 @@ func TestTrickyMemberPointer(t *testing.T) { doCopyAndCheck(t, bar, false) } + +type CustomTypeForCopier struct { + Value int + F func() // Normally unsupported if non-nil. +} + +var ( + customTypeCopyCalled bool + customTypeCopyErrored bool + customPtrTypeCopyCalled bool +) + +func (ct CustomTypeForCopier) Copy() (CustomTypeForCopier, error) { + customTypeCopyCalled = true + if ct.F != nil && ct.Value == -1 { // Special case to return error + customTypeCopyErrored = true + return CustomTypeForCopier{}, fmt.Errorf("custom copy error for F") + } + // Example custom logic: double value, share function pointer + return CustomTypeForCopier{Value: ct.Value * 2, F: ct.F}, nil +} + +func TestCopy_CustomCopier_ValueReceiver(t *testing.T) { + customTypeCopyCalled = false + customTypeCopyErrored = false + src := CustomTypeForCopier{Value: 10, F: func() {}} + + dst, err := Copy(src) + + if err != nil { + t.Fatalf("Copy failed for CustomCopier: %v", err) + } + if !customTypeCopyCalled { + t.Errorf("Custom Copier method was not called") + } + if customTypeCopyErrored { + t.Errorf("Custom Copier method unexpectedly errored") + } + if dst.Value != 20 { // As per custom logic + t.Errorf("Expected dst.Value to be 20, got %d", dst.Value) + } + if reflect.ValueOf(dst.F).Pointer() != reflect.ValueOf(src.F).Pointer() { + t.Errorf("Expected func to be shallow copied (shared) by custom copier") + } +} + +func TestCopy_CustomCopier_ErrorCase(t *testing.T) { + customTypeCopyCalled = false + customTypeCopyErrored = false + // Trigger error condition in custom copier + src := CustomTypeForCopier{Value: -1, F: func() {}} + + _, err := Copy(src) + + if err == nil { + t.Fatalf("Expected error from custom copier, got nil") + } + if !customTypeCopyCalled { + t.Errorf("Custom Copier method was not called (for error case)") + } + if !customTypeCopyErrored { + t.Errorf("Custom Copier method did not flag error internally") + } + expectedErrorMsg := "custom copy error for F" + if err.Error() != expectedErrorMsg { + t.Errorf("Expected error message '%s', got '%s'", expectedErrorMsg, err.Error()) + } +} + +type CustomPtrTypeForCopier struct { + Value int +} + +func (cpt *CustomPtrTypeForCopier) Copy() (*CustomPtrTypeForCopier, error) { + customPtrTypeCopyCalled = true + if cpt == nil { + // This case should ideally not be hit if the main Copy function guards against it. + return nil, fmt.Errorf("custom Copy() called on nil CustomPtrTypeForCopier receiver") + } + return &CustomPtrTypeForCopier{Value: cpt.Value * 3}, nil +} + +func TestCopy_CustomCopier_PointerReceiver(t *testing.T) { + customPtrTypeCopyCalled = false + src := &CustomPtrTypeForCopier{Value: 5} // T is *CustomPtrTypeForCopier + + dst, err := Copy(src) + + if err != nil { + t.Fatalf("Copy failed for CustomCopier with pointer receiver: %v", err) + } + if !customPtrTypeCopyCalled { + t.Errorf("Custom Copier method (ptr receiver) was not called") + } + if dst.Value != 15 { + t.Errorf("Expected dst.Value to be 15, got %d", dst.Value) + } + if dst == src { + t.Errorf("Expected a new pointer from custom copier, got the same pointer") + } + + // Test that a nil pointer of a type that implements Copier still results in a nil copy + // and does not call the custom Copy method. + customPtrTypeCopyCalled = false + var nilSrc *CustomPtrTypeForCopier + dstNil, errNil := Copy(nilSrc) + if errNil != nil { + t.Fatalf("Copy failed for nil CustomPtrTypeForCopier: %v", errNil) + } + if customPtrTypeCopyCalled { + t.Errorf("Custom Copier method (ptr receiver) was called for nil input, but should not have been") + } + if dstNil != nil { + t.Errorf("Expected nil for copied nil pointer of custom type, got %v", dstNil) + } +}