Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ build/

### Beads ###
.beads/
test-plan-*.md
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>build-info</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down
117 changes: 50 additions & 67 deletions src/main/java/com/contrast/labs/ai/mcp/contrast/AssessService.java
Original file line number Diff line number Diff line change
Expand Up @@ -221,89 +221,72 @@ public List<VulnLight> listVulnsByAppId(



@Tool(name = "list_vulnerabilities_by_application_and_session_metadata", description = "Takes an application name ( app_name ) and session metadata in the form of name / value. and returns a list of vulnerabilities matching that application name and session metadata.")
public List<VulnLight> listVulnsInAppByNameAndSessionMetadata(
@ToolParam(description = "Application name") String app_name,
@Tool(name = "list_vulns_by_app_and_metadata", description = "Takes an application ID (appID) and session metadata in the form of name / value and returns a list of vulnerabilities matching that application ID and session metadata. Use list_applications_with_name first to get the application ID from a name.")
public List<VulnLight> listVulnsByAppIdAndSessionMetadata(
@ToolParam(description = "Application ID") String appID,
@ToolParam(description = "Session metadata field name") String session_Metadata_Name,
@ToolParam(description = "Session metadata field value") String session_Metadata_Value) throws IOException {
logger.info("Listing vulnerabilities for application: {}", app_name);
ContrastSDK contrastSDK = SDKHelper.getSDK(hostName, apiKey, serviceKey, userName,httpProxyHost, httpProxyPort);
logger.info("Listing vulnerabilities for application: {}", appID);

logger.info("metadata : " + session_Metadata_Name+session_Metadata_Value);

logger.debug("Searching for application ID matching name: {}", app_name);

Optional<Application> application = SDKHelper.getApplicationByName(app_name, orgID, contrastSDK);
if(application.isPresent()) {
try {
List<VulnLight> vulns = listVulnsByAppId(application.get().getAppId());
List<VulnLight> returnVulns = new ArrayList<>();
for(VulnLight vuln : vulns) {
if(vuln.sessionMetadata()!=null) {
for(SessionMetadata sm : vuln.sessionMetadata()) {
for(MetadataItem metadataItem : sm.getMetadata()) {
if(metadataItem.getDisplayLabel().equalsIgnoreCase(session_Metadata_Name) &&
metadataItem.getValue().equalsIgnoreCase(session_Metadata_Value)) {
returnVulns.add(vuln);
logger.debug("Found matching vulnerability with ID: {}", vuln.vulnID());
break;
}
}
}
}
try {
List<VulnLight> vulns = listVulnsByAppId(appID);
List<VulnLight> returnVulns = new ArrayList<>();
for(VulnLight vuln : vulns) {
if (vuln.sessionMetadata() == null) {
continue;
}
for (SessionMetadata sm : vuln.sessionMetadata()) {
for (MetadataItem metadataItem : sm.getMetadata()) {
if (metadataItem.getDisplayLabel().equalsIgnoreCase(session_Metadata_Name) &&
metadataItem.getValue().equalsIgnoreCase(session_Metadata_Value)) {
returnVulns.add(vuln);
logger.debug("Found matching vulnerability with ID: {}", vuln.vulnID());
break;
}
}
return returnVulns;
} catch (Exception e) {
logger.error("Error listing vulnerabilities for application: {}", app_name, e);
throw new IOException("Failed to list vulnerabilities: " + e.getMessage(), e);
}
}
} else {
logger.debug("Application with name {} not found, returning empty list", app_name);
return new ArrayList<>();
return returnVulns;
} catch (Exception e) {
logger.error("Error listing vulnerabilities for application: {}", appID, e);
throw new IOException("Failed to list vulnerabilities: " + e.getMessage(), e);
}
}


@Tool(name = "list_vulnerabilities_by_application_and_latest_session", description = "Takes an application name ( app_name ) and returns a list of vulnerabilities for the latest session matching that application name. This is useful for getting the most recent vulnerabilities without needing to specify session metadata.")
public List<VulnLight> listVulnsInAppByNameForLatestSession(
@ToolParam(description = "Application name") String app_name) throws IOException {
logger.info("Listing vulnerabilities for application: {}", app_name);
@Tool(name = "list_vulns_by_app_latest_session", description = "Takes an application ID (appID) and returns a list of vulnerabilities for the latest session matching that application ID. This is useful for getting the most recent vulnerabilities without needing to specify session metadata. Use list_applications_with_name first to get the application ID from a name.")
public List<VulnLight> listVulnsByAppIdForLatestSession(
@ToolParam(description = "Application ID") String appID) throws IOException {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't need a more complete description?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it does. I just consolidated methods, so this must be how this already was. All the tool descriptions and tool names are getting a full makeover. I am getting down to the the tools I want to keep before I invest in that. So yes, this will get a full description,.

logger.info("Listing vulnerabilities for application: {}", appID);
ContrastSDK contrastSDK = SDKHelper.getSDK(hostName, apiKey, serviceKey, userName,httpProxyHost, httpProxyPort);

try {
SDKExtension extension = new SDKExtension(contrastSDK);
SessionMetadataResponse latest = extension.getLatestSessionMetadata(orgID, appID);

logger.debug("Searching for application ID matching name: {}", app_name);
Optional<Application> application = SDKHelper.getApplicationByName(app_name, orgID, contrastSDK);

if(application.isPresent()) {
try {
SDKExtension extension = new SDKExtension(contrastSDK);
SessionMetadataResponse latest = extension.getLatestSessionMetadata(orgID,application.get().getAppId());

// Use SDK's native TraceFilterBody with agentSessionId field
var filterBody = new com.contrastsecurity.models.TraceFilterBody();
if (latest != null && latest.getAgentSession() != null && latest.getAgentSession().getAgentSessionId() != null) {
filterBody.setAgentSessionId(latest.getAgentSession().getAgentSessionId());
}
// Use SDK's native TraceFilterBody with agentSessionId field
com.contrastsecurity.models.TraceFilterBody filterBody = new com.contrastsecurity.models.TraceFilterBody();
if(latest!=null&&latest.getAgentSession()!=null&&latest.getAgentSession().getAgentSessionId()!=null) {
filterBody.setAgentSessionId(latest.getAgentSession().getAgentSessionId());
}

// Use SDK's native getTraces() with expand parameter
Traces tracesResponse = contrastSDK.getTraces(
orgID,
application.get().getAppId(),
filterBody,
EnumSet.of(TraceFilterForm.TraceExpandValue.SESSION_METADATA)
);
// Use SDK's native getTraces() with expand parameter
Traces tracesResponse = contrastSDK.getTraces(
orgID,
appID,
filterBody,
EnumSet.of(TraceFilterForm.TraceExpandValue.SESSION_METADATA)
);

List<VulnLight> vulns = tracesResponse.getTraces().stream()
.map(vulnerabilityMapper::toVulnLight)
.collect(Collectors.toList());
return vulns;
} catch (Exception e) {
logger.error("Error listing vulnerabilities for application: {}", app_name, e);
throw new IOException("Failed to list vulnerabilities: " + e.getMessage(), e);
}
} else {
logger.debug("Application with name {} not found, returning empty list", app_name);
return new ArrayList<>();
List<VulnLight> vulns = tracesResponse.getTraces().stream()
.map(vulnerabilityMapper::toVulnLight)
.collect(Collectors.toList());
return vulns;
} catch (Exception e) {
logger.error("Error listing vulnerabilities for application: {}", appID, e);
throw new IOException("Failed to list vulnerabilities: " + e.getMessage(), e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
*/
package com.contrast.labs.ai.mcp.contrast;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.info.BuildProperties;
import org.springframework.context.annotation.Bean;

import java.util.List;
Expand All @@ -28,10 +33,26 @@
@SpringBootApplication
public class McpContrastApplication {

private static final Logger logger = LoggerFactory.getLogger(McpContrastApplication.class);

public static void main(String[] args) {
SpringApplication.run(McpContrastApplication.class, args);
}

@Bean
public ApplicationRunner logVersion(org.springframework.beans.factory.ObjectProvider<BuildProperties> buildPropertiesProvider) {
return args -> {
BuildProperties buildProperties = buildPropertiesProvider.getIfAvailable();
logger.info("=".repeat(60));
if (buildProperties != null) {
logger.info("Contrast MCP Server - Version {}", buildProperties.getVersion());
} else {
logger.info("Contrast MCP Server - Version information not available");
}
logger.info("=".repeat(60));
};
}

@Bean
public List<ToolCallback> tools(AssessService assessService, SastService sastService,SCAService scaService,ADRService adrService,RouteCoverageService routeCoverageService) {
return of(ToolCallbacks.from(assessService,sastService,scaService,adrService,routeCoverageService));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class SCAService {

@Tool(name = "list_application_libraries", description = "Takes an application ID and returns the libraries used in the application. Use list_applications_with_name first to get the application ID from a name. Note: if class usage count is 0 the library is unlikely to be used")
public List<LibraryExtended> getApplicationLibrariesByID(String appID) throws IOException {
if (appID == null || appID.isEmpty()) {
if (appID == null || appID.isBlank()) {
throw new IllegalArgumentException("Application ID cannot be null or empty");
}
logger.info("Retrieving libraries for application id: {}", appID);
Expand Down
61 changes: 27 additions & 34 deletions src/test/java/com/contrast/labs/ai/mcp/contrast/ADRServiceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.contrast.labs.ai.mcp.contrast.data.PaginatedResponse;
import com.contrast.labs.ai.mcp.contrast.sdkexstension.SDKExtension;
import com.contrast.labs.ai.mcp.contrast.sdkexstension.SDKHelper;
import com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData;
import com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule;
import com.contrast.labs.ai.mcp.contrast.sdkexstension.data.adr.Attack;
import com.contrast.labs.ai.mcp.contrast.sdkexstension.data.adr.AttacksFilterBody;
import com.contrast.labs.ai.mcp.contrast.sdkexstension.data.adr.AttacksResponse;
Expand Down Expand Up @@ -545,15 +547,15 @@ void testGetAttacks_MultipleValidationErrors_CombinesErrors() throws Exception {
@Test
void testGetProtectDataByAppID_Success() throws Exception {
// Given
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData mockProtectData = createMockProtectData(3);
ProtectData mockProtectData = createMockProtectData(3);

mockedSDKExtension = mockConstruction(SDKExtension.class, (mock, context) -> {
when(mock.getProtectConfig(eq(TEST_ORG_ID), eq(TEST_APP_ID)))
.thenReturn(mockProtectData);
});

// When
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData result =
ProtectData result =
adrService.getProtectDataByAppID(TEST_APP_ID);

// Then
Expand All @@ -565,15 +567,15 @@ void testGetProtectDataByAppID_Success() throws Exception {
@Test
void testGetProtectDataByAppID_WithRules() throws Exception {
// Given
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData mockProtectData = createMockProtectDataWithRules();
ProtectData mockProtectData = createMockProtectDataWithRules();

mockedSDKExtension = mockConstruction(SDKExtension.class, (mock, context) -> {
when(mock.getProtectConfig(eq(TEST_ORG_ID), eq(TEST_APP_ID)))
.thenReturn(mockProtectData);
});

// When
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData result =
ProtectData result =
adrService.getProtectDataByAppID(TEST_APP_ID);

// Then
Expand Down Expand Up @@ -637,7 +639,7 @@ void testGetProtectDataByAppID_NoProtectDataReturned() throws Exception {
});

// When
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData result =
ProtectData result =
adrService.getProtectDataByAppID(TEST_APP_ID);

// Then
Expand All @@ -647,8 +649,8 @@ void testGetProtectDataByAppID_NoProtectDataReturned() throws Exception {
@Test
void testGetProtectDataByAppID_EmptyRulesList() throws Exception {
// Given - Protect enabled but no rules configured
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData mockProtectData =
new com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData();
ProtectData mockProtectData =
new ProtectData();
mockProtectData.setRules(new ArrayList<>());

mockedSDKExtension = mockConstruction(SDKExtension.class, (mock, context) -> {
Expand All @@ -657,7 +659,7 @@ void testGetProtectDataByAppID_EmptyRulesList() throws Exception {
});

// When
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData result =
ProtectData result =
adrService.getProtectDataByAppID(TEST_APP_ID);

// Then
Expand All @@ -668,19 +670,14 @@ void testGetProtectDataByAppID_EmptyRulesList() throws Exception {

// ========== Helper Methods ==========

/**
* Creates mock AttacksResponse for testing
*/
private AttacksResponse createMockAttacksResponse(int count, Integer totalCount) {
AttacksResponse response = new AttacksResponse();
response.setAttacks(createMockAttacks(count));
response.setCount(totalCount);
return response;
}

/**
* Creates mock Attack objects for testing
*/

private List<Attack> createMockAttacks(int count) {
List<Attack> attacks = new ArrayList<>();
long baseTime = System.currentTimeMillis();
Expand All @@ -703,17 +700,15 @@ private List<Attack> createMockAttacks(int count) {
return attacks;
}

/**
* Creates mock ProtectData for testing
*/
private com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData createMockProtectData(int ruleCount) {
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData protectData =
new com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData();

List<com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule> rules = new ArrayList<>();
private ProtectData createMockProtectData(int ruleCount) {
ProtectData protectData =
new ProtectData();

List<Rule> rules = new ArrayList<>();
for (int i = 0; i < ruleCount; i++) {
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule rule =
new com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule();
Rule rule =
new Rule();
rule.setName("protect-rule-" + i);
rule.setProduction(i % 2 == 0 ? "block" : "monitor");
rules.add(rule);
Expand All @@ -723,25 +718,23 @@ private com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData createM
return protectData;
}

/**
* Creates mock ProtectData with realistic rule configuration
*/
private com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData createMockProtectDataWithRules() {
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData protectData =
new com.contrast.labs.ai.mcp.contrast.sdkexstension.data.ProtectData();

List<com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule> rules = new ArrayList<>();
private ProtectData createMockProtectDataWithRules() {
ProtectData protectData =
new ProtectData();

List<Rule> rules = new ArrayList<>();

// SQL Injection rule
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule sqlRule =
new com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule();
Rule sqlRule =
new Rule();
sqlRule.setName("sql-injection");
sqlRule.setProduction("block");
rules.add(sqlRule);

// XSS rule
com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule xssRule =
new com.contrast.labs.ai.mcp.contrast.sdkexstension.data.Rule();
Rule xssRule =
new Rule();
xssRule.setName("xss-reflected");
xssRule.setProduction("monitor");
rules.add(xssRule);
Expand Down
Loading