diff --git a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java index e9cb391ea08..c896c7a23ea 100644 --- a/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java +++ b/servlet/src/jettyTest/java/io/grpc/servlet/JettyTransportTest.java @@ -69,6 +69,7 @@ public void start(ServerListener listener) throws IOException { listener.transportCreated(new ServletServerBuilder.ServerTransportImpl(scheduler)); ServletAdapter adapter = new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); diff --git a/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java index 4bfe8949776..e84b9341fd9 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletAdapter.java @@ -45,6 +45,7 @@ import java.util.Enumeration; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import java.util.logging.Logger; import javax.servlet.AsyncContext; import javax.servlet.AsyncEvent; @@ -72,18 +73,23 @@ public final class ServletAdapter { static final Logger logger = Logger.getLogger(ServletAdapter.class.getName()); + static final Function DEFAULT_METHOD_NAME_RESOLVER = + req -> req.getRequestURI().substring(1); // remove the leading "/" private final ServerTransportListener transportListener; private final List streamTracerFactories; + private final Function methodNameResolver; private final int maxInboundMessageSize; private final Attributes attributes; ServletAdapter( ServerTransportListener transportListener, List streamTracerFactories, + Function methodNameResolver, int maxInboundMessageSize) { this.transportListener = transportListener; this.streamTracerFactories = streamTracerFactories; + this.methodNameResolver = methodNameResolver; this.maxInboundMessageSize = maxInboundMessageSize; attributes = transportListener.transportReady(Attributes.EMPTY); } @@ -119,7 +125,7 @@ public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOEx AsyncContext asyncCtx = req.startAsync(req, resp); - String method = req.getRequestURI().substring(1); // remove the leading "/" + String method = methodNameResolver.apply(req); Metadata headers = getHeaders(req); if (logger.isLoggable(FINEST)) { diff --git a/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java index 72c4383d273..aee25de01ad 100644 --- a/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java +++ b/servlet/src/main/java/io/grpc/servlet/ServletServerBuilder.java @@ -49,8 +49,10 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; +import javax.servlet.http.HttpServletRequest; /** * Builder to build a gRPC server that can run as a servlet. This is for advanced custom settings. @@ -64,6 +66,8 @@ @NotThreadSafe public final class ServletServerBuilder extends ForwardingServerBuilder { List streamTracerFactories; + private Function methodNameResolver = + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER; int maxInboundMessageSize = DEFAULT_MAX_MESSAGE_SIZE; private final ServerImplBuilder serverImplBuilder; @@ -98,7 +102,8 @@ public Server build() { * Creates a {@link ServletAdapter}. */ public ServletAdapter buildServletAdapter() { - return new ServletAdapter(buildAndStart(), streamTracerFactories, maxInboundMessageSize); + return new ServletAdapter(buildAndStart(), streamTracerFactories, methodNameResolver, + maxInboundMessageSize); } /** @@ -176,6 +181,18 @@ public ServletServerBuilder useTransportSecurity(File certChain, File privateKey throw new UnsupportedOperationException("TLS should be configured by the servlet container"); } + /** + * Specifies how to determine gRPC method name from servlet request. + * + *

The default strategy is using {@link HttpServletRequest#getRequestURI()} without the leading + * slash.

+ */ + public ServletServerBuilder methodNameResolver( + Function methodResolver) { + this.methodNameResolver = checkNotNull(methodResolver); + return this; + } + @Override public ServletServerBuilder maxInboundMessageSize(int bytes) { checkArgument(bytes >= 0, "bytes must be >= 0"); diff --git a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java index 262036883a9..2171c6eb2df 100644 --- a/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java +++ b/servlet/src/tomcatTest/java/io/grpc/servlet/TomcatTransportTest.java @@ -81,7 +81,9 @@ public void start(ServerListener listener) throws IOException { ServerTransportListener serverTransportListener = listener.transportCreated(new ServerTransportImpl(scheduler)); ServletAdapter adapter = - new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, + Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); tomcatServer = new Tomcat(); diff --git a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java index e14c11985de..ef897c87d70 100644 --- a/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java +++ b/servlet/src/undertowTest/java/io/grpc/servlet/UndertowTransportTest.java @@ -100,7 +100,9 @@ public void start(ServerListener listener) throws IOException { ServerTransportListener serverTransportListener = listener.transportCreated(new ServerTransportImpl(scheduler)); ServletAdapter adapter = - new ServletAdapter(serverTransportListener, streamTracerFactories, Integer.MAX_VALUE); + new ServletAdapter(serverTransportListener, streamTracerFactories, + ServletAdapter.DEFAULT_METHOD_NAME_RESOLVER, + Integer.MAX_VALUE); GrpcServlet grpcServlet = new GrpcServlet(adapter); InstanceFactory instanceFactory = () -> new ImmediateInstanceHandle<>(grpcServlet);