diff --git a/build.sbt b/build.sbt index 51bc65470..98928a372 100644 --- a/build.sbt +++ b/build.sbt @@ -91,6 +91,7 @@ lazy val runtime = Project(id = "runtime", base = file("runtime")) .addPekkoModuleDependency("pekko-discovery", "", PekkoCoreDependency.default) .addPekkoModuleDependency("pekko-http-cors", "", PekkoHttpDependency.default) .addPekkoModuleDependency("pekko-testkit", "test", PekkoCoreDependency.default) + .addPekkoModuleDependency("pekko-http-testkit", "test", PekkoHttpDependency.default) .addPekkoModuleDependency("pekko-stream-testkit", "test", PekkoCoreDependency.default) .settings(Dependencies.runtime) .settings(VersionGenerator.settings) diff --git a/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcDirectives.scala b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcDirectives.scala new file mode 100644 index 000000000..ca4993a24 --- /dev/null +++ b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcDirectives.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.grpc.javadsl + +import org.apache.pekko +import pekko.http.javadsl.server.Route + +import java.util.function.Supplier + +/** + * Provides directives to support serving of gRPC services. + */ +object GrpcDirectives { + + import pekko.grpc.scaladsl.{ GrpcDirectives => G } + import pekko.http.javadsl.server.directives.RouteAdapter + + /** + * Wraps the inner route, passing only standard gRPC (i.e. not grpc-web) requests. + * + * @since 2.0.0 + */ + def grpc(inner: Supplier[Route]): Route = + RouteAdapter { + G.grpc { + inner.get() match { + case ra: RouteAdapter => ra.delegate + } + } + } + + /** + * Wraps the inner route, passing only gRPC-Web requests. + * + * @since 2.0.0 + */ + def grpcWeb(inner: Supplier[Route]): Route = + RouteAdapter { + G.grpcWeb { + inner.get() match { + case ra: RouteAdapter => ra.delegate + } + } + } + + /** + * Wraps the inner route, passing requests for all gRPC protocols. + * + * Unlike a combined grpc | grpcWeb directive, this will provide a single rejection specifying all supported protocols. + * + * @since 2.0.0 + */ + def grpcAll(inner: Supplier[Route]): Route = + RouteAdapter { + G.grpcAll { + inner.get() match { + case ra: RouteAdapter => ra.delegate + } + } + } + +} diff --git a/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/GrpcDirectives.scala b/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/GrpcDirectives.scala new file mode 100644 index 000000000..b243e62da --- /dev/null +++ b/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/GrpcDirectives.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.grpc.scaladsl + +import org.apache.pekko +import pekko.http.scaladsl.server._ + +/** + * Provides directives to support serving of gRPC services. + */ +object GrpcDirectives { + import Directives._ + import pekko.grpc.GrpcProtocol + import pekko.grpc.internal.{ GrpcProtocolNative, GrpcProtocolWeb, GrpcProtocolWebText } + import pekko.http.scaladsl.model.{ ContentTypeRange, MediaType } + + /** + * Wraps the inner route, passing only standard gRPC (i.e. not grpc-web) requests. + * + * @since 2.0.0 + */ + def grpc: Directive0 = grpc(GrpcProtocolNative) + + /** + * Wraps the inner route, passing only gRPC-Web requests. + * + * @since 2.0.0 + */ + def grpcWeb: Directive0 = grpc(GrpcProtocolWeb, GrpcProtocolWebText) + + /** + * Wraps the inner route, passing requests for all gRPC protocols. + * + * Unlike a combined grpc | grpcWeb directive, this will provide a single rejection specifying all supported protocols. + * + * @since 2.0.0 + */ + def grpcAll: Directive0 = grpc(GrpcProtocolNative, GrpcProtocolWeb, GrpcProtocolWebText) + + /** + * Wraps the inner route, passing requests only for a specific set of Grpc protocols. + * @param protocols the protocols to accept and pass to the inner route. + */ + private def grpc(protocols: GrpcProtocol*): Directive0 = { + val acceptedMediaTypes = protocols.flatMap(_.mediaTypes).map(_.asInstanceOf[MediaType]).toSet + extractRequest.flatMap { request => + if (acceptedMediaTypes.contains(request.entity.contentType.mediaType)) + pass + else + reject( + UnsupportedRequestContentTypeRejection( + acceptedMediaTypes.map(mt => ContentTypeRange(mt)), + Some(request.entity.contentType) + ) + ) + } + } + +} diff --git a/runtime/src/test/scala/org/apache/pekko/grpc/scaladsl/GrpcDirectivesSpec.scala b/runtime/src/test/scala/org/apache/pekko/grpc/scaladsl/GrpcDirectivesSpec.scala new file mode 100644 index 000000000..d65df6d51 --- /dev/null +++ b/runtime/src/test/scala/org/apache/pekko/grpc/scaladsl/GrpcDirectivesSpec.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.grpc.scaladsl + +import org.apache.pekko +import org.scalatest.Inside.inside +import org.scalatest.Inspectors +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import pekko.grpc.GrpcProtocol +import pekko.grpc.internal.{ GrpcProtocolNative, GrpcProtocolWeb, GrpcProtocolWebText } +import pekko.http.scaladsl.model._ +import pekko.http.scaladsl.server.{ Directives, Route, UnsupportedRequestContentTypeRejection } +import pekko.http.scaladsl.testkit.ScalatestRouteTest + +class GrpcDirectivesSpec extends AnyWordSpec with Matchers with Inspectors with Directives with ScalatestRouteTest { + import pekko.grpc.scaladsl.GrpcDirectives._ + + private val actual = "actual" + private val exampleStatus = StatusCodes.Created + + private val requestContent = Array[Byte]() + + private def protocolRequests(grpcProtocol: GrpcProtocol*): Seq[HttpRequest] = + grpcProtocol.flatMap(_.mediaTypes).map { mt => + Post("/service", + HttpEntity( + ContentType(mt.asInstanceOf[MediaType.Binary]), + requestContent + )) + } + + private def validRequest(route: Route)(request: HttpRequest): Unit = { + request ~> route ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + } + } + + private def invalidRequest(route: Route, acceptedProtocols: GrpcProtocol*)(request: HttpRequest): Unit = { + val expectedContentTypes = acceptedProtocols.flatMap(_.mediaTypes).map(_.asInstanceOf[MediaType.Binary]).map(mt => + ContentTypeRange(mt)).toSet + request ~> route ~> check { + inside(rejections) { + case UnsupportedRequestContentTypeRejection(contentTypeRanges) +: _ => contentTypeRanges shouldBe + expectedContentTypes + } + } + } + + private val nonGrpcRequests = Seq( + Get("/healthz"), + Post("/service", + HttpEntity( + ContentType(MediaType.applicationBinary("grpc-not", MediaType.Compressible)), + requestContent + )), + Post("/service", + HttpEntity( + ContentType(MediaTypes.`application/json`), + requestContent + )) + ) + + "The grpc directive" should { + val route = grpc { + complete(HttpResponse(exampleStatus, Nil, HttpEntity(actual))) + } + "pass only grpc native protocol" in { + forAll(protocolRequests(GrpcProtocolNative))(validRequest(route)) + } + + "not pass non-grpc native protocols" in { + forAll(nonGrpcRequests ++ protocolRequests(GrpcProtocolWeb, GrpcProtocolWebText))(invalidRequest(route, + GrpcProtocolNative)) + } + } + + "The grpcWeb directive" should { + val route = grpcWeb { + complete(HttpResponse(exampleStatus, Nil, HttpEntity(actual))) + } + "pass all matching grpc-web protocols" in { + forAll(protocolRequests(GrpcProtocolWeb, GrpcProtocolWebText))(validRequest(route)) + } + + "not pass non-grpc and non-matching grpc protocols" in { + forAll(nonGrpcRequests ++ protocolRequests(GrpcProtocolNative))(invalidRequest(route, GrpcProtocolWeb, + GrpcProtocolWebText)) + } + } + + "The grpcAll directive" should { + val route = grpcAll { + complete(HttpResponse(exampleStatus, Nil, HttpEntity(actual))) + } + "pass all grpc protocols" in { + forAll(protocolRequests(GrpcProtocolNative, GrpcProtocolWeb, GrpcProtocolWebText))(validRequest(route)) + } + + "not pass non-grpc and non-matching grpc protocols" in { + forAll(nonGrpcRequests)(invalidRequest(route, GrpcProtocolNative, GrpcProtocolWeb, GrpcProtocolWebText)) + } + } + + "Combined grpc | grpcWeb directive" should { + val route = (grpc | grpcWeb) { + complete(HttpResponse(exampleStatus, Nil, HttpEntity(actual))) + } + "pass all grpc protocols" in { + forAll(protocolRequests(GrpcProtocolNative, GrpcProtocolWeb, GrpcProtocolWebText))(validRequest(route)) + } + + "not pass non-grpc and non-matching grpc protocols" in { + forAll(nonGrpcRequests)(invalidRequest(route, GrpcProtocolNative)) + } + } + +}