diff --git a/src/itest/java/com/hierynomus/sshj/KeepAliveTest.java b/src/itest/java/com/hierynomus/sshj/KeepAliveTest.java new file mode 100644 index 000000000..c95bfe6ce --- /dev/null +++ b/src/itest/java/com/hierynomus/sshj/KeepAliveTest.java @@ -0,0 +1,59 @@ +package com.hierynomus.sshj; + +import net.schmizz.keepalive.BoundedKeepAliveProvider; +import net.schmizz.sshj.Config; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.LoggerFactory; +import net.schmizz.sshj.transport.verification.PromiscuousVerifier; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; + +import java.util.ArrayList; +import java.util.List; + +public class KeepAliveTest { + @Container + SshdContainer sshd = new SshdContainer(SshdContainer.Builder + .defaultBuilder() + .withAllKeys() + .withPackages("iptables") + .withPrivileged(true)); + + @Test + void testKeepAlive() throws Exception { + sshd.start(); + + Config config = new DefaultConfig(); + BoundedKeepAliveProvider p = new BoundedKeepAliveProvider(LoggerFactory.DEFAULT, 4); + p.setKeepAliveInterval(1); + p.setMaxKeepAliveCount(1); + config.setKeepAliveProvider(p); + List clients = new ArrayList<>(); + for (int i=0; i<10; i++) { + SSHClient c = new SSHClient(config); + c.addHostKeyVerifier(new PromiscuousVerifier()); + c.connect("127.0.0.1", sshd.getFirstMappedPort()); + c.authPassword("sshj", "ultrapassword"); + var sess = c.startSession(); + sess.allocateDefaultPTY(); + clients.add(c); + } + + for (SSHClient client : clients) { + Assertions.assertTrue(client.isConnected()); + } + + var res = sshd.execInContainer("iptables", "-A", "INPUT", "-p", "tcp", "--dport", "22", "-j", "DROP"); + Assertions.assertEquals(0, res.getExitCode()); + // wait for keepalive to take action + Thread.sleep(2000); + + for (SSHClient client : clients) { + Assertions.assertFalse(client.isConnected()); + } + + p.shutdown(); + } +} diff --git a/src/itest/java/com/hierynomus/sshj/SshdContainer.java b/src/itest/java/com/hierynomus/sshj/SshdContainer.java index 91b531e68..8d17560fb 100644 --- a/src/itest/java/com/hierynomus/sshj/SshdContainer.java +++ b/src/itest/java/com/hierynomus/sshj/SshdContainer.java @@ -106,13 +106,24 @@ public static class Builder implements Consumer { private List hostKeys = new ArrayList<>(); private List certificates = new ArrayList<>(); private @NotNull SshdConfigBuilder sshdConfig = SshdConfigBuilder.defaultBuilder(); + private boolean privileged = false; + private List packages = new ArrayList<>(); public static Builder defaultBuilder() { Builder b = new Builder(); - return b; } + public @NotNull Builder withPrivileged(boolean privileged) { + this.privileged = privileged; + return this; + } + + public @NotNull Builder withPackages(@NotNull String... packages) { + this.packages.addAll(List.of(packages)); + return this; + } + public @NotNull Builder withSshdConfig(@NotNull SshdConfigBuilder sshdConfig) { this.sshdConfig = sshdConfig; @@ -153,6 +164,9 @@ public void accept(@NotNull DockerfileBuilder builder) { builder.expose(22); builder.copy("entrypoint.sh", "/entrypoint.sh"); + if (!packages.isEmpty()) { + builder.run("apk add --no-cache " + String.join(" ", packages)); + } builder.add("authorized_keys", "/home/sshj/.ssh/authorized_keys"); builder.copy("test-container/trusted_ca_keys", "/etc/ssh/trusted_ca_keys"); @@ -201,6 +215,9 @@ public SshdContainer() { public SshdContainer(SshdContainer.Builder builder) { this(builder.buildInner()); + if (builder.privileged) { + withPrivilegedMode(true); + } } public SshdContainer(@NotNull Future future) { diff --git a/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java b/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java new file mode 100644 index 000000000..5702b2aaa --- /dev/null +++ b/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java @@ -0,0 +1,201 @@ +package net.schmizz.keepalive; + +import net.schmizz.sshj.Config; +import net.schmizz.sshj.common.LoggerFactory; +import net.schmizz.sshj.connection.ConnectionException; +import net.schmizz.sshj.connection.ConnectionImpl; +import net.schmizz.sshj.transport.TransportException; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +/** + * This implementation manages all {@link KeepAlive}s using configured number of threads. It works like a + * thread pool, thus {@link BoundedKeepAliveProvider#shutdown()} must be called to clean up resources. + *
+ * This provider uses {@link KeepAliveRunner#doKeepAlive()} as delegate, so it supports maxKeepAliveCount + * parameter. All instances provided by this provider have identical configuration. + */ +public class BoundedKeepAliveProvider extends KeepAliveProvider { + + public int maxKeepAliveCount = 3; + public int keepAliveInterval = 5; + + protected final KeepAliveMonitor monitor; + + + public BoundedKeepAliveProvider(LoggerFactory loggerFactory, int numberOfThreads) { + this.monitor = new KeepAliveMonitor(loggerFactory, numberOfThreads); + } + + public void setKeepAliveInterval(int interval) { + keepAliveInterval = interval; + } + + public void setMaxKeepAliveCount(int count) { + maxKeepAliveCount = count; + } + + @Override + public KeepAlive provide(ConnectionImpl connection) { + return new Impl(connection, "bounded-keepalive-impl"); + } + + public void shutdown() throws InterruptedException { + monitor.shutdown(); + } + + class Impl extends KeepAlive { + + private final KeepAliveRunner delegate; + + protected Impl(ConnectionImpl conn, String name) { + super(conn, name); + this.delegate = new KeepAliveRunner(conn); + + // take care here, some parameters are set to both delegate and this + this.delegate.setMaxAliveCount(BoundedKeepAliveProvider.this.maxKeepAliveCount); + super.keepAliveInterval = BoundedKeepAliveProvider.this.keepAliveInterval; + } + + @Override + protected void doKeepAlive() throws TransportException, ConnectionException { + delegate.doKeepAlive(); + } + + @Override + public void startKeepAlive() { + monitor.register(this); + } + + } + + protected static class KeepAliveMonitor { + private final Logger logger; + + private final PriorityBlockingQueue q = + new PriorityBlockingQueue<>(32, Comparator.comparingLong(w -> w.nextTimeMillis)); + private static final List workerThreads = new ArrayList<>(); + + private volatile long idleSleepMillis = 100; + private final int numberOfThreads; + + volatile boolean started = false; + + private final ReentrantLock lock = new ReentrantLock(); + private final Condition shutDown = lock.newCondition(); + private final AtomicInteger shutDownCnt = new AtomicInteger(0); + + public KeepAliveMonitor(LoggerFactory loggerFactory, int numberOfThreads) { + this.numberOfThreads = numberOfThreads; + logger = loggerFactory.getLogger(KeepAliveMonitor.class); + } + + // made public for test + public void register(KeepAlive keepAlive) { + if (!started) { + start(); + } + q.add(new Wrapper(keepAlive)); + } + + public void setIdleSleepMillis(long idleSleepMillis) { + this.idleSleepMillis = idleSleepMillis; + } + + private void sleep() { + sleep(idleSleepMillis); + } + + private void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private synchronized void start() { + if (started) { + return; + } + + for (int i = 0; i < numberOfThreads; i++) { + Thread t = new Thread(this::doStart); + workerThreads.add(t); + } + workerThreads.forEach(Thread::start); + started = true; + } + + + private void doStart() { + while (!Thread.currentThread().isInterrupted()) { + Wrapper wrapper; + + if (q.isEmpty() || (wrapper = q.poll()) == null) { + sleep(); + continue; + } + + long currentTimeMillis = System.currentTimeMillis(); + if (wrapper.nextTimeMillis > currentTimeMillis) { + long sleepMillis = wrapper.nextTimeMillis - currentTimeMillis; + logger.debug("{} millis until next check, sleep", sleepMillis); + sleep(sleepMillis); + } + + try { + wrapper.keepAlive.doKeepAlive(); + q.add(wrapper.reschedule()); + } catch (Exception e) { + // If we weren't interrupted, kill the transport, then this exception was unexpected. + // Else we're in shutdown-mode already, so don't forcibly kill the transport. + if (!Thread.currentThread().isInterrupted()) { + wrapper.keepAlive.conn.getTransport().die(e); + } + } + } + lock.lock(); + try { + if (shutDownCnt.incrementAndGet() == numberOfThreads) { + shutDown.signal(); + } + } finally { + lock.unlock(); + } + } + + private synchronized void shutdown() throws InterruptedException { + if (workerThreads.isEmpty()) { + return; + } + for (Thread t : workerThreads) { + t.interrupt(); + } + lock.lock(); + logger.info("waiting for all {} threads to finish", numberOfThreads); + shutDown.await(); + } + + private static class Wrapper { + private final KeepAlive keepAlive; + private final long nextTimeMillis; + + private Wrapper(KeepAlive keepAlive) { + this.keepAlive = keepAlive; + this.nextTimeMillis = System.currentTimeMillis() + keepAlive.keepAliveInterval * 1000L; + } + + private Wrapper reschedule() { + return new Wrapper(keepAlive); + } + } + } +} diff --git a/src/main/java/net/schmizz/keepalive/KeepAlive.java b/src/main/java/net/schmizz/keepalive/KeepAlive.java index 05e771f73..badbb948d 100644 --- a/src/main/java/net/schmizz/keepalive/KeepAlive.java +++ b/src/main/java/net/schmizz/keepalive/KeepAlive.java @@ -89,4 +89,11 @@ public void run() { } protected abstract void doKeepAlive() throws TransportException, ConnectionException; + + /** + * Start keep-alive loop. Implementations MUST NOT block current thread. + */ + public void startKeepAlive() { + start(); + } } diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index 792038c55..93d62c90b 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -809,7 +809,7 @@ protected void onConnect() final KeepAlive keepAliveThread = conn.getKeepAlive(); if (keepAliveThread.isEnabled()) { ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); - keepAliveThread.start(); + keepAliveThread.startKeepAlive(); } } diff --git a/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java b/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java new file mode 100644 index 000000000..0ee50f6c4 --- /dev/null +++ b/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java @@ -0,0 +1,96 @@ +package com.hierynomus.sshj.keepalive; + +import com.hierynomus.sshj.test.SshServerExtension; +import net.schmizz.keepalive.BoundedKeepAliveProvider; +import net.schmizz.keepalive.KeepAlive; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.LoggerFactory; +import net.schmizz.sshj.connection.ConnectionException; +import net.schmizz.sshj.connection.ConnectionImpl; +import net.schmizz.sshj.transport.TransportException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +class EventuallyFailKeepAlive extends KeepAlive { + // they can survive first 2 checks, and fail at 3rd + int failAfter = 1; + volatile int current = 0; + + protected EventuallyFailKeepAlive(ConnectionImpl conn, String name) { + super(conn, name); + setKeepAliveInterval(1); + } + + @Override + protected void doKeepAlive() throws TransportException, ConnectionException { + current++; + if (current > failAfter) { + throw new ConnectionException("failed"); + } + } +} + +public class BoundedKeepAliveProviderTest { + + static BoundedKeepAliveProvider kp; + static final DefaultConfig defaultConfig = new DefaultConfig(); + + + @BeforeAll + static void setUpBeforeClass() throws Exception { + + kp = new BoundedKeepAliveProvider(LoggerFactory.DEFAULT, 2) { + @Override + public KeepAlive provide(ConnectionImpl connection) { + return new EventuallyFailKeepAlive(connection, "test") { + @Override + public void startKeepAlive() { + monitor.register(this); + } + }; + } + }; + } + + @RegisterExtension + public SshServerExtension fixture = new SshServerExtension(); + + void testWithConnections(int numOfConnections) throws IOException, InterruptedException { + List clients = setupClients(numOfConnections); + for (SSHClient client : clients) { + fixture.connectClient(client); + } + // first two checks are ok + Thread.sleep(1000); + Assertions.assertTrue(clients.stream().allMatch(SSHClient::isConnected)); + + // wait for 2nd check to take place, we wait additional 200ms for it to finish + Thread.sleep(1200); + Assertions.assertTrue(clients.stream().noneMatch(SSHClient::isConnected)); + Assertions.assertEquals(0, fixture.getServer().getActiveSessions().size()); + } + + @Test + void testBoundedKeepAlive() throws IOException, InterruptedException { + // 2 threads can handle 32 connections + testWithConnections(32); + } + + private List setupClients(int numOfConnections) { + List clients = new ArrayList<>(); + defaultConfig.setKeepAliveProvider(kp); + + for (int i = 0; i < numOfConnections; i++) { + final SSHClient sshClient = fixture.createClient(defaultConfig); + clients.add(sshClient); + } + return clients; + } +} diff --git a/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java b/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java index 1d07bbe36..9711530f7 100644 --- a/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java +++ b/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java @@ -97,6 +97,15 @@ public SSHClient setupClient(Config config) { return client; } + /** + * create a new uncached client + */ + public SSHClient createClient(Config config) { + SSHClient client = new SSHClient(config); + client.addHostKeyVerifier(fingerprint); + return client; + } + public SSHClient getClient() { if (client != null) { return client;