Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -140,7 +139,8 @@ public class AppSecRequestContext implements DataBundle, Closeable {
private boolean responseBodyPublished;
private boolean respDataPublished;
private boolean pathParamsPublished;
private volatile Map<String, Object> derivatives;
private volatile ConcurrentHashMap<String, Object> derivatives;
private final Object derivativesSwapLock = new Object();

private final AtomicBoolean rateLimited = new AtomicBoolean(false);
private volatile boolean throttled;
Expand Down Expand Up @@ -649,10 +649,7 @@ public void close() {
requestHeaders.clear();
responseHeaders.clear();
persistentData.clear();
if (derivatives != null) {
derivatives.clear();
derivatives = null;
}
derivatives = null;
}
}

Expand Down Expand Up @@ -743,9 +740,16 @@ public void reportDerivatives(Map<String, Object> data) {
log.debug("Reporting derivatives: {}", data);
if (data == null || data.isEmpty()) return;

// Store raw derivatives
if (derivatives == null) {
derivatives = new HashMap<>();
// Ensure derivatives map exists with lock only for initialization check
ConcurrentHashMap<String, Object> map = derivatives;
if (map == null) {
synchronized (derivativesSwapLock) {
map = derivatives;
if (map == null) {
map = new ConcurrentHashMap<>();
derivatives = map;
}
}
}

// Process each attribute according to the specification
Expand All @@ -762,7 +766,7 @@ public void reportDerivatives(Map<String, Object> data) {
Object literalValue = config.get("value");
if (literalValue != null) {
// Preserve the original type - don't convert to string
derivatives.put(attributeKey, literalValue);
map.put(attributeKey, literalValue);
log.debug(
"Added literal attribute: {} = {} (type: {})",
attributeKey,
Expand All @@ -781,13 +785,13 @@ else if (config.containsKey("address")) {
Object extractedValue = extractValueFromRequestData(address, keyPath, transformers);
if (extractedValue != null) {
// For extracted values, convert to string as they come from request data
derivatives.put(attributeKey, extractedValue.toString());
map.put(attributeKey, extractedValue.toString());
log.debug("Added extracted attribute: {} = {}", attributeKey, extractedValue);
}
}
} else {
// Handle plain string/numeric values
derivatives.put(attributeKey, attributeConfig);
map.put(attributeKey, attributeConfig);
log.debug("Added direct attribute: {} = {}", attributeKey, attributeConfig);
}
}
Expand Down Expand Up @@ -938,45 +942,54 @@ private Object applyTransformers(Object value, List<String> transformers) {
}

public boolean commitDerivatives(TraceSegment traceSegment) {
log.debug("Committing derivatives: {} for {}", derivatives, traceSegment);
if (traceSegment == null) {
return false;
}

// Atomically swap out the map to iterate safely
ConcurrentHashMap<String, Object> snapshot;
synchronized (derivativesSwapLock) {
snapshot = derivatives;
derivatives = null;
}

if (snapshot == null || snapshot.isEmpty()) {
return true;
}

log.debug("Committing derivatives: {} for {}", snapshot, traceSegment);

// Process and commit derivatives directly
if (derivatives != null && !derivatives.isEmpty()) {
for (Map.Entry<String, Object> entry : derivatives.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();

// Handle different value types
if (value instanceof Number) {
traceSegment.setTagTop(key, (Number) value);
} else if (value instanceof String) {
// Try to parse as numeric, otherwise use as string
Number parsedNumber = convertToNumericAttribute((String) value);
if (parsedNumber != null) {
traceSegment.setTagTop(key, parsedNumber);
} else {
traceSegment.setTagTop(key, value);
}
} else if (value instanceof Boolean) {
traceSegment.setTagTop(key, value);
for (Map.Entry<String, Object> entry : snapshot.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();

// Handle different value types
if (value instanceof Number) {
traceSegment.setTagTop(key, (Number) value);
} else if (value instanceof String) {
// Try to parse as numeric, otherwise use as string
Number parsedNumber = convertToNumericAttribute((String) value);
if (parsedNumber != null) {
traceSegment.setTagTop(key, parsedNumber);
} else {
// Convert other types to string
traceSegment.setTagTop(key, value.toString());
traceSegment.setTagTop(key, value);
}
} else if (value instanceof Boolean) {
traceSegment.setTagTop(key, value);
} else {
// Convert other types to string
traceSegment.setTagTop(key, value.toString());
}
}

// Clear all attribute maps
derivatives = null;
return true;
}

// Mainly used for testing and logging
Set<String> getDerivativeKeys() {
return derivatives == null ? emptySet() : new HashSet<>(derivatives.keySet());
ConcurrentHashMap<String, Object> map = derivatives;
return map == null ? emptySet() : new HashSet<>(map.keySet());
}

public boolean isThrottled(RateLimiter rateLimiter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,4 +451,66 @@ class AppSecRequestContextSpecification extends DDSpecification {
assert context.isHttpClientRequestSampled(requestId) == sampled
}
}

void 'concurrent reportDerivatives and commitDerivatives should not throw ConcurrentModificationException'() {
given:
def context = new AppSecRequestContext()
def startLatch = new java.util.concurrent.CountDownLatch(1)
def doneLatch = new java.util.concurrent.CountDownLatch(8)
def errors = new java.util.concurrent.ConcurrentLinkedQueue()
def traceSegment = Mock(datadog.trace.api.internal.TraceSegment)

def reporters = []
def committers = []

// Create 5 reporter threads
for (int threadNum = 1; threadNum <= 5; threadNum++) {
final int num = threadNum
reporters.add(Thread.start {
try {
startLatch.await()
for (int i = 0; i < 100; i++) {
def attrs = [:]
attrs["thread${num}.iteration${i}".toString()] = "value_${num}_${i}".toString()
attrs["thread${num}.number".toString()] = [value: i]
context.reportDerivatives(attrs)
if (i % 10 == 0) {
Thread.sleep(1)
}
}
} catch (Throwable t) {
errors.add(t)
} finally {
doneLatch.countDown()
}
})
}

// Create 3 committer threads
for (int i = 0; i < 3; i++) {
committers.add(Thread.start {
try {
startLatch.await()
for (int j = 0; j < 50; j++) {
context.commitDerivatives(traceSegment)
Thread.sleep(2)
}
} catch (Throwable t) {
errors.add(t)
} finally {
doneLatch.countDown()
}
})
}

when:
startLatch.countDown()
def completed = doneLatch.await(30, java.util.concurrent.TimeUnit.SECONDS)
(reporters + committers)*.join()

then:
completed
errors.isEmpty()
context.getDerivativeKeys() != null
}
}