From 035c48ab7186180dff7b31556ee738ac29162471 Mon Sep 17 00:00:00 2001 From: "nicolas.fraison@datadoghq.com" Date: Tue, 5 Aug 2025 14:58:09 +0200 Subject: [PATCH 1/4] [FLINK-37504] Add ssl certificate reload mechanism --- .../flink/configuration/SecurityOptions.java | 15 ++ .../security/watch/LocalFSWatchService.java | 73 +++++++ .../watch/LocalFSWatchServiceListener.java | 30 +++ .../security/watch/LocalFSWatchSingleton.java | 65 +++++++ .../rpc/pekko/CustomSSLEngineProvider.java | 98 +++++----- .../flink/runtime/rpc/pekko/PekkoUtils.java | 4 + .../runtime/rpc/pekko/SSLContextLoader.java | 180 ++++++++++++++++++ flink-runtime/pom.xml | 2 +- .../apache/flink/runtime/blob/BlobServer.java | 116 +++++------ .../flink/runtime/blob/BlobServerSocket.java | 166 ++++++++++++++++ .../runtime/entrypoint/ClusterEntrypoint.java | 8 + .../runtime/net/ReloadableJdkSslContext.java | 84 ++++++++ .../runtime/net/ReloadableSslContext.java | 164 ++++++++++++++++ .../apache/flink/runtime/net/SSLUtils.java | 155 ++++++++------- .../taskexecutor/TaskManagerRunner.java | 8 + 15 files changed, 973 insertions(+), 195 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java create mode 100644 flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java create mode 100644 flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java create mode 100644 flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableJdkSslContext.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java diff --git a/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java index 24ecf2f6eaf1e..b856847b9c6a7 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java @@ -618,6 +618,16 @@ public static Configuration forProvider(Configuration configuration, String prov + "forcibly. (-1 = use system default)") .withDeprecatedKeys("security.ssl.close-notify-flush-timeout"); + // TODO check all documentation are well updated (explain mechanism) + /** Indicate if changes on keystore/truststore should leads to reload of the certificate. */ + @Documentation.Section(Documentation.Sections.SECURITY_SSL) + public static final ConfigOption SSL_RELOAD = + key("security.ssl.reload") + .booleanType() + .defaultValue(false) + .withDescription( + "Indicate if changes on keystore/truststore should leads to reload of the certificate."); + /** * Checks whether SSL for internal communication (rpc, data transport, blob server) is enabled. */ @@ -635,4 +645,9 @@ public static boolean isRestSSLAuthenticationEnabled(Configuration sslConfig) { checkNotNull(sslConfig, "sslConfig"); return isRestSSLEnabled(sslConfig) && sslConfig.get(SSL_REST_AUTHENTICATION_ENABLED); } + + /** Checks whether certificates must be reloaded in case of keytstore or trusttore changes. */ + public static boolean isReloadCertificate(Configuration sslConfig) { + return sslConfig.get(SSL_RELOAD); + } } diff --git a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java new file mode 100644 index 0000000000000..3f61bf06569cf --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.core.security.watch; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.file.Path; +import java.nio.file.WatchEvent; +import java.nio.file.WatchKey; +import java.nio.file.WatchService; +import java.util.Map; + +import static java.nio.file.StandardWatchEventKinds.ENTRY_CREATE; +import static java.nio.file.StandardWatchEventKinds.ENTRY_DELETE; +import static java.nio.file.StandardWatchEventKinds.ENTRY_MODIFY; +import static java.nio.file.StandardWatchEventKinds.OVERFLOW; + +public class LocalFSWatchService extends Thread { + private static final Logger LOG = LoggerFactory.getLogger(LocalFSWatchService.class); + + public void run() { + try { + while (true) { + for (Map.Entry entry : + LocalFSWatchSingleton.getInstance().watchers.entrySet()) { + LOG.debug("Taking watch key"); + WatchKey watchKey = entry.getKey().poll(); + if (watchKey == null) { + continue; + } + LOG.debug("Watch key arrived"); + for (WatchEvent watchEvent : watchKey.pollEvents()) { + System.out.println(watchEvent.kind()); + System.out.println(watchEvent.context()); + if (watchEvent.kind() == OVERFLOW) { + LOG.error("Filesystem events may have been lost or discarded"); + Thread.yield(); + } else if (watchEvent.kind() == ENTRY_CREATE) { + entry.getValue().onFileOrDirectoryCreated((Path) watchEvent.context()); + } else if (watchEvent.kind() == ENTRY_DELETE) { + entry.getValue().onFileOrDirectoryDeleted((Path) watchEvent.context()); + } else if (watchEvent.kind() == ENTRY_MODIFY) { + entry.getValue().onFileOrDirectoryModified((Path) watchEvent.context()); + } else { + LOG.warn("Unhandled watch event {}", watchEvent.kind()); + } + } + watchKey.reset(); + } + } + } catch (Exception e) { + LOG.error("Filesystem watcher received exception and stopped: ", e); + throw new RuntimeException(e); + } + } +} diff --git a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java new file mode 100644 index 0000000000000..33c3c0b107509 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java @@ -0,0 +1,30 @@ +package org.apache.flink.core.security.watch; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.nio.file.Path; + +public interface LocalFSWatchServiceListener { + default void onWatchStarted(Path realDirectoryPath) {} + + default void onFileOrDirectoryCreated(Path relativePath) {} + + default void onFileOrDirectoryDeleted(Path relativePath) {} + + default void onFileOrDirectoryModified(Path relativePath) {} +} diff --git a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java new file mode 100644 index 0000000000000..23e55860ab370 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.core.security.watch; + +import java.io.IOException; +import java.nio.file.FileSystems; +import java.nio.file.Path; +import java.nio.file.WatchService; +import java.util.concurrent.ConcurrentHashMap; + +import static java.nio.file.StandardWatchEventKinds.ENTRY_CREATE; +import static java.nio.file.StandardWatchEventKinds.ENTRY_DELETE; +import static java.nio.file.StandardWatchEventKinds.ENTRY_MODIFY; + +public final class LocalFSWatchSingleton { + // The field must be declared volatile so that double check lock would work + // correctly. + private static volatile LocalFSWatchSingleton instance; + + ConcurrentHashMap watchers = + new ConcurrentHashMap<>(); + + private LocalFSWatchSingleton() {} + + public static LocalFSWatchSingleton getInstance() { + LocalFSWatchSingleton result = instance; + if (result != null) { + return result; + } + synchronized (LocalFSWatchSingleton.class) { + if (instance == null) { + instance = new LocalFSWatchSingleton(); + } + return instance; + } + } + + public void registerPath(Path[] pathsToWatch, LocalFSWatchServiceListener callback) + throws IOException { + + WatchService watcher = FileSystems.getDefault().newWatchService(); + for (Path pathToWatch : pathsToWatch) { + Path realDirectoryPath = pathToWatch.toRealPath(); + realDirectoryPath.register(watcher, ENTRY_CREATE, ENTRY_DELETE, ENTRY_MODIFY); + } + callback.onWatchStarted(pathsToWatch[0]); + watchers.put(watcher, callback); + } +} diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java index e2ea1801ea860..4b1d318e96c4d 100644 --- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java @@ -17,86 +17,78 @@ package org.apache.flink.runtime.rpc.pekko; +import org.apache.flink.core.security.watch.LocalFSWatchSingleton; + import org.apache.flink.shaded.netty4.io.netty.handler.ssl.util.FingerprintTrustManagerFactory; import com.typesafe.config.Config; import org.apache.pekko.actor.ActorSystem; -import org.apache.pekko.remote.RemoteTransportException; import org.apache.pekko.remote.transport.netty.ConfigSSLEngineProvider; +import org.apache.pekko.remote.transport.netty.SSLEngineProvider; +import org.apache.pekko.stream.TLSRole; -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.SSLEngine; import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.security.GeneralSecurityException; -import java.security.KeyStore; +import java.nio.file.Path; import java.util.List; /** * Extension of the {@link ConfigSSLEngineProvider} to use a {@link FingerprintTrustManagerFactory}. */ @SuppressWarnings("deprecation") -public class CustomSSLEngineProvider extends ConfigSSLEngineProvider { +public class CustomSSLEngineProvider implements SSLEngineProvider { + private final String sslTrustStore; - private final String sslTrustStorePassword; - private final List sslCertFingerprints; - private final String sslKeyStoreType; - private final String sslTrustStoreType; + private final List sslEnabledAlgorithms; + private final String sslProtocol; + private final Boolean sslRequireMutualAuthentication; + private final SSLContextLoader sslContextLoader; - public CustomSSLEngineProvider(ActorSystem system) { - super(system); + public CustomSSLEngineProvider(ActorSystem system) throws IOException { final Config securityConfig = system.settings().config().getConfig("pekko.remote.classic.netty.ssl.security"); sslTrustStore = securityConfig.getString("trust-store"); - sslTrustStorePassword = securityConfig.getString("trust-store-password"); - sslCertFingerprints = securityConfig.getStringList("cert-fingerprints"); - sslKeyStoreType = securityConfig.getString("key-store-type"); - sslTrustStoreType = securityConfig.getString("trust-store-type"); + String sslKeyStore = securityConfig.getString("key-store"); + sslEnabledAlgorithms = securityConfig.getStringList("enabled-algorithms"); + sslProtocol = securityConfig.getString("protocol"); + sslRequireMutualAuthentication = securityConfig.getBoolean("require-mutual-authentication"); + Boolean sslEnabledCertReload = securityConfig.getBoolean("enabled-cert-reload"); + + sslContextLoader = new SSLContextLoader(sslTrustStore, sslProtocol, securityConfig); + if (sslEnabledCertReload) { + LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerPath( + new Path[] { + Path.of(sslTrustStore).getParent(), Path.of(sslKeyStore).getParent() + }, + sslContextLoader); + } } @Override - public TrustManager[] trustManagers() { - try { - final TrustManagerFactory trustManagerFactory = - sslCertFingerprints.isEmpty() - ? TrustManagerFactory.getInstance( - TrustManagerFactory.getDefaultAlgorithm()) - : FingerprintTrustManagerFactory.builder("SHA1") - .fingerprints(sslCertFingerprints) - .build(); - - trustManagerFactory.init( - loadKeystore(sslTrustStore, sslTrustStorePassword, sslTrustStoreType)); - return trustManagerFactory.getTrustManagers(); - } catch (GeneralSecurityException | IOException e) { - // replicate exception handling from SSLEngineProvider - throw new RemoteTransportException( - "Server SSL connection could not be established because SSL context could not be constructed", - e); - } + public SSLEngine createServerSSLEngine() { + return createSSLEngine(TLSRole.server()); } @Override - public KeyStore loadKeystore(String filename, String password) { - try { - return loadKeystore(filename, password, sslKeyStoreType); - } catch (IOException | GeneralSecurityException e) { - throw new RemoteTransportException( - "Server SSL connection could not be established because key store could not be loaded", - e); - } + public SSLEngine createClientSSLEngine() { + return createSSLEngine(TLSRole.client()); + } + + private SSLEngine createSSLEngine(TLSRole role) { + return createSSLEngine(sslContextLoader.createSSLEngine(), role); } - private KeyStore loadKeystore(String filename, String password, String keystoreType) - throws IOException, GeneralSecurityException { - KeyStore keyStore = KeyStore.getInstance(keystoreType); - try (InputStream fin = Files.newInputStream(Paths.get(filename))) { - char[] passwordCharArray = password.toCharArray(); - keyStore.load(fin, passwordCharArray); + private SSLEngine createSSLEngine(SSLEngine engine, TLSRole role) { + engine.setUseClientMode(role == TLSRole.client()); + engine.setEnabledCipherSuites(sslEnabledAlgorithms.toArray(String[]::new)); + engine.setEnabledProtocols(new String[] {sslProtocol}); + + if ((role != TLSRole.client()) && sslRequireMutualAuthentication) { + engine.setNeedClientAuth(true); } - return keyStore; + + return engine; } } diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java index 2825c834ffc5b..aef433b7affd0 100644 --- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java @@ -359,6 +359,9 @@ private static void addSslRemoteConfig( Arrays.stream(sslAlgorithmsString.split(",")) .collect(Collectors.joining(",", "[", "]")); + final boolean enabledCertReloadConfig = SecurityOptions.isReloadCertificate(configuration); + final String enabledCertReload = booleanToOnOrOff(enabledCertReloadConfig); + final String sslEngineProviderName = CustomSSLEngineProvider.class.getCanonicalName(); configBuilder @@ -383,6 +386,7 @@ private static void addSslRemoteConfig( .add(" random-number-generator = \"\"") .add(" require-mutual-authentication = on") .add(" cert-fingerprints = " + sslCertFingerprints + "") + .add(" enabled-cert-reload = " + enabledCertReload + "") .add(" }") .add(" }") .add(" }") diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java new file mode 100644 index 0000000000000..e56bb534402d1 --- /dev/null +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.pekko; + +import org.apache.flink.core.security.watch.LocalFSWatchServiceListener; + +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.util.FingerprintTrustManagerFactory; + +import com.typesafe.config.Config; +import org.apache.pekko.remote.RemoteTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.GeneralSecurityException; +import java.security.KeyManagementException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.UnrecoverableKeyException; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +public class SSLContextLoader implements LocalFSWatchServiceListener { + + private static final Logger LOG = LoggerFactory.getLogger(SSLContextLoader.class); + + private final String sslTrustStore; + private final String sslTrustStorePassword; + private final List sslCertFingerprints; + private final String sslKeyStoreType; + private final String sslTrustStoreType; + private final String sslProtocol; + private final String sslKeyStore; + private final String sslKeyStorePassword; + private final String sslKeyPassword; + private final String sslRandomNumberGenerator; + + private final AtomicBoolean toReload = new AtomicBoolean(false); + + private volatile SSLContext sslContext; + + public SSLContextLoader(String sslTrustStore, String sslProtocol, Config securityConfig) { + this.sslTrustStore = sslTrustStore; + this.sslProtocol = sslProtocol; + + this.sslTrustStorePassword = securityConfig.getString("trust-store-password"); + this.sslCertFingerprints = securityConfig.getStringList("cert-fingerprints"); + this.sslKeyStoreType = securityConfig.getString("key-store-type"); + this.sslTrustStoreType = securityConfig.getString("trust-store-type"); + this.sslKeyStore = securityConfig.getString("key-store"); + sslKeyStorePassword = securityConfig.getString("key-store-password"); + sslKeyPassword = securityConfig.getString("key-password"); + sslRandomNumberGenerator = securityConfig.getString("random-number-generator"); + + loadSSLContext(); + } + + void loadSSLContext() { + SSLContext ctx; + try { + LOG.debug("Loading SSL context for pekko"); + SecureRandom rng = createSecureRandom(); + ctx = SSLContext.getInstance(sslProtocol); + ctx.init(keyManagers(), trustManagers(), rng); + } catch (KeyManagementException + | NoSuchAlgorithmException + | UnrecoverableKeyException + | KeyStoreException e) { + throw new RuntimeException("Cannot load SSL context", e); + } + + this.sslContext = ctx; + } + + public SSLEngine createSSLEngine() { + reloadContextIfNeeded(); + return sslContext.createSSLEngine(); + } + + public SecureRandom createSecureRandom() throws NoSuchAlgorithmException { + SecureRandom rng; + if ("".equals(sslRandomNumberGenerator)) { + rng = new SecureRandom(); + } else { + rng = SecureRandom.getInstance(sslRandomNumberGenerator); + } + rng.nextInt(); + return rng; + } + + @Override + public void onFileOrDirectoryModified(Path relativePath) { + toReload.set(true); + } + + private synchronized void reloadContextIfNeeded() { + if (toReload.compareAndSet(true, false)) { + loadSSLContext(); + } + } + + /** Subclass may override to customize `KeyManager`. */ + private KeyManager[] keyManagers() + throws NoSuchAlgorithmException, UnrecoverableKeyException, KeyStoreException { + KeyManagerFactory factory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + factory.init(loadKeystore(sslKeyStore, sslKeyStorePassword), sslKeyPassword.toCharArray()); + return factory.getKeyManagers(); + } + + public TrustManager[] trustManagers() { + try { + final TrustManagerFactory trustManagerFactory = + sslCertFingerprints.isEmpty() + ? TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()) + : FingerprintTrustManagerFactory.builder("SHA1") + .fingerprints(sslCertFingerprints) + .build(); + + trustManagerFactory.init( + loadKeystore(sslTrustStore, sslTrustStorePassword, sslTrustStoreType)); + return trustManagerFactory.getTrustManagers(); + } catch (GeneralSecurityException | IOException e) { + // replicate exception handling from SSLEngineProvider + throw new RemoteTransportException( + "Server SSL connection could not be established because SSL context could not be constructed", + e); + } + } + + public KeyStore loadKeystore(String filename, String password) { + try { + return loadKeystore(filename, password, sslKeyStoreType); + } catch (IOException | GeneralSecurityException e) { + throw new RemoteTransportException( + "Server SSL connection could not be established because key store could not be loaded", + e); + } + } + + private KeyStore loadKeystore(String filename, String password, String keystoreType) + throws IOException, GeneralSecurityException { + KeyStore keyStore = KeyStore.getInstance(keystoreType); + try (InputStream fin = Files.newInputStream(Paths.get(filename))) { + char[] passwordCharArray = password.toCharArray(); + keyStore.load(fin, passwordCharArray); + } + return keyStore; + } +} diff --git a/flink-runtime/pom.xml b/flink-runtime/pom.xml index 61a2a5beac4ce..e0f11e4f2deaa 100644 --- a/flink-runtime/pom.xml +++ b/flink-runtime/pom.xml @@ -442,7 +442,7 @@ under the License. - + diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java index 758e7924afc53..52b6fd41cdb19 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java @@ -23,11 +23,10 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; -import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.security.watch.LocalFSWatchSingleton; import org.apache.flink.runtime.dispatcher.cleanup.GloballyCleanableResource; import org.apache.flink.runtime.dispatcher.cleanup.LocallyCleanableResource; -import org.apache.flink.runtime.net.SSLUtils; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FileUtils; import org.apache.flink.util.NetUtils; @@ -42,7 +41,6 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; -import javax.net.ServerSocketFactory; import java.io.File; import java.io.FileNotFoundException; @@ -53,11 +51,11 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.UnknownHostException; +import java.nio.file.Path; import java.security.MessageDigest; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.Timer; @@ -99,7 +97,7 @@ public class BlobServer extends Thread /** The server socket listening for incoming connections. */ // can be null if BlobServer is shut down before constructor completion - @Nullable private final ServerSocket serverSocket; + @Nullable private final BlobServerSocket blobServerSocket; /** Blob Server configuration. */ private final Configuration blobServiceConfiguration; @@ -200,51 +198,32 @@ public BlobServer(Configuration config, Reference storageDir, BlobStore bl // ----------------------- start the server ------------------- - final String serverPortRange = config.get(BlobServerOptions.PORT); - final Iterator ports = NetUtils.getPortRangeFromString(serverPortRange); - - final ServerSocketFactory socketFactory; + blobServerSocket = new BlobServerSocket(config, backlog, maxConnections); if (SecurityOptions.isInternalSSLEnabled(config) - && config.get(BlobServerOptions.SSL_ENABLED)) { - try { - socketFactory = SSLUtils.createSSLServerSocketFactory(config); - } catch (Exception e) { - throw new IOException("Failed to initialize SSL for the blob server", e); - } - } else { - socketFactory = ServerSocketFactory.getDefault(); - } - - final int finalBacklog = backlog; - final String bindHost = - config.getOptional(JobManagerOptions.BIND_HOST) - .orElseGet(NetUtils::getWildcardIPAddress); - - this.serverSocket = - NetUtils.createSocketFromPorts( - ports, - (port) -> - socketFactory.createServerSocket( - port, finalBacklog, InetAddress.getByName(bindHost))); - - if (serverSocket == null) { - throw new IOException( - "Unable to open BLOB Server in specified port range: " + serverPortRange); + && config.get(BlobServerOptions.SSL_ENABLED) + && SecurityOptions.isReloadCertificate(config)) { + String keystoreFilePath = + config.get( + SecurityOptions.SSL_INTERNAL_KEYSTORE, + config.get(SecurityOptions.SSL_KEYSTORE)); + String truststoreFilePath = + config.get( + SecurityOptions.SSL_INTERNAL_TRUSTSTORE, + config.get(SecurityOptions.SSL_TRUSTSTORE)); + + LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerPath( + new Path[] { + Path.of(keystoreFilePath).getParent(), + Path.of(truststoreFilePath).getParent() + }, + blobServerSocket); } // start the server thread setName("BLOB Server listener at " + getPort()); setDaemon(true); - if (LOG.isInfoEnabled()) { - LOG.info( - "Started BLOB server at {}:{} - max concurrent requests: {} - max backlog: {}", - serverSocket.getInetAddress().getHostAddress(), - getPort(), - maxConnections, - backlog); - } - checkStoredBlobsForCorruption(); registerBlobExpiryTimes(); } @@ -314,8 +293,13 @@ ReadWriteLock getReadWriteLock() { public void run() { try { while (!this.shutdownRequested.get()) { + if (this.blobServerSocket.reloadContextIfNeeded()) { + closeActiveConnections(); + } BlobServerConnection conn = - new BlobServerConnection(NetUtils.acceptWithoutTimeout(serverSocket), this); + new BlobServerConnection( + NetUtils.acceptWithoutTimeout(blobServerSocket.getServerSocket()), + this); try { synchronized (activeConnections) { while (activeConnections.size() >= maxConnections) { @@ -356,9 +340,9 @@ public void close() throws IOException { if (shutdownRequested.compareAndSet(false, true)) { Exception exception = null; - if (serverSocket != null) { + if (blobServerSocket != null) { try { - this.serverSocket.close(); + this.blobServerSocket.close(); } catch (IOException ioe) { exception = ioe; } @@ -375,15 +359,7 @@ public void close() throws IOException { LOG.debug("Error while waiting for this thread to die.", ie); } - synchronized (activeConnections) { - if (!activeConnections.isEmpty()) { - for (BlobServerConnection conn : activeConnections) { - LOG.debug("Shutting down connection {}.", conn.getName()); - conn.close(); - } - activeConnections.clear(); - } - } + closeActiveConnections(); // Clean up the storage directory if it is owned try { @@ -397,24 +373,26 @@ public void close() throws IOException { // Remove shutdown hook to prevent resource leaks ShutdownHookUtil.removeShutdownHook(shutdownHook, getClass().getSimpleName(), LOG); - if (LOG.isInfoEnabled()) { - if (serverSocket != null) { - LOG.info( - "Stopped BLOB server at {}:{}", - serverSocket.getInetAddress().getHostAddress(), - getPort()); - } else { - LOG.info("Stopped BLOB server before initializing the socket"); + ExceptionUtils.tryRethrowIOException(exception); + } + } + + private void closeActiveConnections() { + synchronized (activeConnections) { + if (!activeConnections.isEmpty()) { + for (BlobServerConnection conn : activeConnections) { + LOG.debug("Shutting down connection {}.", conn.getName()); + conn.close(); } + activeConnections.clear(); } - - ExceptionUtils.tryRethrowIOException(exception); } } protected BlobClient createClient() throws IOException { return new BlobClient( - new InetSocketAddress(serverSocket.getInetAddress(), getPort()), + new InetSocketAddress( + blobServerSocket.getServerSocket().getInetAddress(), getPort()), blobServiceConfiguration); } @@ -1007,7 +985,7 @@ public final int getMinOffloadingSize() { */ @Override public int getPort() { - return this.serverSocket.getLocalPort(); + return this.blobServerSocket.getPort(); } /** @@ -1017,7 +995,7 @@ public int getPort() { */ @Override public InetAddress getAddress() { - InetAddress bindAddr = serverSocket.getInetAddress(); + InetAddress bindAddr = blobServerSocket.getServerSocket().getInetAddress(); if (bindAddr.getHostAddress().equals(NetUtils.getWildcardIPAddress())) { try { return InetAddress.getLocalHost(); @@ -1049,7 +1027,7 @@ public boolean isShutdown() { /** Access to the server socket, for testing. */ ServerSocket getServerSocket() { - return this.serverSocket; + return this.blobServerSocket.getServerSocket(); } void unregisterConnection(BlobServerConnection conn) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java new file mode 100644 index 0000000000000..259c73390a9cc --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.blob; + +import org.apache.flink.configuration.BlobServerOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.security.watch.LocalFSWatchServiceListener; +import org.apache.flink.runtime.net.SSLUtils; +import org.apache.flink.util.NetUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ServerSocketFactory; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.ServerSocket; +import java.nio.file.Path; +import java.util.Collections; +import java.util.Iterator; +import java.util.concurrent.atomic.AtomicBoolean; + +/** This class implements socket management (open, close) for the BLOB server. */ +public class BlobServerSocket implements LocalFSWatchServiceListener { + + private static final Logger LOG = LoggerFactory.getLogger(BlobServerSocket.class); + + private final Configuration config; + private final int backlog; + private final String serverPortRange; + private ServerSocket serverSocket; + private final int maxConnections; + private final AtomicBoolean firstCreation = new AtomicBoolean(true); + private final AtomicBoolean toReload = new AtomicBoolean(false); + + public BlobServerSocket(Configuration config, int backlog, int maxConnections) + throws IOException { + this.config = config; + this.backlog = backlog; + this.maxConnections = maxConnections; + + serverPortRange = config.get(BlobServerOptions.PORT); + createSocket(); + } + + @Override + public void onFileOrDirectoryModified(Path relativePath) { + toReload.set(true); + } + + public ServerSocket getServerSocket() { + return serverSocket; + } + + public synchronized boolean reloadContextIfNeeded() { + if (toReload.compareAndSet(true, false)) { + try { + close(); + createSocket(); + return true; + } catch (Exception e) { + LOG.warn("Failed to reload SSL context", e); + toReload.set(true); + } + } + return false; + } + + private synchronized void createSocket() throws IOException { + Iterator ports; + if (firstCreation.get()) { + ports = NetUtils.getPortRangeFromString(serverPortRange); + } else { + ports = Collections.singleton(serverSocket.getLocalPort()).iterator(); + } + + final ServerSocketFactory socketFactory; + if (SecurityOptions.isInternalSSLEnabled(config) + && config.get(BlobServerOptions.SSL_ENABLED)) { + try { + socketFactory = SSLUtils.createSSLServerSocketFactory(config); + } catch (Exception e) { + throw new IOException("Failed to initialize SSL for the blob server", e); + } + } else { + socketFactory = ServerSocketFactory.getDefault(); + } + + final int finalBacklog = backlog; + final String bindHost = + config.getOptional(JobManagerOptions.BIND_HOST) + .orElseGet(NetUtils::getWildcardIPAddress); + + this.serverSocket = + NetUtils.createSocketFromPorts( + ports, + (port) -> + socketFactory.createServerSocket( + port, finalBacklog, InetAddress.getByName(bindHost))); + + if (serverSocket == null) { + throw new IOException( + "Unable to open BLOB Server in specified port range: " + serverPortRange); + } + + if (LOG.isInfoEnabled()) { + LOG.info( + "Started BLOB server at {}:{} - max concurrent requests: {} - max backlog: {}", + serverSocket.getInetAddress().getHostAddress(), + getPort(), + maxConnections, + backlog); + } + firstCreation.set(false); + } + + /** + * Returns the port on which the server is listening. + * + * @return port on which the server is listening + */ + public int getPort() { + return serverSocket.getLocalPort(); + } + + public void close() throws IOException { + if (serverSocket != null) { + close(serverSocket); + } + } + + private void close(ServerSocket serverSocketToClose) throws IOException { + if (LOG.isInfoEnabled()) { + if (serverSocketToClose != null) { + LOG.info( + "Stopped BLOB server at {}:{}", + serverSocketToClose.getInetAddress().getHostAddress(), + getPort()); + } else { + LOG.info("Stopped BLOB server before initializing the socket"); + } + } + if (serverSocketToClose != null) { + serverSocketToClose.close(); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java index ced8edeb072b5..9d6f91d2900c0 100755 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java @@ -31,12 +31,14 @@ import org.apache.flink.configuration.JMXServerOptions; import org.apache.flink.configuration.JobManagerOptions; import org.apache.flink.configuration.SchedulerExecutionMode; +import org.apache.flink.configuration.SecurityOptions; import org.apache.flink.configuration.WebOptions; import org.apache.flink.core.failure.FailureEnricher; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.plugin.PluginManager; import org.apache.flink.core.plugin.PluginUtils; import org.apache.flink.core.security.FlinkSecurityManager; +import org.apache.flink.core.security.watch.LocalFSWatchService; import org.apache.flink.management.jmx.JMXService; import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.blob.BlobUtils; @@ -349,6 +351,12 @@ protected void initializeServices(Configuration configuration, PluginManager plu DeterminismEnvelope.nondeterministicValue( ResourceID.generate())); + if (SecurityOptions.isReloadCertificate(configuration)) { + LOG.debug("Initialize local file system watch service for certificate reloading."); + LocalFSWatchService localFSWatchService = new LocalFSWatchService(); + localFSWatchService.start(); + } + LOG.debug( "Initialize cluster entrypoint {} with resource id {}.", getClass().getSimpleName(), diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableJdkSslContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableJdkSslContext.java new file mode 100644 index 0000000000000..25f3e899b7e2c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableJdkSslContext.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.net; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.SecurityOptions; + +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.ClientAuth; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslContextBuilder; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslProvider; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.TrustManagerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.runtime.net.SSLUtils.getEnabledCipherSuites; +import static org.apache.flink.runtime.net.SSLUtils.getEnabledProtocols; +import static org.apache.flink.runtime.net.SSLUtils.getKeyManagerFactory; +import static org.apache.flink.runtime.net.SSLUtils.getTrustManagerFactory; + +/** JDK SSL context which is able to reload keystore. */ +public class ReloadableJdkSslContext extends ReloadableSslContext { + + private static final Logger LOG = LoggerFactory.getLogger(ReloadableJdkSslContext.class); + + public ReloadableJdkSslContext(Configuration config, boolean clientMode, SslProvider provider) + throws Exception { + super(config, clientMode, ClientAuth.NONE, provider); + } + + @Override + protected void loadContext() throws Exception { + LOG.info("Loading JDK SSL context from {}", this.config); + + String[] sslProtocols = getEnabledProtocols(config); + List ciphers = Arrays.asList(getEnabledCipherSuites(config)); + int sessionCacheSize = config.get(SecurityOptions.SSL_INTERNAL_SESSION_CACHE_SIZE); + int sessionTimeoutMs = config.get(SecurityOptions.SSL_INTERNAL_SESSION_TIMEOUT); + + KeyManagerFactory kmf = getKeyManagerFactory(config, true, provider); + ClientAuth clientAuth = ClientAuth.REQUIRE; + + final SslContextBuilder sslContextBuilder; + if (clientMode) { + sslContextBuilder = SslContextBuilder.forClient().keyManager(kmf); + } else { + sslContextBuilder = SslContextBuilder.forServer(kmf); + } + + Optional tmf = getTrustManagerFactory(config, true); + tmf.map(sslContextBuilder::trustManager); + + this.sslContext = + sslContextBuilder + .sslProvider(provider) + .protocols(sslProtocols) + .ciphers(ciphers) + .clientAuth(clientAuth) + .sessionCacheSize(sessionCacheSize) + .sessionTimeout(sessionTimeoutMs / 1000) + .build(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java new file mode 100644 index 0000000000000..f85dc625c7db1 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.net; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.security.watch.LocalFSWatchServiceListener; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufAllocator; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.ApplicationProtocolNegotiator; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.ClientAuth; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.JdkSslContext; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslContext; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslContextBuilder; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslProvider; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.TrustManagerFactory; + +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.apache.flink.runtime.net.SSLUtils.getEnabledCipherSuites; +import static org.apache.flink.runtime.net.SSLUtils.getEnabledProtocols; +import static org.apache.flink.runtime.net.SSLUtils.getKeyManagerFactory; +import static org.apache.flink.runtime.net.SSLUtils.getTrustManagerFactory; + +/** SSL context which is able to reload keystore. */ +public class ReloadableSslContext extends SslContext implements LocalFSWatchServiceListener { + + private static final Logger LOG = LoggerFactory.getLogger(ReloadableSslContext.class); + + protected final Configuration config; + protected final boolean clientMode; + protected final ClientAuth clientAuth; + protected final SslProvider provider; + protected volatile SslContext sslContext; + + private final AtomicBoolean toReload = new AtomicBoolean(false); + + public ReloadableSslContext( + Configuration config, boolean clientMode, ClientAuth clientAuth, SslProvider provider) + throws Exception { + this.config = config; + this.clientMode = clientMode; + this.clientAuth = clientAuth; + this.provider = provider; + loadContext(); + } + + public SSLContext getSSLContext() { + reloadContextIfNeeded(); + return ((JdkSslContext) this.sslContext).context(); + } + + @Override + public boolean isClient() { + reloadContextIfNeeded(); + return sslContext.isClient(); + } + + @Override + public List cipherSuites() { + reloadContextIfNeeded(); + return sslContext.cipherSuites(); + } + + @Override + public ApplicationProtocolNegotiator applicationProtocolNegotiator() { + reloadContextIfNeeded(); + return sslContext.applicationProtocolNegotiator(); + } + + @Override + public SSLEngine newEngine(ByteBufAllocator byteBufAllocator) { + reloadContextIfNeeded(); + return sslContext.newEngine(byteBufAllocator); + } + + @Override + public SSLEngine newEngine(ByteBufAllocator byteBufAllocator, String s, int i) { + reloadContextIfNeeded(); + return sslContext.newEngine(byteBufAllocator, s, i); + } + + @Override + public SSLSessionContext sessionContext() { + reloadContextIfNeeded(); + return sslContext.sessionContext(); + } + + @Override + public void onFileOrDirectoryModified(Path relativePath) { + toReload.set(true); + } + + protected synchronized void reloadContextIfNeeded() { + if (toReload.compareAndSet(true, false)) { + try { + loadContext(); + } catch (Exception e) { + LOG.warn("Failed to reload SSL context", e); + toReload.set(true); + } + } + } + + protected void loadContext() throws Exception { + LOG.info("Loading SSL context from {}", config); + + String[] sslProtocols = getEnabledProtocols(config); + List ciphers = Arrays.asList(getEnabledCipherSuites(config)); + + final SslContextBuilder sslContextBuilder; + if (clientMode) { + sslContextBuilder = SslContextBuilder.forClient(); + if (clientAuth != ClientAuth.NONE) { + KeyManagerFactory kmf = getKeyManagerFactory(config, false, provider); + sslContextBuilder.keyManager(kmf); + } + } else { + KeyManagerFactory kmf = getKeyManagerFactory(config, false, provider); + sslContextBuilder = SslContextBuilder.forServer(kmf); + } + + if (clientMode || clientAuth != ClientAuth.NONE) { + Optional tmf = getTrustManagerFactory(config, false); + tmf.map( + // Use specific ciphers and protocols if SSL is configured with self-signed + // certificates (user-supplied truststore) + tm -> + sslContextBuilder + .trustManager(tm) + .protocols(sslProtocols) + .ciphers(ciphers) + .clientAuth(clientAuth)); + } + + sslContext = sslContextBuilder.sslProvider(provider).build(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java index 0660873b1cb88..eba3aa0d5ad4b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java @@ -23,18 +23,20 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.security.watch.LocalFSWatchSingleton; import org.apache.flink.runtime.io.network.netty.SSLHandlerFactory; import org.apache.flink.util.StringUtils; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.ClientAuth; -import org.apache.flink.shaded.netty4.io.netty.handler.ssl.JdkSslContext; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.OpenSsl; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.OpenSslX509KeyManagerFactory; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslContext; -import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslContextBuilder; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslProvider; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.util.FingerprintTrustManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import javax.annotation.Nullable; import javax.net.ServerSocketFactory; import javax.net.SocketFactory; @@ -50,13 +52,13 @@ import java.net.InetAddress; import java.net.ServerSocket; import java.nio.file.Files; +import java.nio.file.Path; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; -import java.util.Arrays; -import java.util.List; +import java.util.HashSet; import java.util.Optional; import static org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslProvider.JDK; @@ -67,13 +69,15 @@ /** Common utilities to manage SSL transport settings. */ public class SSLUtils { + private static final Logger LOG = LoggerFactory.getLogger(SSLUtils.class); + /** * Creates a factory for SSL Server Sockets from the given configuration. SSL Server Sockets are * always part of internal communication. */ public static ServerSocketFactory createSSLServerSocketFactory(Configuration config) throws Exception { - SSLContext sslContext = createInternalSSLContext(config, false); + SSLContext sslContext = createInternalSSLContext(config, false, false); if (sslContext == null) { throw new IllegalConfigurationException("SSL is not enabled"); } @@ -91,7 +95,7 @@ public static ServerSocketFactory createSSLServerSocketFactory(Configuration con */ public static SocketFactory createSSLClientSocketFactory(Configuration config) throws Exception { - SSLContext sslContext = createInternalSSLContext(config, true); + SSLContext sslContext = createInternalSSLContext(config, true, true); if (sslContext == null) { throw new IllegalConfigurationException("SSL is not enabled"); } @@ -102,7 +106,7 @@ public static SocketFactory createSSLClientSocketFactory(Configuration config) /** Creates a SSLEngineFactory to be used by internal communication server endpoints. */ public static SSLHandlerFactory createInternalServerSSLEngineFactory(final Configuration config) throws Exception { - SslContext sslContext = createInternalNettySSLContext(config, false); + SslContext sslContext = createInternalNettySSLContext(config, false, true); if (sslContext == null) { throw new IllegalConfigurationException( "SSL is not enabled for internal communication."); @@ -117,7 +121,7 @@ public static SSLHandlerFactory createInternalServerSSLEngineFactory(final Confi /** Creates a SSLEngineFactory to be used by internal communication client endpoints. */ public static SSLHandlerFactory createInternalClientSSLEngineFactory(final Configuration config) throws Exception { - SslContext sslContext = createInternalNettySSLContext(config, true); + SslContext sslContext = createInternalNettySSLContext(config, true, true); if (sslContext == null) { throw new IllegalConfigurationException( "SSL is not enabled for internal communication."); @@ -167,12 +171,12 @@ public static SSLHandlerFactory createRestClientSSLEngineFactory(final Configura return new SSLHandlerFactory(sslContext, -1, -1); } - private static String[] getEnabledProtocols(final Configuration config) { + static String[] getEnabledProtocols(final Configuration config) { checkNotNull(config, "config must not be null"); return config.get(SecurityOptions.SSL_PROTOCOL).split(","); } - private static String[] getEnabledCipherSuites(final Configuration config) { + static String[] getEnabledCipherSuites(final Configuration config) { checkNotNull(config, "config must not be null"); return config.get(SecurityOptions.SSL_ALGORITHMS).split(","); } @@ -195,7 +199,7 @@ static SslProvider getSSLProvider(final Configuration config) { } } - private static Optional getTrustManagerFactory( + static Optional getTrustManagerFactory( Configuration config, boolean internal) throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException { @@ -263,7 +267,7 @@ private static Optional getTrustManagerFactory( return Optional.of(tmf); } - private static KeyManagerFactory getKeyManagerFactory( + static KeyManagerFactory getKeyManagerFactory( Configuration config, boolean internal, SslProvider provider) throws KeyStoreException, IOException, @@ -322,12 +326,15 @@ private static KeyManagerFactory getKeyManagerFactory( * the client and server side configuration are identical, because of mutual authentication. */ @Nullable - private static SSLContext createInternalSSLContext(Configuration config, boolean clientMode) + private static SSLContext createInternalSSLContext( + Configuration config, boolean clientMode, Boolean watchForCertificateChange) throws Exception { - JdkSslContext nettySSLContext = - (JdkSslContext) createInternalNettySSLContext(config, clientMode, JDK); + ReloadableJdkSslContext nettySSLContext = + (ReloadableJdkSslContext) + createInternalNettySSLContext( + config, clientMode, JDK, watchForCertificateChange); if (nettySSLContext != null) { - return nettySSLContext.context(); + return nettySSLContext.getSSLContext(); } else { return null; } @@ -335,8 +342,10 @@ private static SSLContext createInternalSSLContext(Configuration config, boolean @Nullable private static SslContext createInternalNettySSLContext( - Configuration config, boolean clientMode) throws Exception { - return createInternalNettySSLContext(config, clientMode, getSSLProvider(config)); + Configuration config, boolean clientMode, Boolean watchForCertificateChange) + throws Exception { + return createInternalNettySSLContext( + config, clientMode, getSSLProvider(config), watchForCertificateChange); } /** @@ -345,39 +354,42 @@ private static SslContext createInternalNettySSLContext( */ @Nullable private static SslContext createInternalNettySSLContext( - Configuration config, boolean clientMode, SslProvider provider) throws Exception { + Configuration config, + boolean clientMode, + SslProvider provider, + Boolean watchForCertificateChange) + throws Exception { checkNotNull(config, "config"); if (!SecurityOptions.isInternalSSLEnabled(config)) { return null; } - String[] sslProtocols = getEnabledProtocols(config); - List ciphers = Arrays.asList(getEnabledCipherSuites(config)); - int sessionCacheSize = config.get(SecurityOptions.SSL_INTERNAL_SESSION_CACHE_SIZE); - int sessionTimeoutMs = config.get(SecurityOptions.SSL_INTERNAL_SESSION_TIMEOUT); - - KeyManagerFactory kmf = getKeyManagerFactory(config, true, provider); - ClientAuth clientAuth = ClientAuth.REQUIRE; - - final SslContextBuilder sslContextBuilder; - if (clientMode) { - sslContextBuilder = SslContextBuilder.forClient().keyManager(kmf); - } else { - sslContextBuilder = SslContextBuilder.forServer(kmf); + ReloadableJdkSslContext reloadableJdkSslContext = + new ReloadableJdkSslContext(config, clientMode, provider); + if (SecurityOptions.isReloadCertificate(config) && watchForCertificateChange) { + HashSet certificatePaths = new HashSet<>(); + certificatePaths.add( + Path.of( + getAndCheckOption( + config, + SecurityOptions.SSL_INTERNAL_KEYSTORE, + SecurityOptions.SSL_KEYSTORE)) + .getParent()); + String truststoreFilePath = + config.get( + SecurityOptions.SSL_INTERNAL_TRUSTSTORE, + config.get(SecurityOptions.SSL_TRUSTSTORE)); + if (truststoreFilePath != null) { + certificatePaths.add(Path.of(truststoreFilePath).getParent()); + } + Path[] pathsToWatch = new Path[certificatePaths.size()]; + certificatePaths.toArray(pathsToWatch); + LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerPath(pathsToWatch, reloadableJdkSslContext); } - Optional tmf = getTrustManagerFactory(config, true); - tmf.map(sslContextBuilder::trustManager); - - return sslContextBuilder - .sslProvider(provider) - .protocols(sslProtocols) - .ciphers(ciphers) - .clientAuth(clientAuth) - .sessionCacheSize(sessionCacheSize) - .sessionTimeout(sessionTimeoutMs / 1000) - .build(); + return reloadableJdkSslContext; } /** Creates an SSL context for clients against the external REST endpoint. */ @@ -389,10 +401,11 @@ public static SSLContext createRestSSLContext(Configuration config, boolean clie SecurityOptions.isRestSSLAuthenticationEnabled(config) ? ClientAuth.REQUIRE : ClientAuth.NONE; - JdkSslContext nettySSLContext = - (JdkSslContext) createRestNettySSLContext(config, clientMode, clientAuth, JDK); - if (nettySSLContext != null) { - return nettySSLContext.context(); + ReloadableSslContext reloadableSslContext = + (ReloadableSslContext) + createRestNettySSLContext(config, clientMode, clientAuth, JDK); + if (reloadableSslContext != null) { + return reloadableSslContext.getSSLContext(); } else { return null; } @@ -418,35 +431,33 @@ public static SslContext createRestNettySSLContext( return null; } - String[] sslProtocols = getEnabledProtocols(config); - List ciphers = Arrays.asList(getEnabledCipherSuites(config)); - - final SslContextBuilder sslContextBuilder; - if (clientMode) { - sslContextBuilder = SslContextBuilder.forClient(); - if (clientAuth != ClientAuth.NONE) { - KeyManagerFactory kmf = getKeyManagerFactory(config, false, provider); - sslContextBuilder.keyManager(kmf); + ReloadableSslContext reloadableSslContext = + new ReloadableSslContext(config, clientMode, clientAuth, provider); + + if (SecurityOptions.isReloadCertificate(config)) { + HashSet certificatePaths = new HashSet<>(); + String keystoreFilePath = + config.get( + SecurityOptions.SSL_REST_KEYSTORE, + config.get(SecurityOptions.SSL_KEYSTORE)); + if (keystoreFilePath != null) { + certificatePaths.add(Path.of(keystoreFilePath).getParent()); + } + String truststoreFilePath = + config.get( + SecurityOptions.SSL_REST_TRUSTSTORE, + config.get(SecurityOptions.SSL_TRUSTSTORE)); + if (truststoreFilePath != null) { + certificatePaths.add(Path.of(truststoreFilePath).getParent()); } - } else { - KeyManagerFactory kmf = getKeyManagerFactory(config, false, provider); - sslContextBuilder = SslContextBuilder.forServer(kmf); - } - if (clientMode || clientAuth != ClientAuth.NONE) { - Optional tmf = getTrustManagerFactory(config, false); - tmf.map( - // Use specific ciphers and protocols if SSL is configured with self-signed - // certificates (user-supplied truststore) - tm -> - sslContextBuilder - .trustManager(tm) - .protocols(sslProtocols) - .ciphers(ciphers) - .clientAuth(clientAuth)); + Path[] pathsToWatch = new Path[certificatePaths.size()]; + certificatePaths.toArray(pathsToWatch); + LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerPath(pathsToWatch, reloadableSslContext); } - return sslContextBuilder.sslProvider(provider).build(); + return reloadableSslContext; } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java index c8f0e24f73218..b15d8622a3163 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java @@ -23,12 +23,14 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.JMXServerOptions; import org.apache.flink.configuration.RpcOptions; +import org.apache.flink.configuration.SecurityOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.configuration.TaskManagerOptionsInternal; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.plugin.PluginManager; import org.apache.flink.core.plugin.PluginUtils; import org.apache.flink.core.security.FlinkSecurityManager; +import org.apache.flink.core.security.watch.LocalFSWatchService; import org.apache.flink.management.jmx.JMXService; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.runtime.blob.BlobCacheService; @@ -191,6 +193,12 @@ private void startTaskManagerRunnerServices() throws Exception { Hardware.getNumberCPUCores(), new ExecutorThreadFactory("taskmanager-future")); + if (SecurityOptions.isReloadCertificate(configuration)) { + LOG.debug("Initialize local file system watch service for certificate reloading."); + LocalFSWatchService localFSWatchService = new LocalFSWatchService(); + localFSWatchService.start(); + } + highAvailabilityServices = HighAvailabilityServicesUtils.createHighAvailabilityServices( configuration, From cc14c9e58fa18ff7ff20771a88a9fd7abeb408f6 Mon Sep 17 00:00:00 2001 From: "oleksandr.nitavskyi" Date: Tue, 23 Sep 2025 21:30:22 +0200 Subject: [PATCH 2/4] [FLINK-37504] Refactor for test compatibility * Unit test LocalFSWatchSingleton/SSLContextLoader * Test BlobStoreSsl to ensure certificates are reloaded --- .../flink/configuration/SecurityOptions.java | 9 +- .../watch/LocalFSDirectoryWatcher.java | 33 + .../security/watch/LocalFSWatchService.java | 61 +- .../watch/LocalFSWatchServiceListener.java | 89 ++- .../security/watch/LocalFSWatchSingleton.java | 23 +- .../watch/LocalFSWatchSingletonTest.java | 620 ++++++++++++++++++ .../rpc/pekko/CustomSSLEngineProvider.java | 24 +- .../flink/runtime/rpc/pekko/PekkoUtils.java | 2 +- .../runtime/rpc/pekko/SSLContextLoader.java | 42 +- .../rpc/pekko/SSLContextLoaderTest.java | 81 +++ flink-runtime/pom.xml | 2 +- .../apache/flink/runtime/blob/BlobServer.java | 21 +- .../flink/runtime/blob/BlobServerSocket.java | 44 +- .../runtime/entrypoint/ClusterEntrypoint.java | 2 +- .../runtime/net/ReloadableSslContext.java | 35 +- .../apache/flink/runtime/net/SSLUtils.java | 13 +- .../taskexecutor/TaskManagerRunner.java | 2 +- .../runtime/blob/BlobClientSslReloadTest.java | 255 +++++++ .../flink/runtime/blob/BlobClientSslTest.java | 10 + .../flink/runtime/net/SSLUtilsTest.java | 14 + 20 files changed, 1244 insertions(+), 138 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSDirectoryWatcher.java create mode 100644 flink-core/src/test/java/org/apache/flink/core/security/watch/LocalFSWatchSingletonTest.java create mode 100644 flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoaderTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslReloadTest.java diff --git a/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java index b856847b9c6a7..fb7dd332f38e6 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/SecurityOptions.java @@ -618,15 +618,14 @@ public static Configuration forProvider(Configuration configuration, String prov + "forcibly. (-1 = use system default)") .withDeprecatedKeys("security.ssl.close-notify-flush-timeout"); - // TODO check all documentation are well updated (explain mechanism) - /** Indicate if changes on keystore/truststore should leads to reload of the certificate. */ + /** Indicate if changes on keystore/truststore should trigger reload of the certificate. */ @Documentation.Section(Documentation.Sections.SECURITY_SSL) public static final ConfigOption SSL_RELOAD = key("security.ssl.reload") .booleanType() .defaultValue(false) .withDescription( - "Indicate if changes on keystore/truststore should leads to reload of the certificate."); + "If enabled, the application will monitor the keystore and truststore files for any changes. When a change is detected, internal network components (like Netty, Pekko, or BlobServer) will automatically reload the keystore/truststore certificates."); /** * Checks whether SSL for internal communication (rpc, data transport, blob server) is enabled. @@ -646,8 +645,8 @@ public static boolean isRestSSLAuthenticationEnabled(Configuration sslConfig) { return isRestSSLEnabled(sslConfig) && sslConfig.get(SSL_REST_AUTHENTICATION_ENABLED); } - /** Checks whether certificates must be reloaded in case of keytstore or trusttore changes. */ - public static boolean isReloadCertificate(Configuration sslConfig) { + /** Checks whether certificates must be reloaded in case of keystore or truststore changes. */ + public static boolean isCertificateReloadEnabled(Configuration sslConfig) { return sslConfig.get(SSL_RELOAD); } } diff --git a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSDirectoryWatcher.java b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSDirectoryWatcher.java new file mode 100644 index 0000000000000..4c1a81c8cce8b --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSDirectoryWatcher.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.core.security.watch; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.WatchService; +import java.util.Map; +import java.util.Set; + +public interface LocalFSDirectoryWatcher { + + Set> getWatchers(); + + void registerDirectory(Path[] dirsToWatch, LocalFSWatchServiceListener listener) + throws IOException; +} diff --git a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java index 3f61bf06569cf..3afd382924a20 100644 --- a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java +++ b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchService.java @@ -18,6 +18,8 @@ package org.apache.flink.core.security.watch; +import org.apache.flink.annotation.VisibleForTesting; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,7 +27,9 @@ import java.nio.file.WatchEvent; import java.nio.file.WatchKey; import java.nio.file.WatchService; +import java.time.Duration; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import static java.nio.file.StandardWatchEventKinds.ENTRY_CREATE; import static java.nio.file.StandardWatchEventKinds.ENTRY_DELETE; @@ -35,39 +39,58 @@ public class LocalFSWatchService extends Thread { private static final Logger LOG = LoggerFactory.getLogger(LocalFSWatchService.class); + private final long sleepDurationMs; + @VisibleForTesting AtomicBoolean running = new AtomicBoolean(false); + + public LocalFSWatchService() { + this(Duration.ofMillis(100)); + } + + public LocalFSWatchService(Duration sleepDuration) { + setDaemon(true); + setName("LocalFSWatchServiceThread"); + sleepDurationMs = sleepDuration.toMillis(); + } + public void run() { try { + running.set(true); while (true) { for (Map.Entry entry : - LocalFSWatchSingleton.getInstance().watchers.entrySet()) { - LOG.debug("Taking watch key"); + LocalFSWatchSingleton.getInstance().getWatchers()) { WatchKey watchKey = entry.getKey().poll(); if (watchKey == null) { continue; } - LOG.debug("Watch key arrived"); - for (WatchEvent watchEvent : watchKey.pollEvents()) { - System.out.println(watchEvent.kind()); - System.out.println(watchEvent.context()); - if (watchEvent.kind() == OVERFLOW) { - LOG.error("Filesystem events may have been lost or discarded"); - Thread.yield(); - } else if (watchEvent.kind() == ENTRY_CREATE) { - entry.getValue().onFileOrDirectoryCreated((Path) watchEvent.context()); - } else if (watchEvent.kind() == ENTRY_DELETE) { - entry.getValue().onFileOrDirectoryDeleted((Path) watchEvent.context()); - } else if (watchEvent.kind() == ENTRY_MODIFY) { - entry.getValue().onFileOrDirectoryModified((Path) watchEvent.context()); - } else { - LOG.warn("Unhandled watch event {}", watchEvent.kind()); - } - } + LOG.debug("Watch key arrived - {}", watchKey); + processWatchKey(entry, watchKey); watchKey.reset(); } + Thread.sleep(sleepDurationMs); } } catch (Exception e) { LOG.error("Filesystem watcher received exception and stopped: ", e); throw new RuntimeException(e); + } finally { + running.set(false); + } + } + + protected void processWatchKey( + Map.Entry entry, WatchKey watchKey) { + for (WatchEvent watchEvent : watchKey.pollEvents()) { + if (watchEvent.kind() == OVERFLOW) { + LOG.error("Filesystem events may have been lost or discarded"); + Thread.yield(); + } else if (watchEvent.kind() == ENTRY_CREATE) { + entry.getValue().onFileOrDirectoryCreated((Path) watchEvent.context()); + } else if (watchEvent.kind() == ENTRY_DELETE) { + entry.getValue().onFileOrDirectoryDeleted((Path) watchEvent.context()); + } else if (watchEvent.kind() == ENTRY_MODIFY) { + entry.getValue().onFileOrDirectoryModified((Path) watchEvent.context()); + } else { + LOG.warn("Unhandled watch event {}", watchEvent.kind()); + } } } } diff --git a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java index 33c3c0b107509..a8fc773522278 100644 --- a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java +++ b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchServiceListener.java @@ -1,14 +1,13 @@ -package org.apache.flink.core.security.watch; - /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -17,14 +16,84 @@ * limitations under the License. */ +package org.apache.flink.core.security.watch; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.nio.file.Path; +import java.util.concurrent.atomic.AtomicReference; public interface LocalFSWatchServiceListener { + + enum ReloadState { + CLEAN, // Context is up to date + DIRTY, // Context needs reloading + RELOADING // Context is currently being reloaded + } + + @FunctionalInterface + interface ContextLoader { + + void loadContext() throws Exception; + } + + Logger LOG = LoggerFactory.getLogger(LocalFSWatchServiceListener.class); + + /** + * Get the current reload state. Implementations should provide their own state management. + * + * @return the current reload state + */ + AtomicReference getReloadStateReference(); + default void onWatchStarted(Path realDirectoryPath) {} default void onFileOrDirectoryCreated(Path relativePath) {} default void onFileOrDirectoryDeleted(Path relativePath) {} - default void onFileOrDirectoryModified(Path relativePath) {} + default void onFileOrDirectoryModified(Path relativePath) { + getReloadStateReference().compareAndSet(ReloadState.CLEAN, ReloadState.DIRTY); + LOG.debug( + "File {} has been modified in {}, reloadState={}", + relativePath, + this, + getReloadStateReference().get()); + } + + default boolean reloadContextIfNeeded(ContextLoader loader) { + AtomicReference reloadState = getReloadStateReference(); + // Only one thread can transition from DIRTY to RELOADING + if (reloadState.compareAndSet(ReloadState.DIRTY, ReloadState.RELOADING)) { + try { + loader.loadContext(); + // Successfully loaded, mark as clean + reloadState.set(ReloadState.CLEAN); + return true; + } catch (Exception e) { + LOG.warn("Failed to reload context", e); + // Failed to load, mark as dirty for retry + reloadState.set(ReloadState.DIRTY); + } + } + return false; + // If state is CLEAN, do nothing + // If state is RELOADING, another thread is handling it, so we can proceed with current + // context + } + + /** + * Abstract base class that provides a default implementation of LocalFSWatchServiceListener + * with instance-level reload state management. + */ + abstract class AbstractLocalFSWatchServiceListener implements LocalFSWatchServiceListener { + private final AtomicReference reloadState = + new AtomicReference<>(ReloadState.CLEAN); + + @Override + public final AtomicReference getReloadStateReference() { + return reloadState; + } + } } diff --git a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java index 23e55860ab370..29a99fd8d2267 100644 --- a/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java +++ b/flink-core/src/main/java/org/apache/flink/core/security/watch/LocalFSWatchSingleton.java @@ -22,24 +22,26 @@ import java.nio.file.FileSystems; import java.nio.file.Path; import java.nio.file.WatchService; +import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import static java.nio.file.StandardWatchEventKinds.ENTRY_CREATE; import static java.nio.file.StandardWatchEventKinds.ENTRY_DELETE; import static java.nio.file.StandardWatchEventKinds.ENTRY_MODIFY; -public final class LocalFSWatchSingleton { +public final class LocalFSWatchSingleton implements LocalFSDirectoryWatcher { // The field must be declared volatile so that double check lock would work // correctly. - private static volatile LocalFSWatchSingleton instance; + private static volatile LocalFSDirectoryWatcher instance; ConcurrentHashMap watchers = new ConcurrentHashMap<>(); private LocalFSWatchSingleton() {} - public static LocalFSWatchSingleton getInstance() { - LocalFSWatchSingleton result = instance; + public static LocalFSDirectoryWatcher getInstance() { + LocalFSDirectoryWatcher result = instance; if (result != null) { return result; } @@ -51,15 +53,20 @@ public static LocalFSWatchSingleton getInstance() { } } - public void registerPath(Path[] pathsToWatch, LocalFSWatchServiceListener callback) + public Set> getWatchers() { + return watchers.entrySet(); + } + + @Override + public void registerDirectory(Path[] dirsToWatch, LocalFSWatchServiceListener listener) throws IOException { WatchService watcher = FileSystems.getDefault().newWatchService(); - for (Path pathToWatch : pathsToWatch) { + for (Path pathToWatch : dirsToWatch) { Path realDirectoryPath = pathToWatch.toRealPath(); realDirectoryPath.register(watcher, ENTRY_CREATE, ENTRY_DELETE, ENTRY_MODIFY); } - callback.onWatchStarted(pathsToWatch[0]); - watchers.put(watcher, callback); + listener.onWatchStarted(dirsToWatch[0]); + watchers.put(watcher, listener); } } diff --git a/flink-core/src/test/java/org/apache/flink/core/security/watch/LocalFSWatchSingletonTest.java b/flink-core/src/test/java/org/apache/flink/core/security/watch/LocalFSWatchSingletonTest.java new file mode 100644 index 0000000000000..718f0feae012c --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/core/security/watch/LocalFSWatchSingletonTest.java @@ -0,0 +1,620 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.core.security.watch; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +class LocalFSWatchSingletonTest { + + private static final Logger LOG = LoggerFactory.getLogger(LocalFSWatchSingletonTest.class); + + private static final int EXTRA_WATCHER_ITERATIONS = 100; + + private ExecutorService writerExecutor; + + @BeforeEach + void setUp() { + writerExecutor = Executors.newFixedThreadPool(10); + } + + @AfterEach + void tearDown() throws InterruptedException { + if (writerExecutor != null && !writerExecutor.isShutdown()) { + writerExecutor.shutdown(); + if (!writerExecutor.awaitTermination(5, TimeUnit.SECONDS)) { + writerExecutor.shutdownNow(); + } + } + } + + static Stream testParameters() { + return Stream.of( + Arguments.of("single file; single notification; single writer", 1, 1, 1, 1, 1), + Arguments.of("single file; few writers", 1, 8, 1, 1, 1), + Arguments.of("single file; few writes", 1, 1, 8, 1, 1), + Arguments.of("single file; few watchers", 1, 1, 1, 8, 1), + Arguments.of("multiple writes/single watcher scenario", 3, 2, 10, 1, 3), + Arguments.of("multiple writes/multiple watchers scenario", 3, 2, 10, 2, 3)); + } + + static Stream manualParameters() { + return Stream.of( + Arguments.of("single file notification to ensure it works", 1, 1, 1, 1, 1), + Arguments.of("stress test scenario", 50, 30, 80, 5, 4)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("testParameters") + void testFileWatchingScenarios( + String testName, + int fileCount, + int writerCount, + int writesPerWriter, + int watcherCount, + int contextReloaderThreads, + @TempDir Path tempDir) + throws Exception { + + new TestCase(tempDir) + .fileCount(fileCount) + .writerCount(writerCount) + .writesPerWriter(writesPerWriter) + .watcherCount(watcherCount) + .contextReloaderThreads(contextReloaderThreads) + .run(); + } + + @Disabled( + "manual test due to long and potentially flacky execution: ensures that reload function is executed under heavy write load") + @ParameterizedTest(name = "{0}") + @MethodSource("manualParameters") + void testManuallyFileWatchingScenarios( + String testName, + int fileCount, + int writerCount, + int writesPerWriter, + int watcherCount, + int contextReloaderThreads, + @TempDir Path tempDir) + throws Exception { + + new TestCase(tempDir) + .fileCount(fileCount) + .writerCount(writerCount) + .writesPerWriter(writesPerWriter) + .watcherCount(watcherCount) + .contextReloaderThreads(contextReloaderThreads) + .waitMultiplicator(20) + .run(); + } + + static class TestCase { + private final Path tempDir; + + private ExecutorService writerExecutor; + private ExecutorService contextReloaderThreadPool; + private final List detailedLogs = Collections.synchronizedList(new ArrayList<>()); + + private int fileCount = 1; + private int writerCount = 1; + private int writesPerWriter = 1; + private int listenersCount = 1; + private int contextReloaderThreads = 1; + private int waitMultiplicator = 1; + + public TestCase(Path tempDir) { + this.tempDir = tempDir; + } + + private void setUp() { + writerExecutor = Executors.newFixedThreadPool(10); + contextReloaderThreadPool = + Executors.newFixedThreadPool(listenersCount * contextReloaderThreads); + } + + private void tearDown() throws InterruptedException { + shutdownGracefully(writerExecutor); + shutdownGracefully(contextReloaderThreadPool); + } + + private void shutdownGracefully(ExecutorService executor) throws InterruptedException { + if (executor != null && !executor.isShutdown()) { + executor.shutdown(); + if (!executor.awaitTermination(5L * waitMultiplicator, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } + } + + private void assertWithDetailedLogs(boolean condition, String message) { + try { + assertTrue(condition, message); + } catch (AssertionError e) { + detailedLogs.forEach(LOG::error); + throw e; + } + } + + public TestCase fileCount(int fileCount) { + this.fileCount = fileCount; + return this; + } + + public TestCase writerCount(int writerCount) { + this.writerCount = writerCount; + return this; + } + + public TestCase writesPerWriter(int writesPerWriter) { + this.writesPerWriter = writesPerWriter; + return this; + } + + public TestCase watcherCount(int watcherCount) { + this.listenersCount = watcherCount; + return this; + } + + public TestCase contextReloaderThreads(int contextReloaderThreads) { + this.contextReloaderThreads = contextReloaderThreads; + return this; + } + + public TestCase waitMultiplicator(int waitMultiplicator) { + this.waitMultiplicator = waitMultiplicator; + return this; + } + + public void run() throws Exception { + if (tempDir == null) { + throw new IllegalStateException("tempDir must be provided to use run() method"); + } + setUp(); + try { + executeTest(); + } finally { + tearDown(); + } + } + + private void executeTest() throws Exception { + detailedLogs.add( + String.format( + "Test started with parameters: fileCount=%d, writerCount=%d, writesPerWriter=%d, watcherCount=%d", + fileCount, writerCount, writesPerWriter, listenersCount)); + startLocalFSWatchService(); + LocalFSDirectoryWatcher localFsWatchSingleton = LocalFSWatchSingleton.getInstance(); + + Path[] testFiles = createFiles(); + + Map contextReloadCounts = new ConcurrentHashMap<>(); + Map contextReloaderAttemptCounts = new ConcurrentHashMap<>(); + + // Each watcher will have multiple context reloaders that all try to reload context in + // parallel + CountDownLatch contextReloaderThreadsRunning = + new CountDownLatch(listenersCount * contextReloaderThreads); + AtomicBoolean stopContextReloader = new AtomicBoolean(false); + + startWatchers( + contextReloadCounts, + contextReloaderAttemptCounts, + localFsWatchSingleton, + contextReloaderThreadsRunning, + stopContextReloader); + + detailedLogs.add( + "Waiting for ContextReloader threads to be running (contextReloaderThreadsRunning)"); + boolean contextReloaderThreadsStarted = + contextReloaderThreadsRunning.await(5L * waitMultiplicator, TimeUnit.SECONDS); + detailedLogs.add( + String.format( + "ContextReloader threads started result: %s (remaining count: %d)", + contextReloaderThreadsStarted, + contextReloaderThreadsRunning.getCount())); + assertWithDetailedLogs( + contextReloaderThreadsStarted, + "ContextReloader threads did not start within timeout"); + + CountDownLatch writersFinishedLatch = writeToFiles(testFiles); + + detailedLogs.add("Waiting for writers to finish (writersFinishedLatch)"); + boolean writersFinished = + writersFinishedLatch.await(10L * waitMultiplicator, TimeUnit.SECONDS); + detailedLogs.add( + String.format( + "Writers finished result: %s (remaining count: %d)", + writersFinished, writersFinishedLatch.getCount())); + assertWithDetailedLogs(writersFinished, "Writers did not complete within timeout"); + detailedLogs.add("All writers finished"); + + // Stop watcher threads + detailedLogs.add("Stopping context reloaders threads"); + stopContextReloader.set(true); + contextReloaderThreadPool.shutdown(); + detailedLogs.add("Waiting for context reloader thread pool termination"); + boolean watcherExecutorTerminated = + contextReloaderThreadPool.awaitTermination( + 10L * waitMultiplicator, TimeUnit.SECONDS); + detailedLogs.add( + String.format( + "Watcher executor terminated result: %s", watcherExecutorTerminated)); + assertWithDetailedLogs(watcherExecutorTerminated, "Watcher threads did not stop"); + detailedLogs.add("All watcher threads stopped"); + + int totalReloads = + contextReloadCounts.values().stream().mapToInt(AtomicInteger::get).sum(); + int totalAttempts = + contextReloaderAttemptCounts.values().stream() + .mapToInt(AtomicInteger::get) + .sum(); + int totalFileWrites = writerCount * writesPerWriter * fileCount; + + logDetailedResult( + totalFileWrites, + totalAttempts, + totalReloads, + contextReloadCounts, + contextReloaderAttemptCounts); + + assertWithDetailedLogs( + totalReloads > 0, + "Expected at least one context reload, but got " + totalReloads); + + // Verify that we have more attempts than successful reloads (proving concurrency + // control works) + assertWithDetailedLogs( + totalAttempts > totalReloads, + String.format( + "Expected more attempts (%d) than successful reloads (%d) due to concurrent access control", + totalAttempts, totalReloads)); + + // Verify that reloads are distributed (not all done by one watcher) + long contextsWithReload = + contextReloadCounts.values().stream() + .mapToInt(AtomicInteger::get) + .filter(count -> count > 0) + .count(); + + // Verify that each watcher receives at least one modification event + for (int currentContextReloaderId = 0; + currentContextReloaderId < listenersCount; + currentContextReloaderId++) { + int reloadCount = contextReloadCounts.get(currentContextReloaderId).get(); + if (reloadCount == 0) { + detailedLogs.add( + String.format( + "ASSERTION FAILED: %d did not receive any modification events", + currentContextReloaderId)); + } + assertWithDetailedLogs( + reloadCount > 0, + String.format( + "Expected watcher with ID: %d to receive at least one modification event and perform a reload, but got %d reloads; total reloads - %s", + currentContextReloaderId, reloadCount, contextReloadCounts)); + } + + if (totalReloads >= 2) { + assertWithDetailedLogs( + contextsWithReload >= 1, + "Expected at least 1 watcher to perform reloads, but only " + + contextsWithReload + + " did"); + } + } + + private void startWatchers( + Map contextReloadCounts, + Map contextReloaderAttemptCounts, + LocalFSDirectoryWatcher localFsWatchSingleton, + CountDownLatch contextReloaderThreadsRunning, + AtomicBoolean stopContextReloader) + throws IOException { + for (int listenerId = 0; listenerId < listenersCount; listenerId++) { + final int curListenerId = listenerId; + contextReloadCounts.put(curListenerId, new AtomicInteger(0)); + contextReloaderAttemptCounts.put(curListenerId, new AtomicInteger(0)); + + LocalFSWatchServiceListener listener = + new LocalFSWatchServiceListener.AbstractLocalFSWatchServiceListener() { + public void onWatchStarted(Path relativePath) { + detailedLogs.add( + String.format( + "Listener-%d started to listen file modification: %s", + curListenerId, relativePath)); + } + + @Override + public void onFileOrDirectoryModified(Path relativePath) { + super.onFileOrDirectoryModified(relativePath); + detailedLogs.add( + String.format( + "Listener-%d detected file modification: %s; watchers - %s", + curListenerId, + relativePath, + localFsWatchSingleton.getWatchers())); + } + + @Override + public String toString() { + return String.format( + "TestLocalFSWatchServiceListener{id=%d, RELOAD_STATE=%s}", + curListenerId, getReloadStateReference().get()); + } + }; + + // Register the fsListener to watch the same path + localFsWatchSingleton.registerDirectory(new Path[] {tempDir}, listener); + detailedLogs.add( + String.format( + "Listener-%d is registered; current watchers is - %s", + curListenerId, localFsWatchSingleton.getWatchers())); + + // single listener, but we try to reload from different threads + // expect single reload per listener + startContextReloader( + contextReloaderThreadsRunning, + curListenerId, + stopContextReloader, + contextReloaderAttemptCounts, + listener, + contextReloadCounts); + } + } + + private void logDetailedResult( + int totalFileWrites, + int totalAttempts, + int totalReloads, + Map contextReloadCounts, + Map contextReloaderAttemptCounts) { + detailedLogs.add(String.format("Total file writes: %d", totalFileWrites)); + detailedLogs.add(String.format("Total context reload attempts: %d", totalAttempts)); + detailedLogs.add(String.format("Total successful context reloads: %d", totalReloads)); + } + + private CountDownLatch writeToFiles(Path[] testFiles) { + CountDownLatch writersFinishedLatch = new CountDownLatch(writerCount); + for (int writerId = 0; writerId < writerCount; writerId++) { + final int currentWriterId = writerId; + writerExecutor.submit( + runnableWriter(testFiles, currentWriterId, writersFinishedLatch)); + } + return writersFinishedLatch; + } + + private Runnable runnableWriter( + Path[] testFiles, int currentWriterId, CountDownLatch writersFinishedLatch) { + return () -> { + try { + detailedLogs.add(String.format("Writer %d starting", currentWriterId)); + for (int writeNum = 0; writeNum < writesPerWriter; writeNum++) { + writeFiles(testFiles, writeNum); + detailedLogs.add( + String.format( + "Writer %d completed batch %d (wrote to %d files)", + currentWriterId, writeNum, testFiles.length)); + Thread.sleep(100L); // Delay between batches to let watchers process + } + detailedLogs.add( + String.format("Writer %d finished all writes", currentWriterId)); + } catch (Exception e) { + detailedLogs.add( + String.format("Writer %d failed: %s", currentWriterId, e.getMessage())); + fail("Writer failed: " + e.getMessage()); + } finally { + writersFinishedLatch.countDown(); + } + }; + } + + private static void writeFiles(Path[] testFiles, int writeNum) { + final int finalWriteNum = writeNum; + Arrays.stream(testFiles) + .forEach( + f -> { + try { + Files.write( + f, + ("Write " + finalWriteNum + "\n").getBytes(), + StandardOpenOption.APPEND); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private void startContextReloader( + CountDownLatch contextReloaderThreadsRunning, + int currentContextReloaderId, + AtomicBoolean stopContextReloader, + Map contextReloaderAttemptCounts, + LocalFSWatchServiceListener listener, + Map contextReloadCounts) { + // Start multiple threads for each fsListener to try reloading context in parallel + // Emulate different subsystem which would interact with a Service and perform the + // context reload if needed + for (int threadIdx = 0; threadIdx < contextReloaderThreads; threadIdx++) { + final int currentThreadIdx = threadIdx; + contextReloaderThreadPool.submit( + runnableForContextReloader( + contextReloaderThreadsRunning, + currentContextReloaderId, + stopContextReloader, + contextReloaderAttemptCounts, + listener, + contextReloadCounts, + currentThreadIdx)); + } + } + + private Runnable runnableForContextReloader( + CountDownLatch contextReloaderThreadsRunning, + int currentContextReloaderId, + AtomicBoolean stopContextReloader, + Map contextReloaderAttemptCounts, + LocalFSWatchServiceListener listener, + Map contextReloadCounts, + int currentThreadIdx) { + return () -> { + contextReloaderThreadsRunning.countDown(); + detailedLogs.add( + String.format( + "ContextReloader-%d Thread-%d starting", + currentContextReloaderId, currentThreadIdx)); + + AtomicInteger threadReloadAttempts = new AtomicInteger(0); + AtomicInteger extraIterations = new AtomicInteger(0); + while (!stopContextReloader.get() + || extraIterations.get() < EXTRA_WATCHER_ITERATIONS) { + try { + int currentAttempt = threadReloadAttempts.incrementAndGet(); + contextReloaderAttemptCounts + .get(currentContextReloaderId) + .incrementAndGet(); + + if (stopContextReloader.get()) { + // If stopContextReloader is set, start counting extra + // iterations + extraIterations.incrementAndGet(); + } + + // Context will be reloaded only if fsListener detected file + // state change and mark flag as DIRTY + boolean contextWasReloaded = + reloadContextIfNeeded( + currentContextReloaderId, + listener, + currentThreadIdx, + currentAttempt); + + if (contextWasReloaded) { + int reloadCount = + contextReloadCounts + .get(currentContextReloaderId) + .incrementAndGet(); + detailedLogs.add( + String.format( + "ContextReloader-%d Thread-%d successfully reloaded context (fsListener reload #%d, thread attempt #%d, watchers - %s)", + currentContextReloaderId, + currentThreadIdx, + reloadCount, + currentAttempt, + LocalFSWatchSingleton.getInstance().getWatchers())); + } + + Thread.sleep(10L); + } catch (InterruptedException e) { + detailedLogs.add( + String.format( + "ContextReloader-%d Thread-%d interrupted", + currentContextReloaderId, currentThreadIdx)); + Thread.currentThread().interrupt(); + break; + } catch (Exception e) { + detailedLogs.add( + String.format( + "ContextReloader-%d Thread-%d failed to reload context: %s", + currentContextReloaderId, + currentThreadIdx, + e.getMessage())); + } + } + detailedLogs.add( + String.format( + "ContextReloader-%d Thread-%d stopping (thread attempts: %d)", + currentContextReloaderId, + currentThreadIdx, + threadReloadAttempts.get())); + }; + } + + private boolean reloadContextIfNeeded( + int currentContextReloaderId, + LocalFSWatchServiceListener listener, + int currentThreadIdx, + int currentAttempt) { + return listener.reloadContextIfNeeded( + () -> { + detailedLogs.add( + String.format( + "ContextReloader-%d Thread-%d performing context reload (attempt %d)", + currentContextReloaderId, + currentThreadIdx, + currentAttempt)); + try { + long sleepMs = + (10 + currentContextReloaderId * 2L + currentThreadIdx) % 50; + Thread.sleep(sleepMs); // Different reload times per thread + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } + + private Path[] createFiles() throws IOException { + Path[] testFiles = new Path[fileCount]; + for (int i = 0; i < fileCount; i++) { + testFiles[i] = tempDir.resolve("testfile_" + i + ".txt"); + Files.createFile(testFiles[i]); + } + detailedLogs.add("Files are created in " + tempDir); + return testFiles; + } + + private void startLocalFSWatchService() throws InterruptedException { + LocalFSWatchService watchService = new LocalFSWatchService(); + watchService.setDaemon(true); + watchService.start(); + while (!watchService.running.get()) { + Thread.sleep(100L); + } + detailedLogs.add("LocalFSWatchService started and running"); + } + } +} diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java index 4b1d318e96c4d..f897305fd8f1f 100644 --- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/CustomSSLEngineProvider.java @@ -17,6 +17,7 @@ package org.apache.flink.runtime.rpc.pekko; +import org.apache.flink.core.security.watch.LocalFSDirectoryWatcher; import org.apache.flink.core.security.watch.LocalFSWatchSingleton; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.util.FingerprintTrustManagerFactory; @@ -46,19 +47,20 @@ public class CustomSSLEngineProvider implements SSLEngineProvider { private final SSLContextLoader sslContextLoader; public CustomSSLEngineProvider(ActorSystem system) throws IOException { - final Config securityConfig = + final Config pekkoSecurityConfig = system.settings().config().getConfig("pekko.remote.classic.netty.ssl.security"); - sslTrustStore = securityConfig.getString("trust-store"); - String sslKeyStore = securityConfig.getString("key-store"); - sslEnabledAlgorithms = securityConfig.getStringList("enabled-algorithms"); - sslProtocol = securityConfig.getString("protocol"); - sslRequireMutualAuthentication = securityConfig.getBoolean("require-mutual-authentication"); - Boolean sslEnabledCertReload = securityConfig.getBoolean("enabled-cert-reload"); - - sslContextLoader = new SSLContextLoader(sslTrustStore, sslProtocol, securityConfig); + sslTrustStore = pekkoSecurityConfig.getString("trust-store"); + String sslKeyStore = pekkoSecurityConfig.getString("key-store"); + sslEnabledAlgorithms = pekkoSecurityConfig.getStringList("enabled-algorithms"); + sslProtocol = pekkoSecurityConfig.getString("protocol"); + sslRequireMutualAuthentication = + pekkoSecurityConfig.getBoolean("require-mutual-authentication"); + Boolean sslEnabledCertReload = pekkoSecurityConfig.getBoolean("enabled-cert-reload"); + + sslContextLoader = new SSLContextLoader(sslTrustStore, sslProtocol, pekkoSecurityConfig); if (sslEnabledCertReload) { - LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); - localFSWatchSingleton.registerPath( + LocalFSDirectoryWatcher localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerDirectory( new Path[] { Path.of(sslTrustStore).getParent(), Path.of(sslKeyStore).getParent() }, diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java index aef433b7affd0..fa4623e84b238 100644 --- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java @@ -359,7 +359,7 @@ private static void addSslRemoteConfig( Arrays.stream(sslAlgorithmsString.split(",")) .collect(Collectors.joining(",", "[", "]")); - final boolean enabledCertReloadConfig = SecurityOptions.isReloadCertificate(configuration); + final boolean enabledCertReloadConfig = SecurityOptions.isCertificateReloadEnabled(configuration); final String enabledCertReload = booleanToOnOrOff(enabledCertReloadConfig); final String sslEngineProviderName = CustomSSLEngineProvider.class.getCanonicalName(); diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java index e56bb534402d1..874a589ae44c3 100644 --- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoader.java @@ -37,7 +37,6 @@ import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; -import java.nio.file.Path; import java.nio.file.Paths; import java.security.GeneralSecurityException; import java.security.KeyManagementException; @@ -47,39 +46,37 @@ import java.security.SecureRandom; import java.security.UnrecoverableKeyException; import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -public class SSLContextLoader implements LocalFSWatchServiceListener { +public class SSLContextLoader + extends LocalFSWatchServiceListener.AbstractLocalFSWatchServiceListener { private static final Logger LOG = LoggerFactory.getLogger(SSLContextLoader.class); private final String sslTrustStore; private final String sslTrustStorePassword; + private final String sslKeyStore; + private final String sslKeyStorePassword; private final List sslCertFingerprints; private final String sslKeyStoreType; private final String sslTrustStoreType; private final String sslProtocol; - private final String sslKeyStore; - private final String sslKeyStorePassword; private final String sslKeyPassword; private final String sslRandomNumberGenerator; - private final AtomicBoolean toReload = new AtomicBoolean(false); - private volatile SSLContext sslContext; - public SSLContextLoader(String sslTrustStore, String sslProtocol, Config securityConfig) { + public SSLContextLoader(String sslTrustStore, String sslProtocol, Config pekkoSecurityConfig) { this.sslTrustStore = sslTrustStore; this.sslProtocol = sslProtocol; - this.sslTrustStorePassword = securityConfig.getString("trust-store-password"); - this.sslCertFingerprints = securityConfig.getStringList("cert-fingerprints"); - this.sslKeyStoreType = securityConfig.getString("key-store-type"); - this.sslTrustStoreType = securityConfig.getString("trust-store-type"); - this.sslKeyStore = securityConfig.getString("key-store"); - sslKeyStorePassword = securityConfig.getString("key-store-password"); - sslKeyPassword = securityConfig.getString("key-password"); - sslRandomNumberGenerator = securityConfig.getString("random-number-generator"); + this.sslTrustStorePassword = pekkoSecurityConfig.getString("trust-store-password"); + this.sslCertFingerprints = pekkoSecurityConfig.getStringList("cert-fingerprints"); + this.sslKeyStoreType = pekkoSecurityConfig.getString("key-store-type"); + this.sslTrustStoreType = pekkoSecurityConfig.getString("trust-store-type"); + this.sslKeyStore = pekkoSecurityConfig.getString("key-store"); + this.sslKeyStorePassword = pekkoSecurityConfig.getString("key-store-password"); + this.sslKeyPassword = pekkoSecurityConfig.getString("key-password"); + this.sslRandomNumberGenerator = pekkoSecurityConfig.getString("random-number-generator"); loadSSLContext(); } @@ -102,7 +99,7 @@ void loadSSLContext() { } public SSLEngine createSSLEngine() { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadSSLContext); return sslContext.createSSLEngine(); } @@ -117,17 +114,6 @@ public SecureRandom createSecureRandom() throws NoSuchAlgorithmException { return rng; } - @Override - public void onFileOrDirectoryModified(Path relativePath) { - toReload.set(true); - } - - private synchronized void reloadContextIfNeeded() { - if (toReload.compareAndSet(true, false)) { - loadSSLContext(); - } - } - /** Subclass may override to customize `KeyManager`. */ private KeyManager[] keyManagers() throws NoSuchAlgorithmException, UnrecoverableKeyException, KeyStoreException { diff --git a/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoaderTest.java b/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoaderTest.java new file mode 100644 index 0000000000000..8e1993f5240d1 --- /dev/null +++ b/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/pekko/SSLContextLoaderTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.rpc.pekko; + +import com.typesafe.config.Config; +import com.typesafe.config.ConfigFactory; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class SSLContextLoaderTest { + + @Test + public void testCreaseSSLContextLoaderWithUnexistedCertificates() throws Exception { + final Config pekkoSecurityConfig = pekkoConfig(""); + String sslTrustStore = pekkoSecurityConfig.getString("trust-store"); + + assertThatThrownBy( + () -> + new SSLContextLoader( + sslTrustStore, "sslProtocol", pekkoSecurityConfig)) + .isInstanceOf(RuntimeException.class) + .hasMessage("Cannot load SSL context"); + } + + @Test + public void testCreaseSSLContextLoaderWithWrongPekkoConfig() throws Exception { + final Config pekkoSecurityConfig = pekkoConfig("wrong"); + String sslTrustStore = pekkoSecurityConfig.getString("trust-store"); + + assertThatThrownBy( + () -> + new SSLContextLoader( + sslTrustStore, "sslProtocol", pekkoSecurityConfig)) + .isInstanceOf(RuntimeException.class) + .hasMessage( + "hardcoded value: No configuration setting found for key 'trust-store-password'"); + } + + private static Config pekkoConfig(String prefix) { + return ConfigFactory.parseMap( + Map.of( + "trust-store", + "non-trust-store", + prefix + "trust-store-password", + "ts-pwd-123", + prefix + "cert-fingerprints", + List.of("F1:INGER:PRINT:01", "F2:INGER:PRINT:02"), + prefix + "key-store-type", + "JKS", + prefix + "trust-store-type", + "JKS", + prefix + "key-store", + "/tmp/keystore.jks", + prefix + "key-store-password", + "ks-pwd-456", + prefix + "key-password", + "key-pwd-789", + prefix + "random-number-generator", + "SHA1PRNG")); + } +} diff --git a/flink-runtime/pom.xml b/flink-runtime/pom.xml index e0f11e4f2deaa..61a2a5beac4ce 100644 --- a/flink-runtime/pom.xml +++ b/flink-runtime/pom.xml @@ -442,7 +442,7 @@ under the License. - + diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java index 52b6fd41cdb19..338bbb2ccd556 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServer.java @@ -24,6 +24,7 @@ import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.security.watch.LocalFSDirectoryWatcher; import org.apache.flink.core.security.watch.LocalFSWatchSingleton; import org.apache.flink.runtime.dispatcher.cleanup.GloballyCleanableResource; import org.apache.flink.runtime.dispatcher.cleanup.LocallyCleanableResource; @@ -138,6 +139,8 @@ public class BlobServer extends Thread /** Timer task to execute the cleanup at regular intervals. */ private final Timer cleanupTimer; + private final boolean socketRecreationIsNeeded; + @VisibleForTesting public BlobServer(Configuration config, File storageDir, BlobStore blobStore) throws IOException { @@ -198,10 +201,13 @@ public BlobServer(Configuration config, Reference storageDir, BlobStore bl // ----------------------- start the server ------------------- + socketRecreationIsNeeded = + SecurityOptions.isInternalSSLEnabled(config) + && SecurityOptions.isCertificateReloadEnabled(config); blobServerSocket = new BlobServerSocket(config, backlog, maxConnections); if (SecurityOptions.isInternalSSLEnabled(config) && config.get(BlobServerOptions.SSL_ENABLED) - && SecurityOptions.isReloadCertificate(config)) { + && SecurityOptions.isCertificateReloadEnabled(config)) { String keystoreFilePath = config.get( SecurityOptions.SSL_INTERNAL_KEYSTORE, @@ -211,8 +217,8 @@ public BlobServer(Configuration config, Reference storageDir, BlobStore bl SecurityOptions.SSL_INTERNAL_TRUSTSTORE, config.get(SecurityOptions.SSL_TRUSTSTORE)); - LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); - localFSWatchSingleton.registerPath( + LocalFSDirectoryWatcher localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerDirectory( new Path[] { Path.of(keystoreFilePath).getParent(), Path.of(truststoreFilePath).getParent() @@ -293,7 +299,7 @@ ReadWriteLock getReadWriteLock() { public void run() { try { while (!this.shutdownRequested.get()) { - if (this.blobServerSocket.reloadContextIfNeeded()) { + if (socketRecreationIsNeeded && this.blobServerSocket.reloadContextIfNeeded()) { closeActiveConnections(); } BlobServerConnection conn = @@ -1026,10 +1032,17 @@ public boolean isShutdown() { } /** Access to the server socket, for testing. */ + @VisibleForTesting ServerSocket getServerSocket() { return this.blobServerSocket.getServerSocket(); } + /** Access to the reload counter, for testing. */ + @VisibleForTesting + int getReloadCounter() { + return this.blobServerSocket.getReloadCounter(); + } + void unregisterConnection(BlobServerConnection conn) { synchronized (activeConnections) { activeConnections.remove(conn); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java index 259c73390a9cc..fbff4ffbfda83 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobServerSocket.java @@ -34,13 +34,13 @@ import java.io.IOException; import java.net.InetAddress; import java.net.ServerSocket; -import java.nio.file.Path; import java.util.Collections; import java.util.Iterator; import java.util.concurrent.atomic.AtomicBoolean; /** This class implements socket management (open, close) for the BLOB server. */ -public class BlobServerSocket implements LocalFSWatchServiceListener { +public class BlobServerSocket + extends LocalFSWatchServiceListener.AbstractLocalFSWatchServiceListener { private static final Logger LOG = LoggerFactory.getLogger(BlobServerSocket.class); @@ -49,8 +49,9 @@ public class BlobServerSocket implements LocalFSWatchServiceListener { private final String serverPortRange; private ServerSocket serverSocket; private final int maxConnections; + private int reloadCounter = 0; + private final AtomicBoolean firstCreation = new AtomicBoolean(true); - private final AtomicBoolean toReload = new AtomicBoolean(false); public BlobServerSocket(Configuration config, int backlog, int maxConnections) throws IOException { @@ -62,27 +63,26 @@ public BlobServerSocket(Configuration config, int backlog, int maxConnections) createSocket(); } - @Override - public void onFileOrDirectoryModified(Path relativePath) { - toReload.set(true); - } - public ServerSocket getServerSocket() { return serverSocket; } + /** + * Recreates a socket with a new ssl certificates. + * + * @return true if socket was recreated, false otherwise + */ public synchronized boolean reloadContextIfNeeded() { - if (toReload.compareAndSet(true, false)) { - try { - close(); - createSocket(); - return true; - } catch (Exception e) { - LOG.warn("Failed to reload SSL context", e); - toReload.set(true); - } - } - return false; + return reloadContextIfNeeded(this::reloadContext); + } + + private void reloadContext() throws IOException { + LOG.info("Reloading blob server context."); + close(); + // in case of SSL reload, at this moment we cannot serve requests (we hope clients would + // retry) + createSocket(); + reloadCounter++; } private synchronized void createSocket() throws IOException { @@ -142,6 +142,10 @@ public int getPort() { return serverSocket.getLocalPort(); } + public int getReloadCounter() { + return reloadCounter; + } + public void close() throws IOException { if (serverSocket != null) { close(serverSocket); @@ -152,7 +156,7 @@ private void close(ServerSocket serverSocketToClose) throws IOException { if (LOG.isInfoEnabled()) { if (serverSocketToClose != null) { LOG.info( - "Stopped BLOB server at {}:{}", + "Stopped BLOB server socket at {}:{}", serverSocketToClose.getInetAddress().getHostAddress(), getPort()); } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java index 9d6f91d2900c0..125fe172396f2 100755 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/entrypoint/ClusterEntrypoint.java @@ -351,7 +351,7 @@ protected void initializeServices(Configuration configuration, PluginManager plu DeterminismEnvelope.nondeterministicValue( ResourceID.generate())); - if (SecurityOptions.isReloadCertificate(configuration)) { + if (SecurityOptions.isCertificateReloadEnabled(configuration)) { LOG.debug("Initialize local file system watch service for certificate reloading."); LocalFSWatchService localFSWatchService = new LocalFSWatchService(); localFSWatchService.start(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java index f85dc625c7db1..44f4117bf9c9d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/net/ReloadableSslContext.java @@ -37,11 +37,10 @@ import javax.net.ssl.SSLSessionContext; import javax.net.ssl.TrustManagerFactory; -import java.nio.file.Path; import java.util.Arrays; import java.util.List; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static org.apache.flink.runtime.net.SSLUtils.getEnabledCipherSuites; import static org.apache.flink.runtime.net.SSLUtils.getEnabledProtocols; @@ -59,7 +58,8 @@ public class ReloadableSslContext extends SslContext implements LocalFSWatchServ protected final SslProvider provider; protected volatile SslContext sslContext; - private final AtomicBoolean toReload = new AtomicBoolean(false); + private final AtomicReference reloadState = + new AtomicReference<>(ReloadState.CLEAN); public ReloadableSslContext( Configuration config, boolean clientMode, ClientAuth clientAuth, SslProvider provider) @@ -72,60 +72,49 @@ public ReloadableSslContext( } public SSLContext getSSLContext() { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadContext); return ((JdkSslContext) this.sslContext).context(); } @Override public boolean isClient() { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadContext); return sslContext.isClient(); } @Override public List cipherSuites() { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadContext); return sslContext.cipherSuites(); } @Override public ApplicationProtocolNegotiator applicationProtocolNegotiator() { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadContext); return sslContext.applicationProtocolNegotiator(); } @Override public SSLEngine newEngine(ByteBufAllocator byteBufAllocator) { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadContext); return sslContext.newEngine(byteBufAllocator); } @Override public SSLEngine newEngine(ByteBufAllocator byteBufAllocator, String s, int i) { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadContext); return sslContext.newEngine(byteBufAllocator, s, i); } @Override public SSLSessionContext sessionContext() { - reloadContextIfNeeded(); + reloadContextIfNeeded(this::loadContext); return sslContext.sessionContext(); } @Override - public void onFileOrDirectoryModified(Path relativePath) { - toReload.set(true); - } - - protected synchronized void reloadContextIfNeeded() { - if (toReload.compareAndSet(true, false)) { - try { - loadContext(); - } catch (Exception e) { - LOG.warn("Failed to reload SSL context", e); - toReload.set(true); - } - } + public AtomicReference getReloadStateReference() { + return reloadState; } protected void loadContext() throws Exception { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java index eba3aa0d5ad4b..2beea6790f78b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/net/SSLUtils.java @@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.security.watch.LocalFSDirectoryWatcher; import org.apache.flink.core.security.watch.LocalFSWatchSingleton; import org.apache.flink.runtime.io.network.netty.SSLHandlerFactory; import org.apache.flink.util.StringUtils; @@ -367,7 +368,7 @@ private static SslContext createInternalNettySSLContext( ReloadableJdkSslContext reloadableJdkSslContext = new ReloadableJdkSslContext(config, clientMode, provider); - if (SecurityOptions.isReloadCertificate(config) && watchForCertificateChange) { + if (SecurityOptions.isCertificateReloadEnabled(config) && watchForCertificateChange) { HashSet certificatePaths = new HashSet<>(); certificatePaths.add( Path.of( @@ -385,8 +386,8 @@ private static SslContext createInternalNettySSLContext( } Path[] pathsToWatch = new Path[certificatePaths.size()]; certificatePaths.toArray(pathsToWatch); - LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); - localFSWatchSingleton.registerPath(pathsToWatch, reloadableJdkSslContext); + LocalFSDirectoryWatcher localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerDirectory(pathsToWatch, reloadableJdkSslContext); } return reloadableJdkSslContext; @@ -434,7 +435,7 @@ public static SslContext createRestNettySSLContext( ReloadableSslContext reloadableSslContext = new ReloadableSslContext(config, clientMode, clientAuth, provider); - if (SecurityOptions.isReloadCertificate(config)) { + if (SecurityOptions.isCertificateReloadEnabled(config)) { HashSet certificatePaths = new HashSet<>(); String keystoreFilePath = config.get( @@ -453,8 +454,8 @@ public static SslContext createRestNettySSLContext( Path[] pathsToWatch = new Path[certificatePaths.size()]; certificatePaths.toArray(pathsToWatch); - LocalFSWatchSingleton localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); - localFSWatchSingleton.registerPath(pathsToWatch, reloadableSslContext); + LocalFSDirectoryWatcher localFSWatchSingleton = LocalFSWatchSingleton.getInstance(); + localFSWatchSingleton.registerDirectory(pathsToWatch, reloadableSslContext); } return reloadableSslContext; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java index b15d8622a3163..8577f47e6b82e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerRunner.java @@ -193,7 +193,7 @@ private void startTaskManagerRunnerServices() throws Exception { Hardware.getNumberCPUCores(), new ExecutorThreadFactory("taskmanager-future")); - if (SecurityOptions.isReloadCertificate(configuration)) { + if (SecurityOptions.isCertificateReloadEnabled(configuration)) { LOG.debug("Initialize local file system watch service for certificate reloading."); LocalFSWatchService localFSWatchService = new LocalFSWatchService(); localFSWatchService.start(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslReloadTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslReloadTest.java new file mode 100644 index 0000000000000..8f247a059bec3 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslReloadTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.blob; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.security.watch.LocalFSWatchService; +import org.apache.flink.core.security.watch.LocalFSWatchServiceListener; +import org.apache.flink.core.security.watch.LocalFSWatchSingleton; +import org.apache.flink.runtime.net.SSLUtilsTest; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.WatchKey; +import java.nio.file.WatchService; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; + +public class BlobClientSslReloadTest { + + private static final Logger LOG = LoggerFactory.getLogger(BlobClientSslReloadTest.class); + + private static final Duration TIMEOUT = Duration.ofSeconds(20); + + private static BlobServer blobReloadableSslServer; + + private static Configuration reloadableSslClientConfig; + + private static TestLocalFSWatchService watchService; + + @TempDir static java.nio.file.Path tempDir; + + @BeforeAll + static void startReloadableSSLServer() throws IOException { + Configuration config = + SSLUtilsTest.createInternalSslConfigWithKeyAndTrustStores( + SecurityOptions.SSL_PROVIDER.defaultValue()); + config.set(SecurityOptions.SSL_RELOAD, true); + + blobReloadableSslServer = + TestingBlobUtils.createServer(tempDir.resolve("realoadable_ssl"), config); + blobReloadableSslServer.start(); + + reloadableSslClientConfig = config; + } + + @BeforeAll + static void startLocalFSWatchService() throws InterruptedException { + watchService = new TestLocalFSWatchService(); + watchService.start(); + } + + /** Shuts the BLOB server down. */ + @AfterAll + static void stopServers() throws IOException { + if (blobReloadableSslServer != null) { + blobReloadableSslServer.close(); + } + } + + /** Verify that blob server doesn't run watchers to watch the ssl certificates change. */ + @Test + public void testWatchersRegistered() throws Exception { + LocalFSWatchSingleton watchSingleton = + (LocalFSWatchSingleton) LocalFSWatchSingleton.getInstance(); + assertThat(watchSingleton.getWatchers().size()).isGreaterThan(0); + } + + static Stream sslReloadTestParameters() { + return Stream.of( + Arguments.of(true, true, "both keystore and truststore"), + Arguments.of(true, false, "keystore only"), + Arguments.of(false, true, "truststore only")); + } + + /** Verify ssl client to ssl server upload with different certificate modification scenarios. */ + @ParameterizedTest + @MethodSource("sslReloadTestParameters") + public void testUploadJarFilesHelperReloadable( + boolean touchKeyStore, boolean touchTrustStore, String description) throws Exception { + int initialReloadCounter = prepare(); + int watchServiceReloadCounter = watchService.getServerSideReloadCounter(); + + LOG.debug( + "Testing SSL reload scenario: {}; initialReloadCounter={}", + description, + initialReloadCounter); + + // Touch the specified certificate files + if (touchKeyStore) { + SSLUtilsTest.touchKeyStore(); + } + if (touchTrustStore) { + SSLUtilsTest.touchTrustStore(); + } + + LOG.debug("Modified SSL certificate files for: {}", description); + + waitServerSideWatchEventReceived(watchServiceReloadCounter); + assertSslReloaded(initialReloadCounter); + } + + private static void assertSslReloaded(int initialReloadCounter) throws Exception { + LOG.debug("Initiating another file upload, which should lead to the context reload"); + + // Retry file upload with exponential backoff to handle SSL reload timing issues + uploadJarFileWithRetry(blobReloadableSslServer, reloadableSslClientConfig, 3, 100); + + // wait when server reloads certificates + assertTimeoutPreemptively( + TIMEOUT, + () -> { + while (blobReloadableSslServer.getReloadCounter() == initialReloadCounter) { + Thread.sleep(100); + } + assertThat(blobReloadableSslServer.getReloadCounter()) + .withFailMessage( + "Expect ssl changes to be reloaded for BlobServer in " + + TIMEOUT) + .isGreaterThan(initialReloadCounter); + }); + } + + private static void uploadJarFileWithRetry( + BlobServer server, Configuration config, int maxRetries, long baseDelayMs) + throws Exception { + Exception lastException = null; + + for (int attempt = 0; attempt < maxRetries; attempt++) { + try { + LOG.debug("Upload attempt {} of {}", attempt, maxRetries); + BlobClientTest.uploadJarFile(server, config); + LOG.debug("Upload successful on attempt {}", attempt); + return; // Success, exit retry loop + } catch (Exception e) { + lastException = e; + String errorMsg = e.getMessage(); + + // Check if this is a retryable SSL/connection error + if (isRetryableError(e)) { + long delayMs = baseDelayMs * (1L << (attempt - 1)); // Exponential backoff + LOG.debug( + "Upload failed on attempt {} with retryable error: {}. Retrying in {}ms", + attempt, + errorMsg, + delayMs); + Thread.sleep(delayMs); + } else { + LOG.warn( + "Upload failed on attempt {} with non-retryable error or max retries reached: {}", + attempt, + errorMsg); + break; + } + } + } + + // If we get here, all retries failed + throw new Exception("File upload failed after " + maxRetries + " attempts", lastException); + } + + private static boolean isRetryableError(Exception e) { + String message = e.getMessage(); + if (message == null) { + return false; + } + + // Check for SSL reload related errors + return message.contains("Broken pipe") + || message.contains("Connection reset") + || message.contains("Connection refused") + || message.contains("Socket closed") + || message.contains("SSL handshake") + || e instanceof java.net.SocketException + || e instanceof java.io.IOException + && (message.contains("PUT operation failed") + || message.contains("Connection or inbound has closed")); + } + + private static void waitServerSideWatchEventReceived(int watchServiceReloadCounter) { + assertTimeoutPreemptively( + TIMEOUT, + () -> { + while (watchService.getServerSideReloadCounter() == watchServiceReloadCounter) { + Thread.sleep(100); + } + assertThat(watchService.getServerSideReloadCounter()) + .withFailMessage( + "Expect sll changes by FileWatcher to be reloaded in " + + TIMEOUT) + .isGreaterThan(watchServiceReloadCounter); + }); + LOG.debug("SSL file modifications are seen"); + } + + private static int prepare() throws Exception { + LOG.debug("Initial upload of jar files"); + + BlobClientTest.uploadJarFile(blobReloadableSslServer, reloadableSslClientConfig); + + return blobReloadableSslServer.getReloadCounter(); + } + + private static class TestLocalFSWatchService extends LocalFSWatchService { + + AtomicInteger serverSideReloadCounter = new AtomicInteger(0); + + TestLocalFSWatchService() { + super(Duration.ofMillis(0)); + } + + protected void processWatchKey( + Map.Entry entry, WatchKey watchKey) { + super.processWatchKey(entry, watchKey); + LOG.debug("Watch key has been processed for {}", entry); + if (entry.getValue() instanceof BlobServerSocket) { + serverSideReloadCounter.incrementAndGet(); + } + } + + public int getServerSideReloadCounter() { + return serverSideReloadCounter.get(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java index 8cdce6e95b7a3..5a25307c139fe 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/blob/BlobClientSslTest.java @@ -21,6 +21,7 @@ import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.security.watch.LocalFSWatchSingleton; import org.apache.flink.runtime.net.SSLUtilsTest; import org.junit.jupiter.api.AfterAll; @@ -29,6 +30,7 @@ import java.io.IOException; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; /** This class contains unit tests for the {@link BlobClient} with ssl enabled. */ @@ -157,4 +159,12 @@ public void testNonSSLConnection3() throws Exception { public void testNonSSLConnection4() throws Exception { uploadJarFile(blobNonSslServer, nonSslClientConfig); } + + /** Verify that blob server doesn't run watchers to watch the ssl certificates change. */ + @Test + public void testNoWatchersRegistered() throws Exception { + LocalFSWatchSingleton watchSingleton = + (LocalFSWatchSingleton) LocalFSWatchSingleton.getInstance(); + assertThat(watchSingleton.getWatchers().size()).isEqualTo(0); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java index 183665fcb6a3c..ceb4263b01f97 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java @@ -36,9 +36,12 @@ import javax.net.ssl.SSLServerSocket; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.net.ServerSocket; import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.attribute.FileTime; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.MessageDigest; @@ -515,6 +518,17 @@ public static String getRestCertificateFingerprint( return getSha1Fingerprint(keyStore.getCertificate(certificateAlias)); } + public static void touchKeyStore() throws IOException { + FileTime newTime = FileTime.fromMillis(System.currentTimeMillis()); + Files.setLastModifiedTime(Paths.get(KEY_STORE_PATH), newTime); + Files.setLastModifiedTime(Paths.get(TRUST_STORE_PATH), newTime); + } + + public static void touchTrustStore() throws IOException { + FileTime newTime = FileTime.fromMillis(System.currentTimeMillis()); + Files.setLastModifiedTime(Paths.get(TRUST_STORE_PATH), newTime); + } + private static void addSslProviderConfig(Configuration config, String sslProvider) { if (sslProvider.equalsIgnoreCase("OPENSSL")) { OpenSsl.ensureAvailability(); From 4dedf3b0b390b438b0b217f7a2c52e0f50ede182 Mon Sep 17 00:00:00 2001 From: "oleksandr.nitavskyi" Date: Thu, 2 Oct 2025 15:05:23 +0200 Subject: [PATCH 3/4] [FLINK-37504] Develop an Integration test for SSL certificate reload * test Blob server certificate exposure and reload in case of different ssl configs * check files are not read after certificate reload when reload/ssl is disabled --- flink-end-to-end-tests/flink-ssl-test/pom.xml | 65 +++ .../flink/ssl/tests/NoSslNoReloadIT.java | 72 +++ .../ssl/tests/SslEndToEndITCaseBase.java | 382 +++++++++++++ .../apache/flink/ssl/tests/SslNoReloadIT.java | 69 +++ .../apache/flink/ssl/tests/SslTestUtils.java | 541 ++++++++++++++++++ .../flink/ssl/tests/SslWithReloadIT.java | 87 +++ .../src/test/resources/log4j2-test.properties | 28 + flink-end-to-end-tests/pom.xml | 1 + .../flink/runtime/rpc/pekko/PekkoUtils.java | 3 +- 9 files changed, 1247 insertions(+), 1 deletion(-) create mode 100644 flink-end-to-end-tests/flink-ssl-test/pom.xml create mode 100644 flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/NoSslNoReloadIT.java create mode 100644 flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslEndToEndITCaseBase.java create mode 100644 flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslNoReloadIT.java create mode 100644 flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslTestUtils.java create mode 100644 flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java create mode 100644 flink-end-to-end-tests/flink-ssl-test/src/test/resources/log4j2-test.properties diff --git a/flink-end-to-end-tests/flink-ssl-test/pom.xml b/flink-end-to-end-tests/flink-ssl-test/pom.xml new file mode 100644 index 0000000000000..03548509143be --- /dev/null +++ b/flink-end-to-end-tests/flink-ssl-test/pom.xml @@ -0,0 +1,65 @@ + + + + + flink-end-to-end-tests + org.apache.flink + 2.2-SNAPSHOT + + + 4.0.0 + + flink-ssl-test + Flink : E2E Tests : SSL Test + + + + org.apache.flink + flink-core + ${project.version} + test + + + org.apache.flink + flink-runtime + ${project.version} + test + + + org.apache.flink + flink-runtime + ${project.version} + test-jar + test + + + org.apache.flink + flink-end-to-end-tests-common + ${project.version} + test + + + org.apache.flink + flink-test-utils-junit + + + + \ No newline at end of file diff --git a/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/NoSslNoReloadIT.java b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/NoSslNoReloadIT.java new file mode 100644 index 0000000000000..579cc91b6739a --- /dev/null +++ b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/NoSslNoReloadIT.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ssl.tests; + +import org.apache.flink.tests.util.flink.ClusterController; + +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.attribute.FileTime; +import java.util.Optional; + +import static org.junit.Assert.assertTrue; + +/** + * End-to-end test for SSL disabled scenario. This test verifies that when SSL is disabled, + * certificates are not used and certificate files are not accessed. + */ +public class NoSslNoReloadIT extends SslEndToEndITCaseBase { + + public NoSslNoReloadIT() throws IOException { + super(false, false); + } + + /** + * Test Flink operations with SSL disabled. Verifies that certificates are NOT accessible and + * certificate files are NOT accessed when SSL is disabled. + */ + @Test + public void testWithSslDisabled() throws Exception { + LOG.info("Starting SSL end-to-end test: SSL disabled"); + + // Start Flink cluster with the SSL configuration set in constructor + try (ClusterController ignored = flinkResource.startCluster(1)) { + final FlinkPorts ports = getAllPorts(); + + // Verify certificate on RPC port is NOT accessible + final Optional maybeCertDate = + getSslCertExpirationDate(ports.getJobManagerRpcPort()); + assertTrue( + "No certificates on rpc port should be accessible when SSL is disabled: " + + maybeCertDate, + maybeCertDate.isEmpty()); + + LOG.info("Generating new SSL certificates with {}-day validity", NEW_VALIDITY_DAYS); + SslTestUtils.generateAndInstallCertificates( + internalSslDir, SSL_PASSWORD, NEW_VALIDITY_DAYS); + FileTime keystoreAccessTimeBefore = + getFileAccessTime(internalSslDir.resolve(KEYSTORE_FILENAME)); + FileTime truststoreAccessTimeBefore = + getFileAccessTime(internalSslDir.resolve(TRUSTSTORE_FILENAME)); + + verifyCertificatesAreNotAccessed(keystoreAccessTimeBefore, truststoreAccessTimeBefore); + } + } +} diff --git a/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslEndToEndITCaseBase.java b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslEndToEndITCaseBase.java new file mode 100644 index 0000000000000..1eada6dd5f04b --- /dev/null +++ b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslEndToEndITCaseBase.java @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ssl.tests; + +import org.apache.flink.configuration.BlobServerOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.tests.util.flink.FlinkResource; +import org.apache.flink.tests.util.flink.FlinkResourceSetup; +import org.apache.flink.tests.util.flink.LocalStandaloneFlinkResourceFactory; +import org.apache.flink.util.TestLogger; + +import org.junit.Rule; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.attribute.FileTime; +import java.time.Duration; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +/** + * Base class for SSL end-to-end tests. Contains shared fields and utility methods for SSL testing. + */ +public abstract class SslEndToEndITCaseBase extends TestLogger { + protected static final Logger LOG = LoggerFactory.getLogger(SslEndToEndITCaseBase.class); + + protected static final int INITIAL_VALIDITY_DAYS = 2; + protected static final int NEW_VALIDITY_DAYS = 365; + protected static final int RELOAD_CHECK_INTERVAL_MS = 10_000; + protected static final String SSL_PASSWORD = "password"; + protected static final Duration WAIT_MS = Duration.ofMillis(60_000); + protected static final String KEYSTORE_FILENAME = "node.keystore"; + protected static final String TRUSTSTORE_FILENAME = "ca.truststore"; + + // Fixed ports for deterministic testing + protected static final int BLOB_SERVER_PORT = 59873; + protected static final int JOB_MANAGER_RPC_PORT = 6123; + protected static final int NETTY_SERVER_PORT = 59874; + + @Rule public final FlinkResource flinkResource; + + protected final Path tempDir; + protected final Path internalSslDir; + + protected SslEndToEndITCaseBase(boolean sslEnabled, boolean sslReloadEnabled) + throws IOException { + // Create temp directory for SSL certificates + this.tempDir = java.nio.file.Files.createTempDirectory("flink-ssl-test-"); + this.internalSslDir = tempDir.resolve("ssl").resolve("internal"); + + SslTestUtils.generateAndInstallCertificates( + internalSslDir, SSL_PASSWORD, INITIAL_VALIDITY_DAYS); + + // Create SSL configuration + Configuration sslConfig = createSslConfiguration(tempDir, sslEnabled, sslReloadEnabled); + + // Create FlinkResource with SSL configuration + final FlinkResourceSetup.FlinkResourceSetupBuilder builder = FlinkResourceSetup.builder(); + builder.addConfiguration(sslConfig); + flinkResource = new LocalStandaloneFlinkResourceFactory().create(builder.build()); + } + + private Configuration createSslConfiguration( + Path sslDir, boolean sslEnabled, boolean sslReloadEnabled) { + Configuration config = new Configuration(); + + // Set fixed ports for deterministic testing + config.set(BlobServerOptions.PORT, String.valueOf(BLOB_SERVER_PORT)); + config.set(JobManagerOptions.PORT, JOB_MANAGER_RPC_PORT); + config.set(TaskManagerOptions.RPC_PORT, String.valueOf(NETTY_SERVER_PORT)); + config.set(SecurityOptions.SSL_INTERNAL_ENABLED, sslEnabled); + + if (sslEnabled) { + config.set(SecurityOptions.SSL_PROVIDER, "JDK"); + config.set(SecurityOptions.SSL_RELOAD, sslReloadEnabled); + + Path internalSslDir = sslDir.resolve("ssl").resolve("internal"); + config.set( + SecurityOptions.SSL_INTERNAL_KEYSTORE, + internalSslDir.resolve(KEYSTORE_FILENAME).toString()); + config.set(SecurityOptions.SSL_INTERNAL_KEYSTORE_PASSWORD, SSL_PASSWORD); + config.set(SecurityOptions.SSL_INTERNAL_KEY_PASSWORD, SSL_PASSWORD); + config.set( + SecurityOptions.SSL_INTERNAL_TRUSTSTORE, + internalSslDir.resolve(TRUSTSTORE_FILENAME).toString()); + config.set(SecurityOptions.SSL_INTERNAL_TRUSTSTORE_PASSWORD, SSL_PASSWORD); + } + + return config; + } + + /** + * Returns the configured Flink ports. + * + * @return FlinkPorts object containing all configured ports + */ + protected FlinkPorts getAllPorts() { + FlinkPorts ports = + new FlinkPorts(BLOB_SERVER_PORT, JOB_MANAGER_RPC_PORT, NETTY_SERVER_PORT); + LOG.info("Using configured ports: {}", ports); + return ports; + } + + protected Optional getSslCertExpirationDate(int port) throws InterruptedException { + LOG.info("Verifying initial certificate on port {}", port); + String[] initialDates = waitForCertificate("localhost", port); + if (initialDates == null) { + return Optional.empty(); + } + String initialNotAfter = initialDates[1]; + LOG.info("Initial certificate notAfter: {}", initialNotAfter); + return Optional.of(initialNotAfter); + } + + /** + * Retrieves certificate expiration dates for all Flink ports. + * + * @param ports the FlinkPorts object containing all port numbers + * @return CertificateDates object containing certificate dates for all ports + * @throws InterruptedException if interrupted while waiting for certificates + */ + protected CertificateDates getAllCertificateDates(FlinkPorts ports) + throws InterruptedException { + final Optional blobServerCertDate = + getSslCertExpirationDate(ports.getBlobServerPort()); + final Optional jobManagerRpcCertDate = + getSslCertExpirationDate(ports.getJobManagerRpcPort()); + final Optional nettyServerCertDate = + getSslCertExpirationDate(ports.getNettyServerPort()); + + CertificateDates certDates = + new CertificateDates( + blobServerCertDate, jobManagerRpcCertDate, nettyServerCertDate); + LOG.info("Retrieved certificate dates: {}", certDates); + return certDates; + } + + /** + * Waits for and retrieves new certificate dates for all Flink ports after reload. + * + * @param ports the FlinkPorts object containing all port numbers + * @param initialCertDates the initial certificate dates to compare against + * @return CertificateDates object containing new certificate dates for all ports + * @throws InterruptedException if interrupted while waiting for certificates + */ + protected CertificateDates getAllNewCertificateDates( + FlinkPorts ports, CertificateDates initialCertDates) throws InterruptedException { + final Optional blobServerCertDate = + getNewCertificateDate( + "localhost", + ports.getBlobServerPort(), + initialCertDates.getBlobServerCertDate().orElse("")); + final Optional jobManagerRpcCertDate = + getNewCertificateDate( + "localhost", + ports.getJobManagerRpcPort(), + initialCertDates.getJobManagerRpcCertDate().orElse("")); + final Optional nettyServerCertDate = + getNewCertificateDate( + "localhost", + ports.getNettyServerPort(), + initialCertDates.getNettyServerCertDate().orElse("")); + + CertificateDates newCertDates = + new CertificateDates( + blobServerCertDate, jobManagerRpcCertDate, nettyServerCertDate); + LOG.info("Retrieved new certificate dates: {}", newCertDates); + return newCertDates; + } + + /** + * Gets the file access time for a given path. + * + * @param path the path to the file + * @return the file's last access time + * @throws IOException if unable to read file attributes + */ + protected FileTime getFileAccessTime(Path path) throws IOException { + return (FileTime) Files.getAttribute(path, "lastAccessTime"); + } + + /** + * Waits for a certificate to become available on the given host and port. + * + * @param host the host to check + * @param port the port to check + * @return certificate validity dates [notBefore, notAfter] + * @throws InterruptedException if interrupted while waiting + */ + protected String[] waitForCertificate(String host, int port) throws InterruptedException { + long startTime = System.currentTimeMillis(); + while (System.currentTimeMillis() - startTime < WAIT_MS.toMillis()) { + String[] dates = SslTestUtils.getCertificateValidityDates(host, port); + if (dates != null) { + return dates; + } + LOG.info( + "Certificate not yet available, waiting... {} ms left", + WAIT_MS.toMillis() - (System.currentTimeMillis() - startTime)); + Thread.sleep(5_000); + } + return null; + } + + /** + * Waits for the certificate to be reloaded (notAfter date changes). + * + * @param host the host to check + * @param port the port to check + * @param initialCertDate the original date to compare against + * @return Optional containing the new ecpiration certificate date if reload occurred, empty + * otherwise + * @throws InterruptedException if interrupted while waiting + */ + protected Optional getNewCertificateDate(String host, int port, String initialCertDate) + throws InterruptedException { + long startTime = System.currentTimeMillis(); + int checkCount = 0; + + while (System.currentTimeMillis() - startTime < WAIT_MS.toMillis()) { + Thread.sleep(RELOAD_CHECK_INTERVAL_MS); + checkCount++; + + String[] dates = SslTestUtils.getCertificateValidityDates(host, port); + if (dates != null) { + String currentNotAfter = dates[1]; + LOG.info( + "Check #{}: Current certificate notAfter: {}", checkCount, currentNotAfter); + + if (!currentNotAfter.equals(initialCertDate)) { + LOG.info( + "Certificate reload detected after {} ms!", + System.currentTimeMillis() - startTime); + return Optional.of(currentNotAfter); + } + } else { + LOG.warn("Could not retrieve certificate on check #{}", checkCount); + } + } + + LOG.warn("Certificate reload not detected within {} ms", WAIT_MS.toMillis()); + return Optional.empty(); + } + + protected void verifyCertificatesAreNotAccessed( + FileTime keystoreAccessTimeBefore, FileTime truststoreAccessTimeBefore) + throws InterruptedException, IOException { + LOG.info( + "Waiting {} seconds to verify certificates are not accessed...", + WAIT_MS.toSeconds()); + Thread.sleep(WAIT_MS.toMillis()); + + FileTime keystoreAccessTimeAfter = + getFileAccessTime(internalSslDir.resolve(KEYSTORE_FILENAME)); + FileTime truststoreAccessTimeAfter = + getFileAccessTime(internalSslDir.resolve(TRUSTSTORE_FILENAME)); + + assertEquals( + "Keystore should not be accessed when SSL is disabled", + keystoreAccessTimeBefore, + keystoreAccessTimeAfter); + assertEquals( + "Truststore should not be accessed when SSL is disabled", + truststoreAccessTimeBefore, + truststoreAccessTimeAfter); + + LOG.info("SSL end-to-end test completed successfully (SSL disabled verified)"); + } + + /** POJO class to hold Flink port information. */ + protected static class FlinkPorts { + private final int blobServerPort; + private final int jobManagerRpcPort; + private final int nettyServerPort; + + public FlinkPorts(int blobServerPort, int jobManagerRpcPort, int nettyServerPort) { + this.blobServerPort = blobServerPort; + this.jobManagerRpcPort = jobManagerRpcPort; + this.nettyServerPort = nettyServerPort; + } + + public int getBlobServerPort() { + return blobServerPort; + } + + public int getJobManagerRpcPort() { + return jobManagerRpcPort; + } + + public int getNettyServerPort() { + return nettyServerPort; + } + + @Override + public String toString() { + return String.format( + "FlinkPorts{blobServer=%d, jobManagerRpc=%d, nettyServer=%d}", + blobServerPort, jobManagerRpcPort, nettyServerPort); + } + } + + protected static class CertificateDates { + private final Optional blobServerCertDate; + private final Optional jobManagerRpcCertDate; + private final Optional nettyServerCertDate; + + public CertificateDates( + Optional blobServerCertDate, + Optional jobManagerRpcCertDate, + Optional nettyServerCertDate) { + this.blobServerCertDate = blobServerCertDate; + this.jobManagerRpcCertDate = jobManagerRpcCertDate; + this.nettyServerCertDate = nettyServerCertDate; + } + + public Optional getBlobServerCertDate() { + return blobServerCertDate; + } + + public Optional getJobManagerRpcCertDate() { + return jobManagerRpcCertDate; + } + + public Optional getNettyServerCertDate() { + return nettyServerCertDate; + } + + /** + * Checks if all certificate dates are present. + * + * @return true if all certificates are present, false otherwise + */ + public boolean isAllPresent() { + return blobServerCertDate.isPresent() + && jobManagerRpcCertDate.isPresent() + && nettyServerCertDate.isPresent(); + } + + /** + * Checks if none of the certificate dates are present. + * + * @return true if none of the certificates are present, false otherwise + */ + public boolean isNonePresent() { + return !blobServerCertDate.isPresent() + && !jobManagerRpcCertDate.isPresent() + && !nettyServerCertDate.isPresent(); + } + + @Override + public String toString() { + return String.format( + "CertificateDates{blobServer=%s, jobManagerRpc=%s, nettyServer=%s}", + blobServerCertDate.orElse("N/A"), + jobManagerRpcCertDate.orElse("N/A"), + nettyServerCertDate.orElse("N/A")); + } + } +} diff --git a/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslNoReloadIT.java b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslNoReloadIT.java new file mode 100644 index 0000000000000..2ffaaec8bb06c --- /dev/null +++ b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslNoReloadIT.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ssl.tests; + +import org.apache.flink.tests.util.flink.ClusterController; + +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.attribute.FileTime; + +import static org.junit.Assert.assertTrue; + +/** + * End-to-end test for SSL enabled without certificate reload. This test verifies that when SSL + * reload is disabled, certificates are not reloaded even when they change on disk. + */ +public class SslNoReloadIT extends SslEndToEndITCaseBase { + + public SslNoReloadIT() throws IOException { + super(true, false); + } + + /** + * Test SSL functionality without certificate reload. Verifies that certificates are NOT + * reloaded and certificate files are NOT accessed when reload is disabled. + */ + @Test + public void testSslBlobOperationsWithoutCertificateReload() throws Exception { + LOG.info("Starting SSL end-to-end test: SSL enabled without reload"); + + // Start Flink cluster with the SSL configuration set in constructor + try (ClusterController ignored = flinkResource.startCluster(1)) { + final FlinkPorts ports = getAllPorts(); + + // Verify all certificates are accessible + final CertificateDates initialCertDates = getAllCertificateDates(ports); + assertTrue( + "All certificates should be accessible: " + initialCertDates, + initialCertDates.isAllPresent()); + + LOG.info("Generating new SSL certificates with {}-day validity", NEW_VALIDITY_DAYS); + SslTestUtils.generateAndInstallCertificates( + internalSslDir, SSL_PASSWORD, NEW_VALIDITY_DAYS); + FileTime keystoreAccessTimeBefore = + getFileAccessTime(internalSslDir.resolve(KEYSTORE_FILENAME)); + FileTime truststoreAccessTimeBefore = + getFileAccessTime(internalSslDir.resolve(TRUSTSTORE_FILENAME)); + + verifyCertificatesAreNotAccessed(keystoreAccessTimeBefore, truststoreAccessTimeBefore); + } + } +} diff --git a/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslTestUtils.java b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslTestUtils.java new file mode 100644 index 0000000000000..11ca3e7f23226 --- /dev/null +++ b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslTestUtils.java @@ -0,0 +1,541 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ssl.tests; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.SecurityOptions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.tests.util.AutoClosableProcess.runBlocking; + +/** + * Utility class for SSL test setup and certificate generation. This class provides Java-based + * alternatives to the common_ssl.sh bash scripts used in end-to-end tests. + */ +public class SslTestUtils { + + private static final Logger LOG = LoggerFactory.getLogger(SslTestUtils.class); + + /** SSL provider types. */ + public enum SslProvider { + JDK, + OPENSSL + } + + /** SSL provider library linking type. */ + public enum ProviderLibrary { + DYNAMIC, + STATIC + } + + /** SSL connectivity type. */ + public enum SslType { + INTERNAL("internal"), + REST("rest"); + + private final String value; + + SslType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** SSL authentication mode. */ + public enum AuthenticationMode { + SERVER, + MUTUAL + } + + /** + * Generates SSL certificates and configures Flink SSL settings. + * + * @param testDataDir the root test data directory + * @param type SSL type (internal or rest) + * @param provider SSL provider (JDK or OPENSSL) + * @param providerLib provider library type (dynamic or static) + * @return Configuration with SSL settings + * @throws IOException if certificate generation fails + */ + public static Configuration setupSslHelper( + Path testDataDir, SslType type, SslProvider provider, ProviderLibrary providerLib) + throws IOException { + return setupSslHelper(testDataDir, type, provider, providerLib, 2); + } + + /** + * Generates SSL certificates and configures Flink SSL settings with custom validity period. + * + * @param testDataDir the root test data directory + * @param type SSL type (internal or rest) + * @param provider SSL provider (JDK or OPENSSL) + * @param providerLib provider library type (dynamic or static) + * @param validityDays certificate validity period in days + * @return Configuration with SSL settings + * @throws IOException if certificate generation fails + */ + public static Configuration setupSslHelper( + Path testDataDir, + SslType type, + SslProvider provider, + ProviderLibrary providerLib, + int validityDays) + throws IOException { + + LOG.info( + "Setting up SSL with: {} {} {} (validity: {} days)", + type, + provider, + providerLib, + validityDays); + + Path sslDir = testDataDir.resolve("ssl").resolve(type.getValue()); + String password = type.getValue() + ".password"; + + // Generate and install certificates + generateAndInstallCertificates(sslDir, password, validityDays); + + // Configure OpenSSL if needed + if (provider == SslProvider.OPENSSL) { + configureOpenSsl(providerLib); + } + + // Build and return configuration + return buildSslConfiguration(type, provider, sslDir, password); + } + + /** + * Generates SSL certificates and installs them in the specified directory. + * + * @param sslDir the directory where certificates will be stored + * @param password the password for keystores + * @param validityDays certificate validity period in days + * @throws IOException if certificate generation fails + */ + public static void generateAndInstallCertificates( + Path sslDir, String password, int validityDays) throws IOException { + + // Clean up and create SSL directory + if (Files.exists(sslDir)) { + LOG.info("Directory {} exists. Deleting it...", sslDir); + deleteRecursively(sslDir); + } + Files.createDirectories(sslDir); + + // Build SAN string + String nodeName = getNodeName(); + List nodeIps = getNodeIps(); + StringBuilder sanString = new StringBuilder("dns:" + nodeName); + for (String ip : nodeIps) { + sanString.append(",ip:").append(ip); + } + + LOG.info("Using SAN {}", sanString); + + // Create certificates + createCertificates(sslDir, password, nodeName, sanString.toString(), validityDays); + + // Export keystore to PEM format for curl + convertKeystoreToPem(sslDir, password); + } + + /** Sets up internal SSL configuration. */ + public static Configuration setupInternalSsl( + Path testDataDir, SslProvider provider, ProviderLibrary providerLib) + throws IOException { + return setupSslHelper(testDataDir, SslType.INTERNAL, provider, providerLib); + } + + /** Sets up REST SSL configuration. */ + public static Configuration setupRestSsl( + Path testDataDir, + AuthenticationMode auth, + SslProvider provider, + ProviderLibrary providerLib) + throws IOException { + Configuration config = setupSslHelper(testDataDir, SslType.REST, provider, providerLib); + + boolean mutualAuth = auth == AuthenticationMode.MUTUAL; + LOG.info("Mutual ssl auth: {}", mutualAuth); + config.set(SecurityOptions.SSL_REST_AUTHENTICATION_ENABLED, mutualAuth); + + return config; + } + + /** + * Creates SSL certificates using keytool. + * + * @param sslDir the directory where certificates will be stored + * @param password the password for keystores + * @param nodeName the node hostname + * @param sanString the Subject Alternative Names string + * @param validityDays certificate validity period in days + * @throws IOException if certificate generation fails + */ + private static void createCertificates( + Path sslDir, String password, String nodeName, String sanString, int validityDays) + throws IOException { + + // Generate CA certificate + runBlocking( + "keytool", + "-genkeypair", + "-alias", + "ca", + "-keystore", + sslDir.resolve("ca.keystore").toString(), + "-dname", + "CN=Sample CA", + "-storepass", + password, + "-keypass", + password, + "-keyalg", + "RSA", + "-ext", + "bc=ca:true", + "-storetype", + "PKCS12", + "-validity", + String.valueOf(validityDays)); + + // Export CA certificate + runBlocking( + "keytool", + "-keystore", + sslDir.resolve("ca.keystore").toString(), + "-storepass", + password, + "-alias", + "ca", + "-exportcert", + "-file", + sslDir.resolve("ca.cer").toString()); + + // Import CA certificate to truststore + runBlocking( + "keytool", + "-importcert", + "-keystore", + sslDir.resolve("ca.truststore").toString(), + "-alias", + "ca", + "-storepass", + password, + "-noprompt", + "-file", + sslDir.resolve("ca.cer").toString()); + + // Generate node certificate + runBlocking( + "keytool", + "-genkeypair", + "-alias", + "node", + "-keystore", + sslDir.resolve("node.keystore").toString(), + "-dname", + "CN=" + nodeName, + "-ext", + "SAN=" + sanString, + "-storepass", + password, + "-keypass", + password, + "-keyalg", + "RSA", + "-storetype", + "PKCS12", + "-validity", + String.valueOf(validityDays)); + + // Create certificate signing request + runBlocking( + "keytool", + "-certreq", + "-keystore", + sslDir.resolve("node.keystore").toString(), + "-storepass", + password, + "-alias", + "node", + "-file", + sslDir.resolve("node.csr").toString()); + + // Sign certificate + runBlocking( + "keytool", + "-gencert", + "-keystore", + sslDir.resolve("ca.keystore").toString(), + "-storepass", + password, + "-alias", + "ca", + "-ext", + "SAN=" + sanString, + "-validity", + String.valueOf(validityDays), + "-infile", + sslDir.resolve("node.csr").toString(), + "-outfile", + sslDir.resolve("node.cer").toString()); + + // Import CA certificate to node keystore + runBlocking( + "keytool", + "-importcert", + "-keystore", + sslDir.resolve("node.keystore").toString(), + "-storepass", + password, + "-file", + sslDir.resolve("ca.cer").toString(), + "-alias", + "ca", + "-noprompt"); + + // Import signed node certificate + runBlocking( + "keytool", + "-importcert", + "-keystore", + sslDir.resolve("node.keystore").toString(), + "-storepass", + password, + "-file", + sslDir.resolve("node.cer").toString(), + "-alias", + "node", + "-noprompt"); + } + + /** Converts keystore to PEM format using OpenSSL. */ + private static void convertKeystoreToPem(Path sslDir, String password) throws IOException { + List command = new ArrayList<>(); + command.add("openssl"); + command.add("pkcs12"); + + // Check OpenSSL version and add legacy flag if needed + if (isOpenSsl3OrHigher()) { + command.add("-legacy"); + } + + command.add("-passin"); + command.add("pass:" + password); + command.add("-in"); + command.add(sslDir.resolve("node.keystore").toString()); + command.add("-out"); + command.add(sslDir.resolve("node.pem").toString()); + command.add("-nodes"); + + runBlocking(command.toArray(new String[0])); + } + + /** + * Gets certificate validity dates from a given host and port. + * + * @param host the host to check + * @param port the port to check + * @return array with [notBefore, notAfter] date strings, or null if unable to retrieve + */ + public static String[] getCertificateValidityDates(String host, int port) { + try { + ProcessBuilder pb = + new ProcessBuilder( + "sh", + "-c", + String.format( + "openssl s_client -connect %s:%d /dev/null | openssl x509 -noout -dates", + host, port)); + + Process process = pb.start(); + StringBuilder output = new StringBuilder(); + + try (var reader = + new java.io.BufferedReader( + new java.io.InputStreamReader(process.getInputStream()))) { + String line; + while ((line = reader.readLine()) != null) { + output.append(line).append("\n"); + } + } + + process.waitFor(5, java.util.concurrent.TimeUnit.SECONDS); + process.destroyForcibly(); + + String result = output.toString(); + String notBefore = null; + String notAfter = null; + + // Parse output + for (String line : result.split("\n")) { + if (line.startsWith("notBefore=")) { + notBefore = line.substring("notBefore=".length()).trim(); + } else if (line.startsWith("notAfter=")) { + notAfter = line.substring("notAfter=".length()).trim(); + } + } + + if (notBefore != null && notAfter != null) { + LOG.info( + "Certificate validity for {}:{} - notBefore: {}, notAfter: {}", + host, + port, + notBefore, + notAfter); + return new String[] {notBefore, notAfter}; + } + + LOG.warn("Could not retrieve certificate validity dates from {}:{}", host, port); + return null; + } catch (Exception e) { + LOG.debug("Failed to get certificate dates from {}:{}", host, port, e); + return null; + } + } + + /** Checks if OpenSSL version is 3.x or higher. */ + private static boolean isOpenSsl3OrHigher() { + try { + Process process = new ProcessBuilder("openssl", "version").start(); + try (var reader = + new java.io.BufferedReader( + new java.io.InputStreamReader(process.getInputStream()))) { + String version = reader.readLine(); + if (version != null) { + return !version.contains("OpenSSL 1"); + } + } + process.waitFor(); + } catch (Exception e) { + LOG.warn("Could not determine OpenSSL version, assuming OpenSSL 3+", e); + } + return true; + } + + /** + * Configures OpenSSL library for Flink. + * + * @param providerLib the provider library type (dynamic or static) + */ + private static void configureOpenSsl(ProviderLibrary providerLib) { + // This would copy the appropriate netty-tcnative jar to Flink's lib directory + // For test purposes, this might not be needed if using JDK provider + LOG.info("OpenSSL configuration for {} library type would be applied here", providerLib); + // Implementation depends on test environment setup + // In bash script this copies flink-shaded-netty-tcnative-*.jar to $FLINK_DIR/lib/ + } + + /** Builds Flink SSL configuration. */ + private static Configuration buildSslConfiguration( + SslType type, SslProvider provider, Path sslDir, String password) { + + Configuration config = new Configuration(); + + config.set(SecurityOptions.SSL_PROVIDER, provider.name()); + + if (type == SslType.INTERNAL) { + config.set(SecurityOptions.SSL_INTERNAL_ENABLED, true); + config.set( + SecurityOptions.SSL_INTERNAL_KEYSTORE, + sslDir.resolve("node.keystore").toString()); + config.set(SecurityOptions.SSL_INTERNAL_KEYSTORE_PASSWORD, password); + config.set(SecurityOptions.SSL_INTERNAL_KEY_PASSWORD, password); + config.set( + SecurityOptions.SSL_INTERNAL_TRUSTSTORE, + sslDir.resolve("ca.truststore").toString()); + config.set(SecurityOptions.SSL_INTERNAL_TRUSTSTORE_PASSWORD, password); + } else { // REST + config.set(SecurityOptions.SSL_REST_ENABLED, true); + config.set( + SecurityOptions.SSL_REST_KEYSTORE, sslDir.resolve("node.keystore").toString()); + config.set(SecurityOptions.SSL_REST_KEYSTORE_PASSWORD, password); + config.set(SecurityOptions.SSL_REST_KEY_PASSWORD, password); + config.set( + SecurityOptions.SSL_REST_TRUSTSTORE, + sslDir.resolve("ca.truststore").toString()); + config.set(SecurityOptions.SSL_REST_TRUSTSTORE_PASSWORD, password); + } + + return config; + } + + /** Gets the node name (hostname). */ + private static String getNodeName() { + try { + return InetAddress.getLocalHost().getHostName(); + } catch (Exception e) { + LOG.warn("Could not determine hostname, using localhost", e); + return "localhost"; + } + } + + /** Gets all IP addresses of the node. */ + private static List getNodeIps() { + List ips = new ArrayList<>(); + try { + Enumeration interfaces = NetworkInterface.getNetworkInterfaces(); + while (interfaces.hasMoreElements()) { + NetworkInterface iface = interfaces.nextElement(); + if (iface.isLoopback() || !iface.isUp()) { + continue; + } + Enumeration addresses = iface.getInetAddresses(); + while (addresses.hasMoreElements()) { + InetAddress addr = addresses.nextElement(); + ips.add(addr.getHostAddress()); + } + } + } catch (Exception e) { + LOG.warn("Could not enumerate network interfaces, using localhost", e); + } + if (ips.isEmpty()) { + ips.add("127.0.0.1"); + } + return ips; + } + + /** Recursively deletes a directory. */ + private static void deleteRecursively(Path path) throws IOException { + if (Files.isDirectory(path)) { + try (var stream = Files.list(path)) { + for (Path child : stream.collect(Collectors.toList())) { + deleteRecursively(child); + } + } + } + Files.deleteIfExists(path); + } +} diff --git a/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java new file mode 100644 index 0000000000000..f173f2aec6e11 --- /dev/null +++ b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ssl.tests; + +import org.apache.flink.tests.util.flink.ClusterController; + +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +/** + * End-to-end test for SSL with certificate reload enabled. This test verifies that SSL-enabled + * components can handle certificate changes without service disruption. + */ +public class SslWithReloadIT extends SslEndToEndITCaseBase { + + public SslWithReloadIT() throws IOException { + super(true, true); + } + + /** + * Test SSL functionality with certificate reload enabled. Verifies that new certificates are + * properly reloaded and used by the BlobServer. + */ + @Test + public void testSslBlobOperationsAndCertificateReload() throws Exception { + LOG.info("Starting SSL end-to-end test: SSL enabled with reload"); + + // Start Flink cluster with the SSL configuration set in constructor + try (ClusterController ignored = flinkResource.startCluster(1)) { + final FlinkPorts ports = getAllPorts(); + + // Verify all certificates are accessible + final CertificateDates initialCertDates = getAllCertificateDates(ports); + assertTrue( + "All certificates should be accessible: " + initialCertDates, + initialCertDates.isAllPresent()); + + LOG.info("Generating new SSL certificates with {}-day validity", NEW_VALIDITY_DAYS); + SslTestUtils.generateAndInstallCertificates( + internalSslDir, SSL_PASSWORD, NEW_VALIDITY_DAYS); + LOG.info("New certificates generated, waiting for reload..."); + + // Wait for certificate reload on all ports + final CertificateDates newCertDates = + getAllNewCertificateDates(ports, initialCertDates); + + // Verify all certificates were reloaded + assertTrue( + "All certificates should be reloaded: " + newCertDates, + newCertDates.isAllPresent()); + + // Verify certificate dates changed after reload + assertNotEquals( + "BlobServer certificate notAfter date should change after reload", + initialCertDates.getBlobServerCertDate(), + newCertDates.getBlobServerCertDate()); + assertNotEquals( + "JobManager RPC certificate notAfter date should change after reload", + initialCertDates.getJobManagerRpcCertDate(), + newCertDates.getJobManagerRpcCertDate()); + assertNotEquals( + "Netty server certificate notAfter date should change after reload", + initialCertDates.getNettyServerCertDate(), + newCertDates.getNettyServerCertDate()); + } + } +} diff --git a/flink-end-to-end-tests/flink-ssl-test/src/test/resources/log4j2-test.properties b/flink-end-to-end-tests/flink-ssl-test/src/test/resources/log4j2-test.properties new file mode 100644 index 0000000000000..835c2ec9a3d02 --- /dev/null +++ b/flink-end-to-end-tests/flink-ssl-test/src/test/resources/log4j2-test.properties @@ -0,0 +1,28 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level = OFF +rootLogger.appenderRef.test.ref = TestLogger + +appender.testlogger.name = TestLogger +appender.testlogger.type = CONSOLE +appender.testlogger.target = SYSTEM_ERR +appender.testlogger.layout.type = PatternLayout +appender.testlogger.layout.pattern = %-4r [%t] %-5p %c %x - %m%n diff --git a/flink-end-to-end-tests/pom.xml b/flink-end-to-end-tests/pom.xml index c6137a188da48..2b9a5adab7cef 100644 --- a/flink-end-to-end-tests/pom.xml +++ b/flink-end-to-end-tests/pom.xml @@ -60,6 +60,7 @@ under the License. flink-end-to-end-tests-common flink-metrics-availability-test flink-metrics-reporter-prometheus-test + flink-ssl-test flink-heavy-deployment-stress-test flink-plugins-test flink-tpch-test diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java index fa4623e84b238..dae36e8ee25dc 100644 --- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java +++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/pekko/PekkoUtils.java @@ -359,7 +359,8 @@ private static void addSslRemoteConfig( Arrays.stream(sslAlgorithmsString.split(",")) .collect(Collectors.joining(",", "[", "]")); - final boolean enabledCertReloadConfig = SecurityOptions.isCertificateReloadEnabled(configuration); + final boolean enabledCertReloadConfig = + SecurityOptions.isCertificateReloadEnabled(configuration); final String enabledCertReload = booleanToOnOrOff(enabledCertReloadConfig); final String sslEngineProviderName = CustomSSLEngineProvider.class.getCanonicalName(); From 8baa9a310162017546478fdb01e7fbb40e93ec31 Mon Sep 17 00:00:00 2001 From: "oleksandr.nitavskyi" Date: Sat, 4 Oct 2025 13:18:07 +0200 Subject: [PATCH 4/4] Fix CI * javadoc * flacky tests * another tests --- .../shortcodes/generated/security_configuration.html | 6 ++++++ .../layouts/shortcodes/generated/security_ssl_section.html | 6 ++++++ .../java/org/apache/flink/ssl/tests/SslWithReloadIT.java | 7 ++++--- .../java/org/apache/flink/runtime/net/SSLUtilsTest.java | 7 +++---- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/layouts/shortcodes/generated/security_configuration.html b/docs/layouts/shortcodes/generated/security_configuration.html index ff19479042e88..6c8d216a983e4 100644 --- a/docs/layouts/shortcodes/generated/security_configuration.html +++ b/docs/layouts/shortcodes/generated/security_configuration.html @@ -182,6 +182,12 @@ String The SSL engine provider to use for the ssl transport:
  • JDK: default Java-based SSL engine
  • OPENSSL: openSSL-based SSL engine using system libraries
OPENSSL is based on netty-tcnative and comes in two flavours:
  • dynamically linked: This will use your system's openSSL libraries (if compatible) and requires opt/flink-shaded-netty-tcnative-dynamic-*.jar to be copied to lib/
  • statically linked: Due to potential licensing issues with openSSL (see LEGAL-393), we cannot ship pre-built libraries. However, you can build the required library yourself and put it into lib/:
    git clone https://github.com/apache/flink-shaded.git && cd flink-shaded && mvn clean package -Pinclude-netty-tcnative-static -pl flink-shaded-netty-tcnative-static
+ +
security.ssl.reload
+ false + Boolean + If enabled, the application will monitor the keystore and truststore files for any changes. When a change is detected, internal network components (like Netty, Pekko, or BlobServer) will automatically reload the keystore/truststore certificates. +
security.ssl.rest.authentication-enabled
false diff --git a/docs/layouts/shortcodes/generated/security_ssl_section.html b/docs/layouts/shortcodes/generated/security_ssl_section.html index ad5c72b3cf232..14e4643c4649d 100644 --- a/docs/layouts/shortcodes/generated/security_ssl_section.html +++ b/docs/layouts/shortcodes/generated/security_ssl_section.html @@ -74,6 +74,12 @@ String The SSL protocol version to be supported for the ssl transport. Note that it doesn’t support comma separated list. + +
security.ssl.reload
+ false + Boolean + If enabled, the application will monitor the keystore and truststore files for any changes. When a change is detected, internal network components (like Netty, Pekko, or BlobServer) will automatically reload the keystore/truststore certificates. +
security.ssl.rest.authentication-enabled
false diff --git a/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java index f173f2aec6e11..8967f8b3c4284 100644 --- a/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java +++ b/flink-end-to-end-tests/flink-ssl-test/src/test/java/org/apache/flink/ssl/tests/SslWithReloadIT.java @@ -60,13 +60,14 @@ public void testSslBlobOperationsAndCertificateReload() throws Exception { internalSslDir, SSL_PASSWORD, NEW_VALIDITY_DAYS); LOG.info("New certificates generated, waiting for reload..."); - // Wait for certificate reload on all ports final CertificateDates newCertDates = getAllNewCertificateDates(ports, initialCertDates); - // Verify all certificates were reloaded assertTrue( - "All certificates should be reloaded: " + newCertDates, + "All certificates should be reloaded: " + + newCertDates + + ", intial certificate dates: " + + initialCertDates, newCertDates.isAllPresent()); // Verify certificate dates changed after reload diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java index ceb4263b01f97..375e2c778ce08 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/net/SSLUtilsTest.java @@ -25,8 +25,8 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.ClientAuth; -import org.apache.flink.shaded.netty4.io.netty.handler.ssl.JdkSslContext; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.OpenSsl; +import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslContext; import org.apache.flink.shaded.netty4.io.netty.handler.ssl.SslHandler; import org.junit.jupiter.api.Test; @@ -170,9 +170,8 @@ void testRESTSSLConfigCipherAlgorithms(String sslProvider) throws Exception { Configuration config = createRestSslConfigWithTrustStore(sslProvider); config.set(SecurityOptions.SSL_REST_ENABLED, true); config.setString(SecurityOptions.SSL_ALGORITHMS.key(), testSSLAlgorithms); - JdkSslContext nettySSLContext = - (JdkSslContext) - SSLUtils.createRestNettySSLContext(config, true, ClientAuth.NONE, JDK); + SslContext nettySSLContext = + SSLUtils.createRestNettySSLContext(config, true, ClientAuth.NONE, JDK); List cipherSuites = checkNotNull(nettySSLContext).cipherSuites(); assertThat(cipherSuites).hasSize(2); assertThat(cipherSuites).containsExactlyInAnyOrder(testSSLAlgorithms.split(","));