diff --git a/go.mod b/go.mod index b3d4e591..4fa21179 100644 --- a/go.mod +++ b/go.mod @@ -3,33 +3,46 @@ module github.com/microsoft/moc go 1.23.0 require ( + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 github.com/go-logr/logr v1.4.2 github.com/golang-jwt/jwt/v4 v4.5.2 + github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.4 + github.com/google/uuid v1.6.0 github.com/hectane/go-acl v0.0.0-20230122075934-ca0b05cb1adb github.com/jmespath/go-jmespath v0.4.0 + github.com/lestrrat-go/jwx v1.2.31 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.10.0 go.uber.org/multierr v1.11.0 - google.golang.org/grpc v1.72.0 + google.golang.org/grpc v1.64.0 google.golang.org/grpc/security/advancedtls v1.0.0 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-jose/go-jose/v4 v4.0.4 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/kr/pretty v0.1.0 // indirect github.com/kr/text v0.2.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect + github.com/lestrrat-go/blackmagic v1.0.2 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect - github.com/zeebo/errs v1.4.0 // indirect - golang.org/x/crypto v0.37.0 // indirect - golang.org/x/net v0.39.0 // indirect - golang.org/x/sys v0.32.0 // indirect - golang.org/x/text v0.24.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 // indirect google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) replace ( diff --git a/go.sum b/go.sum index 21decf8e..e203d11a 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,19 @@ +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E= -github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= @@ -26,33 +30,35 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= +github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx v1.2.31 h1:/OM9oNl/fzyldpv5HKZ9m7bTywa7COUfg8gujd9nJ54= +github.com/lestrrat-go/jwx v1.2.31/go.mod h1:eQJKoRwWcLg4PfD5CFA5gIZGxhPgoPYq9pZISdxLf0c= +github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE= -github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= -github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= -go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= -go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= -go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= -go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= -go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= -go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce1EK0Gyvahk= -go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= -go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= -go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= @@ -70,8 +76,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -98,8 +104,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -110,10 +116,10 @@ golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxb golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34 h1:h6p3mQqrmT1XkHVTfzLdNz1u7IhINeZkz67/xTbOuWs= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250428153025-10db94c68c34/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= -google.golang.org/grpc v1.72.0/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= +google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI= google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0= google.golang.org/grpc/security/advancedtls v1.0.0 h1:/KQ7VP/1bs53/aopk9QhuPyFAp9Dm9Ejix3lzYkCrDA= diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 38543f18..679f1df9 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "strings" + "github.com/microsoft/moc/pkg/auth/poptoken" "github.com/microsoft/moc/pkg/config" "github.com/microsoft/moc/pkg/marshal" "github.com/microsoft/moc/rpc/common" @@ -261,6 +262,20 @@ func NewPopTokenAuthorizer() (*CmpPopTokenAuthorizer, error) { }, nil } +// TODO wecha: This will the new constructor for the poptoken authorizer and replace existing +// NewPopTokenAuthorizer(). We can't switch to the new ctor until we update wssdcloud agent. +func NewPopTokenAuthorizerPopAuth(popTokenAuth *poptoken.PopTokenAuth) (*CmpPopTokenAuthorizer, error) { + tp, err := DisableTransportAuthorization() + if err != nil { + return nil, err + } + + return &CmpPopTokenAuthorizer{ + transportProvider: tp, + rpcProvider: popTokenAuth, + }, nil +} + func (c *CmpPopTokenAuthorizer) WithTransportAuthorization() credentials.TransportCredentials { return c.transportProvider } diff --git a/pkg/auth/faketokenprovider.go b/pkg/auth/faketokenprovider.go index 0ee79f15..51bf0b29 100644 --- a/pkg/auth/faketokenprovider.go +++ b/pkg/auth/faketokenprovider.go @@ -7,6 +7,7 @@ import ( type fakeTokenProvider struct { } +// TODO wecha: temp fake token auth provider. To be replaced with PopTokenAuth func NewFakeTokenProvier() (*fakeTokenProvider, error) { return &fakeTokenProvider{}, nil } diff --git a/pkg/auth/poptoken/jwkmanager.go b/pkg/auth/poptoken/jwkmanager.go new file mode 100644 index 00000000..04f913b8 --- /dev/null +++ b/pkg/auth/poptoken/jwkmanager.go @@ -0,0 +1,60 @@ +package poptoken + +import ( + "context" + "crypto/rsa" + "fmt" + "time" + + "github.com/lestrrat-go/jwx/jwk" + "github.com/pkg/errors" +) + +const ( + IssuerPostfix = "common/discovery/keys" + RefreshJwkInterval = time.Hour * 24 +) + +// Wrapper around jwk library to retrieve and refresh the jwk endpoints from Entra/AAD +type jwkManager struct { + // STS JWK endpoint, e.g. "https://login.microsoftonline.com/common/discovery/keys" + jwkEndpoint string + ar *jwk.AutoRefresh +} + +type JwkInterface interface { + GetPublicKey(kid string) (*rsa.PublicKey, error) +} + +func (j *jwkManager) GetPublicKey(kid string) (*rsa.PublicKey, error) { + ctx := context.Background() + keys, err := j.ar.Fetch(ctx, j.jwkEndpoint) + if err != nil { + return nil, errors.Wrapf(err, "failed to look up jwk endpoint %s to retrieve keys", j.jwkEndpoint) + } + + key, ok := keys.LookupKeyID(kid) + if !ok { + return nil, fmt.Errorf("failed to find kid %s in jwk endpoint %s", kid, j.jwkEndpoint) + } + + var pKey rsa.PublicKey + if err := key.Raw(&pKey); err != nil { + return nil, err + } + + return &pKey, nil +} + +func NewJwkManager(authorityUrl string, refreshInterval time.Duration) (*jwkManager, error) { + jwkEndpoint := appendUrl(authorityUrl, IssuerPostfix) + ctx, _ := context.WithCancel(context.Background()) + ar := jwk.NewAutoRefresh(ctx) + ar.Configure(jwkEndpoint, jwk.WithMinRefreshInterval(refreshInterval)) + _, err := ar.Refresh(ctx, jwkEndpoint) + if err != nil { + return nil, fmt.Errorf("failed to refresh the jwk endpoint %s", jwkEndpoint) + } + + return &jwkManager{jwkEndpoint: jwkEndpoint, ar: ar}, nil +} diff --git a/pkg/auth/poptoken/msalauthprovider.go b/pkg/auth/poptoken/msalauthprovider.go new file mode 100644 index 00000000..ec8b5428 --- /dev/null +++ b/pkg/auth/poptoken/msalauthprovider.go @@ -0,0 +1,87 @@ +package poptoken + +import ( + "context" + "os" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/pkg/errors" +) + +// Msal client to generate the pop token. Note that the msal sdk does not provide pop token support +// out of the box, refer to NodeAgentPopTokenScheme. +type MsalAuthProvider struct { + clientId string + tenantId string + authorityUrl string + scope []string + clientCertPath string +} + +func (m MsalAuthProvider) refreshConfidentialClient() (*confidential.Client, error) { + pemData, err := os.ReadFile(m.clientCertPath) + if err != nil { + return nil, err + } + + cert, privateKey, err := confidential.CertFromPEM(pemData, "") + if err != nil { + return nil, err + } + + // the PEM file can contain multiple intermediate certificates to validate TLS cert chaining but we are only interested in our + // own certificate which is always the first one. See https://www.rfc-editor.org/rfc/rfc5246#section-7.4.2 + cert = cert[:1] + cred, err := confidential.NewCredFromCert(cert, privateKey) + if err != nil { + return nil, errors.Wrapf(err, "failed to create confidential credential from certificate") + } + + // use Subject NAame and Issuer (SN+I) authentication to request for token. For this to work, Withx5c() must be set + // to pass the certificate chain in the request header. + confidentialClient, err := confidential.New(m.authorityUrl, m.clientId, cred, confidential.WithX5C()) + if err != nil { + return nil, errors.Wrapf(err, "failed to create confidential client") + } + + return &confidentialClient, nil +} + +func (m MsalAuthProvider) GetToken(targetResourceId string, grpcObjectPath string) (string, error) { + + // TODO: the underlying client certificate will be refreshed, hence we need to also pick up the new certificate + // Longer run we can cache the client but for now we will refresh the client for every token call. + confidentialClient, err := m.refreshConfidentialClient() + if err != nil { + return "", err + } + + popTokenScheme, err := NewNodeAgentPopTokenAuthScheme(targetResourceId, grpcObjectPath) + if err != nil { + return "", errors.Wrapf(err, "failed to create new pop token scheme") + } + + result, err := confidentialClient.AcquireTokenByCredential(context.Background(), m.scope, confidential.WithAuthenticationScheme(popTokenScheme)) + if err != nil { + return "", errors.Wrapf(err, "failed to get token") + } + return result.AccessToken, nil +} + +func NewMsalClient(clientId string, tenantId, authorityUrl string, clientCertPath string) (*MsalAuthProvider, error) { + m := &MsalAuthProvider{ + clientId: clientId, + tenantId: tenantId, + authorityUrl: appendUrl(authorityUrl, tenantId), + clientCertPath: clientCertPath, + scope: []string{appendUrl(clientId, ".default")}, // intentionally target itself as the pop token custom claim will contain the actual audience. + } + + // sanity check to ensure client is setup correctly + _, err := m.refreshConfidentialClient() + if err != nil { + return nil, err + } + + return m, nil +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme.go b/pkg/auth/poptoken/nodeagentpoptokenscheme.go new file mode 100644 index 00000000..3cbafd22 --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme.go @@ -0,0 +1,28 @@ +package poptoken + +import "github.com/google/uuid" + +// Implements the interface for MSAL SDK to callback when creating the poptoken. +// See AuthenticationScheme interface in https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/main/apps/internal/oauth/ops/authority/authority.go#L146 +type NodeAgentPopTokenAuthScheme struct { + *PopTokenAuthScheme +} + +// Create a new instance of NodeAgentPopTokenAuthScheme. +// targetResourceId: the ARM resourceId representing the edge node machine. This is the Arc For Server resource Id and is part of the node entity. +// grpcObjectId: the uri to the grpc entity, e.g. container. This will be passed in as part of the grpc metadata. +func NewNodeAgentPopTokenAuthScheme(targetNodeId string, grpcObjectId string) (*NodeAgentPopTokenAuthScheme, error) { + popTokenScheme, err := NewPopTokenAuthScheme( + map[string]interface{}{ + "nodeid": targetNodeId, + "p": grpcObjectId, + "nonce": uuid.New().String(), + }) + if err != nil { + return nil, err + } + + return &NodeAgentPopTokenAuthScheme{ + PopTokenAuthScheme: popTokenScheme, + }, nil +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go new file mode 100644 index 00000000..0cd55d5c --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenscheme_test.go @@ -0,0 +1,34 @@ +package poptoken + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// This test suite focus on the testing the custom claims of nodeagentpoptokenscheme is returned +// poptokenscheme_test the underlying poptokenscheme_test +func Test_NodeAgentPopTokenScheme(t *testing.T) { + expectedNodeId := "mynodeId" + expectedGrpcObjectId := "myObjectId" + + // Generate nodeagent scheme + nodeAgentScheme, err := NewNodeAgentPopTokenAuthScheme(expectedNodeId, expectedGrpcObjectId) + + //For nodeagentpoptokenscheme, we just verify that the custom claims were added to the token. + popToken, err := nodeAgentScheme.FormatAccessToken("accessToken") + assert.Nil(t, err) + assert.NotEmpty(t, popToken) + + toks := strings.Split(popToken, ".") + assert.Equal(t, 3, len(toks)) + + body, err := decodeFromBase64[NodeAgentPopTokenBody](toks[1]) + + assert.Nil(t, err) + assert.Equal(t, expectedNodeId, body.NodeId) + assert.Equal(t, expectedGrpcObjectId, body.GrpcObjectId) + assert.NotEmpty(t, body.Nonce) + +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go new file mode 100644 index 00000000..bf98697f --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator.go @@ -0,0 +1,371 @@ +package poptoken + +import ( + "bytes" + "crypto" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "math/big" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/pkg/errors" +) + +const ( + TokenVersion1 = "1.0" + TokenVersion2 = "2.0" + PopTokenValidInterval = 5 * time.Minute + PopTokenClockSkew = 5 * time.Minute +) + +type NodeAgentPopTokenBody struct { + PopTokenBody + // target node Id. this is expected to be the Arc For Server resource Id. + NodeId string `json:"nodeid"` + // uri to the grpc object targeted + GrpcObjectId string `json:"p"` + // unique id bound to the token to prevent replay attaches. + Nonce string `json:"nonce"` +} + +// contains a subset of custom claims in Entra/AzureAD access tokens we want to validate. +// See https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#payload-claims +type AccessTokenCustomClaims struct { + // contains the public key kid use to sign the pop token. Verify this matches the kid in the poptoken body. + ReqCnf ReqCnf `json:"cnf"` + // requester Id, i.e. CMP 1P. Expected to match ShrPopTokenValidator.ClientId. Only valid for token v2 + Azp string `json:"azp"` + // requester Id, i.e. CMP 1P/identifuerUri. Expected to match ShrPopTokenValidator.ClientId. Only valid for token v1 + AppId string `json:"appid"` + // Tenant Id. Expected to match ShrPopTokenValidator.TenantId + Tid string `json:"tid"` + // token version. + TokenVersion string `json:"ver"` + jwt.RegisteredClaims +} + +type shrPopTokenValidator struct { + // id of node which the pop token is bounded to. This is expected to be the Arc For Server (A4S) agent resourceId + NodeId string + // pathurl of grpc entity that the pop token is bounded to, e.g. the uri to the storage container entity. + GrpcObjectId string + // Tenant Id of CMP 1P + TenantId string + // The target Id. In this case, client Id or one of its identifierUri, depending on the token version. + Audience map[string]bool + // CMP 1P client Id + ClientId string + // Issuer url, e.g. https://login.microsoftonline.com/ + IssuerUrl string + // component use to query Entra JWK endpoint. + jwk JwkInterface + // component use to cache and check request's nonceCache to prevent replay attack + nonceCache NonceCacheInterface +} + +// handle situation where url may or may not have a backslash +func appendUrl(url string, postfix string) string { + sep := "/" + if strings.HasSuffix(url, "/") { + sep = "" + } + return fmt.Sprintf("%s%s%s", url, sep, postfix) +} + +func isTokenExpire(timestamp int64, now time.Time, clockSkew time.Duration) error { + var issuedTime time.Time + convertTime(timestamp, &issuedTime) + expireat := issuedTime.Add(PopTokenValidInterval) + skewedTime := now.Add(-clockSkew) + if expireat.Before(skewedTime) { + return fmt.Errorf("pop token has expired. currentTimestamp: %v, issuedAt: %v, validDuration: %v", now, issuedTime, PopTokenValidInterval) + } + return nil +} +func isKidValid(cnf *Cnf, actualKid string) error { + expectedKid, err := calculatePublicKeyId(&cnf.Jwk) + if err != nil { + return errors.Wrapf(err, "failed to generate kid from pop token cnf") + } + if expectedKid != actualKid { + return fmt.Errorf("pop token header kid %s does not match kid %s as generated from 'cnf.jwk'", actualKid, expectedKid) + } + + return nil +} + +func isHeaderValid(header *PopTokenHeader) error { + if header.Typ != TokenType { + return fmt.Errorf("unsupported token type in pop token header; expected %s, got %s", TokenType, header.Typ) + } + if header.Alg != Alg { + return fmt.Errorf("unsupported alg in pop token header, expected %s, got %s", Alg, header.Alg) + } + + return nil +} + +func isSignatureValid(signingStr *string, signature []byte, cnf *Cnf) error { + publicKey, err := publicRSA256KeyFromCnf(cnf) + if err != nil { + return err + } + + return verifyPayload(signingStr, []byte(signature), publicKey) +} + +func verifyPayload(signingStr *string, sig []byte, pubKey *rsa.PublicKey) error { + hash := sha256.New() + hash.Write([]byte(*signingStr)) + err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash.Sum(nil), sig) + if err != nil { + return errors.Wrapf(err, "failed to validate signature of pop token") + } + return nil +} + +func publicRSA256KeyFromCnf(cnf *Cnf) (*rsa.PublicKey, error) { + modulus, err := base64.RawURLEncoding.DecodeString(cnf.Jwk.N) + if err != nil { + err := errors.Wrapf(err, "error while parsing pop token cnf: failed to decode modulus") + return nil, err + } + n := new(big.Int).SetBytes(modulus) + + e, err := base64ToExponential(string(cnf.Jwk.E)) + if err != nil { + err := errors.Wrapf(err, "error while parsing pop token cnf: failed to parse exponent") + return nil, err + } + pKey := rsa.PublicKey{N: n, E: int(e)} + return &pKey, nil +} + +func base64ToExponential(encodedE string) (int, error) { + decE, err := base64.RawURLEncoding.DecodeString(encodedE) + if err != nil { + return 0, err + } + + var eBytes []byte + if len(decE) < 8 { + eBytes = make([]byte, 8-len(decE), 8) + eBytes = append(eBytes, decE...) + } else { + eBytes = decE + } + eReader := bytes.NewReader(eBytes) + var ee uint64 + err = binary.Read(eReader, binary.BigEndian, &ee) + if err != nil { + return 0, err + } + + return int(ee), nil +} + +func decodeFromBase64[T any](jsonData string) (T, error) { + var t T + var err error + + bytes, err := base64.RawURLEncoding.DecodeString(jsonData) + if err != nil { + return t, err + } + + err = json.Unmarshal(bytes, &t) + if err != nil { + return t, err + } + + return t, nil +} + +func convertTime(i any, tm *time.Time) { + switch iat := i.(type) { + case float64: + *tm = time.Unix(int64(iat), 0) + case int64: + *tm = time.Unix(iat, 0) + case string: + v, _ := strconv.ParseInt(iat, 10, 64) + *tm = time.Unix(v, 0) + } +} + +func (s *shrPopTokenValidator) parseAndValidateAccessToken(tokenStr string, popTokenKid string) error { + // ParseWithClaims() will validate the token expiry date and signing. + at, err := jwt.ParseWithClaims(tokenStr, &AccessTokenCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("missing metadata 'kid' 'in the header of claim 'at'") + } + + pKey, err := s.jwk.GetPublicKey(kid) + if err != nil { + return nil, err + } + + return pKey, nil + }, jwt.WithValidMethods([]string{jwt.SigningMethodRS256.Alg()})) + + if err != nil { + return errors.Wrapf(err, "failed to validate claim 'at'") + } + + err = s.validateAccessTokenClaims(at, popTokenKid) + return err +} + +func (s *shrPopTokenValidator) validateAccessTokenClaims(token *jwt.Token, popTokenKid string) error { + if token == nil { + return fmt.Errorf("missing claim 'at' in pop token.") + } + + // now read in Entra access token's specific claims and validate them + claims, ok := token.Claims.(*AccessTokenCustomClaims) + if !ok { + return fmt.Errorf("failed to parse claim 'at' in pop token.") + } + + // Handle claims that are specifc to token versions + switch claims.TokenVersion { + case TokenVersion1: + if claims.AppId != s.ClientId { + return fmt.Errorf("claim 'at.appId' does not match the expected app Id. token is not from an accepted client") + } + if claims.Issuer != s.IssuerUrl { + return fmt.Errorf("claim 'at.iss' points to an invalid issuer for v1 token") + } + case TokenVersion2: + if claims.Azp != s.ClientId { + return fmt.Errorf("claim 'at.azp' does not match the expected app Id. token is not from an accepted client") + } + // for v2, issuer ends with v2.0 + if claims.Issuer != appendUrl(s.IssuerUrl, "v2.0") { + return fmt.Errorf("claim 'at.iss' points to an invalid issuer for v2 token") + } + + default: + return fmt.Errorf("claim 'at.ver' has an unknown token version %s. expected either %sor %s", claims.TokenVersion, TokenVersion1, TokenVersion2) + } + + if claims.ReqCnf.Kid != popTokenKid { + return fmt.Errorf("claim 'kid' in pop token did not match kid in claim 'at.cnf.kid'. pop token may have been signed with an unexpected key.") + } + + foundAud := false + for _, aud := range claims.Audience { + if _, ok := s.Audience[aud]; ok { + foundAud = true + break + } + } + if !foundAud { + return fmt.Errorf("claim 'at.aud' did not match the expected audience") + } + + return nil +} + +func (s *shrPopTokenValidator) isCustomClaimsValid(body *NodeAgentPopTokenBody) error { + if body.NodeId != s.NodeId { + return fmt.Errorf("claim 'nodeid'in pop token does not match the id of the edge node.") + } + if body.GrpcObjectId != s.GrpcObjectId { + return fmt.Errorf("claim 'p' in pop token does not match the targeted moc entity") + } + return nil +} + +func (s *shrPopTokenValidator) isTokenReused(nonce string, now time.Time) error { + if nonce == "" { + return fmt.Errorf("claim 'nonce' in pop token is empty/missing, expected a unique id") + } + + ok := s.nonceCache.IsNonceExists(nonce, now, PopTokenValidInterval) + if ok { + return fmt.Errorf("claim 'nonce' in pop token has been used before, potentially a replay attack") + } + return nil +} + +func (s *shrPopTokenValidator) Validate(popToken string) error { + toks := strings.Split(popToken, ".") + if len(toks) != 3 { + return fmt.Errorf("invalid pop token; expected 3 segments, got %d", len(toks)) + } + + header, err := decodeFromBase64[PopTokenHeader](toks[0]) + if err != nil { + return err + } + + if err := isHeaderValid(&header); err != nil { + return err + } + + body, err := decodeFromBase64[NodeAgentPopTokenBody](toks[1]) + if err != nil { + return err + } + + if err := isTokenExpire(body.Ts, time.Now(), PopTokenClockSkew); err != nil { + return err + } + + if err := s.isTokenReused(body.Nonce, time.Now()); err != nil { + return err + } + + if err := isKidValid(&body.Cnf, header.Kid); err != nil { + return err + } + + if err := s.isCustomClaimsValid(&body); err != nil { + return err + } + + signature, err := base64.RawURLEncoding.DecodeString(toks[2]) + if err != nil { + return err + } + signingStr := strings.Join([]string{toks[0], toks[1]}, ".") + err = isSignatureValid(&signingStr, signature, &body.Cnf) + if err != nil { + return err + } + + // now retrieve the inner access token + err = s.parseAndValidateAccessToken(body.At, header.Kid) + if err != nil { + return err + } + + return nil +} + +func NewPopTokenValidator(nodeId string, grpcobjectId string, tenantId string, audiences []string, clientId string, authorityUrl string, jwk JwkInterface, nonceChecker NonceCacheInterface) (*shrPopTokenValidator, error) { + audienceMap := make(map[string]bool) + for _, aud := range audiences { + audienceMap[aud] = true + } + + return &shrPopTokenValidator{ + NodeId: nodeId, + GrpcObjectId: grpcobjectId, + TenantId: tenantId, + Audience: audienceMap, + ClientId: clientId, + IssuerUrl: appendUrl(authorityUrl, tenantId), + jwk: jwk, + nonceCache: nonceChecker, + }, nil +} diff --git a/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go new file mode 100644 index 00000000..b4d7359f --- /dev/null +++ b/pkg/auth/poptoken/nodeagentpoptokenvalidator_test.go @@ -0,0 +1,613 @@ +package poptoken + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func Test_NodeAgentPopTokenValidatorAppendUrl(t *testing.T) { + + tests := []struct { + name string + url string + postfix string + expectedUrl string + }{ + { + name: "without backslash at end", + url: "http://localhost", + postfix: "myapi", + expectedUrl: "http://localhost/myapi", + }, + { + name: "with backslash at end", + url: "http://localhost/", + postfix: "myapi", + expectedUrl: "http://localhost/myapi", + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actualUrl := appendUrl(tt.url, tt.postfix) + assert.Equal(t, tt.expectedUrl, actualUrl) + }) + } +} + +func Test_NodeAgentPopTokenValidatorIsKidValid(t *testing.T) { + + newKeyPair, err := generateKeyPair() + assert.Nil(t, err) + cnf := publicKeyToCnf(newKeyPair) + expectedKid, err := calculatePublicKeyId(&cnf.Jwk) + assert.Nil(t, err) + + tests := []struct { + name string + expectedKid string + shouldPass bool + }{ + { + name: "valid kid", + expectedKid: expectedKid, + shouldPass: true, + }, + { + name: "invalid kid", + expectedKid: "randomKid", + shouldPass: false, + }} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isKidValid(cnf, tt.expectedKid) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +func Test_NodeAgentPopTokenValidatorIsTokenExpire(t *testing.T) { + tokenIssuedAt, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") + assert.Nil(t, err) + tokenIssuedAtInt := tokenIssuedAt.Truncate(time.Second).Unix() + + tests := []struct { + name string + tokenCheckAt time.Time + clockSkew time.Duration + shouldPass bool + }{ + { + name: "token valid", + // set token evaluation time to be 1 second after token was issued, token is valid. + tokenCheckAt: tokenIssuedAt.Add(time.Second * 1), + clockSkew: 0, + shouldPass: true, + }, + { + name: "token expired", + // set token evaluation time to be 10 seconds after max valid period. token has expired. + tokenCheckAt: tokenIssuedAt.Add(PopTokenValidInterval).Add(time.Second * 10), + clockSkew: 0, + shouldPass: false, + }, + { + name: "token pass due to clock skew", + // set token evaluation time to be 10 seconds after max valid period. token should have expired + // like previous test, but passes thanks to the allowed clock skew of 11 seconds. + tokenCheckAt: tokenIssuedAt.Add(PopTokenValidInterval).Add(time.Second * 10), + clockSkew: time.Second * 11, + shouldPass: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isTokenExpire(tokenIssuedAtInt, tt.tokenCheckAt, tt.clockSkew) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +func Test_NodeAgentPopTokenValidatorIsHeaderValid(t *testing.T) { + tests := []struct { + name string + header PopTokenHeader + shouldPass bool + }{ + { + name: "valid header", + shouldPass: true, + header: PopTokenHeader{Alg: Alg, Typ: TokenType}, + }, + { + name: "invalid alg", + shouldPass: false, + header: PopTokenHeader{Alg: "RSA123", Typ: TokenType}, + }, + { + name: "invalid typ", + shouldPass: false, + header: PopTokenHeader{Alg: Alg, Typ: "jwt"}, + }, + { + name: "empty header", + shouldPass: false, + header: PopTokenHeader{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isHeaderValid(&tt.header) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +func Test_NodeAgentPopTokenValidatorIsTokenReused(t *testing.T) { + nonceCache := &FakeNonceCache{Exists: false} + + // for this test, we only care about setting the noncecache + popTokenValidator, err := NewPopTokenValidator("", "", "", []string{"aud"}, "", "", nil, nonceCache) + assert.Nil(t, err) + + // simulate nonce entry does not exists in cache, i.e. we have not seen the token before + nonceCache.Exists = false + err = popTokenValidator.isTokenReused("myId", time.Now()) + assert.Nil(t, err) + + // simulate nonce entry exists in cache, i.e. same token is reused, potentially a replay token + nonceCache.Exists = true + err = popTokenValidator.isTokenReused("myId", time.Now()) + assert.NotNil(t, err) + + // simulate missing nonce. we will reject token + nonceCache.Exists = false + err = popTokenValidator.isTokenReused("", time.Now()) + assert.NotNil(t, err) +} + +func Test_NodeAgentPopTokenValidatorIsSignatureValid(t *testing.T) { + keypair, err := generateKeyPair() + assert.Nil(t, err) + + payload := []byte("ThisIsPayload") + payloadStr := string(payload) + + sigEncoded, err := signPayload(payload, keypair.PrivateKey) + assert.Nil(t, err) + sig, err := base64.RawURLEncoding.DecodeString(sigEncoded) + assert.Nil(t, err) + + cnf := publicKeyToCnf(keypair) + + // signature is correctly validated given correct payload and public key. + err = isSignatureValid(&payloadStr, sig, cnf) + assert.Nil(t, err) + + // simulate a mangled payload, expect failure + badPayload := payloadStr + "baddata" + + err = isSignatureValid(&badPayload, sig, cnf) + assert.NotNil(t, err) + + // simulate bad sig, expect failure + badSig := string(sig) + "1111" + + err = isSignatureValid(&payloadStr, []byte(badSig), cnf) + assert.NotNil(t, err) + + // simulate wrong public key, expect failure + newKeyPair, err := generateKeyPair() + assert.Nil(t, err) + misMatachCnf := publicKeyToCnf(newKeyPair) + + err = isSignatureValid(&payloadStr, sig, misMatachCnf) + assert.NotNil(t, err) +} + +func Test_NodeAgentPopTokenValidatorbase64ToExponential(t *testing.T) { + encodedExponential := "AQAB" + e, err := base64ToExponential(encodedExponential) + assert.Nil(t, err) + // this is the decoded value of a well known exponential value + assert.Equal(t, 65537, e) +} + +func Test_NodeAgentPopTokenValidatorIsCustomClaimsValid(t *testing.T) { + expectedNodeId := "myNodeId" + expectedGrpcObjectId := "myObjectId" + + tests := []struct { + name string + actualNodeId string + actualGrpcObjectId string + shouldPass bool + }{ + { + name: "valid nodeId and objectId claims", + actualNodeId: expectedNodeId, + actualGrpcObjectId: expectedGrpcObjectId, + shouldPass: true, + }, + { + name: "invalid nodeId claim", + actualNodeId: "somethingelse", + actualGrpcObjectId: expectedGrpcObjectId, + shouldPass: false, + }, + { + name: "missing nodeId claim", + actualNodeId: "", + actualGrpcObjectId: expectedGrpcObjectId, + shouldPass: false, + }, + { + name: "invalid grpcObjectId claim", + actualNodeId: expectedNodeId, + actualGrpcObjectId: "somethingelse", + shouldPass: false, + }, + { + name: "missing grpcObjectId claim", + actualNodeId: expectedNodeId, + actualGrpcObjectId: "", + shouldPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // for this test, we only care about initializeing the custom claims + popTokenValidator, err := NewPopTokenValidator(expectedNodeId, expectedGrpcObjectId, "", []string{"aud"}, "", "", nil, nil) + assert.Nil(t, err) + // likewise we only set the custom claims in the poptoken body + popTokenBody := NodeAgentPopTokenBody{NodeId: tt.actualNodeId, GrpcObjectId: tt.actualGrpcObjectId} + + err = popTokenValidator.isCustomClaimsValid(&popTokenBody) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +// Test the access token validation against a range of invalid claims. +func Test_NodeAgentPopTokenValidatorParseAndValidateAccessToken(t *testing.T) { + expectedAuthorityUrl := "https://login.fake.microsoftonline.com" + expectedTenantId := "cmpTenantId" + expectedClientId := "cmpClientId" + expectedAudience := "cmpAudience" + expectedIssuedTime := time.Now() + expectedPopTokenKid := "defaultKid" + + defaultPKey, _ := getPrivateKey() + defaultJwkMgr := &FakeJwrMgr{PublicKey: &defaultPKey.PublicKey} + nonceCache := &FakeNonceCache{Exists: false} + // By default, the accesstoken and validator will use the same expected values as listed above + // hence the access token validation will succeeded. + + // test is setup such that each invalid claim declared will override the access token's ciaim, causing it to be invalid. + tests := []struct { + name string + invalidAuthorityUrl string + invalidTenantId string + invalidClientId string + invalidAudience string + invalidIssuerUrl string + invalidPopTokenKid string + tokenVersion string + isInvalidSigning bool + isMissingKid bool + isExpiredToken bool + shouldPass bool + }{ + { + name: "valid access token v1", + tokenVersion: TokenVersion1, + shouldPass: true, + }, + { + name: "valid access token v1", + tokenVersion: TokenVersion2, + shouldPass: true, + }, + { + name: "invalid tenantId", + tokenVersion: TokenVersion2, + invalidTenantId: "badTenantId", + shouldPass: false, + }, + { + name: "invalid clientId", + tokenVersion: TokenVersion2, + invalidClientId: "badClientId", + shouldPass: false, + }, + { + name: "invalid audience", + tokenVersion: TokenVersion2, + invalidAudience: "badAudienceId", + shouldPass: false, + }, + { + name: "invalid pop token id", + tokenVersion: TokenVersion2, + invalidPopTokenKid: "badPopTokenId", + shouldPass: false, + }, + { + name: "invalid signing", + tokenVersion: TokenVersion2, + isInvalidSigning: true, + shouldPass: false, + }, + { + name: "token expired", + tokenVersion: TokenVersion2, + isExpiredToken: true, + shouldPass: false, + }, + { + name: "missing kid", + tokenVersion: TokenVersion2, + isMissingKid: true, + shouldPass: false, + }, + { + name: "invalid issuer", + tokenVersion: TokenVersion2, + invalidAuthorityUrl: "https://bad.issuer", + shouldPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tenantId := expectedTenantId + if tt.invalidTenantId != "" { + tenantId = tt.invalidTenantId + } + + clientId := expectedClientId + if tt.invalidClientId != "" { + clientId = tt.invalidClientId + } + + // to simulate expired token, we make it invalid after one second and sleep before validating token + newexpiry := time.Second + expiredTime := expectedIssuedTime.Add(time.Hour) + if tt.isExpiredToken { + expiredTime = expectedIssuedTime.Add(newexpiry) + } + + tokenVersion := tt.tokenVersion + + authorityUrl := expectedAuthorityUrl + if tt.invalidAuthorityUrl != "" { + authorityUrl = tt.invalidAuthorityUrl + } + + audience := expectedAudience + if tt.invalidAudience != "" { + audience = tt.invalidAudience + } + + issuserUrl := fmt.Sprintf("%s/%s", authorityUrl, tenantId) + if tokenVersion == TokenVersion2 { + issuserUrl = fmt.Sprintf("%s/v2.0", issuserUrl) + } + + popTokenKid := expectedPopTokenKid + if tt.invalidPopTokenKid != "" { + popTokenKid = tt.invalidPopTokenKid + } + + jwkMgr := defaultJwkMgr + if tt.isInvalidSigning { + newPKey, _ := getPrivateKey() + jwkMgr = &FakeJwrMgr{PublicKey: &newPKey.PublicKey} + } + + if tt.isMissingKid { + jwkMgr = &FakeJwrMgr{Err: fmt.Errorf("failed to find key")} + } + + // access token to be generated can contain invalid claims depending on the test param, + // by default it will use the same data as the token validator. + claims := AccessTokenCustomClaims{ + Tid: tenantId, + ReqCnf: ReqCnf{Kid: popTokenKid}, + Azp: clientId, + AppId: clientId, + TokenVersion: tokenVersion, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuserUrl, + ExpiresAt: jwt.NewNumericDate(expiredTime), + IssuedAt: jwt.NewNumericDate(expectedIssuedTime), + NotBefore: jwt.NewNumericDate(expectedIssuedTime), + Audience: jwt.ClaimStrings{audience}, + Subject: clientId, + }, + } + + // The token validator is set to the expected values, except for invalid jwk + tokenValidator, err := NewPopTokenValidator( + "notused", // this is not tested here. + "notused", // this is not tested here. + expectedTenantId, + []string{expectedAudience}, + expectedClientId, + expectedAuthorityUrl, + jwkMgr, + nonceCache) + assert.Nil(t, err) + + s, err := generateAccessToken(&claims, defaultPKey) + assert.Nil(t, err) + + // sleep for a while to ensure token expires + if tt.isExpiredToken { + time.Sleep(newexpiry * 2) + } + + err = tokenValidator.parseAndValidateAccessToken(s, expectedPopTokenKid) + if tt.shouldPass { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + } + }) + } +} + +// Test the overall validate function. We have already tested the individual functions that makes up this call +// this is just a simple test to validate end to end. +func Test_NodeAgentPopTokenValidatorValidate(t *testing.T) { + authorityUrl := "https://login.fake.microsoftonline.com" + nodeId := "nodeId" + grpcObjectId := "objectId" + tenantId := "cmpTenantId" + clientId := "cmpClientId" + audience := "cmpAudience" + issuedTime := time.Now() + expiredTime := issuedTime.Add(time.Hour) + tokenVersion := TokenVersion2 + issuerUrl := fmt.Sprintf("%s/%s/v2.0", authorityUrl, tenantId) + + accessTokenPKey, err := getPrivateKey() + assert.Nil(t, err) + jwkMgr := &FakeJwrMgr{PublicKey: &accessTokenPKey.PublicKey} + nonceCache := &FakeNonceCache{Exists: false} + + //rsaKeyPair, err := generateKeyPair() + //assert.Nil(t, err) + + // partial generate pop token, we need to add the popKid into the accesstoken + popToken, err := NewNodeAgentPopTokenAuthScheme(nodeId, grpcObjectId) + assert.Nil(t, err) + + // Generate access token + claims := AccessTokenCustomClaims{ + Tid: tenantId, + ReqCnf: ReqCnf{Kid: popToken.header.Kid}, + Azp: clientId, + AppId: clientId, + TokenVersion: tokenVersion, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuerUrl, + ExpiresAt: jwt.NewNumericDate(expiredTime), + IssuedAt: jwt.NewNumericDate(issuedTime), + NotBefore: jwt.NewNumericDate(issuedTime), + Audience: jwt.ClaimStrings{audience}, + Subject: clientId, + }, + } + at, err := generateAccessToken(&claims, accessTokenPKey) + assert.Nil(t, err) + + // Generate pop token + pt, err := popToken.FormatAccessToken(at) + assert.Nil(t, err) + + // validate poptoken + tokenValidator, err := NewPopTokenValidator( + nodeId, + grpcObjectId, + tenantId, + []string{audience}, + clientId, + authorityUrl, + jwkMgr, + nonceCache) + assert.Nil(t, err) + + err = tokenValidator.Validate(pt) + assert.Nil(t, err) +} + +func generateAccessToken(claims *AccessTokenCustomClaims, privateKey *rsa.PrivateKey) (string, error) { + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": claims.Issuer, + "sub": claims.Subject, + "aud": claims.Audience, + "exp": claims.ExpiresAt, + "nbf": claims.NotBefore, + "iat": claims.IssuedAt, + "cnf": claims.ReqCnf, + "azp": claims.Azp, + "appId": claims.AppId, + "ver": claims.TokenVersion, + "tid": claims.Tid, + }) + + // we don't care about the actual kid value here since we use a fake jwkMgr that ignores the kidm but we do + // expect it to be present in the header. + tok.Header["kid"] = "notused" + at, err := tok.SignedString(privateKey) + if err != nil { + return "", err + } + + return at, err +} + +// fake jwkMgr that will return either the public key or error +type FakeJwrMgr struct { + PublicKey *rsa.PublicKey + Err error +} + +func (j *FakeJwrMgr) GetPublicKey(kid string) (*rsa.PublicKey, error) { + if j.Err != nil { + return nil, j.Err + } else { + return j.PublicKey, nil + } +} + +// fake nonceCache that we can set to return true or false at will. +type FakeNonceCache struct { + Exists bool +} + +func (n *FakeNonceCache) IsNonceExists(nonceId string, now time.Time, tokenValidInterval time.Duration) bool { + return n.Exists +} + +func getPrivateKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, RsaSize) +} + +func publicKeyToCnf(keyPair *rsaKeyPair) *Cnf { + return &Cnf{ + Jwk: Jwk{ + Kty: Kty, + E: exponential2Base64(keyPair.PublicKey.E), + N: base64.RawURLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.Bytes())), + }, + } +} diff --git a/pkg/auth/poptoken/noncecache.go b/pkg/auth/poptoken/noncecache.go new file mode 100644 index 00000000..0e75a57e --- /dev/null +++ b/pkg/auth/poptoken/noncecache.go @@ -0,0 +1,92 @@ +package poptoken + +import ( + "sync" + "time" +) + +const ( + DefaultNonceCacheSize = 20 +) + +type NonceCacheInterface interface { + IsNonceExists(nonceId string, now time.Time, tokenValidInterval time.Duration) bool +} + +type Nonce struct { + Id string + ExpireAtDateTime time.Time +} + +// Implement a simple LRU cache that evicts older nonce entries. +type nonceCache struct { + cache map[string]*Nonce + queue []*Nonce + nonceValidInterval time.Duration + size int + maxSize int + mutex sync.Mutex +} + +func (n *nonceCache) append(nonce *Nonce) { + n.cache[nonce.Id] = nonce + n.queue = append(n.queue, nonce) + n.size++ + return + +} + +// keep trimming the cache until all expired entries are purged or we are within the cache size +func (n *nonceCache) trim(now time.Time) { + isDelete := true + + for isDelete { + if len(n.queue) == 0 { + break + } + + nonce := n.queue[0] + isDelete = nonce.ExpireAtDateTime.Before(now) + + if !isDelete { + isDelete = n.size >= n.maxSize + } + if isDelete { + delete(n.cache, nonce.Id) + n.queue = n.queue[1:] + n.size-- + } + } +} + +func (n *nonceCache) IsNonceExists(nonceId string, now time.Time, tokenValidInterval time.Duration) bool { + n.mutex.Lock() + defer n.mutex.Unlock() + + _, ok := n.cache[nonceId] + if ok { + return ok + } + + nonce := &Nonce{ + Id: nonceId, + ExpireAtDateTime: now.Add(tokenValidInterval), + } + n.append(nonce) + n.trim(now) + + return false + +} + +func (n *nonceCache) GetCacheSize() int { + return n.size +} + +func NewNonceCache(maxSize int) (*nonceCache, error) { + return &nonceCache{ + cache: make(map[string]*Nonce), + size: 0, + maxSize: maxSize, + }, nil +} diff --git a/pkg/auth/poptoken/noncecache_test.go b/pkg/auth/poptoken/noncecache_test.go new file mode 100644 index 00000000..ba174ab1 --- /dev/null +++ b/pkg/auth/poptoken/noncecache_test.go @@ -0,0 +1,126 @@ +package poptoken + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + testNonceCacheSize = 3 + testTokenValidInterval = time.Minute * 1 + testNonceNowDateTimeStr = "2025-12-01T15:00:00Z" +) + +func Test_NonceCacheIdExists(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize) + assert.Nil(t, err) + + nonceId := "nonceId_1" + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + + // first time seeing this nodeId, returning false + isexist := noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) + assert.False(t, isexist) + + // the second time the nonceId should be cached. + isexist = noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) + assert.True(t, isexist) + + // Validate a new entry will return false + isexist = noncecache.IsNonceExists("nonceId_2", now, testTokenValidInterval) + assert.False(t, isexist) +} + +func Test_NonceCacheIdExistsButExpired(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize) + assert.Nil(t, err) + + nonceId := "nonceId_1" + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + + // first time seeing this nodeId, returning false + isexist := noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) + assert.False(t, isexist) + + // the second time the nonceId should be cached. + isexist = noncecache.IsNonceExists(nonceId, now, testTokenValidInterval) + assert.True(t, isexist) +} + +// Validate that expired Ids will be evicted upon the addition of a new entry. +func Test_NonceCacheEvictExpiredIds(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize) + assert.Nil(t, err) + + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + for i := 0; i < testNonceCacheSize-1; i++ { + id := fmt.Sprintf("%d", i) + now = now.Add(time.Second) + noncecache.IsNonceExists(id, now, testTokenValidInterval) + + //need to call twice to confirm the nonceId were added, since the first time + // it is added, it will not exist + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) + assert.True(t, isexist) + } + + // simulate querying a new nonce Id after time where the previously added ids expired. + // adding the new entry will trigger an eviction of the expired entries + newId := "new" + now = now.Add(testTokenValidInterval * 2) + noncecache.IsNonceExists(newId, now, testTokenValidInterval) + + // validate older entry has been evicted; size of cache should be 1 + // we check the size before checking if the older ids have been evicted as they will get + // readded again. + assert.Equal(t, 1, noncecache.GetCacheSize()) + + // validate the ids no longer exists in cache + for i := 0; i < testNonceCacheSize-1; i++ { + id := fmt.Sprintf("%d", i) + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) + assert.False(t, isexist) + } + +} + +// Validate that the oldest Ids will be evicted upon the addition of a new entry that exceeds the cache size. +func Test_NonceCacheEvictOverflowIds(t *testing.T) { + noncecache, err := NewNonceCache(testNonceCacheSize) + assert.Nil(t, err) + + idsToAddCount := testNonceCacheSize + 2 + now, _ := time.Parse(time.RFC3339, testNonceNowDateTimeStr) + for i := 0; i < idsToAddCount; i++ { + id := fmt.Sprintf("%d", i) + now = now.Add(time.Second) + noncecache.IsNonceExists(id, now, testTokenValidInterval) + + //need to call twice to confirm the nonceId were added, since the first time + // it is added, it will not exist + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) + assert.True(t, isexist) + + // validate size of cache does not exceed the max even if more ids were added. + assert.True(t, noncecache.GetCacheSize() <= testNonceCacheSize) + } + + // when we add more ids than supported, the oldest ids get evicted. + // verify the earlier ids should no longer exist. We need to check in reverse order + // to avoid evicting the newer entries + for i := idsToAddCount - 1; i >= 0; i-- { + id := fmt.Sprintf("%d", i) + now = now.Add(time.Second) + + isexist := noncecache.IsNonceExists(id, now, testTokenValidInterval) + if i >= testNonceCacheSize { + assert.True(t, isexist) + } else { + assert.False(t, isexist) + } + } + +} diff --git a/pkg/auth/poptoken/poptokenauth.go b/pkg/auth/poptoken/poptokenauth.go new file mode 100644 index 00000000..5f354f10 --- /dev/null +++ b/pkg/auth/poptoken/poptokenauth.go @@ -0,0 +1,44 @@ +package poptoken + +import ( + "context" + + "github.com/microsoft/moc/pkg/errors" +) + +/* +The setup of the pop token creaton is as follows: + PopTokenAuth (interface betwen grpc and msalauthprovider) + | + --> MsalAuthProvider (global component that request the token from Entra/AzureAAD via MSAL SDK) + | + --> NodeAgentPopTokenAuthScheme (wrapper that adds the claims specific to node agent) + | + --> PopTokenAuthScheme(a more generic pop tokens implementation) +*/ + +// This component integrates the MSAL provider to the grpc credentials.PerRPCCredentials interface +type PopTokenAuth struct { + msalauthprovider *MsalAuthProvider + targetResourceId string +} + +func NewPopTokenAuth(msalProvider *MsalAuthProvider, targetResourceId string) (*PopTokenAuth, error) { + return &PopTokenAuth{ + msalauthprovider: msalProvider, + targetResourceId: targetResourceId, + }, nil +} + +func (p *PopTokenAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + accessToken, err := p.msalauthprovider.GetToken(p.targetResourceId, uri[0]) + if err != nil { + return nil, errors.Wrapf(err, "failed to generate poptoken") + } + + return map[string]string{"authorization": accessToken, "uri": uri[0]}, nil +} + +func (p *PopTokenAuth) RequireTransportSecurity() bool { + return true +} diff --git a/pkg/auth/poptoken/poptokenscheme.go b/pkg/auth/poptoken/poptokenscheme.go new file mode 100644 index 00000000..aef61a61 --- /dev/null +++ b/pkg/auth/poptoken/poptokenscheme.go @@ -0,0 +1,256 @@ +package poptoken + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/json" + "reflect" + "strings" + "time" + + "github.com/pkg/errors" +) + +const ( + refreshInterval time.Duration = time.Hour * 8 + TokenType = "pop" + DefaultRefreshInterval = time.Hour * 8 + RsaSize = 2048 + Kty = "RSA" + Alg = "RS256" +) + +var ( + globalRsaKey *rsaKeyPair = nil + globalRefreshInterval time.Duration = DefaultRefreshInterval + globalLastRefreshRsaKeyDateTime time.Time = time.Now() +) + +type rsaKeyPair struct { + PrivateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey +} + +type PopTokenHeader struct { + // RSA PS256? + Alg string `json:"alg"` + // key Id of public key + Kid string `json:"kid"` + // always pop + Typ string `json:"typ"` +} + +// https://datatracker.ietf.org/doc/html/rfc7638#section-3.1 +// contains the metadata use to calculate kid. +type Jwk struct { + // Exponent + E string `json:"e"` + // encryption + Kty string `json:"kty"` + // modulus + N string `json:"n"` +} + +// https://datatracker.ietf.org/doc/html/rfc7800#section-3.2 +type Cnf struct { + Jwk Jwk `json:"jwk"` +} + +type ReqCnf struct { + Kid string `json:"kid"` +} + +// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-signed-http-request-03#section-3 +type PopTokenBody struct { + Cnf Cnf `json:"cnf"` + // timestamp + Ts int64 `json:"ts"` + // access token + At string `json:"at"` +} + +// Implements the interface for MSAL SDK to callback when creating the poptoken. +// See AuthenticationScheme interface in https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/main/apps/internal/oauth/ops/authority/authority.go#L146 +type PopTokenAuthScheme struct { + header PopTokenHeader + body PopTokenBody + reqCnfBase64 string + claims map[string]interface{} + keyPair *rsaKeyPair +} + +// refresh the global rsa keypair once every 8 hours. +func generateRSAKeyPair(now time.Time) (*rsaKeyPair, error) { + + if globalRsaKey == nil || globalLastRefreshRsaKeyDateTime.Add(globalRefreshInterval).Before(now) { + pKey, err := rsa.GenerateKey(rand.Reader, RsaSize) + if err != nil { + return nil, err + } + globalRsaKey = &rsaKeyPair{ + PrivateKey: pKey, + PublicKey: pKey.Public().(*rsa.PublicKey), + } + globalLastRefreshRsaKeyDateTime = now + } + return globalRsaKey, nil +} + +func calculatePublicKeyId(jwk *Jwk) (string, error) { + // - https://tools.ietf.org/html/rfc7638#section-3.1 + jwkByte, err := json.Marshal(jwk) + if err != nil { + return "", err + } + jwk256 := sha256.Sum256(jwkByte) + return base64.RawURLEncoding.EncodeToString(jwk256[:]), nil +} + +func jsonToBase64(v any) (string, error) { + jsonValue, err := json.Marshal(v) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(jsonValue), nil +} + +func signPayload(payload []byte, rsaKey *rsa.PrivateKey) (string, error) { + hash := sha256.New() + _, err := hash.Write(payload) + if err != nil { + return "", err + } + sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hash.Sum(nil)) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(sigBytes), nil +} + +func exponential2Base64(e int) string { + bs := make([]byte, 4) + binary.BigEndian.PutUint32(bs, uint32(e)) + + bs = bs[1:] // drop most significant byte - leaving least-significant 3-bytes + ss := base64.RawURLEncoding.EncodeToString(bs) + return ss +} + +// Append custom claims to the existing ShrPopTokenBody. +func (p *PopTokenAuthScheme) appendCustomClaimsToBody(customClaims map[string]interface{}) map[string]interface{} { + + bodyMap := make(map[string]interface{}) + + // first convert the existing body to a map of interface. + val := reflect.ValueOf(p.body) + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + if name := strings.ToLower(typ.Field(i).Name); name != "" { + bodyMap[name] = val.Field(i).Interface() + } + } + // now append the custom claims + for k, v := range customClaims { + bodyMap[k] = v + } + + return bodyMap +} + +// Complete the poptoken creation by adding the custom claims and signing it. +func (p *PopTokenAuthScheme) generateToken(token string, now time.Time) (string, error) { + + p.body.Ts = now.Truncate(time.Second).Unix() + p.body.At = token + + body, err := jsonToBase64(p.appendCustomClaimsToBody(p.claims)) + if err != nil { + return "", err + } + + header, err := jsonToBase64(p.header) + if err != nil { + return "", err + } + + signingStr := strings.Join([]string{header, body}, ".") + + signature, err := signPayload([]byte(signingStr), p.keyPair.PrivateKey) + if err != nil { + return "", nil + } + + return strings.Join([]string{signingStr, signature}, "."), nil +} + +// Return the claim containg the pop token kid that will be added to the Entra access token. +func (p *PopTokenAuthScheme) TokenRequestParams() map[string]string { + return map[string]string{ + "token_type": p.header.Typ, + "req_cnf": p.reqCnfBase64, + } +} + +// Return the keyId for MSAL to lookup for a cached access token. If it does not exist, MSAL will request a new access token +func (p *PopTokenAuthScheme) KeyID() string { + return p.header.Kid +} + +// Generate the pop token; adding in the accessToken generated by Entra. +func (p *PopTokenAuthScheme) FormatAccessToken(accessToken string) (string, error) { + return p.generateToken(accessToken, time.Now()) +} + +// Return the token type. Must be "pop" +func (p *PopTokenAuthScheme) AccessTokenType() string { + return p.header.Typ +} + +// Create a new instance of PopTokenAuthScheme. Pass in the custom claims to be set in the pop token here, e.g. resourceId +func NewPopTokenAuthScheme(claims map[string]interface{}) (*PopTokenAuthScheme, error) { + + keyPair, err := generateRSAKeyPair(time.Now()) + if err != nil { + return nil, err + } + + popTokenScheme := &PopTokenAuthScheme{ + header: PopTokenHeader{ + Alg: Alg, + Typ: TokenType, + }, + body: PopTokenBody{ + Cnf: Cnf{ + Jwk: Jwk{ + Kty: Kty, + N: base64.RawURLEncoding.EncodeToString([]byte(keyPair.PublicKey.N.Bytes())), + E: exponential2Base64(keyPair.PublicKey.E), + }, + }, + }, + keyPair: keyPair, + claims: claims, + } + + keyId, err := calculatePublicKeyId(&popTokenScheme.body.Cnf.Jwk) + if err != nil { + return nil, errors.Wrapf(err, "faild to generate kid") + } + + popTokenScheme.header.Kid = keyId + + reqCnfb64, err := jsonToBase64( + ReqCnf{ + Kid: keyId, + }) + if err != nil { + return nil, errors.Wrapf(err, "faild to generate base64 representation of req_cnf") + } + popTokenScheme.reqCnfBase64 = reqCnfb64 + + return popTokenScheme, nil +} diff --git a/pkg/auth/poptoken/poptokenscheme_test.go b/pkg/auth/poptoken/poptokenscheme_test.go new file mode 100644 index 00000000..5ece9635 --- /dev/null +++ b/pkg/auth/poptoken/poptokenscheme_test.go @@ -0,0 +1,244 @@ +package poptoken + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + testClaimName = "test" + testValue = "value" +) + +var ( + testClaims = map[string]interface{}{testClaimName: testValue} +) + +type TestPopTokenSchemeBody struct { + PopTokenBody + Test string `json:"test"` // must matchtestClaimName +} + +type testStruct struct { + StrValue string + IntValue int +} + +func Test_PopTokenAuthSchemeNew(t *testing.T) { + + pop, err := NewPopTokenAuthScheme(testClaims) + assert.Nil(t, err) + + // calculate kid + expectedKid, err := calculatePublicKeyId(&pop.body.Cnf.Jwk) + assert.Nil(t, err) + + // check header + assert.Equal(t, Alg, pop.header.Alg) + assert.Equal(t, TokenType, pop.header.Typ) + assert.Equal(t, expectedKid, pop.header.Kid) + + // check body + expectedE := exponential2Base64(pop.keyPair.PrivateKey.E) + expectedN := base64.RawURLEncoding.EncodeToString([]byte(pop.keyPair.PublicKey.N.Bytes())) + assert.Equal(t, expectedE, pop.body.Cnf.Jwk.E) + assert.Equal(t, expectedN, pop.body.Cnf.Jwk.N) + + //check claims + actualValue, ok := pop.claims[testClaimName] + assert.True(t, ok) + assert.Equal(t, testValue, actualValue) +} + +func Test_PopTokenAuthSchemeGenerateToken(t *testing.T) { + pop, err := NewPopTokenAuthScheme(testClaims) + assert.Nil(t, err) + + expectedAccessToken := "myFakeAccessToken" + expectedTimeStamp, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") + assert.Nil(t, err) + + expectedKid, err := calculatePublicKeyId(&pop.body.Cnf.Jwk) + assert.Nil(t, err) + + // Generate the token and validate its content + popToken, err := pop.generateToken(expectedAccessToken, expectedTimeStamp) + assert.Nil(t, err) + + toks := strings.Split(popToken, ".") + assert.Equal(t, 3, len(toks)) + + // validate header. + header, err := decodeFromBase64[PopTokenHeader](toks[0]) + assert.Nil(t, err) + assert.Equal(t, Alg, header.Alg) + assert.Equal(t, TokenType, header.Typ) + assert.Equal(t, expectedKid, header.Kid) + + // validate body + body, err := decodeFromBase64[TestPopTokenSchemeBody](toks[1]) + assert.Nil(t, err) + assert.Equal(t, expectedTimeStamp.Truncate(time.Second).Unix(), body.Ts) + assert.Equal(t, testValue, body.Test) + assert.Equal(t, expectedAccessToken, body.At) + + // validate signature. + signature, err := base64.RawURLEncoding.DecodeString(toks[2]) + assert.Nil(t, err) + + signingStr := strings.Join([]string{toks[0], toks[1]}, ".") + err = isSignatureValid(&signingStr, signature, &body.Cnf) + assert.Nil(t, err) +} + +func Test_PopTokenAuthSchemeTokenRequestParams(t *testing.T) { + pop, err := NewPopTokenAuthScheme(testClaims) + assert.Nil(t, err) + + expectedCnfBase64, err := jsonToBase64( + ReqCnf{ + Kid: pop.header.Kid, + }) + assert.Nil(t, err) + + requestParams := pop.TokenRequestParams() + + // we expect these two entries + tokType, ok := requestParams["token_type"] + assert.True(t, ok) + assert.Equal(t, TokenType, tokType) + + actualReqCnf, ok := requestParams["req_cnf"] + assert.True(t, ok) + assert.Equal(t, expectedCnfBase64, actualReqCnf) +} + +func Test_PopTokenAuthSchemeAppendCustomClaims(t *testing.T) { + pop, err := NewPopTokenAuthScheme(testClaims) + assert.Nil(t, err) + + expectedStringValue := "string" + expectedIntegerValue := 1234 + expectedStrArrValue := []string{"hello", "world"} + expectedStructValue := testStruct{StrValue: "string", IntValue: 1234} + + customClaims := map[string]interface{}{ + "string": expectedStringValue, + "integer": expectedIntegerValue, + "strArray": expectedStrArrValue, + "struct": expectedStructValue, + } + + actualClaims := pop.appendCustomClaimsToBody(customClaims) + + tmp, ok := actualClaims["string"] + assert.True(t, ok) + actualstringValue, ok := tmp.(string) + assert.True(t, ok) + assert.Equal(t, expectedStringValue, actualstringValue) + + tmp, ok = actualClaims["integer"] + assert.True(t, ok) + actualIntegerValue, ok := tmp.(int) + assert.True(t, ok) + assert.Equal(t, expectedIntegerValue, actualIntegerValue) + + tmp, ok = actualClaims["strArray"] + assert.True(t, ok) + actualStrArrValue, ok := tmp.([]string) + assert.True(t, ok) + assert.Equal(t, expectedStrArrValue, actualStrArrValue) + + tmp, ok = actualClaims["struct"] + assert.True(t, ok) + actualStructValue, ok := tmp.(testStruct) + assert.True(t, ok) + assert.Equal(t, expectedStructValue, actualStructValue) + + // finally sanity check that these custom claims can be converted to json + _, err = jsonToBase64(actualClaims) + assert.Nil(t, err) +} + +func Test_PopTokenAuthSchemeExponential2Base64(t *testing.T) { + e := 65537 + base64 := exponential2Base64(e) + // this is the encoded value of a well known exponential value + assert.Equal(t, "AQAB", base64) +} + +func Test_PopTokenAuthSchemeCalculatePublicKeyId(t *testing.T) { + jwk := Jwk{ + Kty: "RSA", + E: "AQAB", + N: "MjM1MDg5MDU4MzgxMDg3OTI5NTU3NjM1ODg4NTA3NDE5OTAwNzc0MzkzNzQ5NDcwNzcwMjA2MDIxNjMyNzk5NzYxNDM4NTczMjc3NTA0NzI4ODkzNDUzNjU0NDU0NjMxMjcxNjQ0MTAwMDM0NzUzNzU2MTEyMjkzODYzMDYxMjk5MDQxNzI5OTc0MDg5OTk2OTEzNTY4MjM5OTc0NDMwNTExODI3MDgyNDAzMDQxNDMxMTQ5ODA4ODc4NjE5NTc5MjcwMjAxNjc3ODM1NTQ0NDI3NDMwMDczODI2OTAwODk2MzcxNTM2NzE5NDQyNTUxNzIzNTM5MTg4OTU2MDc4MzI0MzYxNDM4MDEzNjA3OTI0NzMyNTUxMDg5ODU3NjQ1NDA0MTIyMTk3ODUwNjkyMjEyMTk4OTMxMDU1NTkzOTk4NzYyMjIwODg1NDg5NzE4MjQxNDAxMTg2MTMwMzExODAwMDQ2NjEwMjk0MDIzMzQ1MTA1NjE4ODY0ODc0OTgzNzU2NTMzMTY0OTk5MTg1NDk4ODIwOTY3NjYyNjM1NTUxMjk0NTkzNDEwNzc5MzUwODg2MjMxODkyMTc0NTcwODkxNDU4MjIwNzIwMzI5MTg3OTA3NzAxMzMzMDU1NzM0ODk0NjU3MDYzOTMzMzA3MTUwNjgzMTk1NjkyOTk0MzAxMjUxODUwNzUwMTg2MzI5MzM4ODk2NjY3OTQyMDE0OTcwODY3MTAzMTgxNTA5NDAxMTAwMzUwMzk5MDE3MDI3MTI3MTAwMDM5OTIwNjgwNjExNjcxNTQ3MDE1ODM2NzIyMTU1OTgxMTE=", + } + keyId, err := calculatePublicKeyId(&jwk) + assert.Nil(t, err) + assert.Equal(t, "a0CyVS__Npcx4GXYm1OCoxrlboOWKF02MXzSSh92ckY", keyId) +} + +func Test_PopTokenAuthSchemeSignPayload(t *testing.T) { + keypair, err := generateKeyPair() + assert.Nil(t, err) + + payload := []byte("ThisIsMyTestPayLoad") + + sig, err := signPayload(payload, keypair.PrivateKey) + assert.Nil(t, err) + + //now verify the signature using the public key + sigDecode, err := base64.RawURLEncoding.DecodeString(sig) + hash := sha256.New() + hash.Write(payload) + err = rsa.VerifyPKCS1v15(keypair.PublicKey, crypto.SHA256, hash.Sum(nil), sigDecode) + assert.Nil(t, err) +} + +func Test_PopTokenAuthSchemeGenerateRsaKeyPair(t *testing.T) { + + now, err := time.Parse(time.RFC3339, "2025-12-01T15:00:00Z") + + keyPair1, err := generateRSAKeyPair(now) + assert.Nil(t, err) + assert.NotNil(t, keyPair1) + + // now simulate getting a second keypair < refreshInterval; keyPair2 should be the same + keyPair2, err := generateRSAKeyPair(now) + assert.Nil(t, err) + assert.NotNil(t, keyPair2) + assert.Equal(t, keyPair1.PrivateKey.N, keyPair2.PrivateKey.N) + + // now simulate getting a third keypair > refreshInterval; keyPair3 should be different from 1 and 2. + newNow := now.Add(globalRefreshInterval).Add(time.Minute) + keyPair3, err := generateRSAKeyPair(newNow) + assert.Nil(t, err) + assert.NotNil(t, keyPair3) + assert.NotEqual(t, keyPair1.PrivateKey.N, keyPair3.PrivateKey.N) + assert.NotEqual(t, keyPair2.PrivateKey.N, keyPair3.PrivateKey.N) + + // now try again; keypair4 == keypair3 + keyPair4, err := generateRSAKeyPair(newNow) + assert.Nil(t, err) + assert.NotNil(t, keyPair4) + assert.Equal(t, keyPair3.PrivateKey.N, keyPair4.PrivateKey.N) +} + +func generateKeyPair() (*rsaKeyPair, error) { + pKey, err := rsa.GenerateKey(rand.Reader, RsaSize) + if err != nil { + return nil, err + } + return &rsaKeyPair{ + PrivateKey: pKey, + PublicKey: pKey.Public().(*rsa.PublicKey), + }, nil +}