diff --git a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java index 4ea7591c..eacaba62 100644 --- a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java +++ b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java @@ -96,6 +96,7 @@ String tag() { public static final int SOCKET_BUFFER_BYTES = -1; public static final boolean DEFAULT_BLOCKING = false; public static final boolean DEFAULT_ENABLE_TELEMETRY = true; + public static final boolean DEFAULT_ENABLE_JDK_SOCKET = true; public static final boolean DEFAULT_ENABLE_AGGREGATION = true; public static final boolean DEFAULT_ENABLE_ORIGIN_DETECTION = true; @@ -248,7 +249,8 @@ public NonBlockingStatsDClient(final NonBlockingStatsDClientBuilder builder) builder.addressLookup, builder.timeout, builder.connectionTimeout, - builder.socketBufferSize); + builder.socketBufferSize, + builder.enableJdkSocket); ThreadFactory threadFactory = builder.threadFactory != null @@ -296,7 +298,8 @@ public NonBlockingStatsDClient(final NonBlockingStatsDClientBuilder builder) builder.telemetryAddressLookup, builder.timeout, builder.connectionTimeout, - builder.socketBufferSize); + builder.socketBufferSize, + builder.enableJdkSocket); // similar settings, but a single worker and non-blocking. telemetryStatsDProcessor = @@ -482,7 +485,8 @@ ClientChannel createByteChannel( Callable addressLookup, int timeout, int connectionTimeout, - int bufferSize) + int bufferSize, + boolean enableJdkSocket) throws Exception { final SocketAddress address = addressLookup.call(); if (address instanceof NamedPipeSocketAddress) { @@ -497,7 +501,11 @@ ClientChannel createByteChannel( switch (unixAddr.getTransportType()) { case UDS_STREAM: return new UnixStreamClientChannel( - unixAddr.getAddress(), timeout, connectionTimeout, bufferSize); + unixAddr.getAddress(), + timeout, + connectionTimeout, + bufferSize, + enableJdkSocket); case UDS_DATAGRAM: case UDS: return new UnixDatagramClientChannel( diff --git a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java index 289b16c2..ba07fbd1 100644 --- a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java +++ b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java @@ -1,5 +1,6 @@ package com.timgroup.statsd; +import java.lang.reflect.Method; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; @@ -52,6 +53,9 @@ public class NonBlockingStatsDClientBuilder implements Cloneable { public boolean enableAggregation = NonBlockingStatsDClient.DEFAULT_ENABLE_AGGREGATION; + /** Enable native JDK support for UDS. Only available on Java 16+. */ + public boolean enableJdkSocket = NonBlockingStatsDClient.DEFAULT_ENABLE_JDK_SOCKET; + /** Telemetry flush interval, in milliseconds. */ public int telemetryFlushInterval = Telemetry.DEFAULT_FLUSH_INTERVAL; @@ -322,6 +326,11 @@ public NonBlockingStatsDClientBuilder originDetectionEnabled(boolean val) { return this; } + public NonBlockingStatsDClientBuilder enableJdkSocket(boolean val) { + enableJdkSocket = val; + return this; + } + /** * Request that all metrics from this client to be enriched to specified tag cardinality. * @@ -523,8 +532,30 @@ protected static Callable staticUnixResolution( return new Callable() { @Override public SocketAddress call() { - final UnixSocketAddress socketAddress = new UnixSocketAddress(path); - return new UnixSocketAddressWithTransport(socketAddress, transportType); + SocketAddress socketAddress; + + // Use native JDK support for UDS on Java 16+ and jnr-unixsocket otherwise + if (VersionUtils.isJavaVersionAtLeast(16) + && NonBlockingStatsDClient.DEFAULT_ENABLE_JDK_SOCKET) { + try { + // Avoid compiling Java 16+ classes in incompatible versions + Class unixDomainSocketAddressClass = + Class.forName("java.net.UnixDomainSocketAddress"); + Method ofMethod = + unixDomainSocketAddressClass.getMethod("of", String.class); + socketAddress = (SocketAddress) ofMethod.invoke(null, path); + } catch (Exception e) { + throw new StatsDClientException( + "Failed to create UnixSocketAddress for native JDK UDS implementation", + e); + } + } else { + socketAddress = new UnixSocketAddress(path); + } + UnixSocketAddressWithTransport result = + new UnixSocketAddressWithTransport(socketAddress, transportType); + + return result; } }; } diff --git a/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java b/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java index 4fccddf6..b8dae3e0 100644 --- a/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java +++ b/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java @@ -16,6 +16,10 @@ class UnixDatagramClientChannel extends DatagramClientChannel { */ UnixDatagramClientChannel(SocketAddress address, int timeout, int bufferSize) throws IOException { + // Ideally we could use native JDK UDS support such as with the UnixStreamClientChannel. + // However, DatagramChannels do not support StandardProtocolFamily.UNIX, so this is + // unavailable. + // See this open issue for updates: https://bugs.openjdk.org/browse/JDK-8297837? super(UnixDatagramChannel.open(), address); // Set send timeout, to handle the case where the transmission buffer is full // If no timeout is set, the send becomes blocking diff --git a/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java b/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java index c86c3c57..66a53efc 100644 --- a/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java +++ b/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java @@ -1,9 +1,12 @@ package com.timgroup.statsd; import java.io.IOException; +import java.lang.reflect.Method; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import jnr.unixsocket.UnixSocketAddress; import jnr.unixsocket.UnixSocketChannel; @@ -11,10 +14,11 @@ /** A ClientChannel for Unix domain sockets. */ public class UnixStreamClientChannel implements ClientChannel { - private final UnixSocketAddress address; + private final SocketAddress address; private final int timeout; private final int connectionTimeout; private final int bufferSize; + private final boolean enableJdkSocket; private SocketChannel delegate; private final ByteBuffer delimiterBuffer = @@ -26,13 +30,18 @@ public class UnixStreamClientChannel implements ClientChannel { * @param address Location of named pipe */ UnixStreamClientChannel( - SocketAddress address, int timeout, int connectionTimeout, int bufferSize) + SocketAddress address, + int timeout, + int connectionTimeout, + int bufferSize, + boolean enableJdkSocket) throws IOException { this.delegate = null; - this.address = (UnixSocketAddress) address; + this.address = address; this.timeout = timeout; this.connectionTimeout = connectionTimeout; this.bufferSize = bufferSize; + this.enableJdkSocket = enableJdkSocket; } @Override @@ -87,19 +96,37 @@ public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline) throws IOException { int remaining = bb.remaining(); int written = 0; + long timeoutMs = timeout; + while (remaining > 0) { int read = delegate.write(bb); - - // If we haven't written anything yet, we can still return - if (read == 0 && canReturnOnTimeout && written == 0) { - return written; + if (read > 0) { + remaining -= read; + written += read; + continue; } - remaining -= read; - written += read; + if (read == 0) { + if (canReturnOnTimeout && written == 0) { + return written; + } + + try (Selector selector = Selector.open()) { + SelectionKey key = delegate.register(selector, SelectionKey.OP_WRITE); + long selectTimeout = timeoutMs; + + if (deadline > 0) { + long remainingNs = deadline - System.nanoTime(); + if (remainingNs <= 0) { + throw new IOException("Write timed out"); + } + selectTimeout = Math.min(timeoutMs, remainingNs / 1_000_000L); + } - if (deadline > 0 && System.nanoTime() > deadline) { - throw new IOException("Write timed out"); + if (selector.select(selectTimeout) == 0) { + throw new IOException("Write timed out after " + selectTimeout + "ms"); + } + } } } return written; @@ -127,40 +154,112 @@ private void connect() throws IOException { } } - UnixSocketChannel delegate = UnixSocketChannel.create(); - long deadline = System.nanoTime() + connectionTimeout * 1_000_000L; + // Use native JDK support for UDS on Java 16+ and jnr-unixsocket otherwise + if (VersionUtils.isJavaVersionAtLeast(16) && enableJdkSocket) { + try { + // Avoid compiling Java 16+ classes in incompatible versions + Class protocolFamilyClass = Class.forName("java.net.ProtocolFamily"); + Class standardProtocolFamilyClass = + Class.forName("java.net.StandardProtocolFamily"); + Object unixProtocol = + Enum.valueOf((Class) standardProtocolFamilyClass, "UNIX"); + Method openMethod = SocketChannel.class.getMethod("open", protocolFamilyClass); + SocketChannel channel = (SocketChannel) openMethod.invoke(null, unixProtocol); + + channel.configureBlocking(false); + + try { + SocketAddress connectAddress = address; + if (address instanceof UnixSocketAddressWithTransport) { + connectAddress = ((UnixSocketAddressWithTransport) address).getAddress(); + } + + Method connectMethod = + SocketChannel.class.getMethod("connect", SocketAddress.class); + boolean connected = (boolean) connectMethod.invoke(channel, connectAddress); + + if (!connected) { + try (Selector selector = Selector.open()) { + SelectionKey key = channel.register(selector, SelectionKey.OP_CONNECT); + int timeoutMs = connectionTimeout > 0 ? connectionTimeout : 1000; + int ready = selector.select(timeoutMs); + + if (ready == 0) { + throw new IOException( + "Connection timed out after " + timeoutMs + "ms"); + } + + if (key.isConnectable()) { + connected = channel.finishConnect(); + if (!connected) { + throw new IOException("Failed to complete connection"); + } + } + } + } + } catch (Exception e) { + try { + channel.close(); + } catch (IOException __) { + // ignore + } + throw e; + } + + this.delegate = channel; + return; + } catch (Exception e) { + Throwable cause = e.getCause(); + if (e instanceof java.lang.reflect.InvocationTargetException + && cause instanceof IOException) { + throw (IOException) cause; + } + throw new IOException( + "Failed to create UnixStreamClientChannel for native UDS implementation", + e); + } + } + // Default to jnr-unixsocket if Java version is < 16 or native support is disabled + UnixSocketChannel channel = UnixSocketChannel.create(); + if (connectionTimeout > 0) { // Set connect timeout, this should work at least on linux // https://elixir.bootlin.com/linux/v5.7.4/source/net/unix/af_unix.c#L1696 - // We'd have better timeout support if we used Java 16's native Unix domain socket - // support (JEP 380) - delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout); + channel.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout); } + try { - if (!delegate.connect(address)) { + UnixSocketAddress unixAddress; + if (address instanceof UnixSocketAddress) { + unixAddress = (UnixSocketAddress) address; + } else { + unixAddress = new UnixSocketAddress(address.toString()); + } + + if (!channel.connect(unixAddress)) { if (connectionTimeout > 0 && System.nanoTime() > deadline) { throw new IOException("Connection timed out"); } - if (!delegate.finishConnect()) { + if (!channel.finishConnect()) { throw new IOException("Connection failed"); } } - delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, Math.max(timeout, 0)); + channel.setOption(UnixSocketOptions.SO_SNDTIMEO, Math.max(timeout, 0)); if (bufferSize > 0) { - delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize); + channel.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize); } } catch (Exception e) { try { - delegate.close(); + channel.close(); } catch (IOException __) { // ignore } throw e; } - this.delegate = delegate; + this.delegate = channel; } @Override diff --git a/src/main/java/com/timgroup/statsd/VersionUtils.java b/src/main/java/com/timgroup/statsd/VersionUtils.java new file mode 100644 index 00000000..7fba319c --- /dev/null +++ b/src/main/java/com/timgroup/statsd/VersionUtils.java @@ -0,0 +1,114 @@ +package com.timgroup.statsd; + +import java.util.ArrayList; +import java.util.List; + +// Logic copied from dd-trace-java Platform class. See: +// https://github.com/DataDog/dd-trace-java/blob/master/internal-api/src/main/java/datadog/trace/api/Platform.java +public class VersionUtils { + private static final Version JAVA_VERSION = + parseJavaVersion(System.getProperty("java.version")); + + private static Version parseJavaVersion(String javaVersion) { + // Remove pre-release part, usually -ea + final int indexOfDash = javaVersion.indexOf('-'); + if (indexOfDash >= 0) { + javaVersion = javaVersion.substring(0, indexOfDash); + } + + int major = 0; + int minor = 0; + int update = 0; + + try { + List nums = splitDigits(javaVersion); + major = nums.get(0); + + // for java 1.6/1.7/1.8 + if (major == 1) { + major = nums.get(1); + minor = nums.get(2); + update = nums.get(3); + } else { + minor = nums.get(1); + update = nums.get(2); + } + } catch (NumberFormatException | IndexOutOfBoundsException e) { + // unable to parse version string - do nothing + } + return new Version(major, minor, update); + } + + private static List splitDigits(String str) { + List results = new ArrayList<>(); + + int len = str.length(); + + int value = 0; + for (int i = 0; i < len; i++) { + char ch = str.charAt(i); + if (ch >= '0' && ch <= '9') { + value = value * 10 + (ch - '0'); + } else if (ch == '.' || ch == '_' || ch == '+') { + results.add(value); + value = 0; + } else { + throw new NumberFormatException(); + } + } + results.add(value); + return results; + } + + static final class Version { + public final int major; + public final int minor; + public final int update; + + public Version(int major, int minor, int update) { + this.major = major; + this.minor = minor; + this.update = update; + } + + public boolean is(int major) { + return this.major == major; + } + + public boolean is(int major, int minor) { + return this.major == major && this.minor == minor; + } + + public boolean is(int major, int minor, int update) { + return this.major == major && this.minor == minor && this.update == update; + } + + public boolean isAtLeast(int major, int minor, int update) { + return isAtLeast(this.major, this.minor, this.update, major, minor, update); + } + + private static boolean isAtLeast( + int major, + int minor, + int update, + int atLeastMajor, + int atLeastMinor, + int atLeastUpdate) { + return (major > atLeastMajor) + || (major == atLeastMajor && minor > atLeastMinor) + || (major == atLeastMajor && minor == atLeastMinor && update >= atLeastUpdate); + } + } + + public static boolean isJavaVersionAtLeast(int major) { + return isJavaVersionAtLeast(major, 0, 0); + } + + public static boolean isJavaVersionAtLeast(int major, int minor) { + return isJavaVersionAtLeast(major, minor, 0); + } + + public static boolean isJavaVersionAtLeast(int major, int minor, int update) { + return JAVA_VERSION.isAtLeast(major, minor, update); + } +} diff --git a/src/test/java/com/timgroup/statsd/BuilderAddressTest.java b/src/test/java/com/timgroup/statsd/BuilderAddressTest.java index d37703d5..bf9de64d 100644 --- a/src/test/java/com/timgroup/statsd/BuilderAddressTest.java +++ b/src/test/java/com/timgroup/statsd/BuilderAddressTest.java @@ -213,9 +213,12 @@ public void address_resolution() throws Exception { if (expected instanceof UnixSocketAddressWithTransport) { UnixSocketAddressWithTransport a = (UnixSocketAddressWithTransport) actual; UnixSocketAddressWithTransport e = (UnixSocketAddressWithTransport) expected; + // Native JDK UDS support returns a SocketAddress rather than a UnixSocketAddress assertEquals( ((FakeUnixSocketAddress) e.getAddress()).getPath(), - ((UnixSocketAddress) a.getAddress()).path()); + a.getAddress() instanceof UnixSocketAddress + ? ((UnixSocketAddress) a.getAddress()).path() + : a.getAddress().toString()); assertEquals(e.getTransportType(), a.getTransportType()); } else { assertEquals(expected, actual); diff --git a/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java b/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java index 8ed7a030..4895657d 100644 --- a/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java +++ b/src/test/java/com/timgroup/statsd/NonBlockingStatsDClientTest.java @@ -1285,7 +1285,8 @@ ClientChannel createByteChannel( Callable addressLookup, int timeout, int connectionTimeout, - int bufferSize) + int bufferSize, + boolean enableJdkSocket) throws Exception { return new DatagramClientChannel(addressLookup.call()) { @Override @@ -1336,7 +1337,8 @@ ClientChannel createByteChannel( Callable addressLookup, int timeout, int connectionTimeout, - int bufferSize) + int bufferSize, + boolean enableJdkSocket) throws Exception { return new DatagramClientChannel(addressLookup.call()) { @Override diff --git a/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java b/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java index 6c480a6a..56a425a7 100644 --- a/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java +++ b/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java @@ -23,7 +23,8 @@ public class UnixStreamSocketDummyStatsDServer extends DummyStatsDServer { public UnixStreamSocketDummyStatsDServer(String socketPath) throws IOException { server = UnixServerSocketChannel.open(); server.configureBlocking(true); - server.socket().bind(new UnixSocketAddress(socketPath)); + UnixSocketAddress address = new UnixSocketAddress(socketPath); + server.socket().bind(address); this.listen(); } diff --git a/src/test/java/com/timgroup/statsd/UnixStreamSocketTest.java b/src/test/java/com/timgroup/statsd/UnixStreamSocketTest.java index d7e64e4d..dc621861 100644 --- a/src/test/java/com/timgroup/statsd/UnixStreamSocketTest.java +++ b/src/test/java/com/timgroup/statsd/UnixStreamSocketTest.java @@ -95,11 +95,14 @@ public void assert_default_uds_size() throws Exception { @Test(timeout = 5000L) public void sends_to_statsd() throws Exception { + Thread.sleep(100); + server.clear(); + for (long i = 0; i < 5; i++) { client.gauge("mycount", i); server.waitForMessage(); String expected = String.format("my.prefix.mycount:%d|g", i); - assertThat(server.messagesReceived(), contains(expected)); + assertThat(server.messagesReceived(), hasItem(expected)); server.clear(); } assertThat(lastException.getMessage(), nullValue());