diff --git a/src/main/java/org/opensearch/security/auth/http/jwt/keybyoidc/JwtVerifier.java b/src/main/java/org/opensearch/security/auth/http/jwt/keybyoidc/JwtVerifier.java index 68e024088e..de2f47a3f3 100644 --- a/src/main/java/org/opensearch/security/auth/http/jwt/keybyoidc/JwtVerifier.java +++ b/src/main/java/org/opensearch/security/auth/http/jwt/keybyoidc/JwtVerifier.java @@ -25,8 +25,10 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSVerifier; import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory; +import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.OctetSequenceKey; +import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jose.proc.SimpleSecurityContext; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; @@ -52,7 +54,6 @@ public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, Strin public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException { try { SignedJWT jwt = SignedJWT.parse(encodedJwt); - String escapedKid = jwt.getHeader().getKeyID(); String kid = escapedKid; if (!Strings.isNullOrEmpty(kid)) { @@ -61,7 +62,6 @@ public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExc log.debug("JWT token is missing 'kid' (Key ID) claim in header. This may cause key selection issues."); } JWK key = keyProvider.getKey(kid); - JWSVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt); boolean signatureValid = jwt.verify(signatureVerifier); @@ -104,10 +104,14 @@ private JWSVerifier getInitializedSignatureVerifier(JWK key, SignedJWT jwt) thro validateSignatureAlgorithm(key, jwt); final JWSVerifier result; - if (key.getClass() == OctetSequenceKey.class) { + if (key instanceof OctetSequenceKey) { result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toOctetSequenceKey().toSecretKey()); - } else { + } else if (key instanceof RSAKey) { result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toRSAKey().toRSAPublicKey()); + } else if (key instanceof ECKey) { + result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toECKey().toECPublicKey()); + } else { + throw new IllegalArgumentException("Unsupported JWK key type: " + key.getClass()); } if (result == null) { diff --git a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByJWKSAuthenticatorTest.java b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByJWKSAuthenticatorTest.java index 178d055409..ad19a9b16e 100644 --- a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByJWKSAuthenticatorTest.java +++ b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/HTTPJwtKeyByJWKSAuthenticatorTest.java @@ -53,6 +53,28 @@ public void testBasicJwksAuthentication() throws Exception { } } + @Test + public void testJwksAuthenticationWithEC() throws Exception { + MockJwksServer mockJwksServer = new MockJwksServer(TestJwk.Jwks.ALL); + + try { + Settings settings = Settings.builder().put("jwks_uri", mockJwksServer.getJwksUri()).build(); + + HTTPJwtKeyByJWKSAuthenticator jwtAuth = new HTTPJwtKeyByJWKSAuthenticator(settings, null); + + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_EC_1), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(TestJwts.MCCOY_SUBJECT)); + + } finally { + mockJwksServer.close(); + } + } + @Test public void testJwksAuthenticationWithBearerPrefix() throws Exception { MockJwksServer mockJwksServer = new MockJwksServer(TestJwk.Jwks.ALL); diff --git a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwk.java b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwk.java index 531f433e0f..10b8a29754 100644 --- a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwk.java +++ b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwk.java @@ -15,6 +15,8 @@ import java.util.Arrays; import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.KeyUse; @@ -47,6 +49,10 @@ class TestJwk { "hMSoV74FRtoaU7xpp0llsXbHE4oUseKoSNga-C_YIXuoGc3pajHh1WtJppZQNYM1Xy07nHchLJAdgqL2_q_Lk8cFHmmL1KTjwPflK9zZ9C0-8QTOrrqU9vkp3gT00jWWJ0HJbUvXIGxPGPnxoJoI--ToE0EWsYEWqWyx1TqYol--oUUPlY5r7vXRKIn5UZNz6VGkW8nI4fXaqDUpXH9uVM9A-nJX2B0Xjwu3VOn2zrgkCZeGTHjNgfLISOTFe9m8lHWLKcuxOWPuCZyCN0C6ZdWB1YP2NhxYFQwQfGV8yfnTImgL-DuV4WPSRVj7W_GJr213-oXBrBR0CnQEPbi_3w"; static final String RSA_1_E = "AQAB"; + static final String EC_1_X = "K9a34L6QkEkWggKi700OyBCRghR2Xt-0Wym8qz0GdsM"; + static final String EC_1_Y = "NaINav68UqRb9D9MWkUJZN6acnuOYSb2iWAVI05iKpw"; + static final String EC_1_D = "lkTOA6MmmMDUg44V0coRAHbB5Zw-0748N8l8EOK8-5A"; + static final String RSA_2_D = "QQ18k_buZHOSVYzkXL1FaqdodZVNZ_hrBtDcmCVUYjm3dfDVQYt70h8LUdLUCSUA2-_VEwqVdQ-L2FTg7NZVvZJXIyQXp3yrdY1vGKebs3oaIB_VQT8jt-64s12r_8V2ksK2myRrvfm2Fgqi32H5QkspuaQYb9s4NJwKSk7mVAz5dRWQdCx9JNVWknWDJxgHzh3Uku1tNwUOyvSYcRnSZ9X7oWNHaHkSGLEYE_mxD7YXs6HEdCDwc3WuvR5AiVKg2OGec0lL1hY_AWX5UxnR00mhAa0qPytFfaPe-Sc5tQ5regQRqRNDyDESVGIvqXsY8ePjZPOFyoxrcJ2wN3bt4Q"; static final String RSA_2_N = @@ -61,6 +67,9 @@ class TestJwk { static final JWK RSA_1 = createRsa("kid/1", "RS256", RSA_1_E, RSA_1_N, RSA_1_D); + static final JWK EC_1 = createEc("kid/ec1", "ES256", EC_1_X, EC_1_Y, EC_1_D); + static final JWK EC_1_PUBLIC = createEc("kid/ec1", "ES256", EC_1_X, EC_1_Y, null); + static final JWK RSA_1_PUBLIC = createRsaPublic("kid/1", "RS256", RSA_1_E, RSA_1_N); static final JWK RSA_1_PUBLIC_NO_ALG = createRsaPublic("kid/1", null, RSA_1_E, RSA_1_N); static final JWK RSA_1_PUBLIC_WRONG_ALG = createRsaPublic("kid/1", "HS256", RSA_1_E, RSA_1_N); @@ -74,9 +83,10 @@ class TestJwk { static final JWKSet RSA_1_2_PUBLIC = createJwks(RSA_1_PUBLIC, RSA_2_PUBLIC); static class Jwks { - static final JWKSet ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC); + static final JWKSet ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC, EC_1_PUBLIC); static final JWKSet RSA_1 = createJwks(RSA_1_PUBLIC); static final JWKSet RSA_2 = createJwks(RSA_2_PUBLIC); + static final JWKSet EC_1 = createJwks(EC_1_PUBLIC); static final JWKSet RSA_1_NO_ALG = createJwks(RSA_1_PUBLIC_NO_ALG); static final JWKSet RSA_1_WRONG_ALG = createJwks(RSA_1_PUBLIC_WRONG_ALG); } @@ -104,6 +114,18 @@ private static JWK createRsaPublic(String keyId, String algorithm, String e, Str return createRsa(keyId, algorithm, e, n, null); } + private static JWK createEc(String keyId, String algorithm, String x, String y, String d) { + ECKey.Builder builder = new ECKey.Builder(Curve.P_256, Base64URL.from(x), Base64URL.from(y)).keyUse(KeyUse.SIGNATURE) + .algorithm(algorithm == null ? null : JWSAlgorithm.parse(algorithm)) + .keyID(keyId); + + if (d != null) { + builder.d(Base64URL.from(d)); + } + + return builder.build(); + } + private static JWKSet createJwks(JWK... array) { return new JWKSet(Arrays.asList(array)); } diff --git a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java index 9971ca4bc5..d1ae23bccc 100644 --- a/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java +++ b/src/test/java/org/opensearch/security/auth/http/jwt/keybyoidc/TestJwts.java @@ -104,6 +104,8 @@ class TestJwts { static final String MC_COY_SIGNED_RSA_1 = createSigned(MC_COY, TestJwk.RSA_1); + static final String MC_COY_SIGNED_EC_1 = createSigned(MC_COY, TestJwk.EC_1); + static final String MC_COY_SIGNED_RSA_X = createSigned(MC_COY, TestJwk.RSA_X); static final String MC_COY_EXPIRED_SIGNED_OCT_1 = createSigned(MC_COY_EXPIRED, TestJwk.OCT_1);