Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 158 additions & 51 deletions diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,81 @@ import (

type sbuf []string

func (s *sbuf) Write(b []byte) (int, error) {
*s = append(*s, string(b))
return len(b), nil
func (p *sbuf) Printf(format string, a ...interface{}) {
s := fmt.Sprintf(format, a...)
*p = append(*p, s)
}

// Diff returns a slice where each element describes
// a difference between a and b.
func Diff(a, b interface{}) (desc []string) {
Fdiff((*sbuf)(&desc), a, b)
Pdiff((*sbuf)(&desc), a, b)
return desc
}

// wprintfer calls Fprintf on w for each Printf call
// with a trailing newline.
type wprintfer struct{ w io.Writer }

func (p *wprintfer) Printf(format string, a ...interface{}) {
fmt.Fprintf(p.w, format+"\n", a...)
}

// Fdiff writes to w a description of the differences between a and b.
func Fdiff(w io.Writer, a, b interface{}) {
diffWriter{w: w}.diff(reflect.ValueOf(a), reflect.ValueOf(b))
Pdiff(&wprintfer{w}, a, b)
}

type Printfer interface {
Printf(format string, a ...interface{})
}

// Pdiff prints to p a description of the differences between a and b.
// It calls Printf once for each difference, with no trailing newline.
// The standard library log.Logger is a Printfer.
func Pdiff(p Printfer, a, b interface{}) {
diffPrinter{w: p}.diff(reflect.ValueOf(a), reflect.ValueOf(b))
}

type Logfer interface {
Logf(format string, a ...interface{})
}

type diffWriter struct {
w io.Writer
// logprintfer calls Fprintf on w for each Printf call
// with a trailing newline.
type logprintfer struct{ l Logfer }

func (p *logprintfer) Printf(format string, a ...interface{}) {
p.l.Logf(format, a...)
}

// Ldiff prints to l a description of the differences between a and b.
// It calls Logf once for each difference, with no trailing newline.
// The standard library testing.T and testing.B are Logfers.
func Ldiff(l Logfer, a, b interface{}) {
Pdiff(&logprintfer{l}, a, b)
}

type diffPrinter struct {
w Printfer
l string // label
}

func (w diffWriter) printf(f string, a ...interface{}) {
func (w diffPrinter) printf(f string, a ...interface{}) {
var l string
if w.l != "" {
l = w.l + ": "
}
fmt.Fprintf(w.w, l+f, a...)
w.w.Printf(l+f, a...)
}

func (w diffWriter) diff(av, bv reflect.Value) {
func (w diffPrinter) diff(av, bv reflect.Value) {
if !av.IsValid() && bv.IsValid() {
w.printf("nil != %#v", bv.Interface())
w.printf("nil != %# v", formatter{v: bv, quote: true})
return
}
if av.IsValid() && !bv.IsValid() {
w.printf("%#v != nil", av.Interface())
w.printf("%# v != nil", formatter{v: av, quote: true})
return
}
if !av.IsValid() && !bv.IsValid() {
Expand All @@ -58,34 +96,61 @@ func (w diffWriter) diff(av, bv reflect.Value) {
return
}

// numeric types, including bool
if at.Kind() < reflect.Array {
a, b := av.Interface(), bv.Interface()
if a != b {
w.printf("%#v != %#v", a, b)
switch kind := at.Kind(); kind {
case reflect.Bool:
if a, b := av.Bool(), bv.Bool(); a != b {
w.printf("%v != %v", a, b)
}
return
}

switch at.Kind() {
case reflect.String:
a, b := av.Interface(), bv.Interface()
if a != b {
w.printf("%q != %q", a, b)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if a, b := av.Int(), bv.Int(); a != b {
w.printf("%d != %d", a, b)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if a, b := av.Uint(), bv.Uint(); a != b {
w.printf("%d != %d", a, b)
}
case reflect.Float32, reflect.Float64:
if a, b := av.Float(), bv.Float(); a != b {
w.printf("%v != %v", a, b)
}
case reflect.Complex64, reflect.Complex128:
if a, b := av.Complex(), bv.Complex(); a != b {
w.printf("%v != %v", a, b)
}
case reflect.Array:
n := av.Len()
for i := 0; i < n; i++ {
w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
}
case reflect.Chan, reflect.Func, reflect.UnsafePointer:
if a, b := av.Pointer(), bv.Pointer(); a != b {
w.printf("%#x != %#x", a, b)
}
case reflect.Interface:
w.diff(av.Elem(), bv.Elem())
case reflect.Map:
ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys())
for _, k := range ak {
w := w.relabel(fmt.Sprintf("[%#v]", k))
w.printf("%q != (missing)", av.MapIndex(k))
}
for _, k := range both {
w := w.relabel(fmt.Sprintf("[%#v]", k))
w.diff(av.MapIndex(k), bv.MapIndex(k))
}
for _, k := range bk {
w := w.relabel(fmt.Sprintf("[%#v]", k))
w.printf("(missing) != %q", bv.MapIndex(k))
}
case reflect.Ptr:
switch {
case av.IsNil() && !bv.IsNil():
w.printf("nil != %v", bv.Interface())
w.printf("nil != %# v", formatter{v: bv, quote: true})
case !av.IsNil() && bv.IsNil():
w.printf("%v != nil", av.Interface())
w.printf("%# v != nil", formatter{v: av, quote: true})
case !av.IsNil() && !bv.IsNil():
w.diff(av.Elem(), bv.Elem())
}
case reflect.Struct:
for i := 0; i < av.NumField(); i++ {
w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i))
}
case reflect.Slice:
lenA := av.Len()
lenB := bv.Len()
Expand All @@ -96,30 +161,20 @@ func (w diffWriter) diff(av, bv reflect.Value) {
for i := 0; i < lenA; i++ {
w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
}
case reflect.Map:
ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys())
for _, k := range ak {
w := w.relabel(fmt.Sprintf("[%#v]", k.Interface()))
w.printf("%q != (missing)", av.MapIndex(k))
}
for _, k := range both {
w := w.relabel(fmt.Sprintf("[%#v]", k.Interface()))
w.diff(av.MapIndex(k), bv.MapIndex(k))
case reflect.String:
if a, b := av.String(), bv.String(); a != b {
w.printf("%q != %q", a, b)
}
for _, k := range bk {
w := w.relabel(fmt.Sprintf("[%#v]", k.Interface()))
w.printf("(missing) != %q", bv.MapIndex(k))
case reflect.Struct:
for i := 0; i < av.NumField(); i++ {
w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i))
}
case reflect.Interface:
w.diff(reflect.ValueOf(av.Interface()), reflect.ValueOf(bv.Interface()))
default:
if !reflect.DeepEqual(av.Interface(), bv.Interface()) {
w.printf("%# v != %# v", Formatter(av.Interface()), Formatter(bv.Interface()))
}
panic("unknown reflect Kind: " + kind.String())
}
}

func (d diffWriter) relabel(name string) (d1 diffWriter) {
func (d diffPrinter) relabel(name string) (d1 diffPrinter) {
d1 = d
if d.l != "" && name[0] != '[' {
d1.l += "."
Expand All @@ -128,11 +183,63 @@ func (d diffWriter) relabel(name string) (d1 diffWriter) {
return d1
}

// keyEqual compares a and b for equality.
// Both a and b must be valid map keys.
func keyEqual(av, bv reflect.Value) bool {
if !av.IsValid() && !bv.IsValid() {
return true
}
if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
return false
}
switch kind := av.Kind(); kind {
case reflect.Bool:
a, b := av.Bool(), bv.Bool()
return a == b
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
a, b := av.Int(), bv.Int()
return a == b
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
a, b := av.Uint(), bv.Uint()
return a == b
case reflect.Float32, reflect.Float64:
a, b := av.Float(), bv.Float()
return a == b
case reflect.Complex64, reflect.Complex128:
a, b := av.Complex(), bv.Complex()
return a == b
case reflect.Array:
for i := 0; i < av.Len(); i++ {
if !keyEqual(av.Index(i), bv.Index(i)) {
return false
}
}
return true
case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
a, b := av.Pointer(), bv.Pointer()
return a == b
case reflect.Interface:
return keyEqual(av.Elem(), bv.Elem())
case reflect.String:
a, b := av.String(), bv.String()
return a == b
case reflect.Struct:
for i := 0; i < av.NumField(); i++ {
if !keyEqual(av.Field(i), bv.Field(i)) {
return false
}
}
return true
default:
panic("invalid map key type " + av.Type().String())
}
}

func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) {
for _, av := range a {
inBoth := false
for _, bv := range b {
if reflect.DeepEqual(av.Interface(), bv.Interface()) {
if keyEqual(av, bv) {
inBoth = true
both = append(both, av)
break
Expand All @@ -145,7 +252,7 @@ func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) {
for _, bv := range b {
inBoth := false
for _, av := range a {
if reflect.DeepEqual(av.Interface(), bv.Interface()) {
if keyEqual(av, bv) {
inBoth = true
break
}
Expand Down
Loading