diff --git a/dd-java-agent/agent-aiguard/build.gradle b/dd-java-agent/agent-aiguard/build.gradle index f8dcb4df379..5e4841dbf3c 100644 --- a/dd-java-agent/agent-aiguard/build.gradle +++ b/dd-java-agent/agent-aiguard/build.gradle @@ -18,6 +18,7 @@ dependencies { implementation libs.okhttp api project(':dd-trace-api') + api project(':utils:version-utils') implementation project(':internal-api') implementation project(':communication') diff --git a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java index a7d098a4b91..fc2d2dafb64 100644 --- a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java +++ b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java @@ -1,5 +1,8 @@ package com.datadog.aiguard; +import static datadog.communication.ddagent.TracerVersion.TRACER_VERSION; +import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.CONTENT; +import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.MESSAGES; import static datadog.trace.util.Strings.isBlank; import static java.util.Collections.singletonMap; @@ -21,6 +24,7 @@ import datadog.trace.api.aiguard.AIGuard.ToolCall.Function; import datadog.trace.api.aiguard.Evaluator; import datadog.trace.api.aiguard.noop.NoOpEvaluator; +import datadog.trace.api.telemetry.WafMetricCollector; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; import datadog.trace.bootstrap.instrumentation.api.AgentTracer; @@ -79,7 +83,18 @@ public static void install() { if (isBlank(endpoint)) { endpoint = String.format("https://app.%s/api/v2/ai-guard", config.getSite()); } - final Map headers = mapOf("DD-API-KEY", apiKey, "DD-APPLICATION-KEY", appKey); + final Map headers = + mapOf( + "DD-API-KEY", + apiKey, + "DD-APPLICATION-KEY", + appKey, + "DD-AI-GUARD-VERSION", + TRACER_VERSION, + "DD-AI-GUARD-SOURCE", + "SDK", + "DD-AI-GUARD-LANGUAGE", + "jvm"); final HttpUrl url = HttpUrl.get(endpoint).newBuilder().addPathSegment("evaluate").build(); final int timeout = config.getAiGuardTimeout(); final OkHttpClient client = buildClient(url, timeout); @@ -113,12 +128,17 @@ static void uninstall() { private static List messagesForMetaStruct(List messages) { final Config config = Config.get(); final int size = Math.min(messages.size(), config.getAiGuardMaxMessagesLength()); + if (size < messages.size()) { + WafMetricCollector.get().aiGuardTruncated(MESSAGES); + } final List result = new ArrayList<>(size); final int maxContent = config.getAiGuardMaxContentSize(); + boolean contentTruncated = false; for (int i = 0; i < size; i++) { Message source = messages.get(i); final String content = source.getContent(); if (content != null && content.length() > maxContent) { + contentTruncated = true; source = new Message( source.getRole(), @@ -128,6 +148,9 @@ private static List messagesForMetaStruct(List messages) { } result.add(source); } + if (contentTruncated) { + WafMetricCollector.get().aiGuardTruncated(CONTENT); + } return result; } @@ -203,20 +226,27 @@ public Evaluation evaluate(final List messages, final Options options) final String reason = (String) result.get("reason"); span.setTag(ACTION_TAG, action); span.setTag(REASON_TAG, reason); - final boolean blockingEnabled = - isBlockingEnabled(options, result.get("is_blocking_enabled")); - if (blockingEnabled && action != Action.ALLOW) { + final boolean shouldBlock = + isBlockingEnabled(options, result.get("is_blocking_enabled")) && action != Action.ALLOW; + WafMetricCollector.get().aiGuardRequest(action, shouldBlock); + if (shouldBlock) { span.setTag(BLOCKED_TAG, true); throw new AIGuardAbortError(action, reason); } return new Evaluation(action, reason); } - } catch (AIGuardAbortError | AIGuardClientError e) { + } catch (AIGuardAbortError e) { + span.addThrowable(e); + throw e; + } catch (AIGuardClientError e) { + WafMetricCollector.get().aiGuardError(); span.addThrowable(e); throw e; } catch (final Exception e) { + WafMetricCollector.get().aiGuardError(); final AIGuardClientError error = - new AIGuardClientError("AI Guard service returned unexpected response", e); + new AIGuardClientError( + "AI Guard service returned unexpected response: " + e.getMessage(), e); span.addThrowable(error); throw error; } finally { @@ -248,11 +278,14 @@ private static OkHttpClient buildClient(final HttpUrl url, final long timeout) { return OkHttpUtils.buildHttpClient(url, timeout).newBuilder().build(); } - private static Map mapOf( - final String key1, final String prop1, final String key2, final String prop2) { - final Map map = new HashMap<>(2); - map.put(key1, prop1); - map.put(key2, prop2); + private static Map mapOf(final String... props) { + if (props.length % 2 != 0) { + throw new IllegalArgumentException("Props must be even"); + } + final Map map = new HashMap<>(props.length << 1); + for (int i = 0; i < props.length; ) { + map.put(props[i++], props[i++]); + } return map; } diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy index 913224f70ac..366b977a7a3 100644 --- a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy +++ b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy @@ -4,8 +4,10 @@ import com.fasterxml.jackson.annotation.JsonInclude import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.PropertyNamingStrategies import com.squareup.moshi.Moshi +import datadog.common.version.VersionInfo import datadog.trace.api.Config import datadog.trace.api.aiguard.AIGuard +import datadog.trace.api.telemetry.WafMetricCollector import datadog.trace.bootstrap.instrumentation.api.AgentSpan import datadog.trace.bootstrap.instrumentation.api.AgentTracer import datadog.trace.test.util.DDSpecification @@ -35,7 +37,11 @@ class AIGuardInternalTests extends DDSpecification { protected static final URL = HttpUrl.parse('https://app.datadoghq.com/api/v2/ai-guard/evaluate') @Shared - protected static final HEADERS = ['DD-API-KEY': 'api', 'DD-APPLICATION-KEY': 'app'] + protected static final HEADERS = ['DD-API-KEY': 'api', + 'DD-APPLICATION-KEY': 'app', + 'DD-AI-GUARD-VERSION': VersionInfo.VERSION, + 'DD-AI-GUARD-SOURCE': 'SDK', + 'DD-AI-GUARD-LANGUAGE': 'jvm'] @Shared protected static final ORIGINAL_TRACER = AgentTracer.get() @@ -79,6 +85,11 @@ class AIGuardInternalTests extends DDSpecification { buildSpan(_ as String, _ as String) >> builder } AgentTracer.forceRegister(tracer) + + WafMetricCollector.get().tap { + prepareMetrics() + drain() + } } void cleanup() { @@ -193,6 +204,7 @@ class AIGuardInternalTests extends DDSpecification { eval.action == suite.action eval.reason == suite.reason } + assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false') where: suite << TestSuite.build() @@ -222,6 +234,7 @@ class AIGuardInternalTests extends DDSpecification { final exception = thrown(AIGuard.AIGuardClientError) exception.errors == errors 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + assertTelemetry('ai_guard.requests', 'error:true') } void 'test evaluate with invalid JSON'() { @@ -246,6 +259,7 @@ class AIGuardInternalTests extends DDSpecification { then: thrown(AIGuard.AIGuardClientError) 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + assertTelemetry('ai_guard.requests', 'error:true') } void 'test evaluate with missing action'() { @@ -270,6 +284,7 @@ class AIGuardInternalTests extends DDSpecification { then: thrown(AIGuard.AIGuardClientError) 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + assertTelemetry('ai_guard.requests', 'error:true') } void 'test evaluate with non JSON response'() { @@ -294,6 +309,7 @@ class AIGuardInternalTests extends DDSpecification { then: thrown(AIGuard.AIGuardClientError) 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + assertTelemetry('ai_guard.requests', 'error:true') } void 'test evaluate with empty response'() { @@ -318,6 +334,7 @@ class AIGuardInternalTests extends DDSpecification { then: thrown(AIGuard.AIGuardClientError) 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + assertTelemetry('ai_guard.requests', 'error:true') } void 'test message length truncation'() { @@ -349,6 +366,7 @@ class AIGuardInternalTests extends DDSpecification { assert received.size() == maxMessages assert received.size() < messages.size() } + assertTelemetry('ai_guard.truncated', 'type:messages') } void 'test message content truncation'() { @@ -380,6 +398,7 @@ class AIGuardInternalTests extends DDSpecification { assert it.content.length() < message.content.length() } } + assertTelemetry('ai_guard.truncated', 'type:content') } void 'test no messages'() { @@ -425,6 +444,21 @@ class AIGuardInternalTests extends DDSpecification { 0 * span.setTag(AIGuardInternal.TOOL_TAG, _) } + private static assertTelemetry(final String metric, final String...tags) { + final metrics = WafMetricCollector.get().with { + prepareMetrics() + drain() + } + final filtered = metrics.findAll { + it.namespace == 'appsec' + && it.metricName == metric + && it.tags == tags.toList() + } + assert filtered.size() == 1 : metrics + assert filtered*.value.sum() == 1 + return true + } + private static assertRequest(final Request request, final List messages) { assert request.url() == URL assert request.method() == 'POST' @@ -452,12 +486,12 @@ class AIGuardInternalTests extends DDSpecification { private static Response mockResponse(final Request request, final int status, final Object body) { return new Response.Builder() - .protocol(Protocol.HTTP_1_1) - .message('ok') - .request(request) - .code(status) - .body(body == null ? null : ResponseBody.create(MediaType.parse('application/json'), MOSHI.adapter(Object).toJson(body))) - .build() + .protocol(Protocol.HTTP_1_1) + .message('ok') + .request(request) + .code(status) + .body(body == null ? null : ResponseBody.create(MediaType.parse('application/json'), MOSHI.adapter(Object).toJson(body))) + .build() } private static class TestSuite { @@ -495,13 +529,13 @@ class AIGuardInternalTests extends DDSpecification { @Override String toString() { return "TestSuite{" + - "description='" + description + '\'' + - ", action=" + action + - ", reason='" + reason + '\'' + - ", blocking=" + blocking + - ", target='" + target + '\'' + - ", messages=" + messages + - '}' + "description='" + description + '\'' + + ", action=" + action + + ", reason='" + reason + '\'' + + ", blocking=" + blocking + + ", target='" + target + '\'' + + ", messages=" + messages + + '}' } } } diff --git a/internal-api/src/main/java/datadog/trace/api/telemetry/WafMetricCollector.java b/internal-api/src/main/java/datadog/trace/api/telemetry/WafMetricCollector.java index 1834ef73da5..032212a7f1a 100644 --- a/internal-api/src/main/java/datadog/trace/api/telemetry/WafMetricCollector.java +++ b/internal-api/src/main/java/datadog/trace/api/telemetry/WafMetricCollector.java @@ -1,5 +1,6 @@ package datadog.trace.api.telemetry; +import datadog.trace.api.aiguard.AIGuard; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -60,6 +61,11 @@ private WafMetricCollector() { private static final AtomicLongArray appSecSdkEventQueue = new AtomicLongArray(LoginEvent.getNumValues() * LoginVersion.getNumValues()); private static final AtomicInteger wafConfigErrorCounter = new AtomicInteger(); + private static final AtomicLongArray aiGuardRequests = + new AtomicLongArray(AIGuard.Action.values().length * 2); // 3 actions * block + private static final AtomicInteger aiGuardErrors = new AtomicInteger(); + private static final AtomicLongArray aiGuardTruncated = + new AtomicLongArray(AIGuardTruncationType.values().length); /** WAF version that will be initialized with wafInit and reused for all metrics. */ private static String wafVersion = ""; @@ -195,6 +201,18 @@ public void appSecSdkEvent(final LoginEvent event, final LoginVersion version) { appSecSdkEventQueue.incrementAndGet(index); } + public void aiGuardRequest(final AIGuard.Action action, final boolean block) { + aiGuardRequests.incrementAndGet(action.ordinal() * 2 + (block ? 1 : 0)); + } + + public void aiGuardError() { + aiGuardErrors.incrementAndGet(); + } + + public void aiGuardTruncated(final AIGuardTruncationType type) { + aiGuardTruncated.incrementAndGet(type.ordinal()); + } + @Override public Collection drain() { if (!rawMetricsQueue.isEmpty()) { @@ -364,6 +382,40 @@ public void prepareMetrics() { return; } } + + // AI Guard successful requests + for (final AIGuard.Action action : AIGuard.Action.values()) { + final long blocked = aiGuardRequests.getAndSet(action.ordinal() * 2 + 1, 0); + if (blocked > 0) { + if (!rawMetricsQueue.offer(AIGuardRequests.success(blocked, action, true))) { + break; + } + } + final long nonBlocked = aiGuardRequests.getAndSet(action.ordinal() * 2, 0); + if (nonBlocked > 0) { + if (!rawMetricsQueue.offer(AIGuardRequests.success(nonBlocked, action, false))) { + break; + } + } + } + + // AI Guard failed requests + final int aiGuardErrorRequests = aiGuardErrors.getAndSet(0); + if (aiGuardErrorRequests > 0) { + if (!rawMetricsQueue.offer(AIGuardRequests.error(aiGuardErrorRequests))) { + return; + } + } + + // AI Guard truncated messages + for (final AIGuardTruncationType type : AIGuardTruncationType.values()) { + final long count = aiGuardTruncated.getAndSet(type.ordinal(), 0); + if (count > 0) { + if (!rawMetricsQueue.offer(new AIGuardTruncated(count, type))) { + return; + } + } + } } public abstract static class WafMetric extends MetricCollector.Metric { @@ -579,6 +631,37 @@ public WafInputTruncated(final long counter, final int bitfield) { } } + public static class AIGuardRequests extends WafMetric { + private AIGuardRequests(final long count, final String... tags) { + super("ai_guard.requests", count, tags); + } + + public static AIGuardRequests success( + final long count, final AIGuard.Action action, final boolean block) { + return new AIGuardRequests(count, "action:" + action, "block:" + block, "error:false"); + } + + public static AIGuardRequests error(final long count) { + return new AIGuardRequests(count, "error:true"); + } + } + + public static class AIGuardTruncated extends WafMetric { + public AIGuardTruncated(final long count, final AIGuardTruncationType type) { + super("ai_guard.truncated", count, "type:" + type.tagValue); + } + } + + public enum AIGuardTruncationType { + MESSAGES("messages"), + CONTENT("content"); + public final String tagValue; + + AIGuardTruncationType(final String tagValue) { + this.tagValue = tagValue; + } + } + /** * Mirror of the {@code WafErrorCode} enum defined in the {@code libddwaf-java} module. * diff --git a/internal-api/src/test/groovy/datadog/trace/api/telemetry/WafMetricCollectorTest.groovy b/internal-api/src/test/groovy/datadog/trace/api/telemetry/WafMetricCollectorTest.groovy index b0166987bf1..251bdb4baef 100644 --- a/internal-api/src/test/groovy/datadog/trace/api/telemetry/WafMetricCollectorTest.groovy +++ b/internal-api/src/test/groovy/datadog/trace/api/telemetry/WafMetricCollectorTest.groovy @@ -1,5 +1,11 @@ package datadog.trace.api.telemetry +import static datadog.trace.api.aiguard.AIGuard.Action.ABORT +import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW +import static datadog.trace.api.aiguard.AIGuard.Action.DENY +import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.CONTENT +import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.MESSAGES + import datadog.trace.test.util.DDSpecification import java.util.concurrent.CountDownLatch @@ -507,6 +513,78 @@ class WafMetricCollectorTest extends DDSpecification { metric.tags.toSet() == ['waf_version:waf_ver1', 'event_rules_version:rules.1'].toSet() } + void 'test ai guard request'() { + given: + final collector = WafMetricCollector.get() + + when: + collector.aiGuardRequest(action, block) + + then: + collector.prepareMetrics() + final metrics = collector.drain() + final configErrorMetrics = metrics.findAll { it.metricName == 'ai_guard.requests' } + + final metric = configErrorMetrics[0] + metric.type == 'count' + metric.metricName == 'ai_guard.requests' + metric.namespace == 'appsec' + metric.value == 1 + metric.tags.toSet() == ["action:${action.name()}", "block:${block}", 'error:false'].toSet() + + where: + action | block + ALLOW | true + ALLOW | false + DENY | true + DENY | false + ABORT | true + ABORT | false + } + + void 'test ai guard error'() { + given: + final collector = WafMetricCollector.get() + + when: + collector.aiGuardError() + + then: + collector.prepareMetrics() + final metrics = collector.drain() + final configErrorMetrics = metrics.findAll { it.metricName == 'ai_guard.requests' } + + final metric = configErrorMetrics[0] + metric.type == 'count' + metric.metricName == 'ai_guard.requests' + metric.namespace == 'appsec' + metric.value == 1 + metric.tags.toSet() == ['error:true'].toSet() + } + + void 'test ai guard truncated'() { + given: + final collector = WafMetricCollector.get() + + when: + collector.aiGuardTruncated(type) + + then: + collector.prepareMetrics() + final metrics = collector.drain() + final configErrorMetrics = metrics.findAll { it.metricName == 'ai_guard.truncated' } + + final metric = configErrorMetrics[0] + metric.type == 'count' + metric.metricName == 'ai_guard.truncated' + metric.namespace == 'appsec' + metric.value == 1 + metric.tags.toSet() == ["type:${type.tagValue}"].toSet() + + where: + type << [MESSAGES, CONTENT] + } + /** * Helper method to generate all combinations of n boolean values. */