diff --git a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java index 1bf6dfab1c3..6e8e7920058 100644 --- a/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java +++ b/dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; @@ -140,7 +139,8 @@ public class AppSecRequestContext implements DataBundle, Closeable { private boolean responseBodyPublished; private boolean respDataPublished; private boolean pathParamsPublished; - private volatile Map derivatives; + private volatile ConcurrentHashMap derivatives; + private final Object derivativesSwapLock = new Object(); private final AtomicBoolean rateLimited = new AtomicBoolean(false); private volatile boolean throttled; @@ -649,10 +649,7 @@ public void close() { requestHeaders.clear(); responseHeaders.clear(); persistentData.clear(); - if (derivatives != null) { - derivatives.clear(); - derivatives = null; - } + derivatives = null; } } @@ -743,9 +740,16 @@ public void reportDerivatives(Map data) { log.debug("Reporting derivatives: {}", data); if (data == null || data.isEmpty()) return; - // Store raw derivatives - if (derivatives == null) { - derivatives = new HashMap<>(); + // Ensure derivatives map exists with lock only for initialization check + ConcurrentHashMap map = derivatives; + if (map == null) { + synchronized (derivativesSwapLock) { + map = derivatives; + if (map == null) { + map = new ConcurrentHashMap<>(); + derivatives = map; + } + } } // Process each attribute according to the specification @@ -762,7 +766,7 @@ public void reportDerivatives(Map data) { Object literalValue = config.get("value"); if (literalValue != null) { // Preserve the original type - don't convert to string - derivatives.put(attributeKey, literalValue); + map.put(attributeKey, literalValue); log.debug( "Added literal attribute: {} = {} (type: {})", attributeKey, @@ -781,13 +785,13 @@ else if (config.containsKey("address")) { Object extractedValue = extractValueFromRequestData(address, keyPath, transformers); if (extractedValue != null) { // For extracted values, convert to string as they come from request data - derivatives.put(attributeKey, extractedValue.toString()); + map.put(attributeKey, extractedValue.toString()); log.debug("Added extracted attribute: {} = {}", attributeKey, extractedValue); } } } else { // Handle plain string/numeric values - derivatives.put(attributeKey, attributeConfig); + map.put(attributeKey, attributeConfig); log.debug("Added direct attribute: {} = {}", attributeKey, attributeConfig); } } @@ -938,45 +942,54 @@ private Object applyTransformers(Object value, List transformers) { } public boolean commitDerivatives(TraceSegment traceSegment) { - log.debug("Committing derivatives: {} for {}", derivatives, traceSegment); if (traceSegment == null) { return false; } + // Atomically swap out the map to iterate safely + ConcurrentHashMap snapshot; + synchronized (derivativesSwapLock) { + snapshot = derivatives; + derivatives = null; + } + + if (snapshot == null || snapshot.isEmpty()) { + return true; + } + + log.debug("Committing derivatives: {} for {}", snapshot, traceSegment); + // Process and commit derivatives directly - if (derivatives != null && !derivatives.isEmpty()) { - for (Map.Entry entry : derivatives.entrySet()) { - String key = entry.getKey(); - Object value = entry.getValue(); - - // Handle different value types - if (value instanceof Number) { - traceSegment.setTagTop(key, (Number) value); - } else if (value instanceof String) { - // Try to parse as numeric, otherwise use as string - Number parsedNumber = convertToNumericAttribute((String) value); - if (parsedNumber != null) { - traceSegment.setTagTop(key, parsedNumber); - } else { - traceSegment.setTagTop(key, value); - } - } else if (value instanceof Boolean) { - traceSegment.setTagTop(key, value); + for (Map.Entry entry : snapshot.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + // Handle different value types + if (value instanceof Number) { + traceSegment.setTagTop(key, (Number) value); + } else if (value instanceof String) { + // Try to parse as numeric, otherwise use as string + Number parsedNumber = convertToNumericAttribute((String) value); + if (parsedNumber != null) { + traceSegment.setTagTop(key, parsedNumber); } else { - // Convert other types to string - traceSegment.setTagTop(key, value.toString()); + traceSegment.setTagTop(key, value); } + } else if (value instanceof Boolean) { + traceSegment.setTagTop(key, value); + } else { + // Convert other types to string + traceSegment.setTagTop(key, value.toString()); } } - // Clear all attribute maps - derivatives = null; return true; } // Mainly used for testing and logging Set getDerivativeKeys() { - return derivatives == null ? emptySet() : new HashSet<>(derivatives.keySet()); + ConcurrentHashMap map = derivatives; + return map == null ? emptySet() : new HashSet<>(map.keySet()); } public boolean isThrottled(RateLimiter rateLimiter) { diff --git a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/AppSecRequestContextSpecification.groovy b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/AppSecRequestContextSpecification.groovy index 0d249961c5b..c7765a9aff0 100644 --- a/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/AppSecRequestContextSpecification.groovy +++ b/dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/AppSecRequestContextSpecification.groovy @@ -451,4 +451,66 @@ class AppSecRequestContextSpecification extends DDSpecification { assert context.isHttpClientRequestSampled(requestId) == sampled } } + + void 'concurrent reportDerivatives and commitDerivatives should not throw ConcurrentModificationException'() { + given: + def context = new AppSecRequestContext() + def startLatch = new java.util.concurrent.CountDownLatch(1) + def doneLatch = new java.util.concurrent.CountDownLatch(8) + def errors = new java.util.concurrent.ConcurrentLinkedQueue() + def traceSegment = Mock(datadog.trace.api.internal.TraceSegment) + + def reporters = [] + def committers = [] + + // Create 5 reporter threads + for (int threadNum = 1; threadNum <= 5; threadNum++) { + final int num = threadNum + reporters.add(Thread.start { + try { + startLatch.await() + for (int i = 0; i < 100; i++) { + def attrs = [:] + attrs["thread${num}.iteration${i}".toString()] = "value_${num}_${i}".toString() + attrs["thread${num}.number".toString()] = [value: i] + context.reportDerivatives(attrs) + if (i % 10 == 0) { + Thread.sleep(1) + } + } + } catch (Throwable t) { + errors.add(t) + } finally { + doneLatch.countDown() + } + }) + } + + // Create 3 committer threads + for (int i = 0; i < 3; i++) { + committers.add(Thread.start { + try { + startLatch.await() + for (int j = 0; j < 50; j++) { + context.commitDerivatives(traceSegment) + Thread.sleep(2) + } + } catch (Throwable t) { + errors.add(t) + } finally { + doneLatch.countDown() + } + }) + } + + when: + startLatch.countDown() + def completed = doneLatch.await(30, java.util.concurrent.TimeUnit.SECONDS) + (reporters + committers)*.join() + + then: + completed + errors.isEmpty() + context.getDerivativeKeys() != null + } }