From b73f618bb721fdebc8ae7657a053604046cb3b1c Mon Sep 17 00:00:00 2001 From: wecha Date: Tue, 10 Jun 2025 20:23:10 +0000 Subject: [PATCH 1/7] add poptoken support --- go.mod | 12 + go.sum | 24 + pkg/auth/auth.go | 15 + pkg/auth/faketokenprovider.go | 1 + pkg/auth/poptoken/jwkmanager.go | 60 ++ pkg/auth/poptoken/msalauthprovider.go | 94 ++++ pkg/auth/poptoken/nodeagentpoptokenscheme.go | 66 +++ .../poptoken/nodeagentpoptokenscheme_test.go | 61 +++ .../poptoken/nodeagentpoptokenvalidator.go | 327 +++++++++++ .../nodeagentpoptokenvalidator_test.go | 513 ++++++++++++++++++ pkg/auth/poptoken/poptokenauth.go | 44 ++ pkg/auth/poptoken/raskeymanager_test.go | 73 +++ pkg/auth/poptoken/rsakeymanager.go | 120 ++++ pkg/auth/poptoken/shrpoptoken.go | 206 +++++++ pkg/auth/poptoken/shrpoptoken_test.go | 211 +++++++ 15 files changed, 1827 insertions(+) create mode 100644 pkg/auth/poptoken/jwkmanager.go create mode 100644 pkg/auth/poptoken/msalauthprovider.go create mode 100644 pkg/auth/poptoken/nodeagentpoptokenscheme.go create mode 100644 pkg/auth/poptoken/nodeagentpoptokenscheme_test.go create mode 100644 pkg/auth/poptoken/nodeagentpoptokenvalidator.go create mode 100644 pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go create mode 100644 pkg/auth/poptoken/poptokenauth.go create mode 100644 pkg/auth/poptoken/raskeymanager_test.go create mode 100644 pkg/auth/poptoken/rsakeymanager.go create mode 100644 pkg/auth/poptoken/shrpoptoken.go create mode 100644 pkg/auth/poptoken/shrpoptoken_test.go diff --git a/go.mod b/go.mod index b3d4e591..306394fd 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,15 @@ module github.com/microsoft/moc go 1.23.0 require ( + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 github.com/go-logr/logr v1.4.2 github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.4 github.com/hectane/go-acl v0.0.0-20230122075934-ca0b05cb1adb github.com/jmespath/go-jmespath v0.4.0 + github.com/lestrrat-go/jwx v1.2.31 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.10.0 go.uber.org/multierr v1.11.0 @@ -19,8 +22,17 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/go-jose/go-jose/v4 v4.0.4 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/kr/text v0.2.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect + github.com/lestrrat-go/blackmagic v1.0.2 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect github.com/zeebo/errs v1.4.0 // indirect diff --git a/go.sum b/go.sum index 21decf8e..ca7bc80c 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,23 @@ +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E= github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= @@ -28,6 +36,21 @@ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= +github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx v1.2.31 h1:/OM9oNl/fzyldpv5HKZ9m7bTywa7COUfg8gujd9nJ54= +github.com/lestrrat-go/jwx v1.2.31/go.mod h1:eQJKoRwWcLg4PfD5CFA5gIZGxhPgoPYq9pZISdxLf0c= +github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -35,6 +58,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE= github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 38543f18..679f1df9 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "strings" + "github.com/microsoft/moc/pkg/auth/poptoken" "github.com/microsoft/moc/pkg/config" "github.com/microsoft/moc/pkg/marshal" "github.com/microsoft/moc/rpc/common" @@ -261,6 +262,20 @@ func NewPopTokenAuthorizer() (*CmpPopTokenAuthorizer, error) { }, nil } +// TODO wecha: This will the new constructor for the poptoken authorizer and replace existing +// NewPopTokenAuthorizer(). We can't switch to the new ctor until we update wssdcloud agent. +func NewPopTokenAuthorizerPopAuth(popTokenAuth *poptoken.PopTokenAuth) (*CmpPopTokenAuthorizer, error) { + tp, err := DisableTransportAuthorization() + if err != nil { + return nil, err + } + + return &CmpPopTokenAuthorizer{ + transportProvider: tp, + rpcProvider: popTokenAuth, + }, nil +} + func (c *CmpPopTokenAuthorizer) WithTransportAuthorization() credentials.TransportCredentials { return c.transportProvider } diff --git a/pkg/auth/faketokenprovider.go b/pkg/auth/faketokenprovider.go index 0ee79f15..51bf0b29 100644 --- a/pkg/auth/faketokenprovider.go +++ b/pkg/auth/faketokenprovider.go @@ -7,6 +7,7 @@ import ( type fakeTokenProvider struct { } +// TODO wecha: temp fake token auth provider. To be replaced with PopTokenAuth func NewFakeTokenProvier() (*fakeTokenProvider, error) { return &fakeTokenProvider{}, nil } diff --git a/pkg/auth/poptoken/jwkmanager.go b/pkg/auth/poptoken/jwkmanager.go new file mode 100644 index 00000000..0079d55a --- /dev/null +++ b/pkg/auth/poptoken/jwkmanager.go @@ -0,0 +1,60 @@ +package poptoken + +import ( + "context" + "crypto/rsa" + "fmt" + "time" + + "github.com/lestrrat-go/jwx/jwk" + "github.com/pkg/errors" +) + +const ( + IssuerPostfix = "common/discovery/keys" + RefreshJwkInterval = time.Hour * 24 +) + +// Wrapper around jwk library to retrieve and refresh the jwk endpoints from Entra/AAD +type JwkManager struct { + // STS JWK endpoint, e.g. "https://login.microsoftonline.com/common/discovery/keys" + jwkEndpoint string + ar *jwk.AutoRefresh +} + +type JwkInterface interface { + GetPublicKey(kid string) (*rsa.PublicKey, error) +} + +func (j *JwkManager) GetPublicKey(kid string) (*rsa.PublicKey, error) { + ctx := context.Background() + keys, err := j.ar.Fetch(ctx, j.jwkEndpoint) + if err != nil { + return nil, errors.Wrapf(err, "failed to look up jwk endpoint %s to retrieve keys", j.jwkEndpoint) + } + + key, ok := keys.LookupKeyID(kid) + if !ok { + return nil, fmt.Errorf("failed to find kid %s in jwk endpoint %s", kid, j.jwkEndpoint) + } + + var pKey rsa.PublicKey + if err := key.Raw(&pKey); err != nil { + return nil, err + } + + return &pKey, nil +} + +func NewJwkManager(authorityUrl string, refreshInterval time.Duration) (*JwkManager, error) { + jwkEndpoint := appendUrl(authorityUrl, IssuerPostfix) + ctx, _ := context.WithCancel(context.Background()) + ar := jwk.NewAutoRefresh(ctx) + ar.Configure(jwkEndpoint, jwk.WithMinRefreshInterval(refreshInterval)) + _, err := ar.Refresh(ctx, jwkEndpoint) + if err != nil { + return nil, fmt.Errorf("failed to refresh the jwk endpoint %s", jwkEndpoint) + } + + return &JwkManager{jwkEndpoint: jwkEndpoint, ar: ar}, nil +} diff --git a/pkg/auth/poptoken/msalauthprovider.go b/pkg/auth/poptoken/msalauthprovider.go new file mode 100644 index 00000000..605ef94f --- /dev/null +++ b/pkg/auth/poptoken/msalauthprovider.go @@ -0,0 +1,94 @@ +package poptoken + +import ( + "context" + "os" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/pkg/errors" +) + +// Msal client to generate the pop token. Note that the msal sdk does not provide pop token support +// out of the box, refer to NodeAgentPopTokenScheme. +type MsalAuthProvider struct { + clientId string + tenantId string + authorityUrl string + scope []string + clientCertPath string + rsaKeyManager *RsaKeyManager +} + +func (m MsalAuthProvider) refreshConfidentialClient() (*confidential.Client, error) { + pemData, err := os.ReadFile(m.clientCertPath) + if err != nil { + return nil, err + } + + cert, privateKey, err := confidential.CertFromPEM(pemData, "") + if err != nil { + return nil, err + } + + // the PEM file can contain multiple intermediate certificates to validate TLS cert chaining but we are only interested in our + // own certificate which is always the first one. See https://www.rfc-editor.org/rfc/rfc5246#section-7.4.2 + cert = cert[:1] + cred, err := confidential.NewCredFromCert(cert, privateKey) + if err != nil { + return nil, errors.Wrapf(err, "failed to create confidential credential from certificate") + } + + // use Subject NAame and Issuer (SN+I) authentication to request for token. For this to work, Withx5c() must be set + // to pass the certificate chain in the request header. + confidentialClient, err := confidential.New(m.authorityUrl, m.clientId, cred, confidential.WithX5C()) + if err != nil { + return nil, errors.Wrapf(err, "failed to create confidential client") + } + + return &confidentialClient, nil +} + +func (m MsalAuthProvider) GetToken(targetResourceId string) (string, error) { + + // TODO: the underlying client certificate will be refreshed, hence we need to also pick up the new certificate + // Longer run we can cache the client but for now we will refresh the client for every token call. + confidentialClient, err := m.refreshConfidentialClient() + if err != nil { + return "", err + } + + keyPair, err := m.rsaKeyManager.GetKeyPair() + if err != nil { + return "", errors.Wrapf(err, "failed to get keypair for pop token") + } + + popTokenScheme, err := NewNodeAgentPopTokenAuthScheme(targetResourceId, keyPair) + if err != nil { + return "", errors.Wrapf(err, "failed to create new pop token scheme") + } + + result, err := confidentialClient.AcquireTokenByCredential(context.Background(), m.scope, confidential.WithAuthenticationScheme(popTokenScheme)) + if err != nil { + return "", errors.Wrapf(err, "failed to get token") + } + return result.AccessToken, nil +} + +func NewMsalClient(clientId string, tenantId, authorityUrl string, clientCertPath string, rsaKeyManager *RsaKeyManager) (*MsalAuthProvider, error) { + m := &MsalAuthProvider{ + clientId: clientId, + tenantId: tenantId, + authorityUrl: appendUrl(authorityUrl, tenantId), + clientCertPath: clientCertPath, + scope: []string{appendUrl(clientId, ".default")}, // intentionally target itself as the pop token custom claim will contain the actual audience. + rsaKeyManager: rsaKeyManager, + } + + // sanity check to ensure client is setup correctly + _, err := m.refreshConfidentialClient() + if err != nil { + return nil, err + } + + return m, nil +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme.go b/pkg/auth/poptoken/nodeagentpoptokenscheme.go new file mode 100644 index 00000000..08cc6776 --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme.go @@ -0,0 +1,66 @@ +package poptoken + +import ( + "time" +) + +const ( + tokenType = "token_type" + reqCnf = "req_cnf" + resourceId = "resourceid" +) + +// Implements the interface for MSAL SDK to callback when creating the poptoken. +// See AuthenticationScheme interface in https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/main/apps/internal/oauth/ops/authority/authority.go#L146 +type NodeAgentPopTokenAuthScheme struct { + shrPopToken *ShrPopToken + keyPair *RsaKeyPair + ResourceId string +} + +// Return the claim containg the pop token kid that will be added to the Entra access token. +func (a *NodeAgentPopTokenAuthScheme) TokenRequestParams() map[string]string { + + reqCnfBase64, err := a.shrPopToken.GetReqCnf() + if err != nil { + return map[string]string{} + } + + return map[string]string{ + tokenType: a.shrPopToken.Header.Typ, + reqCnf: reqCnfBase64, + } +} + +// Return the keyId for MSAL to lookup for a cached access token. If it does not exist, MSAL will request a new access token +func (a *NodeAgentPopTokenAuthScheme) KeyID() string { + return a.shrPopToken.Header.Kid +} + +// Generate the pop token; adding in the accessToken generated by Entra. +func (a *NodeAgentPopTokenAuthScheme) FormatAccessToken(accessToken string) (string, error) { + // append accessToken and our custom claims to the pop token. + // Note custom claims should be compatible with JWT specs, we don't expect these claims to be complex + customClaims := map[string]interface{}{ + resourceId: a.ResourceId} + + return a.shrPopToken.GenerateToken(accessToken, time.Now(), customClaims) +} + +// Return the token type. Must be "pop" +func (a *NodeAgentPopTokenAuthScheme) AccessTokenType() string { + return a.shrPopToken.Header.Typ +} + +// Create a new instance of NodeAgentPopTokenAuthScheme. Pass in the custom claims to be set in the pop token here, e.g. resourceId +func NewNodeAgentPopTokenAuthScheme(resourceId string, rsaKeyPair *RsaKeyPair) (*NodeAgentPopTokenAuthScheme, error) { + shrPopToken, err := NewPopToken(rsaKeyPair) + if err != nil { + return nil, err + } + + return &NodeAgentPopTokenAuthScheme{ + shrPopToken: shrPopToken, + ResourceId: resourceId, + }, nil +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go new file mode 100644 index 00000000..7f764253 --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go @@ -0,0 +1,61 @@ +package poptoken + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// This test suite focus on the testing nodeagentpoptokenscheme is returning the expected values that MSAL expected +// the actual token generation is tested in shrpoptoken_test +func Test_NodeAgentPopTokenScheme(t *testing.T) { + expectedResourceId := "myresourceId" + + kmgr, err := NewRsaKeyManager(time.Hour) + assert.Nil(t, err) + + keypair, err := kmgr.GetKeyPair() + assert.Nil(t, err) + + // create a "reference" pop token that we can use to validate some of the nodeagentpoptokenscheme content since it + // should generate the same values + refPopToken, err := NewPopToken(keypair) + assert.Nil(t, err) + + // Generate nodeagent scheme + nodeAgentScheme, err := NewNodeAgentPopTokenAuthScheme(expectedResourceId, keypair) + + //validate AccessTokenType returns "pop" + assert.Equal(t, TokenType, nodeAgentScheme.AccessTokenType()) + assert.Equal(t, refPopToken.Header.Typ, nodeAgentScheme.AccessTokenType()) + + //Validate KeyID + assert.Equal(t, refPopToken.Header.Kid, nodeAgentScheme.KeyID()) + + // Validate TokenRequestParams returns a specific struct + reqCnf := nodeAgentScheme.TokenRequestParams() + + tokenType, ok := reqCnf["token_type"] + assert.True(t, ok) + assert.Equal(t, refPopToken.Header.Typ, tokenType) + + expectedCnf, err := refPopToken.GetReqCnf() + assert.Nil(t, err) + cnf, ok := reqCnf["req_cnf"] + assert.True(t, ok) + assert.Equal(t, expectedCnf, cnf) + + // Validate FormatAccessToken. Here we just check that the custom claim "resourceId" was added. + popToken, err := nodeAgentScheme.FormatAccessToken("accessToken") + assert.Nil(t, err) + assert.NotEmpty(t, popToken) + + toks := strings.Split(popToken, ".") + assert.Equal(t, 3, len(toks)) + body, err := decodeFromBase64[nodeAgentPopTokenBody](toks[1]) + assert.Nil(t, err) + assert.Equal(t, expectedResourceId, body.ResourceId) + +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go new file mode 100644 index 00000000..6cad5bed --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go @@ -0,0 +1,327 @@ +package poptoken + +import ( + "bytes" + "crypto" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "math/big" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/pkg/errors" +) + +const ( + TokenVersion1 = "1.0" + TokenVersion2 = "2.0" + PopTokenValidInterval = 1 * time.Hour //TODO: should we make this smaller? +) + +type nodeAgentPopTokenBody struct { + ShrPopTokenBody + // target resource Id. Expected to match ShrPopTokenValidator.TargetResourceId + ResourceId string `json:"resourceid"` +} + +// contains a subset of custom claims in Entra/AzureAD access tokens we want to validate. +// See https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#payload-claims +type AccessTokenCustomClaims struct { + // contains the public key kid use to sign the pop token. Verify this matches the kid in the poptoken body. + ReqCnf ReqCnf `json:"cnf"` + // requester Id, i.e. CMP 1P. Expected to match ShrPopTokenValidator.ClientId. Only valid for token v2 + Azp string `json:"azp"` + // requester Id, i.e. CMP 1P/identifuerUri. Expected to match ShrPopTokenValidator.ClientId. Only valid for token v1 + AppId string `json:"appid"` + // Tenant Id. Expected to match ShrPopTokenValidator.TenantId + Tid string `json:"tid"` + // token version. + TokenVersion string `json:"ver"` + jwt.RegisteredClaims +} + +type ShrPopTokenValidator struct { + // A4S agent resourceId + TargetResourceId string + // Tenant Id of CMP 1P + TenantId string + // The target Id. In this case, client Id or one of its identifierUri, depending on the token version. + Audience map[string]bool + // CMP 1P client Id + ClientId string + // Issuer url, e.g. https://login.microsoftonline.com/ + IssuerUrl string + // component use to query Entra JWK endpoint. + jwk JwkInterface +} + +// handle situation where url may or may not have a backslash +func appendUrl(url string, postfix string) string { + sep := "/" + if strings.HasSuffix(url, "/") { + sep = "" + } + return fmt.Sprintf("%s%s%s", url, sep, postfix) +} + +func isTokenExpire(timestamp int64, now time.Time) error { + var issuedTime time.Time + convertTime(timestamp, &issuedTime) + expireat := issuedTime.Add(PopTokenValidInterval) + if expireat.Before(now) { + return fmt.Errorf("pop token has expired. Time when validated: %v, issued At: %v, valid duration: %v", now, issuedTime, PopTokenValidInterval) + } + return nil +} + +func isHeaderValid(header *ShrPopHeader) error { + if header.Typ != TokenType { + return fmt.Errorf("expected token type %s, got %s", TokenType, header.Typ) + } + if header.Alg != Alg { + return fmt.Errorf("expected alg %s, got %s", Alg, header.Alg) + } + + return nil +} + +func isSignatureValid(signingStr *string, signature []byte, cnf *Cnf) error { + publicKey, err := publicRSA256KeyFromCnf(cnf) + if err != nil { + return err + } + return verifyPayload(signingStr, []byte(signature), publicKey) +} + +func verifyPayload(signingStr *string, sig []byte, pubKey *rsa.PublicKey) error { + hash := sha256.New() + hash.Write([]byte(*signingStr)) + err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash.Sum(nil), sig) + if err != nil { + return errors.Wrapf(err, "failed to validate signature of poptoken using cnf") + } + return nil +} + +func publicRSA256KeyFromCnf(cnf *Cnf) (*rsa.PublicKey, error) { + modulus, err := base64.URLEncoding.DecodeString(cnf.Jwk.N) + if err != nil { + err := errors.Wrapf(err, "error while parsing poptoken cnf: failed to decode modulus") + return nil, err + } + n := big.NewInt(0) + n.SetString(string(modulus), 10) + + e, err := base64ToExponential(string(cnf.Jwk.E)) + if err != nil { + err := errors.Wrapf(err, "error while parsing poptoken cnf: failed to parse exponent") + return nil, err + } + pKey := rsa.PublicKey{N: n, E: int(e)} + + return &pKey, nil +} + +func base64ToExponential(encodedE string) (int, error) { + decE, err := base64.URLEncoding.DecodeString(encodedE) + if err != nil { + return 0, err + } + + var eBytes []byte + if len(decE) < 8 { + eBytes = make([]byte, 8-len(decE), 8) + eBytes = append(eBytes, decE...) + } else { + eBytes = decE + } + eReader := bytes.NewReader(eBytes) + var ee uint64 + err = binary.Read(eReader, binary.BigEndian, &ee) + if err != nil { + return 0, err + } + + return int(ee), nil +} + +func decodeFromBase64[T any](jsonData string) (T, error) { + var t T + var err error + + bytes, err := base64.RawURLEncoding.DecodeString(jsonData) + if err != nil { + return t, err + } + + err = json.Unmarshal(bytes, &t) + if err != nil { + return t, err + } + + return t, nil +} + +func convertTime(i any, tm *time.Time) { + switch iat := i.(type) { + case float64: + *tm = time.Unix(int64(iat), 0) + case int64: + *tm = time.Unix(iat, 0) + case string: + v, _ := strconv.ParseInt(iat, 10, 64) + *tm = time.Unix(v, 0) + } +} + +func (s *ShrPopTokenValidator) parseAndValidateAccessToken(tokenStr string, popTokenKid string) error { + // ParseWithClaims() will validate the token expiry date and signing. + at, err := jwt.ParseWithClaims(tokenStr, &AccessTokenCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("failed to find kid in access token header") + } + + pKey, err := s.jwk.GetPublicKey(kid) + if err != nil { + return nil, err + } + + return pKey, nil + }, jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()})) + + if err != nil { + return errors.Wrapf(err, "failed to parse access token") + } + + err = s.validateAccessTokenClaims(at, popTokenKid) + return err +} + +func (s *ShrPopTokenValidator) validateAccessTokenClaims(token *jwt.Token, popTokenKid string) error { + if token == nil { + return fmt.Errorf("empty token in validateAccessTokenClaims!") + } + + // now read in Entra access token's specific claims and validate them + claims, ok := token.Claims.(*AccessTokenCustomClaims) + if !ok { + return fmt.Errorf("failed to retrieve expected claims in access token") + } + + // Handle claims that are specifc to token versions + switch claims.TokenVersion { + case TokenVersion1: + if claims.AppId != s.ClientId { + return fmt.Errorf("invalid appId claim") + } + if claims.Issuer != s.IssuerUrl { + return fmt.Errorf("invalid issuer") + } + case TokenVersion2: + if claims.Azp != s.ClientId { + return fmt.Errorf("invalid azp claim") + } + // for v2, issuer ends with v2.0 + if claims.Issuer != appendUrl(s.IssuerUrl, "v2.0") { + return fmt.Errorf("invalid issuer for v2 token") + } + + default: + return fmt.Errorf("unknown token version %s. expected either %sor %s", claims.TokenVersion, TokenVersion1, TokenVersion2) + } + + if claims.ReqCnf.Kid != popTokenKid { + return fmt.Errorf("kid in pop token did not match kid in access token. expected kid: %s, got kid: %s", claims.ReqCnf.Kid, popTokenKid) + } + + foundAud := false + for _, aud := range claims.Audience { + if _, ok := s.Audience[aud]; ok { + foundAud = true + break + } + } + if !foundAud { + return fmt.Errorf("aud claim was not expected") + } + + return nil +} + +func (s *ShrPopTokenValidator) isCustomClaimsValid(body *nodeAgentPopTokenBody) error { + if body.ResourceId != s.TargetResourceId { + return fmt.Errorf("invalid resourceId") + } + + return nil +} + +func (s *ShrPopTokenValidator) Validate(popToken string) error { + toks := strings.Split(popToken, ".") + if len(toks) != 3 { + return fmt.Errorf("invalid pop tokens expected 3 segments, got %d", len(toks)) + } + + header, err := decodeFromBase64[ShrPopHeader](toks[0]) + if err != nil { + return err + } + + if err := isHeaderValid(&header); err != nil { + return err + } + + body, err := decodeFromBase64[nodeAgentPopTokenBody](toks[1]) + if err != nil { + return err + } + + if err := isTokenExpire(body.Ts, time.Now()); err != nil { + return err + } + + if err := s.isCustomClaimsValid(&body); err != nil { + return err + } + + signature, err := base64.RawURLEncoding.DecodeString(toks[2]) + if err != nil { + return err + } + signingStr := strings.Join([]string{toks[0], toks[1]}, ".") + err = isSignatureValid(&signingStr, signature, &body.Cnf) + if err != nil { + return err + } + + // now retrieve the inner access token + err = s.parseAndValidateAccessToken(body.At, body.Cnf.Jwk.Kid) + if err != nil { + return err + } + + return nil +} + +func NewPopTokenValidator(targetResourceId string, tenantId string, audiences []string, clientId string, authorityUrl string, jwk JwkInterface) (*ShrPopTokenValidator, error) { + audienceMap := make(map[string]bool) + for _, aud := range audiences { + audienceMap[aud] = true + } + + return &ShrPopTokenValidator{ + TargetResourceId: targetResourceId, + TenantId: tenantId, + Audience: audienceMap, + ClientId: clientId, + IssuerUrl: appendUrl(authorityUrl, tenantId), + jwk: jwk, + }, nil +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go new file mode 100644 index 00000000..787db388 --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go @@ -0,0 +1,513 @@ +package poptoken + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func Test_NodeAgentPopTokenValidatorAppendUrl(t *testing.T) { + + tests := []struct { + name string + url string + postfix string + expectedUrl string + }{ + { + name: "without backslash at end", + url: "http://localhost", + postfix: "myapi", + expectedUrl: "http://localhost/myapi", + }, + { + name: "with backslash at end", + url: "http://localhost/", + postfix: "myapi", + expectedUrl: "http://localhost/myapi", + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualUrl := appendUrl(tt.url, tt.postfix) + assert.Equal(t, tt.expectedUrl, actualUrl) + }) + } +} + +func Test_NodeAgentPopTokenValidatorIsTokenExpire(t *testing.T) { + tokenIssuedAt, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") + assert.Nil(t, err) + tokenIssuedAtInt := tokenIssuedAt.Truncate(time.Second).Unix() + + tests := []struct { + name string + tokenCheckAt time.Time + shouldPass bool + }{ + { + name: "token valid", + // set token evaluation time to be 1 second after token was issued, token is valid. + tokenCheckAt: tokenIssuedAt.Add(time.Second * 1), + shouldPass: true, + }, + { + name: "token expired", + // set token evaluation time to be 10 seconds after max valid period. token has expired. + tokenCheckAt: tokenIssuedAt.Add(PopTokenValidInterval).Add(time.Second * 10), + shouldPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isTokenExpire(tokenIssuedAtInt, tt.tokenCheckAt) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +func Test_NodeAgentPopTokenValidatorIsHeaderValid(t *testing.T) { + tests := []struct { + name string + header ShrPopHeader + shouldPass bool + }{ + { + name: "valid header", + shouldPass: true, + header: ShrPopHeader{Alg: Alg, Typ: TokenType}, + }, + { + name: "invalid alg", + shouldPass: false, + header: ShrPopHeader{Alg: "RSA123", Typ: TokenType}, + }, + { + name: "invalid typ", + shouldPass: false, + header: ShrPopHeader{Alg: Alg, Typ: "jwt"}, + }, + { + name: "empty header", + shouldPass: false, + header: ShrPopHeader{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isHeaderValid(&tt.header) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +func Test_NodeAgentPopTokenValidatorIsSignatureValid(t *testing.T) { + keypair, err := getKeyPair() + assert.Nil(t, err) + + payload := []byte("ThisIsPayload") + payloadStr := string(payload) + + sigEncoded, err := signPayload(payload, keypair.PrivateKey) + assert.Nil(t, err) + sig, err := base64.RawURLEncoding.DecodeString(sigEncoded) + assert.Nil(t, err) + + cnf := publicKeyToCnf(keypair) + + // signature is correctly validated given correct payload and public key. + err = isSignatureValid(&payloadStr, sig, cnf) + assert.Nil(t, err) + + // simulate a mangled payload, expect failure + badPayload := payloadStr + "baddata" + + err = isSignatureValid(&badPayload, sig, cnf) + assert.NotNil(t, err) + + // simulate bad sig, expect failure + badSig := string(sig) + "1111" + + err = isSignatureValid(&payloadStr, []byte(badSig), cnf) + assert.NotNil(t, err) + + // simulate wrong public key, expect failure + newKeyPair, err := getKeyPair() + assert.Nil(t, err) + misMatachCnf := publicKeyToCnf(newKeyPair) + + err = isSignatureValid(&payloadStr, sig, misMatachCnf) + assert.NotNil(t, err) +} + +func Test_NodeAgentPopTokenValidatorbase64ToExponential(t *testing.T) { + encodedExponential := "AQAB" + e, err := base64ToExponential(encodedExponential) + assert.Nil(t, err) + // this is the decoded value of a well known exponential value + assert.Equal(t, 65537, e) +} + +func Test_NodeAgentPopTokenValidatorIsCustomClaimsValid(t *testing.T) { + expectedResourceId := "myResourceId" + + tests := []struct { + name string + actualResourceId string + shouldPass bool + }{ + { + name: "valid resource Id claim", + actualResourceId: expectedResourceId, + shouldPass: true, + }, + { + name: "invalid resourceId claim", + actualResourceId: "somethingelse", + shouldPass: false, + }, + { + name: "missing resourceId claim", + actualResourceId: "", + shouldPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // for this test, we only care about setting the targetResourceId + popTokenValidator, err := NewPopTokenValidator(expectedResourceId, "", []string{"aud"}, "", "", nil) + assert.Nil(t, err) + // likewise we only set the resourceId in the poptoken body + popTokenBody := nodeAgentPopTokenBody{ResourceId: tt.actualResourceId} + + err = popTokenValidator.isCustomClaimsValid(&popTokenBody) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +// Test the access token validation against a range of invalid claims. +func Test_NodeAgentPopTokenValidatorParseAndValidateAccessToken(t *testing.T) { + expectedAuthorityUrl := "https://login.fake.microsoftonline.com" + expectedTenantId := "cmpTenantId" + expectedClientId := "cmpClientId" + expectedAudience := "cmpAudience" + expectedIssuedTime := time.Now() + expectedPopTokenKid := "defaultKid" + + defaultPKey, _ := getPrivateKey() + defaultJwkMgr := &FakeJwrMgr{PublicKey: &defaultPKey.PublicKey} + + // By default, the accesstoken and validator will use the same expected values as listed above + // hence the access token validation will succeeded. + + // test is setup such that each invalid claim declared will override the access token's ciaim, causing it to be invalid. + tests := []struct { + name string + invalidAuthorityUrl string + invalidTenantId string + invalidClientId string + invalidAudience string + invalidIssuerUrl string + invalidPopTokenKid string + tokenVersion string + isInvalidSigning bool + isMissingKid bool + isExpiredToken bool + shouldPass bool + }{ + { + name: "valid access token v1", + tokenVersion: TokenVersion1, + shouldPass: true, + }, + { + name: "valid access token v1", + tokenVersion: TokenVersion2, + shouldPass: true, + }, + { + name: "invalid tenantId", + tokenVersion: TokenVersion2, + invalidTenantId: "badTenantId", + shouldPass: false, + }, + { + name: "invalid clientId", + tokenVersion: TokenVersion2, + invalidClientId: "badClientId", + shouldPass: false, + }, + { + name: "invalid audience", + tokenVersion: TokenVersion2, + invalidAudience: "badAudienceId", + shouldPass: false, + }, + { + name: "invalid pop token id", + tokenVersion: TokenVersion2, + invalidPopTokenKid: "badPopTokenId", + shouldPass: false, + }, + { + name: "invalid signing", + tokenVersion: TokenVersion2, + isInvalidSigning: true, + shouldPass: false, + }, + { + name: "token expired", + tokenVersion: TokenVersion2, + isExpiredToken: true, + shouldPass: false, + }, + { + name: "missing kid", + tokenVersion: TokenVersion2, + isMissingKid: true, + shouldPass: false, + }, + { + name: "invalid issuer", + tokenVersion: TokenVersion2, + invalidAuthorityUrl: "https://bad.issuer", + shouldPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tenantId := expectedTenantId + if tt.invalidTenantId != "" { + tenantId = tt.invalidTenantId + } + + clientId := expectedClientId + if tt.invalidClientId != "" { + clientId = tt.invalidClientId + } + + // to simulate expired token, we make it invalid after one second and sleep before validating token + newexpiry := time.Second + expiredTime := expectedIssuedTime.Add(time.Hour) + if tt.isExpiredToken { + expiredTime = expectedIssuedTime.Add(newexpiry) + } + + tokenVersion := tt.tokenVersion + + authorityUrl := expectedAuthorityUrl + if tt.invalidAuthorityUrl != "" { + authorityUrl = tt.invalidAuthorityUrl + } + + audience := expectedAudience + if tt.invalidAudience != "" { + audience = tt.invalidAudience + } + + issuserUrl := fmt.Sprintf("%s/%s", authorityUrl, tenantId) + if tokenVersion == TokenVersion2 { + issuserUrl = fmt.Sprintf("%s/v2.0", issuserUrl) + } + + popTokenKid := expectedPopTokenKid + if tt.invalidPopTokenKid != "" { + popTokenKid = tt.invalidPopTokenKid + } + + jwkMgr := defaultJwkMgr + if tt.isInvalidSigning { + newPKey, _ := getPrivateKey() + jwkMgr = &FakeJwrMgr{PublicKey: &newPKey.PublicKey} + } + + if tt.isMissingKid { + jwkMgr = &FakeJwrMgr{Err: fmt.Errorf("failed to find key")} + } + + // access token to be generated can contain invalid claims depending on the test param, + // by default it will use the same data as the token validator. + claims := AccessTokenCustomClaims{ + Tid: tenantId, + ReqCnf: ReqCnf{Kid: popTokenKid}, + Azp: clientId, + AppId: clientId, + TokenVersion: tokenVersion, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuserUrl, + ExpiresAt: jwt.NewNumericDate(expiredTime), + IssuedAt: jwt.NewNumericDate(expectedIssuedTime), + NotBefore: jwt.NewNumericDate(expectedIssuedTime), + Audience: jwt.ClaimStrings{audience}, + Subject: clientId, + }, + } + + // The token validator is set to the expected values, except for invalid jwk + tokenValidator, err := NewPopTokenValidator( + "notused", // this is not tested here. + expectedTenantId, + []string{expectedAudience}, + expectedClientId, + expectedAuthorityUrl, + jwkMgr) + assert.Nil(t, err) + + s, err := generateAccessToken(&claims, defaultPKey) + assert.Nil(t, err) + + // sleep for a while to ensure token expires + if tt.isExpiredToken { + time.Sleep(newexpiry * 2) + } + + err = tokenValidator.parseAndValidateAccessToken(s, expectedPopTokenKid) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +// Test the overall validate function. We have already tested the individual functions that makes up this call +// this is just a simple test to validate end to end. +func Test_NodeAgentPopTokenValidatorValidate(t *testing.T) { + authorityUrl := "https://login.fake.microsoftonline.com" + resourceId := "resourceId" + tenantId := "cmpTenantId" + clientId := "cmpClientId" + audience := "cmpAudience" + issuedTime := time.Now() + expiredTime := issuedTime.Add(time.Hour) + tokenVersion := TokenVersion2 + issuerUrl := fmt.Sprintf("%s/%s/v2.0", authorityUrl, tenantId) + + accessTokenPKey, err := getPrivateKey() + assert.Nil(t, err) + jwkMgr := &FakeJwrMgr{PublicKey: &accessTokenPKey.PublicKey} + + rsaKeyPair, err := getKeyPair() + assert.Nil(t, err) + + // partial generate pop token, we need to add the popKid into the accesstoken + popToken, err := NewPopToken(rsaKeyPair) + assert.Nil(t, err) + + // Generate access token + claims := AccessTokenCustomClaims{ + Tid: tenantId, + ReqCnf: ReqCnf{Kid: popToken.Header.Kid}, + Azp: clientId, + AppId: clientId, + TokenVersion: tokenVersion, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuerUrl, + ExpiresAt: jwt.NewNumericDate(expiredTime), + IssuedAt: jwt.NewNumericDate(issuedTime), + NotBefore: jwt.NewNumericDate(issuedTime), + Audience: jwt.ClaimStrings{audience}, + Subject: clientId, + }, + } + at, err := generateAccessToken(&claims, accessTokenPKey) + assert.Nil(t, err) + + // Generate pop token + pt, err := popToken.GenerateToken(at, time.Now(), map[string]interface{}{"resourceId": resourceId}) + assert.Nil(t, err) + + // validate poptoken + tokenValidator, err := NewPopTokenValidator( + resourceId, + tenantId, + []string{audience}, + clientId, + authorityUrl, + jwkMgr) + assert.Nil(t, err) + + err = tokenValidator.Validate(pt) + assert.Nil(t, err) +} + +func generateAccessToken(claims *AccessTokenCustomClaims, privateKey *rsa.PrivateKey) (string, error) { + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": claims.Issuer, + "sub": claims.Subject, + "aud": claims.Audience, + "exp": claims.ExpiresAt, + "nbf": claims.NotBefore, + "iat": claims.IssuedAt, + "cnf": claims.ReqCnf, + "azp": claims.Azp, + "appId": claims.AppId, + "ver": claims.TokenVersion, + "tid": claims.Tid, + }) + + // we don't care about the actual kid value here since we use a fake jwkMgr that ignores the kidm but we do + // expect it to be present in the header. + tok.Header["kid"] = "notused" + at, err := tok.SignedString(privateKey) + if err != nil { + return "", err + } + + return at, err +} + +// fake jwkMgr that will return either the public key or error +type FakeJwrMgr struct { + PublicKey *rsa.PublicKey + Err error +} + +func (j *FakeJwrMgr) GetPublicKey(kid string) (*rsa.PublicKey, error) { + if j.Err != nil { + return nil, j.Err + } else { + return j.PublicKey, nil + } +} + +func getPrivateKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, RsaSize) +} + +func publicKeyToCnf(keyPair *RsaKeyPair) *Cnf { + return &Cnf{ + Jwk: Jwk{ + JwkInner: JwkInner{ + Kty: keyPair.Kty, + E: exponential2Base64(keyPair.PublicKey.E), + N: base64.URLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.String())), + }, + }, + } +} diff --git a/pkg/auth/poptoken/poptokenauth.go b/pkg/auth/poptoken/poptokenauth.go new file mode 100644 index 00000000..528c59b8 --- /dev/null +++ b/pkg/auth/poptoken/poptokenauth.go @@ -0,0 +1,44 @@ +package poptoken + +import ( + "context" + + "github.com/microsoft/moc/pkg/errors" +) + +/* +The setup of the pop token creaton is as follows: + PopTokenAuth (interface betwen grpc and msalauthprovider) + | + --> MsalAuthProvider (global component that request the token from Entra/AzureAAD via MSAL SDK) + | + --> NodeAgentPopTokenAuthScheme (implements callback MSAL requires to generate the pop token) + | + --> ShrPopToken (does most of the heavy lifing in generating the pop token) +*/ + +// This component integrates the MSAL provider to the grpc credentials.PerRPCCredentials interface +type PopTokenAuth struct { + msalauthprovider *MsalAuthProvider + targetResourceId string +} + +func NewPopTokenAuth(msalProvider *MsalAuthProvider, targetResourceId string) (*PopTokenAuth, error) { + return &PopTokenAuth{ + msalauthprovider: msalProvider, + targetResourceId: targetResourceId, + }, nil +} + +func (p *PopTokenAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + accessToken, err := p.msalauthprovider.GetToken(p.targetResourceId) + if err != nil { + return nil, errors.Wrapf(err, "failed to generate poptoken") + } + + return map[string]string{"authorization": accessToken}, nil +} + +func (p *PopTokenAuth) RequireTransportSecurity() bool { + return true +} diff --git a/pkg/auth/poptoken/raskeymanager_test.go b/pkg/auth/poptoken/raskeymanager_test.go new file mode 100644 index 00000000..abebbdfa --- /dev/null +++ b/pkg/auth/poptoken/raskeymanager_test.go @@ -0,0 +1,73 @@ +package poptoken + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_RsaKeyPairGetKeyPair(t *testing.T) { + rsamgr, err := NewRsaKeyManager(time.Hour * 1) + assert.Nil(t, err) + + rsa, err := rsamgr.GetKeyPair() + assert.Nil(t, err) + + assert.Equal(t, Alg, rsa.Alg) + assert.Equal(t, Kty, rsa.Kty) + assert.Equal(t, RsaSize, rsa.RsaSize) + assert.NotNil(t, rsa.PrivateKey) + assert.NotNil(t, rsa.PublicKey) + + // now get the keypair a second time, if it has not refreshed, it will be the same value + rsa2, err := rsamgr.GetKeyPair() + //validate private key are same + assert.Equal(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) +} + +func Test_RsaKeyPairRefresh(t *testing.T) { + rsamgr, err := NewRsaKeyManager(time.Second * 1) + assert.Nil(t, err) + + rsa, err := rsamgr.GetKeyPair() + assert.Nil(t, err) + + time.Sleep(time.Second * 2) + + rsa2, err := rsamgr.GetKeyPair() + //validate private key are different the second time we ger it + assert.NotEqual(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) +} + +func Test_RsaKeyPairForceRefresh(t *testing.T) { + rsamgr, err := NewRsaKeyManager(time.Hour * 1) + assert.Nil(t, err) + + rsa, err := rsamgr.GetKeyPair() + assert.Nil(t, err) + + rsamgr.ForceRefresh() + // wait for some time for it to respond, note the sleep here is far less than the refesh interval of 1 hour + time.Sleep(time.Second * 1) + + rsa2, err := rsamgr.GetKeyPair() + //validate private key are different the second time we ger it + assert.NotEqual(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) +} + +// validate keymanager will not deadlock if refresh happen quicker than get call. +func Test_RsaKeyPairNoDeadLock(t *testing.T) { + rsamgr, err := NewRsaKeyManager(time.Hour * 1) + assert.Nil(t, err) + + for i := 0; i < 5; i++ { + rsamgr.ForceRefresh() + // wait for some time for it to respond + time.Sleep(time.Second * 1) + } + + rsa2, err := rsamgr.GetKeyPair() + assert.Nil(t, err) + assert.NotNil(t, rsa2.PrivateKey) +} diff --git a/pkg/auth/poptoken/rsakeymanager.go b/pkg/auth/poptoken/rsakeymanager.go new file mode 100644 index 00000000..5daff647 --- /dev/null +++ b/pkg/auth/poptoken/rsakeymanager.go @@ -0,0 +1,120 @@ +package poptoken + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "time" +) + +type RsaKeyPair struct { + PrivateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey + RsaSize int + Kty string + Alg string +} + +// a RSA Key generator that refresh the RSA KeyPair at regular interval +// Used to ensure the keys use to sign the poptoken are rotated +type RsaKeyManager struct { + refreshInterval time.Duration + refreshTicker *time.Ticker + keyPairChan chan *rsa.PrivateKey + forceRefreshChan chan bool + stopChan chan bool + privateKey *rsa.PrivateKey +} + +const ( + RsaSize = 2048 + Kty = "RSA" + Alg = "RS256" +) + +func generatePrivateKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, RsaSize) +} + +// Background job to continuously refresh the keypair in a best effort basis. +func (r *RsaKeyManager) refreshPrivateKeyJob() { + for { + select { + case <-r.stopChan: + return + case <-r.refreshTicker.C: + r.tryRefreshPrivateKey() + case <-r.forceRefreshChan: + r.tryRefreshPrivateKey() + } + } +} + +// Generate new keypair and send it back to the main go routine. +func (r *RsaKeyManager) tryRefreshPrivateKey() { + privateKey, err := generatePrivateKey() + // generatePrivateKey() should not fail, we don't have a good way to surface this error + if err == nil { + // In the unlikely event the refresh rate happens faster than getting the key, + // drop the key to prevent deadlocking the channel + select { + case r.keyPairChan <- privateKey: + default: + // imply key is dropped if channel is full. + } + } +} + +// Return a KeyPair. The keypair is its own copy and not a reference. +func (r *RsaKeyManager) GetKeyPair() (*RsaKeyPair, error) { + // non blocking wait to get new private key if available + select { + case r.privateKey = <-r.keyPairChan: + default: + //continue to use existing key + } + + // Create and return a deep copy of the private key so clients are not impacted by a rotation midway. + privateKeyBytes := x509.MarshalPKCS1PrivateKey(r.privateKey) + privateKeyCopy, err := x509.ParsePKCS1PrivateKey(privateKeyBytes) + if err != nil { + return nil, err + } + + return &RsaKeyPair{ + PrivateKey: privateKeyCopy, + PublicKey: privateKeyCopy.Public().(*rsa.PublicKey), + RsaSize: RsaSize, + Kty: Kty, + Alg: Alg, + }, nil +} + +// Force a refresh now. This can be use during test. +func (r *RsaKeyManager) ForceRefresh() { + r.forceRefreshChan <- true +} + +// Stop the refresh of the keypair. +func (r *RsaKeyManager) Stop() { + r.stopChan <- true +} + +// Create a new RSAKeyManager that will refresh the keypair in the background. +func NewRsaKeyManager(refreshInterval time.Duration) (*RsaKeyManager, error) { + var err error + rsaMgr := &RsaKeyManager{} + + rsaMgr.refreshInterval = refreshInterval + rsaMgr.refreshTicker = time.NewTicker(rsaMgr.refreshInterval) + rsaMgr.privateKey, err = generatePrivateKey() + if err != nil { + return nil, err + } + rsaMgr.forceRefreshChan = make(chan bool) + rsaMgr.stopChan = make(chan bool) + rsaMgr.keyPairChan = make(chan *rsa.PrivateKey, 2) + + go rsaMgr.refreshPrivateKeyJob() + return rsaMgr, nil +} diff --git a/pkg/auth/poptoken/shrpoptoken.go b/pkg/auth/poptoken/shrpoptoken.go new file mode 100644 index 00000000..ccd11824 --- /dev/null +++ b/pkg/auth/poptoken/shrpoptoken.go @@ -0,0 +1,206 @@ +package poptoken + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/json" + "reflect" + "strings" + "time" +) + +type ShrPopHeader struct { + // RSA PS256? + Alg string `json:"alg"` + // key Id of public key + Kid string `json:"kid"` + // always pop + Typ string `json:"typ"` +} + +// https://datatracker.ietf.org/doc/html/rfc7638#section-3.1 +// contains the metadata use to calculate kid. +type JwkInner struct { + // Exponent + E string `json:"e"` + // encryption + Kty string `json:"kty"` + // modulus + N string `json:"n"` +} + +type Jwk struct { + JwkInner + // public key kid + Kid string `json:"kid"` +} + +// https://datatracker.ietf.org/doc/html/rfc7800#section-3.2 +type Cnf struct { + Jwk Jwk `json:"jwk"` + Xms_ksl string `json:"xms_ksl"` +} + +type ReqCnf struct { + Kid string `json:"kid"` +} + +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-signed-http-request-03#section-3 +type ShrPopTokenBody struct { + Cnf Cnf `json:"cnf"` + // timestamp + Ts int64 `json:"ts"` + // access token + At string `json:"at"` + // random unique value to prevent replay attack. not used + NonCe string `json:"nonce"` +} + +// Implements the shr pop token generically. Callers of ths instance can add their own custom claims when generating the token. +type ShrPopToken struct { + Header ShrPopHeader + Body ShrPopTokenBody + ReqCnf ReqCnf + RSAKeyPair *RsaKeyPair +} + +const ( + TokenType = "pop" +) + +func calculatePublicKeyId(jwkInner *JwkInner) (string, error) { + // - https://tools.ietf.org/html/rfc7638#section-3.1 + jwkByte, err := json.Marshal(jwkInner) + if err != nil { + return "", err + } + jwk256 := sha256.Sum256(jwkByte) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(jwk256[:]), nil +} + +func jsonToBase64(v any) (string, error) { + jsonValue, err := json.Marshal(v) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(jsonValue), nil +} + +func signPayload(payload []byte, rsaKey *rsa.PrivateKey) (string, error) { + hash := sha256.New() + _, err := hash.Write(payload) + if err != nil { + return "", err + } + sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hash.Sum(nil)) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(sigBytes), nil +} + +func exponential2Base64(e int) string { + bs := make([]byte, 4) + binary.BigEndian.PutUint32(bs, uint32(e)) + + bs = bs[1:] // drop most significant byte - leaving least-significant 3-bytes + ss := base64.URLEncoding.EncodeToString(bs) + return ss +} + +// Append custom claims to the existing ShrPopTokenBody. +func (pop *ShrPopToken) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { + + bodyMap := make(map[string]interface{}) + + // first convert the existing body to a map of interface. + val := reflect.ValueOf(pop.Body) + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + if name := strings.ToLower(typ.Field(i).Name); name != "" { + bodyMap[name] = val.Field(i).Interface() + } + } + // now append the custom claims + for k, v := range customClaims { + bodyMap[k] = v + } + + return bodyMap +} + +// Complete the poptoken creation by adding the custom claims and signing it. +func (pop *ShrPopToken) GenerateToken(token string, now time.Time, customClaims map[string]interface{}) (string, error) { + + pop.Body.Ts = now.Truncate(time.Second).Unix() + pop.Body.At = token + + body, err := jsonToBase64(pop.appendCustomClaimsToBody(customClaims)) + if err != nil { + return "", err + } + + header, err := jsonToBase64(pop.Header) + if err != nil { + return "", err + } + + signingStr := strings.Join([]string{header, body}, ".") + + signature, err := signPayload([]byte(signingStr), pop.RSAKeyPair.PrivateKey) + if err != nil { + return "", nil + } + + return strings.Join([]string{signingStr, signature}, "."), nil +} + +// Generate ReqCnf to be passed to Msal +func (pop *ShrPopToken) GetReqCnf() (string, error) { + refCnfb64, err := jsonToBase64(pop.ReqCnf) + if err != nil { + return "", err + } + return refCnfb64, nil +} + +// Create a new instance of ShrPopToken. This generate a partial filled, generic shrpoptoken. The custom claims will be +// added later on in GenerateToken() +func NewPopToken(keyPair *RsaKeyPair) (*ShrPopToken, error) { + pop := ShrPopToken{ + Header: ShrPopHeader{ + Alg: keyPair.Alg, + Typ: TokenType, + }, + Body: ShrPopTokenBody{ + Cnf: Cnf{ + Jwk: Jwk{ + JwkInner: JwkInner{ + Kty: keyPair.Kty, + E: exponential2Base64(keyPair.PublicKey.E), + N: base64.URLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.String())), + }, + }, + }, + }, + ReqCnf: ReqCnf{}, + RSAKeyPair: keyPair, + } + + keyId, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk.JwkInner) + if err != nil { + return nil, err + } + pop.Header.Kid = keyId + pop.ReqCnf.Kid = keyId + pop.Body.Cnf.Jwk.Kid = keyId + + return &pop, err +} diff --git a/pkg/auth/poptoken/shrpoptoken_test.go b/pkg/auth/poptoken/shrpoptoken_test.go new file mode 100644 index 00000000..dd4ab84a --- /dev/null +++ b/pkg/auth/poptoken/shrpoptoken_test.go @@ -0,0 +1,211 @@ +package poptoken + +import ( + "crypto" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type testStruct struct { + StrValue string `json:"str"` + IntValue int `json:"int"` +} + +// the pop token is partially filled out upon calling NewPopToken +func Test_ShrPopTokenNewPopToken(t *testing.T) { + keypair, err := getKeyPair() + assert.Nil(t, err) + + pop, err := NewPopToken(keypair) + assert.Nil(t, err) + + // calculate kid + expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk.JwkInner) + assert.Nil(t, err) + + // check header + assert.Equal(t, Alg, pop.Header.Alg) + assert.Equal(t, TokenType, pop.Header.Typ) + assert.Equal(t, expectedKid, pop.Header.Kid) + + // check body + expectedE := exponential2Base64(keypair.PrivateKey.E) + expectedN := base64.URLEncoding.EncodeToString([]byte(keypair.PublicKey.N.String())) + assert.Equal(t, expectedE, pop.Body.Cnf.Jwk.E) + assert.Equal(t, expectedN, pop.Body.Cnf.Jwk.N) + assert.Equal(t, keypair.Kty, pop.Body.Cnf.Jwk.Kty) + assert.Equal(t, expectedKid, pop.Body.Cnf.Jwk.Kid) + + // check ReqCnf + assert.Equal(t, expectedKid, pop.ReqCnf.Kid) +} + +func Test_ShrPopTokenGenerateToken(t *testing.T) { + keypair, err := getKeyPair() + assert.Nil(t, err) + + pop, err := NewPopToken(keypair) + assert.Nil(t, err) + + expectedAccessToken := "myFakeAccessToken" + expectedTimeStamp, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") + assert.Nil(t, err) + + expectedResourceIdValue := "1234" + customClaims := map[string]interface{}{"resourceId": expectedResourceIdValue} + + // calculate kid + expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk.JwkInner) + assert.Nil(t, err) + + // Generate the token and validate its content + popToken, err := pop.GenerateToken(expectedAccessToken, expectedTimeStamp, customClaims) + assert.Nil(t, err) + + toks := strings.Split(popToken, ".") + assert.Equal(t, 3, len(toks)) + + // validate header. + header, err := decodeFromBase64[ShrPopHeader](toks[0]) + assert.Nil(t, err) + assert.Equal(t, Alg, header.Alg) + assert.Equal(t, TokenType, header.Typ) + assert.Equal(t, expectedKid, header.Kid) + + // validate body + body, err := decodeFromBase64[nodeAgentPopTokenBody](toks[1]) + assert.Nil(t, err) + assert.Equal(t, expectedTimeStamp.Truncate(time.Second).Unix(), body.Ts) + assert.Equal(t, expectedResourceIdValue, body.ResourceId) + assert.Equal(t, expectedAccessToken, body.At) + assert.Equal(t, expectedKid, body.Cnf.Jwk.Kid) + + // validate signature. + signature, err := base64.RawURLEncoding.DecodeString(toks[2]) + assert.Nil(t, err) + + signingStr := strings.Join([]string{toks[0], toks[1]}, ".") + err = isSignatureValid(&signingStr, signature, &body.Cnf) + assert.Nil(t, err) +} + +func Test_ShrPopTokenAppendCustomClaims(t *testing.T) { + keypair, err := getKeyPair() + assert.Nil(t, err) + + pop, err := NewPopToken(keypair) + assert.Nil(t, err) + + expectedStringValue := "string" + expectedIntegerValue := 1234 + expectedStrArrValue := []string{"hello", "world"} + expectedStructValue := testStruct{StrValue: "string", IntValue: 1234} + + customClaims := map[string]interface{}{ + "string": expectedStringValue, + "integer": expectedIntegerValue, + "strArray": expectedStrArrValue, + "struct": expectedStructValue, + } + + actualClaims := pop.appendCustomClaimsToBody(customClaims) + + tmp, ok := actualClaims["string"] + assert.True(t, ok) + actualstringValue, ok := tmp.(string) + assert.True(t, ok) + assert.Equal(t, expectedStringValue, actualstringValue) + + tmp, ok = actualClaims["integer"] + assert.True(t, ok) + actualIntegerValue, ok := tmp.(int) + assert.True(t, ok) + assert.Equal(t, expectedIntegerValue, actualIntegerValue) + + tmp, ok = actualClaims["strArray"] + assert.True(t, ok) + actualStrArrValue, ok := tmp.([]string) + assert.True(t, ok) + assert.Equal(t, expectedStrArrValue, actualStrArrValue) + + tmp, ok = actualClaims["struct"] + assert.True(t, ok) + actualStructValue, ok := tmp.(testStruct) + assert.True(t, ok) + assert.Equal(t, expectedStructValue, actualStructValue) + + // finally sanity check that these custom claims can be converted to json + _, err = jsonToBase64(actualClaims) + assert.Nil(t, err) +} + +func Test_ShrPopTokenGetReqCnf(t *testing.T) { + keypair, err := getKeyPair() + assert.Nil(t, err) + + pop, err := NewPopToken(keypair) + assert.Nil(t, err) + + expectedReqCnfBase64, err := jsonToBase64(pop.ReqCnf) + assert.Nil(t, err) + + actualreqCnfBase64, err := pop.GetReqCnf() + assert.Equal(t, expectedReqCnfBase64, actualreqCnfBase64) +} + +func Test_ShrPopTokenExponential2Base64(t *testing.T) { + e := 65537 + base64 := exponential2Base64(e) + // this is the encoded value of a well known exponential value + assert.Equal(t, "AQAB", base64) +} + +func Test_ShrPopTokenCalculatePublicKeyId(t *testing.T) { + jwkinner := JwkInner{ + Kty: "RSA", + E: "AQAB", + N: "MjM1MDg5MDU4MzgxMDg3OTI5NTU3NjM1ODg4NTA3NDE5OTAwNzc0MzkzNzQ5NDcwNzcwMjA2MDIxNjMyNzk5NzYxNDM4NTczMjc3NTA0NzI4ODkzNDUzNjU0NDU0NjMxMjcxNjQ0MTAwMDM0NzUzNzU2MTEyMjkzODYzMDYxMjk5MDQxNzI5OTc0MDg5OTk2OTEzNTY4MjM5OTc0NDMwNTExODI3MDgyNDAzMDQxNDMxMTQ5ODA4ODc4NjE5NTc5MjcwMjAxNjc3ODM1NTQ0NDI3NDMwMDczODI2OTAwODk2MzcxNTM2NzE5NDQyNTUxNzIzNTM5MTg4OTU2MDc4MzI0MzYxNDM4MDEzNjA3OTI0NzMyNTUxMDg5ODU3NjQ1NDA0MTIyMTk3ODUwNjkyMjEyMTk4OTMxMDU1NTkzOTk4NzYyMjIwODg1NDg5NzE4MjQxNDAxMTg2MTMwMzExODAwMDQ2NjEwMjk0MDIzMzQ1MTA1NjE4ODY0ODc0OTgzNzU2NTMzMTY0OTk5MTg1NDk4ODIwOTY3NjYyNjM1NTUxMjk0NTkzNDEwNzc5MzUwODg2MjMxODkyMTc0NTcwODkxNDU4MjIwNzIwMzI5MTg3OTA3NzAxMzMzMDU1NzM0ODk0NjU3MDYzOTMzMzA3MTUwNjgzMTk1NjkyOTk0MzAxMjUxODUwNzUwMTg2MzI5MzM4ODk2NjY3OTQyMDE0OTcwODY3MTAzMTgxNTA5NDAxMTAwMzUwMzk5MDE3MDI3MTI3MTAwMDM5OTIwNjgwNjExNjcxNTQ3MDE1ODM2NzIyMTU1OTgxMTE=", + } + keyId, err := calculatePublicKeyId(&jwkinner) + assert.Nil(t, err) + assert.Equal(t, "a0CyVS__Npcx4GXYm1OCoxrlboOWKF02MXzSSh92ckY", keyId) +} + +func Test_ShrPopTokenSignPayload(t *testing.T) { + keypair, err := getKeyPair() + assert.Nil(t, err) + + payload := []byte("ThisIsMyTestPayLoad") + + sig, err := signPayload(payload, keypair.PrivateKey) + assert.Nil(t, err) + + //now verify the signature using the public key + sigDecode, err := base64.RawURLEncoding.DecodeString(sig) + hash := sha256.New() + hash.Write(payload) + err = rsa.VerifyPKCS1v15(keypair.PublicKey, crypto.SHA256, hash.Sum(nil), sigDecode) + assert.Nil(t, err) +} + +func getKeyPair() (*RsaKeyPair, error) { + + pKey, err := generatePrivateKey() + if err != nil { + return nil, err + } + + return &RsaKeyPair{ + PrivateKey: pKey, + PublicKey: pKey.Public().(*rsa.PublicKey), + RsaSize: RsaSize, + Kty: Kty, + Alg: Alg, + }, nil +} From a5937a7b5e82e1dbf7bf62f51bb2482726844769 Mon Sep 17 00:00:00 2001 From: wecha Date: Thu, 26 Jun 2025 00:08:29 +0000 Subject: [PATCH 2/7] add objectpath customclaim, add nonce check, simplify rsakeymgr, added new abstraction poptokenscheme --- go.mod | 3 +- pkg/auth/poptoken/jwkmanager.go | 8 +- pkg/auth/poptoken/msalauthprovider.go | 18 ++- pkg/auth/poptoken/nodeagentpoptokenscheme.go | 52 +++---- .../poptoken/nodeagentpoptokenscheme_test.go | 46 ++---- .../poptoken/nodeagentpoptokenvalidator.go | 126 +++++++++------- .../nodeagentpoptokenvalidator_test.go | 122 ++++++++++++---- pkg/auth/poptoken/noncecache.go | 96 +++++++++++++ pkg/auth/poptoken/noncecache_test.go | 134 ++++++++++++++++++ pkg/auth/poptoken/poptokenauth.go | 4 +- pkg/auth/poptoken/poptokenscheme.go | 59 ++++++++ pkg/auth/poptoken/poptokenscheme_test.go | 69 +++++++++ pkg/auth/poptoken/raskeymanager_test.go | 60 +++----- pkg/auth/poptoken/rsakeymanager.go | 84 +++-------- pkg/auth/poptoken/shrpoptoken.go | 73 +++++----- pkg/auth/poptoken/shrpoptoken_test.go | 39 +++-- 16 files changed, 675 insertions(+), 318 deletions(-) create mode 100644 pkg/auth/poptoken/noncecache.go create mode 100644 pkg/auth/poptoken/noncecache_test.go create mode 100644 pkg/auth/poptoken/poptokenscheme.go create mode 100644 pkg/auth/poptoken/poptokenscheme_test.go diff --git a/go.mod b/go.mod index 306394fd..7fd6fa2f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.4 + github.com/google/uuid v1.6.0 github.com/hectane/go-acl v0.0.0-20230122075934-ca0b05cb1adb github.com/jmespath/go-jmespath v0.4.0 github.com/lestrrat-go/jwx v1.2.31 @@ -25,7 +26,6 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/goccy/go-json v0.10.3 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect @@ -42,6 +42,7 @@ require ( golang.org/x/text v0.24.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34 // indirect google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) replace ( diff --git a/pkg/auth/poptoken/jwkmanager.go b/pkg/auth/poptoken/jwkmanager.go index 0079d55a..04f913b8 100644 --- a/pkg/auth/poptoken/jwkmanager.go +++ b/pkg/auth/poptoken/jwkmanager.go @@ -16,7 +16,7 @@ const ( ) // Wrapper around jwk library to retrieve and refresh the jwk endpoints from Entra/AAD -type JwkManager struct { +type jwkManager struct { // STS JWK endpoint, e.g. "https://login.microsoftonline.com/common/discovery/keys" jwkEndpoint string ar *jwk.AutoRefresh @@ -26,7 +26,7 @@ type JwkInterface interface { GetPublicKey(kid string) (*rsa.PublicKey, error) } -func (j *JwkManager) GetPublicKey(kid string) (*rsa.PublicKey, error) { +func (j *jwkManager) GetPublicKey(kid string) (*rsa.PublicKey, error) { ctx := context.Background() keys, err := j.ar.Fetch(ctx, j.jwkEndpoint) if err != nil { @@ -46,7 +46,7 @@ func (j *JwkManager) GetPublicKey(kid string) (*rsa.PublicKey, error) { return &pKey, nil } -func NewJwkManager(authorityUrl string, refreshInterval time.Duration) (*JwkManager, error) { +func NewJwkManager(authorityUrl string, refreshInterval time.Duration) (*jwkManager, error) { jwkEndpoint := appendUrl(authorityUrl, IssuerPostfix) ctx, _ := context.WithCancel(context.Background()) ar := jwk.NewAutoRefresh(ctx) @@ -56,5 +56,5 @@ func NewJwkManager(authorityUrl string, refreshInterval time.Duration) (*JwkMana return nil, fmt.Errorf("failed to refresh the jwk endpoint %s", jwkEndpoint) } - return &JwkManager{jwkEndpoint: jwkEndpoint, ar: ar}, nil + return &jwkManager{jwkEndpoint: jwkEndpoint, ar: ar}, nil } diff --git a/pkg/auth/poptoken/msalauthprovider.go b/pkg/auth/poptoken/msalauthprovider.go index 605ef94f..6af538fa 100644 --- a/pkg/auth/poptoken/msalauthprovider.go +++ b/pkg/auth/poptoken/msalauthprovider.go @@ -3,6 +3,7 @@ package poptoken import ( "context" "os" + "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/pkg/errors" @@ -16,7 +17,7 @@ type MsalAuthProvider struct { authorityUrl string scope []string clientCertPath string - rsaKeyManager *RsaKeyManager + rsaKeyManager *rsaKeyManager } func (m MsalAuthProvider) refreshConfidentialClient() (*confidential.Client, error) { @@ -48,7 +49,7 @@ func (m MsalAuthProvider) refreshConfidentialClient() (*confidential.Client, err return &confidentialClient, nil } -func (m MsalAuthProvider) GetToken(targetResourceId string) (string, error) { +func (m MsalAuthProvider) GetToken(targetResourceId string, grpcObjectPath string) (string, error) { // TODO: the underlying client certificate will be refreshed, hence we need to also pick up the new certificate // Longer run we can cache the client but for now we will refresh the client for every token call. @@ -57,12 +58,12 @@ func (m MsalAuthProvider) GetToken(targetResourceId string) (string, error) { return "", err } - keyPair, err := m.rsaKeyManager.GetKeyPair() + keyPair, err := m.rsaKeyManager.GetKeyPair(time.Now()) if err != nil { return "", errors.Wrapf(err, "failed to get keypair for pop token") } - popTokenScheme, err := NewNodeAgentPopTokenAuthScheme(targetResourceId, keyPair) + popTokenScheme, err := NewNodeAgentPopTokenAuthScheme(targetResourceId, grpcObjectPath, keyPair) if err != nil { return "", errors.Wrapf(err, "failed to create new pop token scheme") } @@ -74,7 +75,12 @@ func (m MsalAuthProvider) GetToken(targetResourceId string) (string, error) { return result.AccessToken, nil } -func NewMsalClient(clientId string, tenantId, authorityUrl string, clientCertPath string, rsaKeyManager *RsaKeyManager) (*MsalAuthProvider, error) { +func NewMsalClient(clientId string, tenantId, authorityUrl string, clientCertPath string) (*MsalAuthProvider, error) { + rsaKeyManager, err := NewRsaKeyManager(DefaultRefreshInterval) + if err != nil { + return nil, err + } + m := &MsalAuthProvider{ clientId: clientId, tenantId: tenantId, @@ -85,7 +91,7 @@ func NewMsalClient(clientId string, tenantId, authorityUrl string, clientCertPat } // sanity check to ensure client is setup correctly - _, err := m.refreshConfidentialClient() + _, err = m.refreshConfidentialClient() if err != nil { return nil, err } diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme.go b/pkg/auth/poptoken/nodeagentpoptokenscheme.go index 08cc6776..ead6abec 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenscheme.go +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme.go @@ -1,66 +1,50 @@ package poptoken -import ( - "time" -) - -const ( - tokenType = "token_type" - reqCnf = "req_cnf" - resourceId = "resourceid" -) +import "github.com/google/uuid" // Implements the interface for MSAL SDK to callback when creating the poptoken. // See AuthenticationScheme interface in https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/main/apps/internal/oauth/ops/authority/authority.go#L146 type NodeAgentPopTokenAuthScheme struct { - shrPopToken *ShrPopToken - keyPair *RsaKeyPair - ResourceId string + *PopTokenAuthScheme } // Return the claim containg the pop token kid that will be added to the Entra access token. func (a *NodeAgentPopTokenAuthScheme) TokenRequestParams() map[string]string { - reqCnfBase64, err := a.shrPopToken.GetReqCnf() - if err != nil { - return map[string]string{} - } - - return map[string]string{ - tokenType: a.shrPopToken.Header.Typ, - reqCnf: reqCnfBase64, - } + return a.PopTokenAuthScheme.TokenRequestParams() } // Return the keyId for MSAL to lookup for a cached access token. If it does not exist, MSAL will request a new access token func (a *NodeAgentPopTokenAuthScheme) KeyID() string { - return a.shrPopToken.Header.Kid + return a.PopTokenAuthScheme.KeyID() } // Generate the pop token; adding in the accessToken generated by Entra. func (a *NodeAgentPopTokenAuthScheme) FormatAccessToken(accessToken string) (string, error) { - // append accessToken and our custom claims to the pop token. - // Note custom claims should be compatible with JWT specs, we don't expect these claims to be complex - customClaims := map[string]interface{}{ - resourceId: a.ResourceId} - - return a.shrPopToken.GenerateToken(accessToken, time.Now(), customClaims) + return a.PopTokenAuthScheme.FormatAccessToken(accessToken) } // Return the token type. Must be "pop" func (a *NodeAgentPopTokenAuthScheme) AccessTokenType() string { - return a.shrPopToken.Header.Typ + return a.PopTokenAuthScheme.AccessTokenType() } -// Create a new instance of NodeAgentPopTokenAuthScheme. Pass in the custom claims to be set in the pop token here, e.g. resourceId -func NewNodeAgentPopTokenAuthScheme(resourceId string, rsaKeyPair *RsaKeyPair) (*NodeAgentPopTokenAuthScheme, error) { - shrPopToken, err := NewPopToken(rsaKeyPair) +// Create a new instance of NodeAgentPopTokenAuthScheme. +// targetResourceId: the ARM resourceId representing the edge node machine. This is the Arc For Server resource Id and is part of the node entity. +// grpcObjectId: the uri to the grpc entity, e.g. container. This will be passed in as part of the grpc metadata. +func NewNodeAgentPopTokenAuthScheme(targetNodeId string, grpcObjectId string, rsaKeyPair *RsaKeyPair) (*NodeAgentPopTokenAuthScheme, error) { + popTokenScheme, err := NewPopTokenAuthScheme( + map[string]interface{}{ + "nodeid": targetNodeId, + "p": grpcObjectId, + "nonce": uuid.New().String(), + }, + rsaKeyPair) if err != nil { return nil, err } return &NodeAgentPopTokenAuthScheme{ - shrPopToken: shrPopToken, - ResourceId: resourceId, + PopTokenAuthScheme: popTokenScheme, }, nil } diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go index 7f764253..13cadf5a 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go @@ -8,54 +8,34 @@ import ( "github.com/stretchr/testify/assert" ) -// This test suite focus on the testing nodeagentpoptokenscheme is returning the expected values that MSAL expected -// the actual token generation is tested in shrpoptoken_test +// This test suite focus on the testing the custom claims of nodeagentpoptokenscheme is returned +// poptokenscheme_test the underlying poptokenscheme_test func Test_NodeAgentPopTokenScheme(t *testing.T) { - expectedResourceId := "myresourceId" + expectedNodeId := "mynodeId" + expectedGrpcObjectId := "myObjectId" kmgr, err := NewRsaKeyManager(time.Hour) assert.Nil(t, err) - keypair, err := kmgr.GetKeyPair() - assert.Nil(t, err) - - // create a "reference" pop token that we can use to validate some of the nodeagentpoptokenscheme content since it - // should generate the same values - refPopToken, err := NewPopToken(keypair) + keypair, err := kmgr.GetKeyPair(time.Now()) assert.Nil(t, err) // Generate nodeagent scheme - nodeAgentScheme, err := NewNodeAgentPopTokenAuthScheme(expectedResourceId, keypair) - - //validate AccessTokenType returns "pop" - assert.Equal(t, TokenType, nodeAgentScheme.AccessTokenType()) - assert.Equal(t, refPopToken.Header.Typ, nodeAgentScheme.AccessTokenType()) - - //Validate KeyID - assert.Equal(t, refPopToken.Header.Kid, nodeAgentScheme.KeyID()) - - // Validate TokenRequestParams returns a specific struct - reqCnf := nodeAgentScheme.TokenRequestParams() + nodeAgentScheme, err := NewNodeAgentPopTokenAuthScheme(expectedNodeId, expectedGrpcObjectId, keypair) - tokenType, ok := reqCnf["token_type"] - assert.True(t, ok) - assert.Equal(t, refPopToken.Header.Typ, tokenType) - - expectedCnf, err := refPopToken.GetReqCnf() - assert.Nil(t, err) - cnf, ok := reqCnf["req_cnf"] - assert.True(t, ok) - assert.Equal(t, expectedCnf, cnf) - - // Validate FormatAccessToken. Here we just check that the custom claim "resourceId" was added. + //For nodeagentpoptokenscheme, we just verify that the custom claims were added to the token. popToken, err := nodeAgentScheme.FormatAccessToken("accessToken") assert.Nil(t, err) assert.NotEmpty(t, popToken) toks := strings.Split(popToken, ".") assert.Equal(t, 3, len(toks)) - body, err := decodeFromBase64[nodeAgentPopTokenBody](toks[1]) + + body, err := decodeFromBase64[NodeAgentPopTokenBody](toks[1]) + assert.Nil(t, err) - assert.Equal(t, expectedResourceId, body.ResourceId) + assert.Equal(t, expectedNodeId, body.NodeId) + assert.Equal(t, expectedGrpcObjectId, body.GrpcObjectId) + assert.NotEmpty(t, body.Nonce) } diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go index 6cad5bed..fb1936ac 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go @@ -21,13 +21,18 @@ import ( const ( TokenVersion1 = "1.0" TokenVersion2 = "2.0" - PopTokenValidInterval = 1 * time.Hour //TODO: should we make this smaller? + PopTokenValidInterval = 5 * time.Minute + PopTokenClockSkew = 1 * time.Minute ) -type nodeAgentPopTokenBody struct { +type NodeAgentPopTokenBody struct { ShrPopTokenBody - // target resource Id. Expected to match ShrPopTokenValidator.TargetResourceId - ResourceId string `json:"resourceid"` + // target node Id. this is expected to be the Arc For Server resource Id. + NodeId string `json:"nodeid"` + // uri to the grpc object targeted + GrpcObjectId string `json:"p"` + // unique id bound to the token to prevent replay attaches. + Nonce string `json:"nonce"` } // contains a subset of custom claims in Entra/AzureAD access tokens we want to validate. @@ -46,9 +51,11 @@ type AccessTokenCustomClaims struct { jwt.RegisteredClaims } -type ShrPopTokenValidator struct { - // A4S agent resourceId - TargetResourceId string +type shrPopTokenValidator struct { + // id of node which the pop token is bounded to. This is expected to be the Arc For Server (A4S) agent resourceId + NodeId string + // pathurl of grpc entity that the pop token is bounded to, e.g. the uri to the storage container entity. + GrpcObjectId string // Tenant Id of CMP 1P TenantId string // The target Id. In this case, client Id or one of its identifierUri, depending on the token version. @@ -59,6 +66,8 @@ type ShrPopTokenValidator struct { IssuerUrl string // component use to query Entra JWK endpoint. jwk JwkInterface + // component use to cache and check request's nonceCache to prevent replay attack + nonceCache NonceCacheInterface } // handle situation where url may or may not have a backslash @@ -70,22 +79,23 @@ func appendUrl(url string, postfix string) string { return fmt.Sprintf("%s%s%s", url, sep, postfix) } -func isTokenExpire(timestamp int64, now time.Time) error { +func isTokenExpire(timestamp int64, now time.Time, clockSkew time.Duration) error { var issuedTime time.Time convertTime(timestamp, &issuedTime) expireat := issuedTime.Add(PopTokenValidInterval) - if expireat.Before(now) { - return fmt.Errorf("pop token has expired. Time when validated: %v, issued At: %v, valid duration: %v", now, issuedTime, PopTokenValidInterval) + skewedTime := now.Add(-clockSkew) + if expireat.Before(skewedTime) { + return fmt.Errorf("pop token has expired. currentTimestamp: %v, issuedAt: %v, validDuration: %v", now, issuedTime, PopTokenValidInterval) } return nil } func isHeaderValid(header *ShrPopHeader) error { if header.Typ != TokenType { - return fmt.Errorf("expected token type %s, got %s", TokenType, header.Typ) + return fmt.Errorf("unsupported token type in pop token header; expected %s, got %s", TokenType, header.Typ) } if header.Alg != Alg { - return fmt.Errorf("expected alg %s, got %s", Alg, header.Alg) + return fmt.Errorf("unsupported alg in pop token header, expected %s, got %s", Alg, header.Alg) } return nil @@ -104,32 +114,30 @@ func verifyPayload(signingStr *string, sig []byte, pubKey *rsa.PublicKey) error hash.Write([]byte(*signingStr)) err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash.Sum(nil), sig) if err != nil { - return errors.Wrapf(err, "failed to validate signature of poptoken using cnf") + return errors.Wrapf(err, "failed to validate signature of pop token") } return nil } func publicRSA256KeyFromCnf(cnf *Cnf) (*rsa.PublicKey, error) { - modulus, err := base64.URLEncoding.DecodeString(cnf.Jwk.N) + modulus, err := base64.RawURLEncoding.DecodeString(cnf.Jwk.N) if err != nil { - err := errors.Wrapf(err, "error while parsing poptoken cnf: failed to decode modulus") + err := errors.Wrapf(err, "error while parsing pop token cnf: failed to decode modulus") return nil, err } - n := big.NewInt(0) - n.SetString(string(modulus), 10) + n := new(big.Int).SetBytes(modulus) e, err := base64ToExponential(string(cnf.Jwk.E)) if err != nil { - err := errors.Wrapf(err, "error while parsing poptoken cnf: failed to parse exponent") + err := errors.Wrapf(err, "error while parsing pop token cnf: failed to parse exponent") return nil, err } pKey := rsa.PublicKey{N: n, E: int(e)} - return &pKey, nil } func base64ToExponential(encodedE string) (int, error) { - decE, err := base64.URLEncoding.DecodeString(encodedE) + decE, err := base64.RawURLEncoding.DecodeString(encodedE) if err != nil { return 0, err } @@ -180,12 +188,12 @@ func convertTime(i any, tm *time.Time) { } } -func (s *ShrPopTokenValidator) parseAndValidateAccessToken(tokenStr string, popTokenKid string) error { +func (s *shrPopTokenValidator) parseAndValidateAccessToken(tokenStr string, popTokenKid string) error { // ParseWithClaims() will validate the token expiry date and signing. at, err := jwt.ParseWithClaims(tokenStr, &AccessTokenCustomClaims{}, func(token *jwt.Token) (interface{}, error) { kid, ok := token.Header["kid"].(string) if !ok { - return nil, fmt.Errorf("failed to find kid in access token header") + return nil, fmt.Errorf("missing metadata 'kid' 'in the header of claim 'at'") } pKey, err := s.jwk.GetPublicKey(kid) @@ -197,48 +205,48 @@ func (s *ShrPopTokenValidator) parseAndValidateAccessToken(tokenStr string, popT }, jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()})) if err != nil { - return errors.Wrapf(err, "failed to parse access token") + return errors.Wrapf(err, "failed to validate claim 'at'") } err = s.validateAccessTokenClaims(at, popTokenKid) return err } -func (s *ShrPopTokenValidator) validateAccessTokenClaims(token *jwt.Token, popTokenKid string) error { +func (s *shrPopTokenValidator) validateAccessTokenClaims(token *jwt.Token, popTokenKid string) error { if token == nil { - return fmt.Errorf("empty token in validateAccessTokenClaims!") + return fmt.Errorf("missing claim 'at' in pop token.") } // now read in Entra access token's specific claims and validate them claims, ok := token.Claims.(*AccessTokenCustomClaims) if !ok { - return fmt.Errorf("failed to retrieve expected claims in access token") + return fmt.Errorf("failed to parse claim 'at' in pop token.") } // Handle claims that are specifc to token versions switch claims.TokenVersion { case TokenVersion1: if claims.AppId != s.ClientId { - return fmt.Errorf("invalid appId claim") + return fmt.Errorf("claim 'at.appId' does not match the expected app Id. token is not from an accepted client") } if claims.Issuer != s.IssuerUrl { - return fmt.Errorf("invalid issuer") + return fmt.Errorf("claim 'at.iss' points to an invalid issuer for v1 token") } case TokenVersion2: if claims.Azp != s.ClientId { - return fmt.Errorf("invalid azp claim") + return fmt.Errorf("claim 'at.azp' does not match the expected app Id. token is not from an accepted client") } // for v2, issuer ends with v2.0 if claims.Issuer != appendUrl(s.IssuerUrl, "v2.0") { - return fmt.Errorf("invalid issuer for v2 token") + return fmt.Errorf("claim 'at.iss' points to an invalid issuer for v2 token") } default: - return fmt.Errorf("unknown token version %s. expected either %sor %s", claims.TokenVersion, TokenVersion1, TokenVersion2) + return fmt.Errorf("claim 'at.ver' has an unknown token version %s. expected either %sor %s", claims.TokenVersion, TokenVersion1, TokenVersion2) } if claims.ReqCnf.Kid != popTokenKid { - return fmt.Errorf("kid in pop token did not match kid in access token. expected kid: %s, got kid: %s", claims.ReqCnf.Kid, popTokenKid) + return fmt.Errorf("claim 'kid' in pop token did not match kid in claim 'at.cnf.kid'. pop token may have been signed with an unexpected key.") } foundAud := false @@ -249,24 +257,38 @@ func (s *ShrPopTokenValidator) validateAccessTokenClaims(token *jwt.Token, popTo } } if !foundAud { - return fmt.Errorf("aud claim was not expected") + return fmt.Errorf("claim 'at.aud' did not match the expected audience") } return nil } -func (s *ShrPopTokenValidator) isCustomClaimsValid(body *nodeAgentPopTokenBody) error { - if body.ResourceId != s.TargetResourceId { - return fmt.Errorf("invalid resourceId") +func (s *shrPopTokenValidator) isCustomClaimsValid(body *NodeAgentPopTokenBody) error { + if body.NodeId != s.NodeId { + return fmt.Errorf("claim 'nodeid'in pop token does not match the id of the edge node.") + } + if body.GrpcObjectId != s.GrpcObjectId { + return fmt.Errorf("claim 'p' in pop token does not match the targeted moc entity") + } + return nil +} + +func (s *shrPopTokenValidator) isTokenReused(nonce string, now time.Time) error { + if nonce == "" { + return fmt.Errorf("claim 'nonce' in pop token is empty/missing, expected a unique id") } + ok := s.nonceCache.IsNonceExists(nonce, now) + if ok { + return fmt.Errorf("claim 'nonce' in pop token has been used before, potentially a replay attack") + } return nil } -func (s *ShrPopTokenValidator) Validate(popToken string) error { +func (s *shrPopTokenValidator) Validate(popToken string) error { toks := strings.Split(popToken, ".") if len(toks) != 3 { - return fmt.Errorf("invalid pop tokens expected 3 segments, got %d", len(toks)) + return fmt.Errorf("invalid pop token; expected 3 segments, got %d", len(toks)) } header, err := decodeFromBase64[ShrPopHeader](toks[0]) @@ -278,12 +300,16 @@ func (s *ShrPopTokenValidator) Validate(popToken string) error { return err } - body, err := decodeFromBase64[nodeAgentPopTokenBody](toks[1]) + body, err := decodeFromBase64[NodeAgentPopTokenBody](toks[1]) if err != nil { return err } - if err := isTokenExpire(body.Ts, time.Now()); err != nil { + if err := isTokenExpire(body.Ts, time.Now(), PopTokenClockSkew); err != nil { + return err + } + + if err := s.isTokenReused(body.Nonce, time.Now()); err != nil { return err } @@ -302,7 +328,7 @@ func (s *ShrPopTokenValidator) Validate(popToken string) error { } // now retrieve the inner access token - err = s.parseAndValidateAccessToken(body.At, body.Cnf.Jwk.Kid) + err = s.parseAndValidateAccessToken(body.At, header.Kid) if err != nil { return err } @@ -310,18 +336,20 @@ func (s *ShrPopTokenValidator) Validate(popToken string) error { return nil } -func NewPopTokenValidator(targetResourceId string, tenantId string, audiences []string, clientId string, authorityUrl string, jwk JwkInterface) (*ShrPopTokenValidator, error) { +func NewPopTokenValidator(nodeId string, grpcobjectId string, tenantId string, audiences []string, clientId string, authorityUrl string, jwk JwkInterface, nonceChecker NonceCacheInterface) (*shrPopTokenValidator, error) { audienceMap := make(map[string]bool) for _, aud := range audiences { audienceMap[aud] = true } - return &ShrPopTokenValidator{ - TargetResourceId: targetResourceId, - TenantId: tenantId, - Audience: audienceMap, - ClientId: clientId, - IssuerUrl: appendUrl(authorityUrl, tenantId), - jwk: jwk, + return &shrPopTokenValidator{ + NodeId: nodeId, + GrpcObjectId: grpcobjectId, + TenantId: tenantId, + Audience: audienceMap, + ClientId: clientId, + IssuerUrl: appendUrl(authorityUrl, tenantId), + jwk: jwk, + nonceCache: nonceChecker, }, nil } diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go index 787db388..048d9dac 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go @@ -49,25 +49,36 @@ func Test_NodeAgentPopTokenValidatorIsTokenExpire(t *testing.T) { tests := []struct { name string tokenCheckAt time.Time + clockSkew time.Duration shouldPass bool }{ { name: "token valid", // set token evaluation time to be 1 second after token was issued, token is valid. tokenCheckAt: tokenIssuedAt.Add(time.Second * 1), + clockSkew: 0, shouldPass: true, }, { name: "token expired", // set token evaluation time to be 10 seconds after max valid period. token has expired. tokenCheckAt: tokenIssuedAt.Add(PopTokenValidInterval).Add(time.Second * 10), + clockSkew: 0, shouldPass: false, }, + { + name: "token pass due to clock skew", + // set token evaluation time to be 10 seconds after max valid period. token should have expired + // like previous test, but passes thanks to the allowed clock skew of 11 seconds. + tokenCheckAt: tokenIssuedAt.Add(PopTokenValidInterval).Add(time.Second * 10), + clockSkew: time.Second * 11, + shouldPass: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := isTokenExpire(tokenIssuedAtInt, tt.tokenCheckAt) + err := isTokenExpire(tokenIssuedAtInt, tt.tokenCheckAt, tt.clockSkew) if tt.shouldPass { assert.Nil(t, err) } else { @@ -117,6 +128,29 @@ func Test_NodeAgentPopTokenValidatorIsHeaderValid(t *testing.T) { } } +func Test_NodeAgentPopTokenValidatorIsTokenReused(t *testing.T) { + nonceCache := &FakeNonceCache{Exists: false} + + // for this test, we only care about setting the noncecache + popTokenValidator, err := NewPopTokenValidator("", "", "", []string{"aud"}, "", "", nil, nonceCache) + assert.Nil(t, err) + + // simulate nonce entry does not exists in cache, i.e. we have not seen the token before + nonceCache.Exists = false + err = popTokenValidator.isTokenReused("myId", time.Now()) + assert.Nil(t, err) + + // simulate nonce entry exists in cache, i.e. same token is reused, potentially a replay token + nonceCache.Exists = true + err = popTokenValidator.isTokenReused("myId", time.Now()) + assert.NotNil(t, err) + + // simulate missing nonce. we will reject token + nonceCache.Exists = false + err = popTokenValidator.isTokenReused("", time.Now()) + assert.NotNil(t, err) +} + func Test_NodeAgentPopTokenValidatorIsSignatureValid(t *testing.T) { keypair, err := getKeyPair() assert.Nil(t, err) @@ -165,38 +199,55 @@ func Test_NodeAgentPopTokenValidatorbase64ToExponential(t *testing.T) { } func Test_NodeAgentPopTokenValidatorIsCustomClaimsValid(t *testing.T) { - expectedResourceId := "myResourceId" + expectedNodeId := "myNodeId" + expectedGrpcObjectId := "myObjectId" tests := []struct { - name string - actualResourceId string - shouldPass bool + name string + actualNodeId string + actualGrpcObjectId string + shouldPass bool }{ { - name: "valid resource Id claim", - actualResourceId: expectedResourceId, - shouldPass: true, + name: "valid nodeId and objectId claims", + actualNodeId: expectedNodeId, + actualGrpcObjectId: expectedGrpcObjectId, + shouldPass: true, }, { - name: "invalid resourceId claim", - actualResourceId: "somethingelse", - shouldPass: false, + name: "invalid nodeId claim", + actualNodeId: "somethingelse", + actualGrpcObjectId: expectedGrpcObjectId, + shouldPass: false, }, { - name: "missing resourceId claim", - actualResourceId: "", - shouldPass: false, + name: "missing nodeId claim", + actualNodeId: "", + actualGrpcObjectId: expectedGrpcObjectId, + shouldPass: false, + }, + { + name: "invalid grpcObjectId claim", + actualNodeId: expectedNodeId, + actualGrpcObjectId: "somethingelse", + shouldPass: false, + }, + { + name: "missing grpcObjectId claim", + actualNodeId: expectedNodeId, + actualGrpcObjectId: "", + shouldPass: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // for this test, we only care about setting the targetResourceId - popTokenValidator, err := NewPopTokenValidator(expectedResourceId, "", []string{"aud"}, "", "", nil) + // for this test, we only care about initializeing the custom claims + popTokenValidator, err := NewPopTokenValidator(expectedNodeId, expectedGrpcObjectId, "", []string{"aud"}, "", "", nil, nil) assert.Nil(t, err) - // likewise we only set the resourceId in the poptoken body - popTokenBody := nodeAgentPopTokenBody{ResourceId: tt.actualResourceId} + // likewise we only set the custom claims in the poptoken body + popTokenBody := NodeAgentPopTokenBody{NodeId: tt.actualNodeId, GrpcObjectId: tt.actualGrpcObjectId} err = popTokenValidator.isCustomClaimsValid(&popTokenBody) if tt.shouldPass { @@ -219,7 +270,7 @@ func Test_NodeAgentPopTokenValidatorParseAndValidateAccessToken(t *testing.T) { defaultPKey, _ := getPrivateKey() defaultJwkMgr := &FakeJwrMgr{PublicKey: &defaultPKey.PublicKey} - + nonceCache := &FakeNonceCache{Exists: false} // By default, the accesstoken and validator will use the same expected values as listed above // hence the access token validation will succeeded. @@ -369,12 +420,14 @@ func Test_NodeAgentPopTokenValidatorParseAndValidateAccessToken(t *testing.T) { // The token validator is set to the expected values, except for invalid jwk tokenValidator, err := NewPopTokenValidator( + "notused", // this is not tested here. "notused", // this is not tested here. expectedTenantId, []string{expectedAudience}, expectedClientId, expectedAuthorityUrl, - jwkMgr) + jwkMgr, + nonceCache) assert.Nil(t, err) s, err := generateAccessToken(&claims, defaultPKey) @@ -399,7 +452,8 @@ func Test_NodeAgentPopTokenValidatorParseAndValidateAccessToken(t *testing.T) { // this is just a simple test to validate end to end. func Test_NodeAgentPopTokenValidatorValidate(t *testing.T) { authorityUrl := "https://login.fake.microsoftonline.com" - resourceId := "resourceId" + nodeId := "nodeId" + grpcObjectId := "objectId" tenantId := "cmpTenantId" clientId := "cmpClientId" audience := "cmpAudience" @@ -411,6 +465,7 @@ func Test_NodeAgentPopTokenValidatorValidate(t *testing.T) { accessTokenPKey, err := getPrivateKey() assert.Nil(t, err) jwkMgr := &FakeJwrMgr{PublicKey: &accessTokenPKey.PublicKey} + nonceCache := &FakeNonceCache{Exists: false} rsaKeyPair, err := getKeyPair() assert.Nil(t, err) @@ -439,17 +494,19 @@ func Test_NodeAgentPopTokenValidatorValidate(t *testing.T) { assert.Nil(t, err) // Generate pop token - pt, err := popToken.GenerateToken(at, time.Now(), map[string]interface{}{"resourceId": resourceId}) + pt, err := popToken.GenerateToken(at, time.Now(), map[string]interface{}{"nodeId": nodeId, "p": grpcObjectId, "nonce": "nonceId"}) assert.Nil(t, err) // validate poptoken tokenValidator, err := NewPopTokenValidator( - resourceId, + nodeId, + grpcObjectId, tenantId, []string{audience}, clientId, authorityUrl, - jwkMgr) + jwkMgr, + nonceCache) assert.Nil(t, err) err = tokenValidator.Validate(pt) @@ -496,6 +553,15 @@ func (j *FakeJwrMgr) GetPublicKey(kid string) (*rsa.PublicKey, error) { } } +// fake nonceCache that we can set to return true or false at will. +type FakeNonceCache struct { + Exists bool +} + +func (n *FakeNonceCache) IsNonceExists(nonceId string, now time.Time) bool { + return n.Exists +} + func getPrivateKey() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, RsaSize) } @@ -503,11 +569,9 @@ func getPrivateKey() (*rsa.PrivateKey, error) { func publicKeyToCnf(keyPair *RsaKeyPair) *Cnf { return &Cnf{ Jwk: Jwk{ - JwkInner: JwkInner{ - Kty: keyPair.Kty, - E: exponential2Base64(keyPair.PublicKey.E), - N: base64.URLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.String())), - }, + Kty: keyPair.Kty, + E: exponential2Base64(keyPair.PublicKey.E), + N: base64.RawURLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.Bytes())), }, } } diff --git a/pkg/auth/poptoken/noncecache.go b/pkg/auth/poptoken/noncecache.go new file mode 100644 index 00000000..307b8cc4 --- /dev/null +++ b/pkg/auth/poptoken/noncecache.go @@ -0,0 +1,96 @@ +package poptoken + +import ( + "sync" + "time" +) + +const ( + DefaultNonceCacheSize = 20 + DefaultNonceValidInterval = time.Minute * 5 +) + +type NonceCacheInterface interface { + IsNonceExists(nonceId string, now time.Time) bool +} + +type Nonce struct { + Id string + CreatedDateTime time.Time +} + +// Implement a simple LRU cache that evicts older nonce entries. +type nonceCache struct { + cache map[string]*Nonce + queue []*Nonce + nonceValidInterval time.Duration + size int + maxSize int + mutex sync.Mutex +} + +func (n *nonceCache) append(nonce *Nonce) { + n.cache[nonce.Id] = nonce + n.queue = append(n.queue, nonce) + n.size++ + return + +} + +// keep trimming the cache until all expired entries are purged or we are within the cache size +func (n *nonceCache) trim(now time.Time) { + isDelete := true + + for isDelete { + if len(n.queue) == 0 { + break + } + + nonce := n.queue[0] + isDelete = nonce.CreatedDateTime.Add(n.nonceValidInterval).Before(now) + + if !isDelete { + isDelete = n.size >= n.maxSize + } + if isDelete { + delete(n.cache, nonce.Id) + n.queue = n.queue[1:] + n.size-- + } + } +} + +func (n *nonceCache) IsNonceExists(nonceId string, now time.Time) bool { + n.mutex.Lock() + defer n.mutex.Unlock() + + existNonce, ok := n.cache[nonceId] + // entries are evicted lazily, so if entry has expired, return false even though the entry + // is still in the cache. + if ok { + return existNonce.CreatedDateTime.Add(n.nonceValidInterval).After(now) + } + + nonce := &Nonce{ + Id: nonceId, + CreatedDateTime: now, + } + n.append(nonce) + n.trim(now) + + return false + +} + +func (n *nonceCache) GetCacheSize() int { + return n.size +} + +func NewNonceCache(maxSize int, nonceValidPeriod time.Duration) (*nonceCache, error) { + return &nonceCache{ + cache: make(map[string]*Nonce), + nonceValidInterval: nonceValidPeriod, + size: 0, + maxSize: maxSize, + }, nil +} diff --git a/pkg/auth/poptoken/noncecache_test.go b/pkg/auth/poptoken/noncecache_test.go new file mode 100644 index 00000000..2c2cbfcc --- /dev/null +++ b/pkg/auth/poptoken/noncecache_test.go @@ -0,0 +1,134 @@ +package poptoken + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + testNonceCacheSize = 3 + testNonceValidInterval = time.Minute * 1 + testNonceNowDateTimeStr = "2025-12-01T15:00:00Z" +) + +func Test_NonceCacheIdExists(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + assert.Nil(t, err) + + nonceId := "nonceId_1" + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + + // first time seeing this nodeId, returning false + isexist := noncecache.IsNonceExists(nonceId, now) + assert.False(t, isexist) + + // the second time the nonceId should be cached. + isexist = noncecache.IsNonceExists(nonceId, now) + assert.True(t, isexist) + + // Validate a new entry will return false + isexist = noncecache.IsNonceExists("nonceId_2", now) + assert.False(t, isexist) +} + +// expired entries are lazily evicted, so an invalid entry can remain in the cache +// validate that even the expired entry exists, we will still return false. +func Test_NonceCacheIdExistsButExpired(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + assert.Nil(t, err) + + nonceId := "nonceId_1" + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + + // first time seeing this nodeId, returning false + isexist := noncecache.IsNonceExists(nonceId, now) + assert.False(t, isexist) + + // the second time the nonceId should be cached. + isexist = noncecache.IsNonceExists(nonceId, now) + assert.True(t, isexist) + + // now simulate checking for the nonceId past the valid period and try again + // this time it should return false. + now = now.Add(testNonceValidInterval * 2) + isexist = noncecache.IsNonceExists(nonceId, now) + assert.False(t, isexist) +} + +// Validate that expired Ids will be evicted upon the addition of a new entry. +func Test_NonceCacheEvictExpiredIds(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + assert.Nil(t, err) + + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + for i := 0; i < testNonceCacheSize-1; i++ { + id := fmt.Sprintf("%d", i) + now = now.Add(time.Second) + noncecache.IsNonceExists(id, now) + + //need to call twice to confirm the nonceId were added, since the first time + // it is added, it will not exist + isexist := noncecache.IsNonceExists(id, now) + assert.True(t, isexist) + } + + // simulate querying a new nonce Id after time where the previously added ids expired. + // adding the new entry will trigger an eviction of the expired entries + newId := "new" + now = now.Add(testNonceValidInterval * 2) + noncecache.IsNonceExists(newId, now) + + // validate older entry has been evicted; size of cache should be 1 + // we check the size before checking if the older ids have been evicted as they will get + // readded again. + assert.Equal(t, 1, noncecache.GetCacheSize()) + + // validate the ids no longer exists in cache + for i := 0; i < testNonceCacheSize-1; i++ { + id := fmt.Sprintf("%d", i) + isexist := noncecache.IsNonceExists(id, now) + assert.False(t, isexist) + } + +} + +// Validate that the oldest Ids will be evicted upon the addition of a new entry that exceeds the cache size. +func Test_NonceCacheEvictOverflowIds(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + assert.Nil(t, err) + + idsToAddCount := testNonceCacheSize + 2 + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + for i := 0; i < idsToAddCount; i++ { + id := fmt.Sprintf("%d", i) + now = now.Add(time.Second) + noncecache.IsNonceExists(id, now) + + //need to call twice to confirm the nonceId were added, since the first time + // it is added, it will not exist + isexist := noncecache.IsNonceExists(id, now) + assert.True(t, isexist) + + // validate size of cache does not exceed the max even if more ids were added. + assert.True(t, noncecache.GetCacheSize() <= testNonceCacheSize) + } + + // when we add more ids than supported, the oldest ids get evicted. + // verify the earlier ids should no longer exist. We need to check in reverse order + // to avoid evicting the newer entries + for i := idsToAddCount - 1; i >= 0; i-- { + id := fmt.Sprintf("%d", i) + now = now.Add(time.Second) + + isexist := noncecache.IsNonceExists(id, now) + if i >= testNonceCacheSize { + assert.True(t, isexist) + } else { + assert.False(t, isexist) + } + } + +} diff --git a/pkg/auth/poptoken/poptokenauth.go b/pkg/auth/poptoken/poptokenauth.go index 528c59b8..151c6ecd 100644 --- a/pkg/auth/poptoken/poptokenauth.go +++ b/pkg/auth/poptoken/poptokenauth.go @@ -31,12 +31,12 @@ func NewPopTokenAuth(msalProvider *MsalAuthProvider, targetResourceId string) (* } func (p *PopTokenAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - accessToken, err := p.msalauthprovider.GetToken(p.targetResourceId) + accessToken, err := p.msalauthprovider.GetToken(p.targetResourceId, uri[0]) if err != nil { return nil, errors.Wrapf(err, "failed to generate poptoken") } - return map[string]string{"authorization": accessToken}, nil + return map[string]string{"authorization": accessToken, "uri": uri[0]}, nil } func (p *PopTokenAuth) RequireTransportSecurity() bool { diff --git a/pkg/auth/poptoken/poptokenscheme.go b/pkg/auth/poptoken/poptokenscheme.go new file mode 100644 index 00000000..281813de --- /dev/null +++ b/pkg/auth/poptoken/poptokenscheme.go @@ -0,0 +1,59 @@ +package poptoken + +import ( + "sync" + "time" +) + +const ( + tokenType = "token_type" + reqCnf = "req_cnf" +) + +// Implements the interface for MSAL SDK to callback when creating the poptoken. +// See AuthenticationScheme interface in https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/main/apps/internal/oauth/ops/authority/authority.go#L146 +type PopTokenAuthScheme struct { + shrPopToken *shrPopToken + claims map[string]interface{} + currRsaKeyPair *RsaKeyPair + rsaKeyPairRefreshInterval time.Duration + rsaKeyPairRefreshMutex sync.Mutex +} + +// Return the claim containg the pop token kid that will be added to the Entra access token. +func (a *PopTokenAuthScheme) TokenRequestParams() map[string]string { + return map[string]string{ + tokenType: a.shrPopToken.Header.Typ, + reqCnf: a.shrPopToken.GetReqCnf(), + } +} + +// Return the keyId for MSAL to lookup for a cached access token. If it does not exist, MSAL will request a new access token +func (a *PopTokenAuthScheme) KeyID() string { + return a.shrPopToken.Header.Kid +} + +// Generate the pop token; adding in the accessToken generated by Entra. +func (a *PopTokenAuthScheme) FormatAccessToken(accessToken string) (string, error) { + // append accessToken and our custom claims to the pop token. + // Note custom claims should be compatible with JWT specs, we don't expect these claims to be complex + return a.shrPopToken.GenerateToken(accessToken, time.Now(), a.claims) +} + +// Return the token type. Must be "pop" +func (a *PopTokenAuthScheme) AccessTokenType() string { + return a.shrPopToken.Header.Typ +} + +// Create a new instance of PopTokenAuthScheme. Pass in the custom claims to be set in the pop token here, e.g. resourceId +func NewPopTokenAuthScheme(claims map[string]interface{}, rsaKeyPair *RsaKeyPair) (*PopTokenAuthScheme, error) { + shrPopToken, err := NewPopToken(rsaKeyPair) + if err != nil { + return nil, err + } + + return &PopTokenAuthScheme{ + shrPopToken: shrPopToken, + claims: claims, + }, nil +} diff --git a/pkg/auth/poptoken/poptokenscheme_test.go b/pkg/auth/poptoken/poptokenscheme_test.go new file mode 100644 index 00000000..aaee92e0 --- /dev/null +++ b/pkg/auth/poptoken/poptokenscheme_test.go @@ -0,0 +1,69 @@ +package poptoken + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type TestPopTokenSchemeBody struct { + ShrPopTokenBody + NodeId string `json:"nodeid"` +} + +// This test suite focus on the testing poptokenscheme is returning the expected values that MSAL expected +// the actual token generation is tested in shrpoptoken_test +func Test_PopTokenScheme(t *testing.T) { + expectedNodeId := "mynodeId" + + kmgr, err := NewRsaKeyManager(time.Hour) + assert.Nil(t, err) + + keypair, err := kmgr.GetKeyPair(time.Now()) + assert.Nil(t, err) + + // create a "reference" pop token that we can use to validate some of the nodeagentpoptokenscheme content since it + // should generate the same values + refPopToken, err := NewPopToken(keypair) + assert.Nil(t, err) + + // Generate nodeagent scheme + claims := map[string]interface{}{ + "nodeId": expectedNodeId, + } + popTokenScheme, err := NewPopTokenAuthScheme(claims, keypair) + + //validate AccessTokenType returns "pop" + assert.Equal(t, TokenType, popTokenScheme.AccessTokenType()) + assert.Equal(t, refPopToken.Header.Typ, popTokenScheme.AccessTokenType()) + + //Validate KeyID + assert.Equal(t, refPopToken.Header.Kid, popTokenScheme.KeyID()) + + // Validate TokenRequestParams returns a specific struct + reqCnf := popTokenScheme.TokenRequestParams() + + tokenType, ok := reqCnf["token_type"] + assert.True(t, ok) + assert.Equal(t, refPopToken.Header.Typ, tokenType) + + expectedCnf := refPopToken.GetReqCnf() + assert.Nil(t, err) + cnf, ok := reqCnf["req_cnf"] + assert.True(t, ok) + assert.Equal(t, expectedCnf, cnf) + + // Validate FormatAccessToken. Here we just check that the custom claim "nodeId" was added. + popToken, err := popTokenScheme.FormatAccessToken("accessToken") + assert.Nil(t, err) + assert.NotEmpty(t, popToken) + + toks := strings.Split(popToken, ".") + assert.Equal(t, 3, len(toks)) + body, err := decodeFromBase64[TestPopTokenSchemeBody](toks[1]) + assert.Nil(t, err) + assert.Equal(t, expectedNodeId, body.NodeId) + +} diff --git a/pkg/auth/poptoken/raskeymanager_test.go b/pkg/auth/poptoken/raskeymanager_test.go index abebbdfa..e6b38925 100644 --- a/pkg/auth/poptoken/raskeymanager_test.go +++ b/pkg/auth/poptoken/raskeymanager_test.go @@ -7,11 +7,17 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_RsaKeyPairGetKeyPair(t *testing.T) { - rsamgr, err := NewRsaKeyManager(time.Hour * 1) +const ( + testRsaValidInterval = time.Hour * 1 + testRsaNowDateTimeStr = "2025-12-01T15:00:00Z" //time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") +) + +func Test_RsaKeyManagerGetKeyPair(t *testing.T) { + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + rsamgr, err := NewRsaKeyManager(testRsaValidInterval) assert.Nil(t, err) - rsa, err := rsamgr.GetKeyPair() + rsa, err := rsamgr.GetKeyPair(now) assert.Nil(t, err) assert.Equal(t, Alg, rsa.Alg) @@ -21,53 +27,23 @@ func Test_RsaKeyPairGetKeyPair(t *testing.T) { assert.NotNil(t, rsa.PublicKey) // now get the keypair a second time, if it has not refreshed, it will be the same value - rsa2, err := rsamgr.GetKeyPair() - //validate private key are same + rsa2, err := rsamgr.GetKeyPair(now) + //validate private key are equal value assert.Equal(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) } -func Test_RsaKeyPairRefresh(t *testing.T) { - rsamgr, err := NewRsaKeyManager(time.Second * 1) - assert.Nil(t, err) - - rsa, err := rsamgr.GetKeyPair() +func Test_RsaKeyManagerGetKeyPairRotated(t *testing.T) { + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + rsamgr, err := NewRsaKeyManager(testRsaValidInterval) assert.Nil(t, err) - time.Sleep(time.Second * 2) - - rsa2, err := rsamgr.GetKeyPair() - //validate private key are different the second time we ger it - assert.NotEqual(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) -} - -func Test_RsaKeyPairForceRefresh(t *testing.T) { - rsamgr, err := NewRsaKeyManager(time.Hour * 1) + rsa, err := rsamgr.GetKeyPair(now) assert.Nil(t, err) - rsa, err := rsamgr.GetKeyPair() + // now get the keypair a second time past the refresh interval, a new key should be generated. + rsa2, err := rsamgr.GetKeyPair(now.Add(testRsaValidInterval * 2)) assert.Nil(t, err) - rsamgr.ForceRefresh() - // wait for some time for it to respond, note the sleep here is far less than the refesh interval of 1 hour - time.Sleep(time.Second * 1) - - rsa2, err := rsamgr.GetKeyPair() - //validate private key are different the second time we ger it + //validate the two keys are now different. assert.NotEqual(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) } - -// validate keymanager will not deadlock if refresh happen quicker than get call. -func Test_RsaKeyPairNoDeadLock(t *testing.T) { - rsamgr, err := NewRsaKeyManager(time.Hour * 1) - assert.Nil(t, err) - - for i := 0; i < 5; i++ { - rsamgr.ForceRefresh() - // wait for some time for it to respond - time.Sleep(time.Second * 1) - } - - rsa2, err := rsamgr.GetKeyPair() - assert.Nil(t, err) - assert.NotNil(t, rsa2.PrivateKey) -} diff --git a/pkg/auth/poptoken/rsakeymanager.go b/pkg/auth/poptoken/rsakeymanager.go index 5daff647..0239c4f6 100644 --- a/pkg/auth/poptoken/rsakeymanager.go +++ b/pkg/auth/poptoken/rsakeymanager.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "sync" "time" ) @@ -17,61 +18,36 @@ type RsaKeyPair struct { // a RSA Key generator that refresh the RSA KeyPair at regular interval // Used to ensure the keys use to sign the poptoken are rotated -type RsaKeyManager struct { - refreshInterval time.Duration - refreshTicker *time.Ticker - keyPairChan chan *rsa.PrivateKey - forceRefreshChan chan bool - stopChan chan bool - privateKey *rsa.PrivateKey +type rsaKeyManager struct { + refreshInterval time.Duration + createdDateTime time.Time + privateKey *rsa.PrivateKey + mutex sync.Mutex } const ( - RsaSize = 2048 - Kty = "RSA" - Alg = "RS256" + DefaultRefreshInterval = time.Hour * 8 + RsaSize = 2048 + Kty = "RSA" + Alg = "RS256" ) func generatePrivateKey() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, RsaSize) } -// Background job to continuously refresh the keypair in a best effort basis. -func (r *RsaKeyManager) refreshPrivateKeyJob() { - for { - select { - case <-r.stopChan: - return - case <-r.refreshTicker.C: - r.tryRefreshPrivateKey() - case <-r.forceRefreshChan: - r.tryRefreshPrivateKey() - } - } -} +// Return a KeyPair. The keypair is its own copy and not a reference. +func (r *rsaKeyManager) GetKeyPair(now time.Time) (*RsaKeyPair, error) { + r.mutex.Lock() + defer r.mutex.Unlock() -// Generate new keypair and send it back to the main go routine. -func (r *RsaKeyManager) tryRefreshPrivateKey() { - privateKey, err := generatePrivateKey() - // generatePrivateKey() should not fail, we don't have a good way to surface this error - if err == nil { - // In the unlikely event the refresh rate happens faster than getting the key, - // drop the key to prevent deadlocking the channel - select { - case r.keyPairChan <- privateKey: - default: - // imply key is dropped if channel is full. + if r.createdDateTime.Add(r.refreshInterval).Before(now) { + newPKey, err := generatePrivateKey() + if err != nil { + return nil, err } - } -} - -// Return a KeyPair. The keypair is its own copy and not a reference. -func (r *RsaKeyManager) GetKeyPair() (*RsaKeyPair, error) { - // non blocking wait to get new private key if available - select { - case r.privateKey = <-r.keyPairChan: - default: - //continue to use existing key + r.privateKey = newPKey + r.createdDateTime = now } // Create and return a deep copy of the private key so clients are not impacted by a rotation midway. @@ -90,31 +66,15 @@ func (r *RsaKeyManager) GetKeyPair() (*RsaKeyPair, error) { }, nil } -// Force a refresh now. This can be use during test. -func (r *RsaKeyManager) ForceRefresh() { - r.forceRefreshChan <- true -} - -// Stop the refresh of the keypair. -func (r *RsaKeyManager) Stop() { - r.stopChan <- true -} - // Create a new RSAKeyManager that will refresh the keypair in the background. -func NewRsaKeyManager(refreshInterval time.Duration) (*RsaKeyManager, error) { +func NewRsaKeyManager(refreshInterval time.Duration) (*rsaKeyManager, error) { var err error - rsaMgr := &RsaKeyManager{} + rsaMgr := &rsaKeyManager{} rsaMgr.refreshInterval = refreshInterval - rsaMgr.refreshTicker = time.NewTicker(rsaMgr.refreshInterval) rsaMgr.privateKey, err = generatePrivateKey() if err != nil { return nil, err } - rsaMgr.forceRefreshChan = make(chan bool) - rsaMgr.stopChan = make(chan bool) - rsaMgr.keyPairChan = make(chan *rsa.PrivateKey, 2) - - go rsaMgr.refreshPrivateKeyJob() return rsaMgr, nil } diff --git a/pkg/auth/poptoken/shrpoptoken.go b/pkg/auth/poptoken/shrpoptoken.go index ccd11824..8b65b9fb 100644 --- a/pkg/auth/poptoken/shrpoptoken.go +++ b/pkg/auth/poptoken/shrpoptoken.go @@ -24,7 +24,7 @@ type ShrPopHeader struct { // https://datatracker.ietf.org/doc/html/rfc7638#section-3.1 // contains the metadata use to calculate kid. -type JwkInner struct { +type Jwk struct { // Exponent E string `json:"e"` // encryption @@ -33,16 +33,9 @@ type JwkInner struct { N string `json:"n"` } -type Jwk struct { - JwkInner - // public key kid - Kid string `json:"kid"` -} - // https://datatracker.ietf.org/doc/html/rfc7800#section-3.2 type Cnf struct { - Jwk Jwk `json:"jwk"` - Xms_ksl string `json:"xms_ksl"` + Jwk Jwk `json:"jwk"` } type ReqCnf struct { @@ -56,32 +49,28 @@ type ShrPopTokenBody struct { Ts int64 `json:"ts"` // access token At string `json:"at"` - // random unique value to prevent replay attack. not used - NonCe string `json:"nonce"` } // Implements the shr pop token generically. Callers of ths instance can add their own custom claims when generating the token. -type ShrPopToken struct { - Header ShrPopHeader - Body ShrPopTokenBody - ReqCnf ReqCnf - RSAKeyPair *RsaKeyPair +type shrPopToken struct { + Header ShrPopHeader + Body ShrPopTokenBody + refCnfBase64 string + RSAKeyPair *RsaKeyPair } const ( TokenType = "pop" + //TokenType = "JWT" ) -func calculatePublicKeyId(jwkInner *JwkInner) (string, error) { +func calculatePublicKeyId(jwk *Jwk) (string, error) { // - https://tools.ietf.org/html/rfc7638#section-3.1 - jwkByte, err := json.Marshal(jwkInner) + jwkByte, err := json.Marshal(jwk) if err != nil { return "", err } jwk256 := sha256.Sum256(jwkByte) - if err != nil { - return "", err - } return base64.RawURLEncoding.EncodeToString(jwk256[:]), nil } @@ -111,12 +100,12 @@ func exponential2Base64(e int) string { binary.BigEndian.PutUint32(bs, uint32(e)) bs = bs[1:] // drop most significant byte - leaving least-significant 3-bytes - ss := base64.URLEncoding.EncodeToString(bs) + ss := base64.RawURLEncoding.EncodeToString(bs) return ss } // Append custom claims to the existing ShrPopTokenBody. -func (pop *ShrPopToken) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { +func (pop *shrPopToken) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { bodyMap := make(map[string]interface{}) @@ -137,7 +126,7 @@ func (pop *ShrPopToken) appendCustomClaimsToBody(customClaims map[string]interfa } // Complete the poptoken creation by adding the custom claims and signing it. -func (pop *ShrPopToken) GenerateToken(token string, now time.Time, customClaims map[string]interface{}) (string, error) { +func (pop *shrPopToken) GenerateToken(token string, now time.Time, customClaims map[string]interface{}) (string, error) { pop.Body.Ts = now.Truncate(time.Second).Unix() pop.Body.At = token @@ -163,18 +152,14 @@ func (pop *ShrPopToken) GenerateToken(token string, now time.Time, customClaims } // Generate ReqCnf to be passed to Msal -func (pop *ShrPopToken) GetReqCnf() (string, error) { - refCnfb64, err := jsonToBase64(pop.ReqCnf) - if err != nil { - return "", err - } - return refCnfb64, nil +func (pop *shrPopToken) GetReqCnf() string { + return pop.refCnfBase64 } // Create a new instance of ShrPopToken. This generate a partial filled, generic shrpoptoken. The custom claims will be // added later on in GenerateToken() -func NewPopToken(keyPair *RsaKeyPair) (*ShrPopToken, error) { - pop := ShrPopToken{ +func NewPopToken(keyPair *RsaKeyPair) (*shrPopToken, error) { + pop := shrPopToken{ Header: ShrPopHeader{ Alg: keyPair.Alg, Typ: TokenType, @@ -182,25 +167,31 @@ func NewPopToken(keyPair *RsaKeyPair) (*ShrPopToken, error) { Body: ShrPopTokenBody{ Cnf: Cnf{ Jwk: Jwk{ - JwkInner: JwkInner{ - Kty: keyPair.Kty, - E: exponential2Base64(keyPair.PublicKey.E), - N: base64.URLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.String())), - }, + Kty: keyPair.Kty, + N: base64.RawURLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.Bytes())), + E: exponential2Base64(keyPair.PublicKey.E), }, }, }, - ReqCnf: ReqCnf{}, RSAKeyPair: keyPair, } - keyId, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk.JwkInner) + keyId, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk) if err != nil { return nil, err } + pop.Header.Kid = keyId - pop.ReqCnf.Kid = keyId - pop.Body.Cnf.Jwk.Kid = keyId + //pop.Body.Cnf.Jwk.Kid = keyId + + refCnfb64, err := jsonToBase64( + ReqCnf{ + Kid: keyId, + }) + if err != nil { + return nil, err + } + pop.refCnfBase64 = refCnfb64 return &pop, err } diff --git a/pkg/auth/poptoken/shrpoptoken_test.go b/pkg/auth/poptoken/shrpoptoken_test.go index dd4ab84a..2bf019d1 100644 --- a/pkg/auth/poptoken/shrpoptoken_test.go +++ b/pkg/auth/poptoken/shrpoptoken_test.go @@ -17,6 +17,14 @@ type testStruct struct { IntValue int `json:"int"` } +type TestPopTokenBody struct { + ShrPopTokenBody + // target node Id. + NodeId string `json:"nodeid"` + // uri to the grpc object targeted + ObjectPath string `json:"p"` +} + // the pop token is partially filled out upon calling NewPopToken func Test_ShrPopTokenNewPopToken(t *testing.T) { keypair, err := getKeyPair() @@ -26,7 +34,7 @@ func Test_ShrPopTokenNewPopToken(t *testing.T) { assert.Nil(t, err) // calculate kid - expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk.JwkInner) + expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk) assert.Nil(t, err) // check header @@ -36,20 +44,16 @@ func Test_ShrPopTokenNewPopToken(t *testing.T) { // check body expectedE := exponential2Base64(keypair.PrivateKey.E) - expectedN := base64.URLEncoding.EncodeToString([]byte(keypair.PublicKey.N.String())) + expectedN := base64.RawURLEncoding.EncodeToString([]byte(keypair.PublicKey.N.Bytes())) assert.Equal(t, expectedE, pop.Body.Cnf.Jwk.E) assert.Equal(t, expectedN, pop.Body.Cnf.Jwk.N) assert.Equal(t, keypair.Kty, pop.Body.Cnf.Jwk.Kty) - assert.Equal(t, expectedKid, pop.Body.Cnf.Jwk.Kid) - // check ReqCnf - assert.Equal(t, expectedKid, pop.ReqCnf.Kid) } func Test_ShrPopTokenGenerateToken(t *testing.T) { keypair, err := getKeyPair() assert.Nil(t, err) - pop, err := NewPopToken(keypair) assert.Nil(t, err) @@ -58,10 +62,12 @@ func Test_ShrPopTokenGenerateToken(t *testing.T) { assert.Nil(t, err) expectedResourceIdValue := "1234" - customClaims := map[string]interface{}{"resourceId": expectedResourceIdValue} + expectedObjectPathValue := "myObject" + // note: name of the claims must match the names of the entities in test struct testPopTokenBody + customClaims := map[string]interface{}{"nodeId": expectedResourceIdValue, "p": expectedObjectPathValue} // calculate kid - expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk.JwkInner) + expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk) assert.Nil(t, err) // Generate the token and validate its content @@ -79,12 +85,12 @@ func Test_ShrPopTokenGenerateToken(t *testing.T) { assert.Equal(t, expectedKid, header.Kid) // validate body - body, err := decodeFromBase64[nodeAgentPopTokenBody](toks[1]) + body, err := decodeFromBase64[TestPopTokenBody](toks[1]) assert.Nil(t, err) assert.Equal(t, expectedTimeStamp.Truncate(time.Second).Unix(), body.Ts) - assert.Equal(t, expectedResourceIdValue, body.ResourceId) + assert.Equal(t, expectedResourceIdValue, body.NodeId) + assert.Equal(t, expectedObjectPathValue, body.ObjectPath) assert.Equal(t, expectedAccessToken, body.At) - assert.Equal(t, expectedKid, body.Cnf.Jwk.Kid) // validate signature. signature, err := base64.RawURLEncoding.DecodeString(toks[2]) @@ -152,10 +158,13 @@ func Test_ShrPopTokenGetReqCnf(t *testing.T) { pop, err := NewPopToken(keypair) assert.Nil(t, err) - expectedReqCnfBase64, err := jsonToBase64(pop.ReqCnf) + expectedReqCnfBase64, err := jsonToBase64( + ReqCnf{ + Kid: pop.Header.Kid, + }) assert.Nil(t, err) - actualreqCnfBase64, err := pop.GetReqCnf() + actualreqCnfBase64 := pop.GetReqCnf() assert.Equal(t, expectedReqCnfBase64, actualreqCnfBase64) } @@ -167,12 +176,12 @@ func Test_ShrPopTokenExponential2Base64(t *testing.T) { } func Test_ShrPopTokenCalculatePublicKeyId(t *testing.T) { - jwkinner := JwkInner{ + jwk := Jwk{ Kty: "RSA", E: "AQAB", N: "MjM1MDg5MDU4MzgxMDg3OTI5NTU3NjM1ODg4NTA3NDE5OTAwNzc0MzkzNzQ5NDcwNzcwMjA2MDIxNjMyNzk5NzYxNDM4NTczMjc3NTA0NzI4ODkzNDUzNjU0NDU0NjMxMjcxNjQ0MTAwMDM0NzUzNzU2MTEyMjkzODYzMDYxMjk5MDQxNzI5OTc0MDg5OTk2OTEzNTY4MjM5OTc0NDMwNTExODI3MDgyNDAzMDQxNDMxMTQ5ODA4ODc4NjE5NTc5MjcwMjAxNjc3ODM1NTQ0NDI3NDMwMDczODI2OTAwODk2MzcxNTM2NzE5NDQyNTUxNzIzNTM5MTg4OTU2MDc4MzI0MzYxNDM4MDEzNjA3OTI0NzMyNTUxMDg5ODU3NjQ1NDA0MTIyMTk3ODUwNjkyMjEyMTk4OTMxMDU1NTkzOTk4NzYyMjIwODg1NDg5NzE4MjQxNDAxMTg2MTMwMzExODAwMDQ2NjEwMjk0MDIzMzQ1MTA1NjE4ODY0ODc0OTgzNzU2NTMzMTY0OTk5MTg1NDk4ODIwOTY3NjYyNjM1NTUxMjk0NTkzNDEwNzc5MzUwODg2MjMxODkyMTc0NTcwODkxNDU4MjIwNzIwMzI5MTg3OTA3NzAxMzMzMDU1NzM0ODk0NjU3MDYzOTMzMzA3MTUwNjgzMTk1NjkyOTk0MzAxMjUxODUwNzUwMTg2MzI5MzM4ODk2NjY3OTQyMDE0OTcwODY3MTAzMTgxNTA5NDAxMTAwMzUwMzk5MDE3MDI3MTI3MTAwMDM5OTIwNjgwNjExNjcxNTQ3MDE1ODM2NzIyMTU1OTgxMTE=", } - keyId, err := calculatePublicKeyId(&jwkinner) + keyId, err := calculatePublicKeyId(&jwk) assert.Nil(t, err) assert.Equal(t, "a0CyVS__Npcx4GXYm1OCoxrlboOWKF02MXzSSh92ckY", keyId) } From 6f12a53a90bd6117d572c78fa4e7c1af326076fe Mon Sep 17 00:00:00 2001 From: wecha Date: Fri, 27 Jun 2025 20:41:39 +0000 Subject: [PATCH 3/7] temp --- .../poptoken/nodeagentpoptokenvalidator.go | 6 ++--- .../nodeagentpoptokenvalidator_test.go | 10 +++---- pkg/auth/poptoken/noncecache.go | 6 ++--- pkg/auth/poptoken/noncecache_test.go | 6 ----- .../poptoken/{shrpoptoken.go => poptoken.go} | 26 +++++++++---------- .../{shrpoptoken_test.go => poptoken_test.go} | 4 +-- pkg/auth/poptoken/poptokenscheme.go | 2 +- pkg/auth/poptoken/poptokenscheme_test.go | 2 +- pkg/auth/poptoken/rsakeymanager.go | 5 ---- ...ymanager_test.go => rsakeymanager_test.go} | 0 10 files changed, 26 insertions(+), 41 deletions(-) rename pkg/auth/poptoken/{shrpoptoken.go => poptoken.go} (86%) rename pkg/auth/poptoken/{shrpoptoken_test.go => poptoken_test.go} (98%) rename pkg/auth/poptoken/{raskeymanager_test.go => rsakeymanager_test.go} (100%) diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go index fb1936ac..f9ba017e 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go @@ -26,7 +26,7 @@ const ( ) type NodeAgentPopTokenBody struct { - ShrPopTokenBody + PopTokenBody // target node Id. this is expected to be the Arc For Server resource Id. NodeId string `json:"nodeid"` // uri to the grpc object targeted @@ -90,7 +90,7 @@ func isTokenExpire(timestamp int64, now time.Time, clockSkew time.Duration) erro return nil } -func isHeaderValid(header *ShrPopHeader) error { +func isHeaderValid(header *PopTokenHeader) error { if header.Typ != TokenType { return fmt.Errorf("unsupported token type in pop token header; expected %s, got %s", TokenType, header.Typ) } @@ -291,7 +291,7 @@ func (s *shrPopTokenValidator) Validate(popToken string) error { return fmt.Errorf("invalid pop token; expected 3 segments, got %d", len(toks)) } - header, err := decodeFromBase64[ShrPopHeader](toks[0]) + header, err := decodeFromBase64[PopTokenHeader](toks[0]) if err != nil { return err } diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go index 048d9dac..91d6e04f 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go @@ -91,28 +91,28 @@ func Test_NodeAgentPopTokenValidatorIsTokenExpire(t *testing.T) { func Test_NodeAgentPopTokenValidatorIsHeaderValid(t *testing.T) { tests := []struct { name string - header ShrPopHeader + header PopTokenHeader shouldPass bool }{ { name: "valid header", shouldPass: true, - header: ShrPopHeader{Alg: Alg, Typ: TokenType}, + header: PopTokenHeader{Alg: Alg, Typ: TokenType}, }, { name: "invalid alg", shouldPass: false, - header: ShrPopHeader{Alg: "RSA123", Typ: TokenType}, + header: PopTokenHeader{Alg: "RSA123", Typ: TokenType}, }, { name: "invalid typ", shouldPass: false, - header: ShrPopHeader{Alg: Alg, Typ: "jwt"}, + header: PopTokenHeader{Alg: Alg, Typ: "jwt"}, }, { name: "empty header", shouldPass: false, - header: ShrPopHeader{}, + header: PopTokenHeader{}, }, } diff --git a/pkg/auth/poptoken/noncecache.go b/pkg/auth/poptoken/noncecache.go index 307b8cc4..232ba853 100644 --- a/pkg/auth/poptoken/noncecache.go +++ b/pkg/auth/poptoken/noncecache.go @@ -64,11 +64,9 @@ func (n *nonceCache) IsNonceExists(nonceId string, now time.Time) bool { n.mutex.Lock() defer n.mutex.Unlock() - existNonce, ok := n.cache[nonceId] - // entries are evicted lazily, so if entry has expired, return false even though the entry - // is still in the cache. + _, ok := n.cache[nonceId] if ok { - return existNonce.CreatedDateTime.Add(n.nonceValidInterval).After(now) + return ok } nonce := &Nonce{ diff --git a/pkg/auth/poptoken/noncecache_test.go b/pkg/auth/poptoken/noncecache_test.go index 2c2cbfcc..5288e734 100644 --- a/pkg/auth/poptoken/noncecache_test.go +++ b/pkg/auth/poptoken/noncecache_test.go @@ -50,12 +50,6 @@ func Test_NonceCacheIdExistsButExpired(t *testing.T) { // the second time the nonceId should be cached. isexist = noncecache.IsNonceExists(nonceId, now) assert.True(t, isexist) - - // now simulate checking for the nonceId past the valid period and try again - // this time it should return false. - now = now.Add(testNonceValidInterval * 2) - isexist = noncecache.IsNonceExists(nonceId, now) - assert.False(t, isexist) } // Validate that expired Ids will be evicted upon the addition of a new entry. diff --git a/pkg/auth/poptoken/shrpoptoken.go b/pkg/auth/poptoken/poptoken.go similarity index 86% rename from pkg/auth/poptoken/shrpoptoken.go rename to pkg/auth/poptoken/poptoken.go index 8b65b9fb..ea2a59a7 100644 --- a/pkg/auth/poptoken/shrpoptoken.go +++ b/pkg/auth/poptoken/poptoken.go @@ -13,7 +13,7 @@ import ( "time" ) -type ShrPopHeader struct { +type PopTokenHeader struct { // RSA PS256? Alg string `json:"alg"` // key Id of public key @@ -43,7 +43,7 @@ type ReqCnf struct { } // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-signed-http-request-03#section-3 -type ShrPopTokenBody struct { +type PopTokenBody struct { Cnf Cnf `json:"cnf"` // timestamp Ts int64 `json:"ts"` @@ -52,16 +52,15 @@ type ShrPopTokenBody struct { } // Implements the shr pop token generically. Callers of ths instance can add their own custom claims when generating the token. -type shrPopToken struct { - Header ShrPopHeader - Body ShrPopTokenBody +type popToken struct { + Header PopTokenHeader + Body PopTokenBody refCnfBase64 string RSAKeyPair *RsaKeyPair } const ( TokenType = "pop" - //TokenType = "JWT" ) func calculatePublicKeyId(jwk *Jwk) (string, error) { @@ -105,7 +104,7 @@ func exponential2Base64(e int) string { } // Append custom claims to the existing ShrPopTokenBody. -func (pop *shrPopToken) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { +func (pop *popToken) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { bodyMap := make(map[string]interface{}) @@ -126,7 +125,7 @@ func (pop *shrPopToken) appendCustomClaimsToBody(customClaims map[string]interfa } // Complete the poptoken creation by adding the custom claims and signing it. -func (pop *shrPopToken) GenerateToken(token string, now time.Time, customClaims map[string]interface{}) (string, error) { +func (pop *popToken) GenerateToken(token string, now time.Time, customClaims map[string]interface{}) (string, error) { pop.Body.Ts = now.Truncate(time.Second).Unix() pop.Body.At = token @@ -152,19 +151,19 @@ func (pop *shrPopToken) GenerateToken(token string, now time.Time, customClaims } // Generate ReqCnf to be passed to Msal -func (pop *shrPopToken) GetReqCnf() string { +func (pop *popToken) GetReqCnf() string { return pop.refCnfBase64 } // Create a new instance of ShrPopToken. This generate a partial filled, generic shrpoptoken. The custom claims will be // added later on in GenerateToken() -func NewPopToken(keyPair *RsaKeyPair) (*shrPopToken, error) { - pop := shrPopToken{ - Header: ShrPopHeader{ +func NewPopToken(keyPair *RsaKeyPair) (*popToken, error) { + pop := popToken{ + Header: PopTokenHeader{ Alg: keyPair.Alg, Typ: TokenType, }, - Body: ShrPopTokenBody{ + Body: PopTokenBody{ Cnf: Cnf{ Jwk: Jwk{ Kty: keyPair.Kty, @@ -182,7 +181,6 @@ func NewPopToken(keyPair *RsaKeyPair) (*shrPopToken, error) { } pop.Header.Kid = keyId - //pop.Body.Cnf.Jwk.Kid = keyId refCnfb64, err := jsonToBase64( ReqCnf{ diff --git a/pkg/auth/poptoken/shrpoptoken_test.go b/pkg/auth/poptoken/poptoken_test.go similarity index 98% rename from pkg/auth/poptoken/shrpoptoken_test.go rename to pkg/auth/poptoken/poptoken_test.go index 2bf019d1..7152eb49 100644 --- a/pkg/auth/poptoken/shrpoptoken_test.go +++ b/pkg/auth/poptoken/poptoken_test.go @@ -18,7 +18,7 @@ type testStruct struct { } type TestPopTokenBody struct { - ShrPopTokenBody + PopTokenBody // target node Id. NodeId string `json:"nodeid"` // uri to the grpc object targeted @@ -78,7 +78,7 @@ func Test_ShrPopTokenGenerateToken(t *testing.T) { assert.Equal(t, 3, len(toks)) // validate header. - header, err := decodeFromBase64[ShrPopHeader](toks[0]) + header, err := decodeFromBase64[PopTokenHeader](toks[0]) assert.Nil(t, err) assert.Equal(t, Alg, header.Alg) assert.Equal(t, TokenType, header.Typ) diff --git a/pkg/auth/poptoken/poptokenscheme.go b/pkg/auth/poptoken/poptokenscheme.go index 281813de..43f49fe9 100644 --- a/pkg/auth/poptoken/poptokenscheme.go +++ b/pkg/auth/poptoken/poptokenscheme.go @@ -13,7 +13,7 @@ const ( // Implements the interface for MSAL SDK to callback when creating the poptoken. // See AuthenticationScheme interface in https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/main/apps/internal/oauth/ops/authority/authority.go#L146 type PopTokenAuthScheme struct { - shrPopToken *shrPopToken + shrPopToken *popToken claims map[string]interface{} currRsaKeyPair *RsaKeyPair rsaKeyPairRefreshInterval time.Duration diff --git a/pkg/auth/poptoken/poptokenscheme_test.go b/pkg/auth/poptoken/poptokenscheme_test.go index aaee92e0..ee7dba1c 100644 --- a/pkg/auth/poptoken/poptokenscheme_test.go +++ b/pkg/auth/poptoken/poptokenscheme_test.go @@ -9,7 +9,7 @@ import ( ) type TestPopTokenSchemeBody struct { - ShrPopTokenBody + PopTokenBody NodeId string `json:"nodeid"` } diff --git a/pkg/auth/poptoken/rsakeymanager.go b/pkg/auth/poptoken/rsakeymanager.go index 0239c4f6..67fa3ef4 100644 --- a/pkg/auth/poptoken/rsakeymanager.go +++ b/pkg/auth/poptoken/rsakeymanager.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "sync" "time" ) @@ -22,7 +21,6 @@ type rsaKeyManager struct { refreshInterval time.Duration createdDateTime time.Time privateKey *rsa.PrivateKey - mutex sync.Mutex } const ( @@ -38,9 +36,6 @@ func generatePrivateKey() (*rsa.PrivateKey, error) { // Return a KeyPair. The keypair is its own copy and not a reference. func (r *rsaKeyManager) GetKeyPair(now time.Time) (*RsaKeyPair, error) { - r.mutex.Lock() - defer r.mutex.Unlock() - if r.createdDateTime.Add(r.refreshInterval).Before(now) { newPKey, err := generatePrivateKey() if err != nil { diff --git a/pkg/auth/poptoken/raskeymanager_test.go b/pkg/auth/poptoken/rsakeymanager_test.go similarity index 100% rename from pkg/auth/poptoken/raskeymanager_test.go rename to pkg/auth/poptoken/rsakeymanager_test.go From 2dc2fd1a5a2017b43f2df02717ff8d86c713f3cd Mon Sep 17 00:00:00 2001 From: wecha Date: Sat, 28 Jun 2025 01:10:15 +0000 Subject: [PATCH 4/7] merged poptoken and rsakeymgr logic into poptokenscheme --- pkg/auth/poptoken/msalauthprovider.go | 17 +- pkg/auth/poptoken/nodeagentpoptokenscheme.go | 26 +- .../poptoken/nodeagentpoptokenscheme_test.go | 9 +- .../poptoken/nodeagentpoptokenvalidator.go | 4 +- .../nodeagentpoptokenvalidator_test.go | 20 +- pkg/auth/poptoken/noncecache.go | 26 +- pkg/auth/poptoken/noncecache_test.go | 32 +-- pkg/auth/poptoken/poptoken.go | 195 -------------- pkg/auth/poptoken/poptoken_test.go | 220 --------------- pkg/auth/poptoken/poptokenscheme.go | 253 ++++++++++++++++-- pkg/auth/poptoken/poptokenscheme_test.go | 244 ++++++++++++++--- pkg/auth/poptoken/rsakeymanager.go | 75 ------ pkg/auth/poptoken/rsakeymanager_test.go | 49 ---- 13 files changed, 483 insertions(+), 687 deletions(-) delete mode 100644 pkg/auth/poptoken/poptoken.go delete mode 100644 pkg/auth/poptoken/poptoken_test.go delete mode 100644 pkg/auth/poptoken/rsakeymanager.go delete mode 100644 pkg/auth/poptoken/rsakeymanager_test.go diff --git a/pkg/auth/poptoken/msalauthprovider.go b/pkg/auth/poptoken/msalauthprovider.go index 6af538fa..ec8b5428 100644 --- a/pkg/auth/poptoken/msalauthprovider.go +++ b/pkg/auth/poptoken/msalauthprovider.go @@ -3,7 +3,6 @@ package poptoken import ( "context" "os" - "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/pkg/errors" @@ -17,7 +16,6 @@ type MsalAuthProvider struct { authorityUrl string scope []string clientCertPath string - rsaKeyManager *rsaKeyManager } func (m MsalAuthProvider) refreshConfidentialClient() (*confidential.Client, error) { @@ -58,12 +56,7 @@ func (m MsalAuthProvider) GetToken(targetResourceId string, grpcObjectPath strin return "", err } - keyPair, err := m.rsaKeyManager.GetKeyPair(time.Now()) - if err != nil { - return "", errors.Wrapf(err, "failed to get keypair for pop token") - } - - popTokenScheme, err := NewNodeAgentPopTokenAuthScheme(targetResourceId, grpcObjectPath, keyPair) + popTokenScheme, err := NewNodeAgentPopTokenAuthScheme(targetResourceId, grpcObjectPath) if err != nil { return "", errors.Wrapf(err, "failed to create new pop token scheme") } @@ -76,22 +69,16 @@ func (m MsalAuthProvider) GetToken(targetResourceId string, grpcObjectPath strin } func NewMsalClient(clientId string, tenantId, authorityUrl string, clientCertPath string) (*MsalAuthProvider, error) { - rsaKeyManager, err := NewRsaKeyManager(DefaultRefreshInterval) - if err != nil { - return nil, err - } - m := &MsalAuthProvider{ clientId: clientId, tenantId: tenantId, authorityUrl: appendUrl(authorityUrl, tenantId), clientCertPath: clientCertPath, scope: []string{appendUrl(clientId, ".default")}, // intentionally target itself as the pop token custom claim will contain the actual audience. - rsaKeyManager: rsaKeyManager, } // sanity check to ensure client is setup correctly - _, err = m.refreshConfidentialClient() + _, err := m.refreshConfidentialClient() if err != nil { return nil, err } diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme.go b/pkg/auth/poptoken/nodeagentpoptokenscheme.go index ead6abec..3cbafd22 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenscheme.go +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme.go @@ -8,38 +8,16 @@ type NodeAgentPopTokenAuthScheme struct { *PopTokenAuthScheme } -// Return the claim containg the pop token kid that will be added to the Entra access token. -func (a *NodeAgentPopTokenAuthScheme) TokenRequestParams() map[string]string { - - return a.PopTokenAuthScheme.TokenRequestParams() -} - -// Return the keyId for MSAL to lookup for a cached access token. If it does not exist, MSAL will request a new access token -func (a *NodeAgentPopTokenAuthScheme) KeyID() string { - return a.PopTokenAuthScheme.KeyID() -} - -// Generate the pop token; adding in the accessToken generated by Entra. -func (a *NodeAgentPopTokenAuthScheme) FormatAccessToken(accessToken string) (string, error) { - return a.PopTokenAuthScheme.FormatAccessToken(accessToken) -} - -// Return the token type. Must be "pop" -func (a *NodeAgentPopTokenAuthScheme) AccessTokenType() string { - return a.PopTokenAuthScheme.AccessTokenType() -} - // Create a new instance of NodeAgentPopTokenAuthScheme. // targetResourceId: the ARM resourceId representing the edge node machine. This is the Arc For Server resource Id and is part of the node entity. // grpcObjectId: the uri to the grpc entity, e.g. container. This will be passed in as part of the grpc metadata. -func NewNodeAgentPopTokenAuthScheme(targetNodeId string, grpcObjectId string, rsaKeyPair *RsaKeyPair) (*NodeAgentPopTokenAuthScheme, error) { +func NewNodeAgentPopTokenAuthScheme(targetNodeId string, grpcObjectId string) (*NodeAgentPopTokenAuthScheme, error) { popTokenScheme, err := NewPopTokenAuthScheme( map[string]interface{}{ "nodeid": targetNodeId, "p": grpcObjectId, "nonce": uuid.New().String(), - }, - rsaKeyPair) + }) if err != nil { return nil, err } diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go index 13cadf5a..0cd55d5c 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go @@ -3,7 +3,6 @@ package poptoken import ( "strings" "testing" - "time" "github.com/stretchr/testify/assert" ) @@ -14,14 +13,8 @@ func Test_NodeAgentPopTokenScheme(t *testing.T) { expectedNodeId := "mynodeId" expectedGrpcObjectId := "myObjectId" - kmgr, err := NewRsaKeyManager(time.Hour) - assert.Nil(t, err) - - keypair, err := kmgr.GetKeyPair(time.Now()) - assert.Nil(t, err) - // Generate nodeagent scheme - nodeAgentScheme, err := NewNodeAgentPopTokenAuthScheme(expectedNodeId, expectedGrpcObjectId, keypair) + nodeAgentScheme, err := NewNodeAgentPopTokenAuthScheme(expectedNodeId, expectedGrpcObjectId) //For nodeagentpoptokenscheme, we just verify that the custom claims were added to the token. popToken, err := nodeAgentScheme.FormatAccessToken("accessToken") diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go index f9ba017e..ba845fc5 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go @@ -22,7 +22,7 @@ const ( TokenVersion1 = "1.0" TokenVersion2 = "2.0" PopTokenValidInterval = 5 * time.Minute - PopTokenClockSkew = 1 * time.Minute + PopTokenClockSkew = 5 * time.Minute ) type NodeAgentPopTokenBody struct { @@ -278,7 +278,7 @@ func (s *shrPopTokenValidator) isTokenReused(nonce string, now time.Time) error return fmt.Errorf("claim 'nonce' in pop token is empty/missing, expected a unique id") } - ok := s.nonceCache.IsNonceExists(nonce, now) + ok := s.nonceCache.IsNonceExists(nonce, now, PopTokenValidInterval) if ok { return fmt.Errorf("claim 'nonce' in pop token has been used before, potentially a replay attack") } diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go index 91d6e04f..b5299fa6 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go @@ -152,7 +152,7 @@ func Test_NodeAgentPopTokenValidatorIsTokenReused(t *testing.T) { } func Test_NodeAgentPopTokenValidatorIsSignatureValid(t *testing.T) { - keypair, err := getKeyPair() + keypair, err := generateKeyPair() assert.Nil(t, err) payload := []byte("ThisIsPayload") @@ -182,7 +182,7 @@ func Test_NodeAgentPopTokenValidatorIsSignatureValid(t *testing.T) { assert.NotNil(t, err) // simulate wrong public key, expect failure - newKeyPair, err := getKeyPair() + newKeyPair, err := generateKeyPair() assert.Nil(t, err) misMatachCnf := publicKeyToCnf(newKeyPair) @@ -467,17 +467,17 @@ func Test_NodeAgentPopTokenValidatorValidate(t *testing.T) { jwkMgr := &FakeJwrMgr{PublicKey: &accessTokenPKey.PublicKey} nonceCache := &FakeNonceCache{Exists: false} - rsaKeyPair, err := getKeyPair() - assert.Nil(t, err) + //rsaKeyPair, err := generateKeyPair() + //assert.Nil(t, err) // partial generate pop token, we need to add the popKid into the accesstoken - popToken, err := NewPopToken(rsaKeyPair) + popToken, err := NewNodeAgentPopTokenAuthScheme(nodeId, grpcObjectId) assert.Nil(t, err) // Generate access token claims := AccessTokenCustomClaims{ Tid: tenantId, - ReqCnf: ReqCnf{Kid: popToken.Header.Kid}, + ReqCnf: ReqCnf{Kid: popToken.header.Kid}, Azp: clientId, AppId: clientId, TokenVersion: tokenVersion, @@ -494,7 +494,7 @@ func Test_NodeAgentPopTokenValidatorValidate(t *testing.T) { assert.Nil(t, err) // Generate pop token - pt, err := popToken.GenerateToken(at, time.Now(), map[string]interface{}{"nodeId": nodeId, "p": grpcObjectId, "nonce": "nonceId"}) + pt, err := popToken.FormatAccessToken(at) assert.Nil(t, err) // validate poptoken @@ -558,7 +558,7 @@ type FakeNonceCache struct { Exists bool } -func (n *FakeNonceCache) IsNonceExists(nonceId string, now time.Time) bool { +func (n *FakeNonceCache) IsNonceExists(nonceId string, now time.Time, tokenValidInterval time.Duration) bool { return n.Exists } @@ -566,10 +566,10 @@ func getPrivateKey() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, RsaSize) } -func publicKeyToCnf(keyPair *RsaKeyPair) *Cnf { +func publicKeyToCnf(keyPair *rsaKeyPair) *Cnf { return &Cnf{ Jwk: Jwk{ - Kty: keyPair.Kty, + Kty: Kty, E: exponential2Base64(keyPair.PublicKey.E), N: base64.RawURLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.Bytes())), }, diff --git a/pkg/auth/poptoken/noncecache.go b/pkg/auth/poptoken/noncecache.go index 232ba853..0e75a57e 100644 --- a/pkg/auth/poptoken/noncecache.go +++ b/pkg/auth/poptoken/noncecache.go @@ -6,17 +6,16 @@ import ( ) const ( - DefaultNonceCacheSize = 20 - DefaultNonceValidInterval = time.Minute * 5 + DefaultNonceCacheSize = 20 ) type NonceCacheInterface interface { - IsNonceExists(nonceId string, now time.Time) bool + IsNonceExists(nonceId string, now time.Time, tokenValidInterval time.Duration) bool } type Nonce struct { - Id string - CreatedDateTime time.Time + Id string + ExpireAtDateTime time.Time } // Implement a simple LRU cache that evicts older nonce entries. @@ -47,7 +46,7 @@ func (n *nonceCache) trim(now time.Time) { } nonce := n.queue[0] - isDelete = nonce.CreatedDateTime.Add(n.nonceValidInterval).Before(now) + isDelete = nonce.ExpireAtDateTime.Before(now) if !isDelete { isDelete = n.size >= n.maxSize @@ -60,7 +59,7 @@ func (n *nonceCache) trim(now time.Time) { } } -func (n *nonceCache) IsNonceExists(nonceId string, now time.Time) bool { +func (n *nonceCache) IsNonceExists(nonceId string, now time.Time, tokenValidInterval time.Duration) bool { n.mutex.Lock() defer n.mutex.Unlock() @@ -70,8 +69,8 @@ func (n *nonceCache) IsNonceExists(nonceId string, now time.Time) bool { } nonce := &Nonce{ - Id: nonceId, - CreatedDateTime: now, + Id: nonceId, + ExpireAtDateTime: now.Add(tokenValidInterval), } n.append(nonce) n.trim(now) @@ -84,11 +83,10 @@ func (n *nonceCache) GetCacheSize() int { return n.size } -func NewNonceCache(maxSize int, nonceValidPeriod time.Duration) (*nonceCache, error) { +func NewNonceCache(maxSize int) (*nonceCache, error) { return &nonceCache{ - cache: make(map[string]*Nonce), - nonceValidInterval: nonceValidPeriod, - size: 0, - maxSize: maxSize, + cache: make(map[string]*Nonce), + size: 0, + maxSize: maxSize, }, nil } diff --git a/pkg/auth/poptoken/noncecache_test.go b/pkg/auth/poptoken/noncecache_test.go index 5288e734..afb57a2d 100644 --- a/pkg/auth/poptoken/noncecache_test.go +++ b/pkg/auth/poptoken/noncecache_test.go @@ -15,57 +15,57 @@ const ( ) func Test_NonceCacheIdExists(t *testing.T) { - noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + noncecache, err := NewNonceCache(testNonceCacheSize) assert.Nil(t, err) nonceId := "nonceId_1" now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) // first time seeing this nodeId, returning false - isexist := noncecache.IsNonceExists(nonceId, now) + isexist := noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) assert.False(t, isexist) // the second time the nonceId should be cached. - isexist = noncecache.IsNonceExists(nonceId, now) + isexist = noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) assert.True(t, isexist) // Validate a new entry will return false - isexist = noncecache.IsNonceExists("nonceId_2", now) + isexist = noncecache.IsNonceExists("nonceId_2", now, testNonceValidInterval) assert.False(t, isexist) } // expired entries are lazily evicted, so an invalid entry can remain in the cache // validate that even the expired entry exists, we will still return false. func Test_NonceCacheIdExistsButExpired(t *testing.T) { - noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + noncecache, err := NewNonceCache(testNonceCacheSize) assert.Nil(t, err) nonceId := "nonceId_1" now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) // first time seeing this nodeId, returning false - isexist := noncecache.IsNonceExists(nonceId, now) + isexist := noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) assert.False(t, isexist) // the second time the nonceId should be cached. - isexist = noncecache.IsNonceExists(nonceId, now) + isexist = noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) assert.True(t, isexist) } // Validate that expired Ids will be evicted upon the addition of a new entry. func Test_NonceCacheEvictExpiredIds(t *testing.T) { - noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + noncecache, err := NewNonceCache(testNonceCacheSize) assert.Nil(t, err) now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) for i := 0; i < testNonceCacheSize-1; i++ { id := fmt.Sprintf("%d", i) now = now.Add(time.Second) - noncecache.IsNonceExists(id, now) + noncecache.IsNonceExists(id, now, testNonceValidInterval) //need to call twice to confirm the nonceId were added, since the first time // it is added, it will not exist - isexist := noncecache.IsNonceExists(id, now) + isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) assert.True(t, isexist) } @@ -73,7 +73,7 @@ func Test_NonceCacheEvictExpiredIds(t *testing.T) { // adding the new entry will trigger an eviction of the expired entries newId := "new" now = now.Add(testNonceValidInterval * 2) - noncecache.IsNonceExists(newId, now) + noncecache.IsNonceExists(newId, now, testNonceValidInterval) // validate older entry has been evicted; size of cache should be 1 // we check the size before checking if the older ids have been evicted as they will get @@ -83,7 +83,7 @@ func Test_NonceCacheEvictExpiredIds(t *testing.T) { // validate the ids no longer exists in cache for i := 0; i < testNonceCacheSize-1; i++ { id := fmt.Sprintf("%d", i) - isexist := noncecache.IsNonceExists(id, now) + isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) assert.False(t, isexist) } @@ -91,7 +91,7 @@ func Test_NonceCacheEvictExpiredIds(t *testing.T) { // Validate that the oldest Ids will be evicted upon the addition of a new entry that exceeds the cache size. func Test_NonceCacheEvictOverflowIds(t *testing.T) { - noncecache, err := NewNonceCache(testNonceCacheSize, testNonceValidInterval) + noncecache, err := NewNonceCache(testNonceCacheSize) assert.Nil(t, err) idsToAddCount := testNonceCacheSize + 2 @@ -99,11 +99,11 @@ func Test_NonceCacheEvictOverflowIds(t *testing.T) { for i := 0; i < idsToAddCount; i++ { id := fmt.Sprintf("%d", i) now = now.Add(time.Second) - noncecache.IsNonceExists(id, now) + noncecache.IsNonceExists(id, now, testNonceValidInterval) //need to call twice to confirm the nonceId were added, since the first time // it is added, it will not exist - isexist := noncecache.IsNonceExists(id, now) + isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) assert.True(t, isexist) // validate size of cache does not exceed the max even if more ids were added. @@ -117,7 +117,7 @@ func Test_NonceCacheEvictOverflowIds(t *testing.T) { id := fmt.Sprintf("%d", i) now = now.Add(time.Second) - isexist := noncecache.IsNonceExists(id, now) + isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) if i >= testNonceCacheSize { assert.True(t, isexist) } else { diff --git a/pkg/auth/poptoken/poptoken.go b/pkg/auth/poptoken/poptoken.go deleted file mode 100644 index ea2a59a7..00000000 --- a/pkg/auth/poptoken/poptoken.go +++ /dev/null @@ -1,195 +0,0 @@ -package poptoken - -import ( - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "encoding/base64" - "encoding/binary" - "encoding/json" - "reflect" - "strings" - "time" -) - -type PopTokenHeader struct { - // RSA PS256? - Alg string `json:"alg"` - // key Id of public key - Kid string `json:"kid"` - // always pop - Typ string `json:"typ"` -} - -// https://datatracker.ietf.org/doc/html/rfc7638#section-3.1 -// contains the metadata use to calculate kid. -type Jwk struct { - // Exponent - E string `json:"e"` - // encryption - Kty string `json:"kty"` - // modulus - N string `json:"n"` -} - -// https://datatracker.ietf.org/doc/html/rfc7800#section-3.2 -type Cnf struct { - Jwk Jwk `json:"jwk"` -} - -type ReqCnf struct { - Kid string `json:"kid"` -} - -// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-signed-http-request-03#section-3 -type PopTokenBody struct { - Cnf Cnf `json:"cnf"` - // timestamp - Ts int64 `json:"ts"` - // access token - At string `json:"at"` -} - -// Implements the shr pop token generically. Callers of ths instance can add their own custom claims when generating the token. -type popToken struct { - Header PopTokenHeader - Body PopTokenBody - refCnfBase64 string - RSAKeyPair *RsaKeyPair -} - -const ( - TokenType = "pop" -) - -func calculatePublicKeyId(jwk *Jwk) (string, error) { - // - https://tools.ietf.org/html/rfc7638#section-3.1 - jwkByte, err := json.Marshal(jwk) - if err != nil { - return "", err - } - jwk256 := sha256.Sum256(jwkByte) - return base64.RawURLEncoding.EncodeToString(jwk256[:]), nil -} - -func jsonToBase64(v any) (string, error) { - jsonValue, err := json.Marshal(v) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(jsonValue), nil -} - -func signPayload(payload []byte, rsaKey *rsa.PrivateKey) (string, error) { - hash := sha256.New() - _, err := hash.Write(payload) - if err != nil { - return "", err - } - sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hash.Sum(nil)) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(sigBytes), nil -} - -func exponential2Base64(e int) string { - bs := make([]byte, 4) - binary.BigEndian.PutUint32(bs, uint32(e)) - - bs = bs[1:] // drop most significant byte - leaving least-significant 3-bytes - ss := base64.RawURLEncoding.EncodeToString(bs) - return ss -} - -// Append custom claims to the existing ShrPopTokenBody. -func (pop *popToken) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { - - bodyMap := make(map[string]interface{}) - - // first convert the existing body to a map of interface. - val := reflect.ValueOf(pop.Body) - typ := val.Type() - for i := 0; i < val.NumField(); i++ { - if name := strings.ToLower(typ.Field(i).Name); name != "" { - bodyMap[name] = val.Field(i).Interface() - } - } - // now append the custom claims - for k, v := range customClaims { - bodyMap[k] = v - } - - return bodyMap -} - -// Complete the poptoken creation by adding the custom claims and signing it. -func (pop *popToken) GenerateToken(token string, now time.Time, customClaims map[string]interface{}) (string, error) { - - pop.Body.Ts = now.Truncate(time.Second).Unix() - pop.Body.At = token - - body, err := jsonToBase64(pop.appendCustomClaimsToBody(customClaims)) - if err != nil { - return "", err - } - - header, err := jsonToBase64(pop.Header) - if err != nil { - return "", err - } - - signingStr := strings.Join([]string{header, body}, ".") - - signature, err := signPayload([]byte(signingStr), pop.RSAKeyPair.PrivateKey) - if err != nil { - return "", nil - } - - return strings.Join([]string{signingStr, signature}, "."), nil -} - -// Generate ReqCnf to be passed to Msal -func (pop *popToken) GetReqCnf() string { - return pop.refCnfBase64 -} - -// Create a new instance of ShrPopToken. This generate a partial filled, generic shrpoptoken. The custom claims will be -// added later on in GenerateToken() -func NewPopToken(keyPair *RsaKeyPair) (*popToken, error) { - pop := popToken{ - Header: PopTokenHeader{ - Alg: keyPair.Alg, - Typ: TokenType, - }, - Body: PopTokenBody{ - Cnf: Cnf{ - Jwk: Jwk{ - Kty: keyPair.Kty, - N: base64.RawURLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.Bytes())), - E: exponential2Base64(keyPair.PublicKey.E), - }, - }, - }, - RSAKeyPair: keyPair, - } - - keyId, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk) - if err != nil { - return nil, err - } - - pop.Header.Kid = keyId - - refCnfb64, err := jsonToBase64( - ReqCnf{ - Kid: keyId, - }) - if err != nil { - return nil, err - } - pop.refCnfBase64 = refCnfb64 - - return &pop, err -} diff --git a/pkg/auth/poptoken/poptoken_test.go b/pkg/auth/poptoken/poptoken_test.go deleted file mode 100644 index 7152eb49..00000000 --- a/pkg/auth/poptoken/poptoken_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package poptoken - -import ( - "crypto" - "crypto/rsa" - "crypto/sha256" - "encoding/base64" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -type testStruct struct { - StrValue string `json:"str"` - IntValue int `json:"int"` -} - -type TestPopTokenBody struct { - PopTokenBody - // target node Id. - NodeId string `json:"nodeid"` - // uri to the grpc object targeted - ObjectPath string `json:"p"` -} - -// the pop token is partially filled out upon calling NewPopToken -func Test_ShrPopTokenNewPopToken(t *testing.T) { - keypair, err := getKeyPair() - assert.Nil(t, err) - - pop, err := NewPopToken(keypair) - assert.Nil(t, err) - - // calculate kid - expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk) - assert.Nil(t, err) - - // check header - assert.Equal(t, Alg, pop.Header.Alg) - assert.Equal(t, TokenType, pop.Header.Typ) - assert.Equal(t, expectedKid, pop.Header.Kid) - - // check body - expectedE := exponential2Base64(keypair.PrivateKey.E) - expectedN := base64.RawURLEncoding.EncodeToString([]byte(keypair.PublicKey.N.Bytes())) - assert.Equal(t, expectedE, pop.Body.Cnf.Jwk.E) - assert.Equal(t, expectedN, pop.Body.Cnf.Jwk.N) - assert.Equal(t, keypair.Kty, pop.Body.Cnf.Jwk.Kty) - -} - -func Test_ShrPopTokenGenerateToken(t *testing.T) { - keypair, err := getKeyPair() - assert.Nil(t, err) - pop, err := NewPopToken(keypair) - assert.Nil(t, err) - - expectedAccessToken := "myFakeAccessToken" - expectedTimeStamp, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") - assert.Nil(t, err) - - expectedResourceIdValue := "1234" - expectedObjectPathValue := "myObject" - // note: name of the claims must match the names of the entities in test struct testPopTokenBody - customClaims := map[string]interface{}{"nodeId": expectedResourceIdValue, "p": expectedObjectPathValue} - - // calculate kid - expectedKid, err := calculatePublicKeyId(&pop.Body.Cnf.Jwk) - assert.Nil(t, err) - - // Generate the token and validate its content - popToken, err := pop.GenerateToken(expectedAccessToken, expectedTimeStamp, customClaims) - assert.Nil(t, err) - - toks := strings.Split(popToken, ".") - assert.Equal(t, 3, len(toks)) - - // validate header. - header, err := decodeFromBase64[PopTokenHeader](toks[0]) - assert.Nil(t, err) - assert.Equal(t, Alg, header.Alg) - assert.Equal(t, TokenType, header.Typ) - assert.Equal(t, expectedKid, header.Kid) - - // validate body - body, err := decodeFromBase64[TestPopTokenBody](toks[1]) - assert.Nil(t, err) - assert.Equal(t, expectedTimeStamp.Truncate(time.Second).Unix(), body.Ts) - assert.Equal(t, expectedResourceIdValue, body.NodeId) - assert.Equal(t, expectedObjectPathValue, body.ObjectPath) - assert.Equal(t, expectedAccessToken, body.At) - - // validate signature. - signature, err := base64.RawURLEncoding.DecodeString(toks[2]) - assert.Nil(t, err) - - signingStr := strings.Join([]string{toks[0], toks[1]}, ".") - err = isSignatureValid(&signingStr, signature, &body.Cnf) - assert.Nil(t, err) -} - -func Test_ShrPopTokenAppendCustomClaims(t *testing.T) { - keypair, err := getKeyPair() - assert.Nil(t, err) - - pop, err := NewPopToken(keypair) - assert.Nil(t, err) - - expectedStringValue := "string" - expectedIntegerValue := 1234 - expectedStrArrValue := []string{"hello", "world"} - expectedStructValue := testStruct{StrValue: "string", IntValue: 1234} - - customClaims := map[string]interface{}{ - "string": expectedStringValue, - "integer": expectedIntegerValue, - "strArray": expectedStrArrValue, - "struct": expectedStructValue, - } - - actualClaims := pop.appendCustomClaimsToBody(customClaims) - - tmp, ok := actualClaims["string"] - assert.True(t, ok) - actualstringValue, ok := tmp.(string) - assert.True(t, ok) - assert.Equal(t, expectedStringValue, actualstringValue) - - tmp, ok = actualClaims["integer"] - assert.True(t, ok) - actualIntegerValue, ok := tmp.(int) - assert.True(t, ok) - assert.Equal(t, expectedIntegerValue, actualIntegerValue) - - tmp, ok = actualClaims["strArray"] - assert.True(t, ok) - actualStrArrValue, ok := tmp.([]string) - assert.True(t, ok) - assert.Equal(t, expectedStrArrValue, actualStrArrValue) - - tmp, ok = actualClaims["struct"] - assert.True(t, ok) - actualStructValue, ok := tmp.(testStruct) - assert.True(t, ok) - assert.Equal(t, expectedStructValue, actualStructValue) - - // finally sanity check that these custom claims can be converted to json - _, err = jsonToBase64(actualClaims) - assert.Nil(t, err) -} - -func Test_ShrPopTokenGetReqCnf(t *testing.T) { - keypair, err := getKeyPair() - assert.Nil(t, err) - - pop, err := NewPopToken(keypair) - assert.Nil(t, err) - - expectedReqCnfBase64, err := jsonToBase64( - ReqCnf{ - Kid: pop.Header.Kid, - }) - assert.Nil(t, err) - - actualreqCnfBase64 := pop.GetReqCnf() - assert.Equal(t, expectedReqCnfBase64, actualreqCnfBase64) -} - -func Test_ShrPopTokenExponential2Base64(t *testing.T) { - e := 65537 - base64 := exponential2Base64(e) - // this is the encoded value of a well known exponential value - assert.Equal(t, "AQAB", base64) -} - -func Test_ShrPopTokenCalculatePublicKeyId(t *testing.T) { - jwk := Jwk{ - Kty: "RSA", - E: "AQAB", - N: "MjM1MDg5MDU4MzgxMDg3OTI5NTU3NjM1ODg4NTA3NDE5OTAwNzc0MzkzNzQ5NDcwNzcwMjA2MDIxNjMyNzk5NzYxNDM4NTczMjc3NTA0NzI4ODkzNDUzNjU0NDU0NjMxMjcxNjQ0MTAwMDM0NzUzNzU2MTEyMjkzODYzMDYxMjk5MDQxNzI5OTc0MDg5OTk2OTEzNTY4MjM5OTc0NDMwNTExODI3MDgyNDAzMDQxNDMxMTQ5ODA4ODc4NjE5NTc5MjcwMjAxNjc3ODM1NTQ0NDI3NDMwMDczODI2OTAwODk2MzcxNTM2NzE5NDQyNTUxNzIzNTM5MTg4OTU2MDc4MzI0MzYxNDM4MDEzNjA3OTI0NzMyNTUxMDg5ODU3NjQ1NDA0MTIyMTk3ODUwNjkyMjEyMTk4OTMxMDU1NTkzOTk4NzYyMjIwODg1NDg5NzE4MjQxNDAxMTg2MTMwMzExODAwMDQ2NjEwMjk0MDIzMzQ1MTA1NjE4ODY0ODc0OTgzNzU2NTMzMTY0OTk5MTg1NDk4ODIwOTY3NjYyNjM1NTUxMjk0NTkzNDEwNzc5MzUwODg2MjMxODkyMTc0NTcwODkxNDU4MjIwNzIwMzI5MTg3OTA3NzAxMzMzMDU1NzM0ODk0NjU3MDYzOTMzMzA3MTUwNjgzMTk1NjkyOTk0MzAxMjUxODUwNzUwMTg2MzI5MzM4ODk2NjY3OTQyMDE0OTcwODY3MTAzMTgxNTA5NDAxMTAwMzUwMzk5MDE3MDI3MTI3MTAwMDM5OTIwNjgwNjExNjcxNTQ3MDE1ODM2NzIyMTU1OTgxMTE=", - } - keyId, err := calculatePublicKeyId(&jwk) - assert.Nil(t, err) - assert.Equal(t, "a0CyVS__Npcx4GXYm1OCoxrlboOWKF02MXzSSh92ckY", keyId) -} - -func Test_ShrPopTokenSignPayload(t *testing.T) { - keypair, err := getKeyPair() - assert.Nil(t, err) - - payload := []byte("ThisIsMyTestPayLoad") - - sig, err := signPayload(payload, keypair.PrivateKey) - assert.Nil(t, err) - - //now verify the signature using the public key - sigDecode, err := base64.RawURLEncoding.DecodeString(sig) - hash := sha256.New() - hash.Write(payload) - err = rsa.VerifyPKCS1v15(keypair.PublicKey, crypto.SHA256, hash.Sum(nil), sigDecode) - assert.Nil(t, err) -} - -func getKeyPair() (*RsaKeyPair, error) { - - pKey, err := generatePrivateKey() - if err != nil { - return nil, err - } - - return &RsaKeyPair{ - PrivateKey: pKey, - PublicKey: pKey.Public().(*rsa.PublicKey), - RsaSize: RsaSize, - Kty: Kty, - Alg: Alg, - }, nil -} diff --git a/pkg/auth/poptoken/poptokenscheme.go b/pkg/auth/poptoken/poptokenscheme.go index 43f49fe9..b7f4b890 100644 --- a/pkg/auth/poptoken/poptokenscheme.go +++ b/pkg/auth/poptoken/poptokenscheme.go @@ -1,59 +1,262 @@ package poptoken import ( - "sync" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/json" + "reflect" + "strings" "time" + + "github.com/pkg/errors" ) const ( - tokenType = "token_type" - reqCnf = "req_cnf" + refreshInterval time.Duration = time.Hour * 8 + TokenType = "pop" + DefaultRefreshInterval = time.Hour * 8 + RsaSize = 2048 + Kty = "RSA" + Alg = "RS256" +) + +var ( + globalRsaKey *rsaKeyPair = nil + globalRefreshInterval time.Duration = DefaultRefreshInterval + globalLastRefreshRsaKeyDateTime time.Time = time.Now() ) +type rsaKeyPair struct { + PrivateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey + CreatedDateTime time.Time +} + +type PopTokenHeader struct { + // RSA PS256? + Alg string `json:"alg"` + // key Id of public key + Kid string `json:"kid"` + // always pop + Typ string `json:"typ"` +} + +// https://datatracker.ietf.org/doc/html/rfc7638#section-3.1 +// contains the metadata use to calculate kid. +type Jwk struct { + // Exponent + E string `json:"e"` + // encryption + Kty string `json:"kty"` + // modulus + N string `json:"n"` +} + +// https://datatracker.ietf.org/doc/html/rfc7800#section-3.2 +type Cnf struct { + Jwk Jwk `json:"jwk"` +} + +type ReqCnf struct { + Kid string `json:"kid"` +} + +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-signed-http-request-03#section-3 +type PopTokenBody struct { + Cnf Cnf `json:"cnf"` + // timestamp + Ts int64 `json:"ts"` + // access token + At string `json:"at"` +} + // Implements the interface for MSAL SDK to callback when creating the poptoken. // See AuthenticationScheme interface in https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/main/apps/internal/oauth/ops/authority/authority.go#L146 type PopTokenAuthScheme struct { - shrPopToken *popToken - claims map[string]interface{} - currRsaKeyPair *RsaKeyPair - rsaKeyPairRefreshInterval time.Duration - rsaKeyPairRefreshMutex sync.Mutex + header PopTokenHeader + body PopTokenBody + reqCnfBase64 string + claims map[string]interface{} + keyPair *rsaKeyPair +} + +func overwriteGlobalRefereshRate(new time.Duration) { + globalRefreshInterval = new +} + +// refresh the global rsa keypair once every 8 hours. +func generateRSAKeyPair(now time.Time) (*rsaKeyPair, error) { + + if globalRsaKey == nil || globalLastRefreshRsaKeyDateTime.Add(globalRefreshInterval).Before(now) { + pKey, err := rsa.GenerateKey(rand.Reader, RsaSize) + if err != nil { + return nil, err + } + globalRsaKey = &rsaKeyPair{ + PrivateKey: pKey, + PublicKey: pKey.Public().(*rsa.PublicKey), + CreatedDateTime: now, + } + globalLastRefreshRsaKeyDateTime = now + } + return globalRsaKey, nil +} + +func calculatePublicKeyId(jwk *Jwk) (string, error) { + // - https://tools.ietf.org/html/rfc7638#section-3.1 + jwkByte, err := json.Marshal(jwk) + if err != nil { + return "", err + } + jwk256 := sha256.Sum256(jwkByte) + return base64.RawURLEncoding.EncodeToString(jwk256[:]), nil +} + +func jsonToBase64(v any) (string, error) { + jsonValue, err := json.Marshal(v) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(jsonValue), nil +} + +func signPayload(payload []byte, rsaKey *rsa.PrivateKey) (string, error) { + hash := sha256.New() + _, err := hash.Write(payload) + if err != nil { + return "", err + } + sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hash.Sum(nil)) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(sigBytes), nil +} + +func exponential2Base64(e int) string { + bs := make([]byte, 4) + binary.BigEndian.PutUint32(bs, uint32(e)) + + bs = bs[1:] // drop most significant byte - leaving least-significant 3-bytes + ss := base64.RawURLEncoding.EncodeToString(bs) + return ss +} + +// Append custom claims to the existing ShrPopTokenBody. +func (p *PopTokenAuthScheme) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { + + bodyMap := make(map[string]interface{}) + + // first convert the existing body to a map of interface. + val := reflect.ValueOf(p.body) + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + if name := strings.ToLower(typ.Field(i).Name); name != "" { + bodyMap[name] = val.Field(i).Interface() + } + } + // now append the custom claims + for k, v := range customClaims { + bodyMap[k] = v + } + + return bodyMap +} + +// Complete the poptoken creation by adding the custom claims and signing it. +func (p *PopTokenAuthScheme) generateToken(token string, now time.Time) (string, error) { + + p.body.Ts = now.Truncate(time.Second).Unix() + p.body.At = token + + body, err := jsonToBase64(p.appendCustomClaimsToBody(p.claims)) + if err != nil { + return "", err + } + + header, err := jsonToBase64(p.header) + if err != nil { + return "", err + } + + signingStr := strings.Join([]string{header, body}, ".") + + signature, err := signPayload([]byte(signingStr), p.keyPair.PrivateKey) + if err != nil { + return "", nil + } + + return strings.Join([]string{signingStr, signature}, "."), nil } // Return the claim containg the pop token kid that will be added to the Entra access token. -func (a *PopTokenAuthScheme) TokenRequestParams() map[string]string { +func (p *PopTokenAuthScheme) TokenRequestParams() map[string]string { return map[string]string{ - tokenType: a.shrPopToken.Header.Typ, - reqCnf: a.shrPopToken.GetReqCnf(), + "token_type": p.header.Typ, + "req_cnf": p.reqCnfBase64, } } // Return the keyId for MSAL to lookup for a cached access token. If it does not exist, MSAL will request a new access token -func (a *PopTokenAuthScheme) KeyID() string { - return a.shrPopToken.Header.Kid +func (p *PopTokenAuthScheme) KeyID() string { + return p.header.Kid } // Generate the pop token; adding in the accessToken generated by Entra. -func (a *PopTokenAuthScheme) FormatAccessToken(accessToken string) (string, error) { - // append accessToken and our custom claims to the pop token. - // Note custom claims should be compatible with JWT specs, we don't expect these claims to be complex - return a.shrPopToken.GenerateToken(accessToken, time.Now(), a.claims) +func (p *PopTokenAuthScheme) FormatAccessToken(accessToken string) (string, error) { + return p.generateToken(accessToken, time.Now()) } // Return the token type. Must be "pop" -func (a *PopTokenAuthScheme) AccessTokenType() string { - return a.shrPopToken.Header.Typ +func (p *PopTokenAuthScheme) AccessTokenType() string { + return p.header.Typ } // Create a new instance of PopTokenAuthScheme. Pass in the custom claims to be set in the pop token here, e.g. resourceId -func NewPopTokenAuthScheme(claims map[string]interface{}, rsaKeyPair *RsaKeyPair) (*PopTokenAuthScheme, error) { - shrPopToken, err := NewPopToken(rsaKeyPair) +func NewPopTokenAuthScheme(claims map[string]interface{}) (*PopTokenAuthScheme, error) { + + keyPair, err := generateRSAKeyPair(time.Now()) if err != nil { return nil, err } - return &PopTokenAuthScheme{ - shrPopToken: shrPopToken, - claims: claims, - }, nil + popTokenScheme := &PopTokenAuthScheme{ + header: PopTokenHeader{ + Alg: Alg, + Typ: TokenType, + }, + body: PopTokenBody{ + Cnf: Cnf{ + Jwk: Jwk{ + Kty: Kty, + N: base64.RawURLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.Bytes())), + E: exponential2Base64(keyPair.PublicKey.E), + }, + }, + }, + keyPair: keyPair, + claims: claims, + } + + keyId, err := calculatePublicKeyId(&popTokenScheme.body.Cnf.Jwk) + if err != nil { + return nil, errors.Wrapf(err, "faild to generate kid") + } + + popTokenScheme.header.Kid = keyId + + reqCnfb64, err := jsonToBase64( + ReqCnf{ + Kid: keyId, + }) + if err != nil { + return nil, errors.Wrapf(err, "faild to generate base64 representation of req_cnf") + } + popTokenScheme.reqCnfBase64 = reqCnfb64 + + return popTokenScheme, nil } diff --git a/pkg/auth/poptoken/poptokenscheme_test.go b/pkg/auth/poptoken/poptokenscheme_test.go index ee7dba1c..77e2a599 100644 --- a/pkg/auth/poptoken/poptokenscheme_test.go +++ b/pkg/auth/poptoken/poptokenscheme_test.go @@ -1,6 +1,11 @@ package poptoken import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" "strings" "testing" "time" @@ -8,62 +13,233 @@ import ( "github.com/stretchr/testify/assert" ) +const ( + testClaimName = "test" + testValue = "value" +) + +var ( + testClaims = map[string]interface{}{testClaimName: testValue} +) + type TestPopTokenSchemeBody struct { PopTokenBody - NodeId string `json:"nodeid"` + Test string `json:"test"` // must matchtestClaimName } -// This test suite focus on the testing poptokenscheme is returning the expected values that MSAL expected -// the actual token generation is tested in shrpoptoken_test -func Test_PopTokenScheme(t *testing.T) { - expectedNodeId := "mynodeId" +type testStruct struct { + StrValue string + IntValue int +} - kmgr, err := NewRsaKeyManager(time.Hour) - assert.Nil(t, err) +func Test_PopTokenAuthSchemeNew(t *testing.T) { - keypair, err := kmgr.GetKeyPair(time.Now()) + pop, err := NewPopTokenAuthScheme(testClaims) assert.Nil(t, err) - // create a "reference" pop token that we can use to validate some of the nodeagentpoptokenscheme content since it - // should generate the same values - refPopToken, err := NewPopToken(keypair) + // calculate kid + expectedKid, err := calculatePublicKeyId(&pop.body.Cnf.Jwk) assert.Nil(t, err) - // Generate nodeagent scheme - claims := map[string]interface{}{ - "nodeId": expectedNodeId, - } - popTokenScheme, err := NewPopTokenAuthScheme(claims, keypair) + // check header + assert.Equal(t, Alg, pop.header.Alg) + assert.Equal(t, TokenType, pop.header.Typ) + assert.Equal(t, expectedKid, pop.header.Kid) - //validate AccessTokenType returns "pop" - assert.Equal(t, TokenType, popTokenScheme.AccessTokenType()) - assert.Equal(t, refPopToken.Header.Typ, popTokenScheme.AccessTokenType()) + // check body + expectedE := exponential2Base64(pop.keyPair.PrivateKey.E) + expectedN := base64.RawURLEncoding.EncodeToString([]byte(pop.keyPair.PublicKey.N.Bytes())) + assert.Equal(t, expectedE, pop.body.Cnf.Jwk.E) + assert.Equal(t, expectedN, pop.body.Cnf.Jwk.N) - //Validate KeyID - assert.Equal(t, refPopToken.Header.Kid, popTokenScheme.KeyID()) + //check claims + actualValue, ok := pop.claims[testClaimName] + assert.True(t, ok) + assert.Equal(t, testValue, actualValue) +} - // Validate TokenRequestParams returns a specific struct - reqCnf := popTokenScheme.TokenRequestParams() +func Test_PopTokenAuthSchemeGenerateToken(t *testing.T) { + pop, err := NewPopTokenAuthScheme(testClaims) + assert.Nil(t, err) - tokenType, ok := reqCnf["token_type"] - assert.True(t, ok) - assert.Equal(t, refPopToken.Header.Typ, tokenType) + expectedAccessToken := "myFakeAccessToken" + expectedTimeStamp, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") + assert.Nil(t, err) - expectedCnf := refPopToken.GetReqCnf() + expectedKid, err := calculatePublicKeyId(&pop.body.Cnf.Jwk) assert.Nil(t, err) - cnf, ok := reqCnf["req_cnf"] - assert.True(t, ok) - assert.Equal(t, expectedCnf, cnf) - // Validate FormatAccessToken. Here we just check that the custom claim "nodeId" was added. - popToken, err := popTokenScheme.FormatAccessToken("accessToken") + // Generate the token and validate its content + popToken, err := pop.generateToken(expectedAccessToken, expectedTimeStamp) assert.Nil(t, err) - assert.NotEmpty(t, popToken) toks := strings.Split(popToken, ".") assert.Equal(t, 3, len(toks)) + + // validate header. + header, err := decodeFromBase64[PopTokenHeader](toks[0]) + assert.Nil(t, err) + assert.Equal(t, Alg, header.Alg) + assert.Equal(t, TokenType, header.Typ) + assert.Equal(t, expectedKid, header.Kid) + + // validate body body, err := decodeFromBase64[TestPopTokenSchemeBody](toks[1]) assert.Nil(t, err) - assert.Equal(t, expectedNodeId, body.NodeId) + assert.Equal(t, expectedTimeStamp.Truncate(time.Second).Unix(), body.Ts) + assert.Equal(t, testValue, body.Test) + assert.Equal(t, expectedAccessToken, body.At) + + // validate signature. + signature, err := base64.RawURLEncoding.DecodeString(toks[2]) + assert.Nil(t, err) + + signingStr := strings.Join([]string{toks[0], toks[1]}, ".") + err = isSignatureValid(&signingStr, signature, &body.Cnf) + assert.Nil(t, err) +} + +func Test_PopTokenAuthSchemeTokenRequestParams(t *testing.T) { + pop, err := NewPopTokenAuthScheme(testClaims) + assert.Nil(t, err) + + expectedCnfBase64, err := jsonToBase64( + ReqCnf{ + Kid: pop.header.Kid, + }) + assert.Nil(t, err) + + requestParams := pop.TokenRequestParams() + + // we expect these two entries + tokType, ok := requestParams["token_type"] + assert.True(t, ok) + assert.Equal(t, TokenType, tokType) + + actualReqCnf, ok := requestParams["req_cnf"] + assert.True(t, ok) + assert.Equal(t, expectedCnfBase64, actualReqCnf) +} + +func Test_PopTokenAuthSchemeAppendCustomClaims(t *testing.T) { + pop, err := NewPopTokenAuthScheme(testClaims) + assert.Nil(t, err) + + expectedStringValue := "string" + expectedIntegerValue := 1234 + expectedStrArrValue := []string{"hello", "world"} + expectedStructValue := testStruct{StrValue: "string", IntValue: 1234} + + customClaims := map[string]interface{}{ + "string": expectedStringValue, + "integer": expectedIntegerValue, + "strArray": expectedStrArrValue, + "struct": expectedStructValue, + } + + actualClaims := pop.appendCustomClaimsToBody(customClaims) + + tmp, ok := actualClaims["string"] + assert.True(t, ok) + actualstringValue, ok := tmp.(string) + assert.True(t, ok) + assert.Equal(t, expectedStringValue, actualstringValue) + + tmp, ok = actualClaims["integer"] + assert.True(t, ok) + actualIntegerValue, ok := tmp.(int) + assert.True(t, ok) + assert.Equal(t, expectedIntegerValue, actualIntegerValue) + + tmp, ok = actualClaims["strArray"] + assert.True(t, ok) + actualStrArrValue, ok := tmp.([]string) + assert.True(t, ok) + assert.Equal(t, expectedStrArrValue, actualStrArrValue) + + tmp, ok = actualClaims["struct"] + assert.True(t, ok) + actualStructValue, ok := tmp.(testStruct) + assert.True(t, ok) + assert.Equal(t, expectedStructValue, actualStructValue) + + // finally sanity check that these custom claims can be converted to json + _, err = jsonToBase64(actualClaims) + assert.Nil(t, err) +} + +func Test_PopTokenAuthSchemeExponential2Base64(t *testing.T) { + e := 65537 + base64 := exponential2Base64(e) + // this is the encoded value of a well known exponential value + assert.Equal(t, "AQAB", base64) +} + +func Test_PopTokenAuthSchemeCalculatePublicKeyId(t *testing.T) { + jwk := Jwk{ + Kty: "RSA", + E: "AQAB", + N: "MjM1MDg5MDU4MzgxMDg3OTI5NTU3NjM1ODg4NTA3NDE5OTAwNzc0MzkzNzQ5NDcwNzcwMjA2MDIxNjMyNzk5NzYxNDM4NTczMjc3NTA0NzI4ODkzNDUzNjU0NDU0NjMxMjcxNjQ0MTAwMDM0NzUzNzU2MTEyMjkzODYzMDYxMjk5MDQxNzI5OTc0MDg5OTk2OTEzNTY4MjM5OTc0NDMwNTExODI3MDgyNDAzMDQxNDMxMTQ5ODA4ODc4NjE5NTc5MjcwMjAxNjc3ODM1NTQ0NDI3NDMwMDczODI2OTAwODk2MzcxNTM2NzE5NDQyNTUxNzIzNTM5MTg4OTU2MDc4MzI0MzYxNDM4MDEzNjA3OTI0NzMyNTUxMDg5ODU3NjQ1NDA0MTIyMTk3ODUwNjkyMjEyMTk4OTMxMDU1NTkzOTk4NzYyMjIwODg1NDg5NzE4MjQxNDAxMTg2MTMwMzExODAwMDQ2NjEwMjk0MDIzMzQ1MTA1NjE4ODY0ODc0OTgzNzU2NTMzMTY0OTk5MTg1NDk4ODIwOTY3NjYyNjM1NTUxMjk0NTkzNDEwNzc5MzUwODg2MjMxODkyMTc0NTcwODkxNDU4MjIwNzIwMzI5MTg3OTA3NzAxMzMzMDU1NzM0ODk0NjU3MDYzOTMzMzA3MTUwNjgzMTk1NjkyOTk0MzAxMjUxODUwNzUwMTg2MzI5MzM4ODk2NjY3OTQyMDE0OTcwODY3MTAzMTgxNTA5NDAxMTAwMzUwMzk5MDE3MDI3MTI3MTAwMDM5OTIwNjgwNjExNjcxNTQ3MDE1ODM2NzIyMTU1OTgxMTE=", + } + keyId, err := calculatePublicKeyId(&jwk) + assert.Nil(t, err) + assert.Equal(t, "a0CyVS__Npcx4GXYm1OCoxrlboOWKF02MXzSSh92ckY", keyId) +} +func Test_PopTokenAuthSchemeSignPayload(t *testing.T) { + keypair, err := generateKeyPair() + assert.Nil(t, err) + + payload := []byte("ThisIsMyTestPayLoad") + + sig, err := signPayload(payload, keypair.PrivateKey) + assert.Nil(t, err) + + //now verify the signature using the public key + sigDecode, err := base64.RawURLEncoding.DecodeString(sig) + hash := sha256.New() + hash.Write(payload) + err = rsa.VerifyPKCS1v15(keypair.PublicKey, crypto.SHA256, hash.Sum(nil), sigDecode) + assert.Nil(t, err) +} + +func Test_PopTokenAuthSchemeGenerateRsaKeyPair(t *testing.T) { + + now, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") + + keyPair1, err := generateRSAKeyPair(now) + assert.Nil(t, err) + assert.NotNil(t, keyPair1) + + // now simulate getting a second keypair < refreshInterval; keyPair2 should be the same + keyPair2, err := generateRSAKeyPair(now) + assert.Nil(t, err) + assert.NotNil(t, keyPair2) + assert.Equal(t, keyPair1.PrivateKey.N, keyPair2.PrivateKey.N) + + // now simulate getting a third keypair > refreshInterval; keyPair3 should be different from 1 and 2. + newNow := now.Add(globalRefreshInterval).Add(time.Minute) + keyPair3, err := generateRSAKeyPair(newNow) + assert.Nil(t, err) + assert.NotNil(t, keyPair3) + assert.NotEqual(t, keyPair1.PrivateKey.N, keyPair3.PrivateKey.N) + assert.NotEqual(t, keyPair2.PrivateKey.N, keyPair3.PrivateKey.N) + + // now try again; keypair4 == keypair3 + keyPair4, err := generateRSAKeyPair(newNow) + assert.Nil(t, err) + assert.NotNil(t, keyPair4) + assert.Equal(t, keyPair3.PrivateKey.N, keyPair4.PrivateKey.N) +} + +func generateKeyPair() (*rsaKeyPair, error) { + pKey, err := rsa.GenerateKey(rand.Reader, RsaSize) + if err != nil { + return nil, err + } + return &rsaKeyPair{ + PrivateKey: pKey, + PublicKey: pKey.Public().(*rsa.PublicKey), + CreatedDateTime: time.Now(), + }, nil } diff --git a/pkg/auth/poptoken/rsakeymanager.go b/pkg/auth/poptoken/rsakeymanager.go deleted file mode 100644 index 67fa3ef4..00000000 --- a/pkg/auth/poptoken/rsakeymanager.go +++ /dev/null @@ -1,75 +0,0 @@ -package poptoken - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "time" -) - -type RsaKeyPair struct { - PrivateKey *rsa.PrivateKey - PublicKey *rsa.PublicKey - RsaSize int - Kty string - Alg string -} - -// a RSA Key generator that refresh the RSA KeyPair at regular interval -// Used to ensure the keys use to sign the poptoken are rotated -type rsaKeyManager struct { - refreshInterval time.Duration - createdDateTime time.Time - privateKey *rsa.PrivateKey -} - -const ( - DefaultRefreshInterval = time.Hour * 8 - RsaSize = 2048 - Kty = "RSA" - Alg = "RS256" -) - -func generatePrivateKey() (*rsa.PrivateKey, error) { - return rsa.GenerateKey(rand.Reader, RsaSize) -} - -// Return a KeyPair. The keypair is its own copy and not a reference. -func (r *rsaKeyManager) GetKeyPair(now time.Time) (*RsaKeyPair, error) { - if r.createdDateTime.Add(r.refreshInterval).Before(now) { - newPKey, err := generatePrivateKey() - if err != nil { - return nil, err - } - r.privateKey = newPKey - r.createdDateTime = now - } - - // Create and return a deep copy of the private key so clients are not impacted by a rotation midway. - privateKeyBytes := x509.MarshalPKCS1PrivateKey(r.privateKey) - privateKeyCopy, err := x509.ParsePKCS1PrivateKey(privateKeyBytes) - if err != nil { - return nil, err - } - - return &RsaKeyPair{ - PrivateKey: privateKeyCopy, - PublicKey: privateKeyCopy.Public().(*rsa.PublicKey), - RsaSize: RsaSize, - Kty: Kty, - Alg: Alg, - }, nil -} - -// Create a new RSAKeyManager that will refresh the keypair in the background. -func NewRsaKeyManager(refreshInterval time.Duration) (*rsaKeyManager, error) { - var err error - rsaMgr := &rsaKeyManager{} - - rsaMgr.refreshInterval = refreshInterval - rsaMgr.privateKey, err = generatePrivateKey() - if err != nil { - return nil, err - } - return rsaMgr, nil -} diff --git a/pkg/auth/poptoken/rsakeymanager_test.go b/pkg/auth/poptoken/rsakeymanager_test.go deleted file mode 100644 index e6b38925..00000000 --- a/pkg/auth/poptoken/rsakeymanager_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package poptoken - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -const ( - testRsaValidInterval = time.Hour * 1 - testRsaNowDateTimeStr = "2025-12-01T15:00:00Z" //time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") -) - -func Test_RsaKeyManagerGetKeyPair(t *testing.T) { - now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) - rsamgr, err := NewRsaKeyManager(testRsaValidInterval) - assert.Nil(t, err) - - rsa, err := rsamgr.GetKeyPair(now) - assert.Nil(t, err) - - assert.Equal(t, Alg, rsa.Alg) - assert.Equal(t, Kty, rsa.Kty) - assert.Equal(t, RsaSize, rsa.RsaSize) - assert.NotNil(t, rsa.PrivateKey) - assert.NotNil(t, rsa.PublicKey) - - // now get the keypair a second time, if it has not refreshed, it will be the same value - rsa2, err := rsamgr.GetKeyPair(now) - //validate private key are equal value - assert.Equal(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) -} - -func Test_RsaKeyManagerGetKeyPairRotated(t *testing.T) { - now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) - rsamgr, err := NewRsaKeyManager(testRsaValidInterval) - assert.Nil(t, err) - - rsa, err := rsamgr.GetKeyPair(now) - assert.Nil(t, err) - - // now get the keypair a second time past the refresh interval, a new key should be generated. - rsa2, err := rsamgr.GetKeyPair(now.Add(testRsaValidInterval * 2)) - assert.Nil(t, err) - - //validate the two keys are now different. - assert.NotEqual(t, *rsa.PrivateKey.N, *rsa2.PrivateKey.N) -} From 9621437627843ec6699dbca72e5a4b3f847cdb4c Mon Sep 17 00:00:00 2001 From: wecha Date: Sat, 28 Jun 2025 01:15:25 +0000 Subject: [PATCH 5/7] minor cleanup --- pkg/auth/poptoken/noncecache_test.go | 30 +++++++++++++--------------- pkg/auth/poptoken/poptokenscheme.go | 4 ---- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pkg/auth/poptoken/noncecache_test.go b/pkg/auth/poptoken/noncecache_test.go index afb57a2d..ba174ab1 100644 --- a/pkg/auth/poptoken/noncecache_test.go +++ b/pkg/auth/poptoken/noncecache_test.go @@ -10,7 +10,7 @@ import ( const ( testNonceCacheSize = 3 - testNonceValidInterval = time.Minute * 1 + testTokenValidInterval = time.Minute * 1 testNonceNowDateTimeStr = "2025-12-01T15:00:00Z" ) @@ -22,20 +22,18 @@ func Test_NonceCacheIdExists(t *testing.T) { now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) // first time seeing this nodeId, returning false - isexist := noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) + isexist := noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) assert.False(t, isexist) // the second time the nonceId should be cached. - isexist = noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) + isexist = noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) assert.True(t, isexist) // Validate a new entry will return false - isexist = noncecache.IsNonceExists("nonceId_2", now, testNonceValidInterval) + isexist = noncecache.IsNonceExists("nonceId_2", now, testTokenValidInterval) assert.False(t, isexist) } -// expired entries are lazily evicted, so an invalid entry can remain in the cache -// validate that even the expired entry exists, we will still return false. func Test_NonceCacheIdExistsButExpired(t *testing.T) { noncecache, err := NewNonceCache(testNonceCacheSize) assert.Nil(t, err) @@ -44,11 +42,11 @@ func Test_NonceCacheIdExistsButExpired(t *testing.T) { now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) // first time seeing this nodeId, returning false - isexist := noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) + isexist := noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) assert.False(t, isexist) // the second time the nonceId should be cached. - isexist = noncecache.IsNonceExists(nonceId, now, testNonceValidInterval) + isexist = noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) assert.True(t, isexist) } @@ -61,19 +59,19 @@ func Test_NonceCacheEvictExpiredIds(t *testing.T) { for i := 0; i < testNonceCacheSize-1; i++ { id := fmt.Sprintf("%d", i) now = now.Add(time.Second) - noncecache.IsNonceExists(id, now, testNonceValidInterval) + noncecache.IsNonceExists(id, now, testTokenValidInterval) //need to call twice to confirm the nonceId were added, since the first time // it is added, it will not exist - isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) assert.True(t, isexist) } // simulate querying a new nonce Id after time where the previously added ids expired. // adding the new entry will trigger an eviction of the expired entries newId := "new" - now = now.Add(testNonceValidInterval * 2) - noncecache.IsNonceExists(newId, now, testNonceValidInterval) + now = now.Add(testTokenValidInterval * 2) + noncecache.IsNonceExists(newId, now, testTokenValidInterval) // validate older entry has been evicted; size of cache should be 1 // we check the size before checking if the older ids have been evicted as they will get @@ -83,7 +81,7 @@ func Test_NonceCacheEvictExpiredIds(t *testing.T) { // validate the ids no longer exists in cache for i := 0; i < testNonceCacheSize-1; i++ { id := fmt.Sprintf("%d", i) - isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) assert.False(t, isexist) } @@ -99,11 +97,11 @@ func Test_NonceCacheEvictOverflowIds(t *testing.T) { for i := 0; i < idsToAddCount; i++ { id := fmt.Sprintf("%d", i) now = now.Add(time.Second) - noncecache.IsNonceExists(id, now, testNonceValidInterval) + noncecache.IsNonceExists(id, now, testTokenValidInterval) //need to call twice to confirm the nonceId were added, since the first time // it is added, it will not exist - isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) assert.True(t, isexist) // validate size of cache does not exceed the max even if more ids were added. @@ -117,7 +115,7 @@ func Test_NonceCacheEvictOverflowIds(t *testing.T) { id := fmt.Sprintf("%d", i) now = now.Add(time.Second) - isexist := noncecache.IsNonceExists(id, now, testNonceValidInterval) + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) if i >= testNonceCacheSize { assert.True(t, isexist) } else { diff --git a/pkg/auth/poptoken/poptokenscheme.go b/pkg/auth/poptoken/poptokenscheme.go index b7f4b890..db51f39f 100644 --- a/pkg/auth/poptoken/poptokenscheme.go +++ b/pkg/auth/poptoken/poptokenscheme.go @@ -84,10 +84,6 @@ type PopTokenAuthScheme struct { keyPair *rsaKeyPair } -func overwriteGlobalRefereshRate(new time.Duration) { - globalRefreshInterval = new -} - // refresh the global rsa keypair once every 8 hours. func generateRSAKeyPair(now time.Time) (*rsaKeyPair, error) { From 2189affdb2624605a29bc4177058e759f528f6db Mon Sep 17 00:00:00 2001 From: wecha Date: Sat, 28 Jun 2025 01:23:11 +0000 Subject: [PATCH 6/7] cleanup2 --- pkg/auth/poptoken/poptokenauth.go | 4 ++-- pkg/auth/poptoken/poptokenscheme.go | 10 ++++------ pkg/auth/poptoken/poptokenscheme_test.go | 5 ++--- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pkg/auth/poptoken/poptokenauth.go b/pkg/auth/poptoken/poptokenauth.go index 151c6ecd..5f354f10 100644 --- a/pkg/auth/poptoken/poptokenauth.go +++ b/pkg/auth/poptoken/poptokenauth.go @@ -12,9 +12,9 @@ The setup of the pop token creaton is as follows: | --> MsalAuthProvider (global component that request the token from Entra/AzureAAD via MSAL SDK) | - --> NodeAgentPopTokenAuthScheme (implements callback MSAL requires to generate the pop token) + --> NodeAgentPopTokenAuthScheme (wrapper that adds the claims specific to node agent) | - --> ShrPopToken (does most of the heavy lifing in generating the pop token) + --> PopTokenAuthScheme(a more generic pop tokens implementation) */ // This component integrates the MSAL provider to the grpc credentials.PerRPCCredentials interface diff --git a/pkg/auth/poptoken/poptokenscheme.go b/pkg/auth/poptoken/poptokenscheme.go index db51f39f..aef61a61 100644 --- a/pkg/auth/poptoken/poptokenscheme.go +++ b/pkg/auth/poptoken/poptokenscheme.go @@ -31,9 +31,8 @@ var ( ) type rsaKeyPair struct { - PrivateKey *rsa.PrivateKey - PublicKey *rsa.PublicKey - CreatedDateTime time.Time + PrivateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey } type PopTokenHeader struct { @@ -93,9 +92,8 @@ func generateRSAKeyPair(now time.Time) (*rsaKeyPair, error) { return nil, err } globalRsaKey = &rsaKeyPair{ - PrivateKey: pKey, - PublicKey: pKey.Public().(*rsa.PublicKey), - CreatedDateTime: now, + PrivateKey: pKey, + PublicKey: pKey.Public().(*rsa.PublicKey), } globalLastRefreshRsaKeyDateTime = now } diff --git a/pkg/auth/poptoken/poptokenscheme_test.go b/pkg/auth/poptoken/poptokenscheme_test.go index 77e2a599..5ece9635 100644 --- a/pkg/auth/poptoken/poptokenscheme_test.go +++ b/pkg/auth/poptoken/poptokenscheme_test.go @@ -238,8 +238,7 @@ func generateKeyPair() (*rsaKeyPair, error) { return nil, err } return &rsaKeyPair{ - PrivateKey: pKey, - PublicKey: pKey.Public().(*rsa.PublicKey), - CreatedDateTime: time.Now(), + PrivateKey: pKey, + PublicKey: pKey.Public().(*rsa.PublicKey), }, nil } From 6192ae2af6594235ac472a1c5339a93f18e3831c Mon Sep 17 00:00:00 2001 From: wecha Date: Tue, 8 Jul 2025 17:09:15 +0000 Subject: [PATCH 7/7] add kid validation --- go.mod | 18 ++++----- go.sum | 38 +++++-------------- .../poptoken/nodeagentpoptokenvalidator.go | 16 ++++++++ .../nodeagentpoptokenvalidator_test.go | 36 ++++++++++++++++++ 4 files changed, 71 insertions(+), 37 deletions(-) diff --git a/go.mod b/go.mod index 7fd6fa2f..4fa21179 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.10.0 go.uber.org/multierr v1.11.0 - google.golang.org/grpc v1.72.0 + google.golang.org/grpc v1.64.0 google.golang.org/grpc/security/advancedtls v1.0.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -24,8 +24,8 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect - github.com/go-jose/go-jose/v4 v4.0.4 // indirect github.com/goccy/go-json v0.10.3 // indirect + github.com/kr/pretty v0.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect @@ -34,14 +34,14 @@ require ( github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect - github.com/zeebo/errs v1.4.0 // indirect - golang.org/x/crypto v0.37.0 // indirect - golang.org/x/net v0.39.0 // indirect - golang.org/x/sys v0.32.0 // indirect - golang.org/x/text v0.24.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 // indirect google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index ca7bc80c..e203d11a 100644 --- a/go.sum +++ b/go.sum @@ -6,12 +6,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= -github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E= -github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= @@ -34,6 +30,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -55,28 +53,12 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE= -github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= -github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= -go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= -go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= -go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= -go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= -go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= -go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce1EK0Gyvahk= -go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= -go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= -go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= @@ -94,8 +76,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -122,8 +104,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -134,10 +116,10 @@ golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxb golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34 h1:h6p3mQqrmT1XkHVTfzLdNz1u7IhINeZkz67/xTbOuWs= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= -google.golang.org/grpc v1.72.0/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= +google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI= google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0= google.golang.org/grpc/security/advancedtls v1.0.0 h1:/KQ7VP/1bs53/aopk9QhuPyFAp9Dm9Ejix3lzYkCrDA= diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go index ba845fc5..bf98697f 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go @@ -89,6 +89,17 @@ func isTokenExpire(timestamp int64, now time.Time, clockSkew time.Duration) erro } return nil } +func isKidValid(cnf *Cnf, actualKid string) error { + expectedKid, err := calculatePublicKeyId(&cnf.Jwk) + if err != nil { + return errors.Wrapf(err, "failed to generate kid from pop token cnf") + } + if expectedKid != actualKid { + return fmt.Errorf("pop token header kid %s does not match kid %s as generated from 'cnf.jwk'", actualKid, expectedKid) + } + + return nil +} func isHeaderValid(header *PopTokenHeader) error { if header.Typ != TokenType { @@ -106,6 +117,7 @@ func isSignatureValid(signingStr *string, signature []byte, cnf *Cnf) error { if err != nil { return err } + return verifyPayload(signingStr, []byte(signature), publicKey) } @@ -313,6 +325,10 @@ func (s *shrPopTokenValidator) Validate(popToken string) error { return err } + if err := isKidValid(&body.Cnf, header.Kid); err != nil { + return err + } + if err := s.isCustomClaimsValid(&body); err != nil { return err } diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go index b5299fa6..b4d7359f 100644 --- a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go @@ -41,6 +41,42 @@ func Test_NodeAgentPopTokenValidatorAppendUrl(t *testing.T) { } } +func Test_NodeAgentPopTokenValidatorIsKidValid(t *testing.T) { + + newKeyPair, err := generateKeyPair() + assert.Nil(t, err) + cnf := publicKeyToCnf(newKeyPair) + expectedKid, err := calculatePublicKeyId(&cnf.Jwk) + assert.Nil(t, err) + + tests := []struct { + name string + expectedKid string + shouldPass bool + }{ + { + name: "valid kid", + expectedKid: expectedKid, + shouldPass: true, + }, + { + name: "invalid kid", + expectedKid: "randomKid", + shouldPass: false, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isKidValid(cnf, tt.expectedKid) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + func Test_NodeAgentPopTokenValidatorIsTokenExpire(t *testing.T) { tokenIssuedAt, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") assert.Nil(t, err)