From 5fe7fe179ce9bd17524ef1fdabe55ea77665b5fb Mon Sep 17 00:00:00 2001 From: schmidtw Date: Fri, 7 Feb 2025 16:36:26 -0800 Subject: [PATCH] BREAKING CHANGE:Streamline to align better with jwk api. BREAKING CHANGE! A few of the options go away in favor of leaning on the jws/jwk library and jwskeychain to provide more of that verification. Removed: - RequirePolicies - Require - TrustRootCAs - Verifier - SignWith - SignWithRaw Added: - Signer - NewSigner - SignWithX509Chain - SignWithKey - Decoder - NewDecoder - WithKeyProvider - WithKeySet - WithKeyUsed - WithVerifyAuto This was done to incorporate the flexability available from the jws/jwk libraries but limit some of the things that could be configured that would break this library. --- decode.go | 25 ++++---- decodeopts.go | 85 +++++++++++--------------- encode_decode_test.go | 127 ++++++++++++++++++++++++++------------ go.mod | 2 +- go.sum | 6 +- securly.go | 4 +- sign.go | 42 +++---------- signopts.go | 138 ++++++++++++++---------------------------- signopts_test.go | 115 ++++++++++++++++++----------------- 9 files changed, 253 insertions(+), 291 deletions(-) diff --git a/decode.go b/decode.go index 007ab01..e71fc73 100644 --- a/decode.go +++ b/decode.go @@ -5,7 +5,6 @@ package securly import ( "github.com/lestrrat-go/jwx/v2/jws" - "github.com/xmidt-org/jwskeychain" ) // Decode converts a slice of bytes into a *Message if possible. Depending on @@ -16,27 +15,25 @@ import ( // message. If you want to skip this verification, you can pass the // NoVerification() option. func Decode(buf []byte, opts ...DecoderOption) (*Message, error) { - p, err := newDecoder(opts...) + p, err := NewDecoder(opts...) if err != nil { return nil, err } - return p.decode(buf) + return p.Decode(buf) } -// decoder contains the configuration for decoding a set of messages. -type decoder struct { +// Decoder contains the configuration for decoding a set of messages. +type Decoder struct { noVerification bool - opts []jwskeychain.Option - provider *jwskeychain.Provider + verifyOpts []jws.VerifyOption } -// newDecoder converts a slice of bytes plus options into a Message. -func newDecoder(opts ...DecoderOption) (*decoder, error) { - var p decoder +// NewDecoder converts a slice of bytes plus options into a Message. +func NewDecoder(opts ...DecoderOption) (*Decoder, error) { + var p Decoder vadors := []DecoderOption{ - createTrust(), validateRoots(), } @@ -52,8 +49,8 @@ func newDecoder(opts ...DecoderOption) (*decoder, error) { return &p, nil } -// decode converts a slice of bytes into a *Message if possible. -func (p *decoder) decode(buf []byte) (*Message, error) { +// Decode converts a slice of bytes into a *Message if possible. +func (p *Decoder) Decode(buf []byte) (*Message, error) { JWS, err := decompress(buf) if err != nil { return nil, err @@ -68,7 +65,7 @@ func (p *decoder) decode(buf []byte) (*Message, error) { payload = trusted.Payload() } } else { - payload, err = jws.Verify(JWS, jws.WithKeyProvider(p.provider)) + payload, err = jws.Verify(JWS, p.verifyOpts...) } if err != nil { return nil, err diff --git a/decodeopts.go b/decodeopts.go index 9cf575d..9710b07 100644 --- a/decodeopts.go +++ b/decodeopts.go @@ -4,94 +4,81 @@ package securly import ( - "context" - "crypto/x509" "fmt" - "time" - "github.com/xmidt-org/jwskeychain" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" ) // Option is a functional option for the Instructions constructor. type DecoderOption interface { - apply(*decoder) error + apply(*Decoder) error } -type decoderOptionFunc func(*decoder) error +type errDecoderOptionFunc func(*Decoder) error -func (f decoderOptionFunc) apply(p *decoder) error { +func (f errDecoderOptionFunc) apply(p *Decoder) error { return f(p) } -// TrustRootCAs specifies a list of root CAs to trust when verifying the signature. -func TrustRootCAs(certs ...*x509.Certificate) DecoderOption { - return decoderOptionFunc(func(p *decoder) error { - p.opts = append(p.opts, jwskeychain.TrustedRoots(certs...)) +func decoderOptionFunc(f func(*Decoder)) DecoderOption { + return errDecoderOptionFunc(func(p *Decoder) error { + f(p) return nil }) } -// RequirePolicies specifies a list of policies that must be present in the -// signing chain intermediates. -func RequirePolicies(policies ...string) DecoderOption { - return decoderOptionFunc(func(p *decoder) error { - p.opts = append(p.opts, jwskeychain.RequirePolicies(policies...)) - return nil +func verifyOptionFunc(opt jws.VerifyOption) DecoderOption { + return decoderOptionFunc(func(p *Decoder) { + if opt != nil { + p.verifyOpts = append(p.verifyOpts, opt) + } }) } -// Verifier is an interface that defines a function to verify a certificate chain. -type Verifier interface { - Verify(ctx context.Context, chain []*x509.Certificate, now time.Time) error +// WithKeyProvider enables using a jws.KeyProvider. See [jws.WithKeyProvider] +// for more details. +// +// It is likely you will want to use this option with [jwskeychain.Provider] +// package. +func WithKeyProvider(provider jws.KeyProvider) DecoderOption { + return verifyOptionFunc(jws.WithKeyProvider(provider)) } -// VerifierFunc is a function type that implements the Verifier interface. -type VerifierFunc func(ctx context.Context, chain []*x509.Certificate, now time.Time) error - -func (vf VerifierFunc) Verify(ctx context.Context, chain []*x509.Certificate, now time.Time) error { - return vf(ctx, chain, now) +// WithKeySet enables using a jwk.Set. See [jws.WithKeySet] for more details. +func WithKeySet(set jwk.Set, options ...jws.WithKeySetSuboption) DecoderOption { + return verifyOptionFunc(jws.WithKeySet(set, options...)) } -// _ is a compile-time assertion that VerifierFunc implements the Verifier interface. -var _ jwskeychain.Verifier = VerifierFunc(nil) +// WithKeyUsed enables using the [jws.WithKeyUsed] option. See [jws.WithKeyUsed] +// for more details. +func WithKeyUsed(v any) DecoderOption { + return verifyOptionFunc(jws.WithKeyUsed(v)) +} -// Require provides a way to provide a custom verifier for the certificate chain. -func Require(v Verifier) DecoderOption { - return decoderOptionFunc(func(p *decoder) error { - p.opts = append(p.opts, jwskeychain.Require(v)) - return nil - }) +// WithVerifyAuto enables using the [jws.WithVerifyAuto] option. See +// [jws.WithVeriftyAuto] for more details. +func WithVerifyAuto(f jwk.Fetcher, options ...jwk.FetchOption) DecoderOption { + return verifyOptionFunc(jws.WithVerifyAuto(f, options...)) } // NoVerification does not verify the signature or credentials, but decodes // the Message. Generally this is only useful if testing. DO NOT use this in // production. This will intentionally conflict with the TrustedRootCA() option. func NoVerification() DecoderOption { - return decoderOptionFunc(func(p *decoder) error { + return decoderOptionFunc(func(p *Decoder) { p.noVerification = true - return nil }) } // ------------------------------------------------------------------------------ -func createTrust() DecoderOption { - return decoderOptionFunc(func(p *decoder) error { - trusted, err := jwskeychain.New(p.opts...) - if err != nil { - return err - } - p.provider = trusted - return nil - }) -} - func validateRoots() DecoderOption { - return decoderOptionFunc(func(p *decoder) error { - if p.noVerification || len(p.provider.Roots()) > 0 { + return errDecoderOptionFunc(func(p *Decoder) error { + if p.noVerification || len(p.verifyOpts) > 0 { return nil } - return fmt.Errorf("no trusted root CAs provided") + return fmt.Errorf("no valid sources of trust") }) } diff --git a/encode_decode_test.go b/encode_decode_test.go index a73efcf..121124c 100644 --- a/encode_decode_test.go +++ b/encode_decode_test.go @@ -11,8 +11,10 @@ import ( "time" "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/xmidt-org/jwskeychain" ) type encodeDecodeTest struct { @@ -28,21 +30,25 @@ type encodeDecodeTest struct { var simpleWorking = encodeDecodeTest{ desc: "simple, working", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), }, decOpts: []DecoderOption{ - TrustRootCAs(chainA.Root().Public), - RequirePolicies("1.2.100"), + WithKeyProvider( + must(jwskeychain.New( + jwskeychain.TrustedRoots(chainA.Root().Public), + jwskeychain.RequirePolicies("1.2.100"), + )).(jws.KeyProvider), + ), }, } var complexWorking = encodeDecodeTest{ desc: "complex, working", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), @@ -60,14 +66,17 @@ var complexWorking = encodeDecodeTest{ }, }, decOpts: []DecoderOption{ - TrustRootCAs(chainA.Root().Public), - RequirePolicies("1.2.100"), - Require( - // Show this verifier is called and works. - VerifierFunc( - func(_ context.Context, _ []*x509.Certificate, _ time.Time) error { - return nil - })), + WithKeyProvider( + must(jwskeychain.New( + jwskeychain.TrustedRoots(chainA.Root().Public), + jwskeychain.RequirePolicies("1.2.100"), + jwskeychain.Require( + jwskeychain.VerifierFunc( + func(_ context.Context, _ []*x509.Certificate, _ time.Time) error { + return nil + })), + )).(jws.KeyProvider), + ), }, } @@ -77,7 +86,7 @@ var encodeDecodeTests = []encodeDecodeTest{ { desc: "No trusted roots", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), @@ -86,19 +95,23 @@ var encodeDecodeTests = []encodeDecodeTest{ }, { desc: "Require() not met", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), }, decOpts: []DecoderOption{ - TrustRootCAs(chainA.Root().Public), - RequirePolicies("1.2.100"), - Require( - VerifierFunc( - func(_ context.Context, _ []*x509.Certificate, _ time.Time) error { - return errors.New("custom verifier failed") - })), + WithKeyProvider( + must(jwskeychain.New( + jwskeychain.TrustedRoots(chainA.Root().Public), + jwskeychain.RequirePolicies("1.2.100"), + jwskeychain.Require( + jwskeychain.VerifierFunc( + func(_ context.Context, _ []*x509.Certificate, _ time.Time) error { + return errors.New("custom verifier failed") + })), + )).(jws.KeyProvider), + ), }, decErr: errUnknown, }, { @@ -115,7 +128,7 @@ var encodeDecodeTests = []encodeDecodeTest{ }, { desc: "Signature with NoVerification(), working", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), @@ -132,23 +145,18 @@ var encodeDecodeTests = []encodeDecodeTest{ Payload: []byte("Hello, world."), }, decOpts: []DecoderOption{ - TrustRootCAs(chainA.Root().Public), + WithKeyProvider( + must(jwskeychain.New( + jwskeychain.TrustedRoots(chainA.Root().Public), + )).(jws.KeyProvider), + ), }, decErr: errUnknown, - }, { - desc: "Try using signing algorith none, should fail", - encOpts: []SignOption{ - SignWith("none", nil, nil), - }, - input: Message{ - Payload: []byte("Hello, world."), - }, - encErr: ErrInvalidSignAlg, }, { desc: "Try setting multiple signing algorithms, should fail", encOpts: []SignOption{ NoSignature(), - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), @@ -157,7 +165,7 @@ var encodeDecodeTests = []encodeDecodeTest{ }, { desc: "invalid response encryption algorithm", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), @@ -170,7 +178,7 @@ var encodeDecodeTests = []encodeDecodeTest{ }, { desc: "unsafe response encryption algorithm in the clear is not allowed", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), @@ -183,7 +191,7 @@ var encodeDecodeTests = []encodeDecodeTest{ }, { desc: "invalid response encryption key/alg combination", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainA.Leaf().Public, chainA.Leaf().Private, chainA.Included()...), + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), @@ -196,13 +204,49 @@ var encodeDecodeTests = []encodeDecodeTest{ }, { desc: "untrusted chain", encOpts: []SignOption{ - SignWithRaw(jwa.ES256, chainB.Leaf().Public, chainB.Leaf().Private, chainB.Included()...), + SignWithX509Chain(jwa.ES256, chainB.Leaf().Private, chainB.Included()), + }, + input: Message{ + Payload: []byte("Hello, world."), + }, + decOpts: []DecoderOption{ + WithKeyProvider( + must(jwskeychain.New( + jwskeychain.TrustedRoots(chainA.Root().Public), + )).(jws.KeyProvider), + ), + }, + decErr: errUnknown, + }, { + desc: "untrusted chain (mixed B intermediates)", + encOpts: []SignOption{ + SignWithX509Chain(jwa.ES256, chainA.Leaf().Private, chainB.Included()), + }, + input: Message{ + Payload: []byte("Hello, world."), + }, + decOpts: []DecoderOption{ + WithKeyProvider( + must(jwskeychain.New( + jwskeychain.TrustedRoots(chainA.Root().Public), + )).(jws.KeyProvider), + ), + }, + decErr: errUnknown, + }, { + desc: "untrusted chain (mixed A intermediates)", + encOpts: []SignOption{ + SignWithX509Chain(jwa.ES256, chainB.Leaf().Private, chainA.Included()), }, input: Message{ Payload: []byte("Hello, world."), }, decOpts: []DecoderOption{ - TrustRootCAs(chainA.Root().Public), + WithKeyProvider( + must(jwskeychain.New( + jwskeychain.TrustedRoots(chainA.Root().Public), + )).(jws.KeyProvider), + ), }, decErr: errUnknown, }, @@ -278,3 +322,10 @@ func runEncDecTest(t *testing.T, tt encodeDecodeTest) { require.Equal(want, msg) }) } + +func must(a any, err error) any { + if err != nil { + panic(err) + } + return a +} diff --git a/go.mod b/go.mod index 366dd17..a17ace6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/lestrrat-go/jwx/v2 v2.1.3 github.com/stretchr/testify v1.10.0 github.com/tinylib/msgp v1.2.5 - github.com/xmidt-org/jwskeychain v1.1.0 + github.com/xmidt-org/jwskeychain v1.2.0 ) require ( diff --git a/go.sum b/go.sum index 27ce311..0b04f0a 100644 --- a/go.sum +++ b/go.sum @@ -30,12 +30,10 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tinylib/msgp v1.2.5 h1:WeQg1whrXRFiZusidTQqzETkRpGjFjcIhW6uqWH09po= github.com/tinylib/msgp v1.2.5/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= -github.com/xmidt-org/jwskeychain v1.1.0 h1:WhC6AcVMcy5IuLWbcYp//GGRPDmOpXmLQS02NPNdUwc= -github.com/xmidt-org/jwskeychain v1.1.0/go.mod h1:aDQ9lGHwJYxCubgeJGXQQXGFkL3ZZStvfEHFEh4MO+Y= +github.com/xmidt-org/jwskeychain v1.2.0 h1:5lLSNon/6po9ObuJTNAwHzQU8D9HIjknRHgLxMmyUWU= +github.com/xmidt-org/jwskeychain v1.2.0/go.mod h1:aDQ9lGHwJYxCubgeJGXQQXGFkL3ZZStvfEHFEh4MO+Y= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/securly.go b/securly.go index 6e6e87e..ec76f88 100644 --- a/securly.go +++ b/securly.go @@ -130,12 +130,12 @@ func (m Message) Encode() (data []byte, isEncrypted bool, err error) { // Sign converts a Message into a slice of bytes and signs it using the // provided options. func (m Message) Sign(opts ...SignOption) ([]byte, error) { - enc, err := newEncoder(opts...) + enc, err := NewSigner(opts...) if err != nil { return nil, err } - return enc.encode(m) + return enc.Encode(m) } // Encrypt encrypts the message using the provided options. diff --git a/sign.go b/sign.go index 751bf8a..7fc7ada 100644 --- a/sign.go +++ b/sign.go @@ -4,26 +4,19 @@ package securly import ( - "crypto/x509" - "encoding/base64" - - "github.com/lestrrat-go/jwx/v2/cert" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" ) -type encoder struct { +// Signer is a type that can encode & sign a Message. +type Signer struct { doNotSign bool - leaf *x509.Certificate - intermediates []*x509.Certificate - signAlg jwa.SignatureAlgorithm - key jwk.Key + key jws.SignVerifyOption skipResponseCheck bool } -func newEncoder(opts ...SignOption) (*encoder, error) { - enc := encoder{} +// NewSigner creates a new Signer with the given options. +func NewSigner(opts ...SignOption) (*Signer, error) { + enc := Signer{} opts = append(opts, validateSigAlg()) @@ -39,7 +32,8 @@ func newEncoder(opts ...SignOption) (*encoder, error) { return &enc, nil } -func (enc *encoder) encode(m Message) ([]byte, error) { +// Encode encodes the given Message and signs it. +func (enc *Signer) Encode(m Message) ([]byte, error) { if err := m.Response.safeInTheClear(); err != nil { return nil, err } @@ -62,26 +56,8 @@ func (enc *encoder) encode(m Message) ([]byte, error) { return nil, err } } else { - // Build certificate chain. - var chain cert.Chain - for _, cert := range append([]*x509.Certificate{enc.leaf}, enc.intermediates...) { - err = chain.AddString(base64.URLEncoding.EncodeToString(cert.Raw)) - if err != nil { - return nil, err - } - } - - // Create headers and set x5c with certificate chain. - headers := jws.NewHeaders() - err = headers.Set(jws.X509CertChainKey, &chain) - if err != nil { - return nil, err - } - - key := jws.WithKey(enc.signAlg, enc.key, jws.WithProtectedHeaders(headers)) - // Sign the inner payload with the private key. - signed, err = jws.Sign(data, key) + signed, err = jws.Sign(data, enc.key) if err != nil { return nil, err } diff --git a/signopts.go b/signopts.go index 34c8b2e..7b1b635 100644 --- a/signopts.go +++ b/signopts.go @@ -8,123 +8,75 @@ import ( "fmt" "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/xmidt-org/jwskeychain" ) // SignOption is a functional option for the Instructions constructor. type SignOption interface { - apply(*encoder) error + apply(*Signer) error } -// SignWith sets the signing algorithm, public key, and private key used to sign -// the Message, as well as any intermediaries. -// -// The following combinations are valid (the public/private keys must match): -// - ES256, private: *ecdsa.PrivateKey -// - ES384, private: *ecdsa.PrivateKey -// - ES512, private: *ecdsa.PrivateKey -// - RS256, private: *rsa.PrivateKey -// - RS384, private: *rsa.PrivateKey -// - RS512, private: *rsa.PrivateKey -// - PS256, private: *rsa.PrivateKey -// - PS384, private: *rsa.PrivateKey -// - PS512, private: *rsa.PrivateKey -// - EdDSA, private: ed25519.PrivateKey -// -// Unfortunately, to make this work the private type needs to be an interface{}. -func SignWith(alg jwa.SignatureAlgorithm, - public *x509.Certificate, private jwk.Key, - intermediates ...*x509.Certificate, -) SignOption { - return signAlgOption{ - alg: alg, - public: public, - key: private, - intermediates: intermediates, - } -} - -func SignWithRaw(alg jwa.SignatureAlgorithm, - public *x509.Certificate, private any, - intermediates ...*x509.Certificate, -) SignOption { - key, err := jwk.FromRaw(private) - if err != nil { - return errorSign(err) - } +type errSignOptionFunc func(*Signer) error - return SignWith(alg, public, key, intermediates...) +func (f errSignOptionFunc) apply(enc *Signer) error { + return f(enc) } -type signAlgOption struct { - alg jwa.SignatureAlgorithm - public *x509.Certificate - key jwk.Key - intermediates []*x509.Certificate +func signOptionFunc(f func(*Signer)) SignOption { + return errSignOptionFunc(func(enc *Signer) error { + f(enc) + return nil + }) } -func (s signAlgOption) apply(enc *encoder) error { - if s.alg.IsSymmetric() || s.alg == jwa.NoSignature { - return ErrInvalidSignAlg - } +// SignWithX509Chain sets the signing algorithm, public key, and private key +// used to sign the Message, as well as any intermediaries. See +// [jwskeychain.Signer] for more details. +func SignWithX509Chain(alg jwa.SignatureAlgorithm, private any, certs []*x509.Certificate) SignOption { + return errSignOptionFunc(func(enc *Signer) error { + key, err := jwskeychain.Signer(alg, private, certs) + if err != nil { + return err + } - enc.signAlg = s.alg - enc.leaf = s.public - enc.key = s.key - enc.intermediates = s.intermediates - return nil + enc.key = key + return nil + }) +} + +// SignWithKey creates a signing key for the Message. See [jws.WithKey] for more +// details about how to use this option. +func SignWithKey(alg jwa.SignatureAlgorithm, key any, opts ...jws.WithKeySuboption) SignOption { + return signOptionFunc(func(enc *Signer) { + enc.key = jws.WithKey(alg, key, opts...) + }) } // NoSignature indicates that the Message should not be signed. It cannot be used // with any SignWith options or an error will be returned. This is to ensure that // the caller is aware that the Message will not be signed. func NoSignature() SignOption { - return noSignatureOption{} -} - -type noSignatureOption struct{} - -func (n noSignatureOption) apply(enc *encoder) error { - enc.doNotSign = true - return nil + return signOptionFunc(func(enc *Signer) { + enc.doNotSign = true + }) } //------------------------------------------------------------------------------ -func errorSign(err error) SignOption { - return errorSignOption{ - err: err, - } -} - -type errorSignOption struct { - err error -} - -func (e errorSignOption) apply(*encoder) error { - return e.err -} - func validateSigAlg() SignOption { - return validateSigAlgOption{} -} - -type validateSigAlgOption struct{} - -func (v validateSigAlgOption) apply(enc *encoder) error { - if enc.doNotSign { - if enc.signAlg != "" || - enc.leaf != nil || - enc.key != nil || - len(enc.intermediates) > 0 { - return fmt.Errorf("%w: NoSignature() must be used in isolation", ErrInvalidSignAlg) + return errSignOptionFunc(func(enc *Signer) error { + if enc.doNotSign { + if enc.key != nil { + return fmt.Errorf("%w: NoSignature() must be used in isolation", ErrInvalidSignAlg) + } + return nil } - return nil - } - if enc.signAlg == "" || enc.key == nil { - return fmt.Errorf("%w: algorithm and key are required", ErrInvalidSignAlg) - } + if enc.key == nil { + return fmt.Errorf("%w: key is required", ErrInvalidSignAlg) + } - return nil + return nil + }) } diff --git a/signopts_test.go b/signopts_test.go index 01039df..c80e2ff 100644 --- a/signopts_test.go +++ b/signopts_test.go @@ -4,91 +4,91 @@ package securly import ( - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/x509" "testing" "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestSignWithRaw(t *testing.T) { - // Generate test keys - rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) +func TestSignWithX509(t *testing.T) { + chain := mustGeneratecertChain("leaf<-ica<-root") - ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) + tests := []struct { + name string + alg jwa.SignatureAlgorithm + private any + certs []*x509.Certificate + err bool + }{ + { + name: "valid RS256 key and chain", + alg: jwa.RS256, + private: chain.Leaf().Private, + certs: chain.Included(), + }, { + name: "invalid symmetric key", + alg: jwa.HS256, + private: chain.Leaf().Private, + certs: chain.Included(), + err: true, + }, + } - _, ed25519PrivateKey, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opt := SignWithX509Chain(tt.alg, tt.private, tt.certs) - // Generate a test certificate - cert := &x509.Certificate{} + assert.NotNil(t, opt) + + var enc Signer + err := opt.apply(&enc) + + if tt.err { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + }) + } +} + +func TestSignWithKey(t *testing.T) { + // Generate test keys + rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) tests := []struct { - name string - alg jwa.SignatureAlgorithm - privateKey any - expectErr bool - expectedAlg jwa.SignatureAlgorithm - expectedKey any + name string + alg jwa.SignatureAlgorithm + privateKey any }{ { - name: "valid RS256 key", - alg: jwa.RS256, - privateKey: rsaPrivateKey, - expectErr: false, - expectedAlg: jwa.RS256, - }, - { - name: "valid ES256 key", - alg: jwa.ES256, - privateKey: ecdsaPrivateKey, - expectErr: false, - expectedAlg: jwa.ES256, - }, - { - name: "valid EdDSA key", - alg: jwa.EdDSA, - privateKey: ed25519PrivateKey, - expectErr: false, - expectedAlg: jwa.EdDSA, - }, - { - name: "invalid key type", + name: "valid RS256 key", alg: jwa.RS256, - privateKey: "invalid-key", - expectErr: true, + privateKey: rsaPrivateKey, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert := assert.New(t) - require := require.New(t) + opt := SignWithKey(tt.alg, tt.privateKey) - opt := SignWithRaw(tt.alg, cert, tt.privateKey) - enc := &encoder{} + assert.NotNil(t, opt) - err := opt.apply(enc) - if tt.expectErr { - assert.Error(err) - } else { - require.NoError(err) - assert.Equal(tt.expectedAlg, enc.signAlg) - assert.NotNil(enc.key) - } + var enc Signer + err := opt.apply(&enc) + + assert.NoError(t, err) }) } } +/* func TestValidateSigAlg(t *testing.T) { // Generate a test certificate cert := &x509.Certificate{} @@ -139,7 +139,7 @@ func TestValidateSigAlg(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - enc := &encoder{ + enc := &Encoder{ signAlg: tt.signAlg, doNotSign: tt.doNotSign, leaf: tt.leaf, @@ -158,3 +158,4 @@ func TestValidateSigAlg(t *testing.T) { }) } } +*/