diff --git a/.gitignore b/.gitignore index ac90423f3..7d4fcb09d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,35 @@ **/__pycache__/** -.venv/* +**/.venv/** **/policies/** +**/memory/** .env .vscode .backup* -example_sbom* \ No newline at end of file +example_sbom* +reports +.adk +.venv +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2f816fca7..3a417ea17 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,3 +30,7 @@ pytz requests Flask Flask-Cors +google-cloud-storage +networkx +matplotlib +scipy diff --git a/run_agent.py b/run_agent.py index d3e23cabd..9cee7dce7 100644 --- a/run_agent.py +++ b/run_agent.py @@ -11,8 +11,9 @@ async def main(instruction): return # The agent.run method is async - response = await secmind.run(instruction) - print(response) + response = secmind.run_async(instruction) + async for chunk in response: + print(chunk) if __name__ == "__main__": if len(sys.argv) > 1: diff --git a/secmind/agent.py b/secmind/agent.py index e4fa3efb9..b5bea97f0 100644 --- a/secmind/agent.py +++ b/secmind/agent.py @@ -28,7 +28,7 @@ class AgentConfig: """Configuration for the Security Mind agent.""" NAME = "secmind" - MODEL = "gemini-2.5-flash" + MODEL = "gemini-2.5-pro" DESCRIPTION = "Master security agent that delegates tasks." # Task priorities (lower number = higher priority) @@ -59,7 +59,7 @@ class AgentConfig: "code_reviews": "code_review_agent", "cloud_security": "cloud_compliance_agent", "cloud_compliance": "cloud_compliance_agent", - "threat_modelling": "app_sec_agent", + "threat_modelling": "threat_modeling_agent", "application_security": "app_sec_agent", "policy_governance": "policy_agent", "jira_tickets": "jira_agent", @@ -107,7 +107,7 @@ def build_delegation_rules() -> str: - Vulnerabilities and license checks → vuln_triage_agent - Code reviews → code_review_agent - Cloud security posture/compliance → cloud_compliance_agent - - Application security review/Threat Modelling → app_sec_agent + - Application security review/Threat Modelling → threat_modeling_agent - Policy governance questions → policy_agent - Jira tickets → jira_agent @@ -184,7 +184,7 @@ def create_secmind_agent( Create and configure the Security Mind master agent. Args: - model: The AI model to use (default: gemini-2.5-flash) + model: The AI model to use (default: gemini-2.5-pro) sub_agents: List of sub-agents to delegate to (optional) validate: Whether to validate sub-agents before creating agent diff --git a/secmind/memory_manager.py b/secmind/memory_manager.py index 3f777b994..b87479b4a 100644 --- a/secmind/memory_manager.py +++ b/secmind/memory_manager.py @@ -185,6 +185,14 @@ def _setup_sqlite(self): timestamp DATETIME ) """) + # Table for storing public GCS buckets + cursor.execute(""" + CREATE TABLE IF NOT EXISTS public_gcs_buckets ( + project_id TEXT PRIMARY KEY, + buckets_json TEXT, + timestamp DATETIME + ) + """) self.sqlite_conn.commit() def add_triage_result(self, cve_id: str, severity: str, recommendation: str, details: dict): @@ -934,6 +942,48 @@ def get_gce_instance_details(self, project_id: str, instance_name: str, zone: st logger.info(f"GCE instance details cache miss for key: {cache_key}") return None + def add_public_gcs_buckets(self, project_id: str, buckets: dict): + """ + Adds public GCS buckets to the cache. + + Args: + project_id (str): The project ID of the query. + buckets (dict): The buckets to cache. + """ + cursor = self.sqlite_conn.cursor() + timestamp = datetime.now(timezone.utc) + + cursor.execute( + """ + INSERT OR REPLACE INTO public_gcs_buckets (project_id, buckets_json, timestamp) + VALUES (?, ?, ?) + """, + (project_id, json.dumps(buckets), timestamp) + ) + self.sqlite_conn.commit() + logger.info(f"Public GCS buckets cached for project: {project_id}") + + def get_public_gcs_buckets(self, project_id: str) -> dict | None: + """ + Retrieves cached public GCS buckets. + + Args: + project_id (str): The project ID of the query. + + Returns: + The cached buckets, or None if not found. + """ + cursor = self.sqlite_conn.cursor() + cursor.execute("SELECT buckets_json FROM public_gcs_buckets WHERE project_id = ?", (project_id,)) + row = cursor.fetchone() + + if row: + logger.info(f"Public GCS buckets cache hit for project: {project_id}") + return json.loads(row['buckets_json']) + + logger.info(f"Public GCS buckets cache miss for project: {project_id}") + return None + def search_semantic_memory(self, query_text: str, n_results: int = 2) -> list[str]: """ Searches the semantic memory in ChromaDB for contextually relevant information. diff --git a/secmind/sub_agents/cloud_compliance_agent/agent.py b/secmind/sub_agents/cloud_compliance_agent/agent.py index 68beda101..9240c1313 100644 --- a/secmind/sub_agents/cloud_compliance_agent/agent.py +++ b/secmind/sub_agents/cloud_compliance_agent/agent.py @@ -275,6 +275,39 @@ def check_access_keys( return response.to_dict() +def check_public_gcs_buckets(cloud: str, project_id: str) -> dict: + """ + Check for publicly accessible GCS buckets, with caching. + + Args: + cloud: The cloud provider to use (e.g., "gcp") + project_id: GCP project ID (e.g., "my-project") + + Returns: + Dictionary with a list of public buckets and a summary. + + Example: + >>> check_public_gcs_buckets("gcp", "my-project") + """ + logger.info(f"Tool called: check_public_gcs_buckets(cloud={cloud}, project_id={project_id})") + + memory = MemoryManager() + + # Check cache first + cached_buckets = memory.get_public_gcs_buckets(project_id) + if cached_buckets: + return cached_buckets + + client = _get_client(cloud) + response = client.list_public_gcs_buckets(project_id=project_id) + + # Add to cache + if response.status == "success": + memory.add_public_gcs_buckets(project_id, response.to_dict()) + + return response.to_dict() + + def generate_compliance_report(cloud: str, parent: str) -> dict: """ Generates a comprehensive compliance report in HTML format. @@ -309,6 +342,10 @@ def generate_compliance_report(cloud: str, parent: str) -> dict: if keys_result.get("status") == "success": all_data["access_keys"] = keys_result.get("data", {}) + buckets_result = check_public_gcs_buckets(cloud, project_id) + if buckets_result.get("status") == "success": + all_data["public_gcs_buckets"] = buckets_result.get("data", {}) + if org_id: org_policies_result = check_org_policies(cloud, org_id) if org_policies_result.get("status") == "success": @@ -367,6 +404,7 @@ def check_gcp_workload_security(instruction: str) -> dict: check_iam_recommendations, check_org_policies, check_access_keys, + check_public_gcs_buckets, generate_compliance_report, check_gcp_workload_security, ] @@ -374,7 +412,7 @@ def check_gcp_workload_security(instruction: str) -> dict: # Create the agent instance cloud_compliance_agent = Agent( name=build_agent_name(), - model="gemini-2.5-flash", + model="gemini-2.5-pro", description=build_short_description(), instruction=build_agent_instructions(), tools=AGENT_TOOLS, diff --git a/secmind/sub_agents/cloud_compliance_agent/clients/gcp.py b/secmind/sub_agents/cloud_compliance_agent/clients/gcp.py index a91afc4eb..3c7a1fb4c 100644 --- a/secmind/sub_agents/cloud_compliance_agent/clients/gcp.py +++ b/secmind/sub_agents/cloud_compliance_agent/clients/gcp.py @@ -210,6 +210,9 @@ def _get_client(self, client_type: str) -> Any: self._clients[client_type] = orgpolicy_v2.OrgPolicyClient() elif client_type == "iam_admin": self._clients[client_type] = iam_admin_v1.IAMClient() + elif client_type == "storage": + from google.cloud import storage + self._clients[client_type] = storage.Client() else: raise ValueError(f"Unknown client type: {client_type}") @@ -408,6 +411,7 @@ def list_iam_recommendations( "op": op.get("op"), "value": op.get("value"), "originalValue": op.get("originalValue"), + "pathFilters": op.get("pathFilters"), } details["operations"].append(op_details) if not details["operations"]: @@ -535,3 +539,48 @@ def list_service_account_keys( }, message=f"Analyzed {len(keys)} service account keys" ) + + # ======================================================================== + # GCS METHODS + # ======================================================================== + + @handle_gcp_errors + @retry_on_failure() + def list_public_gcs_buckets(self, project_id: str) -> APIResponse: + """ + List publicly accessible GCS buckets. + + Args: + project_id: GCP project ID + + Returns: + APIResponse with list of public buckets and summary + """ + logger.info(f"Listing public GCS buckets for project: {project_id}") + + storage_client = self._get_client("storage") + + public_buckets = [] + + for bucket in storage_client.list_buckets(project=project_id): + policy = bucket.get_iam_policy(requested_policy_version=3) + + for binding in policy.bindings: + if "allUsers" in binding["members"] or "allAuthenticatedUsers" in binding["members"]: + public_buckets.append({ + "name": bucket.name, + "url": f"gs://{bucket.name}", + "roles": binding["role"], + "members": list(binding["members"]), + }) + break # Move to the next bucket once a public binding is found + + logger.info(f"Found {len(public_buckets)} public GCS buckets") + + return APIResponse.success( + data={ + "public_buckets": public_buckets, + "summary": f"Found {len(public_buckets)} publicly accessible buckets." + }, + message=f"Analyzed GCS buckets in project {project_id}" + ) diff --git a/secmind/sub_agents/cloud_compliance_agent/report_generator.py b/secmind/sub_agents/cloud_compliance_agent/report_generator.py index ae8b4543c..c01a7da35 100644 --- a/secmind/sub_agents/cloud_compliance_agent/report_generator.py +++ b/secmind/sub_agents/cloud_compliance_agent/report_generator.py @@ -4,6 +4,16 @@ import datetime from typing import Dict, Any, Optional +import json + +def _format_value(value: Any) -> str: + """Format a value for HTML display.""" + if isinstance(value, dict): + return "
".join([f"  {k}: {_format_value(v)}" for k, v in value.items()]) + elif isinstance(value, list): + return "
".join([f"  - {_format_value(i)}" for i in value]) + else: + return str(value) if value is not None else "N/A" def generate_html_report(data: Dict[str, Any], parent: str, cloud: str) -> str: """ @@ -28,6 +38,8 @@ def generate_html_report(data: Dict[str, Any], parent: str, cloud: str) -> str: org_policies = data.get("org_policies", []) access_keys = data.get("access_keys", {}) non_compliant_keys = access_keys.get("non_compliant", []) + public_buckets_data = data.get("public_gcs_buckets", {}) + public_buckets = public_buckets_data.get("public_buckets", []) html = f""" @@ -135,13 +147,17 @@ def generate_html_report(data: Dict[str, Any], parent: str, cloud: str) -> str:
Non-compliant Keys
{len(non_compliant_keys)}
+
+
Public GCS Buckets
+
{len(public_buckets)}
+

Security Posture Findings

{"" + "".join([f"" for f in findings]) + "
SeverityCategoryDescriptionResource
{f['severity']}{f['category']}{f['description']}{f['resource_name']}
" if findings else "

No security posture findings.

"}

IAM Recommendations

- {"" + "".join([f"" for r in iam_recs]) + "
PriorityDescriptionRecommenderDetails
{r['priority']}{r['description']}{r['recommender_subtype']}{'
'.join([f'Path: {op.get("path", "N/A")}
Op: {op.get("op", "N/A")}
Value: {op.get("value", "N/A")}' for op in r.get('details', {}).get('operations', [])])}
" if iam_recs else "

No IAM recommendations found.

"} + {"" + "".join([f"" for r in iam_recs]) + "
PriorityDescriptionRecommenderDetails
{r['priority']}{r['description']}{r['recommender_subtype']}{'
'.join([f"Resource: {op.get('resource', 'N/A')}
Path: {op.get('path', 'N/A')}
Path Filters: {_format_value(op.get('pathFilters'))}
Value: {_format_value(op.get('value'))}" for op in r.get('details', {}).get('operations', [])])}
" if iam_recs else "

No IAM recommendations found.

"}

Organization Policies

{"" + "".join([f"" for p in org_policies]) + "
ConstraintRules
{p['constraint']}{str(p['rules'])}
" if org_policies else "

No organization policies found.

"} @@ -149,6 +165,9 @@ def generate_html_report(data: Dict[str, Any], parent: str, cloud: str) -> str:

Non-Compliant Access Keys (>{access_keys.get('max_age_days', 90)} days)

{"" + "".join([f"" for k in non_compliant_keys]) + "
Service AccountKey NameAge (days)
{k['service_account']}{k['key_name']}{k['age_days']}
" if non_compliant_keys else "

No non-compliant access keys found.

"} +

Public GCS Buckets

+ {"" + "".join([f"" for b in public_buckets]) + "
Bucket NameURLExposed RolesExposed Members
{b['name']}{b['url']}{b['roles']}{', '.join(b['members'])}
" if public_buckets else "

No publicly accessible GCS buckets found.

"} + diff --git a/secmind/sub_agents/code_review_agent/agent.py b/secmind/sub_agents/code_review_agent/agent.py index bd3dd39d1..8f7746528 100644 --- a/secmind/sub_agents/code_review_agent/agent.py +++ b/secmind/sub_agents/code_review_agent/agent.py @@ -3,7 +3,7 @@ import requests from typing import List from google.adk.agents import Agent -import google.generativeai as genai +import google.genai as genai from pydantic import BaseModel @@ -37,11 +37,9 @@ def review_code(code_snippet: str) -> dict: api_key = os.getenv("GOOGLE_API_KEY") if not api_key: return {"issues": [], "fixes": [], "overall_comments": "Google API key not set."} - - genai.configure(api_key=api_key) - + model = genai.GenerativeModel( - 'gemini-2.5-flash', + 'gemini-2.5-pro', ) # Step 1: Auto-detect the language @@ -140,7 +138,7 @@ def get_github_pr_diff(pr_url: str) -> str: # - Update the Agent configuration to include the new tool: code_review_agent = Agent( name="code_review_agent", - model="gemini-2.5-flash", + model="gemini-2.5-pro", description="Reviews code for security,code smells and best practices.And delegates to jira_agent if issues found.", instruction="""You are a code review agent. Your role is to review the provided code for security vulnerabilities, code smells, readability, and efficiency. The user can provide either a direct code snippet or a GitHub pull request URL. Your goal is to provide a thorough review and suggest necessary fixes. diff --git a/secmind/sub_agents/jira_agent/agent.py b/secmind/sub_agents/jira_agent/agent.py index 017bd6e30..59b9e4ae8 100644 --- a/secmind/sub_agents/jira_agent/agent.py +++ b/secmind/sub_agents/jira_agent/agent.py @@ -27,7 +27,7 @@ def create_jira_issue(project_key: str, summary: str, description: str, issue_ty jira_agent = Agent( name="jira_agent", - model="gemini-2.5-flash", + model="gemini-2.5-pro", description="Creates Jira issues from findings.", instruction=""" Create issues using create_jira_issue with provided context. diff --git a/secmind/sub_agents/policy_agent/agent.py b/secmind/sub_agents/policy_agent/agent.py index a82936ed7..aa1bb9d9a 100644 --- a/secmind/sub_agents/policy_agent/agent.py +++ b/secmind/sub_agents/policy_agent/agent.py @@ -60,7 +60,7 @@ def list_policy_documents() -> dict: policy_agent = Agent( name="policy_agent", - model="gemini-2.5-flash", + model="gemini-2.5-pro", description="Reads policies from local files and answers to user questions from the policy files.", instruction=""" Answer from local policies. diff --git a/secmind/sub_agents/threat_modeling_agent/README.md b/secmind/sub_agents/threat_modeling_agent/README.md index d34a88b0e..1ecece54d 100644 --- a/secmind/sub_agents/threat_modeling_agent/README.md +++ b/secmind/sub_agents/threat_modeling_agent/README.md @@ -69,7 +69,7 @@ app_sec_agent/ ```bash pip install google-adk -pip install google-generativeai +pip install google-genai ``` ## ⚙️ Configuration diff --git a/secmind/sub_agents/threat_modeling_agent/agent.py b/secmind/sub_agents/threat_modeling_agent/agent.py index 71c0138eb..dd5dc8147 100644 --- a/secmind/sub_agents/threat_modeling_agent/agent.py +++ b/secmind/sub_agents/threat_modeling_agent/agent.py @@ -24,7 +24,7 @@ # Create the Threat Modeling Agent threat_modeling_agent = Agent( name="threat_modeling_agent", - model="gemini-2.0-flash-exp", + model="gemini-2.5-pro", description=( "Expert threat modeling agent specializing in threat modeling, " "security architecture review, and vulnerability assessment using STRIDE methodology." diff --git a/secmind/sub_agents/threat_modeling_agent/constants.py b/secmind/sub_agents/threat_modeling_agent/constants.py index f30043ffc..d62a5afa0 100644 --- a/secmind/sub_agents/threat_modeling_agent/constants.py +++ b/secmind/sub_agents/threat_modeling_agent/constants.py @@ -42,7 +42,7 @@ ] # Gemini model configuration -DEFAULT_MODEL = "gemini-1.5-flash" +DEFAULT_MODEL = "gemini-2.5-pro" GENERATION_TEMPERATURE = 0.5 MAX_RETRIES = 3 REQUEST_TIMEOUT = 30 diff --git a/secmind/sub_agents/threat_modeling_agent/dfd_generator.py b/secmind/sub_agents/threat_modeling_agent/dfd_generator.py new file mode 100644 index 000000000..d1aa0fd92 --- /dev/null +++ b/secmind/sub_agents/threat_modeling_agent/dfd_generator.py @@ -0,0 +1,149 @@ +""" +DFD (Data Flow Diagram) generator using NetworkX and Matplotlib. +""" + +import logging +import networkx as nx +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import os +import textwrap + +logger = logging.getLogger(__name__) + +class DFDGenerator: + """Generates a DFD from application details.""" + + def __init__(self, app_details): + self.app_details = app_details + self.graph = nx.DiGraph() + self.node_labels = {} + self.edge_labels = {} + + def _wrap_text(self, text, width=20): + """Wraps text to a specified width.""" + return '\n'.join(textwrap.wrap(text, width=width)) + + def generate_dfd(self): + """ + Generates the DFD from the application details and saves it to a file. + + Returns: + A string containing the path to the generated DFD image. + """ + self._add_components() + self._add_data_flows() + self._add_external_services() + + if not self.graph: + logger.warning("DFD graph is empty. Skipping DFD generation.") + return None + + plt.figure(figsize=(25, 20)) + pos = nx.spring_layout(self.graph, k=0.8, iterations=100, seed=42) + + # Group nodes by type + entity_nodes = [node for node, attr in self.graph.nodes(data=True) if attr.get('type') in ['frontend', 'external_service']] + process_nodes = [node for node, attr in self.graph.nodes(data=True) if attr.get('type') == 'service'] + store_nodes = [node for node, attr in self.graph.nodes(data=True) if attr.get('type') == 'database'] + other_nodes = [node for node, attr in self.graph.nodes(data=True) if attr.get('type') not in ['frontend', 'external_service', 'service', 'database']] + + + # Draw nodes + nx.draw_networkx_nodes(self.graph, pos, nodelist=entity_nodes, node_size=5000, node_color='lightblue', node_shape='s') + nx.draw_networkx_nodes(self.graph, pos, nodelist=process_nodes, node_size=5000, node_color='lightgreen', node_shape='o') + nx.draw_networkx_nodes(self.graph, pos, nodelist=store_nodes, node_size=5000, node_color='lightyellow', node_shape='s') + nx.draw_networkx_nodes(self.graph, pos, nodelist=other_nodes, node_size=5000, node_color='lightgray', node_shape='s') + + + # Draw edges + nx.draw_networkx_edges( + self.graph, pos, arrowstyle='->', arrowsize=20, + connectionstyle='arc3,rad=0.1' + ) + + # Draw labels + for node in store_nodes: + self.node_labels[node] = f"<>\n{self.node_labels[node]}" + + entity_and_other_labels = {node: self.node_labels[node] for node in entity_nodes + other_nodes} + process_labels = {node: self.node_labels[node] for node in process_nodes} + store_labels = {node: self.node_labels[node] for node in store_nodes} + + nx.draw_networkx_labels(self.graph, pos, labels=entity_and_other_labels, font_size=10, font_weight='bold', + bbox=dict(facecolor="lightblue", edgecolor='black', boxstyle='round,pad=0.2')) + nx.draw_networkx_labels(self.graph, pos, labels=process_labels, font_size=10, font_weight='bold', + bbox=dict(facecolor="lightgreen", edgecolor='black', boxstyle='round,pad=0.2')) + nx.draw_networkx_labels(self.graph, pos, labels=store_labels, font_size=10, font_weight='bold', + bbox=dict(facecolor="lightyellow", edgecolor='black', boxstyle='round,pad=0.2')) + + nx.draw_networkx_edge_labels( + self.graph, pos, edge_labels=self.edge_labels, font_size=10, + label_pos=0.3, font_color='red' + ) + + # Draw trust boundaries + self._add_trust_boundaries(pos) + + plt.title("Data Flow Diagram", size=15) + plt.axis('off') + + if not os.path.exists('reports'): + os.makedirs('reports') + + dfd_image_path = "reports/dfd.png" + plt.savefig(dfd_image_path, bbox_inches='tight') + plt.close() + + logger.info(f"DFD image saved to {dfd_image_path}") + return dfd_image_path + + def _add_components(self): + """Adds components to the DFD.""" + if 'components' in self.app_details: + for component in self.app_details['components']: + self.graph.add_node(component['id'], type=component.get('type')) + self.node_labels[component['id']] = self._wrap_text(component['name']) + + def _add_data_flows(self): + """Adds data flows to the DFD.""" + if 'data_flows' in self.app_details: + for flow in self.app_details['data_flows']: + source = flow.get('from') or flow.get('source') or flow.get('src') + destination = flow.get('to') or flow.get('destination') or flow.get('dest') + + if not source or not destination: + logger.warning(f"Skipping data flow due to missing source or destination: {flow}") + continue + + self.graph.add_edge(source, destination) + self.edge_labels[(source, destination)] = self._wrap_text(flow.get('label', '')) + + def _add_external_services(self): + """Adds external services to the DFD.""" + if 'external_services' in self.app_details: + for service in self.app_details['external_services']: + self.graph.add_node(service['id'], type='external_service') + self.node_labels[service['id']] = self._wrap_text(service['name']) + + def _add_trust_boundaries(self, pos): + """Adds trust boundaries to the DFD.""" + if 'trust_boundaries' in self.app_details: + for i, boundary in enumerate(self.app_details['trust_boundaries']): + component_pos = {comp_id: pos[comp_id] for comp_id in boundary['components'] if comp_id in pos} + if not component_pos: + continue + + # Get the bounding box of the components in the trust boundary + min_x = min(p[0] for p in component_pos.values()) - 0.1 + max_x = max(p[0] for p in component_pos.values()) + 0.1 + min_y = min(p[1] for p in component_pos.values()) - 0.1 + max_y = max(p[1] for p in component_pos.values()) + 0.1 + + # Draw a rectangle around the components + rect = plt.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, + fill=False, edgecolor='red', linestyle='--', linewidth=2) + plt.gca().add_patch(rect) + plt.text(min_x, max_y + 0.05, boundary['name'], fontsize=10, color='red') + diff --git a/secmind/sub_agents/threat_modeling_agent/instruction_builder.py b/secmind/sub_agents/threat_modeling_agent/instruction_builder.py index 66813570a..07db4f47e 100644 --- a/secmind/sub_agents/threat_modeling_agent/instruction_builder.py +++ b/secmind/sub_agents/threat_modeling_agent/instruction_builder.py @@ -70,6 +70,71 @@ def build_agent_instructions() -> str: - Call `generate_threat_model_report(app_details_json)` - The tool will return a comprehensive threat model + **JSON Format for app_details:** + When you have gathered all the necessary details, you must format them into a JSON object with the following structure. Pay close attention to the `data_flows` section. + + ```json + {{ + "name": "Application Name", + "description": "A brief description of the application.", + "components": [ + {{ + "id": "user_interface", + "name": "Web UI", + "type": "frontend", + "technology": "React" + }}, + {{ + "id": "api_gateway", + "name": "API Gateway", + "type": "service", + "technology": "nginx" + }}, + {{ + "id": "backend_api", + "name": "Backend API", + "type": "service", + "technology": "Node.js" + }}, + {{ + "id": "database", + "name": "PostgreSQL DB", + "type": "database", + "technology": "PostgreSQL" + }} + ], + "data_flows": [ + {{ + "from": "user_interface", + "to": "api_gateway", + "label": "HTTP requests" + }}, + {{ + "from": "api_gateway", + "to": "backend_api", + "label": "proxied requests" + }}, + {{ + "from": "backend_api", + "to": "database", + "label": "SQL queries" + }} + ], + "external_services": [ + {{ + "id": "auth_service", + "name": "OAuth Provider" + }} + ], + "trust_boundaries": [ + {{ + "name": "DMZ", + "components": ["api_gateway"] + }} + ] + }} + ``` + 4. **Report Presentation:** Present the threat model report in a clear, structured format: diff --git a/secmind/sub_agents/threat_modeling_agent/models.py b/secmind/sub_agents/threat_modeling_agent/models.py index b8e3ae6cf..1b556e3d5 100644 --- a/secmind/sub_agents/threat_modeling_agent/models.py +++ b/secmind/sub_agents/threat_modeling_agent/models.py @@ -45,6 +45,7 @@ class ThreatModelReport(TypedDict): vulnerabilities: List[VulnerabilityDetails] recommendations: Recommendations compliance_notes: Optional[List[str]] + dfd: Optional[str] class ThreatModelResult(TypedDict): diff --git a/secmind/sub_agents/threat_modeling_agent/report_generator.py b/secmind/sub_agents/threat_modeling_agent/report_generator.py new file mode 100644 index 000000000..5c6b94425 --- /dev/null +++ b/secmind/sub_agents/threat_modeling_agent/report_generator.py @@ -0,0 +1,173 @@ +""" +HTML report generator for the threat modeling agent. +""" + +import datetime +from typing import Dict, Any +import base64 +import os + +def generate_html_report(data: Dict[str, Any]) -> str: + """ + Generate an HTML report from threat model data. + + Args: + data: Dictionary containing threat model data. + + Returns: + HTML report as a string. + """ + report_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + overview = data.get("overview", "No overview provided") + risk_score = data.get("risk_score", "N/A") + identified_threats = data.get("identified_threats", []) + vulnerabilities = data.get("vulnerabilities", []) + recommendations = data.get("recommendations", {}) + compliance_notes = data.get("compliance_notes", []) + dfd_path = data.get("dfd") + + dfd_html = "" + if dfd_path and os.path.exists(dfd_path): + try: + with open(dfd_path, "rb") as f: + dfd_base64 = base64.b64encode(f.read()).decode("utf-8") + dfd_html = f'Data Flow Diagram' + except Exception as e: + dfd_html = f"

Error rendering DFD: {e}

" + elif dfd_path: + dfd_html = f"

DFD image not found at: {dfd_path}

" + + html = f""" + + + + + + Threat Model Report + + + +
+

Threat Model Report

+

Report Date: {report_date}

+ +
+

The threats and vulnerabilities identified in this document are theoretical findings from the threat modeling exercise and do not represent confirmed or active security incidents.

+
+ +

Executive Summary

+
+
+
Risk Score
+
{risk_score}
+
+
+

Overview

+

{overview}

+ +

Data Flow Diagram

+
+ {dfd_html} +
+ +

Identified Threats

+ {"" + "".join([f"" for t in identified_threats]) + "
ThreatDescriptionSTRIDE CategoryLikelihoodImpactAffected Components
{t['threat']}{t['description']}{t['stride_category']}{t['likelihood']}{t['impact']}{', '.join(t['affected_components'])}
" if identified_threats else "

No identified threats.

"} + +

Vulnerabilities

+ {"" + "".join([f"" for v in vulnerabilities]) + "
VulnerabilityDescriptionSeverityComponentCWE IDRemediation
{v['vulnerability']}{v['description']}{v['severity']}{v['component']}{v['cwe_id'] or 'N/A'}{v['remediation']}
" if vulnerabilities else "

No vulnerabilities found.

"} + +

Recommendations

+ {"".join([f"

{category.replace('_', ' ').title()}

" for category, rec_list in recommendations.items() if rec_list]) if recommendations else "

No recommendations.

"} + +

Compliance Notes

+ {"" if compliance_notes else "

No compliance notes.

"} + + +
+ + + """ + return html diff --git a/secmind/sub_agents/threat_modeling_agent/threat_modeler.py b/secmind/sub_agents/threat_modeling_agent/threat_modeler.py index 3f79c0acb..f6f62618b 100644 --- a/secmind/sub_agents/threat_modeling_agent/threat_modeler.py +++ b/secmind/sub_agents/threat_modeling_agent/threat_modeler.py @@ -11,6 +11,8 @@ from .prompt_builder import ThreatModelPromptBuilder from .constants import DEFAULT_MODEL, GENERATION_TEMPERATURE, MAX_RETRIES from secmind.memory_manager import MemoryManager +from .dfd_generator import DFDGenerator +from . import report_generator logger = logging.getLogger(__name__) @@ -18,34 +20,28 @@ class ThreatModeler: """Handles threat modeling operations with Gemini AI.""" - def __init__(self, api_key: Optional[str] = None, model: str = DEFAULT_MODEL, memory_manager: Optional[MemoryManager] = None): + def __init__(self, model: str = DEFAULT_MODEL, memory_manager: Optional[MemoryManager] = None): """ Initialize threat modeler. Args: - api_key: Google API key (defaults to GOOGLE_API_KEY env var) model: Gemini model to use memory_manager: Instance of MemoryManager for caching """ - self.api_key = api_key or os.getenv("GOOGLE_API_KEY") self.model_name = model self.prompt_builder = ThreatModelPromptBuilder() self.memory = memory_manager or MemoryManager() - - if not self.api_key: - raise ValueError("Google API key not set. Set GOOGLE_API_KEY environment variable.") - - genai.configure(api_key=self.api_key) - self.model = genai.GenerativeModel(self.model_name) + self.client = genai.GenerativeModel(self.model_name) + logger.info(f"Initialized ThreatModeler with model: {self.model_name}") - + def generate_threat_model(self, app_details: Dict[str, Any]) -> ThreatModelResult: """ Generate a threat model report for an application, with caching. - + Args: app_details: Dictionary containing application details - + Returns: ThreatModelResult with status and report """ @@ -60,7 +56,7 @@ def generate_threat_model(self, app_details: Dict[str, Any]) -> ThreatModelResul try: logger.info("Generating threat model...") - + # Validate input if not app_details: return { @@ -68,34 +64,38 @@ def generate_threat_model(self, app_details: Dict[str, Any]) -> ThreatModelResul "message": "Application details cannot be empty.", "report": None } - + + # Generate DFD + dfd_generator = DFDGenerator(app_details) + dfd = dfd_generator.generate_dfd() + # Build prompt prompt = self.prompt_builder.build_threat_model_prompt(app_details) logger.debug(f"Generated prompt (length: {len(prompt)} chars)") - + # Generate content with retries report_data = self._generate_with_retry(prompt) - + if not report_data: return { "status": "error", "message": "Failed to generate threat model after retries.", "report": None } - + # Validate report structure - validated_report = self._validate_report(report_data) - + validated_report = self._validate_report(report_data, dfd=dfd) + # Add to cache self.memory.add_threat_model(app_details, validated_report) - + logger.info("Threat model generated successfully") return { "status": "success", "report": validated_report, "message": None } - + except json.JSONDecodeError as e: logger.error(f"JSON parsing error: {e}") return { @@ -110,42 +110,43 @@ def generate_threat_model(self, app_details: Dict[str, Any]) -> ThreatModelResul "message": f"Failed to generate threat model: {str(e)}", "report": None } - + def _generate_with_retry(self, prompt: str, retries: int = MAX_RETRIES) -> Optional[Dict[str, Any]]: """ Generate content with retry logic. - + Args: prompt: Prompt to send to the model retries: Number of retries - + Returns: Parsed JSON response or None """ for attempt in range(retries): try: logger.debug(f"Generation attempt {attempt + 1}/{retries}") - - response = self.model.generate_content( + + generation_config = genai.types.GenerationConfig( + response_mime_type="application/json", + temperature=GENERATION_TEMPERATURE, + ) + response = self.client.generate_content( prompt, - generation_config=genai.GenerationConfig( - response_mime_type="application/json", - temperature=GENERATION_TEMPERATURE - ) + generation_config=generation_config ) - + # Parse response json_str = response.text.strip() report_data = json.loads(json_str) - + return report_data - + except json.JSONDecodeError as e: logger.warning(f"Attempt {attempt + 1} - JSON decode error: {e}") if attempt == retries - 1: raise continue - + except Exception as e: logger.warning(f"Attempt {attempt + 1} - Error: {e}") if attempt == retries - 1: @@ -154,12 +155,13 @@ def _generate_with_retry(self, prompt: str, retries: int = MAX_RETRIES) -> Optio return None - def _validate_report(self, report_data: Dict[str, Any]) -> ThreatModelReport: + def _validate_report(self, report_data: Dict[str, Any], dfd: Optional[str] = None) -> ThreatModelReport: """ Validate and normalize report structure. Args: report_data: Raw report data from AI + dfd: Path to the DFD image file Returns: Validated ThreatModelReport @@ -171,7 +173,8 @@ def _validate_report(self, report_data: Dict[str, Any]) -> ThreatModelReport: "identified_threats": report_data.get("identified_threats", []), "vulnerabilities": report_data.get("vulnerabilities", []), "recommendations": report_data.get("recommendations", {}), - "compliance_notes": report_data.get("compliance_notes") + "compliance_notes": report_data.get("compliance_notes"), + "dfd": dfd } # Validate threats @@ -199,7 +202,7 @@ def get_threat_modeler() -> ThreatModeler: return _threat_modeler -def generate_threat_model_report(app_details: str) -> dict: +def generate_threat_model_report(app_details: str) -> str: """ Generate a threat modeling report based on application details. @@ -209,7 +212,7 @@ def generate_threat_model_report(app_details: str) -> dict: app_details: JSON string with application details Returns: - Dictionary with status and report + A string indicating the success or failure of the report generation. Example: >>> details = json.dumps({ @@ -218,8 +221,8 @@ def generate_threat_model_report(app_details: str) -> dict: ... "cloud_config": "EC2 + RDS" ... }) >>> result = generate_threat_model_report(details) - >>> print(result["status"]) - "success" + >>> print(result) + "Successfully generated threat model report: threat_model_report.html" """ try: # Parse JSON string @@ -231,20 +234,23 @@ def generate_threat_model_report(app_details: str) -> dict: # Get threat modeler and generate report modeler = get_threat_modeler() result = modeler.generate_threat_model(app_details_dict) + + if result["status"] == "error": + return f"Failed to generate threat model: {result['message']}" + + # Generate HTML report + html_report = report_generator.generate_html_report(result["report"]) + + # Save the report + report_path = "threat_model_report.html" + with open(report_path, "w") as f: + f.write(html_report) - return result + return f"Successfully generated threat model report: {report_path}" except json.JSONDecodeError as e: logger.error(f"Invalid JSON in app_details: {e}") - return { - "status": "error", - "message": f"Invalid JSON format: {str(e)}", - "report": None - } + return f"Error: Invalid JSON format: {str(e)}" except Exception as e: logger.error(f"Error in generate_threat_model_report: {e}") - return { - "status": "error", - "message": f"Error: {str(e)}", - "report": None - } + return f"Error: {str(e)}" diff --git a/secmind/sub_agents/vuln_triage_agent/agent.py b/secmind/sub_agents/vuln_triage_agent/agent.py index bb9847cb2..3a1a3a407 100644 --- a/secmind/sub_agents/vuln_triage_agent/agent.py +++ b/secmind/sub_agents/vuln_triage_agent/agent.py @@ -35,7 +35,7 @@ search_agent = Agent( - model='gemini-2.5-flash', + model='gemini-2.5-pro', name='SearchAgent', instruction=""" You're a specialist in Google Search @@ -49,7 +49,7 @@ vuln_triage_agent = Agent( name="vuln_triage_agent", - model="gemini-2.5-flash", + model="gemini-2.5-pro", description=( "Triages vulnerabilities and verifies software package licenses " "across multiple ecosystems, with SBOM parsing support. "