diff --git a/.gitignore b/.gitignore index 1aad821..57f97a0 100644 --- a/.gitignore +++ b/.gitignore @@ -476,3 +476,12 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk *.db + +.env +docs + +# Local development overrides +docker-compose.override.yml + +# Shadow mode metrics (generated at runtime) +shadow-mode-metrics.json \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..5606fcc --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,92 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build Commands + +### Docker (Full Stack) +```bash +# Build both containers +docker-compose build + +# Run the full stack +docker-compose up +``` + +### Orchestrator (.NET 8) +```bash +# Build +dotnet build orchestrator/ModelScanner.sln + +# Run (development mode with in-memory job storage) +dotnet run --project orchestrator/ModelScanner + +# Run tests +dotnet test orchestrator/ModelScanner.sln + +# Run a specific test +dotnet test orchestrator/ModelScanner.sln --filter "FullyQualifiedName~HashTaskTests" +``` + +### Model Scanner Container +```bash +docker build -t civitai-model-scanner ./model-scanner/ +docker run -it --rm civitai-model-scanner 'https://example.com/model.bin' +``` + +## Architecture + +This is a distributed AI model scanning system with two main components: + +### 1. Model Scanner Container (Python) +Located in `model-scanner/`. A Docker container running: +- **picklescan**: Detects dangerous pickle imports in PyTorch models +- **clamscan**: ClamAV malware detection +- Python ML libraries (PyTorch CPU, safetensors) for model processing + +### 2. Orchestrator Service (.NET 8) +Located in `orchestrator/ModelScanner/`. An ASP.NET Core web API that: +- Receives scan requests via HTTP endpoints +- Queues jobs using Hangfire (SQLite persistent or in-memory storage) +- Executes the scanner container via Docker API +- Calculates file hashes (SHA256, Blake3, CRC32, AutoV1/V2/V3) +- Converts models between formats (CKPT ↔ SafeTensors) +- Uploads processed files to S3/R2 cloud storage +- Reports results via webhook callbacks + +### Processing Pipeline +Jobs are enqueued via `POST /enqueue` with configurable task flags: +- `Import` (1): Upload to cloud storage +- `Convert` (2): Format conversion +- `Scan` (4): Malware/pickle scanning +- `Hash` (8): Calculate cryptographic hashes +- `ParseMetadata` (16): Extract safetensors metadata +- `Default`: Import | Hash | Scan | ParseMetadata +- `All`: All tasks including Convert + +Key flow: `FileProcessor.cs` downloads the model, runs requested tasks via `IJobTask` implementations, and POSTs results to the callback URL. + +### Job Queues (Priority Order) +- `default`: Normal priority +- `low-prio`: Lower priority processing +- `x-low-prio`: Lowest priority (conversions) +- `cleanup`: Storage cleanup +- `delete-objects`: Deletion jobs + +## Key Files + +- `orchestrator/ModelScanner/Program.cs`: API endpoints and DI setup +- `orchestrator/ModelScanner/FileProcessor.cs`: Main job processing logic +- `orchestrator/ModelScanner/Tasks/`: Individual task implementations (HashTask, ScanTask, ImportTask, ConvertTask, ParseMetadataTask) +- `orchestrator/ModelScanner/CloudStorageService.cs`: S3/R2 integration +- `orchestrator/ModelScanner/DockerService.cs`: Scanner container execution +- `model-scanner/scripts/`: Python conversion scripts (ckpt_to_safetensors.py, safetensors_to_ckpt.py) + +## Configuration + +Settings in `appsettings.json`: +- `ValidTokens`: API authentication tokens +- `CloudStorageOptions`: S3/R2 credentials and bucket names +- `LocalStorageOptions`: Temp folder path +- `ConnectionStrings:JobStorage`: SQLite path for Hangfire (omit for in-memory) +- `Concurrency`: Worker thread count (defaults to CPU count) diff --git a/README.md b/README.md new file mode 100644 index 0000000..b15fe66 --- /dev/null +++ b/README.md @@ -0,0 +1,252 @@ +# Civitai Model Scanner + +A distributed AI model scanning system that detects malware and malicious code in machine learning model files. + +## Architecture + +The system consists of several components: + +``` + +------------------+ + | Cloud Storage | + | (S3/R2) | + +--------^---------+ + | ++----------------+ +-------------------+ | +------------------+ +| HTTP POST | | Orchestrator |-+-| Callback URL | +| /enqueue +---->| (.NET 8) | | (Webhook) | ++----------------+ | | +------------------+ + | - Hangfire | + | - Job Queue | + +--------+----------+ + | + +-----------------+------------------+ + | | | + +---------v------+ +-------v--------+ +-----v------+ + | Legacy Scanner | | Unified Scanner| | ClamAV | + | (picklescan) | | (TensorTrap) | | Updater | + +----------------+ +----------------+ +------------+ +``` + +### Components + +1. **Orchestrator Service** (.NET 8 / ASP.NET Core) + - Receives scan requests via HTTP API + - Queues jobs using Hangfire (SQLite or in-memory) + - Downloads model files, executes scanners via Docker + - Reports results via webhook callbacks + +2. **Legacy Scanner** (Python/Docker) + - Picklescan: Detects dangerous pickle imports in PyTorch models + - ClamAV: Malware signature scanning + +3. **Unified Scanner** (Python/Docker) + - TensorTrap: ML security scanner supporting 13+ formats + - ClamAV: Integrated malware scanning + - Detects 11+ CVEs and security vulnerabilities + +4. **ClamAV Updater** (Sidecar) + - Automatically updates virus definitions every 2 hours + - Shares definitions with scanner containers via Docker volume + +## Quick Start + +### Prerequisites + +- Docker Desktop +- .NET 8 SDK (for local development) + +### Running with Docker Compose + +```bash +# Build all images +docker-compose build + +# Start the stack +docker-compose up -d + +# Check status +docker-compose ps +``` + +### API Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/enqueue` | POST | Submit a scan job | +| `/cleanup` | POST | Trigger temp storage cleanup | +| `/delete` | POST | Delete an object from storage | +| `/metrics/shadow` | GET | Get shadow mode metrics summary | +| `/metrics/shadow/full` | GET | Get full shadow mode metrics | +| `/metrics/shadow/reset` | POST | Reset shadow mode metrics | + +### Submitting a Scan Job + +```bash +curl -X POST "http://localhost/enqueue?token=YOUR_TOKEN&fileUrl=https://example.com/model.safetensors&callbackUrl=https://your-callback.com/result" +``` + +**Parameters:** +- `fileUrl` (required): URL of the model file to scan +- `callbackUrl` (required): Webhook URL for scan results +- `tasks` (optional): Bitmask of tasks to run (default: 28) + - Import = 1 + - Convert = 2 + - Scan = 4 + - Hash = 8 + - ParseMetadata = 16 +- `lowPrio` / `extraLowPrio` (optional): Queue priority flags + +## Configuration + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `ValidTokens__0`, `__1`, etc. | API authentication tokens | - | +| `ScannerOptions__UseUnifiedScanner` | Use TensorTrap instead of picklescan | `false` | +| `ScannerOptions__ShadowMode` | Run both scanners for comparison | `false` | +| `CloudStorageOptions__*` | S3/R2 credentials | - | +| `ConnectionStrings__JobStorage` | SQLite path for Hangfire | (in-memory) | +| `Concurrency` | Worker thread count | CPU count | + +### Scanner Modes + +1. **Legacy Mode** (default): Uses picklescan + ClamAV +2. **Unified Mode**: Uses TensorTrap + ClamAV +3. **Shadow Mode**: Runs both scanners, compares results, uses legacy for response + +Shadow mode is useful for validating the new scanner before full migration. + +## Development + +### Building the Orchestrator + +```bash +cd orchestrator +dotnet build ModelScanner.sln +dotnet run --project ModelScanner +``` + +### Running Tests + +```bash +dotnet test orchestrator/ModelScanner.sln +``` + +### End-to-End Testing + +The `e2e/` directory contains a test script that runs a full scan workflow: + +```bash +# Start the stack with test configuration +docker-compose -f docker-compose.yml -f docker-compose.test.yml up -d + +# Run the e2e test +cd e2e +python e2e_test.py /path/to/model.ckpt --timeout 300 + +# Example with options +python e2e_test.py ./model.safetensors \ + --orchestrator-url http://localhost:80 \ + --token test-token \ + --tasks 28 \ + --timeout 300 \ + --json +``` + +**E2E Test Options:** +- `--orchestrator-url`: Orchestrator API URL (default: http://localhost:8080) +- `--token`: API token (default: test-token) +- `--tasks`: Task flags bitmask (default: 28 = Scan|Hash|ParseMetadata) +- `--timeout`: Timeout in seconds (default: 300) +- `--json`: Output raw JSON results + +## Callback Response Format + +```json +{ + "url": "https://example.com/model.safetensors", + "fileExists": 1, + "picklescanExitCode": 0, + "picklescanOutput": "...", + "picklescanGlobalImports": ["torch", "collections"], + "picklescanDangerousImports": [], + "tensorTrapScanned": true, + "tensorTrapMaxSeverity": "info", + "tensorTrapIsSafe": true, + "tensorTrapFindings": [...], + "clamscanExitCode": 0, + "clamscanOutput": "OK", + "hashes": { + "SHA256": "...", + "Blake3": "...", + "CRC32": "...", + "AutoV1": "...", + "AutoV2": "...", + "AutoV3": "..." + } +} +``` + +## Shadow Mode Metrics + +When running in shadow mode, metrics are collected comparing legacy and unified scanner results: + +```bash +curl "http://localhost/metrics/shadow?token=YOUR_TOKEN" +``` + +```json +{ + "totalScans": 1000, + "matches": 985, + "discrepancies": 15, + "agreementRate": 98.5, + "unifiedFoundMoreThreats": 12, + "legacyFoundMoreThreats": 3, + "bothSafe": 970, + "bothDangerous": 15, + "errors": 0, + "recommendation": "Unified scanner is finding MORE threats - safe to migrate" +} +``` + +## Supported File Formats + +### TensorTrap (Unified Scanner) +- PyTorch: `.pt`, `.pth`, `.bin`, `.ckpt` +- Pickle: `.pkl`, `.pickle` +- NumPy: `.npy`, `.npz` +- Safetensors: `.safetensors` +- ONNX: `.onnx` +- GGUF: `.gguf` + +### Legacy Scanner (Picklescan) +- PyTorch pickle files +- Does NOT scan `.safetensors` (considered safe by design) + +## Security Considerations + +- API endpoints require authentication via `token` query parameter +- Scanner containers run with memory limits (2GB default) +- ClamAV definitions are updated automatically +- Model files are deleted after scanning + +## Troubleshooting + +### Scanner Timeout +Large model files (>1GB) may take several minutes to scan. Adjust timeouts as needed. + +### ClamAV Definitions Not Found +Ensure the clamav-updater container is running and has completed initial download: +```bash +docker logs model-scanner-clamav-updater-1 +``` + +### Connection Refused on Callback +Ensure your callback URL is accessible from the Docker network. Use `host.docker.internal` for local development. + +## License + +[Your license here] diff --git a/clamav-updater/Dockerfile b/clamav-updater/Dockerfile new file mode 100644 index 0000000..629274b --- /dev/null +++ b/clamav-updater/Dockerfile @@ -0,0 +1,19 @@ +FROM debian:bookworm-slim + +RUN apt-get update && \ + apt-get install -y --no-install-recommends clamav clamav-freshclam ca-certificates && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Create directory for definitions +RUN mkdir -p /var/lib/clamav && \ + chown clamav:clamav /var/lib/clamav + +COPY freshclam.conf /etc/clamav/freshclam.conf +COPY update-loop.sh /usr/local/bin/update-loop.sh +RUN chmod +x /usr/local/bin/update-loop.sh + +# Run as clamav user for security +USER clamav + +ENTRYPOINT ["/usr/local/bin/update-loop.sh"] diff --git a/clamav-updater/freshclam.conf b/clamav-updater/freshclam.conf new file mode 100644 index 0000000..6a50d4a --- /dev/null +++ b/clamav-updater/freshclam.conf @@ -0,0 +1,21 @@ +# ClamAV freshclam configuration +DatabaseDirectory /var/lib/clamav +UpdateLogFile /var/log/clamav/freshclam.log +LogTime yes +LogVerbose yes + +# Database mirrors +DatabaseMirror database.clamav.net + +# Check for updates every 2 hours (12 times per day) +Checks 12 + +# Notify on database update (optional - for monitoring) +# NotifyClamd /etc/clamav/clamd.conf + +# Download bytecode signatures +Bytecode yes + +# Connection timeout +ConnectTimeout 30 +ReceiveTimeout 60 diff --git a/clamav-updater/update-loop.sh b/clamav-updater/update-loop.sh new file mode 100644 index 0000000..4dbf504 --- /dev/null +++ b/clamav-updater/update-loop.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +echo "ClamAV Definition Updater starting..." +echo "Definitions will be stored in: /var/lib/clamav" + +# Initial update on startup +echo "Running initial database update..." +freshclam --config-file=/etc/clamav/freshclam.conf || true + +echo "Initial update complete. Entering update loop..." + +# Run freshclam every 2 hours +while true; do + sleep 7200 # 2 hours in seconds + echo "$(date): Running scheduled database update..." + freshclam --config-file=/etc/clamav/freshclam.conf || true + echo "$(date): Update complete." +done diff --git a/docker-compose.override.yml.example b/docker-compose.override.yml.example new file mode 100644 index 0000000..4ba74d1 --- /dev/null +++ b/docker-compose.override.yml.example @@ -0,0 +1,17 @@ +# Local development overrides +# Copy this to docker-compose.override.yml (gitignored) and customize paths +# +# Usage: +# cp docker-compose.override.yml.example docker-compose.override.yml +# # Edit the TensorTrap path below +# docker-compose build unified-scanner +# docker-compose up -d + +services: + unified-scanner: + build: + context: ./unified-scanner/ + dockerfile: Dockerfile.local + additional_contexts: + # Change this path to your local TensorTrap repo + tensortrap: C:\Dev\Repos\open-source\TensorTrap diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 0000000..d4ec349 --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,33 @@ +version: '3.9' + +# Test environment override - use with: docker-compose -f docker-compose.yml -f docker-compose.test.yml up + +services: + orchestrator: + environment: + # Dummy cloud storage settings (cloud features disabled) + - CloudStorageOptions__AccessKey=test-access-key + - CloudStorageOptions__SecretKey=test-secret-key + - CloudStorageOptions__ServiceUrl=http://localhost:9999 + - CloudStorageOptions__UploadBucket=test-bucket + # Test authentication token + - ValidTokens__0=test-token + # Scanner settings - Shadow mode runs both scanners and compares results + - ScannerOptions__UseUnifiedScanner=false + - ScannerOptions__ShadowMode=true + # Local storage (in-memory for tests) + - ConnectionStrings__JobStorage= + # Use shared temp volume path (must match volume mount) + - LocalStorageOptions__TempFolder=/shared-temp/ + volumes: + # Shared temp volume for scanner containers to access files + # IMPORTANT: This bind mount allows Docker-in-Docker to work correctly + # The orchestrator writes files here, and scanner containers can read them + - model-scanner-temp:/shared-temp/ + # Add extra_hosts for host.docker.internal on Linux + extra_hosts: + - "host.docker.internal:host-gateway" + +volumes: + model-scanner-temp: + name: model-scanner-temp # Explicit name to avoid compose prefix diff --git a/docker-compose.yml b/docker-compose.yml index a19256c..bdca117 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,10 +1,32 @@ version: '3.9' services: + # Legacy scanner (kept for backwards compatibility during migration) model-scanner: image: civitai-model-scanner build: context: ./model-scanner/ dockerfile: Dockerfile + + # Unified scanner with TensorTrap + ClamAV + unified-scanner: + image: unified-scanner + build: + context: ./unified-scanner/ + dockerfile: Dockerfile + volumes: + - clamav-definitions:/var/lib/clamav:ro + + # ClamAV definition updater sidecar + # Runs continuously, updating definitions every 2 hours + clamav-updater: + image: clamav-updater + build: + context: ./clamav-updater/ + dockerfile: Dockerfile + restart: always + volumes: + - clamav-definitions:/var/lib/clamav + orchestrator: image: civitai-model-scanner-orchestrator build: @@ -16,9 +38,14 @@ services: volumes: - "/var/run/docker.sock:/var/run/docker.sock" - db:/data/ + # Note: clamav-definitions is NOT mounted here - DockerService mounts it + # dynamically when starting the unified-scanner container ports: - "80:8080" depends_on: - - model-scanner + - unified-scanner + - clamav-updater + volumes: db: + clamav-definitions: diff --git a/e2e/.env.test b/e2e/.env.test new file mode 100644 index 0000000..0fb510a --- /dev/null +++ b/e2e/.env.test @@ -0,0 +1,18 @@ +# E2E Test Environment Configuration +# These are dummy values - cloud storage features won't work but scanning will + +# Required cloud storage settings (dummy values for testing) +CloudStorageOptions__AccessKey=test-access-key +CloudStorageOptions__SecretKey=test-secret-key +CloudStorageOptions__ServiceUrl=http://localhost:9999 +CloudStorageOptions__UploadBucket=test-bucket + +# Authentication +ValidTokens__0=test-token + +# Scanner options +ScannerOptions__UseUnifiedScanner=true +ScannerOptions__ShadowMode=false + +# Local storage +LocalStorageOptions__TempFolder=/tmp/model-scanner diff --git a/e2e/appsettings.Testing.json b/e2e/appsettings.Testing.json new file mode 100644 index 0000000..73e8f9c --- /dev/null +++ b/e2e/appsettings.Testing.json @@ -0,0 +1,17 @@ +{ + "ValidTokens": ["test-token"], + "ScannerOptions": { + "UseUnifiedScanner": true, + "ShadowMode": false, + "MetricsFilePath": "./shadow-mode-metrics.json" + }, + "LocalStorageOptions": { + "TempFolder": "/tmp/model-scanner" + }, + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} diff --git a/e2e/e2e_test.py b/e2e/e2e_test.py new file mode 100644 index 0000000..a8381d0 --- /dev/null +++ b/e2e/e2e_test.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +""" +End-to-end test for the Model Scanner orchestrator. + +This script: +1. Starts a local file server to serve test files +2. Starts a callback receiver to capture scan results +3. Submits a scan job to the orchestrator +4. Displays the results + +Usage: + python e2e_test.py [--orchestrator-url URL] [--token TOKEN] + +Example: + python e2e_test.py ./test-models/safe-model.safetensors + python e2e_test.py E:\models\model.ckpt --orchestrator-url http://localhost:8080 +""" + +import argparse +import http.server +import json +import os +import socketserver +import sys +import threading +import time +import urllib.parse +import urllib.request +from pathlib import Path + + +class FileServerHandler(http.server.SimpleHTTPRequestHandler): + """HTTP handler that serves a single file.""" + + file_path: str = None + file_name: str = None + + def do_GET(self): + if self.path == f"/{self.file_name}": + self.send_response(200) + self.send_header("Content-Type", "application/octet-stream") + file_size = os.path.getsize(self.file_path) + self.send_header("Content-Length", str(file_size)) + self.end_headers() + + with open(self.file_path, "rb") as f: + # Stream in chunks to handle large files + while chunk := f.read(8192): + self.wfile.write(chunk) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + print(f"[FileServer] {args[0]}") + + +class CallbackHandler(http.server.BaseHTTPRequestHandler): + """HTTP handler that receives callback results.""" + + results: list = None + + def do_POST(self): + # Read Content-Length header + content_length_header = self.headers.get("Content-Length", "0") + try: + content_length = int(content_length_header) + except ValueError: + content_length = 0 + + print(f"[Callback] POST received, Content-Length: {content_length}") + + # Read body + if content_length > 0: + body = self.rfile.read(content_length) + else: + body = b"" + + print(f"[Callback] Body length: {len(body)}, first 200 chars: {body[:200]}") + + try: + if body: + data = json.loads(body.decode("utf-8")) + self.results.append(data) + print(f"[Callback] Parsed JSON successfully!") + else: + print(f"[Callback] Empty body received") + self.results.append({"error": "empty body"}) + except json.JSONDecodeError as e: + print(f"[Callback] JSON decode error: {e}") + print(f"[Callback] Raw body: {body[:500]}") + self.results.append({"raw": body.decode("utf-8", errors="replace")}) + + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"status": "ok"}') + + def log_message(self, format, *args): + pass # Suppress default logging + + +class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): + """TCP Server that handles each request in a new thread.""" + allow_reuse_address = True + daemon_threads = True + + +def start_file_server(file_path: str, port: int) -> socketserver.TCPServer: + """Start a file server in a background thread.""" + file_path = os.path.abspath(file_path) + file_name = os.path.basename(file_path) + + handler = FileServerHandler + handler.file_path = file_path + handler.file_name = file_name + + server = socketserver.TCPServer(("", port), handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + return server, file_name + + +def start_callback_server(port: int) -> tuple: + """Start a callback receiver in a background thread.""" + results = [] + + handler = CallbackHandler + handler.results = results + + server = ThreadedTCPServer(("", port), handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + return server, results + + +def submit_scan_job(orchestrator_url: str, token: str, file_url: str, callback_url: str, tasks: int = None) -> bool: + """Submit a scan job to the orchestrator. + + Task flags: + Import = 1, Convert = 2, Scan = 4, Hash = 8, ParseMetadata = 16 + Default (without Import) = Scan | Hash | ParseMetadata = 4 | 8 | 16 = 28 + """ + params_dict = { + "fileUrl": file_url, + "callbackUrl": callback_url, + "token": token + } + if tasks is not None: + params_dict["tasks"] = tasks + + params = urllib.parse.urlencode(params_dict) + + url = f"{orchestrator_url}/enqueue?{params}" + + try: + req = urllib.request.Request(url, method="POST") + with urllib.request.urlopen(req, timeout=10) as response: + print(f"[Orchestrator] Job submitted, status: {response.status}") + return True + except urllib.error.HTTPError as e: + print(f"[Orchestrator] HTTP Error: {e.code} - {e.reason}") + return False + except urllib.error.URLError as e: + print(f"[Orchestrator] Connection Error: {e.reason}") + return False + + +def format_results(results: dict) -> str: + """Format scan results for display.""" + lines = [] + + # File info + if "url" in results: + lines.append(f"File URL: {results['url']}") + + # Scan results + if "picklescanExitCode" in results: + exit_code = results["picklescanExitCode"] + status = "SAFE" if exit_code == 0 else "DANGEROUS" + lines.append(f"Picklescan: {status} (exit code: {exit_code})") + + if "clamscanExitCode" in results: + exit_code = results["clamscanExitCode"] + status = "CLEAN" if exit_code == 0 else "INFECTED" if exit_code == 1 else "ERROR" + lines.append(f"ClamAV: {status} (exit code: {exit_code})") + + # TensorTrap results (unified scanner) + if results.get("tensorTrapScanned"): + severity = results.get("tensorTrapMaxSeverity", "unknown") + is_safe = results.get("tensorTrapIsSafe", False) + status = "SAFE" if is_safe else f"UNSAFE ({severity})" + lines.append(f"TensorTrap: {status}") + + findings = results.get("tensorTrapFindings", []) + if findings: + lines.append(f" Findings: {len(findings)}") + for finding in findings[:5]: # Show first 5 + msg = finding.get("message", "unknown") + sev = finding.get("severity", "?") + lines.append(f" - [{sev}] {msg}") + + # Hashes + if "sha256" in results: + lines.append(f"SHA256: {results['sha256']}") + if "blake3" in results: + lines.append(f"Blake3: {results['blake3']}") + + # Dangerous imports + dangerous = results.get("picklescanDangerousImports", []) + if dangerous: + lines.append(f"Dangerous Imports: {', '.join(dangerous)}") + + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser( + description="End-to-end test for Model Scanner", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python e2e_test.py ./test-model.safetensors + python e2e_test.py E:\\models\\model.ckpt --token mytoken + python e2e_test.py ./model.pt --orchestrator-url http://192.168.1.100:8080 + """ + ) + parser.add_argument("file", help="Path to the model file to scan") + parser.add_argument("--orchestrator-url", default="http://localhost:8080", + help="Orchestrator URL (default: http://localhost:8080)") + parser.add_argument("--token", default="test-token", + help="API token (default: test-token)") + parser.add_argument("--file-server-port", type=int, default=9000, + help="Port for file server (default: 9000)") + parser.add_argument("--callback-port", type=int, default=9001, + help="Port for callback receiver (default: 9001)") + parser.add_argument("--timeout", type=int, default=300, + help="Timeout in seconds (default: 300)") + parser.add_argument("--tasks", type=int, default=28, + help="Task flags: Import=1, Convert=2, Scan=4, Hash=8, ParseMetadata=16 (default: 28 = Scan|Hash|ParseMetadata)") + parser.add_argument("--json", action="store_true", + help="Output raw JSON results") + + args = parser.parse_args() + + # Validate file exists + if not os.path.exists(args.file): + print(f"Error: File not found: {args.file}") + sys.exit(1) + + file_size = os.path.getsize(args.file) + print(f"Test file: {args.file}") + print(f"File size: {file_size / (1024*1024):.2f} MB") + print() + + # Start servers + print(f"Starting file server on port {args.file_server_port}...") + file_server, file_name = start_file_server(args.file, args.file_server_port) + + print(f"Starting callback receiver on port {args.callback_port}...") + callback_server, results = start_callback_server(args.callback_port) + + # Build URLs + file_url = f"http://host.docker.internal:{args.file_server_port}/{file_name}" + callback_url = f"http://host.docker.internal:{args.callback_port}/callback" + + print() + print(f"File URL: {file_url}") + print(f"Callback URL: {callback_url}") + print() + + # Submit job + print(f"Submitting scan job (tasks={args.tasks})...") + if not submit_scan_job(args.orchestrator_url, args.token, file_url, callback_url, args.tasks): + print("Failed to submit job. Is the orchestrator running?") + print(f" Try: docker-compose up orchestrator") + sys.exit(1) + + # Wait for results + # Note: The orchestrator sends a callback after each task, so we wait a bit + # after receiving the first callback to collect all of them + print() + print(f"Waiting for results (timeout: {args.timeout}s)...") + start_time = time.time() + last_result_time = None + last_count = 0 + + while (time.time() - start_time) < args.timeout: + current_count = len(results) + if current_count > last_count: + # New result received + last_result_time = time.time() + last_count = current_count + print(f" [Callback #{current_count} received]") + elif last_result_time and (time.time() - last_result_time) > 15: + # 15 seconds since last callback, assume we're done + break + + time.sleep(0.5) + elapsed = int(time.time() - start_time) + if elapsed % 10 == 0 and elapsed > 0 and not results: + print(f" ... {elapsed}s elapsed") + + # Display results + print() + if results: + print(f"Received {len(results)} callback(s)") + elapsed = time.time() - start_time + print(f"Total time: {elapsed:.1f}s") + print("=" * 60) + + # Use the last result (most complete) + final_result = results[-1] + + if args.json: + print(json.dumps(final_result, indent=2)) + else: + print(format_results(final_result)) + + print("=" * 60) + + # Determine overall status + is_safe = ( + final_result.get("picklescanExitCode", 0) == 0 and + final_result.get("clamscanExitCode", 0) == 0 and + final_result.get("tensorTrapIsSafe", True) + ) + + if is_safe: + print("\nOverall: SAFE") + sys.exit(0) + else: + print("\nOverall: POTENTIALLY DANGEROUS") + sys.exit(1) + else: + print("Timeout waiting for results!") + print("Check the orchestrator logs for errors.") + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/orchestrator/ModelScanner/DockerService.cs b/orchestrator/ModelScanner/DockerService.cs index 25caf2a..287fe11 100644 --- a/orchestrator/ModelScanner/DockerService.cs +++ b/orchestrator/ModelScanner/DockerService.cs @@ -2,6 +2,7 @@ using System.Diagnostics; using System.IO; using System.Text; +using System.Text.RegularExpressions; using System.Threading; namespace ModelScanner; @@ -11,24 +12,156 @@ class DockerService readonly ILogger _logger; readonly LocalStorageOptions _localStorageOptions; + public const string DefaultScannerImage = "civitai-model-scanner"; + public const string UnifiedScannerImage = "unified-scanner"; + + // Pattern for safe file paths - alphanumeric, common path characters, no shell metacharacters + static readonly Regex SafePathPattern = new(@"^[a-zA-Z0-9_./\-\\: ]+$", RegexOptions.Compiled); + + // Pattern for safe Docker volume names - alphanumeric, hyphens, underscores + static readonly Regex SafeVolumeNamePattern = new(@"^[a-zA-Z0-9_\-]+$", RegexOptions.Compiled); + public DockerService(ILogger logger, IOptions localStorageOptions) { _logger = logger; _localStorageOptions = localStorageOptions.Value; } - public const string InPath = "/data/model.in"; + public const string InPathBase = "/data/model"; public const string OutFolderPath = "/data/"; + // Legacy path for backward compatibility + public const string InPath = "/data/model.in"; + + // Named volume for Docker-in-Docker temp file sharing + public const string TempVolumeName = "model-scanner-temp"; + public const string SharedTempPath = "/shared-temp"; + + // ClamAV virus definitions shared volume + public const string ClamavVolumeName = "clamav-definitions"; + + /// + /// Gets the container path with the original file extension preserved. + /// TensorTrap needs the extension to identify file formats. + /// + public static string GetContainerPath(string filePath) + { + var extension = Path.GetExtension(filePath); + return string.IsNullOrEmpty(extension) ? InPath : $"{InPathBase}{extension}"; + } + + /// + /// Validates that a file path contains only safe characters to prevent command injection. + /// + static void ValidatePath(string path, string paramName) + { + if (string.IsNullOrWhiteSpace(path)) + throw new ArgumentException("Path cannot be empty", paramName); + + if (!SafePathPattern.IsMatch(path)) + throw new ArgumentException($"Path contains invalid characters: {path}", paramName); + } + + /// + /// Validates that a Docker volume name contains only safe characters. + /// + static void ValidateVolumeName(string volumeName, string paramName) + { + if (string.IsNullOrWhiteSpace(volumeName)) + throw new ArgumentException("Volume name cannot be empty", paramName); + + if (!SafeVolumeNamePattern.IsMatch(volumeName)) + throw new ArgumentException($"Volume name contains invalid characters: {volumeName}", paramName); + } + + /// + /// Builds Docker volume mounts and determines the effective file path in the scanner container. + /// Handles Docker-in-Docker by using named volumes when the temp folder is in a shared volume. + /// + /// Tuple of (volumeMounts, effectiveFilePath) + (string volumeMounts, string effectiveFilePath) BuildVolumeMountsAndPath(string tempFolderPath, string filePath, string image) + { + // Check if we're using a named volume (temp folder starts with shared path) + var normalizedTempPath = tempFolderPath.Replace("\\", "/"); + var isUsingNamedVolume = normalizedTempPath.Contains(SharedTempPath); + + // Get container path (preserving extension for unified scanner) + var containerPath = image == UnifiedScannerImage ? GetContainerPath(filePath) : InPath; + + if (isUsingNamedVolume) + { + // Docker-in-Docker mode: mount the named volume + // The file is already at a path like /shared-temp/file.ckpt + // Mount the same volume so the scanner can access it at the same path + var fullFilePath = Path.GetFullPath(filePath).Replace("\\", "/"); + _logger.LogInformation("Using named volume {volume} for Docker-in-Docker, file at {path}", TempVolumeName, fullFilePath); + + var volumeMounts = $"-v {TempVolumeName}:{SharedTempPath}"; + // File path in scanner is the same as in orchestrator (e.g., /shared-temp/file.ckpt) + return (volumeMounts, fullFilePath); + } + else + { + // Direct file mount mode (for local development on host) + var volumeMounts = $"-v {Path.GetFullPath(filePath)}:{containerPath} -v {Path.GetFullPath(_localStorageOptions.TempFolder)}:{OutFolderPath}"; + // File path in scanner is the mapped containerPath (e.g., /data/model.ckpt) + return (volumeMounts, containerPath); + } + } + + /// + /// Runs a command in the default scanner container (civitai-model-scanner). + /// StdOut and StdErr are merged into the output string. + /// public async Task<(int exitCode, string output)> RunCommandInDocker(string command, string filePath, CancellationToken cancellationToken) { - _logger.LogInformation("Executing {command} for file {filePath}", command, filePath); + var (exitCode, stdOut, stdErr) = await RunCommandInDocker(DefaultScannerImage, command, filePath, cancellationToken); + return (exitCode, stdOut + stdErr); + } + + /// + /// Runs a command in the specified Docker container. + /// StdOut and StdErr are captured separately for proper JSON parsing. + /// + public async Task<(int exitCode, string stdOut, string stdErr)> RunCommandInDocker(string image, string command, string filePath, CancellationToken cancellationToken) + { + // Validate inputs to prevent command injection + ValidatePath(filePath, nameof(filePath)); + ValidateVolumeName(TempVolumeName, nameof(TempVolumeName)); + + _logger.LogInformation("Executing {command} in {image} for file {filePath}", command, image, filePath); var stopwatch = Stopwatch.StartNew(); - var process = new Process + // Build volume mounts and determine the effective file path in the scanner container + var tempFolderPath = Path.GetFullPath(_localStorageOptions.TempFolder); + ValidatePath(tempFolderPath, nameof(tempFolderPath)); + var (volumeMounts, effectiveFilePath) = BuildVolumeMountsAndPath(tempFolderPath, filePath, image); + + // For unified scanner, also mount ClamAV definitions from shared volume + if (image == UnifiedScannerImage) + { + ValidateVolumeName(ClamavVolumeName, nameof(ClamavVolumeName)); + volumeMounts += $" -v {ClamavVolumeName}:/var/lib/clamav:ro"; + } + + // Build the effective command based on scanner type and volume configuration + string effectiveCommand; + if (image == UnifiedScannerImage) + { + // Unified scanner: pass file path directly to scan-wrapper.py + effectiveCommand = effectiveFilePath; + } + else + { + // Legacy scanner: replace placeholder path with effective path + // Commands like "clamscan /data/model.in" become "clamscan /shared-temp/file.ckpt" + effectiveCommand = command.Replace(InPath, effectiveFilePath); + } + + using var process = new Process { - StartInfo = new ProcessStartInfo("docker", $"run -v {Path.GetFullPath(filePath)}:{InPath} -v {Path.GetFullPath(_localStorageOptions.TempFolder)}:{OutFolderPath} --rm civitai-model-scanner {command}") + StartInfo = new ProcessStartInfo("docker", $"run {volumeMounts} --rm --memory=2g {image} {effectiveCommand}") { CreateNoWindow = true, WindowStyle = ProcessWindowStyle.Hidden, @@ -37,32 +170,45 @@ public DockerService(ILogger logger, IOptions - outputBuilder.Append(e.Data); + { + if (e.Data != null) stdOutBuilder.AppendLine(e.Data); + }; process.ErrorDataReceived += (_, e) => - outputBuilder.Append(e.Data); + { + if (e.Data != null) stdErrBuilder.AppendLine(e.Data); + }; process.Start(); process.BeginOutputReadLine(); process.BeginErrorReadLine(); + int exitCode; try { await process.WaitForExitAsync(cancellationToken); + exitCode = process.ExitCode; } catch (TaskCanceledException) { - process.Kill(); // Ensure that we abort the docker process when we cancel quickly + process.Kill(entireProcessTree: true); // Ensure that we abort the docker process and children when we cancel throw; } - var output = outputBuilder.ToString(); + var stdOut = stdOutBuilder.ToString(); + var stdErr = stdErrBuilder.ToString(); + + _logger.LogInformation("Executed {command} in {image} for file {filePath} completed with exit code {exitCode} in {elapsed}", command, image, filePath, exitCode, stopwatch.Elapsed); + + if (!string.IsNullOrWhiteSpace(stdOut)) + _logger.LogInformation("StdOut: {stdOut}", stdOut); - _logger.LogInformation("Executed {command} for file {filePath} completed with exit code {exitCode} in {elapsed}", command, filePath, process.ExitCode, stopwatch.Elapsed); - _logger.LogInformation(output); + if (!string.IsNullOrWhiteSpace(stdErr)) + _logger.LogWarning("StdErr: {stdErr}", stdErr); - return (process.ExitCode, output); + return (exitCode, stdOut, stdErr); } } diff --git a/orchestrator/ModelScanner/FileProcessor.cs b/orchestrator/ModelScanner/FileProcessor.cs index e1435a4..5827ce1 100644 --- a/orchestrator/ModelScanner/FileProcessor.cs +++ b/orchestrator/ModelScanner/FileProcessor.cs @@ -112,7 +112,23 @@ public async Task ReportFileAsync(string callbackUrl, ScanResult result, Cancell using var httpClient = new HttpClient(); _logger.LogInformation("Invoking {callbackUrl} with result {result}", callbackUrl, result); - var response = await httpClient.PostAsJsonAsync(callbackUrl, result, cancellationToken); + + // Explicitly serialize to JSON for debugging + string jsonPayload; + try + { + jsonPayload = System.Text.Json.JsonSerializer.Serialize(result); + _logger.LogInformation("JSON payload length: {length}", jsonPayload.Length); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to serialize ScanResult to JSON"); + throw; + } + + // Send with explicit content + var content = new StringContent(jsonPayload, Encoding.UTF8, "application/json"); + var response = await httpClient.PostAsync(callbackUrl, content, cancellationToken); response.EnsureSuccessStatusCode(); } diff --git a/orchestrator/ModelScanner/Program.cs b/orchestrator/ModelScanner/Program.cs index eefac6e..98932cc 100644 --- a/orchestrator/ModelScanner/Program.cs +++ b/orchestrator/ModelScanner/Program.cs @@ -8,9 +8,13 @@ using ModelScanner.Tasks; using System.Security.Claims; +// Import ScannerOptions from Tasks namespace +using ScannerOptions = ModelScanner.Tasks.ScannerOptions; + var builder = WebApplication.CreateBuilder(args); builder.Services.AddSingleton(); builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); @@ -28,6 +32,9 @@ .ValidateDataAnnotations() .ValidateOnStart(); +builder.Services.AddOptions() + .BindConfiguration(nameof(ScannerOptions)); + var connectionString = builder.Configuration.GetConnectionString("JobStorage"); if (string.IsNullOrWhiteSpace(connectionString)) { @@ -131,6 +138,23 @@ return Results.Accepted(); }); +// Shadow mode metrics endpoints +app.MapGet("/metrics/shadow", (ShadowModeMetrics metrics) => +{ + return Results.Json(metrics.GetSummary()); +}); + +app.MapGet("/metrics/shadow/full", (ShadowModeMetrics metrics) => +{ + return Results.Json(metrics.GetMetrics()); +}); + +app.MapPost("/metrics/shadow/reset", (ShadowModeMetrics metrics) => +{ + metrics.Reset(); + return Results.Ok("Metrics reset"); +}); + #pragma warning disable ASP0014 // Hangfire dashboard is not compatible with top level routing app.UseRouting(); app.UseEndpoints(routes => diff --git a/orchestrator/ModelScanner/ScanResult.cs b/orchestrator/ModelScanner/ScanResult.cs index ebb76a2..3873134 100644 --- a/orchestrator/ModelScanner/ScanResult.cs +++ b/orchestrator/ModelScanner/ScanResult.cs @@ -1,4 +1,4 @@ -using System.Globalization; +using System.Globalization; using System.Text.Json; using System.Text.Json.Serialization; @@ -8,14 +8,23 @@ public record Conversion(string? Url, Dictionary? Hashes, string [JsonPropertyName("sizeKB")] public double? SizeKB { get; set; } } - + public required string Url { get; set; } public int FileExists { get; set; } + + // Legacy picklescan fields (maintained for backwards compatibility) public int PicklescanExitCode { get; set; } public string? PicklescanOutput { get; set; } public HashSet? PicklescanGlobalImports { get; set; } public HashSet? PicklescanDangerousImports { get; set; } + + // TensorTrap scanner fields + public bool TensorTrapScanned { get; set; } + public string? TensorTrapMaxSeverity { get; set; } + public bool TensorTrapIsSafe { get; set; } + public List? TensorTrapFindings { get; set; } + public Dictionary Conversions { get; set; } = new(); public Dictionary Hashes { get; set; } = new(); public JsonDocument? Metadata { get; set; } @@ -23,3 +32,24 @@ public record Conversion(string? Url, Dictionary? Hashes, string public string? ClamscanOutput { get; set; } public HashSet? Fixed { get; set; } } + +/// +/// Represents a finding from TensorTrap ML security scanner. +/// +public record TensorTrapFinding +{ + [JsonPropertyName("severity")] + public string? Severity { get; init; } + + [JsonPropertyName("message")] + public string? Message { get; init; } + + [JsonPropertyName("location")] + public int? Location { get; init; } + + [JsonPropertyName("details")] + public Dictionary? Details { get; init; } + + [JsonPropertyName("recommendation")] + public string? Recommendation { get; init; } +} diff --git a/orchestrator/ModelScanner/ShadowModeMetrics.cs b/orchestrator/ModelScanner/ShadowModeMetrics.cs new file mode 100644 index 0000000..d18cb65 --- /dev/null +++ b/orchestrator/ModelScanner/ShadowModeMetrics.cs @@ -0,0 +1,250 @@ +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelScanner; + +/// +/// Tracks shadow mode comparison metrics between legacy and unified scanners. +/// Persists to a JSON file and exposes via HTTP endpoint. +/// +public class ShadowModeMetrics : IDisposable +{ + private readonly ILogger _logger; + private readonly string _metricsFilePath; + private readonly object _lock = new(); + private readonly Timer _saveTimer; + private MetricsData _data; + private bool _isDirty; + private bool _disposed; + + public ShadowModeMetrics(ILogger logger, IConfiguration configuration) + { + _logger = logger; + _metricsFilePath = configuration.GetValue("ScannerOptions:MetricsFilePath") + ?? "./shadow-mode-metrics.json"; + _data = LoadMetrics(); + + // Save metrics every 30 seconds if dirty, instead of on every update + _saveTimer = new Timer(_ => FlushIfDirty(), null, TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(30)); + } + + private void FlushIfDirty() + { + lock (_lock) + { + if (_isDirty) + { + SaveMetricsInternal(); + _isDirty = false; + } + } + } + + public void RecordMatch(string filePath, bool bothSafe) + { + lock (_lock) + { + _data.TotalScans++; + _data.Matches++; + if (bothSafe) + _data.BothSafe++; + else + _data.BothDangerous++; + _data.LastUpdated = DateTime.UtcNow; + _isDirty = true; + } + } + + public void RecordDiscrepancy(string filePath, DiscrepancyType type, string? details = null) + { + lock (_lock) + { + _data.TotalScans++; + _data.Discrepancies++; + + switch (type) + { + case DiscrepancyType.UnifiedFoundMore: + _data.UnifiedFoundMoreThreats++; + break; + case DiscrepancyType.LegacyFoundMore: + _data.LegacyFoundMoreThreats++; + break; + } + + // Keep last 100 discrepancies for review + _data.RecentDiscrepancies.Add(new DiscrepancyRecord + { + Timestamp = DateTime.UtcNow, + FilePath = Path.GetFileName(filePath), + Type = type.ToString(), + Details = details + }); + + if (_data.RecentDiscrepancies.Count > 100) + { + _data.RecentDiscrepancies.RemoveAt(0); + } + + _data.LastUpdated = DateTime.UtcNow; + _isDirty = true; + } + } + + public void RecordError(string filePath, string scanner, string error) + { + lock (_lock) + { + _data.TotalScans++; + _data.Errors++; + _data.LastUpdated = DateTime.UtcNow; + _isDirty = true; + } + } + + public MetricsData GetMetrics() + { + lock (_lock) + { + // Deep copy to avoid race condition - the list must be copied too + return _data with + { + RecentDiscrepancies = new List(_data.RecentDiscrepancies) + }; + } + } + + public MetricsSummary GetSummary() + { + lock (_lock) + { + var agreementRate = _data.TotalScans > 0 + ? (double)_data.Matches / _data.TotalScans * 100 + : 0; + + return new MetricsSummary + { + TotalScans = _data.TotalScans, + Matches = _data.Matches, + Discrepancies = _data.Discrepancies, + AgreementRate = Math.Round(agreementRate, 2), + UnifiedFoundMoreThreats = _data.UnifiedFoundMoreThreats, + LegacyFoundMoreThreats = _data.LegacyFoundMoreThreats, + BothSafe = _data.BothSafe, + BothDangerous = _data.BothDangerous, + Errors = _data.Errors, + LastUpdated = _data.LastUpdated, + RecentDiscrepancies = _data.RecentDiscrepancies.TakeLast(10).ToList() + }; + } + } + + public void Reset() + { + lock (_lock) + { + _data = new MetricsData(); + _isDirty = true; + } + } + + public void Dispose() + { + if (_disposed) return; + _disposed = true; + + _saveTimer.Dispose(); + + // Final flush on dispose + lock (_lock) + { + if (_isDirty) + { + SaveMetricsInternal(); + } + } + } + + private MetricsData LoadMetrics() + { + try + { + if (File.Exists(_metricsFilePath)) + { + var json = File.ReadAllText(_metricsFilePath); + return JsonSerializer.Deserialize(json) ?? new MetricsData(); + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to load metrics from {path}, starting fresh", _metricsFilePath); + } + return new MetricsData(); + } + + private void SaveMetricsInternal() + { + try + { + // Use atomic write: write to temp file, then rename (atomic on most filesystems) + var tempPath = _metricsFilePath + ".tmp"; + var json = JsonSerializer.Serialize(_data, new JsonSerializerOptions { WriteIndented = true }); + File.WriteAllText(tempPath, json); + File.Move(tempPath, _metricsFilePath, overwrite: true); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to save metrics to {path}", _metricsFilePath); + } + } +} + +public enum DiscrepancyType +{ + UnifiedFoundMore, // TensorTrap found threat, legacy said safe + LegacyFoundMore // Legacy found threat, TensorTrap said safe +} + +public record MetricsData +{ + public int TotalScans { get; set; } + public int Matches { get; set; } + public int Discrepancies { get; set; } + public int UnifiedFoundMoreThreats { get; set; } + public int LegacyFoundMoreThreats { get; set; } + public int BothSafe { get; set; } + public int BothDangerous { get; set; } + public int Errors { get; set; } + public DateTime LastUpdated { get; set; } = DateTime.UtcNow; + public List RecentDiscrepancies { get; set; } = new(); +} + +public record DiscrepancyRecord +{ + public DateTime Timestamp { get; init; } + public string? FilePath { get; init; } + public string? Type { get; init; } + public string? Details { get; init; } +} + +public record MetricsSummary +{ + public int TotalScans { get; init; } + public int Matches { get; init; } + public int Discrepancies { get; init; } + public double AgreementRate { get; init; } + public int UnifiedFoundMoreThreats { get; init; } + public int LegacyFoundMoreThreats { get; init; } + public int BothSafe { get; init; } + public int BothDangerous { get; init; } + public int Errors { get; init; } + public DateTime LastUpdated { get; init; } + public List RecentDiscrepancies { get; init; } = new(); + + [JsonIgnore] + public string Recommendation => UnifiedFoundMoreThreats > LegacyFoundMoreThreats && Discrepancies > 0 + ? "Unified scanner is finding MORE threats - safe to migrate" + : LegacyFoundMoreThreats > UnifiedFoundMoreThreats && Discrepancies > 0 + ? "WARNING: Legacy finding threats unified misses - investigate before migrating" + : "Scanners agree - migration should be safe"; +} diff --git a/orchestrator/ModelScanner/Tasks/ScanTask.cs b/orchestrator/ModelScanner/Tasks/ScanTask.cs index 886a085..82f79b9 100644 --- a/orchestrator/ModelScanner/Tasks/ScanTask.cs +++ b/orchestrator/ModelScanner/Tasks/ScanTask.cs @@ -1,98 +1,498 @@ -using System.Diagnostics; +using System.Diagnostics; +using System.Text.Json; using System.Text.RegularExpressions; using System.Text; -using Microsoft.AspNetCore.Mvc.ModelBinding.Binders; +using Microsoft.Extensions.Options; namespace ModelScanner.Tasks; class ScanTask : IJobTask { readonly DockerService _dockerService; + readonly ILogger _logger; + readonly ScannerOptions _options; + readonly ShadowModeMetrics _metrics; - public ScanTask(DockerService dockerService) + public ScanTask(DockerService dockerService, ILogger logger, IOptions options, ShadowModeMetrics metrics) { _dockerService = dockerService; + _logger = logger; + _options = options.Value; + _metrics = metrics; } public JobTaskTypes TaskType => JobTaskTypes.Scan; public async Task Process(string filePath, ScanResult result, CancellationToken cancellationToken) { - var fileExtension = Path.GetExtension(filePath); + if (_options.ShadowMode) + { + // Shadow mode: run both scanners, compare results, but use legacy for actual result + await RunShadowModeComparison(filePath, result, cancellationToken); + } + else if (_options.UseUnifiedScanner) + { + // Production mode with unified scanner + await RunUnifiedScan(filePath, result, cancellationToken); + } + else + { + // Legacy mode: use old picklescan + clamscan + await RunLegacyScan(filePath, result, cancellationToken); + } - await RunClamScan(result); - await RunPickleScan(result); - return true; + } + + async Task RunShadowModeComparison(string filePath, ScanResult result, CancellationToken cancellationToken) + { + var fileExtension = Path.GetExtension(filePath); + + // Run both scanners in parallel + var legacyTask = RunLegacyScanInternal(filePath, fileExtension, cancellationToken); + var unifiedTask = RunUnifiedScanInternal(filePath, cancellationToken); + + ScanResultInternal legacyResult; + ScanResultInternal unifiedResult; + + try + { + await Task.WhenAll(legacyTask, unifiedTask); + legacyResult = await legacyTask; + unifiedResult = await unifiedTask; + } + catch (Exception ex) + { + _logger.LogError(ex, "[SHADOW-ERROR] One or both scanner tasks failed for {filePath}", filePath); + + // Try to get results from whichever task succeeded + // Use IsCompletedSuccessfully to ensure task completed without faulting or cancellation + legacyResult = legacyTask.IsCompletedSuccessfully + ? await legacyTask + : new ScanResultInternal + { + PicklescanExitCode = -1, + PicklescanOutput = $"Scanner error: {ex.Message}", + ClamscanExitCode = -1 + }; + + unifiedResult = unifiedTask.IsCompletedSuccessfully + ? await unifiedTask + : new ScanResultInternal { TensorTrapScanned = false, ClamscanExitCode = -1 }; - async Task RunPickleScan(ScanResult result) + _metrics.RecordError(filePath, "parallel-scan", ex.Message); + } + + // Use legacy result for actual webhook response (safe during shadow period) + result.PicklescanExitCode = legacyResult.PicklescanExitCode; + result.PicklescanOutput = legacyResult.PicklescanOutput; + result.PicklescanGlobalImports = legacyResult.PicklescanGlobalImports; + result.PicklescanDangerousImports = legacyResult.PicklescanDangerousImports; + result.ClamscanExitCode = legacyResult.ClamscanExitCode; + result.ClamscanOutput = legacyResult.ClamscanOutput; + + // Also include TensorTrap results for visibility + result.TensorTrapScanned = unifiedResult.TensorTrapScanned; + result.TensorTrapMaxSeverity = unifiedResult.TensorTrapMaxSeverity; + result.TensorTrapIsSafe = unifiedResult.TensorTrapIsSafe; + result.TensorTrapFindings = unifiedResult.TensorTrapFindings; + + // Compare and log discrepancies (only if both scans completed) + if (legacyResult.PicklescanExitCode >= 0 && unifiedResult.TensorTrapScanned) + { + CompareAndLogResults(filePath, legacyResult, unifiedResult); + } + } + + void CompareAndLogResults(string filePath, ScanResultInternal legacy, ScanResultInternal unified) + { + // Use > 0 to distinguish "dangerous" from "error" (-1) exit codes + var legacyDangerous = legacy.PicklescanExitCode > 0 || legacy.ClamscanExitCode > 0; + var unifiedDangerous = unified.TensorTrapMaxSeverity is "CRITICAL" or "HIGH" || unified.ClamscanExitCode > 0; + + if (legacyDangerous != unifiedDangerous) { - // safetensors are safe... - if (fileExtension.EndsWith("safetensors", StringComparison.OrdinalIgnoreCase)) + // Log what TensorTrap found that picklescan missed (or vice versa) + if (unifiedDangerous && !legacyDangerous) { - result.PicklescanExitCode = 0; - result.PicklescanOutput = "safetensors"; - // TODO Improve Pickle Scan: It probably makes sense to verify that this is indeed a safetensor file - return; + var details = $"TensorTrap: {unified.TensorTrapMaxSeverity}, Findings: {unified.TensorTrapFindings?.Count ?? 0}"; + _metrics.RecordDiscrepancy(filePath, DiscrepancyType.UnifiedFoundMore, details); + + // Log limited findings to avoid excessive log size + var limitedFindings = unified.TensorTrapFindings? + .Take(3) + .Select(f => new { f.Severity, f.Message }) + .ToList(); + _logger.LogWarning( + "[SHADOW-NEW-DETECTION] TensorTrap detected threat missed by legacy scanner: {Findings}", + JsonSerializer.Serialize(limitedFindings) + ); } + else if (legacyDangerous && !unifiedDangerous) + { + var imports = string.Join(", ", legacy.PicklescanDangerousImports ?? new HashSet()); + _metrics.RecordDiscrepancy(filePath, DiscrepancyType.LegacyFoundMore, $"DangerousImports: {imports}"); + + _logger.LogWarning( + "[SHADOW-MISSED-DETECTION] Legacy detected threat not found by TensorTrap: DangerousImports={Imports}", + imports + ); + } + } + else + { + _metrics.RecordMatch(filePath, bothSafe: !legacyDangerous); - var (exitCode, output) = await _dockerService.RunCommandInDocker($"picklescan -p {DockerService.InPath} -l DEBUG", filePath, cancellationToken); + _logger.LogInformation( + "[SHADOW-MATCH] File: {FilePath} | Both scanners agree: {Result}", + filePath, + legacyDangerous ? "DANGEROUS" : "SAFE" + ); + } + } + + async Task RunLegacyScanInternal(string filePath, string fileExtension, CancellationToken cancellationToken) + { + var result = new ScanResultInternal(); - result.PicklescanExitCode = exitCode; - result.PicklescanOutput = output; - result.PicklescanGlobalImports = ParseGlobalImports(output); - result.PicklescanDangerousImports = ParseDangerousImports(output); + // ClamAV scan + var (clamExitCode, clamOutput) = await _dockerService.RunCommandInDocker( + $"clamscan {DockerService.InPath}", filePath, cancellationToken); + result.ClamscanExitCode = clamExitCode; + result.ClamscanOutput = clamOutput; - HashSet ParseGlobalImports(string? picklescanOutput) + // Picklescan (with legacy safetensors skip) + if (fileExtension.EndsWith("safetensors", StringComparison.OrdinalIgnoreCase)) + { + result.PicklescanExitCode = 0; + result.PicklescanOutput = "safetensors"; + } + else + { + var (pickleExitCode, pickleOutput) = await _dockerService.RunCommandInDocker( + $"picklescan -p {DockerService.InPath} -l DEBUG", filePath, cancellationToken); + result.PicklescanExitCode = pickleExitCode; + result.PicklescanOutput = pickleOutput; + result.PicklescanGlobalImports = ParseGlobalImports(pickleOutput); + result.PicklescanDangerousImports = ParseDangerousImports(pickleOutput); + } + + return result; + } + + async Task RunUnifiedScanInternal(string filePath, CancellationToken cancellationToken) + { + var result = new ScanResultInternal(); + + // Use dynamic container path to preserve file extension for TensorTrap + var containerPath = DockerService.GetContainerPath(filePath); + + var (exitCode, stdOut, stdErr) = await _dockerService.RunCommandInDocker( + DockerService.UnifiedScannerImage, + containerPath, + filePath, + cancellationToken + ); + + UnifiedScanOutput? scanOutput = null; + try + { + if (!string.IsNullOrWhiteSpace(stdOut)) { - var result = new HashSet(); + scanOutput = JsonSerializer.Deserialize(stdOut.Trim()); + } + } + catch (JsonException ex) + { + _logger.LogError(ex, "Failed to parse unified scanner JSON: {output}", stdOut); + _metrics.RecordError(filePath, "unified-json-parse", ex.Message); + // Mark as scanned but with error - allows shadow mode comparison to continue + result.TensorTrapScanned = false; + result.PicklescanOutput = stdOut; // Preserve raw output for debugging + result.PicklescanExitCode = -1; + } - if (picklescanOutput is not null) - { - const string globalImportListsRegex = """Global imports in (?:.+): {(.+)}"""; - - foreach (Match globalImportListMatch in Regex.Matches(picklescanOutput, globalImportListsRegex)) - { - var globalImportList = globalImportListMatch.Groups[1]; - const string globalImportsRegex = """\((.+?)\)"""; - - foreach (Match globalImportMatch in Regex.Matches(globalImportList.Value, globalImportsRegex)) - { - result.Add(globalImportMatch.Groups[1].Value); - } - } - } + if (scanOutput?.FirstResult != null) + { + result.TensorTrapScanned = true; + result.TensorTrapMaxSeverity = scanOutput.FirstResult.MaxSeverity; + result.TensorTrapIsSafe = scanOutput.FirstResult.IsSafe; + result.TensorTrapFindings = scanOutput.FirstResult.Findings; + } - return result; + if (scanOutput?.Clamav != null) + { + result.ClamscanExitCode = scanOutput.Clamav.ExitCode; + result.ClamscanOutput = scanOutput.Clamav.Output; + } + + return result; + } + + async Task RunLegacyScan(string filePath, ScanResult result, CancellationToken cancellationToken) + { + var fileExtension = Path.GetExtension(filePath); + var legacyResult = await RunLegacyScanInternal(filePath, fileExtension, cancellationToken); + + result.PicklescanExitCode = legacyResult.PicklescanExitCode; + result.PicklescanOutput = legacyResult.PicklescanOutput; + result.PicklescanGlobalImports = legacyResult.PicklescanGlobalImports; + result.PicklescanDangerousImports = legacyResult.PicklescanDangerousImports; + result.ClamscanExitCode = legacyResult.ClamscanExitCode; + result.ClamscanOutput = legacyResult.ClamscanOutput; + } + + async Task RunUnifiedScan(string filePath, ScanResult result, CancellationToken cancellationToken) + { + // Use dynamic container path to preserve file extension for TensorTrap + var containerPath = DockerService.GetContainerPath(filePath); + + var (exitCode, stdOut, stdErr) = await _dockerService.RunCommandInDocker( + DockerService.UnifiedScannerImage, + containerPath, + filePath, + cancellationToken + ); + + if (!string.IsNullOrWhiteSpace(stdErr)) + { + _logger.LogWarning("Unified scanner stderr: {stdErr}", stdErr); + } + + UnifiedScanOutput? scanOutput = null; + bool jsonParseFailed = false; + try + { + if (!string.IsNullOrWhiteSpace(stdOut)) + { + scanOutput = JsonSerializer.Deserialize(stdOut.Trim()); } + } + catch (JsonException ex) + { + _logger.LogError(ex, "Failed to parse unified scanner JSON output: {output}", stdOut); + _metrics.RecordError(filePath, "unified-json-parse", ex.Message); + jsonParseFailed = true; + } + + if (scanOutput?.FirstResult != null) + { + result.TensorTrapScanned = true; + result.TensorTrapMaxSeverity = scanOutput.FirstResult.MaxSeverity; + result.TensorTrapIsSafe = scanOutput.FirstResult.IsSafe; + result.TensorTrapFindings = scanOutput.FirstResult.Findings; + + var hasCriticalOrHigh = scanOutput.FirstResult.MaxSeverity is "CRITICAL" or "HIGH"; + result.PicklescanExitCode = hasCriticalOrHigh ? 1 : 0; + result.PicklescanOutput = stdOut; + result.PicklescanDangerousImports = ExtractDangerousImports(scanOutput.FirstResult.Findings); + result.PicklescanGlobalImports = ExtractGlobalImports(scanOutput.FirstResult.Findings); + } + else + { + result.TensorTrapScanned = false; + // Use exit code -1 for parse failures to indicate error vs clean exit + result.PicklescanExitCode = jsonParseFailed ? -1 : exitCode; + result.PicklescanOutput = stdOut + stdErr; + } + + if (scanOutput?.Clamav != null) + { + result.ClamscanExitCode = scanOutput.Clamav.ExitCode; + result.ClamscanOutput = scanOutput.Clamav.Output; + } + else + { + result.ClamscanExitCode = exitCode; + result.ClamscanOutput = stdErr; + } + } - HashSet ParseDangerousImports(string? picklescanOutput) + static HashSet ParseGlobalImports(string? picklescanOutput) + { + var result = new HashSet(); + if (picklescanOutput is null) return result; + + const string globalImportListsRegex = """Global imports in (?:.+): {(.+)}"""; + foreach (Match globalImportListMatch in Regex.Matches(picklescanOutput, globalImportListsRegex)) + { + var globalImportList = globalImportListMatch.Groups[1]; + const string globalImportsRegex = """\((.+?)\)"""; + foreach (Match globalImportMatch in Regex.Matches(globalImportList.Value, globalImportsRegex)) { - var result = new HashSet(); + result.Add(globalImportMatch.Groups[1].Value); + } + } + return result; + } + + static HashSet ParseDangerousImports(string? picklescanOutput) + { + var result = new HashSet(); + if (picklescanOutput is null) return result; + + const string dangerousImportsRegex = """dangerous import '(.+)'"""; + foreach (Match match in Regex.Matches(picklescanOutput, dangerousImportsRegex)) + { + result.Add(match.Groups[1].Value); + } + return result; + } - if (picklescanOutput is not null) + static HashSet ExtractDangerousImports(List? findings) + { + var imports = new HashSet(); + if (findings == null) return imports; + + foreach (var finding in findings) + { + if (finding.Severity is "CRITICAL" or "HIGH" && finding.Details != null) + { + if (finding.Details.TryGetValue("module", out var moduleEl) && + finding.Details.TryGetValue("function", out var functionEl)) { - const string dangerousImportsRegex = """dangerous import '(.+)'"""; - var dangerousImportMatches = Regex.Matches(picklescanOutput, dangerousImportsRegex); - - foreach (Match dangerousImporMatch in dangerousImportMatches) - { - var dangerousImport = dangerousImporMatch.Groups[1]; - result.Add(dangerousImport.Value); - } + var module = moduleEl.GetString(); + var function = functionEl.GetString(); + if (module != null && function != null) + imports.Add($"{module}.{function}"); + } + else if (finding.Details.TryGetValue("import", out var importEl)) + { + var import = importEl.GetString(); + if (import != null) + imports.Add(import); } - - return result; } } + return imports; + } - async Task RunClamScan(ScanResult result) - { - var (exitCode, output) = await _dockerService.RunCommandInDocker($"clamscan {DockerService.InPath}", filePath, cancellationToken); + static HashSet ExtractGlobalImports(List? findings) + { + var imports = new HashSet(); + if (findings == null) return imports; - result.ClamscanExitCode = exitCode; - result.ClamscanOutput = output; + foreach (var finding in findings) + { + if (finding.Details != null) + { + if (finding.Details.TryGetValue("module", out var moduleEl) && + finding.Details.TryGetValue("function", out var functionEl)) + { + var module = moduleEl.GetString(); + var function = functionEl.GetString(); + if (module != null && function != null) + imports.Add($"{module}, {function}"); + } + } } + return imports; + } + // Internal result class for shadow mode comparison + class ScanResultInternal + { + public int PicklescanExitCode { get; set; } + public string? PicklescanOutput { get; set; } + public HashSet? PicklescanGlobalImports { get; set; } + public HashSet? PicklescanDangerousImports { get; set; } + public int ClamscanExitCode { get; set; } + public string? ClamscanOutput { get; set; } + public bool TensorTrapScanned { get; set; } + public string? TensorTrapMaxSeverity { get; set; } + public bool TensorTrapIsSafe { get; set; } + public List? TensorTrapFindings { get; set; } } } + +// Configuration options for scanner +public class ScannerOptions +{ + /// + /// Use the unified scanner (TensorTrap + ClamAV) instead of legacy picklescan. + /// + public bool UseUnifiedScanner { get; set; } = false; + + /// + /// Shadow mode: run both scanners, compare results, but use legacy for actual response. + /// Logs discrepancies for analysis. + /// + public bool ShadowMode { get; set; } = false; +} + +// JSON deserialization types for unified scanner output +record UnifiedScanOutput +{ + [System.Text.Json.Serialization.JsonPropertyName("file")] + public string? File { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("tensortrap")] + public TensorTrapWrapper? TensorTrap { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("clamav")] + public ClamavOutput? Clamav { get; init; } + + // Helper to get the first (and typically only) file result + public TensorTrapFileResult? FirstResult => TensorTrap?.Results?.FirstOrDefault(); +} + +// TensorTrap returns { results: [...], summary: {...} } +record TensorTrapWrapper +{ + [System.Text.Json.Serialization.JsonPropertyName("results")] + public List? Results { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("summary")] + public TensorTrapSummary? Summary { get; init; } +} + +record TensorTrapFileResult +{ + [System.Text.Json.Serialization.JsonPropertyName("filepath")] + public string? FilePath { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("format")] + public string? Format { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("is_safe")] + public bool IsSafe { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("max_severity")] + public string? MaxSeverity { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("findings")] + public List? Findings { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("scan_time_ms")] + public double ScanTimeMs { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("file_size")] + public long FileSize { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("file_hash")] + public string? FileHash { get; init; } +} + +record TensorTrapSummary +{ + [System.Text.Json.Serialization.JsonPropertyName("total_files")] + public int TotalFiles { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("safe_files")] + public int SafeFiles { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("unsafe_files")] + public int UnsafeFiles { get; init; } +} + +record ClamavOutput +{ + [System.Text.Json.Serialization.JsonPropertyName("exit_code")] + public int ExitCode { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("output")] + public string? Output { get; init; } + + [System.Text.Json.Serialization.JsonPropertyName("infected")] + public bool Infected { get; init; } +} diff --git a/orchestrator/ModelScanner/appsettings.json b/orchestrator/ModelScanner/appsettings.json index 43f6712..f128f9f 100644 --- a/orchestrator/ModelScanner/appsettings.json +++ b/orchestrator/ModelScanner/appsettings.json @@ -15,6 +15,10 @@ "SecretKey": null, "UploadBucket": null }, + "ScannerOptions": { + "UseUnifiedScanner": false, + "ShadowMode": true + }, "ValidTokens": [ "wBwmoQFpp592A0pSoYkb" ] diff --git a/unified-scanner/Dockerfile b/unified-scanner/Dockerfile new file mode 100644 index 0000000..92ef0d0 --- /dev/null +++ b/unified-scanner/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.11-slim + +# Install ClamAV and git (git needed for pip install from GitHub) +RUN apt-get update && \ + apt-get install -y --no-install-recommends clamav ca-certificates git && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install TensorTrap from fork +RUN pip install --no-cache-dir git+https://github.com/JustMaier/TensorTrap.git + +# Copy wrapper script +COPY scan-wrapper.py /usr/local/bin/ +RUN chmod +x /usr/local/bin/scan-wrapper.py + +# ClamAV definitions will be mounted at /var/lib/clamav from shared volume +# This allows the clamav-updater sidecar to keep definitions fresh + +ENTRYPOINT ["python", "/usr/local/bin/scan-wrapper.py"] diff --git a/unified-scanner/Dockerfile.local b/unified-scanner/Dockerfile.local new file mode 100644 index 0000000..0b38fbb --- /dev/null +++ b/unified-scanner/Dockerfile.local @@ -0,0 +1,18 @@ +FROM python:3.11-slim + +# Install ClamAV +RUN apt-get update && \ + apt-get install -y --no-install-recommends clamav ca-certificates && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install TensorTrap from local source (using additional_contexts) +COPY --from=tensortrap . /tmp/TensorTrap/ +RUN pip install --no-cache-dir /tmp/TensorTrap && \ + rm -rf /tmp/TensorTrap + +# Copy wrapper script +COPY scan-wrapper.py /usr/local/bin/ +RUN chmod +x /usr/local/bin/scan-wrapper.py + +ENTRYPOINT ["python", "/usr/local/bin/scan-wrapper.py"] diff --git a/unified-scanner/scan-wrapper.py b/unified-scanner/scan-wrapper.py new file mode 100644 index 0000000..9f67c5e --- /dev/null +++ b/unified-scanner/scan-wrapper.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +""" +Unified scanner wrapper that runs both TensorTrap and ClamAV, +outputting combined results as JSON to stdout. +""" +import json +import subprocess +import sys +from pathlib import Path + + +def run_tensortrap(filepath: str) -> dict: + """Run TensorTrap ML security scanner.""" + try: + result = subprocess.run( + ["tensortrap", "scan", filepath, "-j", "--no-report"], + capture_output=True, + text=True, + timeout=300 # 5 minute timeout + ) + + try: + data = json.loads(result.stdout) + # Handle single file result (TensorTrap returns a list for directory scans) + if isinstance(data, list) and len(data) == 1: + data = data[0] + return data + except json.JSONDecodeError: + return { + "error": "Failed to parse TensorTrap output", + "stderr": result.stderr, + "stdout": result.stdout, + "exit_code": result.returncode + } + except subprocess.TimeoutExpired: + return {"error": "TensorTrap scan timed out", "exit_code": -1} + except Exception as e: + return {"error": str(e), "exit_code": -1} + + +def run_clamav(filepath: str) -> dict: + """Run ClamAV antivirus scanner.""" + try: + result = subprocess.run( + ["clamscan", "--no-summary", filepath], + capture_output=True, + text=True, + timeout=300 # 5 minute timeout + ) + return { + "exit_code": result.returncode, + "output": result.stdout.strip(), + "infected": result.returncode == 1 + } + except subprocess.TimeoutExpired: + return {"error": "ClamAV scan timed out", "exit_code": -1, "infected": False} + except Exception as e: + return {"error": str(e), "exit_code": -1, "infected": False} + + +def main(): + filepath = sys.argv[1] if len(sys.argv) > 1 else "/data/model.in" + + # Run both scanners + tensortrap_result = run_tensortrap(filepath) + clamav_result = run_clamav(filepath) + + # Build combined output + output = { + "file": filepath, + "tensortrap": tensortrap_result, + "clamav": clamav_result + } + + # Output JSON to stdout + print(json.dumps(output, indent=None)) + + # Determine exit code + # TensorTrap: check for errors first, then max_severity for CRITICAL or HIGH + tt_error = "error" in tensortrap_result + tt_max_severity = tensortrap_result.get("max_severity", "") + tt_unsafe = tt_max_severity in ("CRITICAL", "HIGH") + + # ClamAV: check for errors first, then infected flag + clam_error = "error" in clamav_result + clam_infected = clamav_result.get("infected", False) + + # Exit 1 if either scanner found issues OR had errors (fail-safe) + sys.exit(1 if (tt_unsafe or clam_infected or tt_error or clam_error) else 0) + + +if __name__ == "__main__": + main()