diff --git a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/EncodingFactory.java b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/EncodingFactory.java index d20eb1b66..543876001 100644 --- a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/EncodingFactory.java +++ b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/EncodingFactory.java @@ -56,7 +56,7 @@ public class EncodingFactory { /** * An encoding factory that can deal with primitive types. */ - public static final EncodingFactory DEFAULT = new EncodingFactory(Collections.EMPTY_MAP, Collections.EMPTY_MAP, Collections.EMPTY_MAP, Collections.EMPTY_MAP); + public static final EncodingFactory DEFAULT = new EncodingFactory(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap()); private final Map, List>> binaryEncoders; private final Map, List>> binaryDecoders; @@ -96,49 +96,46 @@ public boolean canDecodeBinary(final Class type) { } public Encoding createEncoding(final EndpointConfig endpointConfig) { - try { - Map, List>> binaryEncoders = this.binaryEncoders.isEmpty() ? Collections., List>>emptyMap() : new HashMap, List>>(); - Map, List>> binaryDecoders = this.binaryDecoders.isEmpty() ? Collections., List>>emptyMap() : new HashMap, List>>(); - Map, List>> textEncoders = this.textEncoders.isEmpty() ? Collections., List>>emptyMap() : new HashMap, List>>(); - Map, List>> textDecoders = this.textDecoders.isEmpty() ? Collections., List>>emptyMap() : new HashMap, List>>(); - - for (Map.Entry, List>> entry : this.binaryEncoders.entrySet()) { - final List> val = new ArrayList<>(entry.getValue().size()); - binaryEncoders.put(entry.getKey(), val); - for (InstanceFactory factory : entry.getValue()) { - InstanceHandle instance = factory.createInstance(); - instance.getInstance().init(endpointConfig); - val.add(instance); - } - } - for (Map.Entry, List>> entry : this.binaryDecoders.entrySet()) { - final List> val = new ArrayList<>(entry.getValue().size()); - binaryDecoders.put(entry.getKey(), val); - for (InstanceFactory factory : entry.getValue()) { - InstanceHandle instance = factory.createInstance(); - instance.getInstance().init(endpointConfig); - val.add(instance); - } - } - for (Map.Entry, List>> entry : this.textEncoders.entrySet()) { - final List> val = new ArrayList<>(entry.getValue().size()); - textEncoders.put(entry.getKey(), val); - for (InstanceFactory factory : entry.getValue()) { - InstanceHandle instance = factory.createInstance(); - instance.getInstance().init(endpointConfig); - val.add(instance); - } - } - for (Map.Entry, List>> entry : this.textDecoders.entrySet()) { - final List> val = new ArrayList<>(entry.getValue().size()); - textDecoders.put(entry.getKey(), val); - for (InstanceFactory factory : entry.getValue()) { - InstanceHandle instance = factory.createInstance(); - instance.getInstance().init(endpointConfig); - val.add(instance); - } + Map, List>> binaryEncoders = registerCoders(endpointConfig, this.binaryEncoders); + Map, List>> binaryDecoders = registerCoders(endpointConfig, this.binaryDecoders); + Map, List>> textEncoders = registerCoders(endpointConfig, this.textEncoders); + Map, List>> textDecoders = registerCoders(endpointConfig, this.textDecoders); + return new Encoding(binaryEncoders, binaryDecoders, textEncoders, textDecoders); + } + + // Sacrifice some type checks to avoid generics-hell + @SuppressWarnings({"unchecked", "rawtypes"}) + private T registerCoders(EndpointConfig endpointConfig, Map coders) { + if (coders.isEmpty()) { + return (T) Collections.emptyMap(); + } + + Map result = new HashMap(); + for (Map.Entry, List>> entry : ((Map, List>>) coders).entrySet()) { + final List> val = new ArrayList<>(entry.getValue().size()); + result.put(entry.getKey(), val); + for (InstanceFactory factory : entry.getValue()) { + InstanceHandle instance = createInstance(factory); + initializeCoder(endpointConfig, instance); + val.add(instance); } - return new Encoding(binaryEncoders, binaryDecoders, textEncoders, textDecoders); + } + return (T) result; + } + + private void initializeCoder(EndpointConfig endpointConfig, InstanceHandle instance) { + if (instance.getInstance() instanceof Encoder) { + ((Encoder) instance.getInstance()).init(endpointConfig); + } else if (!(instance.getInstance() instanceof Decoder)) { + ((Decoder) instance.getInstance()).init(endpointConfig); + } else { + throw new IllegalStateException("Illegal type: " + ((Decoder) instance.getInstance()).getClass()); + } + } + + private static InstanceHandle createInstance(InstanceFactory factory) { + try { + return factory.createInstance(); } catch (InstantiationException e) { throw new RuntimeException(e); } @@ -155,93 +152,77 @@ public static EncodingFactory createFactory(final ClassIntrospecter classIntrosp final Map, List>> textDecoders = new HashMap<>(); for (Class decoder : decoders) { - if (Decoder.Binary.class.isAssignableFrom(decoder)) { - try { - Method method = decoder.getMethod("decode", ByteBuffer.class); - final Class type = resolveReturnType(method, decoder); - List> list = binaryDecoders.get(type); - if (list == null) { - binaryDecoders.put(type, list = new ArrayList<>()); - } - list.add(classIntrospecter.createInstanceFactory(decoder)); - } catch (NoSuchMethodException e) { - throw JsrWebSocketMessages.MESSAGES.couldNotDetermineTypeOfDecodeMethodForClass(decoder, e); - } - } else if (Decoder.BinaryStream.class.isAssignableFrom(decoder)) { - try { - Method method = decoder.getMethod("decode", InputStream.class); - final Class type = resolveReturnType(method, decoder); - List> list = binaryDecoders.get(type); - if (list == null) { - binaryDecoders.put(type, list = new ArrayList<>()); - } - list.add(classIntrospecter.createInstanceFactory(decoder)); - } catch (NoSuchMethodException e) { - throw JsrWebSocketMessages.MESSAGES.couldNotDetermineTypeOfDecodeMethodForClass(decoder, e); - } - } else if (Decoder.Text.class.isAssignableFrom(decoder)) { - try { - Method method = decoder.getMethod("decode", String.class); - final Class type = resolveReturnType(method, decoder); - List> list = textDecoders.get(type); - if (list == null) { - textDecoders.put(type, list = new ArrayList<>()); - } - list.add(classIntrospecter.createInstanceFactory(decoder)); - } catch (NoSuchMethodException e) { - throw JsrWebSocketMessages.MESSAGES.couldNotDetermineTypeOfDecodeMethodForClass(decoder, e); - } - } else if (Decoder.TextStream.class.isAssignableFrom(decoder)) { - try { - Method method = decoder.getMethod("decode", Reader.class); - final Class type = resolveReturnType(method, decoder); - List> list = textDecoders.get(type); - if (list == null) { - textDecoders.put(type, list = new ArrayList<>()); - } - list.add(createInstanceFactory(classIntrospecter, decoder)); - } catch (NoSuchMethodException e) { - throw JsrWebSocketMessages.MESSAGES.couldNotDetermineTypeOfDecodeMethodForClass(decoder, e); - } - } else { + if (isUnknownDecoderSubclass(decoder)) { throw JsrWebSocketMessages.MESSAGES.didNotImplementKnownDecoderSubclass(decoder); } + + tryRegisterDecoder(classIntrospecter, binaryDecoders, decoder, Decoder.Binary.class, ByteBuffer.class); + tryRegisterDecoder(classIntrospecter, binaryDecoders, decoder, Decoder.BinaryStream.class, InputStream.class); + tryRegisterDecoder(classIntrospecter, textDecoders, decoder, Decoder.Text.class, String.class); + tryRegisterDecoder(classIntrospecter, textDecoders, decoder, Decoder.TextStream.class, Reader.class); } for (Class encoder : encoders) { + if (isUnknownEncoderSubclass(encoder)) { + throw JsrWebSocketMessages.MESSAGES.didNotImplementKnownEncoderSubclass(encoder); + } + if (Encoder.Binary.class.isAssignableFrom(encoder)) { final Class type = findEncodeMethod(encoder, ByteBuffer.class); - List> list = binaryEncoders.get(type); - if (list == null) { - binaryEncoders.put(type, list = new ArrayList<>()); - } + List> list = binaryEncoders.computeIfAbsent(type, k -> new ArrayList<>()); list.add(createInstanceFactory(classIntrospecter, encoder)); - } else if (Encoder.BinaryStream.class.isAssignableFrom(encoder)) { + } + if (Encoder.BinaryStream.class.isAssignableFrom(encoder)) { final Class type = findEncodeMethod(encoder, void.class, OutputStream.class); - List> list = binaryEncoders.get(type); - if (list == null) { - binaryEncoders.put(type, list = new ArrayList<>()); - } + List> list = binaryEncoders.computeIfAbsent(type, k -> new ArrayList<>()); list.add(createInstanceFactory(classIntrospecter, encoder)); - } else if (Encoder.Text.class.isAssignableFrom(encoder)) { + } + if (Encoder.Text.class.isAssignableFrom(encoder)) { final Class type = findEncodeMethod(encoder, String.class); - List> list = textEncoders.get(type); - if (list == null) { - textEncoders.put(type, list = new ArrayList<>()); - } + List> list = textEncoders.computeIfAbsent(type, k -> new ArrayList<>()); list.add(createInstanceFactory(classIntrospecter, encoder)); - } else if (Encoder.TextStream.class.isAssignableFrom(encoder)) { + } + if (Encoder.TextStream.class.isAssignableFrom(encoder)) { final Class type = findEncodeMethod(encoder, void.class, Writer.class); - List> list = textEncoders.get(type); - if (list == null) { - textEncoders.put(type, list = new ArrayList<>()); - } + List> list = textEncoders.computeIfAbsent(type, k -> new ArrayList<>()); list.add(createInstanceFactory(classIntrospecter, encoder)); } } return new EncodingFactory(binaryEncoders, binaryDecoders, textEncoders, textDecoders); } + private static boolean isUnknownEncoderSubclass(Class encoder) { + return !Encoder.Binary.class.isAssignableFrom(encoder) + && !Encoder.BinaryStream.class.isAssignableFrom(encoder) + && !Encoder.Text.class.isAssignableFrom(encoder) + && !Encoder.TextStream.class.isAssignableFrom(encoder); + } + + private static boolean isUnknownDecoderSubclass(Class decoder) { + return !Decoder.Binary.class.isAssignableFrom(decoder) + && !Decoder.BinaryStream.class.isAssignableFrom(decoder) + && !Decoder.Text.class.isAssignableFrom(decoder) + && !Decoder.TextStream.class.isAssignableFrom(decoder); + } + + private static void tryRegisterDecoder(ClassIntrospecter classIntrospecter, Map, List>> binaryDecoders, Class decoder, Class decoderType, Class decodedType) throws DeploymentException { + if (!decoderType.isAssignableFrom(decoder)) { + return; + } + Method method = findDecodeMethod(decoder, decodedType); + final Class type = resolveReturnType(method, decoder); + List> list = binaryDecoders.computeIfAbsent(type, k -> new ArrayList<>()); + list.add(createInstanceFactory(classIntrospecter, decoder)); + } + + private static Method findDecodeMethod(Class decoder, Class type) throws DeploymentException { + try { + return decoder.getMethod("decode", type); + } catch (NoSuchMethodException e) { + throw JsrWebSocketMessages.MESSAGES.couldNotDetermineTypeOfDecodeMethodForClass(decoder, e); + } + } + private static Class resolveReturnType(Method method, Class decoder) { Type genericReturnType = method.getGenericReturnType(); if (genericReturnType instanceof Class) { diff --git a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java index 6078b0959..3e71ec4e4 100644 --- a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java +++ b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java @@ -161,4 +161,7 @@ public interface JsrWebSocketMessages { @Message(id = 3042, value = "Deployment failed due to invalid programmatically added endpoints") RuntimeException deploymentFailedDueToProgramaticErrors(); + + @Message(id = 3043, value = "%s did not implement known decoder interface") + DeploymentException didNotImplementKnownEncoderSubclass(Class decoder); }