diff --git a/diff.go b/diff.go index 8fe8e24..6aa7f74 100644 --- a/diff.go +++ b/diff.go @@ -8,43 +8,81 @@ import ( type sbuf []string -func (s *sbuf) Write(b []byte) (int, error) { - *s = append(*s, string(b)) - return len(b), nil +func (p *sbuf) Printf(format string, a ...interface{}) { + s := fmt.Sprintf(format, a...) + *p = append(*p, s) } // Diff returns a slice where each element describes // a difference between a and b. func Diff(a, b interface{}) (desc []string) { - Fdiff((*sbuf)(&desc), a, b) + Pdiff((*sbuf)(&desc), a, b) return desc } +// wprintfer calls Fprintf on w for each Printf call +// with a trailing newline. +type wprintfer struct{ w io.Writer } + +func (p *wprintfer) Printf(format string, a ...interface{}) { + fmt.Fprintf(p.w, format+"\n", a...) +} + // Fdiff writes to w a description of the differences between a and b. func Fdiff(w io.Writer, a, b interface{}) { - diffWriter{w: w}.diff(reflect.ValueOf(a), reflect.ValueOf(b)) + Pdiff(&wprintfer{w}, a, b) +} + +type Printfer interface { + Printf(format string, a ...interface{}) +} + +// Pdiff prints to p a description of the differences between a and b. +// It calls Printf once for each difference, with no trailing newline. +// The standard library log.Logger is a Printfer. +func Pdiff(p Printfer, a, b interface{}) { + diffPrinter{w: p}.diff(reflect.ValueOf(a), reflect.ValueOf(b)) +} + +type Logfer interface { + Logf(format string, a ...interface{}) } -type diffWriter struct { - w io.Writer +// logprintfer calls Fprintf on w for each Printf call +// with a trailing newline. +type logprintfer struct{ l Logfer } + +func (p *logprintfer) Printf(format string, a ...interface{}) { + p.l.Logf(format, a...) +} + +// Ldiff prints to l a description of the differences between a and b. +// It calls Logf once for each difference, with no trailing newline. +// The standard library testing.T and testing.B are Logfers. +func Ldiff(l Logfer, a, b interface{}) { + Pdiff(&logprintfer{l}, a, b) +} + +type diffPrinter struct { + w Printfer l string // label } -func (w diffWriter) printf(f string, a ...interface{}) { +func (w diffPrinter) printf(f string, a ...interface{}) { var l string if w.l != "" { l = w.l + ": " } - fmt.Fprintf(w.w, l+f, a...) + w.w.Printf(l+f, a...) } -func (w diffWriter) diff(av, bv reflect.Value) { +func (w diffPrinter) diff(av, bv reflect.Value) { if !av.IsValid() && bv.IsValid() { - w.printf("nil != %#v", bv.Interface()) + w.printf("nil != %# v", formatter{v: bv, quote: true}) return } if av.IsValid() && !bv.IsValid() { - w.printf("%#v != nil", av.Interface()) + w.printf("%# v != nil", formatter{v: av, quote: true}) return } if !av.IsValid() && !bv.IsValid() { @@ -58,34 +96,61 @@ func (w diffWriter) diff(av, bv reflect.Value) { return } - // numeric types, including bool - if at.Kind() < reflect.Array { - a, b := av.Interface(), bv.Interface() - if a != b { - w.printf("%#v != %#v", a, b) + switch kind := at.Kind(); kind { + case reflect.Bool: + if a, b := av.Bool(), bv.Bool(); a != b { + w.printf("%v != %v", a, b) } - return - } - - switch at.Kind() { - case reflect.String: - a, b := av.Interface(), bv.Interface() - if a != b { - w.printf("%q != %q", a, b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if a, b := av.Int(), bv.Int(); a != b { + w.printf("%d != %d", a, b) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if a, b := av.Uint(), bv.Uint(); a != b { + w.printf("%d != %d", a, b) + } + case reflect.Float32, reflect.Float64: + if a, b := av.Float(), bv.Float(); a != b { + w.printf("%v != %v", a, b) + } + case reflect.Complex64, reflect.Complex128: + if a, b := av.Complex(), bv.Complex(); a != b { + w.printf("%v != %v", a, b) + } + case reflect.Array: + n := av.Len() + for i := 0; i < n; i++ { + w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i)) + } + case reflect.Chan, reflect.Func, reflect.UnsafePointer: + if a, b := av.Pointer(), bv.Pointer(); a != b { + w.printf("%#x != %#x", a, b) + } + case reflect.Interface: + w.diff(av.Elem(), bv.Elem()) + case reflect.Map: + ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys()) + for _, k := range ak { + w := w.relabel(fmt.Sprintf("[%#v]", k)) + w.printf("%q != (missing)", av.MapIndex(k)) + } + for _, k := range both { + w := w.relabel(fmt.Sprintf("[%#v]", k)) + w.diff(av.MapIndex(k), bv.MapIndex(k)) + } + for _, k := range bk { + w := w.relabel(fmt.Sprintf("[%#v]", k)) + w.printf("(missing) != %q", bv.MapIndex(k)) } case reflect.Ptr: switch { case av.IsNil() && !bv.IsNil(): - w.printf("nil != %v", bv.Interface()) + w.printf("nil != %# v", formatter{v: bv, quote: true}) case !av.IsNil() && bv.IsNil(): - w.printf("%v != nil", av.Interface()) + w.printf("%# v != nil", formatter{v: av, quote: true}) case !av.IsNil() && !bv.IsNil(): w.diff(av.Elem(), bv.Elem()) } - case reflect.Struct: - for i := 0; i < av.NumField(); i++ { - w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i)) - } case reflect.Slice: lenA := av.Len() lenB := bv.Len() @@ -96,30 +161,20 @@ func (w diffWriter) diff(av, bv reflect.Value) { for i := 0; i < lenA; i++ { w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i)) } - case reflect.Map: - ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys()) - for _, k := range ak { - w := w.relabel(fmt.Sprintf("[%#v]", k.Interface())) - w.printf("%q != (missing)", av.MapIndex(k)) - } - for _, k := range both { - w := w.relabel(fmt.Sprintf("[%#v]", k.Interface())) - w.diff(av.MapIndex(k), bv.MapIndex(k)) + case reflect.String: + if a, b := av.String(), bv.String(); a != b { + w.printf("%q != %q", a, b) } - for _, k := range bk { - w := w.relabel(fmt.Sprintf("[%#v]", k.Interface())) - w.printf("(missing) != %q", bv.MapIndex(k)) + case reflect.Struct: + for i := 0; i < av.NumField(); i++ { + w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i)) } - case reflect.Interface: - w.diff(reflect.ValueOf(av.Interface()), reflect.ValueOf(bv.Interface())) default: - if !reflect.DeepEqual(av.Interface(), bv.Interface()) { - w.printf("%# v != %# v", Formatter(av.Interface()), Formatter(bv.Interface())) - } + panic("unknown reflect Kind: " + kind.String()) } } -func (d diffWriter) relabel(name string) (d1 diffWriter) { +func (d diffPrinter) relabel(name string) (d1 diffPrinter) { d1 = d if d.l != "" && name[0] != '[' { d1.l += "." @@ -128,11 +183,63 @@ func (d diffWriter) relabel(name string) (d1 diffWriter) { return d1 } +// keyEqual compares a and b for equality. +// Both a and b must be valid map keys. +func keyEqual(av, bv reflect.Value) bool { + if !av.IsValid() && !bv.IsValid() { + return true + } + if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() { + return false + } + switch kind := av.Kind(); kind { + case reflect.Bool: + a, b := av.Bool(), bv.Bool() + return a == b + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + a, b := av.Int(), bv.Int() + return a == b + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + a, b := av.Uint(), bv.Uint() + return a == b + case reflect.Float32, reflect.Float64: + a, b := av.Float(), bv.Float() + return a == b + case reflect.Complex64, reflect.Complex128: + a, b := av.Complex(), bv.Complex() + return a == b + case reflect.Array: + for i := 0; i < av.Len(); i++ { + if !keyEqual(av.Index(i), bv.Index(i)) { + return false + } + } + return true + case reflect.Chan, reflect.UnsafePointer, reflect.Ptr: + a, b := av.Pointer(), bv.Pointer() + return a == b + case reflect.Interface: + return keyEqual(av.Elem(), bv.Elem()) + case reflect.String: + a, b := av.String(), bv.String() + return a == b + case reflect.Struct: + for i := 0; i < av.NumField(); i++ { + if !keyEqual(av.Field(i), bv.Field(i)) { + return false + } + } + return true + default: + panic("invalid map key type " + av.Type().String()) + } +} + func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) { for _, av := range a { inBoth := false for _, bv := range b { - if reflect.DeepEqual(av.Interface(), bv.Interface()) { + if keyEqual(av, bv) { inBoth = true both = append(both, av) break @@ -145,7 +252,7 @@ func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) { for _, bv := range b { inBoth := false for _, av := range a { - if reflect.DeepEqual(av.Interface(), bv.Interface()) { + if keyEqual(av, bv) { inBoth = true break } diff --git a/diff_test.go b/diff_test.go index 3c388f1..a951e4b 100644 --- a/diff_test.go +++ b/diff_test.go @@ -1,7 +1,18 @@ package pretty import ( + "bytes" + "fmt" + "log" + "reflect" "testing" + "unsafe" +) + +var ( + _ Logfer = (*testing.T)(nil) + _ Logfer = (*testing.B)(nil) + _ Printfer = (*log.Logger)(nil) ) type difftest struct { @@ -17,6 +28,20 @@ type S struct { C []int } +type ( + N struct{ N int } + E interface{} +) + +var ( + c0 = make(chan int) + c1 = make(chan int) + f0 = func() {} + f1 = func() {} + i0 = 0 + i1 = 1 +) + var diffs = []difftest{ {a: nil, b: nil}, {a: S{A: 1}, b: S{A: 1}}, @@ -28,12 +53,79 @@ var diffs = []difftest{ {S{}, S{A: 1}, []string{`A: 0 != 1`}}, {new(S), &S{A: 1}, []string{`A: 0 != 1`}}, {S{S: new(S)}, S{S: &S{A: 1}}, []string{`S.A: 0 != 1`}}, - {S{}, S{I: 0}, []string{`I: nil != 0`}}, + {S{}, S{I: 0}, []string{`I: nil != int(0)`}}, {S{I: 1}, S{I: "x"}, []string{`I: int != string`}}, {S{}, S{C: []int{1}}, []string{`C: []int[0] != []int[1]`}}, {S{C: []int{}}, S{C: []int{1}}, []string{`C: []int[0] != []int[1]`}}, {S{C: []int{1, 2, 3}}, S{C: []int{1, 2, 4}}, []string{`C[2]: 3 != 4`}}, - {S{}, S{A: 1, S: new(S)}, []string{`A: 0 != 1`, `S: nil != &{0 []}`}}, + {S{}, S{A: 1, S: new(S)}, []string{`A: 0 != 1`, `S: nil != &pretty.S{}`}}, + + // unexported fields of every reflect.Kind (both equal and unequal) + {struct{ x bool }{false}, struct{ x bool }{false}, nil}, + {struct{ x bool }{false}, struct{ x bool }{true}, []string{`x: false != true`}}, + {struct{ x int }{0}, struct{ x int }{0}, nil}, + {struct{ x int }{0}, struct{ x int }{1}, []string{`x: 0 != 1`}}, + {struct{ x int8 }{0}, struct{ x int8 }{0}, nil}, + {struct{ x int8 }{0}, struct{ x int8 }{1}, []string{`x: 0 != 1`}}, + {struct{ x int16 }{0}, struct{ x int16 }{0}, nil}, + {struct{ x int16 }{0}, struct{ x int16 }{1}, []string{`x: 0 != 1`}}, + {struct{ x int32 }{0}, struct{ x int32 }{0}, nil}, + {struct{ x int32 }{0}, struct{ x int32 }{1}, []string{`x: 0 != 1`}}, + {struct{ x int64 }{0}, struct{ x int64 }{0}, nil}, + {struct{ x int64 }{0}, struct{ x int64 }{1}, []string{`x: 0 != 1`}}, + {struct{ x uint }{0}, struct{ x uint }{0}, nil}, + {struct{ x uint }{0}, struct{ x uint }{1}, []string{`x: 0 != 1`}}, + {struct{ x uint8 }{0}, struct{ x uint8 }{0}, nil}, + {struct{ x uint8 }{0}, struct{ x uint8 }{1}, []string{`x: 0 != 1`}}, + {struct{ x uint16 }{0}, struct{ x uint16 }{0}, nil}, + {struct{ x uint16 }{0}, struct{ x uint16 }{1}, []string{`x: 0 != 1`}}, + {struct{ x uint32 }{0}, struct{ x uint32 }{0}, nil}, + {struct{ x uint32 }{0}, struct{ x uint32 }{1}, []string{`x: 0 != 1`}}, + {struct{ x uint64 }{0}, struct{ x uint64 }{0}, nil}, + {struct{ x uint64 }{0}, struct{ x uint64 }{1}, []string{`x: 0 != 1`}}, + {struct{ x uintptr }{0}, struct{ x uintptr }{0}, nil}, + {struct{ x uintptr }{0}, struct{ x uintptr }{1}, []string{`x: 0 != 1`}}, + {struct{ x float32 }{0}, struct{ x float32 }{0}, nil}, + {struct{ x float32 }{0}, struct{ x float32 }{1}, []string{`x: 0 != 1`}}, + {struct{ x float64 }{0}, struct{ x float64 }{0}, nil}, + {struct{ x float64 }{0}, struct{ x float64 }{1}, []string{`x: 0 != 1`}}, + {struct{ x complex64 }{0}, struct{ x complex64 }{0}, nil}, + {struct{ x complex64 }{0}, struct{ x complex64 }{1}, []string{`x: (0+0i) != (1+0i)`}}, + {struct{ x complex128 }{0}, struct{ x complex128 }{0}, nil}, + {struct{ x complex128 }{0}, struct{ x complex128 }{1}, []string{`x: (0+0i) != (1+0i)`}}, + {struct{ x [1]int }{[1]int{0}}, struct{ x [1]int }{[1]int{0}}, nil}, + {struct{ x [1]int }{[1]int{0}}, struct{ x [1]int }{[1]int{1}}, []string{`x[0]: 0 != 1`}}, + {struct{ x chan int }{c0}, struct{ x chan int }{c0}, nil}, + {struct{ x chan int }{c0}, struct{ x chan int }{c1}, []string{fmt.Sprintf("x: %p != %p", c0, c1)}}, + {struct{ x func() }{f0}, struct{ x func() }{f0}, nil}, + {struct{ x func() }{f0}, struct{ x func() }{f1}, []string{fmt.Sprintf("x: %p != %p", f0, f1)}}, + {struct{ x interface{} }{0}, struct{ x interface{} }{0}, nil}, + {struct{ x interface{} }{0}, struct{ x interface{} }{1}, []string{`x: 0 != 1`}}, + {struct{ x interface{} }{0}, struct{ x interface{} }{""}, []string{`x: int != string`}}, + {struct{ x interface{} }{0}, struct{ x interface{} }{nil}, []string{`x: int(0) != nil`}}, + {struct{ x interface{} }{nil}, struct{ x interface{} }{0}, []string{`x: nil != int(0)`}}, + {struct{ x map[int]int }{map[int]int{0: 0}}, struct{ x map[int]int }{map[int]int{0: 0}}, nil}, + {struct{ x map[int]int }{map[int]int{0: 0}}, struct{ x map[int]int }{map[int]int{0: 1}}, []string{`x[0]: 0 != 1`}}, + {struct{ x *int }{new(int)}, struct{ x *int }{new(int)}, nil}, + {struct{ x *int }{&i0}, struct{ x *int }{&i1}, []string{`x: 0 != 1`}}, + {struct{ x *int }{nil}, struct{ x *int }{&i0}, []string{`x: nil != &int(0)`}}, + {struct{ x *int }{&i0}, struct{ x *int }{nil}, []string{`x: &int(0) != nil`}}, + {struct{ x []int }{[]int{0}}, struct{ x []int }{[]int{0}}, nil}, + {struct{ x []int }{[]int{0}}, struct{ x []int }{[]int{1}}, []string{`x[0]: 0 != 1`}}, + {struct{ x string }{"a"}, struct{ x string }{"a"}, nil}, + {struct{ x string }{"a"}, struct{ x string }{"b"}, []string{`x: "a" != "b"`}}, + {struct{ x N }{N{0}}, struct{ x N }{N{0}}, nil}, + {struct{ x N }{N{0}}, struct{ x N }{N{1}}, []string{`x.N: 0 != 1`}}, + { + struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(0))}, + struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(0))}, + nil, + }, + { + struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(0))}, + struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(1))}, + []string{`x: 0x0 != 0x1`}, + }, } func TestDiff(t *testing.T) { @@ -54,6 +146,53 @@ func TestDiff(t *testing.T) { } } +func TestKeyEqual(t *testing.T) { + var emptyInterfaceZero interface{} = 0 + + cases := []interface{}{ + new(bool), + new(int), + new(int8), + new(int16), + new(int32), + new(int64), + new(uint), + new(uint8), + new(uint16), + new(uint32), + new(uint64), + new(uintptr), + new(float32), + new(float64), + new(complex64), + new(complex128), + new([1]int), + new(chan int), + new(unsafe.Pointer), + new(interface{}), + &emptyInterfaceZero, + new(*int), + new(string), + new(struct{ int }), + } + + for _, test := range cases { + rv := reflect.ValueOf(test).Elem() + if !keyEqual(rv, rv) { + t.Errorf("keyEqual(%s, %s) = false want true", rv.Type(), rv.Type()) + } + } +} + +func TestFdiff(t *testing.T) { + var buf bytes.Buffer + Fdiff(&buf, 0, 1) + want := "0 != 1\n" + if got := buf.String(); got != want { + t.Errorf("Fdiff(0, 1) = %q want %q", got, want) + } +} + func diffdiff(t *testing.T, got, exp []string) { minus(t, "unexpected:", got, exp) minus(t, "missing:", exp, got) diff --git a/formatter.go b/formatter.go index 8dacda2..a317d7b 100644 --- a/formatter.go +++ b/formatter.go @@ -10,12 +10,8 @@ import ( "github.com/kr/text" ) -const ( - limit = 50 -) - type formatter struct { - x interface{} + v reflect.Value force bool quote bool } @@ -30,11 +26,11 @@ type formatter struct { // format x according to the usual rules of package fmt. // In particular, if x satisfies fmt.Formatter, then x.Format will be called. func Formatter(x interface{}) (f fmt.Formatter) { - return formatter{x: x, quote: true} + return formatter{v: reflect.ValueOf(x), quote: true} } func (fo formatter) String() string { - return fmt.Sprint(fo.x) // unwrap it + return fmt.Sprint(fo.v.Interface()) // unwrap it } func (fo formatter) passThrough(f fmt.State, c rune) { @@ -51,14 +47,14 @@ func (fo formatter) passThrough(f fmt.State, c rune) { s += fmt.Sprintf(".%d", p) } s += string(c) - fmt.Fprintf(f, s, fo.x) + fmt.Fprintf(f, s, fo.v.Interface()) } func (fo formatter) Format(f fmt.State, c rune) { if fo.force || c == 'v' && f.Flag('#') && f.Flag(' ') { w := tabwriter.NewWriter(f, 4, 4, 1, ' ', 0) p := &printer{tw: w, Writer: w, visited: make(map[visit]int)} - p.printValue(reflect.ValueOf(fo.x), true, fo.quote) + p.printValue(fo.v, true, fo.quote) w.Flush() return } @@ -319,11 +315,6 @@ func (p *printer) fmtString(s string, quote bool) { io.WriteString(p, s) } -func tryDeepEqual(a, b interface{}) bool { - defer func() { recover() }() - return reflect.DeepEqual(a, b) -} - func writeByte(w io.Writer, b byte) { w.Write([]byte{b}) } diff --git a/formatter_test.go b/formatter_test.go index 5f3204e..c8c0b51 100644 --- a/formatter_test.go +++ b/formatter_test.go @@ -13,6 +13,11 @@ type test struct { s string } +type passtest struct { + v interface{} + f, s string +} + type LongStructTypeName struct { longFieldName interface{} otherLongFieldName interface{} @@ -33,8 +38,30 @@ func (f F) Format(s fmt.State, c rune) { fmt.Fprintf(s, "F(%d)", int(f)) } +type Stringer struct { i int } + +func (s *Stringer) String() string { return "foo" } + var long = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +var passthrough = []passtest{ + {1, "%d", "1"}, + {"a", "%s", "a"}, + {&Stringer{}, "%s", "foo"}, +} + +func TestPassthrough(t *testing.T) { + for _, tt := range passthrough { + s := fmt.Sprintf(tt.f, Formatter(tt.v)) + if tt.s != s { + t.Errorf("expected %q", tt.s) + t.Errorf("got %q", s) + t.Errorf("expraw\n%s", tt.s) + t.Errorf("gotraw\n%s", s) + } + } +} + var gosyntax = []test{ {nil, `nil`}, {"", `""`}, diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1e29533 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module "github.com/kr/pretty" + +require "github.com/kr/text" v0.1.0 diff --git a/pretty.go b/pretty.go index d3df868..49423ec 100644 --- a/pretty.go +++ b/pretty.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "reflect" ) // Errorf is a convenience wrapper for fmt.Errorf. @@ -81,6 +82,15 @@ func Println(a ...interface{}) (n int, errno error) { return fmt.Println(wrap(a, true)...) } +// Sprint is a convenience wrapper for fmt.Sprintf. +// +// Calling Sprint(x, y) is equivalent to +// fmt.Sprint(Formatter(x), Formatter(y)), but each operand is +// formatted with "%# v". +func Sprint(a ...interface{}) string { + return fmt.Sprint(wrap(a, true)...) +} + // Sprintf is a convenience wrapper for fmt.Sprintf. // // Calling Sprintf(f, x, y) is equivalent to @@ -92,7 +102,7 @@ func Sprintf(format string, a ...interface{}) string { func wrap(a []interface{}, force bool) []interface{} { w := make([]interface{}, len(a)) for i, x := range a { - w[i] = formatter{x: x, force: force} + w[i] = formatter{v: reflect.ValueOf(x), force: force} } return w }