diff --git a/servlet/src/test/java/io/undertow/servlet/test/streams/Http2InputStreamTestCase.java b/servlet/src/test/java/io/undertow/servlet/test/streams/Http2InputStreamTestCase.java index 68183eb7a..1667c3d4b 100644 --- a/servlet/src/test/java/io/undertow/servlet/test/streams/Http2InputStreamTestCase.java +++ b/servlet/src/test/java/io/undertow/servlet/test/streams/Http2InputStreamTestCase.java @@ -1,10 +1,8 @@ package io.undertow.servlet.test.streams; -import io.undertow.httpcore.StatusCodes; import io.undertow.servlet.api.ServletInfo; import io.undertow.servlet.test.util.DeploymentUtils; import io.undertow.testutils.DefaultServer; -import io.undertow.testutils.HttpClientUtils; import io.undertow.testutils.TestHttpClient; import io.vertx.core.Handler; import io.vertx.core.Vertx; @@ -13,35 +11,19 @@ import io.vertx.core.http.HttpClientRequest; import io.vertx.core.http.HttpClientResponse; import io.vertx.core.http.HttpVersion; -import org.apache.commons.codec.binary.Hex; -import org.apache.http.HttpResponse; -import org.apache.http.client.methods.CloseableHttpResponse; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.entity.InputStreamEntity; -import org.apache.http.entity.StringEntity; -import org.apache.http.impl.client.CloseableHttpClient; -import org.apache.http.impl.client.HttpClients; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InterruptedIOException; -import java.io.OutputStream; -import java.net.HttpURLConnection; -import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; diff --git a/websocket/core/src/main/java/io/undertow/websockets/ServerWebSocketContainer.java b/websocket/core/src/main/java/io/undertow/websockets/ServerWebSocketContainer.java index 8c2f61c67..fe046d265 100644 --- a/websocket/core/src/main/java/io/undertow/websockets/ServerWebSocketContainer.java +++ b/websocket/core/src/main/java/io/undertow/websockets/ServerWebSocketContainer.java @@ -25,20 +25,23 @@ import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandshaker; import io.undertow.websockets.annotated.AnnotatedEndpointFactory; import io.undertow.websockets.handshake.Handshake; -import io.undertow.websockets.util.ObjectIntrospecter; import io.undertow.websockets.util.ConstructorObjectFactory; +import io.undertow.websockets.util.ContextSetupHandler; import io.undertow.websockets.util.ImmediateObjectHandle; import io.undertow.websockets.util.ObjectFactory; import io.undertow.websockets.util.ObjectHandle; +import io.undertow.websockets.util.ObjectIntrospecter; import io.undertow.websockets.util.PathTemplate; -import io.undertow.websockets.util.ContextSetupHandler; import javax.net.ssl.SSLContext; import javax.websocket.ClientEndpoint; import javax.websocket.ClientEndpointConfig; import javax.websocket.CloseReason; +import javax.websocket.Decoder; import javax.websocket.DeploymentException; +import javax.websocket.Encoder; import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; import javax.websocket.Extension; import javax.websocket.HandshakeResponse; import javax.websocket.Session; @@ -74,6 +77,7 @@ import java.util.function.Supplier; import static java.lang.System.currentTimeMillis; +import static java.util.Collections.emptyMap; /** @@ -180,18 +184,6 @@ protected Supplier getExecutorSupplier() { return executorSupplier; } - public Session connectToServer(final Object annotatedEndpointInstance, WebsocketConnectionBuilder connectionBuilder) throws DeploymentException, IOException { - if (closed) { - throw new ClosedChannelException(); - } - ConfiguredClientEndpoint config = getClientEndpoint(annotatedEndpointInstance.getClass(), false); - if (config == null) { - throw JsrWebSocketMessages.MESSAGES.notAValidClientEndpointType(annotatedEndpointInstance.getClass()); - } - Endpoint instance = config.getFactory().createInstance(new ImmediateObjectHandle<>(annotatedEndpointInstance)); - return connectToServerInternal(instance, config, connectionBuilder); - } - @Override public Session connectToServer(final Object annotatedEndpointInstance, final URI path) throws DeploymentException, IOException { if (closed) { @@ -580,7 +572,8 @@ private synchronized void addEndpointInternal(final Class endpoint, boolean r Class configuratorClass = serverEndpoint.configurator(); EncodingFactory encodingFactory = EncodingFactory.createFactory(objectIntrospecter, serverEndpoint.decoders(), serverEndpoint.encoders()); - AnnotatedEndpointFactory annotatedEndpointFactory = AnnotatedEndpointFactory.create(endpoint, encodingFactory, template.getParameterNames()); + EndpointConfig endpointConfig = createEndpointConfig(serverEndpoint.encoders(), serverEndpoint.decoders()); + AnnotatedEndpointFactory annotatedEndpointFactory = AnnotatedEndpointFactory.create(endpoint, encodingFactory, template.getParameterNames(), endpointConfig); ObjectFactory ObjectFactory = null; try { ObjectFactory = objectIntrospecter.createInstanceFactory(endpoint); @@ -605,12 +598,12 @@ public ObjectHandle createInstance() { } ServerEndpointConfig config = ServerEndpointConfig.Builder.create(endpoint, serverEndpoint.value()) - .decoders(Arrays.asList(serverEndpoint.decoders())) - .encoders(Arrays.asList(serverEndpoint.encoders())) - .subprotocols(Arrays.asList(serverEndpoint.subprotocols())) - .extensions(Collections.emptyList()) - .configurator(configurator) - .build(); + .decoders(Arrays.asList(serverEndpoint.decoders())) + .encoders(Arrays.asList(serverEndpoint.encoders())) + .subprotocols(Arrays.asList(serverEndpoint.subprotocols())) + .extensions(Collections.emptyList()) + .configurator(configurator) + .build(); ConfiguredServerEndpoint confguredServerEndpoint = new ConfiguredServerEndpoint(config, ObjectFactory, template, encodingFactory, annotatedEndpointFactory, installedExtensions); @@ -638,17 +631,19 @@ public ObjectHandle createInstance() { } } } - AnnotatedEndpointFactory factory = AnnotatedEndpointFactory.create(endpoint, encodingFactory, Collections.emptySet()); + + EndpointConfig endpointConfig = createEndpointConfig(clientEndpoint.encoders(), clientEndpoint.decoders()); + AnnotatedEndpointFactory factory = AnnotatedEndpointFactory.create(endpoint, encodingFactory, Collections.emptySet(), endpointConfig); ClientEndpointConfig.Configurator configurator = null; configurator = objectIntrospecter.createInstanceFactory(clientEndpoint.configurator()).createInstance().getInstance(); ClientEndpointConfig config = ClientEndpointConfig.Builder.create() - .decoders(Arrays.asList(clientEndpoint.decoders())) - .encoders(Arrays.asList(clientEndpoint.encoders())) - .preferredSubprotocols(Arrays.asList(clientEndpoint.subprotocols())) - .configurator(configurator) - .build(); + .decoders(Arrays.asList(clientEndpoint.decoders())) + .encoders(Arrays.asList(clientEndpoint.encoders())) + .preferredSubprotocols(Arrays.asList(clientEndpoint.subprotocols())) + .configurator(configurator) + .build(); ConfiguredClientEndpoint configuredClientEndpoint = new ConfiguredClientEndpoint(config, factory, encodingFactory, ObjectFactory); clientEndpoints.put(endpoint, configuredClientEndpoint); @@ -657,6 +652,25 @@ public ObjectHandle createInstance() { } } + private EndpointConfig createEndpointConfig(Class[] encoders, Class[] decoders) { + return new EndpointConfig() { + @Override + public List> getEncoders() { + return Arrays.asList(encoders); + } + + @Override + public List> getDecoders() { + return Arrays.asList(decoders); + } + + @Override + public Map getUserProperties() { + return emptyMap(); + } + }; + } + protected void handleAddingFilterMapping() { } @@ -684,7 +698,7 @@ public void addEndpoint(final ServerEndpointConfig endpoint) throws DeploymentEx AnnotatedEndpointFactory annotatedEndpointFactory = null; if (!Endpoint.class.isAssignableFrom(endpoint.getEndpointClass())) { // We may want to check that the path in @ServerEndpoint matches the specified path, and throw if they are not equivalent - annotatedEndpointFactory = AnnotatedEndpointFactory.create(endpoint.getEndpointClass(), encodingFactory, template.getParameterNames()); + annotatedEndpointFactory = AnnotatedEndpointFactory.create(endpoint.getEndpointClass(), encodingFactory, template.getParameterNames(), endpoint); } ConfiguredServerEndpoint confguredServerEndpoint = new ConfiguredServerEndpoint(endpoint, null, template, encodingFactory, annotatedEndpointFactory, endpoint.getExtensions()); configuredServerEndpoints.add(confguredServerEndpoint); diff --git a/websocket/core/src/main/java/io/undertow/websockets/annotated/AnnotatedEndpointFactory.java b/websocket/core/src/main/java/io/undertow/websockets/annotated/AnnotatedEndpointFactory.java index cbaf8f977..58d22aa7b 100644 --- a/websocket/core/src/main/java/io/undertow/websockets/annotated/AnnotatedEndpointFactory.java +++ b/websocket/core/src/main/java/io/undertow/websockets/annotated/AnnotatedEndpointFactory.java @@ -73,8 +73,7 @@ private AnnotatedEndpointFactory(final Class endpointClass, final BoundMethod } - public static AnnotatedEndpointFactory create(final Class endpointClass, final EncodingFactory encodingFactory, final Set paths) throws DeploymentException { - final Set> found = new HashSet<>(); + public static AnnotatedEndpointFactory create(final Class endpointClass, final EncodingFactory encodingFactory, final Set paths, final EndpointConfig endpointConfig) throws DeploymentException { BoundMethod onOpen = null; BoundMethod onClose = null; BoundMethod onError = null; @@ -86,43 +85,40 @@ public static AnnotatedEndpointFactory create(final Class endpointClass, fina do { for (final Method method : c.getDeclaredMethods()) { if (method.isAnnotationPresent(OnOpen.class)) { - if (found.contains(OnOpen.class)) { + if (onOpen != null) { if (!onOpen.overrides(method)) { throw JsrWebSocketMessages.MESSAGES.moreThanOneAnnotation(OnOpen.class); } else { continue; } } - found.add(OnOpen.class); onOpen = new BoundMethod(method, null, false, 0, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(method, EndpointConfig.class, true), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); } if (method.isAnnotationPresent(OnClose.class)) { - if (found.contains(OnClose.class)) { + if (onClose != null) { if (!onClose.overrides(method)) { throw JsrWebSocketMessages.MESSAGES.moreThanOneAnnotation(OnClose.class); } else { continue; } } - found.add(OnClose.class); onClose = new BoundMethod(method, null, false, 0, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(method, CloseReason.class, true), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); } if (method.isAnnotationPresent(OnError.class)) { - if (found.contains(OnError.class)) { + if (onError != null) { if (!onError.overrides(method)) { throw JsrWebSocketMessages.MESSAGES.moreThanOneAnnotation(OnError.class); } else { continue; } } - found.add(OnError.class); onError = new BoundMethod(method, null, false, 0, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(method, Throwable.class, false), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); } if (method.isAnnotationPresent(OnMessage.class) && !method.isBridge()) { if (binaryMessage != null && binaryMessage.overrides(method)) { @@ -153,7 +149,7 @@ public static AnnotatedEndpointFactory create(final Class endpointClass, fina } textMessage = new BoundMethod(method, param, true, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(i, param), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; } else if (encodingFactory.canDecodeBinary(param)) { @@ -162,7 +158,7 @@ public static AnnotatedEndpointFactory create(final Class endpointClass, fina } binaryMessage = new BoundMethod(method, param, true, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(i, param), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; } else if (param.equals(byte[].class)) { @@ -172,7 +168,7 @@ public static AnnotatedEndpointFactory create(final Class endpointClass, fina binaryMessage = new BoundMethod(method, byte[].class, false, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(method, boolean.class, true), new BoundSingleParameter(i, byte[].class), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; } else if (param.equals(ByteBuffer.class)) { @@ -183,7 +179,7 @@ public static AnnotatedEndpointFactory create(final Class endpointClass, fina maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(method, boolean.class, true), new BoundSingleParameter(i, ByteBuffer.class), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; @@ -194,7 +190,7 @@ maxMessageSize, new BoundSingleParameter(method, Session.class, true), binaryMessage = new BoundMethod(method, InputStream.class, false, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(i, InputStream.class), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; @@ -205,7 +201,7 @@ maxMessageSize, new BoundSingleParameter(method, Session.class, true), textMessage = new BoundMethod(method, String.class, false, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(method, boolean.class, true), new BoundSingleParameter(i, String.class), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; @@ -216,7 +212,7 @@ maxMessageSize, new BoundSingleParameter(method, Session.class, true), textMessage = new BoundMethod(method, Reader.class, false, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(i, Reader.class), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; @@ -226,7 +222,7 @@ maxMessageSize, new BoundSingleParameter(method, Session.class, true), } pongMessage = new BoundMethod(method, PongMessage.class, false, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(i, PongMessage.class), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; break; } @@ -240,7 +236,7 @@ maxMessageSize, new BoundSingleParameter(method, Session.class, true), textMessage = new BoundMethod(method, boolClass, true, maxMessageSize, new BoundSingleParameter(method, Session.class, true), new BoundSingleParameter(method, boolean.class, true), new BoundSingleParameter(booleanLocation, boolClass), - createBoundPathParameters(method, paths, endpointClass)); + createBoundPathParameters(method, paths, endpointClass, encodingFactory, endpointConfig)); messageHandled = true; } if (!messageHandled) { @@ -253,8 +249,8 @@ maxMessageSize, new BoundSingleParameter(method, Session.class, true), return new AnnotatedEndpointFactory(endpointClass, onOpen, onClose, onError, textMessage, binaryMessage, pongMessage); } - private static BoundPathParameters createBoundPathParameters(final Method method, Set paths, Class endpointClass) throws DeploymentException { - return new BoundPathParameters(pathParams(method), method, endpointClass, paths); + private static BoundPathParameters createBoundPathParameters(final Method method, Set paths, Class endpointClass, final EncodingFactory encodingFactory, final EndpointConfig endpointConfig) throws DeploymentException { + return new BoundPathParameters(pathParams(method), method, endpointClass, paths, encodingFactory, endpointConfig); } @@ -364,24 +360,20 @@ public Class getType() { */ private static class BoundPathParameters implements BoundParameter { - private final Class endpointClass; - private final Set paths; private final String[] positions; private final Encoding[] encoders; private final Class[] types; - BoundPathParameters(final String[] positions, final Method method, Class endpointClass, Set paths) throws DeploymentException { + BoundPathParameters(final String[] positions, final Method method, Class endpointClass, Set paths, final EncodingFactory encodingFactory, final EndpointConfig endpointConfig) throws DeploymentException { this.positions = positions; - this.endpointClass = endpointClass; - this.paths = paths; this.encoders = new Encoding[positions.length]; this.types = new Class[positions.length]; for (int i = 0; i < positions.length; ++i) { Class type = method.getParameterTypes()[i]; Annotation[] annotations = method.getParameterAnnotations()[i]; - for (int j = 0; j < annotations.length; ++j) { - if (annotations[j] instanceof PathParam) { - PathParam param = (PathParam) annotations[j]; + for (Annotation annotation : annotations) { + if (annotation instanceof PathParam) { + PathParam param = (PathParam) annotation; if (!paths.contains(param.value())) { JsrWebSocketLogger.ROOT_LOGGER.pathTemplateNotFound(endpointClass, param, method, paths); } @@ -390,8 +382,8 @@ private static class BoundPathParameters implements BoundParameter { if (positions[i] == null || type == null || type == String.class) { continue; } - if (EncodingFactory.DEFAULT.canEncodeText(type)) { - encoders[i] = EncodingFactory.DEFAULT.createEncoding(EmptyEndpointConfig.INSTANCE); + if (encodingFactory.canDecodeText(type)) { + encoders[i] = encodingFactory.createEncoding(endpointConfig); types[i] = type; } else { diff --git a/websocket/servlet/src/main/java/io/undertow/websockets/servlet/ServletServerWebSocketContainer.java b/websocket/servlet/src/main/java/io/undertow/websockets/servlet/ServletServerWebSocketContainer.java index cff2e90af..013034bf0 100644 --- a/websocket/servlet/src/main/java/io/undertow/websockets/servlet/ServletServerWebSocketContainer.java +++ b/websocket/servlet/src/main/java/io/undertow/websockets/servlet/ServletServerWebSocketContainer.java @@ -25,7 +25,10 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.websocket.Decoder; +import javax.websocket.Encoder; import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; import javax.websocket.Extension; import javax.websocket.server.ServerEndpointConfig; import java.io.IOException; @@ -101,7 +104,8 @@ public ObjectHandle createInstance() { AnnotatedEndpointFactory annotatedEndpointFactory = null; if (!Endpoint.class.isAssignableFrom(sec.getEndpointClass())) { - annotatedEndpointFactory = AnnotatedEndpointFactory.create(sec.getEndpointClass(), encodingFactory, pt.getParameterNames()); + annotatedEndpointFactory = AnnotatedEndpointFactory.create(sec.getEndpointClass(), encodingFactory, pt.getParameterNames(), + createEndpointConfigurationFromConfig(sec)); } @@ -148,6 +152,25 @@ public void accept(ChannelHandlerContext context) { } } + private EndpointConfig createEndpointConfigurationFromConfig(ServerEndpointConfig sec) { + return new EndpointConfig() { + @Override + public List> getEncoders() { + return sec.getEncoders(); + } + + @Override + public List> getDecoders() { + return sec.getDecoders(); + } + + @Override + public Map getUserProperties() { + return sec.getUserProperties(); + } + }; + } + protected void handleAddingFilterMapping() { if (contextToAddFilter != null) { diff --git a/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java b/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java index 68e16d6a3..23dbe438a 100644 --- a/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java +++ b/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java @@ -17,32 +17,6 @@ */ package io.undertow.websockets.jsr.test.annotated; -import java.io.IOException; -import java.net.URI; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import javax.servlet.ServletException; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.websocket.ClientEndpoint; -import javax.websocket.CloseReason; -import javax.websocket.OnClose; -import javax.websocket.Session; -import javax.websocket.server.ServerEndpointConfig; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; -import org.junit.runner.RunWith; - import io.netty.buffer.Unpooled; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; @@ -61,6 +35,31 @@ import io.undertow.websockets.WebSocketDeploymentInfo; import io.undertow.websockets.jsr.test.FrameChecker; import io.undertow.websockets.jsr.test.WebSocketTestClient; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.ClientEndpoint; +import javax.websocket.CloseReason; +import javax.websocket.OnClose; +import javax.websocket.Session; +import javax.websocket.server.ServerEndpointConfig; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; /** * @author Norman Maurer @@ -87,6 +86,7 @@ public static void setup() throws Exception { .addEndpoint(AnnotatedClientEndpoint.class) .addEndpoint(AnnotatedClientEndpointWithConfigurator.class) .addEndpoint(IncrementEndpoint.class) + .addEndpoint(UUIDEndpoint.class) .addEndpoint(EncodingEndpoint.class) .addEndpoint(EncodingGenericsEndpoint.class) .addEndpoint(TimeoutEndpoint.class) @@ -127,19 +127,23 @@ public void testStringOnMessage() throws Exception { final byte[] payload = "hello".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/chat/Stuart")); + WebSocketTestClient client = createTestClient("/ws/chat/Stuart"); client.connect(); client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello Stuart".getBytes(), latch)); latch.get(); client.destroy(); } + private WebSocketTestClient createTestClient(String s) throws URISyntaxException { + return new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + s)); + } + @Test public void testStringOnMessageAddedProgramatically() throws Exception { final byte[] payload = "foo".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/programmatic")); + WebSocketTestClient client = createTestClient("/ws/programmatic"); client.connect(); client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "oof".getBytes(), latch)); latch.get(); @@ -164,7 +168,7 @@ public void testWebSocketInRootContext() throws Exception { final byte[] payload = "hello".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws")); + WebSocketTestClient client = createTestClient("/ws"); client.connect(); client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello".getBytes(), latch)); latch.get(); @@ -306,32 +310,43 @@ public void testImplicitIntegerConversion() throws Exception { final byte[] payload = "12".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/increment/2")); + WebSocketTestClient client = createTestClient("/ws/increment/2"); client.connect(); client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "14".getBytes(), latch)); latch.get(); client.destroy(); } - @Test public void testEncodingAndDecodingText() throws Exception { final byte[] payload = "hello".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/encoding/Stuart")); + WebSocketTestClient client = createTestClient("/ws/encoding/Stuart"); client.connect(); client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello Stuart".getBytes(), latch)); latch.get(); client.destroy(); } + @Test + public void testPathParamDecoder() throws Exception { + final byte[] payload = "hello".getBytes(); + final CompletableFuture latch = new CompletableFuture(); + + WebSocketTestClient client = createTestClient("/ws/uuid/40164304-B94D-4332-AC31-09D7F9A8B943"); + client.connect(); + client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello40164304-b94d-4332-ac31-09d7f9a8b943".getBytes(), latch)); + latch.get(); + client.destroy(); + } + @Test public void testEncodingAndDecodingBinary() throws Exception { final byte[] payload = "hello".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/encoding/Stuart")); + WebSocketTestClient client = createTestClient("/ws/encoding/Stuart"); client.connect(); client.send(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello Stuart".getBytes(), latch)); latch.get(); @@ -343,7 +358,7 @@ public void testEncodingWithGenericSuperclass() throws Exception { final byte[] payload = "hello".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/encodingGenerics/Stuart")); + WebSocketTestClient client = createTestClient("/ws/encodingGenerics/Stuart"); client.connect(); client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello Stuart".getBytes(), latch)); latch.get(); @@ -355,7 +370,7 @@ public void testRequestUri() throws Exception { final byte[] payload = "hello".getBytes(); final CompletableFuture latch = new CompletableFuture(); - WebSocketTestClient client = new WebSocketTestClient(new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/request?a=b")); + WebSocketTestClient client = createTestClient("/ws/request?a=b"); client.connect(); client.send(new TextWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "/ws/request?a=b".getBytes(), latch)); latch.get(); diff --git a/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/UUIDDecoder.java b/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/UUIDDecoder.java new file mode 100644 index 000000000..57ca1dec1 --- /dev/null +++ b/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/UUIDDecoder.java @@ -0,0 +1,31 @@ +package io.undertow.websockets.jsr.test.annotated; + +import javax.websocket.Decoder; +import javax.websocket.EndpointConfig; +import java.util.UUID; + +public class UUIDDecoder implements Decoder.Text { + + @Override + public void init(EndpointConfig config) { + } + + @Override + public void destroy() { + } + + @Override + public UUID decode(String s) { + return UUID.fromString(s); + } + + @Override + public boolean willDecode(String s) { + try { + UUID.fromString(s); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } +} diff --git a/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/UUIDEndpoint.java b/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/UUIDEndpoint.java new file mode 100644 index 000000000..4f712b512 --- /dev/null +++ b/websocket/servlet/src/test/java/io/undertow/websockets/jsr/test/annotated/UUIDEndpoint.java @@ -0,0 +1,44 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2014 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.undertow.websockets.jsr.test.annotated; + +import javax.websocket.EndpointConfig; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.Session; +import javax.websocket.server.PathParam; +import javax.websocket.server.ServerEndpoint; +import java.util.UUID; + +@ServerEndpoint(value = "/uuid/{id}", decoders = UUIDDecoder.class) +public class UUIDEndpoint { + + UUID id; + + @OnOpen + public void open(Session session, EndpointConfig config, @PathParam("id") UUID id) { + this.id = id; + } + + @OnMessage + public String handleMessage(String message) { + return message + id; + } + +} diff --git a/websocket/vertx/src/main/java/io/undertow/websockets/vertx/VertxServerWebSocketContainer.java b/websocket/vertx/src/main/java/io/undertow/websockets/vertx/VertxServerWebSocketContainer.java index 7dc28a4e8..4d0906bb6 100644 --- a/websocket/vertx/src/main/java/io/undertow/websockets/vertx/VertxServerWebSocketContainer.java +++ b/websocket/vertx/src/main/java/io/undertow/websockets/vertx/VertxServerWebSocketContainer.java @@ -21,6 +21,7 @@ import io.vertx.ext.web.RoutingContext; import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; import javax.websocket.Extension; import javax.websocket.server.ServerEndpointConfig; import java.net.InetSocketAddress; @@ -92,7 +93,7 @@ public ObjectHandle createInstance() { AnnotatedEndpointFactory annotatedEndpointFactory = null; if (!Endpoint.class.isAssignableFrom(sec.getEndpointClass())) { - annotatedEndpointFactory = AnnotatedEndpointFactory.create(sec.getEndpointClass(), encodingFactory, pt.getParameterNames()); + annotatedEndpointFactory = AnnotatedEndpointFactory.create(sec.getEndpointClass(), encodingFactory, pt.getParameterNames(), config); }