Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.proc.SimpleSecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
Expand All @@ -52,7 +54,6 @@ public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, Strin
public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
try {
SignedJWT jwt = SignedJWT.parse(encodedJwt);

String escapedKid = jwt.getHeader().getKeyID();
String kid = escapedKid;
if (!Strings.isNullOrEmpty(kid)) {
Expand All @@ -61,7 +62,6 @@ public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExc
log.debug("JWT token is missing 'kid' (Key ID) claim in header. This may cause key selection issues.");
}
JWK key = keyProvider.getKey(kid);

JWSVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt);
boolean signatureValid = jwt.verify(signatureVerifier);

Expand Down Expand Up @@ -104,10 +104,14 @@ private JWSVerifier getInitializedSignatureVerifier(JWK key, SignedJWT jwt) thro

validateSignatureAlgorithm(key, jwt);
final JWSVerifier result;
if (key.getClass() == OctetSequenceKey.class) {
if (key instanceof OctetSequenceKey) {
result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toOctetSequenceKey().toSecretKey());
} else {
} else if (key instanceof RSAKey) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we write tests to hit all of these branches?

Copy link
Copy Markdown
Author

@wheresNasha wheresNasha May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have already addded test coverage for the missing EC (ECKey) branch here.

I also looked into adding OCT (OctetSequenceKey) coverage, but the current JWKS authentication flow does not appear to successfully authenticate symmetric OCT keys in the existing test infrastructure (extractCredentials() returns null), so that test is currently failing.

RSA coverage is already present, so this PR focuses on covering the missing EC path.

result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toRSAKey().toRSAPublicKey());
} else if (key instanceof ECKey) {
result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toECKey().toECPublicKey());
} else {
throw new IllegalArgumentException("Unsupported JWK key type: " + key.getClass());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it makes sense but should we keep the fallback logic in place?

Any idea what the entire universe of key types can be?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leaned toward explicit key type handling here since failing closed felt safer than broadly falling back and potentially masking unsupported JWK types.
From what I understand, Nimbus JWK currently supports symmetric (oct), RSA, EC, and additional types like OKP.

If preserving backward compatibility with prior RSA fallback makes more sense here, I’m happy to adjust it.

}

if (result == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ public void testBasicJwksAuthentication() throws Exception {
}
}

@Test
public void testJwksAuthenticationWithEC() throws Exception {
MockJwksServer mockJwksServer = new MockJwksServer(TestJwk.Jwks.ALL);

try {
Settings settings = Settings.builder().put("jwks_uri", mockJwksServer.getJwksUri()).build();

HTTPJwtKeyByJWKSAuthenticator jwtAuth = new HTTPJwtKeyByJWKSAuthenticator(settings, null);

AuthCredentials creds = jwtAuth.extractCredentials(
new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_EC_1), new HashMap<>()).asSecurityRequest(),
null
);

Assert.assertNotNull(creds);
assertThat(creds.getUsername(), is(TestJwts.MCCOY_SUBJECT));

} finally {
mockJwksServer.close();
}
}

@Test
public void testJwksAuthenticationWithBearerPrefix() throws Exception {
MockJwksServer mockJwksServer = new MockJwksServer(TestJwk.Jwks.ALL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import java.util.Arrays;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
Expand Down Expand Up @@ -47,6 +49,10 @@ class TestJwk {
"hMSoV74FRtoaU7xpp0llsXbHE4oUseKoSNga-C_YIXuoGc3pajHh1WtJppZQNYM1Xy07nHchLJAdgqL2_q_Lk8cFHmmL1KTjwPflK9zZ9C0-8QTOrrqU9vkp3gT00jWWJ0HJbUvXIGxPGPnxoJoI--ToE0EWsYEWqWyx1TqYol--oUUPlY5r7vXRKIn5UZNz6VGkW8nI4fXaqDUpXH9uVM9A-nJX2B0Xjwu3VOn2zrgkCZeGTHjNgfLISOTFe9m8lHWLKcuxOWPuCZyCN0C6ZdWB1YP2NhxYFQwQfGV8yfnTImgL-DuV4WPSRVj7W_GJr213-oXBrBR0CnQEPbi_3w";
static final String RSA_1_E = "AQAB";

static final String EC_1_X = "K9a34L6QkEkWggKi700OyBCRghR2Xt-0Wym8qz0GdsM";
static final String EC_1_Y = "NaINav68UqRb9D9MWkUJZN6acnuOYSb2iWAVI05iKpw";
static final String EC_1_D = "lkTOA6MmmMDUg44V0coRAHbB5Zw-0748N8l8EOK8-5A";

static final String RSA_2_D =
"QQ18k_buZHOSVYzkXL1FaqdodZVNZ_hrBtDcmCVUYjm3dfDVQYt70h8LUdLUCSUA2-_VEwqVdQ-L2FTg7NZVvZJXIyQXp3yrdY1vGKebs3oaIB_VQT8jt-64s12r_8V2ksK2myRrvfm2Fgqi32H5QkspuaQYb9s4NJwKSk7mVAz5dRWQdCx9JNVWknWDJxgHzh3Uku1tNwUOyvSYcRnSZ9X7oWNHaHkSGLEYE_mxD7YXs6HEdCDwc3WuvR5AiVKg2OGec0lL1hY_AWX5UxnR00mhAa0qPytFfaPe-Sc5tQ5regQRqRNDyDESVGIvqXsY8ePjZPOFyoxrcJ2wN3bt4Q";
static final String RSA_2_N =
Expand All @@ -61,6 +67,9 @@ class TestJwk {

static final JWK RSA_1 = createRsa("kid/1", "RS256", RSA_1_E, RSA_1_N, RSA_1_D);

static final JWK EC_1 = createEc("kid/ec1", "ES256", EC_1_X, EC_1_Y, EC_1_D);
static final JWK EC_1_PUBLIC = createEc("kid/ec1", "ES256", EC_1_X, EC_1_Y, null);

static final JWK RSA_1_PUBLIC = createRsaPublic("kid/1", "RS256", RSA_1_E, RSA_1_N);
static final JWK RSA_1_PUBLIC_NO_ALG = createRsaPublic("kid/1", null, RSA_1_E, RSA_1_N);
static final JWK RSA_1_PUBLIC_WRONG_ALG = createRsaPublic("kid/1", "HS256", RSA_1_E, RSA_1_N);
Expand All @@ -74,9 +83,10 @@ class TestJwk {
static final JWKSet RSA_1_2_PUBLIC = createJwks(RSA_1_PUBLIC, RSA_2_PUBLIC);

static class Jwks {
static final JWKSet ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC);
static final JWKSet ALL = createJwks(OCT_1, OCT_2, OCT_3, FORWARD_SLASH_KID_OCT_1, RSA_1_PUBLIC, RSA_2_PUBLIC, EC_1_PUBLIC);
static final JWKSet RSA_1 = createJwks(RSA_1_PUBLIC);
static final JWKSet RSA_2 = createJwks(RSA_2_PUBLIC);
static final JWKSet EC_1 = createJwks(EC_1_PUBLIC);
static final JWKSet RSA_1_NO_ALG = createJwks(RSA_1_PUBLIC_NO_ALG);
static final JWKSet RSA_1_WRONG_ALG = createJwks(RSA_1_PUBLIC_WRONG_ALG);
}
Expand Down Expand Up @@ -104,6 +114,18 @@ private static JWK createRsaPublic(String keyId, String algorithm, String e, Str
return createRsa(keyId, algorithm, e, n, null);
}

private static JWK createEc(String keyId, String algorithm, String x, String y, String d) {
ECKey.Builder builder = new ECKey.Builder(Curve.P_256, Base64URL.from(x), Base64URL.from(y)).keyUse(KeyUse.SIGNATURE)
.algorithm(algorithm == null ? null : JWSAlgorithm.parse(algorithm))
.keyID(keyId);

if (d != null) {
builder.d(Base64URL.from(d));
}

return builder.build();
}

private static JWKSet createJwks(JWK... array) {
return new JWKSet(Arrays.asList(array));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class TestJwts {

static final String MC_COY_SIGNED_RSA_1 = createSigned(MC_COY, TestJwk.RSA_1);

static final String MC_COY_SIGNED_EC_1 = createSigned(MC_COY, TestJwk.EC_1);

static final String MC_COY_SIGNED_RSA_X = createSigned(MC_COY, TestJwk.RSA_X);

static final String MC_COY_EXPIRED_SIGNED_OCT_1 = createSigned(MC_COY_EXPIRED, TestJwk.OCT_1);
Expand Down
Loading