Skip to content
This repository was archived by the owner on Aug 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ and if not system properties and `META-INF/geronimo/microprofile/jwt-auth.proper
|geronimo.jwt-auth.jca.provider|The JCA provider (java security)|- (built-in one)
|geronimo.jwt-auth.groups.mapping|The mapping for the groups|-
|geronimo.jwt-auth.public-key.cache.active|Should public keys be cached|true
|geronimo.jwt-auth.jwks.invalidationInterval|Invalidation interval in seconds|-
|geronimo.jwt-auth.public-key.default|Default public key to verify JWT|-
|===

Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
<version>2.21.0</version>
<configuration>
<suiteXmlFiles>
<suiteXmlFile>${project.basedir}/src/test/resources/geronimo.xml</suiteXmlFile>
<suiteXmlFile>${project.basedir}/src/test/resources/tck.xml</suiteXmlFile>
</suiteXmlFiles>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.geronimo.microprofile.impl.jwtauth.jwt;

import java.math.BigInteger;
import java.security.AlgorithmParameters;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.spec.*;
import java.util.Base64;
import java.util.Base64.Decoder;

import javax.json.JsonObject;

import static java.security.KeyFactory.getInstance;
import static java.util.Optional.*;

public class JWK {

private String kid;
private String kty;
private String n;
private String e;
private String x;
private String y;
private String crv;
private String use;

public JWK(JsonObject jsonObject) {
kid = jsonObject.getString("kid", null);
kty = jsonObject.getString("kty", null);
x = jsonObject.getString("x", null);
y = jsonObject.getString("y", null);
crv = jsonObject.getString("crv", null);
n = jsonObject.getString("n", null);
e = jsonObject.getString("e", null);
use = jsonObject.getString("use", null);
}

public String getKid() {
return kid;
}

public String getUse() {
return use;
}


public String toPemKey() {
PublicKey publicKey = toPublicKey();
String base64PublicKey = Base64.getMimeEncoder(64, "\n".getBytes()).encodeToString(publicKey.getEncoded());
String result = "-----BEGIN PUBLIC KEY-----" + base64PublicKey + "-----END PUBLIC KEY-----";
return result.replace("\n", "");
}

public PublicKey toPublicKey() {
if ("RSA".equals(kty)) {
return toRSAPublicKey();
} else if("EC".equals(kty)) {
return toECPublicKey();
} else {
throw new IllegalStateException("Unsupported kty. Only RSA and EC are supported.");
}
}

private PublicKey toRSAPublicKey() {
Decoder decoder = Base64.getUrlDecoder();
BigInteger modulus = ofNullable(n).map(mod -> new BigInteger(1, decoder.decode(mod))).orElseThrow(() -> new IllegalStateException("n must be set for RSA keys."));
BigInteger exponent = ofNullable(e).map(exp -> new BigInteger(1, decoder.decode(exp))).orElseThrow(() -> new IllegalStateException("e must be set for RSA keys."));
RSAPublicKeySpec spec = new RSAPublicKeySpec(modulus, exponent);
try {
KeyFactory factory = getInstance("RSA");
return factory.generatePublic(spec);
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
throw new IllegalStateException(e);
}
}

private PublicKey toECPublicKey() {
Decoder decoder = Base64.getUrlDecoder();
BigInteger pointX = ofNullable(x).map(x -> new BigInteger(1, decoder.decode(x))).orElseThrow(() -> new IllegalStateException("x must be set for EC keys."));
BigInteger pointY = ofNullable(y).map(y -> new BigInteger(1, decoder.decode(y))).orElseThrow(() -> new IllegalStateException("y must be set for EC keys."));
ECPoint pubPoint = new ECPoint(pointX, pointY);
try {
AlgorithmParameters parameters = AlgorithmParameters.getInstance("EC");
parameters.init(ofNullable(crv).map(JWK::mapCrv).map(ECGenParameterSpec::new).orElseThrow(() -> new IllegalStateException("crv must be set for EC keys.")));
ECParameterSpec ecParameters = parameters.getParameterSpec(ECParameterSpec.class);
return getInstance("EC").generatePublic(new ECPublicKeySpec(pubPoint, ecParameters));
} catch (NoSuchAlgorithmException | InvalidParameterSpecException | InvalidKeySpecException e) {
throw new IllegalStateException(e);
}
}

private static String mapCrv(String crv) {
if (crv.endsWith("256")) {
return "secp256r1";
} else if (crv.endsWith("384")) {
return "secp384r1";
} else {
return "secp521r1";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,55 @@
*/
package org.apache.geronimo.microprofile.impl.jwtauth.jwt;

import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.joining;
import org.apache.geronimo.microprofile.impl.jwtauth.config.GeronimoJwtAuthConfig;
import org.apache.geronimo.microprofile.impl.jwtauth.io.PropertiesLoader;
import org.eclipse.microprofile.jwt.config.Names;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import javax.json.Json;
import javax.json.JsonArray;
import javax.json.JsonObject;
import javax.json.JsonReader;
import javax.json.JsonReaderFactory;
import javax.json.JsonValue;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.file.Files;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.PostConstruct;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;

import org.apache.geronimo.microprofile.impl.jwtauth.config.GeronimoJwtAuthConfig;
import org.apache.geronimo.microprofile.impl.jwtauth.io.PropertiesLoader;
import org.eclipse.microprofile.jwt.config.Names;
import static java.util.Collections.emptyMap;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

@ApplicationScoped
public class KidMapper {
@Inject
private GeronimoJwtAuthConfig config;

private final ConcurrentMap<String, String> keyMapping = new ConcurrentHashMap<>();
private volatile ConcurrentMap<String, String> keyMapping = new ConcurrentHashMap<>();
private final Map<String, Collection<String>> issuerMapping = new HashMap<>();
private String defaultKey;
private String jwksUrl;
private Set<String> defaultIssuers;

private JsonReaderFactory readerFactory;
private volatile CompletableFuture<Void> reloadJwksRequest;
HttpClient httpClient;
ScheduledExecutorService backgroundThread;
@PostConstruct
private void init() {
ofNullable(config.read("kids.key.mapping", null))
Expand All @@ -79,9 +91,48 @@ private void init() {
.collect(Collectors.toSet()))
.orElseGet(HashSet::new);
ofNullable(config.read("issuer.default", config.read(Names.ISSUER, null))).ifPresent(defaultIssuers::add);
jwksUrl = config.read("mp.jwt.verify.publickey.location", null);
readerFactory = Json.createReaderFactory(emptyMap());
ofNullable(jwksUrl).ifPresent(url -> {
HttpClient.Builder builder = HttpClient.newBuilder();
if (getJwksRefreshInterval() != null) {
long secondsRefresh = getJwksRefreshInterval();
backgroundThread = Executors.newSingleThreadScheduledExecutor();
builder.executor(backgroundThread);
backgroundThread.scheduleAtFixedRate(this::reloadRemoteKeys, getJwksRefreshInterval(), secondsRefresh, TimeUnit.SECONDS );
}
httpClient = builder.build();
reloadJwksRequest = reloadRemoteKeys();// inital load, otherwise the background thread is too slow to start and serve
});
defaultKey = config.read("public-key.default", config.read(Names.VERIFIER_PUBLIC_KEY, null));
}

private Integer getJwksRefreshInterval() {
String interval = config.read("jwks.invalidation.interval",null);
if (interval != null) {
return Integer.parseInt(interval);
} else {
return null;
}
}

private CompletableFuture<Void> reloadRemoteKeys() {
HttpRequest request = HttpRequest.newBuilder().GET().uri(URI.create(jwksUrl)).header("Accept", "application/json").build();
CompletableFuture<HttpResponse<String>> httpResponseCompletableFuture = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString());
CompletableFuture<Void> ongoingRequest = httpResponseCompletableFuture.thenApply(result -> {
List<JWK> jwks = parseKeys(result);
ConcurrentHashMap<String, String> newKeys = new ConcurrentHashMap<>();
jwks.forEach(key -> newKeys.put(key.getKid(), key.toPemKey()));
keyMapping = newKeys;
return null;
});

ongoingRequest.thenRun(() -> {
reloadJwksRequest = ongoingRequest;
});
return ongoingRequest;
}

public String loadKey(final String property) {
String value = keyMapping.get(property);
if (value == null) {
Expand Down Expand Up @@ -120,7 +171,44 @@ private String tryLoad(final String value) {
throw new IllegalArgumentException(e);
}

// else direct value
// load jwks via url
if (jwksUrl != null) {
if(reloadJwksRequest != null) {
try {
reloadJwksRequest.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
String key = keyMapping.get(value);
if (key != null) {
return key;
}

}
return value;
}

private List<JWK> parseKeys(HttpResponse<String> keyResponse) {
StringReader stringReader = new StringReader(keyResponse.body());
JsonReader jwksReader = readerFactory.createReader(stringReader);
JsonObject keySet = jwksReader.readObject();
JsonArray keys = keySet.getJsonArray("keys");
return keys.stream()
.map(JsonValue::asJsonObject)
.map(JWK::new)
.filter(it -> it.getUse() == null || "sig".equals(it.getUse()))
.collect(toList());
}

@PreDestroy
private void destroy() {
if (backgroundThread != null) {
backgroundThread.shutdown();
}
if (reloadJwksRequest != null) {
reloadJwksRequest.cancel(true);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.apache.geronimo.microprofile.impl.jwtauth.jwt;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.net.ServerSocket;
import java.net.Socket;

class JwksServer {

private static final String HEADER = "HTTP/1.0 200 OK\r\nConnection: close\r\n";

private ServerSocket serverSocket;

JwksServer() throws IOException {
serverSocket = new ServerSocket(0);
}

int getPort() {
return serverSocket.getLocalPort();
}

void start() {
Thread server = new Thread(() -> {
while (!serverSocket.isClosed()) {
try (Socket client = serverSocket.accept();
BufferedReader request = new BufferedReader(new InputStreamReader(client.getInputStream()));
BufferedReader reader = new BufferedReader(new InputStreamReader(
getClass().getResourceAsStream(request.readLine().split("\\s")[1])));
PrintWriter writer = new PrintWriter(client.getOutputStream())) {

writer.println(HEADER);
writer.print(load(reader));
} catch (IOException e) {
if (!serverSocket.isClosed()) {
e.printStackTrace(System.err);
}
}
}
});
server.start();
}

void stop() throws IOException {
serverSocket.close();
}

private String load(BufferedReader reader) throws IOException {
StringBuilder content = new StringBuilder();
for (String line = reader.readLine(); line != null; line = reader.readLine()) {
content.append(line).append("\r\n");
}
return content.toString();
}
}
Loading