diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 79d88f0c..4e0aebbc 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -31,6 +31,7 @@ junit-jupiter-engine = { group = "org.junit.jupiter", name = "junit-jupiter-engi mockito-core = { group = "org.mockito", name = "mockito-core", version = "5.17.0" } protoc = { group = "com.google.protobuf", name = "protoc", version = "3.25.3" } protobuf-java = { group = "com.google.protobuf", name = "protobuf-java", version = "3.25.3" } +protobuf-kotlin = { group = "com.google.protobuf", name = "protobuf-kotlin", version = "3.25.3" } truth = { group = "com.google.truth", name = "truth", version = "1.4.4" } truth-proto-extension = { group = "com.google.truth.extensions", name = "truth-proto-extension", version = "1.4.4" } okhttp = { group = "com.squareup.okhttp", name = "okhttp", version = "2.7.5" } diff --git a/stub/build.gradle.kts b/stub/build.gradle.kts index 6e00b59a..e12b2a12 100644 --- a/stub/build.gradle.kts +++ b/stub/build.gradle.kts @@ -22,12 +22,14 @@ dependencies { api(libs.javax.annotation.api) // Testing + testImplementation(kotlin("test")) testImplementation(libs.junit) testImplementation(libs.junit.jupiter.engine) testImplementation(libs.truth.proto.extension) testImplementation(libs.grpc.protobuf) testImplementation(libs.grpc.testing) testImplementation(libs.grpc.inprocess) + testImplementation(libs.protobuf.kotlin) } java { @@ -53,6 +55,10 @@ protobuf { id("grpc") id("grpckt") } + + it.builtins { + id("kotlin") + } } } } diff --git a/stub/src/main/java/io/grpc/kotlin/Helpers.kt b/stub/src/main/java/io/grpc/kotlin/Helpers.kt index bc0e2728..511ce4de 100644 --- a/stub/src/main/java/io/grpc/kotlin/Helpers.kt +++ b/stub/src/main/java/io/grpc/kotlin/Helpers.kt @@ -16,8 +16,10 @@ package io.grpc.kotlin +import io.grpc.Metadata import io.grpc.Status import io.grpc.StatusException +import kotlin.coroutines.coroutineContext import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -75,3 +77,17 @@ internal fun Flow.singleOrStatusFlow(expected: String, descriptor: Any): */ internal suspend fun Flow.singleOrStatus(expected: String, descriptor: Any): T = singleOrStatusFlow(expected, descriptor).single() + +/** + * Returns gRPC Metadata. Throws [StatusException] if [MetadataCoroutineContextInterceptor] is not + * injected to gRPC server. + */ +suspend fun grpcMetadata(): Metadata { + val metadataElement = + coroutineContext[MetadataElement] + ?: throw Status.INTERNAL.withDescription( + "gRPC Metadata not found in coroutineContext. Ensure that MetadataCoroutineContextInterceptor is used in gRPC server." + ) + .asException() + return metadataElement.value +} diff --git a/stub/src/main/java/io/grpc/kotlin/MetadataCoroutineContextInterceptor.kt b/stub/src/main/java/io/grpc/kotlin/MetadataCoroutineContextInterceptor.kt new file mode 100644 index 00000000..24e3e03d --- /dev/null +++ b/stub/src/main/java/io/grpc/kotlin/MetadataCoroutineContextInterceptor.kt @@ -0,0 +1,36 @@ +package io.grpc.kotlin + +import io.grpc.Metadata +import io.grpc.ServerCall +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.coroutineContext + +/** + * A server interceptor which propagates gRPC [Metadata] (HTTP Headers) to coroutineContext. To use + * it attach the interceptor to gRPC Server and then access the [Metadata] using grpcMetadata() + * function. + * + * Example usage: + * + * ServerBuilder.forPort(8060).addService(GreeterImpl()) + * .intercept(MetadataCoroutineContextInterceptor()) + * + * Then in RPC implementation code call grpcMetadata() + */ +class MetadataCoroutineContextInterceptor : CoroutineContextServerInterceptor() { + final override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext = + MetadataElement(value = headers) +} + +/** + * A metadata element for coroutine context. It is used for accessing the gRPC [Metadata] from + * [coroutineContext]. + * + * Example usage: coroutineContext[MetadataElement]?.value + */ +internal data class MetadataElement(val value: Metadata) : CoroutineContext.Element { + companion object Key : CoroutineContext.Key + + override val key: CoroutineContext.Key + get() = Key +} diff --git a/stub/src/test/java/io/grpc/kotlin/MetadataCoroutineContextInterceptorTest.kt b/stub/src/test/java/io/grpc/kotlin/MetadataCoroutineContextInterceptorTest.kt new file mode 100644 index 00000000..deaddb79 --- /dev/null +++ b/stub/src/test/java/io/grpc/kotlin/MetadataCoroutineContextInterceptorTest.kt @@ -0,0 +1,98 @@ +package io.grpc.kotlin + +import com.google.common.truth.Truth.assertThat +import io.grpc.BindableService +import io.grpc.Channel +import io.grpc.Metadata +import io.grpc.Status +import io.grpc.StatusException +import io.grpc.examples.helloworld.GreeterGrpcKt +import io.grpc.examples.helloworld.HelloReply +import io.grpc.examples.helloworld.HelloRequest +import io.grpc.examples.helloworld.helloReply +import io.grpc.examples.helloworld.helloRequest +import io.grpc.inprocess.InProcessChannelBuilder +import io.grpc.inprocess.InProcessServerBuilder +import io.grpc.testing.GrpcCleanupRule +import kotlin.test.assertFailsWith +import kotlinx.coroutines.runBlocking +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +@RunWith(JUnit4::class) +class MetadataCoroutineContextInterceptorTest { + @Rule @JvmField val grpcCleanup = GrpcCleanupRule() + + @Test + fun interceptCall_providesMetadataToCoroutineContext() { + val keyA = Metadata.Key.of("test-header-a", Metadata.ASCII_STRING_MARSHALLER) + val keyB = Metadata.Key.of("test-header-b", Metadata.ASCII_STRING_MARSHALLER) + val clientStub = + GreeterGrpcKt.GreeterCoroutineStub( + testChannel( + object : GreeterGrpcKt.GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + val metadata = grpcMetadata() + return helloReply { + message = listOf(metadata.get(keyA), metadata.get(keyB)).joinToString() + } + } + } + ) + ) + val metadata = Metadata() + metadata.put(keyA, "Test message A") + metadata.put(keyB, "Test message B") + + val response = runBlocking { clientStub.sayHello(helloRequest {}, metadata) } + + assertThat(response.message).isEqualTo("Test message A, Test message B") + } + + @Test + fun grpcMetadata_interceptorNotInjected_throwsStatusExceptionInternal() { + val keyA = Metadata.Key.of("test-header-a", Metadata.ASCII_STRING_MARSHALLER) + val keyB = Metadata.Key.of("test-header-b", Metadata.ASCII_STRING_MARSHALLER) + val clientStub = + GreeterGrpcKt.GreeterCoroutineStub( + testChannel( + object : GreeterGrpcKt.GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + val metadata = grpcMetadata() + return helloReply { + message = listOf(metadata.get(keyA), metadata.get(keyB)).joinToString() + } + } + }, + false + ) + ) + val metadata = Metadata() + metadata.put(keyA, "Test message A") + metadata.put(keyB, "Test message B") + + val exception = + assertFailsWith { + runBlocking { clientStub.sayHello(helloRequest {}, metadata) } + } + assertThat(exception.status.code).isEqualTo(Status.INTERNAL.code) + assertThat(exception.status.description) + .isEqualTo( + "gRPC Metadata not found in coroutineContext. Ensure that MetadataCoroutineContextInterceptor is used in gRPC server." + ) + } + + private fun testChannel(service: BindableService, attachInterceptor: Boolean = true): Channel { + val serverName = InProcessServerBuilder.generateName() + var builder = InProcessServerBuilder.forName(serverName).directExecutor() + if (attachInterceptor) { + builder = builder.intercept(MetadataCoroutineContextInterceptor()) + } + grpcCleanup.register(builder.addService(service).build().start()) + return grpcCleanup.register( + InProcessChannelBuilder.forName(serverName).directExecutor().build() + ) + } +}