diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index 7bccd487bf..2310ae2705 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -178,8 +178,12 @@ public void sendRequestDecorate( requestHeadersToCopy.removeAll(Task.REQUEST_HEADERS); // Special case where this header is preserved during stashContext. } + final Supplier restorableContextSupplier = getThreadContext().newRestorableContext(true); try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { - final TransportResponseHandler restoringHandler = new RestoringTransportResponseHandler(handler, stashedContext); + final TransportResponseHandler restoringHandler = new RestoringTransportResponseHandler( + handler, + restorableContextSupplier + ); getThreadContext().putHeader("_opendistro_security_remotecn", cs.getClusterName().value()); final Map headerMap = new HashMap<>( @@ -377,10 +381,13 @@ private ThreadContext getThreadContext() { // which is private scoped private class RestoringTransportResponseHandler implements TransportResponseHandler { - private final ThreadContext.StoredContext contextToRestore; + private final Supplier contextToRestore; private final TransportResponseHandler innerHandler; - private RestoringTransportResponseHandler(TransportResponseHandler innerHandler, ThreadContext.StoredContext contextToRestore) { + private RestoringTransportResponseHandler( + TransportResponseHandler innerHandler, + Supplier contextToRestore + ) { this.contextToRestore = contextToRestore; this.innerHandler = innerHandler; } @@ -407,7 +414,7 @@ public void handleResponse(T response) { List dlsResponseHeader = responseHeaders.get(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER); List maskedFieldsResponseHeader = responseHeaders.get(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER); - contextToRestore.restore(); + contextToRestore.get(); final boolean isDebugEnabled = log.isDebugEnabled(); if (response instanceof ClusterSearchShardsResponse) { @@ -438,7 +445,7 @@ public void handleResponse(T response) { @Override public void handleException(TransportException e) { - contextToRestore.restore(); + contextToRestore.get(); innerHandler.handleException(e); } diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index 318f3e8984..0685cf7793 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -11,6 +11,10 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -19,6 +23,7 @@ import org.junit.Test; import org.opensearch.Version; +import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.opensearch.action.search.PitService; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -45,6 +50,7 @@ import org.opensearch.test.transport.MockTransport; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport.Connection; +import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportInterceptor.AsyncSender; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestOptions; @@ -58,8 +64,10 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -449,4 +457,601 @@ public void testStreamRequestType() { completableRequestDecorate(jdkSerializedSender, connection1, action, request, streamOptions, handler, localNode); } + /** + * Verifies that TASK_RESOURCE_USAGE response header survives context restore + * in RestoringTransportResponseHandler.handleResponse(). + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testTaskResourceUsageResponseHeaderSurvivesContextRestore() { + final String TASK_RESOURCE_USAGE = "TASK_RESOURCE_USAGE"; + final String resourceUsageValue = "{\"action\":\"indices:data/read/search[phase/query]\"," + + "\"taskId\":1,\"parentTaskId\":2,\"nodeId\":\"dataNode1\"," + + "\"taskResourceUsage\":{\"cpu_time_in_nanos\":123,\"memory_in_bytes\":456}}"; + + final AtomicReference>> responseHeadersAfterRestore = new AtomicReference<>(); + + AsyncSender resourceUsageSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, resourceUsageValue); + + handler.handleResponse((T) new TransportResponse.Empty()); + + responseHeadersAfterRestore.set(threadPool.getThreadContext().getResponseHeaders()); + + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(resourceUsageSender, connection3, action, request, options, handler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + Map> headers = responseHeadersAfterRestore.get(); + assertThat( + "TASK_RESOURCE_USAGE response header should be present after handleResponse() context restore", + headers.containsKey(TASK_RESOURCE_USAGE), + is(true) + ); + assertThat( + "TASK_RESOURCE_USAGE response header value should match", + headers.get(TASK_RESOURCE_USAGE).get(0), + is(resourceUsageValue) + ); + } + + /** + * Verifies that ALL response headers (TASK_RESOURCE_USAGE + arbitrary custom headers) + * survive context restore in RestoringTransportResponseHandler.handleResponse(). + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testMultipleResponseHeadersSurviveContextRestore() { + final String TASK_RESOURCE_USAGE = "TASK_RESOURCE_USAGE"; + final String resourceUsageValue = "{\"action\":\"indices:data/read/search[phase/query]\"," + + "\"taskId\":3,\"parentTaskId\":4,\"nodeId\":\"dataNode2\"," + + "\"taskResourceUsage\":{\"cpu_time_in_nanos\":789,\"memory_in_bytes\":1024}}"; + final String CUSTOM_HEADER = "X-Custom-Plugin-Header"; + final String customHeaderValue = "custom-plugin-data-value"; + + final AtomicReference>> responseHeadersAfterRestore = new AtomicReference<>(); + + AsyncSender multiHeaderSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, resourceUsageValue); + threadPool.getThreadContext().addResponseHeader(CUSTOM_HEADER, customHeaderValue); + + handler.handleResponse((T) new TransportResponse.Empty()); + + responseHeadersAfterRestore.set(threadPool.getThreadContext().getResponseHeaders()); + + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(multiHeaderSender, connection3, action, request, options, handler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + Map> headers = responseHeadersAfterRestore.get(); + + assertThat( + "TASK_RESOURCE_USAGE response header should be present after handleResponse() context restore", + headers.containsKey(TASK_RESOURCE_USAGE), + is(true) + ); + assertThat( + "TASK_RESOURCE_USAGE response header value should match", + headers.get(TASK_RESOURCE_USAGE).get(0), + is(resourceUsageValue) + ); + assertThat( + "Custom response header should be present after handleResponse() context restore", + headers.containsKey(CUSTOM_HEADER), + is(true) + ); + assertThat("Custom response header value should match", headers.get(CUSTOM_HEADER).get(0), is(customHeaderValue)); + } + + /** + * Preservation test: ClusterSearchShardsResponse with DLS response header sets + * OPENDISTRO_SECURITY_DLS_QUERY_CCS transient after handleResponse(). + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_ClusterSearchShardsResponse_DlsTransientSet() { + final String dlsValue = "{\"bool\":{\"must\":[{\"term\":{\"department\":\"HR\"}}]}}"; + final AtomicReference dlsTransientAfterRestore = new AtomicReference<>(); + + AsyncSender dlsSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER, dlsValue); + + ClusterSearchShardsResponse shardsResponse = new ClusterSearchShardsResponse(null, null, null); + handler.handleResponse((T) shardsResponse); + + dlsTransientAfterRestore.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_CCS)); + + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(dlsSender, connection3, action, request, options, handler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + assertNotNull("DLS CCS transient should be set for ClusterSearchShardsResponse", dlsTransientAfterRestore.get()); + assertThat(dlsTransientAfterRestore.get(), is(dlsValue)); + } + + /** + * Preservation test: ClusterSearchShardsResponse with FLS response header sets + * OPENDISTRO_SECURITY_FLS_FIELDS_CCS transient after handleResponse(). + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_ClusterSearchShardsResponse_FlsTransientSet() { + final String flsValue = "field1,field2,field3"; + final AtomicReference flsTransientAfterRestore = new AtomicReference<>(); + + AsyncSender flsSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER, flsValue); + + ClusterSearchShardsResponse shardsResponse = new ClusterSearchShardsResponse(null, null, null); + handler.handleResponse((T) shardsResponse); + + flsTransientAfterRestore.set( + threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_CCS) + ); + + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(flsSender, connection3, action, request, options, handler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + assertNotNull("FLS CCS transient should be set for ClusterSearchShardsResponse", flsTransientAfterRestore.get()); + assertThat(flsTransientAfterRestore.get(), is(flsValue)); + } + + /** + * Preservation test: ClusterSearchShardsResponse with masked fields response header sets + * OPENDISTRO_SECURITY_MASKED_FIELD_CCS transient after handleResponse(). + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_ClusterSearchShardsResponse_MaskedFieldTransientSet() { + final String maskedFieldValue = "ssn,credit_card,phone_number"; + final AtomicReference maskedTransientAfterRestore = new AtomicReference<>(); + + AsyncSender maskedSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER, maskedFieldValue); + + ClusterSearchShardsResponse shardsResponse = new ClusterSearchShardsResponse(null, null, null); + handler.handleResponse((T) shardsResponse); + + maskedTransientAfterRestore.set( + threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS) + ); + + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(maskedSender, connection3, action, request, options, handler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + assertNotNull("Masked field CCS transient should be set for ClusterSearchShardsResponse", maskedTransientAfterRestore.get()); + assertThat(maskedTransientAfterRestore.get(), is(maskedFieldValue)); + } + + /** + * Preservation test: handleException() restores context and propagates TransportException + * to the inner handler. + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_HandleExceptionRestoresContextAndPropagates() { + final TransportException testException = new TransportException("test exception for preservation"); + final AtomicReference capturedException = new AtomicReference<>(); + final AtomicReference userAfterRestore = new AtomicReference<>(); + + TransportResponseHandler capturingHandler = new TransportResponseHandler() { + @Override + public TransportResponse read(org.opensearch.core.common.io.stream.StreamInput in) { + return null; + } + + @Override + public void handleResponse(TransportResponse response) {} + + @Override + public void handleException(TransportException exp) { + capturedException.set(exp); + userAfterRestore.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER)); + } + + @Override + public String executor() { + return "same"; + } + }; + + AsyncSender exceptionSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + handler.handleException(testException); + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(exceptionSender, connection3, action, request, options, capturingHandler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + assertNotNull("Exception should be propagated to inner handler", capturedException.get()); + assertThat(capturedException.get().getMessage(), is("test exception for preservation")); + + assertNotNull("User transient should be restored before inner handler receives exception", userAfterRestore.get()); + assertThat(userAfterRestore.get(), is(user)); + } + + /** + * Preservation test: handleStreamResponse() delegates directly to inner handler + * without header processing. + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_HandleStreamResponseDelegatesDirectly() { + final AtomicReference streamHandlerCalled = new AtomicReference<>(false); + + TransportResponseHandler streamCapturingHandler = new TransportResponseHandler() { + @Override + public TransportResponse read(org.opensearch.core.common.io.stream.StreamInput in) { + return null; + } + + @Override + public void handleResponse(TransportResponse response) {} + + @Override + public void handleException(TransportException exp) {} + + @Override + public void handleStreamResponse(org.opensearch.transport.stream.StreamTransportResponse response) { + streamHandlerCalled.set(true); + } + + @Override + public String executor() { + return "same"; + } + }; + + AsyncSender streamSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + handler.handleStreamResponse(null); + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(streamSender, connection3, action, request, options, streamCapturingHandler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + assertTrue("handleStreamResponse should delegate directly to inner handler", streamHandlerCalled.get()); + } + + /** + * Preservation test: Non-ClusterSearchShardsResponse responses do NOT set DLS/FLS/masked-field + * transients even when those response headers are present. + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_NonClusterSearchShardsResponse_NoTransientsSet() { + final String dlsValue = "{\"bool\":{\"must\":[{\"term\":{\"department\":\"HR\"}}]}}"; + final String flsValue = "field1,field2"; + final String maskedValue = "ssn,credit_card"; + + final AtomicReference dlsTransient = new AtomicReference<>(); + final AtomicReference flsTransient = new AtomicReference<>(); + final AtomicReference maskedTransient = new AtomicReference<>(); + + AsyncSender nonShardsSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER, dlsValue); + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER, flsValue); + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER, maskedValue); + + handler.handleResponse((T) new TransportResponse.Empty()); + + dlsTransient.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_CCS)); + flsTransient.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_CCS)); + maskedTransient.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS)); + + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(nonShardsSender, connection3, action, request, options, handler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + assertNull("DLS transient should NOT be set for non-ClusterSearchShardsResponse", dlsTransient.get()); + assertNull("FLS transient should NOT be set for non-ClusterSearchShardsResponse", flsTransient.get()); + assertNull("Masked field transient should NOT be set for non-ClusterSearchShardsResponse", maskedTransient.get()); + } + + /** + * Property-based style test: For random combinations of DLS/FLS/masked-field response headers, + * verify transient propagation for ClusterSearchShardsResponse. + * + * Generates all 7 non-empty subsets of {DLS, FLS, MaskedField} and verifies that + * each present header results in the corresponding transient being set, and each + * absent header results in no transient. + * + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_RandomDlsFlsMaskedCombinations_ClusterSearchShardsResponse() { + final String dlsValue = "{\"term\":{\"dept\":\"eng\"}}"; + final String flsValue = "name,email,role"; + final String maskedValue = "ssn,phone"; + + // Test all 8 combinations (including empty set) of {DLS, FLS, MaskedField} + for (int combo = 0; combo < 8; combo++) { + final boolean includeDls = (combo & 1) != 0; + final boolean includeFls = (combo & 2) != 0; + final boolean includeMasked = (combo & 4) != 0; + + final AtomicReference dlsTransient = new AtomicReference<>(); + final AtomicReference flsTransient = new AtomicReference<>(); + final AtomicReference maskedTransient = new AtomicReference<>(); + + AsyncSender comboSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (includeDls) { + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER, dlsValue); + } + if (includeFls) { + threadPool.getThreadContext().addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER, flsValue); + } + if (includeMasked) { + threadPool.getThreadContext() + .addResponseHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER, maskedValue); + } + + ClusterSearchShardsResponse shardsResponse = new ClusterSearchShardsResponse(null, null, null); + handler.handleResponse((T) shardsResponse); + + dlsTransient.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_CCS)); + flsTransient.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_CCS)); + maskedTransient.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS)); + + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(comboSender, connection3, action, request, options, handler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + String comboDesc = String.format("combo=%d (DLS=%b, FLS=%b, Masked=%b)", combo, includeDls, includeFls, includeMasked); + + if (includeDls) { + assertNotNull("DLS transient should be set for " + comboDesc, dlsTransient.get()); + assertThat("DLS value mismatch for " + comboDesc, dlsTransient.get(), is(dlsValue)); + } else { + assertNull("DLS transient should NOT be set for " + comboDesc, dlsTransient.get()); + } + + if (includeFls) { + assertNotNull("FLS transient should be set for " + comboDesc, flsTransient.get()); + assertThat("FLS value mismatch for " + comboDesc, flsTransient.get(), is(flsValue)); + } else { + assertNull("FLS transient should NOT be set for " + comboDesc, flsTransient.get()); + } + + if (includeMasked) { + assertNotNull("Masked transient should be set for " + comboDesc, maskedTransient.get()); + assertThat("Masked value mismatch for " + comboDesc, maskedTransient.get(), is(maskedValue)); + } else { + assertNull("Masked transient should NOT be set for " + comboDesc, maskedTransient.get()); + } + } + } + + /** + * Property-based style test: Generate random TransportException instances and verify + * handleException() restores context and delegates to inner handler for each. + * + * Tests with various exception messages and causes to ensure robust exception handling. + * + */ + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testPreservation_RandomTransportExceptions_HandleExceptionRestoresAndDelegates() { + Random random = new Random(42); // Fixed seed for reproducibility + + List exceptions = Arrays.asList( + new TransportException("simple message"), + new TransportException("message with special chars: <>&\"'"), + new TransportException((String) null), + new TransportException("caused exception", new RuntimeException("root cause")), + new TransportException(new IllegalStateException("state error")), + new TransportException("unicode: \u00e9\u00e8\u00ea\u00eb"), + new TransportException("long message " + "x".repeat(1000)), + new TransportException("empty cause", null) + ); + + for (int i = 0; i < exceptions.size(); i++) { + final TransportException testException = exceptions.get(i); + final int testIndex = i; + final AtomicReference capturedException = new AtomicReference<>(); + final AtomicReference userAfterRestore = new AtomicReference<>(); + + TransportResponseHandler capturingHandler = new TransportResponseHandler() { + @Override + public TransportResponse read(org.opensearch.core.common.io.stream.StreamInput in) { + return null; + } + + @Override + public void handleResponse(TransportResponse response) {} + + @Override + public void handleException(TransportException exp) { + capturedException.set(exp); + userAfterRestore.set(threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER)); + } + + @Override + public String executor() { + return "same"; + } + }; + + AsyncSender exceptionSender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + handler.handleException(testException); + senderLatch.get().countDown(); + } + }; + + securityInterceptor.sendRequestDecorate(exceptionSender, connection3, action, request, options, capturingHandler, localNode); + + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + senderLatch.set(new CountDownLatch(1)); + + String desc = "exception[" + testIndex + "]"; + assertNotNull("Exception should be propagated for " + desc, capturedException.get()); + assertThat("Same exception instance should be propagated for " + desc, capturedException.get(), is(testException)); + assertNotNull("User transient should be restored for " + desc, userAfterRestore.get()); + assertThat("User should match original for " + desc, userAfterRestore.get(), is(user)); + } + } + }