diff --git a/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java b/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java index 9cdf123033b..868a774689c 100644 --- a/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java +++ b/binder/src/test/java/io/grpc/binder/internal/PendingAuthListenerTest.java @@ -1,22 +1,39 @@ package io.grpc.binder.internal; import static com.google.common.truth.Truth.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.ManagedChannel; import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerServiceDefinition; import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.ServerCalls; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.TestMethodDescriptors; +import java.io.IOException; +import java.time.Duration; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.Mockito; @@ -26,12 +43,15 @@ @RunWith(JUnit4.class) public final class PendingAuthListenerTest { + private static final MethodDescriptor TEST_METHOD = + TestMethodDescriptors.voidMethod(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @Mock ServerCallHandler next; @Mock ServerCall call; @Mock ServerCall.Listener delegate; - @Captor ArgumentCaptor statusCaptor; private final Metadata headers = new Metadata(); private final PendingAuthListener listener = new PendingAuthListener<>(); @@ -86,16 +106,80 @@ public void onCallbacks_withCancellation_runsPendingCallbacksAfterStartCall() { } @Test - public void whenStartCallFails_closesTheCallWithInternalStatus() { - IllegalStateException exception = new IllegalStateException("oops"); - when(next.startCall(any(), any())).thenThrow(exception); - - listener.onReady(); - listener.startCall(call, headers, next); + public void whenStartCallFails_closesTheCallWithInternalStatus() throws Exception { + // Arrange + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + throw new IllegalStateException("ooops"); + }); + ManagedChannel channel = startServer(callHandler); + + // Act + StatusRuntimeException ex = + assertThrows( + StatusRuntimeException.class, + () -> + ClientCalls.blockingUnaryCall( + channel, + TEST_METHOD, + CallOptions.DEFAULT.withDeadlineAfter(Duration.ofSeconds(5)), + /* request= */ null)); + + // Assert + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.INTERNAL); + } - verify(call).close(statusCaptor.capture(), any()); - Status status = statusCaptor.getValue(); - assertThat(status.getCode()).isEqualTo(Status.Code.INTERNAL); - assertThat(status.getCause()).isSameInstanceAs(exception); + private ManagedChannel startServer(ServerCallHandler callHandler) throws IOException { + String name = TestMethodDescriptors.SERVICE_NAME; + ServerServiceDefinition serviceDef = + ServerServiceDefinition.builder(name).addMethod(TEST_METHOD, callHandler).build(); + Server server = + InProcessServerBuilder.forName(name) + .addService(serviceDef) + .intercept( + new ServerInterceptor() { + @SuppressWarnings("unchecked") + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + Metadata headers, + ServerCallHandler next) { + + listener.startCall( + (ServerCall) call, + headers, + (ServerCallHandler) next); + return (ServerCall.Listener) listener; + } + }) + .build() + .start(); + ManagedChannel channel = + InProcessChannelBuilder.forName(name) + .intercept( + new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + ClientCall delegate = next.newCall(method, callOptions); + return new ForwardingClientCall.SimpleForwardingClientCall( + delegate) { + @Override + public void start(Listener responseListener, Metadata headers) { + ClientCall.Listener wrappedListener = + new ForwardingClientCallListener.SimpleForwardingClientCallListener< + RespT>(responseListener) {}; + super.start(wrappedListener, headers); + } + }; + } + }) + .build(); + + grpcCleanupRule.register(server); + grpcCleanupRule.register(channel); + + return channel; } } diff --git a/testing/src/main/java/io/grpc/testing/TestMethodDescriptors.java b/testing/src/main/java/io/grpc/testing/TestMethodDescriptors.java index 5ffc60e369b..c53d872c49f 100644 --- a/testing/src/main/java/io/grpc/testing/TestMethodDescriptors.java +++ b/testing/src/main/java/io/grpc/testing/TestMethodDescriptors.java @@ -30,16 +30,19 @@ public final class TestMethodDescriptors { private TestMethodDescriptors() {} + /** The name of the service that the method returned by {@link #voidMethod()} uses. */ + public static final String SERVICE_NAME = "service_foo"; + /** * Creates a new method descriptor that always creates zero length messages, and always parses to - * null objects. + * null objects. It is part of the service named {@link #SERVICE_NAME}. * * @since 1.1.0 */ public static MethodDescriptor voidMethod() { return MethodDescriptor.newBuilder() .setType(MethodType.UNARY) - .setFullMethodName(MethodDescriptor.generateFullMethodName("service_foo", "method_bar")) + .setFullMethodName(MethodDescriptor.generateFullMethodName(SERVICE_NAME, "method_bar")) .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build();