Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions deep.go → copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,31 @@ import (
"reflect"
)

// Copier is an interface that types can implement to provide their own
// custom deep copy logic. The type T in Copy() (T, error) must be the
// same concrete type as the receiver that implements this interface.
type Copier[T any] interface {
Copy() (T, error)
}

// Copy creates a deep copy of src. It returns the copy and a nil error in case
// of success and the zero value for the type and a non-nil error on failure.
func Copy[T any](src T) (T, error) {
return copy(src, false)
return copyInternal(src, false)
}

// CopySkipUnsupported creates a deep copy of src. It returns the copy and a nil
// errorin case of success and the zero value for the type and a non-nil error
// error in case of success and the zero value for the type and a non-nil error
// on failure. Unsupported types are skipped (the copy will have the zero value
// for the type) instead of returning an error.
func CopySkipUnsupported[T any](src T) (T, error) {
return copy(src, true)
return copyInternal(src, true)
}

// MustCopy creates a deep copy of src. It returns the copy on success or panics
// in case of any failure.
func MustCopy[T any](src T) T {
dst, err := copy(src, false)
dst, err := copyInternal(src, false)
if err != nil {
panic(err)
}
Expand All @@ -36,17 +43,58 @@ type pointersMapKey struct {
}
type pointersMap map[pointersMapKey]reflect.Value

func copy[T any](src T, skipUnsupported bool) (T, error) {
func copyInternal[T any](src T, skipUnsupported bool) (T, error) {
v := reflect.ValueOf(src)

// We might have a zero value, so we check for this here otherwise
// calling interface below will panic.
if v.Kind() == reflect.Invalid {
// If src is the zero value for its type (e.g. an uninitialized interface,
// or if T is 'any' and src is its zero value), v will be invalid.
if !v.IsValid() {
// This amounts to returning the zero value for T.
var t T
return t, nil
}

// Attempt to use Copier interface if src is suitable:
// - A value type (struct, int, etc.)
// - A non-nil pointer type
// - A non-nil interface type
// This logic avoids trying to call Copy() on a nil receiver if T itself
// is a pointer or interface type that is nil.
attemptCopier := false
srcKind := v.Kind()
if srcKind != reflect.Interface && srcKind != reflect.Ptr {
attemptCopier = true
} else {
// Pointers or interface types are candidates only if they are not nil
if !v.IsNil() {
attemptCopier = true
}
}

if attemptCopier {
srcType := v.Type()

// If T is an interface or pointer type, converting src to 'any' is generally
// non-allocating for src's underlying data.
if srcKind == reflect.Interface || srcKind == reflect.Ptr {
if copier, ok := any(src).(Copier[T]); ok {
return copier.Copy()
}
} else {
// T is a value type (e.g. struct, array, basic type).
// The any(src) conversion might allocate.
// Check Implements first to avoid this allocation if T doesn't implement Copier[T].
copierInterfaceType := reflect.TypeOf((*Copier[T])(nil)).Elem()
if srcType.Implements(copierInterfaceType) {
// T implements Copier[T]. Now the type assertion (and potential allocation)
// is justified as we expect to call the custom method.
if copier, ok := any(src).(Copier[T]); ok {
return copier.Copy()
}
}
}
}

dst, err := recursiveCopy(v, make(pointersMap),
skipUnsupported)
if err != nil {
Expand Down
117 changes: 117 additions & 0 deletions deep_test.go → copy_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package deep

import (
"fmt"
"reflect"
"testing"
"unsafe"
Expand Down Expand Up @@ -355,3 +356,119 @@ func TestTrickyMemberPointer(t *testing.T) {

doCopyAndCheck(t, bar, false)
}

type CustomTypeForCopier struct {
Value int
F func() // Normally unsupported if non-nil.
}

var (
customTypeCopyCalled bool
customTypeCopyErrored bool
customPtrTypeCopyCalled bool
)

func (ct CustomTypeForCopier) Copy() (CustomTypeForCopier, error) {
customTypeCopyCalled = true
if ct.F != nil && ct.Value == -1 { // Special case to return error
customTypeCopyErrored = true
return CustomTypeForCopier{}, fmt.Errorf("custom copy error for F")
}
// Example custom logic: double value, share function pointer
return CustomTypeForCopier{Value: ct.Value * 2, F: ct.F}, nil
}

func TestCopy_CustomCopier_ValueReceiver(t *testing.T) {
customTypeCopyCalled = false
customTypeCopyErrored = false
src := CustomTypeForCopier{Value: 10, F: func() {}}

dst, err := Copy(src)

if err != nil {
t.Fatalf("Copy failed for CustomCopier: %v", err)
}
if !customTypeCopyCalled {
t.Errorf("Custom Copier method was not called")
}
if customTypeCopyErrored {
t.Errorf("Custom Copier method unexpectedly errored")
}
if dst.Value != 20 { // As per custom logic
t.Errorf("Expected dst.Value to be 20, got %d", dst.Value)
}
if reflect.ValueOf(dst.F).Pointer() != reflect.ValueOf(src.F).Pointer() {
t.Errorf("Expected func to be shallow copied (shared) by custom copier")
}
}

func TestCopy_CustomCopier_ErrorCase(t *testing.T) {
customTypeCopyCalled = false
customTypeCopyErrored = false
// Trigger error condition in custom copier
src := CustomTypeForCopier{Value: -1, F: func() {}}

_, err := Copy(src)

if err == nil {
t.Fatalf("Expected error from custom copier, got nil")
}
if !customTypeCopyCalled {
t.Errorf("Custom Copier method was not called (for error case)")
}
if !customTypeCopyErrored {
t.Errorf("Custom Copier method did not flag error internally")
}
expectedErrorMsg := "custom copy error for F"
if err.Error() != expectedErrorMsg {
t.Errorf("Expected error message '%s', got '%s'", expectedErrorMsg, err.Error())
}
}

type CustomPtrTypeForCopier struct {
Value int
}

func (cpt *CustomPtrTypeForCopier) Copy() (*CustomPtrTypeForCopier, error) {
customPtrTypeCopyCalled = true
if cpt == nil {
// This case should ideally not be hit if the main Copy function guards against it.
return nil, fmt.Errorf("custom Copy() called on nil CustomPtrTypeForCopier receiver")
}
return &CustomPtrTypeForCopier{Value: cpt.Value * 3}, nil
}

func TestCopy_CustomCopier_PointerReceiver(t *testing.T) {
customPtrTypeCopyCalled = false
src := &CustomPtrTypeForCopier{Value: 5} // T is *CustomPtrTypeForCopier

dst, err := Copy(src)

if err != nil {
t.Fatalf("Copy failed for CustomCopier with pointer receiver: %v", err)
}
if !customPtrTypeCopyCalled {
t.Errorf("Custom Copier method (ptr receiver) was not called")
}
if dst.Value != 15 {
t.Errorf("Expected dst.Value to be 15, got %d", dst.Value)
}
if dst == src {
t.Errorf("Expected a new pointer from custom copier, got the same pointer")
}

// Test that a nil pointer of a type that implements Copier still results in a nil copy
// and does not call the custom Copy method.
customPtrTypeCopyCalled = false
var nilSrc *CustomPtrTypeForCopier
dstNil, errNil := Copy(nilSrc)
if errNil != nil {
t.Fatalf("Copy failed for nil CustomPtrTypeForCopier: %v", errNil)
}
if customPtrTypeCopyCalled {
t.Errorf("Custom Copier method (ptr receiver) was called for nil input, but should not have been")
}
if dstNil != nil {
t.Errorf("Expected nil for copied nil pointer of custom type, got %v", dstNil)
}
}