Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 32 additions & 34 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
# API endpoints for TZU application
import datetime
import shutil
import os
import shutil
import locale
import zoneinfo
from uuid import UUID
from datetime import datetime, timedelta
from typing import List, Optional

# Third-party imports
from fastapi import FastAPI, HTTPException, Depends, UploadFile, Body, status, Security, Path
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles

from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import joinedload
from datetime import datetime, timedelta
from typing import List, Optional
from sqlalchemy.orm import Session, joinedload
from jose import JWTError, jwt
from passlib.context import CryptContext

# Configure timezone from environment variable
import locale
import zoneinfo

# Set timezone if TZ environment variable is set
if 'TZ' in os.environ:
try:
Expand All @@ -40,19 +38,7 @@
import init_db
from tzu_ai import clientAI
from utils import save_image

# Local imports
import models
import schemas
import crud
import database
import utils
import init_db

from tzu_ai import clientAI
from utils import save_image

import os
from stride_validator import normalize_stride_category, get_valid_stride_categories

# Configuración basada en entorno
ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
Expand Down Expand Up @@ -187,8 +173,6 @@ async def get_current_active_user(current_user: schemas.User = Depends(get_curre
return current_user


# models.Base.metadata.create_all(bind=database.engine)

# Batch endpoint to update risk values for multiple threats at once
class ThreatRiskUpdate(BaseModel):
threat_id: str
Expand Down Expand Up @@ -273,7 +257,7 @@ async def create_threat_for_system(
except Exception:
raise HTTPException(status_code=400, detail="El id del sistema no es un UUID válido")

# Verify that system exists
# Verify that system exists
system = db.query(models.InformationSystem).filter(models.InformationSystem.id == system_uuid).first()
if not system:
raise HTTPException(status_code=404, detail="Information System not found")
Expand Down Expand Up @@ -308,11 +292,18 @@ async def create_threat_for_system(
remediation = crud.create_remediation(db, remediation_data.get('description', ''))

# Crear threat
# Normalizar categoría STRIDE para amenaza manual
raw_type = threat_data.get('type', 'Spoofing')
normalized_type = normalize_stride_category(raw_type)
if not normalized_type:
print(f"⚠️ Warning: Invalid STRIDE category '{raw_type}' in manual threat, using 'Spoofing'")
normalized_type = 'Spoofing'

threat = crud.create_threat(
db,
title=threat_data.get('title', 'Nueva Amenaza'),
description=threat_data.get('description', ''),
type=threat_data.get('type', 'Spoofing'),
type=normalized_type,
information_system_id=system_uuid,
risk_id=risk.id,
remediation_id=remediation.id
Expand Down Expand Up @@ -345,8 +336,12 @@ async def evaluate(file: UploadFile, information_system_id: str, db: Session = D
try:
# Guardar la imagen y obtener base64
print(f"Procesando archivo: {file.filename}")
image_b64 = save_image(file)
db_information_system = crud.attach_diagram(db, information_system_id=information_system_id, image_path=file.filename)
image_b64, saved_filename = save_image(file)

if not image_b64 or not saved_filename:
return {"message": "Error al procesar la imagen", "success": False}

db_information_system = crud.attach_diagram(db, information_system_id=information_system_id, image_path=saved_filename)

# Obtener análisis de la IA
print("Llamando a clientAI...")
Expand All @@ -372,9 +367,16 @@ async def evaluate(file: UploadFile, information_system_id: str, db: Session = D
# Procesar las amenazas encontradas
for i in result.threats:
print(i)

# Normalizar categoría STRIDE
normalized_type = normalize_stride_category(i.type)
if not normalized_type:
print(f"⚠️ Warning: Invalid STRIDE category '{i.type}' normalized to 'Spoofing'")
normalized_type = 'Spoofing' # Default fallback

remediation = crud.create_remediation(db, i.remediation)
risk = crud.create_risk(db, i.risk)
threat = crud.create_threat(db, i.title, i.description, i.categories, UUID(information_system_id), risk.id, remediation.id)
threat = crud.create_threat(db, i.title, i.description, normalized_type, UUID(information_system_id), risk.id, remediation.id)

print(f"Se encontraron {len(result.threats)} amenazas")
return {"information_system": db_information_system, "message": f"Se analizó el diagrama exitosamente y se encontraron {len(result.threats)} amenazas", "success": True}
Expand All @@ -389,10 +391,6 @@ async def evaluate(information_system:schemas.InformationSystemBaseCreate, db: S
db_information_system = crud.create_information_system(db, information_system=information_system)
return db_information_system

# result = clientAI(item)
# print(result.content)
# return JSONResponse(content=result.content)

# Endpoint para actualizar los riesgos de todas las amenazas asociadas a un information_system_id
@app.put("/information_systems/{information_system_id}/threats/risk/batch", response_model=list[schemas.Threat])
async def update_threats_risk_by_system(
Expand Down
16 changes: 14 additions & 2 deletions api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

import schemas

# Import STRIDE normalization function
from stride_validator import normalize_stride_category

# Security configuration for passwords and JWT
import os
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
Expand Down Expand Up @@ -95,8 +98,17 @@ def update_threat_risk(db: Session, threat_id: str, data: dict):
# Update Threat fields
for key in ['title', 'type', 'description']:
if key in data:
print(f"Actualizando Threat {key}: {getattr(threat, key)} -> {data[key]}")
setattr(threat, key, data[key])
value = data[key]
# Normalizar categoría STRIDE si es el campo 'type'
if key == 'type':
normalized_value = normalize_stride_category(value)
if not normalized_value:
print(f"⚠️ Warning: Invalid STRIDE category '{value}' in update, using 'Spoofing'")
value = 'Spoofing'
else:
value = normalized_value
print(f"Actualizando Threat {key}: {getattr(threat, key)} -> {value}")
setattr(threat, key, value)
# Update Remediation field
if remediation and 'remediation' in data and isinstance(data['remediation'], dict):
remediation_data = data['remediation']
Expand Down
70 changes: 70 additions & 0 deletions api/stride_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# STRIDE Category Validation and Normalization
"""
Módulo para validar y normalizar categorías STRIDE.
Asegura consistencia en la nomenclatura de las categorías.
"""

# STRIDE Categories constants for validation
VALID_STRIDE_CATEGORIES = {
'Spoofing',
'Tampering',
'Repudiation',
'Information Disclosure',
'Denial of Service',
'Elevation of Privilege'
}

def normalize_stride_category(category_input):
"""
Normaliza y valida categorías STRIDE para asegurar consistencia.

Args:
category_input (str): Categoría STRIDE a normalizar

Returns:
str: Categoría normalizada si es válida, None si no es válida

Examples:
normalize_stride_category("spoofing") -> "Spoofing"
normalize_stride_category("INFORMATION DISCLOSURE") -> "Information Disclosure"
normalize_stride_category("invalid") -> None
"""
if not category_input or not isinstance(category_input, str):
return None

# Limpiar espacios y normalizar
cleaned = category_input.strip()

# Buscar coincidencia exacta (case-insensitive)
for valid_category in VALID_STRIDE_CATEGORIES:
if cleaned.lower() == valid_category.lower():
return valid_category

# Buscar coincidencia parcial (por si la IA devuelve texto extra)
for valid_category in VALID_STRIDE_CATEGORIES:
if valid_category.lower() in cleaned.lower():
return valid_category

# Si no encuentra coincidencia, retornar None
return None

def is_valid_stride_category(category):
"""
Verifica si una categoría es válida.

Args:
category (str): Categoría a verificar

Returns:
bool: True si es válida, False en caso contrario
"""
return normalize_stride_category(category) is not None

def get_valid_stride_categories():
"""
Retorna la lista de categorías STRIDE válidas.

Returns:
set: Conjunto de categorías válidas
"""
return VALID_STRIDE_CATEGORIES.copy()
11 changes: 8 additions & 3 deletions api/tzu_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@

Important requirements:
- Each threat must explicitly mention the **asset or flow** affected in the diagram (e.g., login form, API Gateway, session token, OTP mechanism, transaction service).
- Each threat must be classified into at least one **STRIDE category** and mapped to **MASVS/ASVS controls** if relevant.
- Each threat must be classified into exactly ONE **STRIDE category**: Spoofing, Tampering, Repudiation, Information Disclosure, Denial of Service, or Elevation of Privilege.
- Each threat must include **concrete remediation controls**, aligned with ASVS/MASVS requirements and the Reglamento de Ciberseguridad de la SBS Perú (e.g., MFA required for financial transactions, SMS OTP not valid, secure session management, signed audit logs).
- For compliance-related threats, explicitly reference the **SBS Perú Cybersecurity Regulation**.
- Use ONLY the allowed numeric values for OWASP Risk Rating factors (no decimals, no values outside the list).
- Output MUST be in **Spanish** and ONLY in JSON format.

Reference format for remediation:
- ASVS: "ASVS V[número].[subnúmero] - [nombre del control]" (e.g., "ASVS V2.6 - Multi-factor Authentication")
- MASVS: "MASVS-[categoría]-[número] - [nombre del control]" (e.g., "MASVS-AUTH-2 - Session Management")
- SBS: "SBS Reg. Ciberseguridad Art. [número] - [descripción breve]" (e.g., "SBS Reg. Ciberseguridad Art. 12 - Autenticación Multifactor")

Allowed values:
Threat Agent Factors:
- skill_level: [0, 1, 3, 5, 6, 9]
Expand Down Expand Up @@ -51,8 +56,8 @@
{{
"title": "Threat Title",
"description": "Detailed threat description.",
"categories": "STRIDE Category and MASVS/ASVS Category if applicable",
"remediation": "Recommended mitigation aligned with ASVS/MASVS and SBS regulation",
"type": "One STRIDE category: Spoofing | Tampering | Repudiation | Information Disclosure | Denial of Service | Elevation of Privilege",
"remediation": "Recommended mitigation with references (ASVS V[x].[y] - [control name], MASVS-[CAT]-[num] - [control name], SBS Reg. Ciberseguridad Art. [num] - [description])",
"risk": {{
"skill_level": "value from list",
"motive": "value from list",
Expand Down
34 changes: 27 additions & 7 deletions api/utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,55 @@
import base64
import shutil
import os
import uuid
from pathlib import Path

def save_image(file):
"""
Save image file and encode it to base64
Save image file with UUID name and encode it to base64
Returns: tuple (image_base64, saved_filename)
"""
image_base64 = None
saved_filename = None

try:
# Ensure diagrams directory exists
import os
if not os.path.exists("diagrams"):
os.makedirs("diagrams")

# Get file extension from original filename
original_filename = file.filename
file_extension = Path(original_filename).suffix.lower()

# Validate file extension (security measure)
allowed_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.svg'}
if file_extension not in allowed_extensions:
raise ValueError(f"Tipo de archivo no permitido: {file_extension}")

# Generate UUID-based filename
unique_id = str(uuid.uuid4())
saved_filename = f"{unique_id}{file_extension}"
file_path = f"diagrams/{saved_filename}"

# Save file physically
file_path = f"diagrams/{file.filename}"
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)

# Reset file pointer and read to encode
file.file.seek(0)
image_base64 = base64.b64encode(file.file.read()).decode('utf-8')

print(f"Imagen guardada en {file_path} y codificada en base64")
print(f"Imagen '{original_filename}' guardada como '{saved_filename}' y codificada en base64")

except Exception as e:
print(f"Error al guardar la imagen: {e}")
return None
return None, None
finally:
# Ensure file is closed
file.file.close()
if hasattr(file.file, 'close'):
file.file.close()

return image_base64
return image_base64, saved_filename


def encode_image(image_path):
Expand Down
6 changes: 4 additions & 2 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: "0.1"

services:
postgresql:
image: postgres:15
Expand All @@ -25,6 +23,8 @@ services:
command: ["./backend.entrypoint.sh"]
# volumes: # Uncomment for development
# - ../api:/app
volumes:
- diagrams_data:/app/diagrams # Volumen persistente para diagramas
depends_on:
postgresql:
condition: service_healthy
Expand Down Expand Up @@ -54,12 +54,14 @@ services:
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
- frontend_build:/usr/share/nginx/html:ro
- diagrams_data:/usr/share/nginx/diagrams:ro # Servir diagramas estáticamente
networks:
- tzu_net

volumes:
postgres_data:
frontend_build:
diagrams_data: # Volumen persistente para las imágenes de diagramas

networks:
tzu_net:
Expand Down
14 changes: 8 additions & 6 deletions docker/nginx.conf
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ http {
proxy_read_timeout 60s;
}

# Static files from backend
# Static files from backend - servir diagramas directamente desde volumen
location /diagrams/ {
proxy_pass http://backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
alias /usr/share/nginx/diagrams/;
expires 1d;
add_header Cache-Control "public, max-age=86400";

# Security headers
add_header X-Frame-Options DENY always;
add_header X-Content-Type-Options nosniff always;
}

# Serve React app for all other routes
Expand Down
Loading
Loading