diff --git a/go/tdh2/tdh2hybridCCP/README.md b/go/tdh2/tdh2hybridCCP/README.md new file mode 100644 index 0000000..1f0fe17 --- /dev/null +++ b/go/tdh2/tdh2hybridCCP/README.md @@ -0,0 +1,37 @@ +## tdh2hybridCCP: Hybrid TDH2 and ChaCha20-Poly1305 + +This fork of /tdh2/tdh2easy provides a hybrid encryption scheme that uses **Threshold Diffie-Hellman (TDH2)** which is secure against adaptive chosen-ciphertext attacks (CCA2), combined with a ***modern symmetric stream cipher*** **ChaCha20-Poly1305** ***instead of*** **AES-256 in Galois/Counter Mode (GCM)**. + +### ChaCha20-Poly1305 replaces AES-256-GCM +The modern stream cipher provides: +- Authenticated Encryption with Associated Data (AEAD), also called Additional Authenticated Data (AAD): + - It encrypts sensitive payload data while allowing additional, authenticated but not encrypted metadata ("associated data") to be authenticated along with the ciphertext which detects any tampering. + - AEAD has become the standard for securing communication, replacing older, less secure methods that combined encryption and Message Authentication Code (MAC) separately. +- Performance: + - Stream ciphers are often faster than AES on devices without hardware acceleration. + - Designed to be fast and efficient, often outperforming separate encryption and authentication mechanisms. + - Verification during Decryption: If the authentication tag does not match the decrypted data and associated data, the decryption fails, ensuring integrity. +- Support for larger plaintext: up to 256 GB compared to maximum ca. 64 GB with AES (RFC5084). + +### Example +The [`func TestHybrid()`](./hybrid_test.go) provides running code that steps through the cycle of Distribted Key Generation (DKG), hybrid encryption of plaintext, decryption of shares by parties and their verification before a combiner aggregates the decryption shares, and finally decrypts the ciphertext. + +Run it together with other `*_test.go` files after change into subdir `tdhhybridCCP` of this repo: +``` +~/tdh2/go/tdh2/tdh2hybridCCP$ go test +Message encrypted successfully. +Decrypted Message: The quick brown fox jumps over the lazy dog's back 0123456789. +PASS +ok github.com/hb9cwp/tdh2/go/tdh2/tdh2hybridCCP 0.109s +``` + +### References + +The implementation "SG02" of TDH2, the threshold cryptosystem proposed by Shoup and Gennaro[^1], in the Rust library "Thetacrypt"[^2] motivated the replacement of AES-GCM by ChaCha20-Poly1305 and the name for this fork of `tdh2easy`: + +> "We apply a ***hybrid*** approach to encrypt a _symmetric key_ under the _threshold key_ and the actual _plaintext_ under the _symmetric key_. As a _symmetric encryption scheme_, we use the ***ChaCha20Poly1305***, a stream cipher with a message authentication code." + +[^1]: [Securing Threshold Cryptosystems against Chosen Ciphertext Attack](https://www.shoup.net/papers/thresh1.pdf), Victor Shoup & Rosario Gennaro, September 18, 2001. + +[^2]: [Thetacrypt: A Distributed Service for Threshold Cryptography](https://arxiv.org/pdf/2502.03247), Cryptology and Data Security Research Group at the University of Bern, 6 February 2025. + diff --git a/go/tdh2/tdh2hybridCCP/hybrid_test.go b/go/tdh2/tdh2hybridCCP/hybrid_test.go new file mode 100644 index 0000000..de33255 --- /dev/null +++ b/go/tdh2/tdh2hybridCCP/hybrid_test.go @@ -0,0 +1,133 @@ +package tdh2hybridCCP + +import ( + "bytes" + "fmt" + "testing" +) + +func TestHybrid(t *testing.T) { + // Optional: Rename this to 'func main() {...}' to convert to + // a self-contained Go program. Also replace 't.' by 'log.' and + // add prefix 'tdh2hybridCCP.' to import functions & objects. + // Alternatively, rename it to 'func ExampleHybrid()' or similar to test + // only for final output, see https://pkg.go.dev/testing#hdr-Examples + + // 1. Setup: Define the threshold (k) and total participants (n). + // We need at least 2 parties to decrypt out of 3 total. + var k, n int = 2, 3 + + // Perform a distributed key generation (DKG) protocol to create a + // Master Secret, a collective Public Key, and n individual + // Private Key Shares. + // Note: The Master Secret (ms) returned is ignored here, but it will + // be required for re-keying by Redeal(pk, ms, k, n). + //ms, pubKey, privShares, err := tdh2hybridCCP.GenerateKeys(k, n) + _, pubKey, privShares, err := GenerateKeys(k, n) + if err != nil { + t.Fatalf("Failed to generate keys: %v", err) + } + + // 2. Encryption + message := []byte("The quick brown fox jumps over the lazy dog's back 0123456789.") + + // Anyone can encrypt using the Public Key only. + //cipherText, err := Encrypt(pubKey, message) + aaData := []byte("tests additional authenticated, but not encrypted metadata") + //cipherText, err := tdh2ccp.EncryptWithAaD(pubKey, message, aaData) + cipherText, err := EncryptWithAaD(pubKey, message, aaData) + //var emptyLabel [tdh2.InputSize]byte + //cipherText, err := EncryptWithLabelAndAaD(pubKey, message, emptyLabel, aaData) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + fmt.Println("Message encrypted successfully.") + + // 3. Decryption of all n shares + // Each participant creates a 'decryption share' from the ciphertext + // using their own private key share, returns a *DecryptionShare. + // ToDo: generalize for k of n (loop) + share0, err := Decrypt(cipherText, privShares[0]) + if err != nil { + t.Fatalf("Decryption share0 by party 0 failed: %v", err) + } + share1, err := Decrypt(cipherText, privShares[1]) + if err != nil { + t.Fatalf("Decryption share1 by party 1 failed: %v", err) + } + share2, err := Decrypt(cipherText, privShares[2]) + if err != nil { + t.Fatalf("Decryption share2 by party 2 failed: %v", err) + } + + // 4. Verification: Combiner verifies decrypted shares before aggregating them. + // Observe comment from Aggregate(): "Ciphertext and shares MUST be verified + // before calling Aggregate ..." + // ToDo: generalize for k of n (loop) + err = VerifyShare(cipherText, pubKey, share0) + if err != nil { + t.Fatalf("Verify share0 by combiner failed: %v", err) + } + err = VerifyShare(cipherText, pubKey, share1) + if err != nil { + t.Fatalf("Verify share1 by combiner failed: %v", err) + } + err = VerifyShare(cipherText, pubKey, share2) + if err != nil { + t.Fatalf("Verify share2 by combiner failed: %v", err) + } + + // 5. Aggregation: Combine min. k of n decrypted shares to recover the + // original message in cleartext. + // ToDo: Perform fuzzing over cleartext of other messages lenght + // from 0 to max. (2^32 -1)*64 = 256 GB (the first block of 64 byte + // is used by Poly1305), see comment in sym.go. + + // Create a slice of the pointers, not a slice of byte slices. + // All the shares have to be distinct and their number has to be + // at least the threshold k. + //decryptionShares := []*tdh2hybridCCP.DecryptionShare{share0, share1} + decryptionShares := []*DecryptionShare{share0, share1} + if _, err := Aggregate(cipherText, decryptionShares, n); err != nil { + t.Fatalf("Aggregation of share0, share1 failed: %v", err) + } + decryptionShares = []*DecryptionShare{share0, share2} + if _, err := Aggregate(cipherText, decryptionShares, n); err != nil { + t.Fatalf("Aggregation of share0, share2 failed: %v", err) + } + decryptionShares = []*DecryptionShare{share1, share2} + if _, err := Aggregate(cipherText, decryptionShares, n); err != nil { + t.Fatalf("Aggregation of share1, share2 failed: %v", err) + } + decryptionShares = []*DecryptionShare{share2, share0} // rotate (reverse) order + if _, err := Aggregate(cipherText, decryptionShares, n); err != nil { + t.Fatalf("Aggregation of share2, share0 failed: %v", err) + } + decryptionShares = []*DecryptionShare{share0, share1, share2} // all shares + if _, err := Aggregate(cipherText, decryptionShares, n); err != nil { + t.Fatalf("Aggregation of share0, share1, share2 failed: %v", err) + } + // make Aggregate() fail: + decryptionShares = []*DecryptionShare{share1, share1} // shares not distinct + if _, err := Aggregate(cipherText, decryptionShares, n); err == nil { + t.Fatalf("Aggregation of share1, share1 must fail: %v", err) + } + decryptionShares = []*DecryptionShare{share1, share1, share2} // shares not distinct + if _, err := Aggregate(cipherText, decryptionShares, n); err == nil { + t.Fatalf("Aggregation of share1, share1, share2 must fail: %v", err) + } + decryptionShares = []*DecryptionShare{share1} // fewer shares than threshold k + if _, err := Aggregate(cipherText, decryptionShares, n); err == nil { + t.Fatalf("Aggregation of share1 must fail: %v", err) + } + + decryptionShares = []*DecryptionShare{share0, share1} // repeat one last time + decryptedMsg, err := Aggregate(cipherText, decryptionShares, n) + if err != nil { + t.Fatalf("Aggregation of share0, share1 failed: %v", err) + } + if !bytes.Equal(decryptedMsg, message) { + t.Fatalf("decrypeted message does not match cleartext\n got: %#v\n want: %#v", decryptedMsg, message) + } + fmt.Printf("Decrypted Message: %s\n", string(decryptedMsg)) +} diff --git a/go/tdh2/tdh2hybridCCP/sym.go b/go/tdh2/tdh2hybridCCP/sym.go new file mode 100644 index 0000000..2fa9bc3 --- /dev/null +++ b/go/tdh2/tdh2hybridCCP/sym.go @@ -0,0 +1,77 @@ +package tdh2hybridCCP + +import ( + "bytes" + "crypto/rand" + "fmt" + + "golang.org/x/crypto/chacha20poly1305" +) + +// symKey generates a symmetric key. +func symKey(keySize int) ([]byte, error) { + key := make([]byte, keySize) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("cannot generate key") + } + return key, nil +} + +// symEncrypt encrypts the message using the ChaCha20Poly1305 AEAD cipher. +func symEncrypt(msg, key, aaData []byte) ([]byte, []byte, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, nil, fmt.Errorf("cannot use ChaCha20Poly1305: %w", err) + } + + // Counter overflow is catastrophic security failure because keystream repeats + // and attacker can XOR two ciphertexts to cancel out keystream! + // Never reuse a (key, nonce) pair for more than the limit: + // * AES-256-GCM block size is 16 byte, max. 2^32 *16 = 64 GB (conservative + // limit) and RFC 5084 2^36 - 32 bytes ≈ 68.7 GB (theoretical maximum) + // if uint64(len(msg)) > ((1<<32)-2)*uint64(block.BlockSize()) { + // * ChaCha20-Poly1305 block size is 64 byte, max. 2^32 *64 = 256 GB + // which allows 4× larger messages than AES-256-GCM. + // Its block 0 is used by Poly1305: + if uint64(len(msg)) > ((1<<32)-1)*uint64(64) { // + return nil, nil, fmt.Errorf("message too long") + } + // * XChaCha20-Poly1305 (Extended Nonce Variant) uses a 64-bit counter + // instead of 32-bit, and nonce size of 24 vs 12 bytes. Its block + // size is also 64 byte, max 2^64 *64 ≈ 1.18 × 10^21 bytes (1 Zettabyte!) + // which is far beyond any practical use case, e.g. practically unlimited. + + // Generate random nonce (12 bytes for ChaCha20Poly1305, same as AES-GCM) + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, nil, fmt.Errorf("failed to generate nonce: %w", err) + } + + // Encrypt: prepend nonce to ciphertext is done by passing nonce into first parameter 'dst' + // Format: [nonce][ciphertext + aaData + authN tag] + //return aead.Seal(nonce, nonce, msg, nil), nonce, nil // returns (ctxt, nonce, err) + return aead.Seal(nonce, nonce, msg, aaData), nonce, nil // returns (ctxt, nonce, err) +} + +// symDecrypt decrypts the ciphertext using theChaCha20-Poly1305 cipher. +func symDecrypt(nonce, ctxt, key, aaData []byte) ([]byte, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, fmt.Errorf("failed to create ChaCha20Poly1305 cipher: %w", err) + } + if len(ctxt) < aead.NonceSize() { + return nil, fmt.Errorf("ciphertext too short") + } + + // Extract nonce and encrypted data + nonceRecovered := ctxt[:aead.NonceSize()] + if !bytes.Equal(nonceRecovered, nonce) { + return nil, fmt.Errorf("nonce mismatch") + } + encryptedData := ctxt[aead.NonceSize():] + + // Decrypt and verify: AEAD authenticates additional, non-encrypted aaData which + // detects which detects any tampering with metadata + //return aead.Open(nil, nonceRecovered, encryptedData, nil) // authN fails if aaData was set + return aead.Open(nil, nonceRecovered, encryptedData, aaData) +} diff --git a/go/tdh2/tdh2hybridCCP/sym_test.go b/go/tdh2/tdh2hybridCCP/sym_test.go new file mode 100644 index 0000000..1e53ed4 --- /dev/null +++ b/go/tdh2/tdh2hybridCCP/sym_test.go @@ -0,0 +1,184 @@ +package tdh2hybridCCP + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +// const keyLength = 16 // AES-GCM supports 128, 192, and 256 bit keys +const keyLength = 32 // ChaCha20-Poly1305 supports 256 bit keys only! + +var aaData = []byte("some additional authenticated, but not encrypted metadata") + +func TestSymmetric(t *testing.T) { + key, err := symKey(keyLength) + if err != nil { + t.Fatalf("symmetricKey: %v", err) + } + for _, tc := range []struct { + name string + msg []byte + key []byte + aaD []byte // AEAD metata + err error + }{ + { + name: "OK (short message)", + msg: []byte("msg"), + key: key, + }, + { + name: "OK with AAData", + msg: []byte("msg"), + aaD: []byte("metadata"), + key: key, + }, + { + name: "OK (empty message)", + key: key, + }, + { + name: "OK (64 k message)", + msg: make([]byte, 65536), + key: key, + }, + { + name: "wrong key length", + msg: make([]byte, 65536), + key: key[:4], + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + c, nonce, err := symEncrypt(tc.msg, tc.key, aaData) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + out, err := symDecrypt(nonce, c, key, aaData) + if err != nil { + t.Errorf("symmetricDecryption: %v", err) + } + if diff := cmp.Diff(tc.msg, out); diff != "" { + t.Errorf("encrypted/decrypted message diff=%v", diff) + } + }) + } +} + +func TestSymmetricDecryptionFail(t *testing.T) { + msg := []byte("msg") + key, err := symKey(keyLength) + if err != nil { + t.Fatalf("symmetricKey: %v", err) + } + c, nonce, err := symEncrypt(msg, key, aaData) + if err != nil { + t.Fatalf("symmetricEncryption: %v", err) + } + for _, tc := range []struct { + name string + nonce []byte + c []byte // ctxt + key []byte + aaD []byte // AEAD metadata + err error + }{ + { + name: "OK", + key: key, + nonce: nonce, + c: c, + }, + { + name: "wrong key", + key: []byte("key"), + nonce: nonce, + c: c, + err: cmpopts.AnyError, + }, + { + name: "wrong nonce", + key: key, + nonce: []byte("nonce"), + c: c, + err: cmpopts.AnyError, + }, + { + name: "wrong ciphertext", + key: key, + nonce: nonce, + c: []byte("ciphertext"), + err: cmpopts.AnyError, + }, + { + name: "wrong AAD", + key: key, + aaD: []byte("wrong"), + c: c, + }, + { + name: "nil AAD when AAD was used", + key: key, + aaD: nil, + c: c, + }, + { + name: "wrong nonce with AAD", + key: key, + aaD: aaData, + nonce: []byte("nonce"), + c: c, + }, + } { + t.Run(tc.name, func(t *testing.T) { + out, err := symDecrypt(nonce, c, key, aaData) + if err != nil { + t.Errorf("symmetricDecryption: %v", err) + } + if diff := cmp.Diff(msg, out); diff != "" { + t.Errorf("encrypted/decrypted message diff=%v", diff) + } + }) + } +} + +func FuzzSymEncryption(f *testing.F) { + f.Add(16, []byte("sample message"), aaData) + f.Add(24, []byte("another sample message"), aaData) + f.Add(32, []byte("and another sample message"), aaData) + f.Fuzz(func(t *testing.T, keySize int, msg []byte, aaD []byte) { + //if keySize != 16 && keySize != 24 && keySize != 32 { // AES-GCM + if keySize != keyLength { // ChaCha20-Poly1305 + t.Skip() + } + key, err := symKey(keySize) + if err != nil { + t.Fatalf("symKey(%v): %v", keySize, err) + } + c, n, err := symEncrypt(msg, key, aaData) + if err != nil { + t.Fatalf("symEncrypt(%v, %v): %v", msg, key, err) + } + p, err := symDecrypt(n, c, key, aaData) + if err != nil { + t.Fatalf("symDecryt(%v, %v, %v): %v", n, c, key, err) + } + if d := cmp.Diff(p, msg); d != "" { + t.Fatalf("got/want diff=%v", d) + } + // Verify wrong AAD causes failure + if len(aaD) > 0 { + wrongAAD := make([]byte, len(aaD)) + copy(wrongAAD, aaD) + wrongAAD[0] ^= 0x42 // inject errors + _, err = symDecrypt(n, c, key, wrongAAD) + if err == nil { + t.Fatal("Expected error with wrong AAD") + } + } + }) +} diff --git a/go/tdh2/tdh2hybridCCP/tdh2hybridCCP.go b/go/tdh2/tdh2hybridCCP/tdh2hybridCCP.go new file mode 100644 index 0000000..f6fa2b7 --- /dev/null +++ b/go/tdh2/tdh2hybridCCP/tdh2hybridCCP.go @@ -0,0 +1,338 @@ +// Package tdh2hybridCCP implements an easy interface of TDH2-based hybrid encryption +// with a moden stream cipher ChaCha20-Poly1305. +package tdh2hybridCCP + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/json" + "fmt" + + "github.com/smartcontractkit/tdh2/go/tdh2/lib/group/nist" + "github.com/smartcontractkit/tdh2/go/tdh2/tdh2" + "golang.org/x/crypto/chacha20poly1305" +) + +// key size used in symmetric encryption (AES replaced by ChaCha20Poly1350). +// 256 bits is a higher securitylevel than provided by the EC group deployed, +// but as tdh2.InputSize is 256 bits we decided to use the same value. +// const symKeySize = 32 // AES-256 +const symKeySize = chacha20poly1305.KeySize // 32 byte + +// defaultGroup is the default EC group used. +var defaultGroup = nist.NewP256() + +// PrivateShare encodes TDH2 private share. +type PrivateShare struct { + p *tdh2.PrivateShare +} + +// Index returns private share index. +func (p *PrivateShare) Index() int { + return p.p.Index() +} + +func (p PrivateShare) Marshal() ([]byte, error) { + return p.p.Marshal() +} + +func (p *PrivateShare) MarshalJSON() ([]byte, error) { + return p.Marshal() +} + +func (p *PrivateShare) Unmarshal(data []byte) error { + p.p = &tdh2.PrivateShare{} + return p.p.Unmarshal(data) +} + +func (p *PrivateShare) UnmarshalJSON(data []byte) error { + return p.Unmarshal(data) +} + +func (p *PrivateShare) Clear() { + p.p.Clear() +} + +// DecryptionShare encodes TDH2 decryption share. +type DecryptionShare struct { + d *tdh2.DecryptionShare +} + +// Index returns private share index. +func (d *DecryptionShare) Index() int { + return d.d.Index() +} + +func (d DecryptionShare) Marshal() ([]byte, error) { + return d.d.Marshal() +} + +func (d DecryptionShare) MarshalJSON() ([]byte, error) { + return d.Marshal() +} + +func (d *DecryptionShare) Unmarshal(data []byte) error { + d.d = &tdh2.DecryptionShare{} + return d.d.Unmarshal(data) +} + +func (d *DecryptionShare) UnmarshalJSON(data []byte) error { + return d.Unmarshal(data) +} + +// PublicKey encodes TDH2 public key. +type PublicKey struct { + p *tdh2.PublicKey +} + +func (p PublicKey) Marshal() ([]byte, error) { + return p.p.Marshal() +} + +func (p *PublicKey) MarshalJSON() ([]byte, error) { + return p.Marshal() +} + +func (p *PublicKey) Unmarshal(data []byte) error { + p.p = &tdh2.PublicKey{} + return p.p.Unmarshal(data) +} + +func (p *PublicKey) UnmarshalJSON(data []byte) error { + return p.Unmarshal(data) +} + +// MasterSecret encodes TDH2 master key. +type MasterSecret struct { + m *tdh2.MasterSecret +} + +func (m MasterSecret) Marshal() ([]byte, error) { + return m.m.Marshal() +} + +func (m MasterSecret) MarshalJSON() ([]byte, error) { + return m.Marshal() +} + +func (m *MasterSecret) Unmarshal(data []byte) error { + m.m = &tdh2.MasterSecret{} + return m.m.Unmarshal(data) +} + +func (m MasterSecret) UnmarshalJSON(data []byte) error { + return m.Unmarshal(data) +} + +func (m *MasterSecret) Clear() { + m.m.Clear() +} + +// Ciphertext encodes hybrid ciphertext. +// ChaCha20-Poly1305 implements Authenticated Encryption with Associated +// Data (AEAD), also called Additional Authenticated Data (AAD). +// It encrypts sensitive payload data while allowing additional, authenticated +// but not encrypted metadata ("associated data") to be authenticated +// along with the ciphertext which detects any tampering (modern: fast & safe). +type Ciphertext struct { + tdh2Ctxt *tdh2.Ciphertext + symCtxt []byte + nonce []byte + aaData []byte // store for verification by AEAD during decryption + //label [tdh2.InputSize]byte // also included in tdh2Ctxt, redundant? +} + +// Decrypt returns a decryption share for the ciphertext. +func Decrypt(c *Ciphertext, x_i *PrivateShare) (*DecryptionShare, error) { + r, err := randStream() + if err != nil { + return nil, err + } + d, err := c.tdh2Ctxt.Decrypt(defaultGroup, x_i.p, r) + if err != nil { + return nil, err + } + return &DecryptionShare{d}, nil +} + +// VerifyShare checks if the share matches the ciphertext and public key. +func VerifyShare(c *Ciphertext, pk *PublicKey, share *DecryptionShare) error { + return tdh2.VerifyShare(pk.p, c.tdh2Ctxt, share.d) +} + +// Aggregate decrypts the TDH2-encrypted key and using it recovers the +// symmetrically encrypted plaintext. It takes decryption shares and +// the total number of participants as the arguments. +// Ciphertext and shares MUST be verified before calling Aggregate, +// all the shares have to be distinct and their number has to be +// at least k (the scheme's threshold). +func Aggregate(c *Ciphertext, shares []*DecryptionShare, n int) ([]byte, error) { + sh := []*tdh2.DecryptionShare{} + for _, s := range shares { + sh = append(sh, s.d) + } + key, err := c.tdh2Ctxt.CombineShares(defaultGroup, sh, len(sh), n) + if err != nil { + return nil, fmt.Errorf("cannot combine shares: %w", err) + } + if symKeySize != len(key) { + return nil, fmt.Errorf("incorrect key size") + } + return symDecrypt(c.nonce, c.symCtxt, key, c.aaData) +} + +// randStream returns a stream cipher used for providing randomness. +func randStream() (cipher.Stream, error) { + key := make([]byte, symKeySize) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("cannot generate key: %w", err) + } + iv := make([]byte, aes.BlockSize) + if _, err := rand.Read(iv); err != nil { + return nil, fmt.Errorf("cannot generate iv: %w", err) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("cannot init aes: %w", err) + } + return cipher.NewCTR(block, iv), nil +} + +type ciphertextRaw struct { + TDH2Ctxt []byte + SymCtxt []byte + Nonce []byte + AaData []byte +} + +func (c *Ciphertext) Marshal() ([]byte, error) { + ctxt, err := c.tdh2Ctxt.Marshal() + if err != nil { + return nil, fmt.Errorf("cannot marshal TDH2 ciphertext: %w", err) + } + return json.Marshal(&ciphertextRaw{ + TDH2Ctxt: ctxt, + SymCtxt: c.symCtxt, + Nonce: c.nonce, + AaData: c.aaData, + }) +} + +// UnmarshalVerify unmarshals ciphertext and verifies if it matches the public key. +func (c *Ciphertext) UnmarshalVerify(data []byte, pk *PublicKey) error { + var raw ciphertextRaw + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("cannot unmarshal data: %w", err) + } + c.symCtxt = raw.SymCtxt + c.nonce = raw.Nonce + c.tdh2Ctxt = &tdh2.Ciphertext{} + c.aaData = raw.AaData + if err := c.tdh2Ctxt.Unmarshal(raw.TDH2Ctxt); err != nil { + return fmt.Errorf("cannot unmarshal TDH2 ciphertext: %w", err) + } + + if err := c.tdh2Ctxt.Verify(pk.p); err != nil { + return fmt.Errorf("tdh2 ciphertext verification: %w", err) + } + return nil +} + +// GenerateKeys generates and returns, the master secret, public key, and private shares. It takes the +// total number of nodes n and a threshold k (the number of shares sufficient for decryption). +func GenerateKeys(k, n int) (*MasterSecret, *PublicKey, []*PrivateShare, error) { + r, err := randStream() + if err != nil { + return nil, nil, nil, err + } + ms, pk, sh, err := tdh2.GenerateKeys(defaultGroup, nil, k, n, r) + if err != nil { + return nil, nil, nil, err + } + shares := []*PrivateShare{} + for i := range sh { + shares = append(shares, &PrivateShare{sh[i]}) + } + return &MasterSecret{ms}, &PublicKey{pk}, shares, nil +} + +// Redeal re-keys private shares such that new quorums can decrypt old ciphertexts. +// It takes the previous public key and master secret as well as the number of nodes +// sufficient for decrypt k, and the total number of nodes n. It returns a new public +// key and private shares. +// Note: The public key returned corresponds to the master secret passed in. Thus, +// the old public key can still be used for encryption but it cannot be used for share +// verification (the new key has to be used instead). +func Redeal(pk *PublicKey, ms *MasterSecret, k, n int) (*PublicKey, []*PrivateShare, error) { + r, err := randStream() + if err != nil { + return nil, nil, err + } + p, sh, err := tdh2.Redeal(pk.p, ms.m, k, n, r) + if err != nil { + return nil, nil, err + } + shares := []*PrivateShare{} + for i := range sh { + shares = append(shares, &PrivateShare{sh[i]}) + } + return &PublicKey{p}, shares, nil +} + +// Encrypt generates a fresh symmetric key, encrypts and authenticates +// the message with it, and encrypts the key using TDH2 with empty label. +// It returns a struct encoding the generated ciphertexts. +func Encrypt(pk *PublicKey, msg []byte) (*Ciphertext, error) { + return EncryptWithLabel(pk, msg, [tdh2.InputSize]byte{}) +} + +// EncryptWithLabel is identical to Encrypt but allows passing a +// non-empty label from TDH2 where tdh2.InputSize = sha256.Size = 32 byte +func EncryptWithLabel(pk *PublicKey, msg []byte, label [tdh2.InputSize]byte) (*Ciphertext, error) { + return EncryptWithLabelAndAaD(pk, msg, label, nil) +} + +func EncryptWithAaD(pk *PublicKey, msg, aaData []byte) (*Ciphertext, error) { + var emptyLabel [tdh2.InputSize]byte + return EncryptWithLabelAndAaD(pk, msg, emptyLabel, aaData) +} + +func EncryptWithLabelAndAaD(pk *PublicKey, msg []byte, label [tdh2.InputSize]byte, aaData []byte) (*Ciphertext, error) { + if symKeySize != tdh2.InputSize { + return nil, fmt.Errorf("incorrect key size") + } + // generate a fresh key and encrypt the message + key, err := symKey(tdh2.InputSize) + if err != nil { + return nil, fmt.Errorf("cannot generate key: %w", err) + } + // for each encryption a fresh key and nonce are generated, + // therefore the probability of nonce misuse is negligible + symCtxt, nonce, err := symEncrypt(msg, key, aaData) + if err != nil { + return nil, fmt.Errorf("cannot encrypt message: %w", err) + } + + r, err := randStream() + if err != nil { + return nil, err + } + // encrypt the key with TDH2 using the provided label + tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, label[:], r) + if err != nil { + return nil, fmt.Errorf("cannot TDH2 encrypt: %w", err) + } + return &Ciphertext{ + tdh2Ctxt: tdh2Ctxt, + symCtxt: symCtxt, + nonce: nonce, + aaData: aaData, + //label: label, // also included in tdh2Ctxt, redundant? + }, nil +} + +// Label returns a defensive copy of the ciphertext's TDH2 label. +func (c *Ciphertext) Label() [tdh2.InputSize]byte { + return c.tdh2Ctxt.Label() +} diff --git a/go/tdh2/tdh2hybridCCP/tdh2hybridCCP_test.go b/go/tdh2/tdh2hybridCCP/tdh2hybridCCP_test.go new file mode 100644 index 0000000..467a2e7 --- /dev/null +++ b/go/tdh2/tdh2hybridCCP/tdh2hybridCCP_test.go @@ -0,0 +1,660 @@ +//go:build !tinygo + +package tdh2hybridCCP + +// TinyGo has limited support for reflect to save space. +// Currently, 'tinygo test' panics due to !reflect.DeepEqual() below. +// Workaround: Copy test file, remove dependencies on reflect and add a +// build tag/constraint '//go:build tinygo' at the top of the file. +// Note: The '//go:build' line must be at the very top of the file, and +// followed by a blank line before the package declaration! + +// Replacing reflect.DeepEqual with cmp.Equal +// Problem that causes panic: +// cmp.Diff cannot compare structs with unexported fields (e.g. fields that +// start with _lowercase_ letters). PublicKey (pk), PrivateShare, etc. +// wrap tdh2 types that have unexported fields. +// Solution: +// Don't compare the structs directly - instead, compare their marshaled bytes! + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/smartcontractkit/tdh2/go/tdh2/lib/group/nist" + "github.com/smartcontractkit/tdh2/go/tdh2/tdh2" +) + +func TestShareIndex(t *testing.T) { + _, pk, sh, err := GenerateKeys(5, 10) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + for i := range sh { + if sh[i].Index() != i { + t.Errorf("index=%v, want=%v", sh[i].Index(), i) + } + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + for i, s := range sh { + ds, err := Decrypt(c, s) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if ds.Index() != i { + t.Errorf("index=%v, want=%v", ds.Index(), i) + } + } +} + +func TestPrivateShareMarshal(t *testing.T) { + _, _, wantShare, err := GenerateKeys(2, 3) // returns MasterSecret (ms), PublicKey (pk), PrivateShare, err + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + wantShareBytes, err := wantShare[0].Marshal() //serialize original + if err != nil { + t.Fatalf("Marshal: %v", err) + } + /* + var got PrivateShare + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(got.p, want[0].p) { // TinyGo panics! + t.Errorf("got=%v want=%v", got, want[0]) + } + */ + //if diff := cmp.Diff(wantShare[0].p, gotShare.p); diff != "" { + //if diff := cmp.Diff(wantShare[0].p, gotShare.p, cmpopts.EquateComparable()); diff != "" { + //if diff := cmp.Diff(wantShare[0].p, gotShare.p, cmpopts.IgnoreUnexported(PrivateShare{})); diff != "" { + // t.Errorf("mismatch (-want +got):\n%s", diff) + //} + gotShare := &PrivateShare{} // deserialize to new struct + if err := gotShare.Unmarshal(wantShareBytes); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if wantShare[0].Index() != gotShare.Index() { // Compare public API + t.Errorf("index mismatch: got %d, want %d", gotShare.Index(), wantShare[0].Index()) + } + gotShareBytes, _ := gotShare.Marshal() // serialize again and compare byte slices + if diff := cmp.Diff(wantShareBytes, gotShareBytes); diff != "" { + t.Errorf("marshaled share mismatch (-want +got):\n%s", diff) + } + + if err := gotShare.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestDecryptionShareMarshal(t *testing.T) { + _, pk, sh, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + want, err := Decrypt(c, sh[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got DecryptionShare + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(got.d, want.d) { // TinyGo panics! + t.Errorf("got=%v want=%v", got, want) + } + if err := got.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestPublicKeyMarshal(t *testing.T) { + _, want, _, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got PublicKey + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !got.p.Equal(want.p) { + t.Errorf("got=%v want=%v", got, want) + } + if err := got.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestMasterSecretMarshal(t *testing.T) { + want, _, _, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got MasterSecret + if err := got.Unmarshal(b); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !reflect.DeepEqual(got.m, want.m) { // TinyGo panics! + t.Errorf("got=%v want=%v", got, want) + } + if err := got.Unmarshal([]byte("broken")); err == nil { + t.Errorf("Unmarshal did not fail") + } +} + +func TestCiphertextDecrypt(t *testing.T) { + _, pk, share, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + r, err := randStream() + if err != nil { + t.Fatalf("RandStream: %v", err) + } + _, _, wrong, err := tdh2.GenerateKeys(nist.NewP521(), nil, 1, 1, r) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + if _, err := Decrypt(c, share[0]); err != nil { + t.Errorf("Decrypt: %v", err) + } + if _, err := Decrypt(c, &PrivateShare{wrong[0]}); err == nil { + t.Errorf("Decrypt did not fail") + } +} + +func TestCiphertextVerifyShare(t *testing.T) { + _, pk, share, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + _, _, wrongShare, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds, err := Decrypt(c, share[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + wrongDs, err := Decrypt(c, wrongShare[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(c, pk, ds); err != nil { + t.Errorf("VerifyShare: %v", err) + } + if err := VerifyShare(c, pk, wrongDs); err == nil { + t.Errorf("VerifyShare did not fail") + } +} + +func TestAggregate(t *testing.T) { + k := 3 + n := 5 + _, pk, shares, err := GenerateKeys(k, n) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + msg := []byte("message") + c, err := Encrypt(pk, msg) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + decShares := make([]*DecryptionShare, n) + for i := range shares { + ds, err := Decrypt(c, shares[i]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + decShares[i] = ds + } + for _, tc := range []struct { + name string + ctxt *Ciphertext + shares []*DecryptionShare + err error + }{ + { + name: "OK (all shares)", + ctxt: c, + shares: decShares, + }, + { + name: "OK (min shares)", + ctxt: c, + shares: decShares[:k], + }, + { + name: "not enough shares", + ctxt: c, + shares: decShares[:2], + err: cmpopts.AnyError, + }, + { + name: "wrong nonce", + ctxt: &Ciphertext{ + tdh2Ctxt: c.tdh2Ctxt, + symCtxt: c.symCtxt, + nonce: make([]byte, len(c.nonce)), + }, + shares: decShares, + err: cmpopts.AnyError, + }, + { + name: "wrong nonce size", + ctxt: &Ciphertext{ + tdh2Ctxt: c.tdh2Ctxt, + symCtxt: c.symCtxt, + nonce: []byte("nonce"), + }, + shares: decShares, + err: cmpopts.AnyError, + }, + { + name: "wrong symmetric ciphertext", + ctxt: &Ciphertext{ + tdh2Ctxt: c.tdh2Ctxt, + symCtxt: []byte("ciphertext"), + nonce: c.nonce, + }, + shares: decShares, + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + out, err := Aggregate(tc.ctxt, tc.shares, n) + if !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("err=%v, want=%v", err, tc.err) + } else if err != nil { + return + } + if diff := cmp.Diff(msg, out); diff != "" { + t.Errorf("encrypted decrypted message diff=%v", diff) + } + }) + } +} + +func TestCiphertextMarshal(t *testing.T) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + want, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var got Ciphertext + if err := got.UnmarshalVerify(b, pk); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if d := cmp.Diff(got.symCtxt, want.symCtxt); d != "" { + t.Errorf("got/want Ciphertext diff=%v", d) + } + if d := cmp.Diff(got.nonce, want.nonce); d != "" { + t.Errorf("got/want Nonce diff=%v", d) + } + if !got.tdh2Ctxt.Equal(want.tdh2Ctxt) { + t.Errorf("different ciphertexts") + } +} + +func TestCiphertextUnmarshal(t *testing.T) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + _, wrong, _, err := GenerateKeys(1, 1) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + c, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + cRaw, err := c.Marshal() + if err != nil { + t.Fatalf("Marshal: %v", err) + } + brokenTdh2, err := json.Marshal(&ciphertextRaw{ + TDH2Ctxt: []byte("broken"), + SymCtxt: []byte("ciphertext"), + Nonce: []byte("nonce"), + }) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + for _, tc := range []struct { + name string + raw []byte + pk *PublicKey + err error + }{ + { + name: "ok", + raw: cRaw, + pk: pk, + }, + { + name: "wrong pk", + raw: cRaw, + pk: wrong, + err: cmpopts.AnyError, + }, + { + name: "broken", + raw: []byte("broken"), + pk: pk, + err: cmpopts.AnyError, + }, + { + name: "broken tdh2 ciphertext", + raw: brokenTdh2, + pk: pk, + err: cmpopts.AnyError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var hc Ciphertext + if err := hc.UnmarshalVerify(tc.raw, tc.pk); !cmp.Equal(err, tc.err, cmpopts.EquateErrors()) { + t.Errorf("got err=%v, want=%v", err, tc.err) + } + }) + } +} + +func TestRedealEncryptNew(t *testing.T) { + ms, pk, _, err := GenerateKeys(3, 5) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + want := []byte("msg") + for _, tc := range []struct { + name string + k, n int + }{ + { + name: "same n,k", + k: 3, + n: 5, + }, + { + name: "smaller quorum", + k: 2, + n: 5, + }, + { + name: "larger quorum", + k: 4, + n: 5, + }, + } { + t.Run(tc.name, func(t *testing.T) { + // generate new instance + newPk, shares, err := Redeal(pk, ms, tc.k, tc.n) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + // encrypt and decrypt using new keys + c, err := Encrypt(newPk, want) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + ds := []*DecryptionShare{} + for _, sh := range shares { + d, err := Decrypt(c, sh) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(c, newPk, d); err != nil { + t.Fatalf("VerifyShare: %v", err) + } + ds = append(ds, d) + } + if got, err := Aggregate(c, ds[:tc.k], tc.n); err != nil { + t.Errorf("Aggregate: %v", err) + } else if !cmp.Equal(got, want) { + t.Errorf("got=%v, want=%v", got, want) + } + }) + } +} + +func TestRedealDecryptOld(t *testing.T) { + ms, pk, _, err := GenerateKeys(3, 5) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + want := []byte("msg") + c, err := Encrypt(pk, want) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + for _, tc := range []struct { + name string + k, n int + }{ + { + name: "same n,k", + k: 3, + n: 5, + }, + { + name: "smaller quorum", + k: 2, + n: 5, + }, + { + name: "larger quorum", + k: 4, + n: 5, + }, + } { + t.Run(tc.name, func(t *testing.T) { + // generate new instance + new, shares, err := Redeal(pk, ms, tc.k, tc.n) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + // try to decrypt old ciphertext + ds := []*DecryptionShare{} + for _, sh := range shares { + d, err := Decrypt(c, sh) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if err := VerifyShare(c, new, d); err != nil { + t.Fatalf("VerifyShare: %v", err) + } + ds = append(ds, d) + } + // should fail w/o enough shares + if _, err := Aggregate(c, ds[:tc.k-1], tc.n); err == nil { + t.Error("Aggregate did not fail") + } + // try with enough shares + if got, err := Aggregate(c, ds[:tc.k], tc.n); err != nil { + t.Errorf("Aggregate: %v", err) + } else if !cmp.Equal(got, want) { + t.Errorf("got=%v, want=%v", got, want) + } + }) + } +} + +func TestRedealReuseOldShares(t *testing.T) { + ms, pk, shares, err := GenerateKeys(3, 5) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + newPk, _, err := Redeal(pk, ms, 3, 5) + if err != nil { + t.Fatalf("Redeal: %v", err) + } + c, err := Encrypt(newPk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + // use old share for decryption + ds, err := Decrypt(c, shares[0]) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + // make sure old shares cannot be used for new encryptions + if err := VerifyShare(c, newPk, ds); err == nil { + t.Error("VerifyShare did not fail") + } +} + +func FuzzCiphertextMarshal(f *testing.F) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + f.Fatalf("Keys: %v", err) + } + r, err := randStream() + if err != nil { + f.Fatalf("randStream: %v", err) + } + tdh2Input := make([]byte, tdh2.InputSize) + f.Add(tdh2Input, []byte("symcCtxt"), []byte("nonce")) + f.Fuzz(func(t *testing.T, key, symCtxt, nonce []byte) { + if len(key) != tdh2.InputSize { + t.Skip() + } + tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, tdh2Input, r) + if err != nil { + t.Fatalf("Encrypt(%v): %v", key, err) + } + want := Ciphertext{ + tdh2Ctxt: tdh2Ctxt, + symCtxt: symCtxt, + nonce: nonce, + } + b, err := want.Marshal() + if err != nil { + t.Fatalf("Marshal(%v): %v", want, err) + } + var got Ciphertext + if err := got.UnmarshalVerify(b, pk); err != nil { + t.Fatalf("UnmarshalVerify(%v): %v", b, err) + } + }) +} + +func FuzzCiphertextUnmarshal(f *testing.F) { + _, pk, _, err := GenerateKeys(1, 1) + if err != nil { + f.Fatalf("Keys: %v", err) + } + r, err := randStream() + if err != nil { + f.Fatalf("ranStream: %v", err) + } + tdh2Ctxt, err := tdh2.Encrypt(pk.p, make([]byte, tdh2.InputSize), make([]byte, tdh2.InputSize), r) + if err != nil { + f.Fatalf("Encrypt: %v", err) + } + c := Ciphertext{ + tdh2Ctxt: tdh2Ctxt, + symCtxt: []byte("symCtxt"), + nonce: []byte("nonce"), + } + b, err := c.Marshal() + if err != nil { + f.Fatalf("Marshal: %v", err) + } + f.Add(b) + f.Fuzz(func(t *testing.T, data []byte) { + var c1, c2 Ciphertext + if err := c1.UnmarshalVerify(data, pk); err != nil { + t.Skip() + } + data1, err := c1.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data, err) + } + if err := c2.UnmarshalVerify(data1, pk); err != nil { + t.Fatalf("Cannot unmarshal: data=%v err=%v", data1, err) + } + data2, err := c2.Marshal() + if err != nil { + t.Fatalf("Cannot marshal: data=%v err=%v", data2, err) + } + if !bytes.Equal(data1, data2) { + t.Errorf("data1=%v data2=%v", data1, data2) + } + if !bytes.Equal(c1.symCtxt, c2.symCtxt) { + t.Errorf("c1.symCtxt=%v data1=%v c2.symCtxt=%v data2=%v", c1.symCtxt, data1, c2.symCtxt, data2) + + } + if !bytes.Equal(c1.nonce, c2.nonce) { + t.Errorf("c1.nonce=%v data1=%v c2.nonce=%v data2=%v", c1.nonce, data1, c2.nonce, data2) + + } + if !c1.tdh2Ctxt.Equal(c2.tdh2Ctxt) { + t.Errorf("c1.tdh2Ctxt=%v data1=%v c2.tdh2Ctxt=%v data2=%v", c1.tdh2Ctxt, data1, c2.tdh2Ctxt, data2) + } + }) +} + +// TestEncryptWithLabel ensures non-empty labels are preserved, and default Encrypt uses empty label. +func TestEncryptWithLabel(t *testing.T) { + _, pk, _, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + var label [tdh2.InputSize]byte + for i := range label { + label[i] = byte(i + 1) + } + c, err := EncryptWithLabel(pk, []byte("msg"), label) + if err != nil { + t.Fatalf("EncryptWithLabel: %v", err) + } + if got := c.Label(); got != label { + t.Errorf("label mismatch got=%v want=%v", got, label) + } + // Ensure regular Encrypt produces all-zero label. + cZero, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + if got := cZero.Label(); got != [tdh2.InputSize]byte{} { + t.Errorf("expected zero label got=%v", got) + } +}