diff --git a/io.go b/io.go index 4629016..4f1ca4e 100644 --- a/io.go +++ b/io.go @@ -62,15 +62,15 @@ func (r *mhReader) ReadByte() (byte, error) { func (r *mhReader) ReadMultihash() (Multihash, error) { code, err := binary.ReadUvarint(r) if err != nil { - return nil, err + return Nil, err } length, err := binary.ReadUvarint(r) if err != nil { - return nil, err + return Nil, err } if length > math.MaxInt32 { - return nil, errors.New("digest too long, supporting only <= 2^31-1") + return Nil, errors.New("digest too long, supporting only <= 2^31-1") } pre := make([]byte, 2*binary.MaxVarintLen64) @@ -83,7 +83,7 @@ func (r *mhReader) ReadMultihash() (Multihash, error) { copy(buf, pre[:n]) if _, err := io.ReadFull(r.r, buf[n:]); err != nil { - return nil, err + return Nil, err } return Cast(buf) @@ -98,6 +98,6 @@ func (w *mhWriter) Write(buf []byte) (n int, err error) { } func (w *mhWriter) WriteMultihash(m Multihash) error { - _, err := w.w.Write([]byte(m)) + _, err := w.w.Write(m.Bytes()) return err } diff --git a/io_test.go b/io_test.go index d9ac4e5..d3add0d 100644 --- a/io_test.go +++ b/io_test.go @@ -25,16 +25,16 @@ func TestEvilReader(t *testing.T) { if err != nil { t.Fatal(err) } - r := NewReader(&evilReader{emptyHash}) + r := NewReader(&evilReader{emptyHash.Bytes()}) h, err := r.ReadMultihash() if err != nil { t.Fatal(err) } - if !bytes.Equal(h, []byte(emptyHash)) { + if h != emptyHash { t.Fatal(err) } h, err = r.ReadMultihash() - if len([]byte(h)) > 0 || err != io.EOF { + if h != Nil || err != io.EOF { t.Fatal("expected end of file") } } @@ -49,7 +49,7 @@ func TestReader(t *testing.T) { t.Fatal(err) } - buf.Write([]byte(m)) + buf.Write(m.Bytes()) } r := NewReader(&buf) @@ -66,7 +66,7 @@ func TestReader(t *testing.T) { continue } - if !bytes.Equal(h, h2) { + if h != h2 { t.Error("h and h2 should be equal") } } @@ -89,13 +89,13 @@ func TestWriter(t *testing.T) { continue } - buf2 := make([]byte, len(m)) + buf2 := make([]byte, len(m.Bytes())) if _, err := io.ReadFull(&buf, buf2); err != nil { t.Error(err) continue } - if !bytes.Equal(m, buf2) { + if m.Binary() != string(buf2) { t.Error("m and buf2 should be equal") } } diff --git a/multihash.go b/multihash.go index 92b4fda..7522ff8 100644 --- a/multihash.go +++ b/multihash.go @@ -25,6 +25,9 @@ var ( ErrVarintTooLong = errors.New("uvarint: varint too big (max 64bit)") ) +// Nil represents an empty multihash value. +var Nil = Multihash{} + // ErrInconsistentLen is returned when a decoded multihash has an inconsistent length type ErrInconsistentLen struct { dm *DecodedMultihash @@ -168,18 +171,37 @@ type DecodedMultihash struct { // Multihash is byte slice with the following form: // . // See the spec for more information. -type Multihash []byte +type Multihash struct { + s string +} // HexString returns the hex-encoded representation of a multihash. -func (m *Multihash) HexString() string { - return hex.EncodeToString([]byte(*m)) +func (m Multihash) HexString() string { + return hex.EncodeToString([]byte(m.s)) } // String is an alias to HexString(). -func (m *Multihash) String() string { +func (m Multihash) String() string { return m.HexString() } +// Bytes returns the multihash as a byte slice. +func (m Multihash) Bytes() []byte { + return []byte(m.s) +} + +// String is an alias to HexString(). +func (m Multihash) IsNil() bool { + return m.s == "" +} + +// Binary returns the multihash as a binary string. +// +// Unlike `bytes`, this doesn't allocate. +func (m Multihash) Binary() string { + return m.s +} + // FromHexString parses a hex-encoded multihash. func FromHexString(s string) (Multihash, error) { b, err := hex.DecodeString(s) @@ -192,7 +214,7 @@ func FromHexString(s string) (Multihash, error) { // B58String returns the B58-encoded representation of a multihash. func (m Multihash) B58String() string { - return b58.Encode([]byte(m)) + return b58.Encode([]byte(m.s)) } // FromB58String parses a B58-encoded multihash. @@ -217,7 +239,7 @@ func Cast(buf []byte) (Multihash, error) { return Multihash{}, ErrUnknownCode } - return Multihash(buf), nil + return Multihash{s: string(buf)}, nil } // Decode parses multihash bytes into a DecodedMultihash. diff --git a/multihash/main.go b/multihash/main.go index 3f874c3..5144d18 100644 --- a/multihash/main.go +++ b/multihash/main.go @@ -119,7 +119,7 @@ func main() { inp, err := getInput() checkErr(err) - if checkMh != nil { + if checkMh.IsNil() { err = opts.Check(inp, checkMh) checkErr(err) if !quiet { diff --git a/multihash_test.go b/multihash_test.go index 9305e1a..7d7b1c8 100644 --- a/multihash_test.go +++ b/multihash_test.go @@ -53,7 +53,7 @@ var testCases = []TestCase{ func (tc TestCase) Multihash() (Multihash, error) { ob, err := hex.DecodeString(tc.hex) if err != nil { - return nil, err + return Nil, err } pre := make([]byte, 2*binary.MaxVarintLen64) @@ -107,7 +107,7 @@ func TestEncode(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(h, nb) { + if h.Binary() != string(nb) { t.Error("Multihash func mismatch.") } } @@ -257,7 +257,7 @@ func TestHex(t *testing.T) { continue } - if !bytes.Equal(mh, nb) { + if mh.Binary() != string(nb) { t.Error("FromHexString failed", nb, mh) continue } diff --git a/opts/coding.go b/opts/coding.go index 0696102..1ea8d3a 100644 --- a/opts/coding.go +++ b/opts/coding.go @@ -14,26 +14,38 @@ func Decode(encoding, digest string) (mh.Multihash, error) { case "raw": return mh.Cast([]byte(digest)) case "hex": - return hex.DecodeString(digest) + bts, err := hex.DecodeString(digest) + if err != nil { + return mh.Nil, err + } + return mh.Cast(bts) case "base58": - return base58.Decode(digest) + bts, err := base58.Decode(digest) + if err != nil { + return mh.Nil, err + } + return mh.Cast(bts) case "base64": - return base64.StdEncoding.DecodeString(digest) + bts, err := base64.StdEncoding.DecodeString(digest) + if err != nil { + return mh.Nil, err + } + return mh.Cast(bts) default: - return nil, fmt.Errorf("unknown encoding: %s", encoding) + return mh.Nil, fmt.Errorf("unknown encoding: %s", encoding) } } func Encode(encoding string, hash mh.Multihash) (string, error) { switch encoding { case "raw": - return string(hash), nil + return hash.Binary(), nil case "hex": - return hex.EncodeToString(hash), nil + return hex.EncodeToString(hash.Bytes()), nil case "base58": - return base58.Encode(hash), nil + return base58.Encode(hash.Bytes()), nil case "base64": - return base64.StdEncoding.EncodeToString(hash), nil + return base64.StdEncoding.EncodeToString(hash.Bytes()), nil default: return "", fmt.Errorf("unknown encoding: %s", encoding) } diff --git a/opts/opts.go b/opts/opts.go index c905c9b..047b787 100644 --- a/opts/opts.go +++ b/opts/opts.go @@ -3,7 +3,6 @@ package opts import ( - "bytes" "errors" "flag" "fmt" @@ -113,7 +112,7 @@ func (o *Options) Check(r io.Reader, h1 mh.Multihash) error { return err } - if !bytes.Equal(h1, h2) { + if h1 != h2 { return fmt.Errorf("computed checksum did not match") } @@ -124,7 +123,7 @@ func (o *Options) Check(r io.Reader, h1 mh.Multihash) error { func (o *Options) Multihash(r io.Reader) (mh.Multihash, error) { b, err := ioutil.ReadAll(r) if err != nil { - return nil, err + return mh.Nil, err } return mh.Sum(b, o.AlgorithmCode, o.Length) diff --git a/sum.go b/sum.go index af3f079..8f58ce3 100644 --- a/sum.go +++ b/sum.go @@ -44,7 +44,7 @@ func Sum(data []byte, code uint64, length int) (Multihash, error) { out := blake2s.Sum256(data) d = out[:] default: - return nil, fmt.Errorf("unsupported length for blake2s: %d", olen) + return Nil, fmt.Errorf("unsupported length for blake2s: %d", olen) } case isBlake2b(code): olen := uint8(code - BLAKE2B_MIN + 1) @@ -93,7 +93,11 @@ func Sum(data []byte, code uint64, length int) (Multihash, error) { if length >= 0 { d = d[:length] } - return Encode(d, code) + bts, err := Encode(d, code) + if err != nil { + return Nil, err + } + return Multihash{string(bts)}, nil } func isBlake2s(code uint64) bool { diff --git a/sum_test.go b/sum_test.go index 9260626..c9a5621 100644 --- a/sum_test.go +++ b/sum_test.go @@ -1,7 +1,6 @@ package multihash import ( - "bytes" "encoding/hex" "fmt" "runtime" @@ -72,9 +71,9 @@ func TestSum(t *testing.T) { continue } - if !bytes.Equal(m1, m2) { + if m1 != m2 { t.Error(tc.code, Codes[tc.code], "sum failed.", m1, m2) - t.Error(hex.EncodeToString(m2)) + t.Error(hex.EncodeToString(m2.Bytes())) } s1 := m1.HexString() @@ -86,7 +85,7 @@ func TestSum(t *testing.T) { m3, err := FromB58String(s2) if err != nil { t.Error("failed to decode b58") - } else if !bytes.Equal(m3, m1) { + } else if m3 != m1 { t.Error("b58 failing bytes") } else if s2 != m3.B58String() { t.Error("b58 failing string")