Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,27 @@
package org.springframework.boot.rsocket.autoconfigure;

import io.rsocket.transport.netty.server.TcpServerTransport;
import org.jspecify.annotations.Nullable;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.messaging.handler.MessagingAdviceBean;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
import org.springframework.web.method.ControllerAdviceBean;

/**
* {@link EnableAutoConfiguration Auto-configuration} for Spring RSocket support in Spring
* Messaging.
*
* @author Brian Clozel
* @author Dmitry Sulman
* @since 4.0.0
*/
@AutoConfiguration(after = RSocketStrategiesAutoConfiguration.class)
Expand All @@ -42,11 +47,44 @@ public final class RSocketMessagingAutoConfiguration {
@Bean
@ConditionalOnMissingBean
RSocketMessageHandler messageHandler(RSocketStrategies rSocketStrategies,
ObjectProvider<RSocketMessageHandlerCustomizer> customizers) {
ObjectProvider<RSocketMessageHandlerCustomizer> customizers, ApplicationContext context) {
RSocketMessageHandler messageHandler = new RSocketMessageHandler();
messageHandler.setRSocketStrategies(rSocketStrategies);
customizers.orderedStream().forEach((customizer) -> customizer.customize(messageHandler));
ControllerAdviceBean.findAnnotatedBeans(context)
.forEach((controllerAdviceBean) -> messageHandler
.registerMessagingAdvice(new ControllerAdviceBeanWrapper(controllerAdviceBean)));
return messageHandler;
}

private static final class ControllerAdviceBeanWrapper implements MessagingAdviceBean {

private final ControllerAdviceBean adviceBean;

private ControllerAdviceBeanWrapper(ControllerAdviceBean adviceBean) {
this.adviceBean = adviceBean;
}

@Override
public @Nullable Class<?> getBeanType() {
return this.adviceBean.getBeanType();
}

@Override
public Object resolveBean() {
return this.adviceBean.resolveBean();
}

@Override
public boolean isApplicableToBeanType(Class<?> beanType) {
return this.adviceBean.isApplicableToBeanType(beanType);
}

@Override
public int getOrder() {
return this.adviceBean.getOrder();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,30 @@

package org.springframework.boot.rsocket.autoconfigure;

import io.rsocket.frame.FrameType;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.messaging.rsocket.annotation.support.RSocketFrameTypeMessageCondition;
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.stereotype.Controller;
import org.springframework.util.MimeType;
import org.springframework.util.RouteMatcher;
import org.springframework.web.bind.annotation.ControllerAdvice;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -72,6 +85,22 @@ void shouldApplyMessageHandlerCustomizers() {
});
}

@Test
void shouldRegisterControllerAdvice() {
this.contextRunner.withBean(TestControllerAdvice.class).withBean(TestController.class).run((context) -> {
RSocketMessageHandler handler = context.getBean(RSocketMessageHandler.class);
TestControllerAdvice controllerAdvice = context.getBean(TestControllerAdvice.class);

MessageHeaderAccessor headers = new MessageHeaderAccessor();
RouteMatcher.Route route = handler.getRouteMatcher().parseRoute("exception");
headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, route);
headers.setHeader(RSocketFrameTypeMessageCondition.FRAME_TYPE_HEADER, FrameType.REQUEST_FNF);
Message<?> message = MessageBuilder.createMessage(Mono.empty(), headers.getMessageHeaders());
StepVerifier.create(handler.handleMessage(message)).expectComplete().verify();
assertThat(controllerAdvice.isExceptionHandled()).isTrue();
});
}

@Configuration(proxyBeanMethods = false)
static class BaseConfiguration {

Expand Down Expand Up @@ -111,4 +140,30 @@ RSocketMessageHandlerCustomizer customizer() {

}

@Controller
static final class TestController {

@MessageMapping("exception")
void handleWithSimulatedException() {
throw new IllegalStateException("simulated exception");
}

}

@ControllerAdvice
static final class TestControllerAdvice {

boolean exceptionHandled;

boolean isExceptionHandled() {
return this.exceptionHandled;
}

@MessageExceptionHandler
void handleException(IllegalStateException ex) {
this.exceptionHandled = true;
}

}

}