Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/java/io/confluent/rest/TenantDosFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import static io.confluent.rest.TenantUtils.UNKNOWN_TENANT;

import io.confluent.rest.jetty.DoSFilter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -36,6 +40,16 @@ public TenantDosFilter() {
super();
}

@Override
protected void doFilter(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws IOException, ServletException {
if (TenantUtils.isHealthCheckRequest(request)) {
filterChain.doFilter(request, response);
return;
}
super.doFilter(request, response, filterChain);
}

@Override
protected String extractUserId(ServletRequest request) {
// IMPORTANT: If we can't identify the tenant (or get a bad request), return null to skip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ protected void doFilter(
HttpServletRequest request, HttpServletResponse response, FilterChain filterChain
)
throws IOException, ServletException {
if (TenantUtils.isHealthCheckRequest(request)) {
filterChain.doFilter(request, response);
return;
}
// Log tenant classification for all requests (successful and violations)
logTenantClassification(request);

// Let the parent class handle rate tracking and potential violations
super.doFilter(request, response, filterChain);
}
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/java/io/confluent/rest/TenantUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ public final class TenantUtils {

private TenantUtils() {}

/**
* Checks if the request is a health check request that should bypass tenant rate limiting.
* Matches the simple health probe endpoint and health check produce requests.
*/
public static boolean isHealthCheckRequest(HttpServletRequest request) {
if (request == null) {
return false;
}
String path = request.getRequestURI();
if (path == null) {
return false;
}
// Simple health probe endpoint
if (path.equals("/kafka/health")) {
return true;
}
// Health check produce to _confluent-healthcheck-rest topics
// Matches paths like /kafka/v3/clusters/lkc-xxx/topics/_confluent-healthcheck-rest_12/records
if (path.contains("/topics/_confluent-healthcheck")) {
return true;
}
return false;
}

/**
* Extracts tenant ID for request
* Attempts hostname extraction first (applies to V4 networking - majority case),
Expand Down
48 changes: 47 additions & 1 deletion core/src/test/java/io/confluent/rest/TenantUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package io.confluent.rest;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -188,7 +190,51 @@ public void testTenantIdExtractionWithFallback() {
"api.confluent.cloud", TenantUtils.UNKNOWN_TENANT);
}

private void assertTenantExtraction(HttpServletRequest request, String requestURI,
@Test
public void testIsHealthCheckRequest_HealthEndpoint() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getRequestURI()).thenReturn("/kafka/health");
assertTrue(TenantUtils.isHealthCheckRequest(request));
}

@Test
public void testIsHealthCheckRequest_HealthCheckRestTopicProduce() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getRequestURI()).thenReturn(
"/kafka/v3/clusters/lkc-3kv9m/topics/_confluent-healthcheck-rest_12/records");
assertTrue(TenantUtils.isHealthCheckRequest(request));
}

@Test
public void testIsHealthCheckRequest_RegularEndpoints() {
HttpServletRequest request = mock(HttpServletRequest.class);

when(request.getRequestURI()).thenReturn(
"/kafka/v3/clusters/lkc-abc123/topics/my-topic/records");
assertFalse(TenantUtils.isHealthCheckRequest(request));

when(request.getRequestURI()).thenReturn("/kafka/v3/clusters/lkc-abc123");
assertFalse(TenantUtils.isHealthCheckRequest(request));

// Topic name containing the healthcheck substring should not match
when(request.getRequestURI()).thenReturn(
"/kafka/v3/clusters/lkc-abc123/topics/malicioususer_confluent-healthcheck/records");
assertFalse(TenantUtils.isHealthCheckRequest(request));
}

@Test
public void testIsHealthCheckRequest_NullUri() {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getRequestURI()).thenReturn(null);
assertFalse(TenantUtils.isHealthCheckRequest(request));
}

@Test
public void testIsHealthCheckRequest_NullRequest() {
assertFalse(TenantUtils.isHealthCheckRequest(null));
}

private void assertTenantExtraction(HttpServletRequest request, String requestURI,
String serverName, String expectedTenantId) {
when(request.getRequestURI()).thenReturn(requestURI);
when(request.getServerName()).thenReturn(serverName);
Expand Down