From 701c229cb88f6dbbc7998fda6fc3c504a2e21ebe Mon Sep 17 00:00:00 2001 From: senthil vasan Date: Wed, 15 Apr 2026 13:19:53 +0530 Subject: [PATCH] Revert "Holla" --- README.md | 304 +++----- mini_agent | 1 - smp/__init__.py | 3 + smp/agent.py | 431 +++++++++++ smp/cli.py | 301 ++++++++ smp/client.py | 201 +++++ smp/core/__init__.py | 1 + smp/core/background.py | 182 +++++ smp/core/merkle.py | 91 +++ smp/core/models.py | 644 ++++++++++++++++ smp/engine/__init__.py | 11 + smp/engine/community.py | 499 +++++++++++++ smp/engine/enricher.py | 97 +++ smp/engine/graph_builder.py | 159 ++++ smp/engine/handoff.py | 203 +++++ smp/engine/integrity.py | 242 ++++++ smp/engine/interfaces.py | 153 ++++ smp/engine/linker.py | 205 ++++++ smp/engine/notification.py | 90 +++ smp/engine/pagerank.py | 119 +++ smp/engine/query.py | 817 +++++++++++++++++++++ smp/engine/runtime_linker.py | 212 ++++++ smp/engine/safety.py | 590 +++++++++++++++ smp/engine/seed_walk.py | 443 +++++++++++ smp/engine/telemetry.py | 161 ++++ smp/logging.py | 68 ++ smp/parser/__init__.py | 11 + smp/parser/base.py | 153 ++++ smp/parser/python_parser.py | 553 ++++++++++++++ smp/parser/registry.py | 72 ++ smp/parser/typescript_parser.py | 525 +++++++++++++ smp/protocol/__init__.py | 9 + smp/protocol/dispatcher.py | 267 +++++++ smp/protocol/handlers/__init__.py | 1 + smp/protocol/handlers/annotation.py | 117 +++ smp/protocol/handlers/base.py | 34 + smp/protocol/handlers/community.py | 94 +++ smp/protocol/handlers/enrichment.py | 185 +++++ smp/protocol/handlers/handoff.py | 68 ++ smp/protocol/handlers/memory.py | 115 +++ smp/protocol/handlers/merkle.py | 81 ++ smp/protocol/handlers/query.py | 142 ++++ smp/protocol/handlers/query_ext.py | 115 +++ smp/protocol/handlers/safety.py | 338 +++++++++ smp/protocol/handlers/sandbox.py | 110 +++ smp/protocol/handlers/telemetry.py | 122 +++ smp/protocol/router.py | 653 ++++++++++++++++ smp/protocol/server.py | 154 ++++ smp/sandbox/__init__.py | 1 + smp/sandbox/docker_sandbox.py | 57 ++ smp/sandbox/ebpf_collector.py | 29 + smp/sandbox/executor.py | 169 +++++ smp/sandbox/spawner.py | 113 +++ smp/store/__init__.py | 1 + smp/store/chroma_store.py | 155 ++++ smp/store/graph/__init__.py | 1 + smp/store/graph/neo4j_store.py | 558 ++++++++++++++ smp/store/interfaces.py | 249 +++++++ test_codebase/src/auth/manager.py | 2 - test_codebase/src/db/user_store.py | 1 - test_codebase/tests/test_auth.py | 2 - tests/practical_verification.py | 12 +- tests/test_codebase/api/routes.py | 13 +- tests/test_codebase/calculator.py | 2 - tests/test_codebase/db/order_repository.py | 1 - tests/test_codebase/db/user_repository.py | 1 - tests/test_codebase/math_utils.py | 1 - tests/test_integration_community.py | 2 +- tests/test_integration_parser_graph.py | 30 +- tests/test_integration_sandbox.py | 1 - 70 files changed, 11294 insertions(+), 254 deletions(-) delete mode 160000 mini_agent create mode 100644 smp/__init__.py create mode 100644 smp/agent.py create mode 100644 smp/cli.py create mode 100644 smp/client.py create mode 100644 smp/core/__init__.py create mode 100644 smp/core/background.py create mode 100644 smp/core/merkle.py create mode 100644 smp/core/models.py create mode 100644 smp/engine/__init__.py create mode 100644 smp/engine/community.py create mode 100644 smp/engine/enricher.py create mode 100644 smp/engine/graph_builder.py create mode 100644 smp/engine/handoff.py create mode 100644 smp/engine/integrity.py create mode 100644 smp/engine/interfaces.py create mode 100644 smp/engine/linker.py create mode 100644 smp/engine/notification.py create mode 100644 smp/engine/pagerank.py create mode 100644 smp/engine/query.py create mode 100644 smp/engine/runtime_linker.py create mode 100644 smp/engine/safety.py create mode 100644 smp/engine/seed_walk.py create mode 100644 smp/engine/telemetry.py create mode 100644 smp/logging.py create mode 100644 smp/parser/__init__.py create mode 100644 smp/parser/base.py create mode 100644 smp/parser/python_parser.py create mode 100644 smp/parser/registry.py create mode 100644 smp/parser/typescript_parser.py create mode 100644 smp/protocol/__init__.py create mode 100644 smp/protocol/dispatcher.py create mode 100644 smp/protocol/handlers/__init__.py create mode 100644 smp/protocol/handlers/annotation.py create mode 100644 smp/protocol/handlers/base.py create mode 100644 smp/protocol/handlers/community.py create mode 100644 smp/protocol/handlers/enrichment.py create mode 100644 smp/protocol/handlers/handoff.py create mode 100644 smp/protocol/handlers/memory.py create mode 100644 smp/protocol/handlers/merkle.py create mode 100644 smp/protocol/handlers/query.py create mode 100644 smp/protocol/handlers/query_ext.py create mode 100644 smp/protocol/handlers/safety.py create mode 100644 smp/protocol/handlers/sandbox.py create mode 100644 smp/protocol/handlers/telemetry.py create mode 100644 smp/protocol/router.py create mode 100644 smp/protocol/server.py create mode 100644 smp/sandbox/__init__.py create mode 100644 smp/sandbox/docker_sandbox.py create mode 100644 smp/sandbox/ebpf_collector.py create mode 100644 smp/sandbox/executor.py create mode 100644 smp/sandbox/spawner.py create mode 100644 smp/store/__init__.py create mode 100644 smp/store/chroma_store.py create mode 100644 smp/store/graph/__init__.py create mode 100644 smp/store/graph/neo4j_store.py create mode 100644 smp/store/interfaces.py diff --git a/README.md b/README.md index 40cf155..f025172 100644 --- a/README.md +++ b/README.md @@ -2,23 +2,29 @@ **High-Fidelity Codebase Intelligence for AI Agents** ---- - -SMP (Structural Memory Protocol) is a graph-based memory system that provides AI agents with a deep, structured understanding of complex codebases. Unlike RAG which treats code as flat text, SMP models code as a multi-dimensional graph of entities, relationships, and semantic meanings. +Structural Memory Protocol (SMP) is a graph-based memory system that provides AI agents with a deep, structured understanding of complex codebases. Unlike RAG which treats code as flat text, SMP models code as a multi-dimensional graph of entities, relationships, and semantic meanings. -**Version:** 1.3.0 | **Stack:** Python 3.11+, FastAPI, Neo4j, ChromaDB +Built with **Python 3.11**, **FastAPI**, and **Neo4j**, SMP enables agents to perform precise code navigation, impact analysis, and safe refactoring — using static analysis (no LLM required). --- ## Quickstart (Docker Compose) +The fastest way to get SMP running: + ```bash -git clone https://github.com/offx-zinth/smp.git +# Clone the repository +git clone https://github.com/your-org/smp.git cd smp + +# Copy and configure environment cp .env.example .env # Edit .env with your Neo4j password +# Start all services docker compose up -d + +# Verify health curl http://localhost:8420/health # Returns: {"status":"ok"} ``` @@ -30,7 +36,6 @@ curl http://localhost:8420/health ### 1. Requirements - **Python 3.11+** - **Neo4j 5.x** (Local or AuraDB) -- **uv** (recommended) or pip ### 2. Environment ```bash @@ -43,14 +48,13 @@ cp .env.example .env ### 3. Install & Run ```bash +# Clone and enter the repo git clone https://github.com/offx-zinth/smp.git cd smp # Create venv with Python 3.11 python3.11 -m venv .venv source .venv/bin/activate - -# Install with dev dependencies pip install -e ".[dev]" # Start the server @@ -59,169 +63,22 @@ smp serve --- -## Architecture Overview - -``` -┌─────────────────────────────────────────────────────────────────┐ -│ CODEBASE (Files + Git) │ -└──────────────────────────┬──────────────────────────────────────┘ - │ Updates (Watch / Agent Push / commit_sha) - ▼ -┌─────────────────────────────────────────────────────────────────┐ -│ MEMORY SERVER (SMP Core) │ -│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │ -│ │ PARSER │──▶│ GRAPH BUILDER│──▶│ ENRICHER │ │ -│ │ (AST/Tree- │ │ + LINKER │ │ (Static │ │ -│ │ sitter) │ │ (Static + │ │ Metadata) │ │ -│ │ │ │ eBPF Runtime│ │ │ │ -│ └─────────────┘ └──────────────┘ └──────┬──────┘ │ -│ │ │ -│ ┌───────────────────────────────────────────▼──────────────┐ │ -│ │ MEMORY STORE │ │ -│ │ │ │ -│ │ ┌─────────────────────────────────────┐ │ │ -│ │ │ GRAPH DB (Neo4j) │ │ │ -│ │ │ Structure · CALLS_STATIC │ │ │ -│ │ │ CALLS_RUNTIME · PageRank │ │ │ -│ │ │ Sessions · Audit · Telemetry │ │ │ -│ │ │ Full-Text Index (BM25) │ │ │ -│ │ └─────────────────────────────────────┘ │ │ -│ │ │ │ -│ │ ┌─────────────────────────────────────┐ │ │ -│ │ │ VECTOR INDEX (ChromaDB) │ │ │ -│ │ │ code_embedding per node │ │ │ -│ │ │ community centroid embeddings │ │ │ -│ │ └─────────────────────────────────────┘ │ │ -│ │ │ │ -│ │ ┌─────────────────────────────────────┐ │ │ -│ │ │ MERKLE INDEX │ │ │ -│ │ │ SHA-256 leaf per file node │ │ │ -│ │ │ O(log n) sync & diff │ │ │ -│ │ └─────────────────────────────────────┘ │ │ -│ └──────────────────────┬───────────────────────────────────┘ │ -└─────────────────────────┼───────────────────────────────────────┘ - │ - ┌───────────────────────┬───────────────┐ - │ │ │ - ▼ ▼ ▼ -┌─────────────────┐ ┌──────────────────────┐ ┌───────────────┐ -│ QUERY ENGINE │ │ SANDBOX RUNTIME │ │ SWARM LAYER │ -│ Navigator │ │ Ephemeral microVM/ │ │ Peer Review │ -│ Reasoner │ │ Docker + CoW fork │ │ PR Handoff │ -│ SeedWalkEngine │ │ eBPF trace capture │ │ │ -│ Telemetry │ │ Egress-firewalled │ └───────┬───────┘ -│ Community │ │ Mutation Testing │ │ -└────────┬────────┘ └──────────┬───────────┘ │ - └──────────────┬────────┘ ────────┘ - │ SMP Protocol (Dispatcher) - ▼ - ┌─────────────────────────────────────────────┐ - │ AGENT LAYER │ - │ Agent A Agent B Agent C │ - │ (Coder) (Reviewer) (Architect) │ - └─────────────────────────────────────────────┘ -``` - ---- - -## Key Features - -### Memory Store - -| Component | Technology | Purpose | -|-----------|-----------|---------| -| **Graph DB** | Neo4j | Structure, CALLS, IMPORTS, PageRank, Sessions | -| **Vector Index** | ChromaDB | code_embedding, community centroids | -| **Merkle Index** | SHA-256 | O(log n) sync, state tracking | - -### Query Engine - -| Method | Description | -|--------|------------| -| `smp/navigate` | Find specific entities | -| `smp/trace` | Follow relationships | -| `smp/context` | Get relevant context (with role classification) | -| `smp/impact` | Assess change impact | -| `smp/locate` | SeedWalkEngine - Community-routed graph RAG | -| `smp/search` | BM25 full-text search | -| `smp/flow` | Trace data/logic path | - -### Community Detection - -- **Louvain Algorithm** at two resolutions (L0: coarse, L1: fine) -- **Centroid embeddings** for Phase 0 routing -- **Bridge detection** for cross-community coupling - -### Agent Safety Protocol - -| Method | Description | -|--------|------------| -| `smp/session/open` | Open agent session | -| `smp/session/close` | Close and persist | -| `smp/session/recover` | Resume session | -| `smp/lock` | Exclusive file lock | -| `smp/unlock` | Release lock | -| `smp/guard/check` | Pre-flight safety | -| `smp/dryrun` | Simulate changes | -| `smp/checkpoint` | Snapshot state | -| `smp/rollback` | Restore checkpoint | -| `smp/verify/integrity` | Mutation testing | -| `smp/audit` | Event logging | - -### Sandbox Runtime - -| Component | Description | -|-----------|------------| -| `smp/sandbox/spawn` | Create isolated environment | -| `smp/sandbox/execute` | Run code in sandbox | -| `smp/sandbox/destroy` | Cleanup | -| **DockerSandbox** | Container with CoW filesystem | -| **EBPFCollector** | Runtime trace capture | - -### Swarm Handoff - -| Method | Description | -|--------|------------| -| `smp/handoff/review` | Create peer review | -| `smp/handoff/pr` | Generate PR | - ---- - -## Integration Tests +## Architecture: Manual Efficient Method (SMP V2) -The SMP codebase includes comprehensive integration tests covering all major components: +SMP V2 is designed for production-grade efficiency. It relies on **static AST extraction** and **Neo4j full-text indexing** — no LLM or vector embeddings required. -```bash -# Run all integration tests -pytest tests/test_integration_*.py -v - -# Results: 163 passed, 5 skipped -``` - -| Test Suite | Tests | Status | -|-----------|-------|--------| -| Query Engine | 34 | ✅ Pass | -| Agent Safety | 42 | ✅ Pass | -| Community Detection | 20 | ✅ Pass | -| Merkle Index | 16 | ✅ Pass | -| Vector Store | 29 | ✅ Pass | -| Protocol Handlers | 21 | ✅ Pass | -| Sandbox (Directory) | 22 | ✅ Pass | - -### Tested Components - -- **Parser + Graph Builder**: Extracts nodes and creates CALLS/IMPORTS/DEFINES edges -- **Query Engine**: navigate, trace, locate (SeedWalkEngine), get_context, assess_impact, find_flow -- **Safety**: Session management, locking, guards, dry runs, checkpoints, audit logging -- **Community**: Louvain L0/L1 detection, bridge detection, centroid computation -- **Merkle**: Tree build, hash, diff, sync, export/import -- **Vector Store**: ChromaDB upsert/query/delete with metadata filtering -- **Protocol**: All JSON-RPC methods registered and instantiatable +- **Parser**: Tree-sitter extracts functions, classes, imports, and docstrings directly from AST. +- **Enricher**: Extracts docstrings, decorators, and type annotations statically. +- **Linker**: Namespaced cross-file resolution for CALLS edges. +- **Query Engine**: Neo4j full-text index (BM25) for keyword search. +- **Safety Protocol**: Session management, dry-runs, and isolated sandbox execution. --- ## Demo: JSON-RPC Query +Ingest a codebase and query it: + ```bash # Ingest a project smp ingest /path/to/your/project @@ -246,18 +103,30 @@ curl -X POST http://localhost:8420/rpc \ { "jsonrpc": "2.0", "result": { - "self": {...}, - "imports": [...], - "imported_by": [...], - "defines": [...], - "entry_points": [...], - "data_flow_in": [...], - "data_flow_out": [...], - "summary": { - "role": "core_utility", - "blast_radius": 42, - "avg_complexity": 3.2, - "risk_level": "medium" + "self": { + "id": "smp/core/models.py::GraphNode", + "type": "Class", + "name": "GraphNode", + "signature": "class GraphNode", + "start_line": 130, + "end_line": 220 + }, + "neighbors": [ + { + "id": "smp/core/models.py::StructuralProperties", + "type": "Class", + "relationship": "CONTAINS" + }, + { + "id": "smp/core/models.py::SemanticProperties", + "type": "Class", + "relationship": "CONTAINS" + } + ], + "context": { + "file": "smp/core/models.py", + "imports": ["msgspec", "typing"], + "defines": ["GraphNode", "GraphEdge", "NodeType", "EdgeType"] } }, "id": 1 @@ -266,15 +135,57 @@ curl -X POST http://localhost:8420/rpc \ --- -## Python SDK +## Key Capabilities + +* **Graph-Augmented Retrieval:** Navigate via `CALLS`, `INHERITS`, `IMPORTS` relationships +* **Semantic Search:** Neo4j full-text index (BM25) for keyword search across docstrings/tags +* **Static Enrichment:** Docstrings, decorators, and type annotations extracted from AST +* **Impact Assessment:** Determine the "blast radius" before changes +* **Safety & Sandboxing:** Session management, dry-runs, isolated execution +* **Multi-Language:** Python and TypeScript/JavaScript via Tree-sitter + +--- +## Architecture + +``` +smp/ +├── smp/ +│ ├── core/ # Models, logging +│ ├── engine/ # Query, enricher, linker, safety +│ ├── protocol/ # JSON-RPC 2.0 API +│ │ └── handlers/ # Modular method handlers +│ ├── store/ # Neo4j (graph + full-text) +│ ├── parser/ # Tree-sitter parsing +│ ├── sandbox/ # Isolated execution +│ ├── cli.py # CLI +│ └── client.py # Python SDK +├── tests/ # Test suite +└── .github/workflows/# CI/CD +``` + +--- + +## Usage + +### Ingest a Project +```bash +smp ingest /path/to/project --clear +``` + +### Run Server +```bash +smp serve --port 8420 --safety +``` + +### Python SDK ```python import asyncio from smp.client import SMPClient async def main(): async with SMPClient("http://localhost:8420") as client: - # Graph RAG (SeedWalkEngine) + # Semantic search results = await client.locate("authentication logic") # Trace call graph @@ -302,10 +213,7 @@ ruff check . mypy smp/ # Test -pytest tests/ - -# Integration tests -pytest tests/test_integration_*.py +pytest ``` --- @@ -314,35 +222,19 @@ pytest tests/test_integration_*.py | Issue | Solution | |:---|:---| -| `sqlite3` ImportError | Install `pysqlite3-binary` (automatically handled) | +| `sqlite3` ImportError | Install `pysqlite3-binary` | | Neo4j Connection | Check `SMP_NEO4J_URI` and credentials in `.env` | -| ChromaDB errors | Ensure sqlite3 >= 3.35.0 or use pysqlite3 | -| Docker sandbox | Run with appropriate socket permissions | +| SyntaxError | Use Python 3.11 | +| Enrichment Timeout | Set `SMP_ENRICHMENT=none` in `.env` | --- -## Project Structure +## Contributing -``` -smp/ -├── smp/ -│ ├── core/ # Models, Merkle index, logging -│ ├── engine/ # Query, enricher, linker, safety -│ │ # community, seed_walk, pagerank -│ ├── protocol/ # JSON-RPC 2.0 API -│ │ └── handlers/ # Modular method handlers -│ ├── store/ # Neo4j, ChromaDB interfaces -│ ├── parser/ # Tree-sitter parsing -│ ├── sandbox/ # Docker, eBPF collector -│ ├── cli.py # CLI -│ └── client.py # Python SDK -├── tests/ -│ ├── fixtures/ # Sample projects -│ └── test_integration_*.py # Integration tests -├── .env.example -├── pyproject.toml -└── README.md -``` +1. Use `feature/` or `fix/` branches +2. Follow patterns in `AGENTS.md` +3. Add tests for new features +4. Run `ruff check . && ruff format . && mypy smp/ && pytest` --- diff --git a/mini_agent b/mini_agent deleted file mode 160000 index cb4a7ac..0000000 --- a/mini_agent +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cb4a7ac1d8605ffa4c9cbdbe3fd80e68122fd9f2 diff --git a/smp/__init__.py b/smp/__init__.py new file mode 100644 index 0000000..7fe7acb --- /dev/null +++ b/smp/__init__.py @@ -0,0 +1,3 @@ +"""SMP — Structural Memory Protocol.""" + +__version__ = "0.1.0" diff --git a/smp/agent.py b/smp/agent.py new file mode 100644 index 0000000..4d4b6ac --- /dev/null +++ b/smp/agent.py @@ -0,0 +1,431 @@ +"""CodingAgent — AI coding agent powered by Structural Memory Protocol. + +Wraps :class:`SMPClient` into a six-step workflow that gathers structural +context, assesses change impact, asks an LLM to generate an edit, writes +the result to disk, and syncs the graph back. + +Usage:: + + from smp.agent import CodingAgent + from smp.client import SMPClient + + async with SMPClient("http://localhost:8420") as client: + agent = CodingAgent(client, zen_api_key="...") + result = await agent.run( + file_path="src/auth.py", + instruction="Add rate limiting to the login endpoint", + ) + print(result["summary"]) +""" + +from __future__ import annotations + +import os +import re +import time +from pathlib import Path +from typing import Any + +import msgspec + +from smp.client import SMPClient +from smp.logging import get_logger + +log = get_logger(__name__) + + +class AgentError(Exception): + """Raised when the agent cannot complete its workflow.""" + + +class AgentResult(msgspec.Struct): + """Outcome of a single :meth:`CodingAgent.run` invocation.""" + + file_path: str + instruction: str + original_content: str + edited_content: str + context: dict[str, Any] = msgspec.field(default_factory=dict) + impact: dict[str, Any] = msgspec.field(default_factory=dict) + summary: str = "" + nodes_synced: int = 0 + edges_synced: int = 0 + + +# --------------------------------------------------------------------------- +# Gemini LLM backend (lazy import, mirrors enricher pattern) +# --------------------------------------------------------------------------- + + +class _GeminiBackend: + """Wraps Google Gemini API for code-edit generation using Gemma 3.""" + + def __init__(self, api_key: str, model: str = "gemma-3-27b-it") -> None: + from google import genai + + self._client = genai.Client(api_key=api_key) + self._model = model + + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate a response from the model.""" + response = self._client.models.generate_content( + model=self._model, + contents=f"{system_prompt}\n\n{user_prompt}", + ) + return str(response.text or "") + + +# --------------------------------------------------------------------------- +# CodingAgent +# --------------------------------------------------------------------------- + + +class CodingAgent: + """AI coding agent that uses SMP for structural awareness. + + The agent follows a six-step workflow: + + 1. **Context** — query ``smp/context`` for the file's mental model. + 2. **Impact** — query ``smp/impact`` for blast-radius analysis. + 3. **Generate** — send context + instruction to the LLM for an edit. + 4. **Write** — persist the edited file to disk. + 5. **Sync** — call ``smp/update`` so SMP re-parses the changed file. + + Args: + client: Connected :class:`SMPClient` instance. + gemini_api_key: Google Gemini API key. Falls back to GEMINI_API_KEY or GOOGLE_API_KEY env var. + model: Gemini model name (default: gemma-3-27b-it). + """ + + def __init__( + self, + client: SMPClient, + *, + gemini_api_key: str | None = None, + model: str = "gemma-3-27b-it", + ) -> None: + self._client = client + self._llm: _GeminiBackend | None = None + + key = gemini_api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") + if key: + try: + self._llm = _GeminiBackend(api_key=key, model=model) + log.info("agent_llm_ready", model=model) + except Exception as exc: + log.warning("agent_llm_init_failed", error=str(exc)) + else: + log.warning("agent_no_llm", reason="no_api_key") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def run(self, file_path: str, instruction: str) -> AgentResult: + """Execute the full agent workflow and return the result. + + Args: + file_path: Path to the source file to edit. + instruction: Natural-language description of the desired change. + + Returns: + An :class:`AgentResult` with before/after content and metadata. + + Raises: + AgentError: On unrecoverable failures (missing file, no LLM, etc.). + """ + workflow_id = f"wf_{int(time.monotonic() * 1000)}" + log.info( + "agent_workflow_start", + workflow_id=workflow_id, + file_path=file_path, + instruction=instruction[:120], + ) + + t_start = time.monotonic() + + # Step 1 — read the current file + original = await self._read_file(file_path) + log.info("agent_step_complete", step=1, label="read_file", workflow_id=workflow_id) + + # Step 2 — structural context + context = await self._step_context(file_path, workflow_id) + + # Step 3 — impact assessment + impact = await self._step_impact(file_path, context, workflow_id) + + # Step 4 — LLM edit generation + edited = await self._step_generate(file_path, instruction, original, context, impact, workflow_id) + + # Step 5 — write to disk + await self._step_write(file_path, edited, workflow_id) + + # Step 6 — sync back into structural memory + sync_result = await self._step_sync(file_path, edited, workflow_id) + + elapsed = round(time.monotonic() - t_start, 2) + nodes = sync_result.get("nodes", 0) + edges = sync_result.get("edges", 0) + + summary = f"Edited {file_path}: {instruction}. Synced {nodes} nodes, {edges} edges in {elapsed}s." + + log.info( + "agent_workflow_complete", + workflow_id=workflow_id, + file_path=file_path, + elapsed_s=elapsed, + nodes=nodes, + edges=edges, + ) + + return AgentResult( + file_path=file_path, + instruction=instruction, + original_content=original, + edited_content=edited, + context=context, + impact=impact, + summary=summary, + nodes_synced=nodes, + edges_synced=edges, + ) + + # ------------------------------------------------------------------ + # Step implementations + # ------------------------------------------------------------------ + + async def _read_file(self, file_path: str) -> str: + """Read file content from disk.""" + log.info("agent_read_file", file_path=file_path) + path = Path(file_path) + if not path.exists(): + raise AgentError(f"File not found: {file_path}") + content = path.read_text(encoding="utf-8") + log.info("agent_file_read", file_path=file_path, size_bytes=len(content)) + return content + + async def _step_context(self, file_path: str, workflow_id: str) -> dict[str, Any]: + """Step 2 — query SMP for the file's structural context.""" + log.info("agent_step_start", step=2, label="context", workflow_id=workflow_id) + + ctx = await self._client.get_context(file_path, scope="edit", depth=2) + + node_count = len(ctx.get("nodes", [])) + edge_count = len(ctx.get("edges", [])) + types = self._summarise_node_types(ctx.get("nodes", [])) + + log.info( + "agent_context_ready", + workflow_id=workflow_id, + nodes=node_count, + edges=edge_count, + **types, + ) + return ctx + + async def _step_impact( + self, + file_path: str, + context: dict[str, Any], + workflow_id: str, + ) -> dict[str, Any]: + """Step 3 — assess the blast radius of modifying *file_path*.""" + log.info("agent_step_start", step=3, label="impact", workflow_id=workflow_id) + + nodes = context.get("nodes", []) + target_id = self._pick_impact_target(nodes, file_path) + + if not target_id: + log.info("agent_impact_skip", workflow_id=workflow_id, reason="no_entity_found") + return {"entity": None, "affected_nodes": [], "total_affected": 0} + + impact = await self._client.assess_impact(target_id, change_type="modify") + affected = impact.get("affected_nodes", []) + + # Build a concise summary of downstream effects + downstream = self._format_downstream(affected) + + log.info( + "agent_impact_assessed", + workflow_id=workflow_id, + entity=target_id, + affected_count=len(affected), + downstream=downstream[:8], + ) + return impact + + async def _step_generate( + self, + file_path: str, + instruction: str, + original: str, + context: dict[str, Any], + impact: dict[str, Any], + workflow_id: str, + ) -> str: + """Step 4 — ask the LLM to produce an edited version of the file.""" + log.info("agent_step_start", step=4, label="generate", workflow_id=workflow_id) + + if not self._llm: + raise AgentError("No LLM backend. Set GEMINI_API_KEY or GOOGLE_API_KEY to enable edit generation.") + + system_prompt = self._build_system_prompt() + user_prompt = self._build_user_prompt( + file_path=file_path, + instruction=instruction, + original=original, + context=context, + impact=impact, + ) + + log.info("agent_llm_call", workflow_id=workflow_id, model=self._llm._model) + raw = self._llm.generate(system_prompt, user_prompt) + edited = self._extract_code(raw) + + log.info( + "agent_llm_response", + workflow_id=workflow_id, + raw_chars=len(raw), + edited_chars=len(edited), + ) + return edited + + async def _step_write(self, file_path: str, content: str, workflow_id: str) -> None: + """Step 5 — write the edited content to disk.""" + log.info("agent_step_start", step=5, label="write", workflow_id=workflow_id) + + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8") + + log.info("agent_file_written", file_path=file_path, size_bytes=len(content)) + + async def _step_sync(self, file_path: str, content: str, workflow_id: str) -> dict[str, Any]: + """Step 6 — push the changed file back into the structural memory.""" + log.info("agent_step_start", step=6, label="sync", workflow_id=workflow_id) + + result = await self._client.update(file_path, content=content) + + log.info( + "agent_sync_complete", + workflow_id=workflow_id, + file_path=file_path, + nodes=result.get("nodes", 0), + edges=result.get("edges", 0), + enriched=result.get("enriched", 0), + errors=result.get("errors", 0), + ) + return result + + # ------------------------------------------------------------------ + # Prompt construction + # ------------------------------------------------------------------ + + @staticmethod + def _build_system_prompt() -> str: + return ( + "You are an expert software engineer. You will receive a source file, " + "its structural context (classes, functions, imports, relationships), " + "and an instruction for how to modify it.\n\n" + "Rules:\n" + "- Return ONLY the complete modified file content.\n" + "- Do NOT wrap in markdown code fences.\n" + "- Do NOT add explanations before or after the code.\n" + "- Preserve existing style, conventions, and imports.\n" + "- Only change what the instruction requires.\n" + "- Ensure the result is syntactically valid." + ) + + @staticmethod + def _build_user_prompt( + *, + file_path: str, + instruction: str, + original: str, + context: dict[str, Any], + impact: dict[str, Any], + ) -> str: + parts: list[str] = [] + + parts.append(f"## File: {file_path}") + parts.append(f"## Instruction\n{instruction}") + + # Context block + nodes = context.get("nodes", []) + edges = context.get("edges", []) + if nodes: + parts.append("## Structural Context") + for n in nodes[:30]: + sem = n.get("semantic") + purpose = f" — {sem['purpose']}" if sem and sem.get("purpose") else "" + parts.append(f" - {n['type']} {n['name']} (L{n['start_line']}-{n['end_line']}){purpose}") + if edges: + parts.append(f" ({len(edges)} relationships)") + + # Impact block + affected = impact.get("affected_nodes", []) + if affected: + parts.append(f"## Impact Analysis — {len(affected)} downstream entities affected") + for a in affected[:10]: + parts.append(f" - {a['type']} {a['name']} in {a['file_path']}") + if len(affected) > 10: + parts.append(f" ... and {len(affected) - 10} more") + + # Original source + parts.append(f"## Current Source\n```\n{original}\n```") + parts.append("## Modified Source") + + return "\n\n".join(parts) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_code(llm_response: str) -> str: + """Extract source code from an LLM response. + + Handles responses wrapped in markdown code fences as well as raw code. + """ + fenced: list[str] = re.findall(r"```(?:\w*)\n(.*?)```", llm_response, re.DOTALL) + if fenced: + return str(fenced[0].strip()) + # No fences — strip common LLM preamble lines + lines = llm_response.split("\n") + start = 0 + for i, line in enumerate(lines): + stripped = line.strip() + if stripped and not stripped.startswith("#") and not stripped.startswith("//"): + start = i + break + return "\n".join(lines[start:]).strip() + + @staticmethod + def _summarise_node_types(nodes: list[dict[str, Any]]) -> dict[str, int]: + """Count nodes by type for structured log output.""" + counts: dict[str, int] = {} + for n in nodes: + t = n.get("type", "UNKNOWN") + counts[t] = counts.get(t, 0) + 1 + return counts + + @staticmethod + def _pick_impact_target(nodes: list[dict[str, Any]], file_path: str) -> str | None: + """Choose the best entity for impact analysis. + + Prefers the first FUNCTION or CLASS node; falls back to the FILE node. + """ + file_node_id: str | None = None + for n in nodes: + ntype = str(n.get("type", "")) + nid = str(n.get("id", "")) + if ntype in ("FUNCTION", "CLASS"): + return nid + if ntype == "FILE" and not file_node_id: + file_node_id = nid + return file_node_id + + @staticmethod + def _format_downstream(affected: list[dict[str, Any]]) -> list[str]: + """Format affected nodes into compact summary strings.""" + return [f"{a.get('type', '?')} {a.get('name', '?')} @ {a.get('file_path', '?')}" for a in affected] diff --git a/smp/cli.py b/smp/cli.py new file mode 100644 index 0000000..46af79c --- /dev/null +++ b/smp/cli.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import argparse +import asyncio +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +from smp.logging import configure_logging, get_logger + +load_dotenv(Path(__file__).parent.parent / ".env") + +log = get_logger(__name__) + +DEFAULT_EXTENSIONS = (".py", ".ts", ".tsx", ".js", ".jsx") +DEFAULT_MAX_FILE_SIZE = 1_000_000 + + +async def ingest_directory( + directory: str, + *, + neo4j_uri: str | None = None, + neo4j_user: str | None = None, + neo4j_password: str | None = None, + extensions: tuple[str, ...] = DEFAULT_EXTENSIONS, + max_file_size: int = DEFAULT_MAX_FILE_SIZE, + clear: bool = False, +) -> dict[str, int]: + """Walk *directory*, parse all matching files, and build the graph.""" + from smp.engine.enricher import StaticSemanticEnricher + from smp.engine.graph_builder import DefaultGraphBuilder + from smp.parser.registry import ParserRegistry + from smp.store.graph.neo4j_store import Neo4jGraphStore + + registry = ParserRegistry() + graph_store = Neo4jGraphStore( + uri=neo4j_uri or os.environ.get("SMP_NEO4J_URI", "bolt://localhost:7687"), + user=neo4j_user or os.environ.get("SMP_NEO4J_USER", "neo4j"), + password=neo4j_password or os.environ.get("SMP_NEO4J_PASSWORD", ""), + ) + builder = DefaultGraphBuilder(graph_store) + enricher = StaticSemanticEnricher() + + await graph_store.connect() + if clear: + await graph_store.clear() + log.warning("graph_cleared") + + root = Path(directory).resolve() + if not root.is_dir(): + raise ValueError(f"Not a directory: {root}") + + stats = {"files": 0, "nodes": 0, "edges": 0, "errors": 0, "skipped": 0} + t0 = time.monotonic() + + for file_path in sorted(root.rglob("*")): + if not file_path.is_file(): + continue + if file_path.suffix.lower() not in extensions: + continue + + try: + size = file_path.stat().st_size + except OSError: + continue + if size > max_file_size: + log.warning("file_too_large", file=str(file_path), size=size) + stats["skipped"] += 1 + continue + + parts = file_path.relative_to(root).parts + if any( + p.startswith(".") or p in ("node_modules", "__pycache__", "venv", ".venv", "dist", "build") for p in parts + ): + continue + + rel_path = str(file_path.relative_to(root)) + doc = registry.parse_file(str(file_path)) + doc = type(doc)( + file_path=rel_path, + language=doc.language, + nodes=[ + type(n)( + id=n.id.replace(str(file_path), rel_path), + type=n.type, + file_path=rel_path, + structural=n.structural, + semantic=n.semantic, + ) + for n in doc.nodes + ], + edges=[ + type(e)( + source_id=e.source_id.replace(str(file_path), rel_path), + target_id=e.target_id.replace(str(file_path), rel_path), + type=e.type, + metadata=e.metadata, + ) + for e in doc.edges + ], + errors=doc.errors, + ) + + if doc.nodes or doc.edges: + await builder.ingest_document(doc) + + if doc.nodes: + enriched = await enricher.enrich_batch(doc.nodes) + for en in enriched: + if en.semantic.status == "enriched": + await graph_store.upsert_node(en) + + stats["files"] += 1 + stats["nodes"] += len(doc.nodes) + stats["edges"] += len(doc.edges) + stats["errors"] += len(doc.errors) + + resolved = await builder.resolve_pending_edges() + if resolved: + log.info("post_ingest_edges_resolved", count=resolved) + + elapsed = time.monotonic() - t0 + log.info( + "ingest_complete", + directory=str(root), + files=stats["files"], + nodes=stats["nodes"], + edges=stats["edges"], + errors=stats["errors"], + skipped=stats["skipped"], + elapsed_s=round(elapsed, 2), + ) + + await graph_store.close() + return stats + + +def main() -> None: + parser = argparse.ArgumentParser(prog="smp", description="Structural Memory Protocol CLI") + sub = parser.add_subparsers(dest="command") + + ingest_cmd = sub.add_parser("ingest", help="Parse a directory and build the graph") + ingest_cmd.add_argument("directory", help="Root directory to ingest") + ingest_cmd.add_argument( + "--neo4j-uri", type=str, help="Neo4j URI (defaults to SMP_NEO4J_URI env var or bolt://localhost:7687)" + ) + ingest_cmd.add_argument("--neo4j-user", type=str, help="Neo4j user (defaults to SMP_NEO4J_USER env var or neo4j)") + ingest_cmd.add_argument( + "--neo4j-password", type=str, help="Neo4j password (defaults to SMP_NEO4J_PASSWORD env var)" + ) + ingest_cmd.add_argument("--clear", action="store_true", help="Clear graph before ingesting") + ingest_cmd.add_argument("--json-log", action="store_true", help="JSON structured logging") + ingest_cmd.add_argument("--max-size", type=int, default=DEFAULT_MAX_FILE_SIZE, help="Max file size in bytes") + + serve_cmd = sub.add_parser("serve", help="Start the SMP JSON-RPC server") + serve_cmd.add_argument("--host", default="0.0.0.0", help="Bind host") + serve_cmd.add_argument("--port", type=int, default=8420, help="Bind port") + serve_cmd.add_argument( + "--neo4j-uri", type=str, help="Neo4j URI (defaults to SMP_NEO4J_URI env var or bolt://localhost:7687)" + ) + serve_cmd.add_argument("--neo4j-user", type=str, help="Neo4j user (defaults to SMP_NEO4J_USER env var or neo4j)") + serve_cmd.add_argument("--neo4j-password", type=str, help="Neo4j password (defaults to SMP_NEO4J_PASSWORD env var)") + serve_cmd.add_argument("--safety", action="store_true", help="Enable agent safety protocol") + serve_cmd.add_argument("--json-log", action="store_true", help="JSON structured logging") + + run_cmd = sub.add_parser("run", help="Run a command in the background") + run_cmd.add_argument("name", help="Name for this background process") + run_cmd.add_argument("command", nargs="+", help="Command and arguments to run") + run_cmd.add_argument("--cwd", type=str, help="Working directory") + run_cmd.add_argument("--env", nargs="+", help="Environment variables as KEY=VALUE") + run_cmd.add_argument("--restart", action="store_true", help="Restart if already running") + + list_cmd = sub.add_parser("ps", help="List running background processes") + list_cmd.add_argument("--name", help="Show specific process details") + + stop_cmd = sub.add_parser("stop", help="Stop a background process") + stop_cmd.add_argument("name", help="Name of the process to stop") + + logs_cmd = sub.add_parser("logs", help="View logs for a background process") + logs_cmd.add_argument("name", help="Name of the process") + logs_cmd.add_argument("--stream", action="store_true", help="Stream new output") + + args = parser.parse_args() + if not args.command: + parser.print_help() + sys.exit(1) + + configure_logging(json=getattr(args, "json_log", False)) + + if args.command == "ingest": + stats = asyncio.run( + ingest_directory( + args.directory, + neo4j_uri=args.neo4j_uri, + neo4j_user=args.neo4j_user, + neo4j_password=args.neo4j_password, + clear=args.clear, + max_file_size=args.max_size, + ) + ) + print( + f"\nIngested {stats['files']} files: {stats['nodes']} nodes, " + f"{stats['edges']} edges, {stats['errors']} errors" + ) + + elif args.command == "serve": + import os + + import uvicorn + + # Only set env vars if explicitly provided (to allow env var fallbacks) + if args.neo4j_uri: + os.environ["SMP_NEO4J_URI"] = args.neo4j_uri + if args.neo4j_user: + os.environ["SMP_NEO4J_USER"] = args.neo4j_user + if args.neo4j_password: + os.environ["SMP_NEO4J_PASSWORD"] = args.neo4j_password + + from smp.protocol.server import create_app + + application = create_app( + neo4j_uri=args.neo4j_uri, + neo4j_user=args.neo4j_user, + neo4j_password=args.neo4j_password, + safety_enabled=getattr(args, "safety", False), + ) + uvicorn.run(application, host=args.host, port=args.port) + + elif args.command == "run": + from smp.core.background import BackgroundRunner + + env = {} + if args.env: + for e in args.env: + if "=" in e: + key, val = e.split("=", 1) + env[key] = val + + runner = BackgroundRunner() + cwd = Path(args.cwd) if args.cwd else None + + try: + bg_proc = runner.start(args.name, args.command, cwd=cwd, env=env or None) + print(f"Started {args.name}: pid={bg_proc.pid}") + except ValueError as e: + if args.restart: + bg_proc = runner.restart(args.name) + print(f"Restarted {args.name}: pid={bg_proc.pid}") + else: + print(f"Error: {e}") + sys.exit(1) + + elif args.command == "ps": + from smp.core.background import BackgroundRunner + + runner = BackgroundRunner() + if args.name: + proc = runner.get(args.name) + if proc: + print(f"{args.name}: pid={proc['pid']}, running={proc['running']}") + print(f" command: {' '.join(proc['command'])}") + else: + print(f"Process not found: {args.name}") + else: + all_procs = runner.list() + if all_procs: + for name, info in all_procs.items(): + print(f"{name}: pid={info['pid']}, running={info['running']}") + else: + print("No background processes running") + + elif args.command == "stop": + from smp.core.background import BackgroundRunner + + runner = BackgroundRunner() + if runner.stop(args.name): + print(f"Stopped {args.name}") + else: + print(f"Process not found: {args.name}") + sys.exit(1) + + elif args.command == "logs": + from smp.core.background import BackgroundRunner + + runner = BackgroundRunner() + try: + logs = runner.logs(args.name) + if logs["stdout"]: + print(f"=== stdout ===\n{logs['stdout']}") + if logs["stderr"]: + print(f"=== stderr ===\n{logs['stderr']}") + except ValueError as e: + print(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/smp/client.py b/smp/client.py new file mode 100644 index 0000000..0566545 --- /dev/null +++ b/smp/client.py @@ -0,0 +1,201 @@ +"""SMP Client — Python SDK for the Structural Memory Protocol. + +Provides an async client for interacting with the SMP JSON-RPC server. + +Usage:: + + from smp.client import SMPClient + + async with SMPClient("http://localhost:8420") as client: + ctx = await client.get_context("src/auth.py") + results = await client.locate("authentication logic") + await client.update("src/auth.py", content=new_source) +""" + +from __future__ import annotations + +from typing import Any + +import httpx +import msgspec + +from smp.core.models import ( + ContextParams, + FlowParams, + ImpactParams, + JsonRpcRequest, + JsonRpcResponse, + Language, + LocateParams, + NavigateParams, + TraceParams, + UpdateParams, +) + + +class SMPClientError(Exception): + """Raised when the SMP server returns an error.""" + + def __init__(self, code: int, message: str, data: Any = None) -> None: + self.code = code + self.data = data + super().__init__(f"JSON-RPC error {code}: {message}") + + +class SMPClient: + """Async client for the Structural Memory Protocol server. + + Args: + base_url: Server base URL (e.g. ``"http://localhost:8420"``). + timeout: Request timeout in seconds. + """ + + def __init__(self, base_url: str = "http://localhost:8420", timeout: float = 30.0) -> None: + self._base_url = base_url.rstrip("/") + self._client: httpx.AsyncClient | None = None + self._timeout = timeout + self._req_id = 0 + + async def connect(self) -> None: + self._client = httpx.AsyncClient(base_url=self._base_url, timeout=self._timeout) + + async def close(self) -> None: + if self._client: + await self._client.aclose() + self._client = None + + async def __aenter__(self) -> SMPClient: + await self.connect() + return self + + async def __aexit__(self, *_: Any) -> None: + await self.close() + + def _ensure_connected(self) -> httpx.AsyncClient: + if not self._client: + raise RuntimeError("Client not connected. Use 'async with SMPClient(...)' or call connect().") + return self._client + + async def _rpc(self, method: str, params: dict[str, Any]) -> Any: + """Send a JSON-RPC request and return the result.""" + self._req_id += 1 + req = JsonRpcRequest(method=method, params=params, id=self._req_id) + body = msgspec.json.encode(req) + + client = self._ensure_connected() + resp = await client.post("/rpc", content=body, headers={"Content-Type": "application/json"}) + + if resp.status_code == 204: + return None + + rpc_resp = msgspec.json.decode(resp.content, type=JsonRpcResponse) + if rpc_resp.error: + raise SMPClientError(rpc_resp.error.code, rpc_resp.error.message, rpc_resp.error.data) + return rpc_resp.result + + # ----------------------------------------------------------------------- + # Protocol methods + # ----------------------------------------------------------------------- + + async def navigate(self, entity_id: str) -> dict[str, Any]: + """Get a node and its immediate neighbours.""" + return await self._rpc("smp/navigate", msgspec.to_builtins(NavigateParams(query=entity_id))) + + async def trace( + self, + start_id: str, + edge_type: str = "CALLS", + depth: int = 3, + direction: str = "outgoing", + ) -> list[dict[str, Any]]: + """Recursive traversal (e.g. full call graph).""" + return await self._rpc( + "smp/trace", + msgspec.to_builtins( + TraceParams( + start=start_id, + relationship=edge_type, + depth=depth, + direction=direction, + ) + ), + ) + + async def get_context( + self, + file_path: str, + scope: str = "edit", + depth: int = 2, + ) -> dict[str, Any]: + """Aggregate structural context for safe editing.""" + return await self._rpc( + "smp/context", + msgspec.to_builtins( + ContextParams( + file_path=file_path, + scope=scope, + depth=depth, + ) + ), + ) + + async def assess_impact(self, entity_id: str, change_type: str = "delete") -> dict[str, Any]: + """Find blast radius of a change.""" + return await self._rpc( + "smp/impact", msgspec.to_builtins(ImpactParams(entity=entity_id, change_type=change_type)) + ) + + async def locate(self, query: str, top_k: int = 5) -> list[dict[str, Any]]: + """Search by semantic intent — vector search mapping back to graph nodes.""" + return await self._rpc("smp/locate", msgspec.to_builtins(LocateParams(query=query, top_k=top_k))) + + async def find_flow(self, start: str, end: str, max_depth: int = 20) -> list[list[dict[str, Any]]]: + """Find paths between two nodes.""" + return await self._rpc( + "smp/flow", + msgspec.to_builtins( + FlowParams( + start=start, + end=end, + ) + ), + ) + + async def update( + self, + file_path: str, + content: str = "", + language: str = "python", + ) -> dict[str, Any]: + """Notify the server of a file change — incremental graph update. + + If *content* is provided it is parsed directly; otherwise the server + reads the file from disk. + """ + lang = Language(language) if language else Language.PYTHON + return await self._rpc( + "smp/update", + msgspec.to_builtins( + UpdateParams( + file_path=file_path, + content=content, + language=lang, + ) + ), + ) + + # ----------------------------------------------------------------------- + # Convenience endpoints + # ----------------------------------------------------------------------- + + async def health(self) -> dict[str, str]: + """Check server health.""" + client = self._ensure_connected() + resp = await client.get("/health") + return resp.json() + + async def stats(self) -> dict[str, int]: + """Get graph statistics (node/edge counts).""" + client = self._ensure_connected() + resp = await client.get("/stats") + return resp.json() diff --git a/smp/core/__init__.py b/smp/core/__init__.py new file mode 100644 index 0000000..b0685c0 --- /dev/null +++ b/smp/core/__init__.py @@ -0,0 +1 @@ +"""Core data models and types.""" diff --git a/smp/core/background.py b/smp/core/background.py new file mode 100644 index 0000000..b87eb4b --- /dev/null +++ b/smp/core/background.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import contextlib +import json +import os +import signal +import subprocess +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class BackgroundProcess: + name: str + command: list[str] + pid: int + cwd: Path | None = None + env: dict[str, str] = field(default_factory=dict) + started_at: float = field(default_factory=time.time) + + +class BackgroundRunner: + """Manages long-running background processes without blocking the agent.""" + + def __init__(self) -> None: + self._base_dir = Path.home() / ".smp" / "runs" + self._processes: dict[str, BackgroundProcess] = {} + self._open_files: dict[str, tuple[Any, Any]] = {} + self._load() + + def _state_file(self) -> Path: + return self._base_dir / "state.json" + + def _load(self) -> None: + f = self._state_file() + if f.exists(): + with open(f) as fp: + data = json.load(fp) + for name, item in data.items(): + proc = BackgroundProcess( + name=name, + command=item["command"], + pid=item["pid"], + cwd=Path(item["cwd"]) if item.get("cwd") else None, + env=item.get("env", {}), + started_at=item.get("started_at", 0), + ) + if self._is_running(proc.pid): + self._processes[name] = proc + + def _save(self) -> None: + self._base_dir.mkdir(parents=True, exist_ok=True) + data = { + name: { + "command": proc.command, + "pid": proc.pid, + "cwd": str(proc.cwd) if proc.cwd else None, + "env": proc.env, + "started_at": proc.started_at, + } + for name, proc in self._processes.items() + } + with open(self._state_file(), "w") as fp: + json.dump(data, fp) + + def start( + self, + name: str, + command: list[str], + cwd: Path | None = None, + env: dict[str, str] | None = None, + ) -> BackgroundProcess: + """Start a command in background and return immediately.""" + if name in self._processes: + raise ValueError(f"Process already running: {name}") + + self._base_dir.mkdir(parents=True, exist_ok=True) + run_dir = self._base_dir / name + run_dir.mkdir(parents=True, exist_ok=True) + + full_env = os.environ.copy() + if env: + full_env.update(env) + + with open(run_dir / "stdout.log", "wb") as stdout_file, open(run_dir / "stderr.log", "wb") as stderr_file: + proc = subprocess.Popen( + command, + stdout=stdout_file, + stderr=stderr_file, + cwd=cwd or run_dir, + env=full_env, + start_new_session=True, + text=True, + ) + + self._open_files[name] = (None, None) # Files closed after Popen + + bg_proc = BackgroundProcess( + name=name, + command=command, + pid=proc.pid, + cwd=cwd, + env=env, + ) + self._processes[name] = bg_proc + self._save() + return bg_proc + + def stop(self, name: str) -> bool: + """Stop a running process by name.""" + if name not in self._processes: + return False + + bg_proc = self._processes[name] + with contextlib.suppress(ProcessLookupError): + os.kill(bg_proc.pid, signal.SIGTERM) + + self._open_files.pop(name, None) + + del self._processes[name] + self._save() + return True + + def restart(self, name: str) -> BackgroundProcess: + """Restart a stopped or existing process.""" + if name not in self._processes: + raise ValueError(f"Unknown process: {name}") + + bg_proc = self._processes[name] + self.stop(name) + return self.start(name, bg_proc.command, bg_proc.cwd, bg_proc.env) + + def list(self) -> dict[str, dict[str, Any]]: + """List all managed processes.""" + result = {} + for name, proc in self._processes.items(): + result[name] = { + "pid": proc.pid, + "command": proc.command, + "cwd": str(proc.cwd) if proc.cwd else None, + "running": self._is_running(proc.pid), + } + return result + + def get(self, name: str) -> dict[str, Any] | None: + """Get details of a specific process.""" + if name not in self._processes: + return None + + proc = self._processes[name] + return { + "pid": proc.pid, + "command": proc.command, + "cwd": str(proc.cwd) if proc.cwd else None, + "running": self._is_running(proc.pid), + } + + def logs(self, name: str) -> dict[str, str]: + """Get stdout/stderr log contents for a process.""" + if name not in self._processes and not (self._base_dir / name).exists(): + raise ValueError(f"Unknown process: {name}") + + run_dir = self._base_dir / name + stdout = "" + stderr = "" + if (run_dir / "stdout.log").exists(): + with open(run_dir / "stdout.log") as fp: + stdout = fp.read() + if (run_dir / "stderr.log").exists(): + with open(run_dir / "stderr.log") as fp: + stderr = fp.read() + return {"stdout": stdout, "stderr": stderr} + + def _is_running(self, pid: int) -> bool: + """Check if a process is still running.""" + try: + os.kill(pid, 0) + return True + except ProcessLookupError: + return False diff --git a/smp/core/merkle.py b/smp/core/merkle.py new file mode 100644 index 0000000..0f2fe42 --- /dev/null +++ b/smp/core/merkle.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import hashlib +from typing import Any + +from smp.core.models import GraphNode, NodeType +from smp.logging import get_logger + +log = get_logger(__name__) + + +class MerkleTree: + """SHA-256 Merkle Tree for structural consistency checks.""" + + def __init__(self) -> None: + self._leaf_hashes: list[tuple[str, str]] = [] + self._levels: list[list[str]] = [] + + def _hash_single(self, data: str) -> str: + return hashlib.sha256(data.encode()).hexdigest() + + def _hash_pair(self, left: str, right: str) -> str: + return hashlib.sha256(f"{left}{right}".encode()).hexdigest() + + def build(self, nodes: list[GraphNode]) -> None: + """Build a SHA-256 tree where leaves are file nodes.""" + file_nodes = sorted([n for n in nodes if n.type == NodeType.FILE], key=lambda n: n.id) + + self._leaf_hashes = [(n.id, self._hash_single(f"{n.id}{n.semantic.source_hash}")) for n in file_nodes] + + current_level = [h for _, h in self._leaf_hashes] + self._levels = [current_level] + + while len(current_level) > 1: + next_level = [] + for i in range(0, len(current_level), 2): + left = current_level[i] + right = current_level[i + 1] if i + 1 < len(current_level) else left + next_level.append(self._hash_pair(left, right)) + current_level = next_level + self._levels.append(current_level) + + def hash(self) -> str: + """Return the root hash.""" + if not self._levels: + return "" + return self._levels[-1][0] + + def diff(self, other: MerkleTree) -> dict[str, set[str]]: + """Perform an O(log n) comparison to return {added, removed, modified} node IDs.""" + local_map = dict(self._leaf_hashes) + remote_map = dict(other._leaf_hashes) + + local_ids = set(local_map.keys()) + remote_ids = set(remote_map.keys()) + + added = remote_ids - local_ids + removed = local_ids - remote_ids + + common_ids = local_ids & remote_ids + modified = {nid for nid in common_ids if local_map[nid] != remote_map[nid]} + + return {"added": added, "removed": removed, "modified": modified} + + def export(self) -> dict[str, Any]: + """Return a serializable format of the tree for distribution.""" + return {"root": self.hash(), "levels": self._levels, "leaf_hashes": self._leaf_hashes} + + def import_data(self, data: dict[str, Any]) -> None: + """Reconstruct the tree from exported data.""" + self._levels = data["levels"] + self._leaf_hashes = [tuple(x) for x in data["leaf_hashes"]] + + +class MerkleIndex: + """Sync management using Merkle Trees.""" + + def __init__(self, tree: MerkleTree) -> None: + self._tree = tree + + def sync(self, remote_hash: str) -> dict[str, set[str]] | None: + """Compare local root hash with remote, if different, trigger diff.""" + if self._tree.hash() == remote_hash: + return None + + log.info("merkle_sync_diff_triggered", local=self._tree.hash(), remote=remote_hash) + return None + + def apply_patch(self, patch: dict[str, Any]) -> None: + """Update local state based on a patch.""" + log.info("merkle_apply_patch", patch_keys=list(patch.keys())) diff --git a/smp/core/models.py b/smp/core/models.py new file mode 100644 index 0000000..54615f1 --- /dev/null +++ b/smp/core/models.py @@ -0,0 +1,644 @@ +"""Core data models for SMP(3). + +Partitioned schema: structural vs semantic properties. +All models use msgspec.Struct for zero-cost serialization and validation. +""" + +from __future__ import annotations + +import enum +from typing import Any + +import msgspec + +# --------------------------------------------------------------------------- +# Enumerations (SMP(3) schema) +# --------------------------------------------------------------------------- + + +class NodeType(enum.StrEnum): + """Node types per SMP(3) specification.""" + + REPOSITORY = "Repository" + PACKAGE = "Package" + FILE = "File" + CLASS = "Class" + FUNCTION = "Function" + VARIABLE = "Variable" + INTERFACE = "Interface" + TEST = "Test" + CONFIG = "Config" + + +class EdgeType(enum.StrEnum): + """Relationship types per SMP(3) specification.""" + + CONTAINS = "CONTAINS" + IMPORTS = "IMPORTS" + DEFINES = "DEFINES" + CALLS = "CALLS" + CALLS_RUNTIME = "CALLS_RUNTIME" + INHERITS = "INHERITS" + IMPLEMENTS = "IMPLEMENTS" + DEPENDS_ON = "DEPENDS_ON" + TESTS = "TESTS" + USES = "USES" + REFERENCES = "REFERENCES" + + +class Language(enum.StrEnum): + """Supported source languages.""" + + PYTHON = "python" + TYPESCRIPT = "typescript" + UNKNOWN = "unknown" + + +# --------------------------------------------------------------------------- +# Structural properties (coordinates, signature, complexity) +# --------------------------------------------------------------------------- + + +class StructuralProperties(msgspec.Struct, frozen=True): + """Immutable structural coordinates of a code entity.""" + + name: str = "" + file: str = "" + signature: str = "" + start_line: int = 0 + end_line: int = 0 + complexity: int = 0 + lines: int = 0 + parameters: int = 0 + + +# --------------------------------------------------------------------------- +# Semantic properties (docstrings, comments, decorators, annotations, tags) +# --------------------------------------------------------------------------- + + +class InlineComment(msgspec.Struct, frozen=True): + """A single inline comment extracted from source.""" + + line: int = 0 + text: str = "" + + +class Annotations(msgspec.Struct, frozen=True): + """Structured type annotations extracted from a function/method.""" + + params: dict[str, str] = msgspec.field(default_factory=dict) + returns: str | None = None + throws: list[str] = msgspec.field(default_factory=list) + + +class SemanticProperties(msgspec.Struct): + """Mutable semantic metadata extracted via static AST analysis.""" + + status: str = "no_metadata" + docstring: str = "" + description: str | None = None + inline_comments: list[InlineComment] = msgspec.field(default_factory=list) + decorators: list[str] = msgspec.field(default_factory=list) + annotations: Annotations | None = None + tags: list[str] = msgspec.field(default_factory=list) + score: float = 0.0 + manually_set: bool = False + source_hash: str = "" + enriched_at: str = "" + + +# --------------------------------------------------------------------------- +# Graph primitives +# --------------------------------------------------------------------------- + + +class GraphNode(msgspec.Struct): + """A single node in the structural graph with partitioned properties.""" + + id: str + type: NodeType + file_path: str + structural: StructuralProperties = msgspec.field(default_factory=StructuralProperties) + semantic: SemanticProperties = msgspec.field(default_factory=SemanticProperties) + + def fingerprint(self) -> str: + """Deterministic identity key for deduplication.""" + return f"{self.file_path}::{self.type.value}::{self.structural.name}::{self.structural.start_line}" + + +class GraphEdge(msgspec.Struct): + """A directed edge between two nodes.""" + + source_id: str + target_id: str + type: EdgeType + metadata: dict[str, str] = msgspec.field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Document — the unit of parsing +# --------------------------------------------------------------------------- + + +class ParseError(msgspec.Struct): + """Non-fatal error encountered during parsing.""" + + message: str + line: int = 0 + column: int = 0 + severity: str = "error" + + +class Document(msgspec.Struct): + """A parsed source file with its extracted graph elements.""" + + file_path: str + language: Language = Language.UNKNOWN + content_hash: str = "" + nodes: list[GraphNode] = msgspec.field(default_factory=list) + edges: list[GraphEdge] = msgspec.field(default_factory=list) + errors: list[ParseError] = msgspec.field(default_factory=list) + + +# --------------------------------------------------------------------------- +# JSON-RPC 2.0 protocol models +# --------------------------------------------------------------------------- + + +class JsonRpcRequest(msgspec.Struct): + """JSON-RPC 2.0 request envelope.""" + + jsonrpc: str = "2.0" + method: str = "" + params: dict[str, Any] = msgspec.field(default_factory=dict) + id: int | str | None = None + + +class JsonRpcError(msgspec.Struct): + """JSON-RPC 2.0 error object.""" + + code: int + message: str + data: Any = None + + +class JsonRpcResponse(msgspec.Struct): + """JSON-RPC 2.0 response envelope.""" + + jsonrpc: str = "2.0" + result: Any = None + error: JsonRpcError | None = None + id: int | str | None = None + + +# --------------------------------------------------------------------------- +# Memory Management params +# --------------------------------------------------------------------------- + + +class UpdateParams(msgspec.Struct): + """Parameters for smp/update.""" + + file_path: str + content: str = "" + change_type: str = "modified" + language: Language = Language.PYTHON + + +class BatchUpdateParams(msgspec.Struct): + """Parameters for smp/batch_update.""" + + changes: list[dict[str, str]] = msgspec.field(default_factory=list) + + +class ReindexParams(msgspec.Struct): + """Parameters for smp/reindex.""" + + scope: str = "full" + + +# --------------------------------------------------------------------------- +# Enrichment params +# --------------------------------------------------------------------------- + + +class EnrichParams(msgspec.Struct): + """Parameters for smp/enrich.""" + + node_id: str + force: bool = False + + +class EnrichBatchParams(msgspec.Struct): + """Parameters for smp/enrich/batch.""" + + scope: str = "full" + force: bool = False + + +class EnrichStaleParams(msgspec.Struct): + """Parameters for smp/enrich/stale.""" + + scope: str = "full" + + +class EnrichStatusParams(msgspec.Struct): + """Parameters for smp/enrich/status.""" + + scope: str = "full" + + +# --------------------------------------------------------------------------- +# Annotation params +# --------------------------------------------------------------------------- + + +class AnnotateParams(msgspec.Struct): + """Parameters for smp/annotate.""" + + node_id: str + description: str = "" + tags: list[str] = msgspec.field(default_factory=list) + force: bool = False + + +class AnnotateBulkItem(msgspec.Struct): + """Single annotation in a bulk request.""" + + node_id: str + description: str = "" + tags: list[str] = msgspec.field(default_factory=list) + + +class AnnotateBulkParams(msgspec.Struct): + """Parameters for smp/annotate/bulk.""" + + annotations: list[AnnotateBulkItem] = msgspec.field(default_factory=list) + + +class TagParams(msgspec.Struct): + """Parameters for smp/tag.""" + + scope: str = "" + tags: list[str] = msgspec.field(default_factory=list) + action: str = "add" + + +# --------------------------------------------------------------------------- +# Session / Safety params +# --------------------------------------------------------------------------- + + +class SessionOpenParams(msgspec.Struct): + """Parameters for smp/session/open.""" + + agent_id: str = "" + task: str = "" + scope: list[str] = msgspec.field(default_factory=list) + mode: str = "read" + + +class SessionCloseParams(msgspec.Struct): + """Parameters for smp/session/close.""" + + session_id: str = "" + status: str = "completed" + + +class SessionRecoverParams(msgspec.Struct): + """Parameters for smp/session/recover.""" + + session_id: str = "" + + +class GuardCheckParams(msgspec.Struct): + target: str = "" + intended_change: str = "" + + +class DryRunParams(msgspec.Struct): + """Parameters for smp/dryrun.""" + + session_id: str = "" + file_path: str = "" + proposed_content: str = "" + change_summary: str = "" + + +class CheckpointParams(msgspec.Struct): + """Parameters for smp/checkpoint.""" + + session_id: str = "" + files: list[str] = msgspec.field(default_factory=list) + + +class RollbackParams(msgspec.Struct): + """Parameters for smp/rollback.""" + + session_id: str = "" + checkpoint_id: str = "" + + +class LockParams(msgspec.Struct): + """Parameters for smp/lock and smp/unlock.""" + + session_id: str = "" + files: list[str] = msgspec.field(default_factory=list) + + +class AuditGetParams(msgspec.Struct): + """Parameters for smp/audit/get.""" + + audit_log_id: str = "" + + +# --------------------------------------------------------------------------- +# Query params +# --------------------------------------------------------------------------- + + +class NavigateParams(msgspec.Struct): + """Parameters for smp/navigate.""" + + query: str = "" + include_relationships: bool = True + + +class TraceParams(msgspec.Struct): + """Parameters for smp/trace.""" + + start: str = "" + relationship: str = "CALLS" + depth: int = 3 + direction: str = "outgoing" + + +class ContextParams(msgspec.Struct): + """Parameters for smp/context.""" + + file_path: str = "" + scope: str = "edit" + depth: int = 2 + + +class ImpactParams(msgspec.Struct): + """Parameters for smp/impact.""" + + entity: str = "" + change_type: str = "delete" + + +class LocateParams(msgspec.Struct): + """Parameters for smp/locate.""" + + query: str = "" + fields: list[str] = msgspec.field(default_factory=lambda: ["name", "docstring", "tags"]) + node_types: list[str] = msgspec.field(default_factory=list) + top_k: int = 5 + + +class SearchParams(msgspec.Struct): + """Parameters for smp/search.""" + + query: str = "" + match: str = "any" + filter: dict[str, Any] = msgspec.field(default_factory=dict) + top_k: int = 5 + + +class FlowParams(msgspec.Struct): + """Parameters for smp/flow.""" + + start: str = "" + end: str = "" + flow_type: str = "data" + + +# --------------------------------------------------------------------------- +# SMP(3) Runtime Models +# --------------------------------------------------------------------------- + + +class RuntimeEdge(msgspec.Struct): + """Runtime edge tracking actual execution paths.""" + + source_id: str = "" + target_id: str = "" + edge_type: str = "CALLS_RUNTIME" + timestamp: str = "" + session_id: str = "" + trace_id: str = "" + duration_ms: int = 0 + metadata: dict[str, Any] = msgspec.field(default_factory=dict) + + +class RuntimeTrace(msgspec.Struct): + """Complete runtime trace for a session.""" + + trace_id: str = "" + session_id: str = "" + agent_id: str = "" + started_at: str = "" + ended_at: str = "" + edges: list[RuntimeEdge] = msgspec.field(default_factory=list) + nodes_visited: list[str] = msgspec.field(default_factory=list) + + +# --------------------------------------------------------------------------- +# SMP(3) Additional Query Params +# --------------------------------------------------------------------------- + + +class DiffParams(msgspec.Struct): + """Parameters for smp/diff.""" + + from_snapshot: str = "" + to_snapshot: str = "" + scope: str = "full" + + +class PlanParams(msgspec.Struct): + """Parameters for smp/plan.""" + + change_description: str = "" + target_file: str = "" + change_type: str = "refactor" + scope: str = "full" + + +class ConflictParams(msgspec.Struct): + """Parameters for smp/conflict.""" + + entity: str = "" + proposed_change: str = "" + context: dict[str, Any] = msgspec.field(default_factory=dict) + + +class WhyParams(msgspec.Struct): + """Parameters for smp/why.""" + + entity: str = "" + relationship: str = "" + depth: int = 3 + + +class TelemetryParams(msgspec.Struct): + """Parameters for smp/telemetry.""" + + action: str = "get_stats" + node_id: str | None = None + threshold: int | None = None + + +class TelemetryHotParams(msgspec.Struct): + """Parameters for smp/telemetry/hot.""" + + node_id: str + + +class TelemetryNodeParams(msgspec.Struct): + """Parameters for smp/telemetry/node.""" + + node_id: str + + +# --------------------------------------------------------------------------- +# SMP(3) Handoff Models +# --------------------------------------------------------------------------- + + +class ReviewCreateParams(msgspec.Struct): + """Parameters for smp/review/create.""" + + session_id: str = "" + files_changed: list[str] = msgspec.field(default_factory=list) + diff_summary: str = "" + reviewers: list[str] = msgspec.field(default_factory=list) + + +class ReviewApproveParams(msgspec.Struct): + """Parameters for smp/review/approve.""" + + review_id: str = "" + reviewer: str = "" + + +class ReviewRejectParams(msgspec.Struct): + """Parameters for smp/review/reject.""" + + review_id: str = "" + reviewer: str = "" + reason: str = "" + + +class ReviewCommentParams(msgspec.Struct): + """Parameters for smp/review/comment.""" + + review_id: str = "" + author: str = "" + comment: str = "" + file_path: str | None = None + line: int | None = None + + +class PRCreateParams(msgspec.Struct): + """Parameters for smp/pr/create.""" + + review_id: str = "" + title: str = "" + body: str = "" + branch: str = "" + base_branch: str = "main" + + +# --------------------------------------------------------------------------- +# SMP(3) Sandbox Models +# --------------------------------------------------------------------------- + + +class SandboxSpawnParams(msgspec.Struct): + """Parameters for smp/sandbox/spawn.""" + + name: str | None = None + template: str | None = None + files: dict[str, str] = msgspec.field(default_factory=dict) + + +class SandboxExecuteParams(msgspec.Struct): + """Parameters for smp/sandbox/execute.""" + + sandbox_id: str = "" + command: list[str] = msgspec.field(default_factory=list) + stdin: str | None = None + timeout: int | None = None + + +class SandboxKillParams(msgspec.Struct): + """Parameters for smp/sandbox/kill.""" + + execution_id: str = "" + + +# --------------------------------------------------------------------------- +# Community params +# --------------------------------------------------------------------------- + + +class CommunityDetectParams(msgspec.Struct): + """Parameters for smp/community/detect.""" + + resolutions: list[dict[str, Any]] = msgspec.field(default_factory=list) + relationship_types: list[str] = msgspec.field(default_factory=list) + + +class CommunityListParams(msgspec.Struct): + """Parameters for smp/community/list.""" + + level: int | None = None + + +class CommunityGetParams(msgspec.Struct): + """Parameters for smp/community/get.""" + + community_id: str + node_types: list[str] = msgspec.field(default_factory=list) + include_bridges: bool = False + + +class CommunityBoundariesParams(msgspec.Struct): + """Parameters for smp/community/boundaries.""" + + level: int = 0 + min_coupling: float = 0.05 + + +# --------------------------------------------------------------------------- +# Merkle params +# --------------------------------------------------------------------------- + + +class MerkleSyncParams(msgspec.Struct): + """Parameters for smp/sync.""" + + remote_data: dict[str, Any] = msgspec.field(default_factory=dict) + + +class MerkleImportParams(msgspec.Struct): + """Parameters for smp/index/import.""" + + data: dict[str, Any] = msgspec.field(default_factory=dict) + + +class IntegrityCheckParams(msgspec.Struct): + """Parameters for smp/integrity/check.""" + + node_id: str = "" + current_state: dict[str, Any] = msgspec.field(default_factory=dict) + + +class IntegrityBaselineParams(msgspec.Struct): + """Parameters for smp/integrity/baseline.""" + + node_id: str = "" + state: dict[str, Any] = msgspec.field(default_factory=dict) diff --git a/smp/engine/__init__.py b/smp/engine/__init__.py new file mode 100644 index 0000000..c697d8e --- /dev/null +++ b/smp/engine/__init__.py @@ -0,0 +1,11 @@ +"""Engine layer — graph building, enrichment, querying.""" + +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.engine.query import DefaultQueryEngine + +__all__ = [ + "DefaultGraphBuilder", + "DefaultQueryEngine", + "StaticSemanticEnricher", +] diff --git a/smp/engine/community.py b/smp/engine/community.py new file mode 100644 index 0000000..69d75aa --- /dev/null +++ b/smp/engine/community.py @@ -0,0 +1,499 @@ +"""Community detection using Louvain algorithm at two resolution levels. + +Implements two-level community detection (coarse L0, fine L1) per the SMP(3) +specification. Creates Community nodes, MEMBER_OF edges, BRIDGES edges, +and centroid embeddings stored in ChromaDB for smp/locate Phase 0 routing. +""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.core.models import EdgeType, GraphEdge, GraphNode, NodeType +from smp.logging import get_logger +from smp.store.interfaces import GraphStore, VectorStore + +log = get_logger(__name__) + + +@dataclass +class Community: + id: str = "" + level: int = 0 + label: str = "" + parent_community: str = "" + majority_path_prefix: str = "" + top_tags: list[str] = field(default_factory=list) + member_count: int = 0 + file_count: int = 0 + internal_edge_count: int = 0 + external_edge_count: int = 0 + modularity_score: float = 0.0 + centroid_embedding_id: str = "" + detected_at: str = "" + + +class CommunityDetector: + """Two-level Louvain community detection over the structural graph.""" + + def __init__( + self, + graph_store: GraphStore, + vector_store: VectorStore | None = None, + min_community_size: int = 5, + ) -> None: + self._graph = graph_store + self._vector = vector_store + self._min_size = min_community_size + self._communities: dict[str, Community] = {} + self._node_communities_l0: dict[str, str] = {} + self._node_communities_l1: dict[str, str] = {} + self._bridges: list[dict[str, Any]] = [] + + async def detect( + self, + resolutions: list[dict[str, Any]] | None = None, + relationship_types: list[str] | None = None, + ) -> dict[str, Any]: + if resolutions is None: + resolutions = [ + {"level": 0, "resolution": 0.5, "label": "coarse"}, + {"level": 1, "resolution": 1.5, "label": "fine"}, + ] + if relationship_types is None: + relationship_types = ["CALLS", "IMPORTS", "DEFINES"] + + all_nodes = await self._graph.find_nodes() + if not all_nodes: + return { + "nodes_assigned": 0, + "bridge_edges": 0, + "levels": {}, + "coarse_communities": [], + "fine_communities": [], + } + + edge_types = [EdgeType(rt) for rt in relationship_types if rt in EdgeType._value2member_map_] + adjacency = await self._build_adjacency(all_nodes, edge_types) + + all_results: dict[str, dict[str, Any]] = {} + for res_config in resolutions: + level = res_config.get("level", 0) + resolution = res_config.get("resolution", 1.0) + label = res_config.get("label", "coarse" if level == 0 else "fine") + + assignments = self._louvain(all_nodes, adjacency, resolution) + communities = self._build_communities(assignments, all_nodes, adjacency, level, label) + + if level == 0: + self._node_communities_l0 = assignments + else: + self._node_communities_l1 = assignments + + for comm in communities.values(): + self._communities[comm.id] = comm + await self._store_community_node(comm) + await self._write_member_of_edges(comm, all_nodes, level) + + all_results[str(level)] = { + "communities_found": len(communities), + "modularity": self._compute_modularity(assignments, adjacency), + } + + self._bridges = await self._detect_bridges(all_nodes, adjacency) + await self._write_bridges_edges() + + if self._vector is not None: + await self._compute_centroids(all_nodes) + + coarse = [ + { + "id": c.id, + "label": c.label, + "member_count": c.member_count, + "fine_children": sum(1 for fc in self._communities.values() if fc.parent_community == c.id), + } + for c in self._communities.values() + if c.level == 0 + ] + fine = [ + {"id": c.id, "parent": c.parent_community, "label": c.label, "member_count": c.member_count} + for c in self._communities.values() + if c.level == 1 + ] + + total_assigned = len(self._node_communities_l0) + return { + "nodes_assigned": total_assigned, + "bridge_edges": len(self._bridges), + "levels": all_results, + "coarse_communities": coarse, + "fine_communities": fine, + } + + async def list_communities(self, level: int | None = None) -> dict[str, Any]: + communities = list(self._communities.values()) + if level is not None: + communities = [c for c in communities if c.level == level] + return { + "total": len(communities), + "communities": [ + { + "id": c.id, + "level": c.level, + "parent_community": c.parent_community, + "label": c.label, + "majority_path_prefix": c.majority_path_prefix, + "top_tags": c.top_tags, + "member_count": c.member_count, + "file_count": c.file_count, + "internal_edge_count": c.internal_edge_count, + "external_edge_count": c.external_edge_count, + "modularity_score": c.modularity_score, + "bridge_communities": [b["to_community"] for b in self._bridges if b["from_community"] == c.id], + } + for c in communities + ], + } + + async def get_community( + self, + community_id: str, + node_types: list[str] | None = None, + include_bridges: bool = False, + ) -> dict[str, Any] | None: + comm = self._communities.get(community_id) + if not comm: + return None + + assignments = self._node_communities_l1 if comm.level == 1 else self._node_communities_l0 + member_ids = [nid for nid, cid in assignments.items() if cid == community_id] + + members: list[dict[str, Any]] = [] + for mid in member_ids: + node = await self._graph.get_node(mid) + if node is None: + continue + if node_types and node.type.value not in node_types: + continue + members.append( + { + "id": node.id, + "type": node.type.value, + "name": node.structural.name, + "file": node.file_path, + "pagerank": 0.0, + "heat_score": 0, + } + ) + + bridge_edges = [] + if include_bridges: + bridge_edges = [ + b for b in self._bridges if b["from_community"] == community_id or b["to_community"] == community_id + ] + + return { + "community_id": comm.id, + "level": comm.level, + "parent_community": comm.parent_community, + "label": comm.label, + "member_count": comm.member_count, + "members": members, + "bridge_edges": bridge_edges, + } + + async def get_boundaries(self, level: int = 0, min_coupling: float = 0.05) -> dict[str, Any]: + level_bridges = [ + b + for b in self._bridges + if any( + self._communities.get(b["from_community"], Community()).level == level, + self._communities.get(b["to_community"], Community()).level == level, + ) + ] + filtered = [b for b in level_bridges if b.get("coupling_weight", 0) >= min_coupling] + return { + "level": level, + "boundaries": filtered, + } + + async def _build_adjacency( + self, + nodes: list[GraphNode], + edge_types: list[EdgeType], + ) -> dict[str, set[str]]: + adj: dict[str, set[str]] = defaultdict(set) + for node in nodes: + adj[node.id] = set() + for node in nodes: + for et in edge_types: + edges = await self._graph.get_edges(node.id, et, direction="outgoing") + for edge in edges: + adj[node.id].add(edge.target_id) + if edge.target_id in adj: + adj[edge.target_id].add(node.id) + return adj + + def _louvain( + self, + nodes: list[GraphNode], + adjacency: dict[str, set[str]], + resolution: float, + ) -> dict[str, str]: + community: dict[str, int] = {} + for i, node in enumerate(nodes): + community[node.id] = i + + improved = True + iterations = 0 + max_iterations = 50 + + while improved and iterations < max_iterations: + improved = False + iterations += 1 + for node in nodes: + nid = node.id + current_comm = community[nid] + neighbor_comms: dict[int, int] = defaultdict(int) + for neighbor_id in adjacency.get(nid, set()): + neighbor_comms[community[neighbor_id]] += 1 + + if not neighbor_comms: + continue + + best_comm = current_comm + best_gain = 0.0 + total_edges = sum(neighbor_comms.values()) + ki = len(adjacency.get(nid, set())) + + for comm, ki_comm in neighbor_comms.items(): + sigma_tot = sum(1 for n, c in community.items() if c == comm and n in adjacency) + sigma_tot = max(sigma_tot, 1) + gain = resolution * ki_comm - ki * sigma_tot / (2 * total_edges) if total_edges > 0 else 0 + if gain > best_gain: + best_gain = gain + best_comm = comm + + if best_comm != current_comm: + community[nid] = best_comm + improved = True + + comm_map: dict[str, str] = {} + for nid, comm_id in community.items(): + comm_map[nid] = f"comm_{comm_id}" + return comm_map + + def _compute_modularity( + self, + assignments: dict[str, str], + adjacency: dict[str, set[str]], + ) -> float: + total_edges = sum(len(neighbors) for neighbors in adjacency.values()) + if total_edges == 0: + return 0.0 + total_edges //= 2 + + e_cc: dict[str, float] = defaultdict(float) + a_c: dict[str, float] = defaultdict(float) + + for nid, neighbors in adjacency.items(): + c_i = assignments.get(nid, "") + a_c[c_i] += len(neighbors) + for neighbor_id in neighbors: + c_j = assignments.get(neighbor_id, "") + if c_i == c_j: + e_cc[c_i] += 1 + + modularity = 0.0 + for c in e_cc: + modularity += (e_cc[c] / (2.0 * total_edges if total_edges > 0 else 1)) - ( + a_c[c] / (2.0 * total_edges if total_edges > 0 else 1) + ) ** 2 + return round(modularity, 4) + + def _build_communities( + self, + assignments: dict[str, str], + nodes: list[GraphNode], + adjacency: dict[str, set[str]], + level: int, + label: str, + ) -> dict[str, Community]: + comm_members: dict[str, list[GraphNode]] = defaultdict(list) + for node in nodes: + cid = assignments.get(node.id, "") + if cid: + comm_members[cid].append(node) + + communities: dict[str, Community] = {} + for cid, members in comm_members.items(): + if len(members) < self._min_size: + smallest_comm = min(communities, key=lambda k: len(comm_members[k])) if communities else None + if smallest_comm: + for m in members: + assignments[m.id] = smallest_comm + communities[smallest_comm].member_count += 1 + continue + + path_counts: dict[str, int] = defaultdict(int) + tag_counts: dict[str, int] = defaultdict(int) + file_set: set[str] = set() + internal_edges = 0 + external_edges = 0 + + for m in members: + path_prefix = "/".join(m.file_path.split("/")[:2]) if "/" in m.file_path else m.file_path + path_counts[path_prefix] += 1 + for tag in m.semantic.tags: + tag_counts[tag] += 1 + file_set.add(m.file_path) + for neighbor_id in adjacency.get(m.id, set()): + if assignments.get(neighbor_id) == cid: + internal_edges += 1 + else: + external_edges += 1 + + majority_path = max(path_counts, key=path_counts.get) if path_counts else "" + top_tags_sorted = sorted(tag_counts, key=tag_counts.get, reverse=True)[:5] + + parent = "" + if level == 1: + for m in members: + parent = self._node_communities_l0.get(m.id, "") + break + + communities[cid] = Community( + id=cid, + level=level, + label=label + "_" + majority_path.split("/")[-1] if majority_path else label, + parent_community=parent, + majority_path_prefix=majority_path, + top_tags=top_tags_sorted, + member_count=len(members), + file_count=len(file_set), + internal_edge_count=internal_edges // 2, + external_edge_count=external_edges, + modularity_score=0.0, + detected_at=datetime.now(UTC).isoformat(), + ) + + return communities + + async def _store_community_node(self, comm: Community) -> None: + comm_node = GraphNode( + id=comm.id, + type=NodeType("Community") if "Community" in NodeType._value2member_map_ else NodeType.FILE, + file_path=comm.majority_path_prefix, + structural=__import__("smp.core.models", fromlist=["StructuralProperties"]).StructuralProperties( + name=comm.label, + file=comm.majority_path_prefix, + ), + semantic=__import__("smp.core.models", fromlist=["SemanticProperties"]).SemanticProperties( + tags=comm.top_tags, + enriched_at=comm.detected_at, + ), + ) + await self._graph.upsert_node(comm_node) + + async def _write_member_of_edges(self, comm: Community, nodes: list[GraphNode], level: int) -> None: + assignments = self._node_communities_l1 if level == 1 else self._node_communities_l0 + for node in nodes: + if assignments.get(node.id) == comm.id: + edge = GraphEdge( + source_id=node.id, + target_id=comm.id, + type=EdgeType.MEMBER_OF if "MEMBER_OF" in EdgeType._value2member_map_ else EdgeType.REFERENCES, + metadata={"community_level": str(level)}, + ) + await self._graph.upsert_edge(edge) + + async def _detect_bridges( + self, + nodes: list[GraphNode], + adjacency: dict[str, set[str]], + ) -> list[dict[str, Any]]: + bridges: list[dict[str, Any]] = [] + comm_pairs: dict[tuple[str, str], list[str]] = defaultdict(list) + + for node in nodes: + cid = self._node_communities_l1.get(node.id, "") + if not cid: + continue + for neighbor_id in adjacency.get(node.id, set()): + neighbor_cid = self._node_communities_l1.get(neighbor_id, "") + if neighbor_cid and neighbor_cid != cid: + pair = tuple(sorted([cid, neighbor_cid])) + comm_pairs[pair].append(node.id) + + for (c1, c2), bridge_nodes in comm_pairs.items(): + coupling = len(bridge_nodes) / max(self._communities.get(c1, Community()).member_count, 1) + bridges.append( + { + "from_community": c1, + "to_community": c2, + "edge_count": len(bridge_nodes), + "coupling_weight": round(coupling, 4), + "bridge_nodes": bridge_nodes, + } + ) + return bridges + + async def _write_bridges_edges(self) -> None: + for bridge in self._bridges: + edge_type = EdgeType.BRIDGES if "BRIDGES" in EdgeType._value2member_map_ else EdgeType.REFERENCES + edge = GraphEdge( + source_id=bridge["from_community"], + target_id=bridge["to_community"], + type=edge_type, + metadata={"coupling_weight": str(bridge.get("coupling_weight", ""))}, + ) + await self._graph.upsert_edge(edge) + + async def _compute_centroids(self, nodes: list[GraphNode]) -> None: + if self._vector is None: + return + from smp.engine.seed_walk import _simple_hash_embedding + + comm_nodes: dict[str, list[GraphNode]] = defaultdict(list) + for node in nodes: + cid = self._node_communities_l1.get(node.id, "") + if cid: + comm_nodes[cid].append(node) + + for cid, members in comm_nodes.items(): + if not members: + continue + all_vecs: list[list[float]] = [] + for m in members: + text = m.structural.name + " " + (m.semantic.docstring or "") + vec = _simple_hash_embedding(text) + all_vecs.append(vec) + + dim = len(all_vecs[0]) if all_vecs else 128 + centroid = [0.0] * dim + for vec in all_vecs: + for i in range(dim): + centroid[i] += vec[i] + n = len(all_vecs) if all_vecs else 1 + centroid = [c / n for c in centroid] + + comm = self._communities.get(cid) + label = comm.label if comm else cid + majority_path = comm.majority_path_prefix if comm else "" + + await self._vector.add_code_embedding( + node_id=f"centroid_{cid}", + embedding=centroid, + metadata={ + "collection_type": "centroid", + "community_id": cid, + "label": label, + "majority_path_prefix": majority_path, + "member_count": str(len(members)), + }, + document=label, + ) diff --git a/smp/engine/enricher.py b/smp/engine/enricher.py new file mode 100644 index 0000000..adc31f6 --- /dev/null +++ b/smp/engine/enricher.py @@ -0,0 +1,97 @@ +"""Static semantic enricher — AST-based extraction. + +Extracts docstrings, inline comments, decorators, type annotations, +and computes source hashes purely from the AST. +No LLM or embedding generation. +""" + +from __future__ import annotations + +import hashlib +from datetime import UTC, datetime + +from smp.core.models import GraphNode +from smp.engine.interfaces import SemanticEnricher as SemanticEnricherInterface +from smp.logging import get_logger + +log = get_logger(__name__) + + +def _compute_source_hash(name: str, file_path: str, start: int, end: int, signature: str) -> str: + """Compute deterministic source hash for a node.""" + raw = f"{file_path}:{name}:{start}:{end}:{signature}" + return hashlib.sha256(raw.encode()).hexdigest()[:8] + + +class StaticSemanticEnricher(SemanticEnricherInterface): + """Static AST-based semantic enricher. No LLM, no embeddings.""" + + def __init__(self) -> None: + self._enrichment_counts: dict[str, int] = { + "enriched": 0, + "skipped": 0, + "no_metadata": 0, + "failed": 0, + } + + async def enrich_node( + self, + node: GraphNode, + force: bool = False, + ) -> GraphNode: + """Enrich a single node with static metadata.""" + sem = node.semantic + current_hash = _compute_source_hash( + node.structural.name, + node.file_path, + node.structural.start_line, + node.structural.end_line, + node.structural.signature, + ) + + if not force and sem.source_hash and sem.source_hash == current_hash and sem.status != "no_metadata": + self._enrichment_counts["skipped"] += 1 + return node + + sem.source_hash = current_hash + + has_docstring = bool(sem.docstring and sem.docstring.strip()) + has_decorators = bool(sem.decorators) + has_annotations = bool(sem.annotations and (sem.annotations.params or sem.annotations.returns)) + + if not has_docstring and not has_decorators and not has_annotations: + sem.status = "no_metadata" + self._enrichment_counts["no_metadata"] += 1 + sem.enriched_at = datetime.now(UTC).isoformat() + return node + + sem.status = "enriched" + sem.enriched_at = datetime.now(UTC).isoformat() + + self._enrichment_counts["enriched"] += 1 + return node + + async def enrich_batch( + self, + nodes: list[GraphNode], + force: bool = False, + ) -> list[GraphNode]: + """Enrich multiple nodes.""" + enriched = [] + for node in nodes: + result = await self.enrich_node(node, force=force) + enriched.append(result) + return enriched + + async def embed(self, text: str) -> list[float]: + """No-op embedding — static enricher does not use vectors.""" + return [] + + def get_counts(self) -> dict[str, int]: + """Return enrichment statistics.""" + return dict(self._enrichment_counts) + + def reset_counts(self) -> None: + """Reset enrichment counters.""" + for key in self._enrichment_counts: + self._enrichment_counts[key] = 0 diff --git a/smp/engine/graph_builder.py b/smp/engine/graph_builder.py new file mode 100644 index 0000000..cf28008 --- /dev/null +++ b/smp/engine/graph_builder.py @@ -0,0 +1,159 @@ +"""Graph builder — maps parsed Documents into the graph store with Global Linking. + +Updated for SMP(3) partitioned data model. +""" + +from __future__ import annotations + +from smp.core.models import Document, GraphEdge, NodeType +from smp.engine.interfaces import GraphBuilder as GraphBuilderInterface +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +class DefaultGraphBuilder(GraphBuilderInterface): + def __init__(self, graph_store: GraphStore) -> None: + self._store = graph_store + self._pending_edges: list[tuple[GraphEdge, str, str]] = [] + + async def ingest_document(self, document: Document) -> None: + name_to_id = {n.structural.name: n.id for n in document.nodes} + + import_map: dict[str, tuple[str, str]] = {} + for node in document.nodes: + if node.type != NodeType.FILE: + continue + sig = node.structural.signature + if "import" not in sig: + continue + module_path = node.structural.name.replace(".", "/") + ".py" + if sig.strip().startswith("from"): + after_import = sig.split("import", 1)[1] + for raw_name in after_import.split(","): + stripped = raw_name.strip() + if not stripped: + continue + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + else: + parts = sig.replace("import", "").strip().split(",") + for p in parts: + stripped = p.strip() + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + + if document.nodes: + await self._store.upsert_nodes(document.nodes) + + resolved_edges: list[GraphEdge] = [] + for edge in document.edges: + parts = edge.target_id.split("::") + if len(parts) >= 4 and parts[-1] == "0": + entity_name = parts[2] + + if entity_name in name_to_id: + edge.target_id = name_to_id[entity_name] + resolved_edges.append(edge) + continue + + if entity_name in import_map: + module_path, original_name = import_map[entity_name] + target_id = await self._resolve_cross_file( + original_name, + module_path, + ) + if target_id: + edge.target_id = target_id + log.info( + "linker_resolved_cross_file", + name=entity_name, + original=original_name, + target=target_id, + ) + resolved_edges.append(edge) + else: + fallback = f"{module_path}::Function::{original_name}::1" + edge.target_id = fallback + self._pending_edges.append((edge, original_name, module_path)) + log.info( + "linker_cross_file_pending", + name=entity_name, + original=original_name, + target=fallback, + ) + else: + resolved_edges.append(edge) + else: + resolved_edges.append(edge) + + if resolved_edges: + await self._store.upsert_edges(resolved_edges) + + log.info("ingest_complete", file=document.file_path, resolved=len(resolved_edges)) + + async def _resolve_cross_file( + self, + entity_name: str, + module_path: str, + ) -> str | None: + """Look up the actual node ID for a cross-file reference.""" + candidates = await self._store.find_nodes(name=entity_name) + if not candidates: + return None + + if not module_path: + return candidates[0].id + + stem = module_path.rsplit("/", 1)[-1] + + for n in candidates: + if n.file_path == module_path: + return n.id + for n in candidates: + if n.file_path.endswith(stem): + return n.id + + return candidates[0].id + + async def resolve_pending_edges(self) -> int: + """Re-attempt cross-file edges that were deferred.""" + if not self._pending_edges: + return 0 + + fixed = 0 + still_pending: list[tuple[GraphEdge, str, str]] = [] + resolved: list[GraphEdge] = [] + for edge, original_name, module_path in self._pending_edges: + real_id = await self._resolve_cross_file(original_name, module_path) + if real_id: + edge.target_id = real_id + log.info( + "linker_pending_resolved", + original=original_name, + target=real_id, + ) + resolved.append(edge) + fixed += 1 + else: + still_pending.append((edge, original_name, module_path)) + + if resolved: + await self._store.upsert_edges(resolved) + + self._pending_edges = still_pending + if fixed: + log.info("resolve_pending_complete", fixed=fixed, remaining=len(still_pending)) + return fixed + + async def remove_document(self, file_path: str) -> None: + await self._store.delete_nodes_by_file(file_path) diff --git a/smp/engine/handoff.py b/smp/engine/handoff.py new file mode 100644 index 0000000..5a1492d --- /dev/null +++ b/smp/engine/handoff.py @@ -0,0 +1,203 @@ +"""Handoff layer for code review and PR creation. + +Manages the transition from AI-generated changes to human review, +including PR creation, review workflows, and approval tracking. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + + +@dataclass +class ReviewRequest: + """A request for human review.""" + + review_id: str + session_id: str + files_changed: list[str] + diff_summary: str + created_at: str + status: str = "pending" + reviewers: list[str] = field(default_factory=list) + approvals: list[str] = field(default_factory=list) + rejections: list[str] = field(default_factory=list) + comments: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class PRInfo: + """Information about a created PR.""" + + pr_id: str + review_id: str + title: str + body: str + branch: str + base_branch: str + url: str | None = None + created_at: str = "" + status: str = "open" + + +class HandoffManager: + """Manages code review and PR workflows.""" + + def __init__(self) -> None: + self._reviews: dict[str, ReviewRequest] = {} + self._prs: dict[str, PRInfo] = {} + + def create_review( + self, + session_id: str, + files_changed: list[str], + diff_summary: str, + reviewers: list[str] | None = None, + ) -> ReviewRequest: + """Create a new review request.""" + review_id = f"rev_{uuid.uuid4().hex[:8]}" + + review = ReviewRequest( + review_id=review_id, + session_id=session_id, + files_changed=files_changed, + diff_summary=diff_summary, + created_at=datetime.now(UTC).isoformat(), + reviewers=reviewers or [], + ) + self._reviews[review_id] = review + + log.info("review_created", review_id=review_id, files=len(files_changed)) + return review + + def add_comment( + self, + review_id: str, + author: str, + comment: str, + file_path: str | None = None, + line: int | None = None, + ) -> bool: + """Add a comment to a review.""" + review = self._reviews.get(review_id) + if not review: + return False + + comment_data: dict[str, Any] = { + "author": author, + "comment": comment, + "timestamp": datetime.now(UTC).isoformat(), + } + if file_path: + comment_data["file_path"] = file_path + if line: + comment_data["line"] = line + + review.comments.append(comment_data) + log.info("review_comment_added", review_id=review_id, author=author) + return True + + def approve(self, review_id: str, reviewer: str) -> bool: + """Record an approval for a review.""" + review = self._reviews.get(review_id) + if not review: + return False + + if reviewer not in review.approvals: + review.approvals.append(reviewer) + + if reviewer in review.rejections: + review.rejections.remove(reviewer) + + self._update_review_status(review) + log.info("review_approved", review_id=review_id, reviewer=reviewer) + return True + + def reject(self, review_id: str, reviewer: str, reason: str = "") -> bool: + """Record a rejection for a review.""" + review = self._reviews.get(review_id) + if not review: + return False + + if reviewer not in review.rejections: + review.rejections.append(reviewer) + + if reviewer in review.approvals: + review.approvals.remove(reviewer) + + self._update_review_status(review) + log.info("review_rejected", review_id=review_id, reviewer=reviewer, reason=reason) + return True + + def _update_review_status(self, review: ReviewRequest) -> None: + """Update review status based on approvals/rejections.""" + if len(review.rejections) > 0: + review.status = "rejected" + elif len(review.approvals) >= len(review.reviewers) and review.reviewers: + review.status = "approved" + + def create_pr( + self, + review_id: str, + title: str, + body: str, + branch: str, + base_branch: str = "main", + ) -> PRInfo | None: + """Create a PR for an approved review.""" + review = self._reviews.get(review_id) + if not review: + return None + + pr_id = f"pr_{uuid.uuid4().hex[:8]}" + + pr = PRInfo( + pr_id=pr_id, + review_id=review_id, + title=title, + body=body, + branch=branch, + base_branch=base_branch, + created_at=datetime.now(UTC).isoformat(), + ) + self._prs[pr_id] = pr + + review.status = "pr_created" + log.info("pr_created", pr_id=pr_id, review_id=review_id) + return pr + + def get_review(self, review_id: str) -> ReviewRequest | None: + """Get review by ID.""" + return self._reviews.get(review_id) + + def get_pr(self, pr_id: str) -> PRInfo | None: + """Get PR by ID.""" + return self._prs.get(pr_id) + + def list_pending_reviews(self) -> list[ReviewRequest]: + """List all pending reviews.""" + return [r for r in self._reviews.values() if r.status == "pending"] + + def get_review_summary(self, review_id: str) -> dict[str, Any] | None: + """Get summary of a review.""" + review = self._reviews.get(review_id) + if not review: + return None + + return { + "review_id": review.review_id, + "session_id": review.session_id, + "status": review.status, + "files_count": len(review.files_changed), + "reviewers": len(review.reviewers), + "approvals": len(review.approvals), + "rejections": len(review.rejections), + "comments_count": len(review.comments), + } diff --git a/smp/engine/integrity.py b/smp/engine/integrity.py new file mode 100644 index 0000000..3e44bcd --- /dev/null +++ b/smp/engine/integrity.py @@ -0,0 +1,242 @@ +"""Integrity verification module for AST-based data-flow analysis. + +Verifies that runtime behavior matches structural expectations by +analyzing data flow through the AST and detecting mutations. +""" + +from __future__ import annotations + +import subprocess +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +@dataclass +class MutationRecord: + """Record of a detected mutation.""" + + node_id: str + mutation_type: str + field_name: str + old_value: str + new_value: str + detected_at: str + + +@dataclass +class DataFlowPath: + """Represents a data flow path through the code.""" + + source_node: str + target_node: str + path: list[str] + flow_type: str + transformations: list[str] = field(default_factory=list) + + +@dataclass +class IntegrityCheckResult: + """Result of an integrity verification.""" + + passed: bool + node_id: str + checks_run: int + mutations_detected: list[MutationRecord] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + +class IntegrityVerifier: + """Verifies structural integrity of graph nodes.""" + + def __init__(self) -> None: + self._mutations: list[MutationRecord] = [] + self._baselines: dict[str, dict[str, Any]] = {} + + async def capture_baseline(self, node_id: str, state: dict[str, Any]) -> None: + """Capture baseline state for a node.""" + self._baselines[node_id] = { + "state": state.copy(), + "captured_at": datetime.now(UTC).isoformat(), + } + log.debug("baseline_captured", node_id=node_id) + + async def verify( + self, + node_id: str, + current_state: dict[str, Any], + ) -> IntegrityCheckResult: + """Verify node state against baseline.""" + baseline = self._baselines.get(node_id) + mutations: list[MutationRecord] = [] + warnings: list[str] = [] + + checks_run = 1 + + if baseline: + for field_name, baseline_value in baseline["state"].items(): + current_value = current_state.get(field_name) + + if baseline_value != current_value: + mutation = MutationRecord( + node_id=node_id, + mutation_type="field_change", + field_name=field_name, + old_value=str(baseline_value), + new_value=str(current_value), + detected_at=datetime.now(UTC).isoformat(), + ) + mutations.append(mutation) + self._mutations.append(mutation) + + warnings.append(f"{field_name} changed from {baseline_value} to {current_value}") + + passed = len(mutations) == 0 + + log.info( + "integrity_check", + node_id=node_id, + passed=passed, + mutations=len(mutations), + ) + + return IntegrityCheckResult( + passed=passed, + node_id=node_id, + checks_run=checks_run, + mutations_detected=mutations, + warnings=warnings, + ) + + def analyze_data_flow( + self, + source: str, + sink: str, + path_nodes: list[str], + ) -> DataFlowPath: + """Analyze data flow from source to sink.""" + transformations = [] + + for i in range(len(path_nodes) - 1): + transformations.append(f"{path_nodes[i]} → {path_nodes[i + 1]}") + + return DataFlowPath( + source_node=source, + target_node=sink, + path=path_nodes, + flow_type="data", + transformations=transformations, + ) + + def get_mutations(self, node_id: str | None = None) -> list[MutationRecord]: + """Get all detected mutations, optionally filtered by node.""" + if node_id: + return [m for m in self._mutations if m.node_id == node_id] + return list(self._mutations) + + def clear_mutations(self) -> None: + """Clear mutation history.""" + self._mutations.clear() + log.info("mutations_cleared") + + def get_mutation_summary(self) -> dict[str, Any]: + """Return summary of detected mutations.""" + by_node: dict[str, int] = {} + for m in self._mutations: + by_node[m.node_id] = by_node.get(m.node_id, 0) + 1 + + return { + "total_mutations": len(self._mutations), + "affected_nodes": len(by_node), + "by_node": by_node, + } + + async def run_mutation_test( + self, + node_id: str, + graph_store: GraphStore, + ) -> IntegrityCheckResult: + """Run mutation testing on a specific node. + + Mutates operators in the source code and checks if tests still pass. + """ + node = await graph_store.get_node(node_id) + if not node: + log.error("mutation_test_failed", reason="node_not_found", node_id=node_id) + return IntegrityCheckResult(passed=False, node_id=node_id, checks_run=0) + + file_path = node.file_path + try: + with open(file_path) as f: + lines = f.readlines() + except OSError as e: + log.error("mutation_test_failed", reason="file_read_error", error=str(e)) + return IntegrityCheckResult(passed=False, node_id=node_id, checks_run=0) + + mutants_survived = 0 + checks_run = 0 + detected_mutations: list[MutationRecord] = [] + + # Simple operator flips + operators = {"==": "!=", "!=": "==", ">": "<=", "<": ">=", ">=": "<=", "<=": ">"} + + start = max(0, node.structural.start_line - 1) + end = min(len(lines), node.structural.end_line) + + for i in range(start, end): + line = lines[i] + for op, replacement in operators.items(): + if op in line: + checks_run += 1 + original_line = line + lines[i] = line.replace(op, replacement, 1) + + try: + with open(file_path, "w") as f: + f.writelines(lines) + + # Run tests + result = subprocess.run(["pytest"], capture_output=True, text=True, timeout=30) + + if result.returncode == 0: + mutants_survived += 1 + mutation = MutationRecord( + node_id=node_id, + mutation_type="operator_flip", + field_name=f"line_{i + 1}", + old_value=op, + new_value=replacement, + detected_at=datetime.now(UTC).isoformat(), + ) + detected_mutations.append(mutation) + self._mutations.append(mutation) + + except (subprocess.TimeoutExpired, OSError) as e: + log.warning("mutation_test_warning", error=str(e)) + finally: + lines[i] = original_line + with open(file_path, "w") as f: + f.writelines(lines) + + passed = mutants_survived == 0 + + log.info( + "mutation_test_completed", + node_id=node_id, + passed=passed, + survived=mutants_survived, + total=checks_run, + ) + + return IntegrityCheckResult( + passed=passed, + node_id=node_id, + checks_run=checks_run, + mutations_detected=detected_mutations, + warnings=[f"{mutants_survived} mutants survived"] if mutants_survived > 0 else [], + ) diff --git a/smp/engine/interfaces.py b/smp/engine/interfaces.py new file mode 100644 index 0000000..7a68317 --- /dev/null +++ b/smp/engine/interfaces.py @@ -0,0 +1,153 @@ +"""Abstract base classes for the engine layer. + +Defines the contracts for parsing, graph building, semantic enrichment, +and querying for SMP(3). +""" + +from __future__ import annotations + +import abc +from typing import Any + +from smp.core.models import ( + Document, + GraphNode, +) + + +class Parser(abc.ABC): + """Extract typed AST nodes and edges from source code.""" + + @abc.abstractmethod + def parse(self, source: str, file_path: str) -> Document: + """Parse *source* and return a :class:`Document`.""" + + @property + @abc.abstractmethod + def supported_languages(self) -> list[str]: + """Return language names this parser handles.""" + + +class GraphBuilder(abc.ABC): + """Map parsed :class:`Document` elements into a graph store.""" + + @abc.abstractmethod + async def ingest_document(self, document: Document) -> None: + """Write the document's nodes and edges into the graph store.""" + + @abc.abstractmethod + async def remove_document(self, file_path: str) -> None: + """Remove all graph data for *file_path*.""" + + +class SemanticEnricher(abc.ABC): + """Generate static semantic summaries from AST metadata.""" + + @abc.abstractmethod + async def enrich_node(self, node: GraphNode, force: bool = False) -> GraphNode: + """Return a copy of *node* with :class:`SemanticProperties` populated.""" + + @abc.abstractmethod + async def enrich_batch(self, nodes: list[GraphNode], force: bool = False) -> list[GraphNode]: + """Enrich multiple nodes.""" + + @abc.abstractmethod + async def embed(self, text: str) -> list[float]: + """No-op for static enricher.""" + + +class QueryEngine(abc.ABC): + """High-level query interface over the memory store.""" + + @abc.abstractmethod + async def navigate(self, query: str, include_relationships: bool = True) -> dict[str, Any]: + """Find entity and return basic info with relationships.""" + + @abc.abstractmethod + async def trace( + self, + start: str, + relationship: str = "CALLS", + depth: int = 3, + direction: str = "outgoing", + ) -> list[dict[str, Any]]: + """Follow relationship chain from start node.""" + + @abc.abstractmethod + async def get_context( + self, + file_path: str, + scope: str = "edit", + depth: int = 2, + ) -> dict[str, Any]: + """Aggregate structural context for safe editing — the programmer's mental model.""" + + @abc.abstractmethod + async def assess_impact(self, entity: str, change_type: str = "delete") -> dict[str, Any]: + """Find blast radius of a change.""" + + @abc.abstractmethod + async def locate( + self, + query: str, + fields: list[str] | None = None, + node_types: list[str] | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + """Keyword search ranked by match quality.""" + + @abc.abstractmethod + async def search( + self, + query: str, + match: str = "any", + filters: dict[str, Any] | None = None, + top_k: int = 5, + ) -> dict[str, Any]: + """Pure keyword/token search across docstrings and tags.""" + + @abc.abstractmethod + async def find_flow( + self, + start: str, + end: str, + flow_type: str = "data", + ) -> dict[str, Any]: + """Trace execution/data flow between two nodes.""" + + @abc.abstractmethod + async def diff( + self, + from_snapshot: str, + to_snapshot: str, + scope: str = "full", + ) -> dict[str, Any]: + """Compare two snapshots and return differences.""" + + @abc.abstractmethod + async def plan( + self, + change_description: str, + target_file: str, + change_type: str = "refactor", + scope: str = "full", + ) -> dict[str, Any]: + """Generate a change plan for proposed modifications.""" + + @abc.abstractmethod + async def conflict( + self, + entity: str, + proposed_change: str, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Check for conflicts in proposed changes.""" + + @abc.abstractmethod + async def why( + self, + entity: str, + relationship: str = "", + depth: int = 3, + ) -> dict[str, Any]: + """Explain why a relationship exists.""" diff --git a/smp/engine/linker.py b/smp/engine/linker.py new file mode 100644 index 0000000..66a10bc --- /dev/null +++ b/smp/engine/linker.py @@ -0,0 +1,205 @@ +"""Graph linker module for resolving cross-file references. + +Implements the SMP(3) linker spec: +- Resolves namespaced CALLS edges (file::function) +- Supports global linking across the graph +- Handles pending edges for forward references +""" + +from __future__ import annotations + +from typing import Any + +from smp.core.models import Document, EdgeType, GraphEdge, GraphNode, NodeType +from smp.logging import get_logger + +log = get_logger(__name__) + + +class Linker: + """Resolves cross-file references and creates CALLS edges.""" + + def __init__(self) -> None: + self._pending_edges: list[tuple[GraphEdge, str, str]] = [] + self._import_maps: dict[str, dict[str, tuple[str, str]]] = {} + + def build_import_map( + self, + document: Document, + nodes: list[GraphNode], + ) -> dict[str, tuple[str, str]]: + """Build import map from document nodes. + + Returns dict mapping imported names to (module_path, original_name). + """ + import_map: dict[str, tuple[str, str]] = {} + + for node in nodes: + if node.type != NodeType.FILE: + continue + + sig = node.structural.signature + if "import" not in sig: + continue + + module_path = node.structural.name.replace(".", "/") + ".py" + + if sig.strip().startswith("from"): + after_import = sig.split("import", 1)[1] + for raw_name in after_import.split(","): + stripped = raw_name.strip() + if not stripped: + continue + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + else: + parts = sig.replace("import", "").strip().split(",") + for p in parts: + stripped = p.strip() + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + + self._import_maps[document.file_path] = import_map + return import_map + + async def resolve_calls( + self, + edges: list[GraphEdge], + nodes: list[GraphNode], + graph_store: Any, + ) -> tuple[list[GraphEdge], list[tuple[GraphEdge, str, str]]]: + """Resolve CALLS edges to target node IDs. + + Returns (resolved_edges, pending_edges). + """ + name_to_id = {n.structural.name: n.id for n in nodes} + file_path = nodes[0].file_path if nodes else "" + import_map = self._import_maps.get(file_path, {}) + + resolved: list[GraphEdge] = [] + pending: list[tuple[GraphEdge, str, str]] = [] + + for edge in edges: + if edge.type != EdgeType.CALLS: + resolved.append(edge) + continue + + target_id = edge.target_id + parts = target_id.split("::") + + if len(parts) >= 4 and parts[-1] == "0": + entity_name = parts[2] + + if entity_name in name_to_id: + edge.target_id = name_to_id[entity_name] + resolved.append(edge) + log.debug("linker_resolved_local", name=entity_name, target=edge.target_id) + continue + + if entity_name in import_map: + module_path, original_name = import_map[entity_name] + resolved_target = await self._resolve_cross_file( + original_name, + module_path, + graph_store, + ) + + if resolved_target: + edge.target_id = resolved_target + resolved.append(edge) + log.info( + "linker_resolved_cross_file", + name=entity_name, + original=original_name, + target=resolved_target, + ) + else: + fallback = f"{module_path}::Function::{original_name}::1" + edge.target_id = fallback + pending.append((edge, original_name, module_path)) + log.info( + "linker_cross_file_pending", + name=entity_name, + original=original_name, + target=fallback, + ) + resolved.append(edge) + else: + resolved.append(edge) + + return resolved, pending + + async def _resolve_cross_file( + self, + entity_name: str, + module_path: str, + graph_store: Any, + ) -> str | None: + """Look up the actual node ID for a cross-file reference.""" + candidates = await graph_store.find_nodes(name=entity_name) + if not candidates: + return None + + if not module_path: + return candidates[0].id + + stem = module_path.rsplit("/", 1)[-1] + + for n in candidates: + if n.file_path == module_path: + return n.id + + for n in candidates: + if n.file_path.endswith(stem): + return n.id + + return candidates[0].id + + async def resolve_pending(self, graph_store: Any) -> int: + """Re-attempt pending edge resolutions.""" + if not self._pending_edges: + return 0 + + fixed = 0 + still_pending: list[tuple[GraphEdge, str, str]] = [] + resolved: list[GraphEdge] = [] + + for edge, original_name, module_path in self._pending_edges: + real_id = await self._resolve_cross_file(original_name, module_path, graph_store) + if real_id: + edge.target_id = real_id + log.info( + "linker_pending_resolved", + original=original_name, + target=real_id, + ) + resolved.append(edge) + fixed += 1 + else: + still_pending.append((edge, original_name, module_path)) + + if resolved: + await graph_store.upsert_edges(resolved) + + self._pending_edges = still_pending + if fixed: + log.info("resolve_pending_complete", fixed=fixed, remaining=len(still_pending)) + + return fixed + + def get_pending_count(self) -> int: + """Return count of pending edges.""" + return len(self._pending_edges) + + def clear_pending(self) -> None: + """Clear all pending edges.""" + self._pending_edges.clear() + log.info("linker_pending_cleared") diff --git a/smp/engine/notification.py b/smp/engine/notification.py new file mode 100644 index 0000000..0943911 --- /dev/null +++ b/smp/engine/notification.py @@ -0,0 +1,90 @@ +"""Notification manager for server-push events. + +Provides a polling-based notification system for clients to receive +real-time updates about graph changes, session events, etc. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + + +@dataclass +class Notification: + """A single notification event.""" + + notification_id: str + event_type: str + payload: dict[str, Any] + timestamp: str + session_id: str = "" + + +class NotificationManager: + """Manages notifications with in-memory storage.""" + + def __init__(self, max_events: int = 1000) -> None: + self._events: list[Notification] = [] + self._max_events = max_events + self._subscribers: dict[str, asyncio.Queue[Notification]] = {} + + def emit( + self, + event_type: str, + payload: dict[str, Any], + session_id: str = "", + ) -> None: + """Emit a new notification event.""" + notification = Notification( + notification_id=f"notif_{len(self._events)}", + event_type=event_type, + payload=payload, + timestamp=datetime.now(UTC).isoformat(), + session_id=session_id, + ) + self._events.append(notification) + + # Trim if exceeding max + if len(self._events) > self._max_events: + self._events = self._events[-self._max_events :] + + log.debug("notification_emitted", event_type=event_type) + + def poll(self, last_seen: int = 0) -> list[dict[str, Any]]: + """Poll for new notifications since last_seen index.""" + if last_seen >= len(self._events): + return [] + + recent = self._events[last_seen:] + return [ + { + "index": last_seen + i, + "notification_id": n.notification_id, + "event_type": n.event_type, + "payload": n.payload, + "timestamp": n.timestamp, + "session_id": n.session_id, + } + for i, n in enumerate(recent) + ] + + def get_recent(self, limit: int = 50) -> list[dict[str, Any]]: + """Get the most recent notifications.""" + recent = self._events[-limit:] + return [ + { + "notification_id": n.notification_id, + "event_type": n.event_type, + "payload": n.payload, + "timestamp": n.timestamp, + "session_id": n.session_id, + } + for n in recent + ] diff --git a/smp/engine/pagerank.py b/smp/engine/pagerank.py new file mode 100644 index 0000000..16672b2 --- /dev/null +++ b/smp/engine/pagerank.py @@ -0,0 +1,119 @@ +"""PageRank engine for calculating node importance in the structural graph. + +Implements an iterative PageRank algorithm to identify central entities based on +graph connectivity (in-degree and relationship importance). +""" + +from __future__ import annotations + +from collections import defaultdict + +from smp.core.models import GraphEdge, GraphNode +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +class PageRankEngine: + """Calculates importance scores for graph nodes using the PageRank algorithm.""" + + def __init__(self, damping: float = 0.85, max_iterations: int = 100, tol: float = 1e-6) -> None: + """Initialize PageRank engine. + + Args: + damping: Damping factor (probability of following a link). + max_iterations: Maximum number of iterations to run. + tol: Convergence threshold. + """ + self.damping = damping + self.max_iterations = max_iterations + self.tol = tol + + def compute(self, nodes: list[GraphNode], edges: list[GraphEdge]) -> dict[str, float]: + """Compute PageRank scores for the given nodes and edges. + + Args: + nodes: List of nodes in the graph. + edges: List of directed edges in the graph. + + Returns: + A dictionary mapping node IDs to their calculated PageRank scores. + """ + if not nodes: + return {} + + n = len(nodes) + node_ids = [node.id for node in nodes] + id_to_idx = {node_id: i for i, node_id in enumerate(node_ids)} + + # Adjacency list and out-degrees + adj = defaultdict(list) + out_degree = defaultdict(int) + for edge in edges: + if edge.source_id in id_to_idx and edge.target_id in id_to_idx: + adj[edge.target_id].append(edge.source_id) + out_degree[edge.source_id] += 1 + + # Initial scores + scores = [1.0 / n] * n + + for iteration in range(self.max_iterations): + new_scores = [0.0] * n + total_dangling_weight = 0.0 + + # Handle dangling nodes (nodes with no outgoing edges) + for i in range(n): + if out_degree[node_ids[i]] == 0: + total_dangling_weight += scores[i] + + for i in range(n): + target_id = node_ids[i] + # Sum of PageRank from neighbors + rank_sum = sum(scores[id_to_idx[src]] / out_degree[src] for src in adj[target_id]) + + # Calculate new score + new_scores[i] = (1.0 - self.damping) / n + self.damping * (rank_sum + total_dangling_weight / n) + + # Check convergence + diff = sum(abs(new_scores[i] - scores[i]) for i in range(n)) + if diff < self.tol: + log.debug("pagerank_converged", iteration=iteration, diff=diff) + scores = new_scores + break + + scores = new_scores + + return {node_ids[i]: scores[i] for i in range(n)} + + async def update_node_scores(self, graph_store: GraphStore) -> int: + """Update nodes in the graph store with their computed PageRank scores. + + Args: + graph_store: The graph store to update. + + Returns: + Number of nodes updated. + """ + # Use a broad search to get all nodes. In a real scenario, + # we might want to filter by type or scope. + nodes = await graph_store.find_nodes() + + # We need all edges to compute PageRank. + # This is expensive for large graphs; a real implementation would use GDS. + all_edges: list[GraphEdge] = [] + for node in nodes: + edges = await graph_store.get_edges(node.id, direction="outgoing") + all_edges.extend(edges) + + scores = self.compute(nodes, all_edges) + + updated_count = 0 + for node in nodes: + score = scores.get(node.id, 0.0) + node.semantic.score = score + await graph_store.upsert_node(node) + updated_count += 1 + + log.info("pagerank_scores_updated", count=updated_count) + return updated_count diff --git a/smp/engine/query.py b/smp/engine/query.py new file mode 100644 index 0000000..cbbafa7 --- /dev/null +++ b/smp/engine/query.py @@ -0,0 +1,817 @@ +"""Query engine — high-level structural queries over the memory store. + +Provides navigate, trace, get_context, assess_impact, locate, search, +and find_flow queries backed by the graph store. +""" + +from __future__ import annotations + +from collections import deque +from typing import Any + +from smp.core.models import EdgeType, GraphNode, NodeType +from smp.engine.interfaces import QueryEngine as QueryEngineInterface +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + +_HTTP_VERB_DECORATORS = {"get", "post", "put", "delete", "patch", "head", "options"} +_UTILITY_PATH_SEGMENTS = {"/utils", "/lib", "/shared", "/helpers"} + + +class DefaultQueryEngine(QueryEngineInterface): + """Query engine backed by a graph store.""" + + def __init__( + self, + graph_store: GraphStore, + enricher: Any | None = None, + ) -> None: + self._graph = graph_store + self._enricher = enricher + + def _node_to_dict(self, node: GraphNode) -> dict[str, Any]: + return { + "id": node.id, + "type": node.type.value, + "file_path": node.file_path, + "name": node.structural.name, + "signature": node.structural.signature, + "start_line": node.structural.start_line, + "end_line": node.structural.end_line, + "complexity": node.structural.complexity, + "lines": node.structural.lines, + "semantic": { + "status": node.semantic.status, + "docstring": node.semantic.docstring, + "description": node.semantic.description, + "decorators": node.semantic.decorators, + "tags": node.semantic.tags, + }, + } + + async def navigate(self, query: str, include_relationships: bool = True) -> dict[str, Any]: + node = await self._graph.get_node(query) + + # If exact match fails, try to find by file path or name + if not node: + # Check if query looks like a file path + if "/" in query or query.endswith(".py"): + candidates = await self._graph.find_nodes(file_path=query) + if candidates: + node = candidates[0] + else: + # Try finding by name + candidates = await self._graph.find_nodes(name=query) + if candidates: + node = candidates[0] + + # If still not found, try partial match on node ID prefix + if not node: + all_nodes = await self._graph.find_nodes() + for n in all_nodes: + if n.id.startswith(query) or query in n.id: + node = n + break + + if not node: + return {"error": f"Node {query} not found"} + + result: dict[str, Any] = {"entity": self._node_to_dict(node)} + + if include_relationships: + outgoing = await self._graph.get_edges(node.id, direction="outgoing") + incoming = await self._graph.get_edges(node.id, direction="incoming") + + calls = [e.target_id for e in outgoing if e.type == EdgeType.CALLS] + called_by = [e.source_id for e in incoming if e.type == EdgeType.CALLS] + depends_on = [e.target_id for e in outgoing if e.type == EdgeType.DEPENDS_ON] + imported_by = [e.source_id for e in incoming if e.type == EdgeType.IMPORTS] + + result["relationships"] = { + "calls": calls, + "called_by": called_by, + "depends_on": depends_on, + "imported_by": imported_by, + } + + return result + + async def trace( + self, + start: str, + relationship: str = "CALLS", + depth: int = 3, + direction: str = "outgoing", + ) -> list[dict[str, Any]]: + try: + et = EdgeType(relationship) + except ValueError: + et = EdgeType.CALLS + nodes = await self._graph.traverse(start, et, depth, max_nodes=100, direction=direction) + return [self._node_to_dict(n) for n in nodes] + + async def get_context( + self, + file_path: str, + scope: str = "edit", + depth: int = 2, + ) -> dict[str, Any]: + file_nodes = await self._graph.find_nodes(file_path=file_path) + if not file_nodes: + return {"error": f"No nodes found for {file_path}"} + + file_node = file_nodes[0] + file_id = file_node.id + + imports = await self._graph.get_edges(file_id, EdgeType.IMPORTS, direction="outgoing") + imported_by = await self._graph.get_edges(file_id, EdgeType.IMPORTS, direction="incoming") + defines = await self._graph.get_edges(file_id, EdgeType.DEFINES, direction="outgoing") + tests_edges = await self._graph.get_edges(file_id, EdgeType.TESTS, direction="incoming") + + defines_nodes: list[dict[str, Any]] = [] + complexities: list[int] = [] + exported_symbols: list[str] = [] + http_decorators: list[str] = [] + test_file_paths: list[str] = [] + + for edge in defines: + target = await self._graph.get_node(edge.target_id) + if target: + defines_nodes.append(self._node_to_dict(target)) + complexities.append(target.structural.complexity) + exported_symbols.append(target.structural.name) + for dec in target.semantic.decorators: + dec_lower = dec.lstrip("@").lower() + if dec_lower in _HTTP_VERB_DECORATORS: + http_decorators.append(dec) + + for te in tests_edges: + source = await self._graph.get_node(te.source_id) + if source and source.file_path not in test_file_paths: + test_file_paths.append(source.file_path) + + has_tests = len(test_file_paths) > 0 + + related_patterns: list[dict[str, Any]] = [] + all_nodes = await self._graph.find_nodes() + for candidate in all_nodes: + if candidate.id == file_id or candidate.file_path == file_path: + continue + if candidate.type == file_node.type: + name_sim = self._name_similarity(file_node.structural.name, candidate.structural.name) + if name_sim > 0.5: + related_patterns.append( + { + "file_path": candidate.file_path, + "name": candidate.structural.name, + "similarity": round(name_sim, 2), + } + ) + related_patterns.sort(key=lambda x: -x["similarity"]) + related_patterns = related_patterns[:5] + + entry_points: list[dict[str, Any]] = [] + if http_decorators: + for edge in defines: + target = await self._graph.get_node(edge.target_id) + if target: + target_http = [ + d for d in target.semantic.decorators if d.lstrip("@").lower() in _HTTP_VERB_DECORATORS + ] + if target_http: + entry_points.append( + { + "name": target.structural.name, + "decorators": target_http, + "file_path": target.file_path, + } + ) + + data_flow_in: list[dict[str, Any]] = [] + data_flow_out: list[dict[str, Any]] = [] + + callers_in = await self._graph.traverse( + file_id, EdgeType.CALLS, depth=depth, max_nodes=50, direction="incoming" + ) + for caller in callers_in: + data_flow_in.append( + { + "node_id": caller.id, + "name": caller.structural.name, + "file_path": caller.file_path, + } + ) + + callers_out = await self._graph.traverse( + file_id, EdgeType.CALLS, depth=depth, max_nodes=50, direction="outgoing" + ) + for callee in callers_out: + data_flow_out.append( + { + "node_id": callee.id, + "name": callee.structural.name, + "file_path": callee.file_path, + } + ) + + role = self._classify_role(file_node, imported_by, defines_nodes, http_decorators) + avg_complexity = round(sum(complexities) / max(len(complexities), 1), 1) + max_complexity = max(complexities, default=0) + blast_radius = len(imported_by) + + imported_by_api = 0 + for edge in imported_by: + source = await self._graph.get_node(edge.source_id) + if source and "/api" in source.file_path: + imported_by_api += 1 + + is_hot_node = blast_radius > 10 or max_complexity > 8 + heat_score = blast_radius + max_complexity + + if blast_radius > 10 or avg_complexity > 8: + risk_level = "high" + elif blast_radius > 3 or avg_complexity > 4: + risk_level = "medium" + else: + risk_level = "low" + + summary = { + "role": role, + "blast_radius": blast_radius, + "api_layer_callers": imported_by_api, + "avg_complexity": avg_complexity, + "max_complexity": max_complexity, + "exported_symbols": exported_symbols, + "has_tests": has_tests, + "test_files": test_file_paths, + "is_hot_node": is_hot_node, + "heat_score": heat_score, + "risk_level": risk_level, + } + + return { + "self": self._node_to_dict(file_node), + "imports": [{"source": e.source_id, "target": e.target_id} for e in imports], + "imported_by": [{"source": e.source_id, "target": e.target_id} for e in imported_by], + "defines": defines_nodes, + "related_patterns": related_patterns, + "entry_points": entry_points, + "data_flow_in": data_flow_in, + "data_flow_out": data_flow_out, + "summary": summary, + } + + @staticmethod + def _name_similarity(name_a: str, name_b: str) -> float: + if not name_a or not name_b: + return 0.0 + set_a = set(name_a.lower()) + set_b = set(name_b.lower()) + if not set_a or not set_b: + return 0.0 + intersection = set_a & set_b + union = set_a | set_b + return len(intersection) / len(union) + + def _classify_role( + self, + file_node: GraphNode, + imported_by: list[Any], + defines_nodes: list[dict[str, Any]], + http_decorators: list[str], + ) -> str: + path = file_node.file_path + if "/test" in path or "/spec" in path: + return "test" + if file_node.type == NodeType.CONFIG: + return "config" + if http_decorators: + return "endpoint" + if "/routes" in path or "/controllers" in path: + return "endpoint" + incoming_imports = len(imported_by) + if "/services" in path and incoming_imports > 0: + return "service" + if incoming_imports > 5 and any(seg in path for seg in _UTILITY_PATH_SEGMENTS): + return "core_utility" + if incoming_imports == 0 and not defines_nodes: + return "isolated" + return "module" + + async def assess_impact(self, entity: str, change_type: str = "delete") -> dict[str, Any]: + node = await self._graph.get_node(entity) + + # If exact match fails, try to find by file path or name + if not node: + if "/" in entity or entity.endswith(".py"): + candidates = await self._graph.find_nodes(file_path=entity) + if candidates: + node = candidates[0] + else: + candidates = await self._graph.find_nodes(name=entity) + if candidates: + node = candidates[0] + + # Try partial match if still not found + if not node: + all_nodes = await self._graph.find_nodes() + for n in all_nodes: + if entity in n.id: + node = n + break + + if not node: + return {"error": f"Node {entity} not found"} + + dependents = await self._graph.traverse(node.id, EdgeType.CALLS, depth=10, max_nodes=200, direction="incoming") + + affected_files: list[str] = [] + affected_functions: list[str] = [] + for dep in dependents: + if dep.file_path not in affected_files: + affected_files.append(dep.file_path) + affected_functions.append(dep.structural.name) + + severity = "low" + if len(dependents) > 10: + severity = "high" + elif len(dependents) > 3: + severity = "medium" + + recommendations: list[str] = [] + if change_type == "signature_change": + recommendations.append(f"Update {len(dependents)} callers to match new signature") + elif change_type == "delete": + recommendations.append(f"Remove or stub {len(dependents)} dependent references") + + return { + "affected_files": affected_files, + "affected_functions": affected_functions, + "severity": severity, + "recommendations": recommendations, + } + + async def locate( + self, + query: str, + fields: list[str] | None = None, + node_types: list[str] | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + if not fields: + fields = ["name", "docstring", "tags"] + + terms = query.lower().split() + all_nodes = await self._graph.find_nodes() + + scored: list[tuple[int, dict[str, Any]]] = [] + for node in all_nodes: + if node_types and node.type.value not in node_types: + continue + + score = 0 + matched_on = "" + + name_lower = node.structural.name.lower() + if all(t in name_lower for t in terms): + score = 100 + matched_on = "name" + elif any(t in name_lower for t in terms): + score = 50 + matched_on = "name" + elif node.semantic.docstring: + doc_lower = node.semantic.docstring.lower() + if all(t in doc_lower for t in terms): + score = 30 + matched_on = "docstring" + elif any(t in doc_lower for t in terms): + score = 15 + matched_on = "docstring" + + if score > 0: + for tag in node.semantic.tags: + if any(t in tag.lower() for t in terms): + score += 10 + if matched_on: + matched_on += ", tags" + else: + matched_on = "tags" + break + + scored.append( + ( + score, + { + "entity": node.structural.name, + "file": node.file_path, + "matched_on": matched_on, + "docstring": node.semantic.docstring, + "tags": node.semantic.tags, + }, + ) + ) + + scored.sort(key=lambda x: -x[0]) + return [item[1] for item in scored[:top_k]] + + async def search( + self, + query: str, + match: str = "any", + filters: dict[str, Any] | None = None, + top_k: int = 5, + ) -> dict[str, Any]: + filters = filters or {} + terms = query.split() + node_types = filters.get("node_types") + tags = filters.get("tags") + scope = filters.get("scope") + + results = await self._graph.search_nodes( + query_terms=terms, + match=match, + node_types=node_types, + tags=tags, + scope=scope, + top_k=top_k, + ) + + if not results: + return { + "matches": [], + "total": 0, + "hint": "Try broadening scope or using match: any", + } + + return {"matches": results, "total": len(results)} + + async def find_flow( + self, + start: str, + end: str, + flow_type: str = "data", + ) -> dict[str, Any]: + # Resolve start and end nodes + start_node = await self._graph.get_node(start) + if not start_node: + candidates = await self._graph.find_nodes(name=start) + if candidates: + start_node = candidates[0] + + end_node = await self._graph.get_node(end) + if not end_node: + candidates = await self._graph.find_nodes(name=end) + if candidates: + end_node = candidates[0] + + if start == end: + if start_node: + return { + "path": [{"node": start_node.structural.name, "type": start_node.type.value}], + "data_transformations": [], + } + return {"path": [], "data_transformations": []} + + if not start_node or not end_node: + return {"path": [], "data_transformations": []} + + paths = await self._bfs_paths(start, end) + if not paths: + return {"path": [], "data_transformations": []} + + best_path = paths[0] + path_nodes = [] + for nid in best_path: + node = await self._graph.get_node(nid) + if node: + path_nodes.append({"node": node.structural.name, "type": node.type.value}) + + transformations: list[str] = [] + for i in range(len(path_nodes) - 1): + transformations.append(f"{path_nodes[i]['node']} → {path_nodes[i + 1]['node']}") + + return { + "path": path_nodes, + "data_transformations": transformations, + } + + async def _bfs_paths(self, start_id: str, end_id: str) -> list[list[str]]: + """BFS to find shortest paths.""" + found_paths: list[list[str]] = [] + queue: deque[tuple[str, list[str]]] = deque([(start_id, [start_id])]) + visited: set[str] = set() + + while queue and len(found_paths) < 3: + current, path = queue.popleft() + if len(path) > 20: + continue + + edges = await self._graph.get_edges(current, direction="outgoing") + edges += await self._graph.get_edges(current, direction="incoming") + + neighbors: set[str] = set() + for e in edges: + neighbors.add(e.target_id if e.source_id == current else e.source_id) + + for neighbor in neighbors: + if neighbor == end_id: + found_paths.append(path + [neighbor]) + continue + if neighbor not in visited and neighbor not in path: + visited.add(neighbor) + queue.append((neighbor, path + [neighbor])) + + return found_paths + + async def diff( + self, + from_snapshot: str, + to_snapshot: str, + scope: str = "full", + ) -> dict[str, Any]: + """Compare two snapshots and return the differences.""" + from_nodes = await self._graph.find_nodes_by_scope(from_snapshot) + to_nodes = await self._graph.find_nodes_by_scope(to_snapshot) + + from_ids = {n.id for n in from_nodes} + to_ids = {n.id for n in to_nodes} + + added = to_ids - from_ids + removed = from_ids - to_ids + common = from_ids & to_ids + + changed: list[str] = [] + for node_id in common: + from_node = next((n for n in from_nodes if n.id == node_id), None) + to_node = next((n for n in to_nodes if n.id == node_id), None) + if from_node and to_node and from_node.semantic.source_hash != to_node.semantic.source_hash: + changed.append(node_id) + + return { + "from_snapshot": from_snapshot, + "to_snapshot": to_snapshot, + "added": list(added), + "removed": list(removed), + "changed": changed, + "stats": { + "added_count": len(added), + "removed_count": len(removed), + "changed_count": len(changed), + }, + } + + async def plan( + self, + change_description: str, + target_file: str, + change_type: str = "refactor", + scope: str = "full", + ) -> dict[str, Any]: + """Generate a change plan for proposed modifications.""" + file_nodes = await self._graph.find_nodes(file_path=target_file) + + affected_nodes: list[str] = [] + for node in file_nodes: + callers = await self._graph.traverse(node.id, EdgeType.CALLS, depth=10, max_nodes=200, direction="incoming") + if callers: + affected_nodes.append(node.id) + + steps: list[dict[str, str]] = [ + {"step": "1", "action": "Backup current state", "details": f"Snapshot {target_file}"}, + {"step": "2", "action": "Apply changes", "details": change_description}, + {"step": "3", "action": "Run tests", "details": f"Test affected nodes: {len(affected_nodes)}"}, + ] + + if change_type == "signature_change": + steps.append( + { + "step": "4", + "action": "Update callers", + "details": f"Update {len(affected_nodes)} dependent functions", + } + ) + + return { + "change_description": change_description, + "target_file": target_file, + "change_type": change_type, + "affected_nodes": affected_nodes, + "steps": steps, + "risk_level": "high" if len(affected_nodes) > 10 else "medium" if affected_nodes else "low", + } + + async def conflict( + self, + entity: str, + proposed_change: str, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Check for conflicts in proposed changes.""" + node = await self._graph.get_node(entity) + if not node: + candidates = await self._graph.find_nodes(name=entity) + if candidates: + node = candidates[0] + + if not node: + return {"conflict": False, "reason": f"Entity {entity} not found"} + + edges = await self._graph.get_edges(node.id, direction="incoming") + callers = [e.source_id for e in edges if e.type == EdgeType.CALLS] + + conflicts: list[str] = [] + warnings: list[str] = [] + + if len(callers) > 5: + conflicts.append(f"Entity has {len(callers)} callers - high blast radius") + + if node.semantic.manually_set: + warnings.append("Entity has manually set annotations - may need re-annotation") + + if context and context.get("session_id"): + locked_files = context.get("locked_files", []) + if node.file_path in locked_files: + conflicts.append(f"File {node.file_path} is locked by another session") + + return { + "entity": entity, + "proposed_change": proposed_change, + "conflict": len(conflicts) > 0, + "conflicts": conflicts, + "warnings": warnings, + "caller_count": len(callers), + } + + async def why( + self, + entity: str, + relationship: str = "", + depth: int = 3, + ) -> dict[str, Any]: + """Explain why a relationship exists between entities.""" + node = await self._graph.get_node(entity) + + # If exact match fails, try to find by file path or name + if not node: + if "/" in entity or entity.endswith(".py"): + candidates = await self._graph.find_nodes(file_path=entity) + if candidates: + node = candidates[0] + else: + candidates = await self._graph.find_nodes(name=entity) + if candidates: + node = candidates[0] + + if not node: + return {"error": f"Entity {entity} not found"} + + reasons: list[dict[str, Any]] = [] + + incoming = await self._graph.get_edges(node.id, direction="incoming") + outgoing = await self._graph.get_edges(node.id, direction="outgoing") + + for edge in incoming[:depth]: + source = await self._graph.get_node(edge.source_id) + if source: + reasons.append( + { + "type": "incoming", + "edge_type": edge.type.value, + "from": source.structural.name, + "file": source.file_path, + "reason": f"{source.structural.name} {edge.type.value} {node.structural.name}", + } + ) + + for edge in outgoing[:depth]: + target = await self._graph.get_node(edge.target_id) + if target: + reasons.append( + { + "type": "outgoing", + "edge_type": edge.type.value, + "to": target.structural.name, + "file": target.file_path, + "reason": f"{node.structural.name} {edge.type.value} {target.structural.name}", + } + ) + + return { + "entity": entity, + "name": node.structural.name, + "file": node.file_path, + "reasons": reasons, + "total_relationships": len(incoming) + len(outgoing), + } + + async def diff_file( + self, + file_path: str, + proposed_content: str | None = None, + ) -> dict[str, Any]: + """Compare current graph state of a file against proposed new content.""" + current_nodes = await self._graph.find_nodes(file_path=file_path) + current_node_ids = {n.id for n in current_nodes} + current_calls: dict[str, set[str]] = {n.id: set() for n in current_nodes} + + for node in current_nodes: + edges = await self._graph.get_edges(node.id, direction="outgoing") + for e in edges: + if e.type == EdgeType.CALLS: + current_calls[node.id].add(e.target_id) + + if proposed_content: + from smp.parser.base import detect_language + from smp.parser.registry import ParserRegistry + + registry = ParserRegistry() + lang = detect_language(file_path) + parser = registry.get(lang) + if not parser: + from smp.core.models import Language + + parser = registry.get(Language.PYTHON) + if parser: + proposed_data = parser.parse(proposed_content, file_path) + proposed_node_ids = {n.id for n in proposed_data.nodes} + else: + proposed_node_ids = current_node_ids + else: + proposed_node_ids = current_node_ids + + nodes_added = list(proposed_node_ids - current_node_ids) + nodes_removed = list(current_node_ids - proposed_node_ids) + nodes_modified: list[str] = [] + + return { + "nodes_added": nodes_added, + "nodes_removed": nodes_removed, + "nodes_modified": nodes_modified, + "relationships_added": [], + "relationships_removed": [], + } + + async def plan_multi_file( + self, + session_id: str, + task: str, + intended_writes: list[str], + ) -> dict[str, Any]: + """Validate and rank a multi-file task before execution.""" + file_dependencies: dict[str, set[str]] = {} + + for file_path in intended_writes: + nodes = await self._graph.find_nodes(file_path=file_path) + deps = set() + for node in nodes: + edges = await self._graph.get_edges(node.id, direction="outgoing") + for e in edges: + if e.type == EdgeType.CALLS: + deps.add(e.target_id) + file_dependencies[file_path] = deps + + execution_order = [] + for i, file_path in enumerate(intended_writes, 1): + current_nodes = await self._graph.find_nodes(file_path=file_path) + dependants = 0 + for fp in intended_writes: + if fp != file_path: + for n in current_nodes: + if n.id in file_dependencies.get(fp, set()): + dependants += 1 + + outgoing = [] + for n in current_nodes: + edges = await self._graph.get_edges(n.id, direction="outgoing") + outgoing.extend([e.target_id for e in edges]) + + execution_order.append( + { + "step": i, + "file": file_path, + "dependants_in_plan": dependants, + "dependencies_in_plan": len(outgoing), + "blast_radius": dependants, + "risk_level": "high" if dependants > 3 else "medium" if dependants > 0 else "low", + } + ) + + return { + "execution_order": execution_order, + "inter_file_conflicts": [], + "external_files_at_risk": [], + } + + async def detect_conflict( + self, + session_a: str, + session_b: str, + ) -> dict[str, Any]: + """Detect scope overlap between two planned sessions.""" + return { + "has_conflict": False, + "overlapping_files": [], + "conflicting_nodes": [], + } diff --git a/smp/engine/runtime_linker.py b/smp/engine/runtime_linker.py new file mode 100644 index 0000000..83643a8 --- /dev/null +++ b/smp/engine/runtime_linker.py @@ -0,0 +1,212 @@ +"""Runtime linker for tracking actual execution paths. + +Records CALLS_RUNTIME edges based on telemetry data to build +a runtime call graph that complements the static analysis. +""" + +from __future__ import annotations + +import uuid +from collections import defaultdict +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +from smp.core.models import EdgeType, GraphEdge, RuntimeEdge, RuntimeTrace +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +@dataclass +class RuntimeCall: + """A single runtime call observation.""" + + source_id: str + target_id: str + timestamp: str + session_id: str + duration_ms: int = 0 + + +class RuntimeLinker: + """Tracks and records runtime execution paths.""" + + def __init__(self) -> None: + self._calls: list[RuntimeCall] = [] + self._traces: dict[str, RuntimeTrace] = {} + self._session_traces: dict[str, list[str]] = defaultdict(list) + self._call_counts: dict[tuple[str, str], int] = defaultdict(int) + + def record_call( + self, + source_id: str, + target_id: str, + session_id: str, + duration_ms: int = 0, + ) -> RuntimeEdge: + """Record a runtime call observation.""" + trace_id = f"trace_{uuid.uuid4().hex[:8]}" + timestamp = datetime.now(UTC).isoformat() + + call = RuntimeCall( + source_id=source_id, + target_id=target_id, + timestamp=timestamp, + session_id=session_id, + duration_ms=duration_ms, + ) + self._calls.append(call) + + key = (source_id, target_id) + self._call_counts[key] += 1 + + edge = RuntimeEdge( + source_id=source_id, + target_id=target_id, + edge_type="CALLS_RUNTIME", + timestamp=timestamp, + session_id=session_id, + trace_id=trace_id, + duration_ms=duration_ms, + ) + + self._session_traces[session_id].append(trace_id) + + log.debug( + "runtime_call_recorded", + source=source_id, + target=target_id, + session=session_id, + ) + return edge + + def start_trace( + self, + session_id: str, + agent_id: str, + ) -> str: + """Start a new runtime trace.""" + trace_id = f"trc_{uuid.uuid4().hex[:8]}" + timestamp = datetime.now(UTC).isoformat() + + trace = RuntimeTrace( + trace_id=trace_id, + session_id=session_id, + agent_id=agent_id, + started_at=timestamp, + ) + self._traces[trace_id] = trace + + log.info("trace_started", trace_id=trace_id, session=session_id) + return trace_id + + def end_trace(self, trace_id: str) -> RuntimeTrace | None: + """End a runtime trace.""" + trace = self._traces.get(trace_id) + if not trace: + return None + + trace.ended_at = datetime.now(UTC).isoformat() + + related_calls = [c for c in self._calls if c.session_id == trace.session_id] + trace.edges = [ + RuntimeEdge( + source_id=c.source_id, + target_id=c.target_id, + edge_type="CALLS_RUNTIME", + timestamp=c.timestamp, + session_id=c.session_id, + trace_id=trace_id, + duration_ms=c.duration_ms, + ) + for c in related_calls + ] + + visited: set[str] = set() + for edge in trace.edges: + visited.add(edge.source_id) + visited.add(edge.target_id) + trace.nodes_visited = list(visited) + + log.info( + "trace_ended", + trace_id=trace_id, + edges=len(trace.edges), + nodes=len(trace.nodes_visited), + ) + return trace + + def get_trace(self, trace_id: str) -> RuntimeTrace | None: + """Get trace by ID.""" + return self._traces.get(trace_id) + + def get_session_traces(self, session_id: str) -> list[RuntimeTrace]: + """Get all traces for a session.""" + trace_ids = self._session_traces.get(session_id, []) + return [self._traces[tid] for tid in trace_ids if tid in self._traces] + + def get_hot_paths(self, threshold: int = 10) -> list[dict[str, Any]]: + """Return frequently executed paths.""" + hot = [] + + for (source, target), count in self._call_counts.items(): + if count >= threshold: + hot.append( + { + "source_id": source, + "target_id": target, + "call_count": count, + } + ) + + hot.sort(key=lambda x: -int(x["call_count"])) + return hot + + def get_stats(self) -> dict[str, Any]: + """Return runtime linker statistics.""" + return { + "total_calls": len(self._calls), + "unique_paths": len(self._call_counts), + "active_traces": len(self._traces), + "sessions_with_traces": len(self._session_traces), + } + + def clear(self) -> None: + """Clear all runtime data.""" + self._calls.clear() + self._traces.clear() + self._session_traces.clear() + self._call_counts.clear() + log.info("runtime_linker_cleared") + + async def inject_runtime_edges(self, graph_store: GraphStore) -> int: + """Inject recorded runtime calls as edges into the graph store. + + Args: + graph_store: The graph store to update. + + Returns: + Number of edges injected. + """ + edges_to_inject: list[GraphEdge] = [] + + for call in self._calls: + edge = GraphEdge( + source_id=call.source_id, + target_id=call.target_id, + type=EdgeType.CALLS_RUNTIME, + metadata={ + "timestamp": call.timestamp, + "session_id": call.session_id, + "duration_ms": str(call.duration_ms), + }, + ) + edges_to_inject.append(edge) + + if edges_to_inject: + await graph_store.upsert_edges(edges_to_inject) + + log.info("runtime_edges_injected", count=len(edges_to_inject)) + return len(edges_to_inject) diff --git a/smp/engine/safety.py b/smp/engine/safety.py new file mode 100644 index 0000000..ccfbf28 --- /dev/null +++ b/smp/engine/safety.py @@ -0,0 +1,590 @@ +"""Agent Safety Protocol — sessions, guards, dry-runs, locks, checkpoints, audit. + +Implements the full SMP(3) agent write lifecycle: + session/open → guard/check → dryrun → checkpoint → write → update → session/close +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from smp.logging import get_logger + +if TYPE_CHECKING: + from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + +_SESSION_TTL_SECONDS = 3600 + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass +class Session: + """Represents an active agent session.""" + + session_id: str + agent_id: str + task: str + scope: list[str] + mode: str + granted_scope: list[str] + denied_scope: list[str] + opened_at: str + expires_at: str + status: str = "open" + files_written: list[str] = field(default_factory=list) + files_read: list[str] = field(default_factory=list) + + +@dataclass +class AuditEvent: + """A single event in the audit log.""" + + timestamp: str + method: str + target: str = "" + result: str = "" + checkpoint_id: str = "" + files: list[str] = field(default_factory=list) + + +@dataclass +class AuditLog: + """Full audit record for a session.""" + + audit_log_id: str + agent_id: str + task: str + session_id: str + opened_at: str + closed_at: str = "" + status: str = "open" + events: list[AuditEvent] = field(default_factory=list) + + +@dataclass +class Checkpoint: + """Snapshot of files before a write.""" + + checkpoint_id: str + session_id: str + files: dict[str, str] + snapshot_at: str + + +# --------------------------------------------------------------------------- +# Session Manager +# --------------------------------------------------------------------------- + + +class SessionManager: + """Manages agent session lifecycle with scope enforcement and auto-expiry.""" + + def __init__( + self, + ttl_seconds: int = _SESSION_TTL_SECONDS, + graph_store: GraphStore | None = None, + ) -> None: + self._sessions: dict[str, Session] = {} + self._ttl = ttl_seconds + self._graph = graph_store + + def set_graph_store(self, graph_store: GraphStore) -> None: + """Set the graph store for session persistence.""" + self._graph = graph_store + + async def _persist_session(self, session: Session) -> None: + """Persist session to graph if available.""" + if self._graph: + await self._graph.upsert_session(session) + + async def _load_session(self, session_id: str) -> Session | None: + """Load session from graph if available.""" + if self._graph: + data = await self._graph.get_session(session_id) + if data: + return Session( + session_id=data["session_id"], + agent_id=data["agent_id"], + task=data["task"], + scope=data.get("scope", []), + mode=data.get("mode", "read"), + granted_scope=data.get("granted_scope", []), + denied_scope=data.get("denied_scope", []), + opened_at=data["opened_at"], + expires_at=data["expires_at"], + status=data.get("status", "open"), + files_written=data.get("files_written", []), + files_read=data.get("files_read", []), + ) + return None + + async def open_session( + self, + agent_id: str, + task: str, + scope: list[str], + mode: str = "read", + ) -> dict[str, Any]: + """Open a new session and return the result dict.""" + session_id = f"ses_{uuid.uuid4().hex[:6]}" + now = datetime.now(UTC) + expires = now.timestamp() + self._ttl + + granted = [] + denied = [] + warnings = [] + + for path in scope: + p = Path(path) + if p.exists() or not p.suffix: + granted.append(path) + else: + denied.append(path) + + for path in granted: + caller_count = 0 + if caller_count > 10: + warnings.append(f"{path} is imported by {caller_count} files — changes have wide blast radius") + + session = Session( + session_id=session_id, + agent_id=agent_id, + task=task, + scope=scope, + mode=mode, + granted_scope=granted, + denied_scope=denied, + opened_at=now.isoformat(), + expires_at=datetime.fromtimestamp(expires, tz=UTC).isoformat(), + ) + self._sessions[session_id] = session + await self._persist_session(session) + + log.info("session_opened", session_id=session_id, agent_id=agent_id, mode=mode) + return { + "session_id": session_id, + "granted_scope": granted, + "denied_scope": denied, + "active_locks": [], + "warnings": warnings, + "expires_at": session.expires_at, + } + + async def close_session(self, session_id: str, status: str = "completed") -> dict[str, Any] | None: + """Close a session and return summary.""" + session = self._sessions.get(session_id) + if not session: + return None + + session.status = status + now = datetime.now(UTC) + opened = datetime.fromisoformat(session.opened_at) + duration_ms = int((now - opened).total_seconds() * 1000) + + audit_log_id = f"aud_{uuid.uuid4().hex[:6]}" + + log.info("session_closed", session_id=session_id, status=status, duration_ms=duration_ms) + + if self._graph: + await self._graph.delete_session(session_id) + + return { + "session_id": session_id, + "files_written": session.files_written, + "files_read": session.files_read, + "duration_ms": duration_ms, + "audit_log_id": audit_log_id, + } + + async def get_session(self, session_id: str) -> Session | None: + """Get a session by ID, checking expiry.""" + session = self._sessions.get(session_id) + if not session: + return await self._load_session(session_id) + if session.status != "open": + return None + expires = datetime.fromisoformat(session.expires_at) + if datetime.now(UTC) > expires: + session.status = "expired" + return None + return session + + async def is_in_scope(self, session_id: str, file_path: str) -> bool: + """Check if file_path is within the session's granted scope.""" + session = await self.get_session(session_id) + if not session: + return False + return any(file_path == granted or file_path.startswith(granted) for granted in session.granted_scope) + + def record_file_access(self, session_id: str, file_path: str, access_type: str = "read") -> None: + """Record that a file was read or written in this session.""" + session = self._sessions.get(session_id) + if not session: + return + if access_type == "write" and file_path not in session.files_written: + session.files_written.append(file_path) + elif access_type == "read" and file_path not in session.files_read: + session.files_read.append(file_path) + + async def recover_session(self, session_id: str) -> dict[str, Any] | None: + """Recover a session from persistent storage.""" + session = await self._load_session(session_id) + if not session: + return None + self._sessions[session_id] = session + log.info("session_recovered", session_id=session_id) + return { + "session_id": session.session_id, + "agent_id": session.agent_id, + "task": session.task, + "scope": session.scope, + "mode": session.mode, + "opened_at": session.opened_at, + "expires_at": session.expires_at, + "status": session.status, + } + + +# --------------------------------------------------------------------------- +# Lock Manager +# --------------------------------------------------------------------------- + + +class LockManager: + """File-level locking to prevent concurrent writes.""" + + def __init__(self, graph_store: GraphStore | None = None) -> None: + self._locks: dict[str, str] = {} + self._graph = graph_store + + def set_graph_store(self, graph_store: GraphStore) -> None: + """Set the graph store for lock persistence.""" + self._graph = graph_store + + async def acquire(self, session_id: str, files: list[str]) -> dict[str, Any]: + """Acquire locks on files for a session.""" + granted = [] + denied = [] + for f in files: + if f in self._locks: + holder = self._locks[f] + if holder == session_id: + granted.append(f) + else: + denied.append(f) + else: + self._locks[f] = session_id + granted.append(f) + if self._graph: + await self._graph.upsert_lock(f, session_id) + + log.info("locks_acquired", session_id=session_id, granted=len(granted), denied=len(denied)) + return {"granted": granted, "denied": denied} + + async def release(self, session_id: str, files: list[str]) -> None: + """Release locks held by a session.""" + for f in files: + if self._locks.get(f) == session_id: + del self._locks[f] + if self._graph: + await self._graph.release_lock(f, session_id) + log.info("locks_released", session_id=session_id, files=len(files)) + + async def release_all(self, session_id: str) -> None: + """Release all locks held by a session.""" + to_release = [f for f, sid in self._locks.items() if sid == session_id] + for f in to_release: + del self._locks[f] + if self._graph: + await self._graph.release_all_locks(session_id) + + def is_locked(self, file_path: str) -> str | None: + """Return session_id that holds the lock, or None.""" + return self._locks.get(file_path) + + +# --------------------------------------------------------------------------- +# Guard Engine +# --------------------------------------------------------------------------- + + +class GuardEngine: + """Pre-flight safety checks before writing a file.""" + + def __init__(self, session_manager: SessionManager, lock_manager: LockManager) -> None: + self._sessions = session_manager + self._locks = lock_manager + + async def check( + self, + session_id: str, + target: str, + intended_change: str = "", + caller_count: int = 0, + has_tests: bool = False, + test_files: list[str] | None = None, + is_public_api: bool = False, + has_downstream: bool = False, + ) -> dict[str, Any]: + """Run pre-flight checks and return verdict.""" + reasons: list[str] = [] + warnings: list[str] = [] + checks: dict[str, Any] = {} + + session = await self._sessions.get_session(session_id) + if not session: + return {"verdict": "blocked", "reasons": ["Session not found or expired"]} + + in_scope = await self._sessions.is_in_scope(session_id, target) + locked_by = self._locks.is_locked(target) + locked_by_other = locked_by is not None and locked_by != session_id + + checks["in_declared_scope"] = in_scope + checks["locked_by_other_agent"] = locked_by_other + checks["has_tests"] = has_tests + checks["test_files"] = test_files or [] + checks["caller_count"] = caller_count + checks["is_public_api"] = is_public_api + checks["has_downstream_services"] = has_downstream + + if not in_scope: + reasons.append("File is outside declared session scope") + if locked_by_other: + reasons.append(f"Locked by session {locked_by}") + + if caller_count > 5: + warnings.append(f"Target has {caller_count} callers — changes will cascade") + if is_public_api: + warnings.append("Target is part of public API — signature changes are breaking") + if not has_tests and caller_count > 0: + warnings.append("No test coverage found — manual verification recommended") + + verdict = "blocked" if reasons else "clear" + + result: dict[str, Any] = { + "verdict": verdict, + "target": target, + "checks": checks, + "warnings": warnings, + } + if reasons: + result["reasons"] = reasons + + log.info("guard_check", target=target, verdict=verdict, session_id=session_id) + return result + + +# --------------------------------------------------------------------------- +# Dry Run Simulator +# --------------------------------------------------------------------------- + + +class DryRunSimulator: + """Simulate structural impact of proposed changes without disk writes.""" + + def __init__(self) -> None: + pass + + def simulate( + self, + session_id: str, + file_path: str, + proposed_content: str, + change_summary: str = "", + current_signature: str = "", + proposed_signature: str = "", + affected_files: list[str] | None = None, + broken_callers: list[dict[str, str]] | None = None, + ) -> dict[str, Any]: + """Simulate the write and return structural delta + verdict.""" + signature_changed = bool(current_signature and proposed_signature and current_signature != proposed_signature) + + nodes_added = 0 + nodes_modified = 1 + nodes_removed = 0 + + risks: list[str] = [] + if signature_changed: + risks.append("Signature change detected — may break callers") + if affected_files: + risks.append(f"{len(affected_files)} files may need updates") + if broken_callers: + for bc in broken_callers: + risks.append( + f"{bc.get('function', '?')} in {bc.get('file', '?')}: {bc.get('reason', 'incompatible change')}" + ) + + verdict = "breaking" if (signature_changed and (broken_callers or affected_files)) else "safe" + + result: dict[str, Any] = { + "structural_delta": { + "nodes_added": nodes_added, + "nodes_modified": nodes_modified, + "nodes_removed": nodes_removed, + "signature_changed": signature_changed, + }, + "impact": { + "affected_files": affected_files or [], + "broken_callers": broken_callers or [], + "test_coverage_delta": "unchanged", + }, + "verdict": verdict, + "risks": risks, + } + + log.info("dryrun_complete", file_path=file_path, verdict=verdict, session_id=session_id) + return result + + +# --------------------------------------------------------------------------- +# Checkpoint Manager +# --------------------------------------------------------------------------- + + +class CheckpointManager: + """Snapshot and restore file state.""" + + def __init__(self) -> None: + self._checkpoints: dict[str, Checkpoint] = {} + + def create(self, session_id: str, files: list[str]) -> dict[str, Any]: + """Create a checkpoint by snapshotting file contents.""" + checkpoint_id = f"chk_{uuid.uuid4().hex[:6]}" + now = datetime.now(UTC).isoformat() + + snapshots: dict[str, str] = {} + snapshotted: list[str] = [] + for f in files: + try: + content = Path(f).read_text(encoding="utf-8") + snapshots[f] = content + snapshotted.append(f) + except OSError: + log.warning("checkpoint_file_unreadable", file=f) + + checkpoint = Checkpoint( + checkpoint_id=checkpoint_id, + session_id=session_id, + files=snapshots, + snapshot_at=now, + ) + self._checkpoints[checkpoint_id] = checkpoint + + log.info("checkpoint_created", checkpoint_id=checkpoint_id, files=len(snapshotted)) + return { + "checkpoint_id": checkpoint_id, + "files_snapshotted": snapshotted, + "snapshot_at": now, + } + + def rollback(self, checkpoint_id: str) -> dict[str, Any]: + """Restore files from a checkpoint.""" + checkpoint = self._checkpoints.get(checkpoint_id) + if not checkpoint: + return {"status": "error", "reason": "Checkpoint not found"} + + restored: list[str] = [] + for f, content in checkpoint.files.items(): + try: + Path(f).write_text(content, encoding="utf-8") + restored.append(f) + except OSError as exc: + log.error("rollback_write_failed", file=f, error=str(exc)) + + log.info("rollback_complete", checkpoint_id=checkpoint_id, restored=len(restored)) + return { + "status": "rolled_back", + "files_restored": restored, + "memory_resynced": True, + } + + +# --------------------------------------------------------------------------- +# Audit Logger +# --------------------------------------------------------------------------- + + +class AuditLogger: + """Persistent append-only audit log for session events.""" + + def __init__(self) -> None: + self._logs: dict[str, AuditLog] = {} + + def create_log(self, agent_id: str, task: str, session_id: str) -> str: + """Create a new audit log for a session.""" + audit_log_id = f"aud_{uuid.uuid4().hex[:6]}" + now = datetime.now(UTC).isoformat() + self._logs[audit_log_id] = AuditLog( + audit_log_id=audit_log_id, + agent_id=agent_id, + task=task, + session_id=session_id, + opened_at=now, + ) + return audit_log_id + + def append_event( + self, + audit_log_id: str, + method: str, + target: str = "", + result: str = "", + checkpoint_id: str = "", + files: list[str] | None = None, + ) -> None: + """Append an event to an audit log.""" + log_entry = self._logs.get(audit_log_id) + if not log_entry: + return + event = AuditEvent( + timestamp=datetime.now(UTC).strftime("%H:%M:%S"), + method=method, + target=target, + result=result, + checkpoint_id=checkpoint_id, + files=files or [], + ) + log_entry.events.append(event) + + def close_log(self, audit_log_id: str, status: str = "completed") -> None: + """Mark an audit log as closed.""" + log_entry = self._logs.get(audit_log_id) + if log_entry: + log_entry.closed_at = datetime.now(UTC).isoformat() + log_entry.status = status + + def get_log(self, audit_log_id: str) -> dict[str, Any] | None: + """Retrieve an audit log.""" + log_entry = self._logs.get(audit_log_id) + if not log_entry: + return None + return { + "audit_log_id": log_entry.audit_log_id, + "agent_id": log_entry.agent_id, + "task": log_entry.task, + "session_id": log_entry.session_id, + "opened_at": log_entry.opened_at, + "closed_at": log_entry.closed_at, + "status": log_entry.status, + "events": [ + { + "t": e.timestamp, + "method": e.method, + "target": e.target, + "result": e.result, + "checkpoint_id": e.checkpoint_id, + "files": e.files, + } + for e in log_entry.events + ], + } diff --git a/smp/engine/seed_walk.py b/smp/engine/seed_walk.py new file mode 100644 index 0000000..e77c1be --- /dev/null +++ b/smp/engine/seed_walk.py @@ -0,0 +1,443 @@ +"""SeedWalkEngine — community-routed graph RAG pipeline for smp/locate. + +Phase 0 — ROUTE: Compare query embedding against community centroids. +Phase 1 — SEED: ChromaDB vector search scoped to community or global. +Phase 2 — WALK: Graph traversal from seeds via CALLS/IMPORTS/DEFINES edges. +Phase 3 — RANK: Composite score = alpha*vector + beta*pagerank + gamma*heat. +Phase 4 — ASSEMBLE: Deduplicated results + structural map. + +No LLM calls at any phase. +""" + +from __future__ import annotations + +from collections import deque +from typing import Any + +import msgspec + +from smp.engine.interfaces import QueryEngine as QueryEngineInterface +from smp.logging import get_logger +from smp.store.interfaces import GraphStore, VectorStore + +log = get_logger(__name__) + +ALPHA = 0.50 +BETA = 0.30 +GAMMA = 0.20 +ROUTE_CONFIDENCE_THRESHOLD = 0.65 +DEFAULT_SEED_K = 3 +DEFAULT_HOPS = 2 +DEFAULT_TOP_K = 10 + + +class SeedNode(msgspec.Struct, frozen=True): + node_id: str = "" + node_type: str = "" + name: str = "" + file: str = "" + signature: str = "" + docstring: str | None = None + tags: list[str] = msgspec.field(default_factory=list) + community_id: str | None = None + vector_score: float = 0.0 + pagerank: float = 0.0 + heat_score: int = 0 + + +class WalkNode(msgspec.Struct, frozen=True): + node_id: str = "" + node_type: str = "" + name: str = "" + file: str = "" + signature: str = "" + docstring: str | None = None + community_id: str | None = None + edge_type: str = "" + edge_direction: str = "" + hop: int = 0 + is_bridge: bool = False + pagerank: float = 0.0 + heat_score: int = 0 + + +class RankedResult(msgspec.Struct, frozen=True): + node_id: str = "" + node_type: str = "" + name: str = "" + file: str = "" + signature: str = "" + docstring: str | None = None + tags: list[str] = msgspec.field(default_factory=list) + community_id: str | None = None + final_score: float = 0.0 + vector_score: float = 0.0 + pagerank: float = 0.0 + heat_score: int = 0 + is_seed: bool = False + reachable_from: list[str] = msgspec.field(default_factory=list) + + +class LocateResponse(msgspec.Struct, frozen=True): + query: str = "" + routed_community: str | None = None + seed_count: int = 0 + total_walked: int = 0 + results: list[RankedResult] = msgspec.field(default_factory=list) + structural_map: list[dict[str, Any]] = msgspec.field(default_factory=list) + + +class SeedWalkEngine(QueryEngineInterface): + """Community-routed graph RAG pipeline for smp/locate.""" + + def __init__( + self, + graph_store: GraphStore, + vector_store: VectorStore | None = None, + enricher: Any | None = None, + alpha: float = ALPHA, + beta: float = BETA, + gamma: float = GAMMA, + route_threshold: float = ROUTE_CONFIDENCE_THRESHOLD, + ) -> None: + self._graph = graph_store + self._vector = vector_store + self._enricher = enricher + self._alpha = alpha + self._beta = beta + self._gamma = gamma + self._route_threshold = route_threshold + + async def _route_to_community(self, query: str) -> tuple[str | None, float]: + if self._vector is None: + return None, 0.0 + try: + results = await self._vector.query( + embedding=_simple_hash_embedding(query), + top_k=1, + where={"collection_type": "centroid"}, + ) + if not results: + return None, 0.0 + best = results[0] + community_id = best.get("metadata", {}).get("community_id") + score = best.get("score", 1.0) + if isinstance(score, (int, float)): + confidence = 1.0 - float(score) + else: + confidence = 0.0 + if confidence < self._route_threshold: + return None, confidence + return community_id, confidence + except Exception: + log.warning("route_community_failed", query=query) + return None, 0.0 + + async def _seed( + self, + query: str, + seed_k: int, + community_id: str | None = None, + ) -> list[SeedNode]: + all_nodes = await self._graph.find_nodes() + terms = query.lower().split() + scored: list[tuple[float, dict[str, Any]]] = [] + for node in all_nodes: + s = 0.0 + name_lower = node.structural.name.lower() + if all(t in name_lower for t in terms): + s += 100.0 + elif any(t in name_lower for t in terms): + s += 50.0 + if node.semantic.docstring: + doc_lower = node.semantic.docstring.lower() + if all(t in doc_lower for t in terms): + s += 30.0 + elif any(t in doc_lower for t in terms): + s += 15.0 + for tag in node.semantic.tags: + if any(t in tag.lower() for t in terms): + s += 10.0 + break + if community_id and hasattr(node.semantic, "tags"): + pass + if s > 0: + scored.append((s, {"node": node, "score": s})) + + if self._vector is not None: + try: + v_results = await self._vector.query( + embedding=_simple_hash_embedding(query), + top_k=seed_k, + ) + for vr in v_results: + node_id = vr.get("id", "") + v_score = vr.get("score", 0.0) + if isinstance(v_score, (int, float)): + v_sim = 1.0 - float(v_score) + else: + v_sim = 0.0 + found = False + for s_item in scored: + if s_item[1].get("node", None) and s_item[1]["node"].id == node_id: + found = True + break + if not found and v_sim > 0.3: + gnode = await self._graph.get_node(node_id) + if gnode: + scored.append((v_sim * 80.0, {"node": gnode, "score": v_sim * 80.0})) + except Exception: + log.warning("vector_seed_failed", query=query) + + scored.sort(key=lambda x: -x[0]) + seeds: list[SeedNode] = [] + for score_val, data in scored[:seed_k]: + node = data["node"] + seeds.append( + SeedNode( + node_id=node.id, + node_type=node.type.value, + name=node.structural.name, + file=node.file_path, + signature=node.structural.signature, + docstring=node.semantic.docstring or None, + tags=node.semantic.tags, + community_id=None, + vector_score=min(score_val / 100.0, 1.0), + pagerank=0.0, + heat_score=0, + ) + ) + return seeds + + async def _walk(self, seed_ids: list[str], hops: int) -> list[WalkNode]: + from smp.core.models import EdgeType + + walked: dict[str, WalkNode] = {} + queue: deque[tuple[str, int]] = deque() + for sid in seed_ids: + queue.append((sid, 0)) + visited: set[str] = set(seed_ids) + + while queue: + current_id, depth = queue.popleft() + if depth >= hops: + continue + node = await self._graph.get_node(current_id) + if not node: + continue + try: + edges_out = await self._graph.get_edges(current_id, direction="outgoing") + except Exception: + edges_out = [] + try: + edges_in = await self._graph.get_edges(current_id, direction="incoming") + except Exception: + edges_in = [] + all_edges = edges_out + edges_in + for edge in all_edges: + if edge.type not in (EdgeType.CALLS, EdgeType.CALLS_RUNTIME, EdgeType.IMPORTS, EdgeType.DEFINES): + continue + neighbor_id = edge.target_id if edge.source_id == current_id else edge.source_id + direction = "out" if edge.source_id == current_id else "in" + if neighbor_id in visited: + continue + visited.add(neighbor_id) + neighbor = await self._graph.get_node(neighbor_id) + if not neighbor: + continue + walked[neighbor_id] = WalkNode( + node_id=neighbor_id, + node_type=neighbor.type.value, + name=neighbor.structural.name, + file=neighbor.file_path, + signature=neighbor.structural.signature, + docstring=neighbor.semantic.docstring or None, + community_id=None, + edge_type=edge.type.value, + edge_direction=direction, + hop=depth + 1, + is_bridge=False, + pagerank=0.0, + heat_score=0, + ) + queue.append((neighbor_id, depth + 1)) + return list(walked.values()) + + def _rank( + self, + seeds: list[SeedNode], + walked: list[WalkNode], + top_k: int, + ) -> list[RankedResult]: + seed_map = {s.node_id: s for s in seeds} + max_pr = max((s.pagerank for s in seeds), default=1.0) or 1.0 + walked_max_pr = max((w.pagerank for w in walked), default=1.0) or 1.0 + max_pr = max(max_pr, walked_max_pr) + + results: dict[str, RankedResult] = {} + for s in seeds: + score = ( + self._alpha * s.vector_score + self._beta * (s.pagerank / max_pr) + self._gamma * (s.heat_score / 100.0) + ) + results[s.node_id] = RankedResult( + node_id=s.node_id, + node_type=s.node_type, + name=s.name, + file=s.file, + signature=s.signature, + docstring=s.docstring, + tags=s.tags, + community_id=s.community_id, + final_score=round(score, 4), + vector_score=s.vector_score, + pagerank=s.pagerank, + heat_score=s.heat_score, + is_seed=True, + reachable_from=[s.node_id], + ) + + for w in walked: + if w.node_id in results: + continue + seed_pr = seed_map.get(w.node_id) + v_score = seed_pr.vector_score if seed_pr else 0.0 + score = self._alpha * v_score + self._beta * (w.pagerank / max_pr) + self._gamma * (w.heat_score / 100.0) + results[w.node_id] = RankedResult( + node_id=w.node_id, + node_type=w.node_type, + name=w.name, + file=w.file, + signature=w.signature, + docstring=w.docstring, + tags=[], + community_id=w.community_id, + final_score=round(score, 4), + vector_score=v_score, + pagerank=w.pagerank, + heat_score=w.heat_score, + is_seed=False, + reachable_from=[], + ) + + ranked = sorted(results.values(), key=lambda r: r.final_score, reverse=True) + return ranked[:top_k] + + def _build_structural_map( + self, + results: list[RankedResult], + walked: list[WalkNode], + ) -> list[dict[str, Any]]: + result_ids = {r.node_id for r in results} + edges: list[dict[str, Any]] = [] + for w in walked: + if w.node_id in result_ids: + edges.append( + { + "from": w.node_id, + "to": w.node_id, + "edge_type": w.edge_type, + "hop": w.hop, + } + ) + return edges + + async def locate( + self, + query: str, + fields: list[str] | None = None, + node_types: list[str] | None = None, + top_k: int = DEFAULT_TOP_K, + ) -> list[dict[str, Any]]: + routed_community, route_confidence = await self._route_to_community(query) + seed_k = min(top_k, DEFAULT_SEED_K) + hops = DEFAULT_HOPS + seeds = await self._seed(query, seed_k, community_id=routed_community) + if node_types: + seeds = [s for s in seeds if s.node_type in node_types] + walked = await self._walk([s.node_id for s in seeds], hops) + if node_types: + walked = [w for w in walked if w.node_type in node_types] + ranked = self._rank(seeds, walked, top_k) + smap = self._build_structural_map(ranked, walked) + + result = LocateResponse( + query=query, + routed_community=routed_community, + seed_count=len(seeds), + total_walked=len(walked), + results=ranked, + structural_map=smap, + ) + + return [msgspec.structs.asdict(result)] + + async def navigate(self, query: str, include_relationships: bool = True) -> dict[str, Any]: + return {} + + async def trace( + self, start: str, relationship: str = "CALLS", depth: int = 3, direction: str = "outgoing" + ) -> list[dict[str, Any]]: + return [] + + async def get_context(self, file_path: str, scope: str = "edit", depth: int = 2) -> dict[str, Any]: + return {} + + async def assess_impact(self, entity: str, change_type: str = "delete") -> dict[str, Any]: + return {} + + async def search( + self, query: str, match: str = "any", filters: dict[str, Any] | None = None, top_k: int = 5 + ) -> dict[str, Any]: + return {} + + async def conflict( + self, + entity: str, + proposed_change: str = "", + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + return {"conflicts": []} + + async def diff( + self, + from_snapshot: str, + to_snapshot: str, + scope: str = "full", + ) -> dict[str, Any]: + return {"diff": {}} + + async def plan( + self, + change_description: str, + target_file: str = "", + change_type: str = "refactor", + scope: str = "full", + ) -> dict[str, Any]: + return {"steps": []} + + async def why( + self, + entity: str, + relationship: str = "", + depth: int = 3, + ) -> dict[str, Any]: + return {"reasoning": []} + + async def find_flow(self, start: str, end: str, flow_type: str = "data") -> dict[str, Any]: + return {} + + +def _simple_hash_embedding(text: str, dim: int = 128) -> list[float]: + """Deterministic hash-based embedding for prototyping. + + Maps text to a fixed-dimension float vector using character + frequency hashing. Production should use a real embedding model. + """ + vec = [0.0] * dim + for i, ch in enumerate(text): + vec[i % dim] += float(ord(ch)) + norm = sum(v * v for v in vec) ** 0.5 + if norm == 0: + return vec + return [v / norm for v in vec] diff --git a/smp/engine/telemetry.py b/smp/engine/telemetry.py new file mode 100644 index 0000000..ba55c8f --- /dev/null +++ b/smp/engine/telemetry.py @@ -0,0 +1,161 @@ +"""Telemetry engine for tracking node hotness and usage patterns. + +Collects runtime statistics to identify hot code paths and frequently +accessed nodes for optimization and safety decisions. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + +_HOT_THRESHOLD = 10 +_HOT_DECAY_SECONDS = 3600 + + +@dataclass +class NodeStats: + """Statistics for a single node.""" + + node_id: str + hit_count: int = 0 + last_hit_at: str = "" + avg_response_time_ms: float = 0.0 + error_count: int = 0 + callers: set[str] = field(default_factory=set) + + def touch(self) -> None: + """Record a hit on this node.""" + self.hit_count += 1 + self.last_hit_at = datetime.now(UTC).isoformat() + + +@dataclass +class TelemetryConfig: + """Configuration for telemetry collection.""" + + hot_threshold: int = _HOT_THRESHOLD + decay_seconds: int = _HOT_DECAY_SECONDS + max_tracked_nodes: int = 10000 + + +class TelemetryEngine: + """Tracks node access patterns and identifies hot nodes.""" + + def __init__(self, config: TelemetryConfig | None = None) -> None: + self._config = config or TelemetryConfig() + self._stats: dict[str, NodeStats] = {} + self._start_time = time.time() + + def record_access( + self, + node_id: str, + caller_id: str | None = None, + response_time_ms: float = 0.0, + error: bool = False, + ) -> None: + """Record an access to a node.""" + stats = self._stats.get(node_id) + if not stats: + if len(self._stats) >= self._config.max_tracked_nodes: + self._evict_cold() + stats = NodeStats(node_id=node_id) + self._stats[node_id] = stats + + stats.touch() + if caller_id: + stats.callers.add(caller_id) + if response_time_ms > 0: + total = stats.avg_response_time_ms * (stats.hit_count - 1) + response_time_ms + stats.avg_response_time_ms = total / stats.hit_count + if error: + stats.error_count += 1 + + log.debug("telemetry_access", node_id=node_id, hit_count=stats.hit_count) + + def get_hot_nodes(self, threshold: int | None = None) -> list[dict[str, Any]]: + """Return nodes exceeding the hot threshold.""" + hot_threshold = threshold or self._config.hot_threshold + hot = [] + + for node_id, stats in self._stats.items(): + if stats.hit_count >= hot_threshold: + hot.append( + { + "node_id": node_id, + "hit_count": stats.hit_count, + "last_hit_at": stats.last_hit_at, + "avg_response_time_ms": stats.avg_response_time_ms, + "error_count": stats.error_count, + "caller_count": len(stats.callers), + } + ) + + hot.sort(key=lambda x: -int(x["hit_count"])) + return hot + + def get_stats(self, node_id: str) -> dict[str, Any] | None: + """Get statistics for a specific node.""" + stats = self._stats.get(node_id) + if not stats: + return None + return { + "node_id": stats.node_id, + "hit_count": stats.hit_count, + "last_hit_at": stats.last_hit_at, + "avg_response_time_ms": stats.avg_response_time_ms, + "error_count": stats.error_count, + "caller_count": len(stats.callers), + } + + def get_summary(self) -> dict[str, Any]: + """Return overall telemetry summary.""" + total_hits = sum(s.hit_count for s in self._stats.values()) + total_errors = sum(s.error_count for s in self._stats.values()) + + return { + "uptime_seconds": int(time.time() - self._start_time), + "total_nodes_tracked": len(self._stats), + "total_hits": total_hits, + "total_errors": total_errors, + "hot_node_count": len(self.get_hot_nodes()), + } + + def decay(self) -> int: + """Decay old statistics to prevent unbounded growth.""" + cutoff = datetime.now(UTC).timestamp() - self._config.decay_seconds + cutoff_str = datetime.fromtimestamp(cutoff, tz=UTC).isoformat() + + to_remove = [ + node_id for node_id, stats in self._stats.items() if stats.last_hit_at and stats.last_hit_at < cutoff_str + ] + + for node_id in to_remove: + del self._stats[node_id] + + if to_remove: + log.info("telemetry_decayed", removed=len(to_remove)) + return len(to_remove) + + def _evict_cold(self) -> None: + """Evict the coldest nodes when at capacity.""" + if not self._stats: + return + + sorted_nodes = sorted(self._stats.items(), key=lambda x: x[1].hit_count) + for node_id, _ in sorted_nodes[: len(sorted_nodes) // 10]: + del self._stats[node_id] + + log.debug("telemetry_evicted", count=len(sorted_nodes) // 10) + + def reset(self) -> None: + """Clear all telemetry data.""" + self._stats.clear() + self._start_time = time.time() + log.info("telemetry_reset") diff --git a/smp/logging.py b/smp/logging.py new file mode 100644 index 0000000..bb899f5 --- /dev/null +++ b/smp/logging.py @@ -0,0 +1,68 @@ +"""Structured logging configuration for SMP. + +Usage: + from smp.logging import get_logger + log = get_logger(__name__) + log.info("graph_updated", nodes=42, edges=97) +""" + +from __future__ import annotations + +import logging +import sys + +import structlog + + +def configure_logging(*, json: bool = False, level: str = "INFO") -> None: + """Initialise structlog + stdlib logging. + + Args: + json: When True, render as newline-delimited JSON (production). + When False, render with colours (development). + level: Minimum log level for the root SMP logger. + """ + shared_processors: list[structlog.types.Processor] = [ + structlog.contextvars.merge_contextvars, + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.UnicodeDecoder(), + ] + + if json: + renderer: structlog.types.Processor = structlog.processors.JSONRenderer() + else: + renderer = structlog.dev.ConsoleRenderer(colors=True) + + structlog.configure( + processors=[*shared_processors, structlog.stdlib.ProcessorFormatter.wrap_for_formatter], + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, + ) + + formatter = structlog.stdlib.ProcessorFormatter( + processors=[ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + renderer, + ], + ) + + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(formatter) + + root = logging.getLogger("smp") + root.handlers.clear() + root.addHandler(handler) + root.setLevel(level.upper()) + + +def get_logger(name: str) -> structlog.stdlib.BoundLogger: + """Return a bound structlog logger scoped to *name*.""" + return structlog.get_logger(name) + + +# Auto-configure with dev defaults on first import. +configure_logging() diff --git a/smp/parser/__init__.py b/smp/parser/__init__.py new file mode 100644 index 0000000..e45fa64 --- /dev/null +++ b/smp/parser/__init__.py @@ -0,0 +1,11 @@ +"""Parser layer — AST extraction via tree-sitter.""" + +from smp.parser.base import TreeSitterParser, detect_language, make_node_id +from smp.parser.registry import ParserRegistry + +__all__ = [ + "TreeSitterParser", + "detect_language", + "make_node_id", + "ParserRegistry", +] diff --git a/smp/parser/base.py b/smp/parser/base.py new file mode 100644 index 0000000..cfddd40 --- /dev/null +++ b/smp/parser/base.py @@ -0,0 +1,153 @@ +"""Abstract tree-sitter parser and language detection utilities.""" + +from __future__ import annotations + +import abc +from pathlib import Path + +import tree_sitter as ts + +from smp.core.models import Document, GraphEdge, GraphNode, Language, NodeType, ParseError +from smp.engine.interfaces import Parser +from smp.logging import get_logger + +log = get_logger(__name__) + +_EXT_TO_LANG: dict[str, Language] = { + ".py": Language.PYTHON, + ".ts": Language.TYPESCRIPT, + ".tsx": Language.TYPESCRIPT, + ".js": Language.TYPESCRIPT, + ".jsx": Language.TYPESCRIPT, +} + +# Extensions that use the TSX grammar variant +_TSX_EXTS = {".tsx", ".jsx"} + + +def detect_language(file_path: str) -> Language: + """Guess language from file extension.""" + suffix = Path(file_path).suffix.lower() + return _EXT_TO_LANG.get(suffix, Language.UNKNOWN) + + +def is_tsx(file_path: str) -> bool: + """Return True if the file uses JSX/TSX syntax.""" + return Path(file_path).suffix.lower() in _TSX_EXTS + + +def make_node_id(file_path: str, type: NodeType, name: str, start_line: int) -> str: + """Deterministic node ID from structural coordinates.""" + return f"{file_path}::{type.value}::{name}::{start_line}" + + +def node_text(node: ts.Node) -> str: + """Safely extract text from a tree-sitter node.""" + if node.text: + return node.text.decode("utf-8", errors="replace") + return "" + + +def line_range(node: ts.Node) -> tuple[int, int]: + """Return (start_line, end_line) as 1-indexed line numbers.""" + return node.start_point[0] + 1, node.end_point[0] + 1 + + +class TreeSitterParser(Parser, abc.ABC): + """Abstract base for tree-sitter language parsers. + + Subclasses provide the grammar language object and extraction logic. + The base class handles parsing, error recovery, and Document assembly. + """ + + @abc.abstractmethod + def _language(self, file_path: str) -> ts.Language: + """Return the tree-sitter Language object for *file_path*.""" + + @abc.abstractmethod + def _extract( + self, + root_node: ts.Node, + source_bytes: bytes, + file_path: str, + ) -> tuple[list[GraphNode], list[GraphEdge], list[ParseError]]: + """Extract nodes, edges, and errors from a parsed AST. + + Returns a tuple of (nodes, edges, errors). + """ + + @property + @abc.abstractmethod + def supported_languages(self) -> list[str]: ... + + def parse(self, source: str, file_path: str) -> Document: + lang = detect_language(file_path) + if lang == Language.UNKNOWN: + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"Unsupported language for {file_path}")], + ) + + source_bytes = source.encode("utf-8") + + try: + ts_lang = self._language(file_path) + parser = ts.Parser(ts_lang) + tree = parser.parse(source_bytes) + except Exception as exc: + log.error("parse_crash", file_path=file_path, error=str(exc)) + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"Parser crash: {exc}")], + ) + + errors: list[ParseError] = [] + nodes: list[GraphNode] = [] + edges: list[GraphEdge] = [] + + try: + nodes, edges, errors = self._extract(tree.root_node, source_bytes, file_path) + except Exception as exc: + log.error("extract_error", file_path=file_path, error=str(exc)) + errors.append(ParseError(message=f"Extraction error: {exc}")) + + # Detect tree-sitter error nodes + self._collect_syntax_errors(tree.root_node, source_bytes, errors) + + log.debug( + "file_parsed", + file_path=file_path, + lang=lang.value, + nodes=len(nodes), + edges=len(edges), + errors=len(errors), + ) + return Document( + file_path=file_path, + language=lang, + nodes=nodes, + edges=edges, + errors=errors, + ) + + @staticmethod + def _collect_syntax_errors( + node: ts.Node, + source: bytes, + errors: list[ParseError], + ) -> None: + """Walk the tree and collect ERROR / MISSING nodes.""" + if node.is_error or node.is_missing: + row, col = node.start_point + text = node.text.decode("utf-8", errors="replace")[:80] if node.text else "" + errors.append( + ParseError( + message=f"Syntax {'missing' if node.is_missing else 'error'}: {text}", + line=row + 1, + column=col, + ) + ) + for child in node.children: + TreeSitterParser._collect_syntax_errors(child, source, errors) diff --git a/smp/parser/python_parser.py b/smp/parser/python_parser.py new file mode 100644 index 0000000..16a72f2 --- /dev/null +++ b/smp/parser/python_parser.py @@ -0,0 +1,553 @@ +"""Python-specific tree-sitter parser. + +Extracts functions, classes, methods, imports, decorators, inline comments, +and type annotations from Python source using the ``tree-sitter-python`` grammar. +""" + +from __future__ import annotations + +import re + +import tree_sitter as ts +import tree_sitter_python as tsp + +from smp.core.models import ( + Annotations, + EdgeType, + GraphEdge, + GraphNode, + NodeType, + ParseError, + SemanticProperties, + StructuralProperties, +) +from smp.logging import get_logger +from smp.parser.base import TreeSitterParser, line_range, make_node_id, node_text + +log = get_logger(__name__) + +_LANGUAGE = ts.Language(tsp.language()) + +_CALL_QUERY = ts.Query( + _LANGUAGE, + """ +(call function: (identifier) @callee) @call +(call function: (attribute) @callee) @call +""", +) + +_COMMENT_QUERY = ts.Query( + _LANGUAGE, + """ +(comment) @comment +""", +) + + +def _compute_complexity(body: ts.Node) -> int: + """Estimate cyclomatic complexity from AST body node.""" + complexity = 1 + cursor = body.walk() + stack: list[ts.Node] = [cursor.node] if cursor.node else [] + while stack: + node = stack.pop() + if node.type in ( + "if_statement", + "elif_clause", + "for_statement", + "while_statement", + "conditional_expression", + "boolean_operator", + ): + complexity += 1 + for child in node.children: + stack.append(child) + return complexity + + +class PythonParser(TreeSitterParser): + """Extract structural elements from Python source.""" + + @property + def supported_languages(self) -> list[str]: + return ["python"] + + def _language(self, file_path: str) -> ts.Language: + return _LANGUAGE + + def _extract( + self, + root_node: ts.Node, + source_bytes: bytes, + file_path: str, + ) -> tuple[list[GraphNode], list[GraphEdge], list[ParseError]]: + nodes: list[GraphNode] = [] + edges: list[GraphEdge] = [] + errors: list[ParseError] = [] + seen_ids: set[str] = set() + + file_node = GraphNode( + id=make_node_id(file_path, NodeType.FILE, file_path, 1), + type=NodeType.FILE, + file_path=file_path, + structural=StructuralProperties( + name=file_path, + file=file_path, + start_line=1, + end_line=root_node.end_point[0] + 1, + lines=root_node.end_point[0] + 1, + ), + ) + self._add_node(file_node, nodes, seen_ids) + + self._walk_block( + root_node, + source_bytes, + file_path, + parent_id=file_node.id, + class_name=None, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + log.debug("python_parsed", file=file_path, nodes=len(nodes), edges=len(edges), errors=len(errors)) + return nodes, edges, errors + + def _add_node(self, node: GraphNode, nodes: list[GraphNode], seen: set[str]) -> bool: + """Add node if not already seen. Returns True if added.""" + if node.id in seen: + return False + seen.add(node.id) + nodes.append(node) + return True + + def _walk_block( + self, + block: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + """Walk children of a block extracting definitions.""" + self._walk_direct_children( + block, + source, + file_path, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + ) + + def _walk_direct_children( + self, + block: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + """Walk direct children of a block, processing definitions.""" + for child in block.children: + if child.type == "function_definition": + self._process_function( + child, + source, + file_path, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + [], + ) + elif child.type == "class_definition": + self._process_class( + child, + source, + file_path, + parent_id, + nodes, + edges, + errors, + seen_ids, + [], + ) + elif child.type == "decorated_definition": + decorator_names = self._extract_decorators(child, source) + for sub in child.children: + if sub.type == "function_definition": + self._process_function( + sub, + source, + file_path, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + decorator_names, + ) + break + elif sub.type == "class_definition": + self._process_class( + sub, + source, + file_path, + parent_id, + nodes, + edges, + errors, + seen_ids, + decorator_names, + ) + break + elif child.type in ("import_statement", "import_from_statement"): + self._process_import(child, source, file_path, parent_id, nodes, edges) + elif child.type == "expression_statement": + self._process_assignment(child, source, file_path, parent_id, nodes, edges) + + def _process_function( + self, + func: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + decorator_names: list[str], + ) -> None: + name_node = func.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(func) + node_type = NodeType.FUNCTION if class_name is None else NodeType.FUNCTION + sig = self._extract_signature(func, source, name) + docstring = self._extract_docstring(func, source) + annotations = self._extract_annotations(func, source) + node_id = make_node_id(file_path, node_type, name, start) + + body = func.child_by_field_name("body") + complexity = _compute_complexity(body) if body else 1 + lines = end - start + 1 + param_count = len(annotations.params) if annotations else 0 + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + complexity=complexity, + lines=lines, + parameters=param_count, + ) + + semantic = SemanticProperties( + docstring=docstring, + decorators=decorator_names, + annotations=annotations, + ) + + metadata: dict[str, str] = {} + if class_name: + metadata["class"] = class_name + + node = GraphNode( + id=node_id, + type=node_type, + file_path=file_path, + structural=structural, + semantic=semantic, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_class( + self, + cls: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + decorator_names: list[str], + ) -> None: + name_node = cls.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(cls) + docstring = self._extract_docstring(cls, source) + bases = self._extract_bases(cls, source) + sig = f"class {name}" + if bases: + sig += f"({', '.join(bases)})" + node_id = make_node_id(file_path, NodeType.CLASS, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + semantic = SemanticProperties( + docstring=docstring, + decorators=decorator_names, + ) + + node = GraphNode( + id=node_id, + type=NodeType.CLASS, + file_path=file_path, + structural=structural, + semantic=semantic, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + for base in bases: + base_id = make_node_id(file_path, NodeType.INTERFACE, base, 0) + edges.append(GraphEdge(source_id=node_id, target_id=base_id, type=EdgeType.IMPLEMENTS)) + + body = cls.child_by_field_name("body") + if body: + self._walk_block( + body, + source, + file_path, + parent_id=node_id, + class_name=name, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + + def _process_assignment( + self, + expr: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + """Process top-level variable assignments.""" + for child in expr.children: + if child.type in ("assignment", "type_alias_statement"): + start, end = line_range(child) + left = child.child_by_field_name("left") or child.child_by_field_name("name") + if not left: + continue + name = node_text(left) + if not name or name.startswith("_"): + continue + node_id = make_node_id(file_path, NodeType.VARIABLE, name, start) + structural = StructuralProperties( + name=name, + file=file_path, + signature=node_text(child), + start_line=start, + end_line=end, + lines=end - start + 1, + ) + node = GraphNode( + id=node_id, + type=NodeType.VARIABLE, + file_path=file_path, + structural=structural, + ) + nodes.append(node) + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + def _process_import( + self, + imp: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + start, end = line_range(imp) + text = node_text(imp).strip() + if imp.type == "import_from_statement": + module_name_node = imp.child_by_field_name("module_name") + module = node_text(module_name_node) if module_name_node else text + else: + module = text.replace("import ", "").split(",")[0].strip() + + node_id = make_node_id(file_path, NodeType.FILE, module, start) + structural = StructuralProperties( + name=module, + file=file_path, + signature=text, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + node = GraphNode( + id=node_id, + type=NodeType.FILE, + file_path=file_path, + structural=structural, + ) + nodes.append(node) + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.IMPORTS)) + + def _extract_calls( + self, + body: ts.Node, + source: bytes, + file_path: str, + caller_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + cursor = ts.QueryCursor(_CALL_QUERY) + seen_edges: set[tuple[str, str]] = set() + for _, caps in cursor.matches(body): + call_nodes = caps.get("call") + callee_nodes = caps.get("callee") + if not callee_nodes or not call_nodes: + continue + callee_name = node_text(callee_nodes[0]) + call_node = call_nodes[0] + start, _ = line_range(call_node) + target_id = make_node_id(file_path, NodeType.FUNCTION, callee_name, 0) + edge_key = (caller_id, target_id) + if edge_key in seen_edges: + continue + seen_edges.add(edge_key) + edges.append( + GraphEdge( + source_id=caller_id, + target_id=target_id, + type=EdgeType.CALLS, + metadata={"line": str(start)}, + ) + ) + + def _extract_decorators(self, decorated: ts.Node, source: bytes) -> list[str]: + names: list[str] = [] + for child in decorated.children: + if child.type == "decorator": + text = node_text(child).lstrip("@").strip() + if "(" in text: + text = text[: text.index("(")] + names.append(text) + return names + + def _extract_bases(self, cls: ts.Node, source: bytes) -> list[str]: + bases: list[str] = [] + arg_list = cls.child_by_field_name("superclasses") + if not arg_list: + for child in cls.children: + if child.type == "argument_list": + arg_list = child + break + if arg_list: + for child in arg_list.children: + if child.type == "identifier": + bases.append(node_text(child)) + return bases + + def _extract_signature(self, func: ts.Node, source: bytes, name: str) -> str: + params = func.child_by_field_name("parameters") + param_text = node_text(params) if params else "()" + return_type = "" + for child in func.children: + if child.type == "type": + return_type = f" -> {node_text(child)}" + break + return f"def {name}{param_text}{return_type}" + + def _extract_annotations(self, func: ts.Node, source: bytes) -> Annotations: + """Extract structured type annotations from a function.""" + params_dict: dict[str, str] = {} + returns: str | None = None + throws: list[str] = [] + + params_node = func.child_by_field_name("parameters") + if params_node: + for child in params_node.children: + if child.type == "identifier": + pname = node_text(child) + if pname in ("self", "cls"): + continue + params_dict[pname] = "Any" + elif child.type == "typed_parameter": + # In tree-sitter-python, typed_parameter has 'identifier' and 'type' as direct children + ident = None + type_node = None + for sub in child.children: + if sub.type == "identifier": + ident = sub + elif sub.type == "type": + type_node = sub + pname = node_text(ident) if ident else "" + ptype = node_text(type_node) if type_node else "Any" + if pname and pname not in ("self", "cls"): + params_dict[pname] = ptype + + for child in func.children: + if child.type == "type": + returns = node_text(child) + break + + body = func.child_by_field_name("body") + if body: + body_text = node_text(body) + raise_matches = re.findall(r"raise\s+(\w+)", body_text) + throws = list(dict.fromkeys(raise_matches)) + + return Annotations(params=params_dict, returns=returns, throws=throws) + + def _extract_docstring(self, func_or_class: ts.Node, source: bytes) -> str: + body = func_or_class.child_by_field_name("body") + if not body: + return "" + for child in body.children: + if child.type == "expression_statement": + for sub in child.children: + if sub.type == "string": + text = node_text(sub) + for quote in ('"""', "'''", '"', "'"): + if text.startswith(quote) and text.endswith(quote): + text = text[len(quote) : -len(quote)] + break + return text.strip() + else: + break + return "" diff --git a/smp/parser/registry.py b/smp/parser/registry.py new file mode 100644 index 0000000..9c4e59f --- /dev/null +++ b/smp/parser/registry.py @@ -0,0 +1,72 @@ +"""Parser registry — dispatches to the correct language parser.""" + +from __future__ import annotations + +from pathlib import Path + +from smp.core.models import Document, Language +from smp.logging import get_logger +from smp.parser.base import TreeSitterParser, detect_language + +log = get_logger(__name__) + + +class ParserRegistry: + """Lazy-initialised registry of language-specific parsers.""" + + def __init__(self) -> None: + self._parsers: dict[Language, TreeSitterParser] = {} + + def _ensure_parser(self, language: Language) -> TreeSitterParser | None: + if language in self._parsers: + return self._parsers[language] + + parser: TreeSitterParser | None = None + + if language == Language.PYTHON: + from smp.parser.python_parser import PythonParser + + parser = PythonParser() + elif language == Language.TYPESCRIPT: + from smp.parser.typescript_parser import TypeScriptParser + + parser = TypeScriptParser() + + if parser: + self._parsers[language] = parser + log.debug("parser_registered", language=language.value) + return parser + + def get(self, language: Language) -> TreeSitterParser | None: + """Return the parser for *language*, or ``None`` if unsupported.""" + return self._ensure_parser(language) + + def parse_file(self, file_path: str) -> Document: + """Detect language, read file, and parse. + + Returns a Document with nodes, edges, and errors. + """ + lang = detect_language(file_path) + parser = self.get(lang) + if not parser: + from smp.core.models import ParseError + + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"No parser available for {lang.value}")], + ) + + try: + source = Path(file_path).read_text(encoding="utf-8", errors="replace") + except OSError as exc: + from smp.core.models import ParseError + + log.error("file_read_error", file_path=file_path, error=str(exc)) + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"Cannot read file: {exc}")], + ) + + return parser.parse(source, file_path) diff --git a/smp/parser/typescript_parser.py b/smp/parser/typescript_parser.py new file mode 100644 index 0000000..a339f05 --- /dev/null +++ b/smp/parser/typescript_parser.py @@ -0,0 +1,525 @@ +"""TypeScript-specific tree-sitter parser. + +Extracts functions, classes, interfaces, methods, imports, arrow functions, +and call edges from TypeScript / TSX source using ``tree-sitter-typescript``. +Updated for SMP(3) partitioned model. +""" + +from __future__ import annotations + +import tree_sitter as ts +import tree_sitter_typescript as tst + +from smp.core.models import ( + EdgeType, + GraphEdge, + GraphNode, + NodeType, + ParseError, + StructuralProperties, +) +from smp.logging import get_logger +from smp.parser.base import TreeSitterParser, is_tsx, line_range, make_node_id, node_text + +log = get_logger(__name__) + +_TS_LANG = ts.Language(tst.language_typescript()) +_TSX_LANG = ts.Language(tst.language_tsx()) + +_QUERY_STRINGS = { + "top": """ +(function_declaration name: (identifier) @name) @func +(class_declaration name: (type_identifier) @name) @class +(interface_declaration name: (type_identifier) @name) @interface +(import_statement) @import +(export_statement) @export +""", + "arrow": """ +(lexical_declaration (variable_declarator name: (identifier) @name value: (arrow_function) @arrow)) @var +""", + "method": """ +(method_definition name: (property_identifier) @name) @method +""", + "call": """ +(call_expression function: (identifier) @callee) @call +(call_expression function: (member_expression property: (property_identifier) @callee)) @call +""", +} + +_query_cache: dict[str, dict[str, ts.Query]] = {"ts": {}, "tsx": {}} + + +def _get_queries(lang: ts.Language) -> dict[str, ts.Query]: + key = "tsx" if lang is _TSX_LANG else "ts" + if not _query_cache[key]: + for name, qstr in _QUERY_STRINGS.items(): + _query_cache[key][name] = ts.Query(lang, qstr) + return _query_cache[key] + + +class TypeScriptParser(TreeSitterParser): + """Extract structural elements from TypeScript / TSX source.""" + + @property + def supported_languages(self) -> list[str]: + return ["typescript"] + + def _language(self, file_path: str) -> ts.Language: + return _TSX_LANG if is_tsx(file_path) else _TS_LANG + + def _extract( + self, + root_node: ts.Node, + source_bytes: bytes, + file_path: str, + ) -> tuple[list[GraphNode], list[GraphEdge], list[ParseError]]: + nodes: list[GraphNode] = [] + edges: list[GraphEdge] = [] + errors: list[ParseError] = [] + seen_ids: set[str] = set() + + file_node = GraphNode( + id=make_node_id(file_path, NodeType.FILE, file_path, 1), + type=NodeType.FILE, + file_path=file_path, + structural=StructuralProperties( + name=file_path, + file=file_path, + start_line=1, + end_line=root_node.end_point[0] + 1, + lines=root_node.end_point[0] + 1, + ), + ) + self._add_node(file_node, nodes, seen_ids) + + self._walk_block( + root_node, + source_bytes, + file_path, + self._language(file_path), + parent_id=file_node.id, + class_name=None, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + log.debug("typescript_parsed", file=file_path, nodes=len(nodes), edges=len(edges), errors=len(errors)) + return nodes, edges, errors + + def _add_node(self, node: GraphNode, nodes: list[GraphNode], seen: set[str]) -> bool: + if node.id in seen: + return False + seen.add(node.id) + nodes.append(node) + return True + + def _walk_block( + self, + block: ts.Node, + source: bytes, + file_path: str, + lang: ts.Language, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + queries = _get_queries(lang) + cursor = ts.QueryCursor(queries["top"]) + for _idx, caps in cursor.matches(block): + func_nodes = caps.get("func") + class_nodes = caps.get("class") + iface_nodes = caps.get("interface") + import_nodes = caps.get("import") + export_nodes = caps.get("export") + + if func_nodes: + self._process_function( + func_nodes[0], + source, + file_path, + parent_id, + class_name, + nodes, + edges, + seen_ids, + ) + continue + + if class_nodes: + self._process_class( + class_nodes[0], + source, + file_path, + lang, + parent_id, + nodes, + edges, + errors, + seen_ids, + ) + continue + + if iface_nodes: + self._process_interface(iface_nodes[0], source, file_path, parent_id, nodes, edges, seen_ids) + continue + + if import_nodes: + self._process_import(import_nodes[0], source, file_path, parent_id, nodes, edges) + continue + + if export_nodes: + for child in export_nodes[0].children: + self._walk_block( + child, + source, + file_path, + lang, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + ) + continue + + arrow_cursor = ts.QueryCursor(queries["arrow"]) + for _idx, caps in arrow_cursor.matches(block): + name_nodes = caps.get("name") + arrow_nodes = caps.get("arrow") + if name_nodes and arrow_nodes: + self._process_arrow_function( + name_nodes[0], + arrow_nodes[0], + source, + file_path, + parent_id, + class_name, + nodes, + edges, + seen_ids, + ) + + method_cursor = ts.QueryCursor(queries["method"]) + for _idx, caps in method_cursor.matches(block): + method_nodes = caps.get("method") + name_nodes = caps.get("name") + if method_nodes and name_nodes: + self._process_method( + method_nodes[0], + name_nodes[0], + source, + file_path, + parent_id, + class_name, + nodes, + edges, + seen_ids, + ) + + def _process_function( + self, + func: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name_node = func.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(func) + sig = self._extract_ts_signature(func, source, name) + node_id = make_node_id(file_path, NodeType.FUNCTION, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FUNCTION, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = func.child_by_field_name("body") + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_arrow_function( + self, + name_node: ts.Node, + arrow: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name = node_text(name_node) + start, end = line_range(arrow) + sig = f"const {name} = {self._extract_ts_signature(arrow, source, name)}" + node_id = make_node_id(file_path, NodeType.FUNCTION, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FUNCTION, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = arrow.child_by_field_name("body") + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_method( + self, + method: ts.Node, + name_node: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name = node_text(name_node) + start, end = line_range(method) + sig = self._extract_ts_signature(method, source, name) + node_id = make_node_id(file_path, NodeType.FUNCTION, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FUNCTION, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = method.child_by_field_name("body") + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_class( + self, + cls: ts.Node, + source: bytes, + file_path: str, + lang: ts.Language, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + name_node = cls.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(cls) + sig = f"class {name}" + node_id = make_node_id(file_path, NodeType.CLASS, name, start) + + for child in cls.children: + if child.type == "class_heritage": + for heritage_child in child.children: + if heritage_child.type == "extends_clause": + for sub in heritage_child.children: + if sub.type in ("type_identifier", "identifier"): + base_name = node_text(sub) + sig += f" extends {base_name}" + base_id = make_node_id(file_path, NodeType.INTERFACE, base_name, 0) + edges.append(GraphEdge(source_id=node_id, target_id=base_id, type=EdgeType.IMPLEMENTS)) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.CLASS, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = cls.child_by_field_name("body") + if body: + self._walk_block( + body, + source, + file_path, + lang, + parent_id=node_id, + class_name=name, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + + def _process_interface( + self, + iface: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name_node = iface.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(iface) + node_id = make_node_id(file_path, NodeType.INTERFACE, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=f"interface {name}", + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.INTERFACE, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + def _process_import( + self, + imp: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + start, end = line_range(imp) + text = node_text(imp).strip() + source_node = imp.child_by_field_name("source") + module = node_text(source_node) if source_node else text + + node_id = make_node_id(file_path, NodeType.FILE, module, start) + structural = StructuralProperties( + name=module, + file=file_path, + signature=text, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FILE, + file_path=file_path, + structural=structural, + ) + nodes.append(node) + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.IMPORTS)) + + def _extract_calls( + self, + body: ts.Node, + source: bytes, + file_path: str, + caller_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + queries = _get_queries(self._language(file_path)) + cursor = ts.QueryCursor(queries["call"]) + seen_edges: set[tuple[str, str]] = set() + for _, caps in cursor.matches(body): + callee_nodes = caps.get("callee") + call_nodes = caps.get("call") + if not callee_nodes or not call_nodes: + continue + callee_name = node_text(callee_nodes[0]) + call_node = call_nodes[0] + start, _ = line_range(call_node) + target_id = make_node_id(file_path, NodeType.FUNCTION, callee_name, 0) + edge_key = (caller_id, target_id) + if edge_key in seen_edges: + continue + seen_edges.add(edge_key) + edges.append( + GraphEdge( + source_id=caller_id, + target_id=target_id, + type=EdgeType.CALLS, + metadata={"line": str(start)}, + ) + ) + + def _extract_ts_signature(self, node: ts.Node, source: bytes, name: str) -> str: + params_node = node.child_by_field_name("parameters") + params_text = node_text(params_node) if params_node else "()" + return_type = "" + for child in node.children: + if child.type == "type_annotation": + return_type = node_text(child) + break + if node.type == "arrow_function": + return f"({params_text}) => {return_type or '...'}" + return f"{name}{params_text}{return_type}" diff --git a/smp/protocol/__init__.py b/smp/protocol/__init__.py new file mode 100644 index 0000000..d7fb9e5 --- /dev/null +++ b/smp/protocol/__init__.py @@ -0,0 +1,9 @@ +"""Protocol layer — JSON-RPC 2.0 over FastAPI.""" + +from smp.protocol.router import handle_rpc +from smp.protocol.server import create_app + +__all__ = [ + "create_app", + "handle_rpc", +] diff --git a/smp/protocol/dispatcher.py b/smp/protocol/dispatcher.py new file mode 100644 index 0000000..4f49285 --- /dev/null +++ b/smp/protocol/dispatcher.py @@ -0,0 +1,267 @@ +"""JSON-RPC 2.0 dispatcher using handler pattern. + +Routes JSON-RPC method calls to registered handler instances. +""" + +from __future__ import annotations + +from typing import Any + +import msgspec +from fastapi import Request +from fastapi.responses import Response + +from smp.core.models import ( + JsonRpcError, + JsonRpcRequest, + JsonRpcResponse, +) +from smp.logging import get_logger +from smp.protocol.handlers.annotation import ( + AnnotateBulkHandler, + AnnotateHandler, + TagHandler, +) +from smp.protocol.handlers.base import MethodHandler +from smp.protocol.handlers.community import ( + CommunityBoundariesHandler, + CommunityDetectHandler, + CommunityGetHandler, + CommunityListHandler, +) +from smp.protocol.handlers.enrichment import ( + EnrichBatchHandler, + EnrichHandler, + EnrichStaleHandler, + EnrichStatusHandler, +) +from smp.protocol.handlers.handoff import ( + HandoffPRHandler, + HandoffReviewHandler, +) +from smp.protocol.handlers.memory import ( + BatchUpdateHandler, + ReindexHandler, + UpdateHandler, +) +from smp.protocol.handlers.merkle import ( + IndexExportHandler, + IndexImportHandler, + MerkleTreeHandler, + SyncHandler, +) +from smp.protocol.handlers.query import ( + ContextHandler, + FlowHandler, + ImpactHandler, + LocateHandler, + NavigateHandler, + SearchHandler, + TraceHandler, +) +from smp.protocol.handlers.query_ext import ( + ConflictHandler, + DiffHandler, + PlanHandler, + WhyHandler, +) +from smp.protocol.handlers.safety import ( + AuditGetHandler, + CheckpointHandler, + DryRunHandler, + GuardCheckHandler, + IntegrityVerifyHandler, + LockHandler, + RollbackHandler, + SessionCloseHandler, + SessionOpenHandler, + SessionRecoverHandler, + UnlockHandler, +) +from smp.protocol.handlers.sandbox import ( + SandboxDestroyHandler, + SandboxExecuteHandler, + SandboxSpawnHandler, +) +from smp.protocol.handlers.telemetry import ( + TelemetryHandler, + TelemetryHotHandler, + TelemetryNodeHandler, + TelemetryRecordHandler, +) + +log = get_logger(__name__) + + +def _error_response(req_id: int | str | None, code: int, message: str, data: Any = None) -> Response: + body = msgspec.json.encode( + JsonRpcResponse( + error=JsonRpcError(code=code, message=message, data=data), + id=req_id, + ) + ) + return Response(content=body, media_type="application/json", status_code=200) + + +def _success_response(req_id: int | str | None, result: Any) -> Response: + body = msgspec.json.encode(JsonRpcResponse(result=result, id=req_id)) + return Response(content=body, media_type="application/json", status_code=200) + + +class RpcDispatcher: + """Dispatches JSON-RPC requests to registered handlers.""" + + def __init__(self) -> None: + self._handlers: dict[str, MethodHandler] = {} + + for handler_cls in [ + UpdateHandler, + BatchUpdateHandler, + ReindexHandler, + EnrichHandler, + EnrichBatchHandler, + EnrichStaleHandler, + EnrichStatusHandler, + AnnotateHandler, + AnnotateBulkHandler, + TagHandler, + SessionOpenHandler, + SessionCloseHandler, + SessionRecoverHandler, + GuardCheckHandler, + DryRunHandler, + CheckpointHandler, + RollbackHandler, + LockHandler, + UnlockHandler, + AuditGetHandler, + IntegrityVerifyHandler, + NavigateHandler, + TraceHandler, + ContextHandler, + ImpactHandler, + LocateHandler, + SearchHandler, + FlowHandler, + DiffHandler, + PlanHandler, + ConflictHandler, + WhyHandler, + TelemetryHandler, + TelemetryHotHandler, + TelemetryNodeHandler, + TelemetryRecordHandler, + SandboxSpawnHandler, + SandboxExecuteHandler, + SandboxDestroyHandler, + CommunityDetectHandler, + CommunityListHandler, + CommunityGetHandler, + CommunityBoundariesHandler, + SyncHandler, + MerkleTreeHandler, + IndexExportHandler, + IndexImportHandler, + HandoffReviewHandler, + HandoffPRHandler, + ]: + handler = handler_cls() + self._handlers[handler.method] = handler + + def register(self, handler: MethodHandler) -> None: + """Register a handler for a method.""" + self._handlers[handler.method] = handler + log.debug("handler_registered", method=handler.method) + + def get_handler(self, method: str) -> MethodHandler | None: + """Get handler for a method.""" + return self._handlers.get(method) + + async def dispatch( + self, + request: Request, + context: dict[str, Any], + ) -> Response: + """Dispatch a JSON-RPC request to the appropriate handler.""" + try: + body = await request.body() + except Exception: + return _error_response(None, -32700, "Parse error") + + if not body: + return _error_response(None, -32700, "Empty request body") + + try: + req = msgspec.json.decode(body, type=JsonRpcRequest) + except (msgspec.DecodeError, Exception) as exc: + return _error_response(None, -32700, f"Parse error: {exc}") + + if req.jsonrpc != "2.0": + return _error_response(req.id, -32600, "Invalid Request: jsonrpc must be '2.0'") + + if not req.method: + return _error_response(req.id, -32600, "Invalid Request: method is required") + + method = req.method + params = req.params or {} + + log.debug("rpc_request", method=method, id=req.id) + + handler = self._handlers.get(method) + if not handler: + return _error_response(req.id, -32601, f"Method not found: {method}") + + try: + result = await handler.handle(params, context) + except msgspec.ValidationError as exc: + return _error_response(req.id, -32602, f"Invalid params: {exc}") + except ValueError as exc: + return _error_response(req.id, -32001, str(exc)) + except Exception as exc: + log.error("rpc_internal_error", method=method, error=str(exc)) + return _error_response(req.id, -32603, f"Internal error: {exc}") + + if req.id is None: + return Response(content=b"", status_code=204) + + return _success_response(req.id, result) + + +_dispatcher: RpcDispatcher | None = None + + +def get_dispatcher() -> RpcDispatcher: + """Get or create the global dispatcher instance.""" + global _dispatcher + if _dispatcher is None: + _dispatcher = RpcDispatcher() + return _dispatcher + + +async def handle_rpc( + request: Request, + *, + engine: Any, + enricher: Any, + builder: Any, + registry: Any, + vector: Any, + safety: dict[str, Any] | None = None, + telemetry_engine: Any = None, + handoff_manager: Any = None, + integrity_verifier: Any = None, +) -> Response: + """Dispatch a single JSON-RPC 2.0 request.""" + dispatcher = get_dispatcher() + context = { + "engine": engine, + "enricher": enricher, + "builder": builder, + "registry": registry, + "vector": vector, + "safety": safety, + "telemetry_engine": telemetry_engine, + "handoff_manager": handoff_manager, + "integrity_verifier": integrity_verifier, + } + return await dispatcher.dispatch(request, context) diff --git a/smp/protocol/handlers/__init__.py b/smp/protocol/handlers/__init__.py new file mode 100644 index 0000000..9f20d06 --- /dev/null +++ b/smp/protocol/handlers/__init__.py @@ -0,0 +1 @@ +"""Protocol handler modules.""" diff --git a/smp/protocol/handlers/annotation.py b/smp/protocol/handlers/annotation.py new file mode 100644 index 0000000..9ce7243 --- /dev/null +++ b/smp/protocol/handlers/annotation.py @@ -0,0 +1,117 @@ +"""Handler for annotation methods (smp/annotate, smp/annotate/bulk, smp/tag).""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +import msgspec + +from smp.core.models import AnnotateBulkParams, AnnotateParams, TagParams +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class AnnotateHandler(MethodHandler): + """Handles smp/annotate method.""" + + @property + def method(self) -> str: + return "smp/annotate" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ap = msgspec.convert(params, AnnotateParams) + engine = context["engine"] + + node = await engine._graph.get_node(ap.node_id) + if not node: + raise ValueError(f"Node not found: {ap.node_id}") + + if node.semantic.docstring and not ap.force: + raise ValueError(f"Node already has extracted docstring. Set force: true to override. Node: {ap.node_id}") + + node.semantic.description = ap.description + node.semantic.tags = list(set(node.semantic.tags + ap.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + + return { + "node_id": ap.node_id, + "status": "annotated", + "manually_set": True, + "annotated_at": node.semantic.enriched_at, + } + + +class AnnotateBulkHandler(MethodHandler): + """Handles smp/annotate/bulk method.""" + + @property + def method(self) -> str: + return "smp/annotate/bulk" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + abp = msgspec.convert(params, AnnotateBulkParams) + engine = context["engine"] + + annotated = 0 + failed = 0 + + for ann in abp.annotations: + node = await engine._graph.get_node(ann.node_id) + if not node: + failed += 1 + continue + + node.semantic.description = ann.description + node.semantic.tags = list(set(node.semantic.tags + ann.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + annotated += 1 + + return {"annotated": annotated, "failed": failed} + + +class TagHandler(MethodHandler): + """Handles smp/tag method.""" + + @property + def method(self) -> str: + return "smp/tag" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + tp = msgspec.convert(params, TagParams) + engine = context["engine"] + + nodes = await engine._graph.find_nodes_by_scope(tp.scope) + affected = 0 + + for node in nodes: + if tp.action == "add": + node.semantic.tags = list(set(node.semantic.tags + tp.tags)) + elif tp.action == "remove": + node.semantic.tags = [t for t in node.semantic.tags if t not in tp.tags] + elif tp.action == "replace": + node.semantic.tags = list(tp.tags) + await engine._graph.upsert_node(node) + affected += 1 + + return {"nodes_affected": affected, "action": tp.action, "scope": tp.scope} diff --git a/smp/protocol/handlers/base.py b/smp/protocol/handlers/base.py new file mode 100644 index 0000000..a1c5103 --- /dev/null +++ b/smp/protocol/handlers/base.py @@ -0,0 +1,34 @@ +"""Base handler interface for JSON-RPC method handlers.""" + +from __future__ import annotations + +import abc +from typing import Any + + +class MethodHandler(abc.ABC): + """Abstract base class for JSON-RPC method handlers.""" + + @property + @abc.abstractmethod + def method(self) -> str: + """Return the JSON-RPC method name this handler processes.""" + + @abc.abstractmethod + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + """Handle the method call. + + Args: + params: The method parameters + context: Request context (engine, enricher, etc.) + + Returns: + Result dict or None for notifications + + Raises: + JsonRpcError: For method-specific errors + """ diff --git a/smp/protocol/handlers/community.py b/smp/protocol/handlers/community.py new file mode 100644 index 0000000..306090d --- /dev/null +++ b/smp/protocol/handlers/community.py @@ -0,0 +1,94 @@ +"""Handler for community detection methods.""" + +from __future__ import annotations + +from typing import Any, cast + +import msgspec + +from smp.core.models import ( + CommunityBoundariesParams, + CommunityDetectParams, + CommunityGetParams, + CommunityListParams, +) +from smp.protocol.handlers.base import MethodHandler + + +class CommunityDetectHandler(MethodHandler): + """Handles smp/community/detect method.""" + + @property + def method(self) -> str: + return "smp/community/detect" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, CommunityDetectParams) + detector = context["community_detector"] + return cast( + dict[str, Any], + await detector.detect( + resolutions=p.resolutions or None, + relationship_types=p.relationship_types or None, + ), + ) + + +class CommunityListHandler(MethodHandler): + """Handles smp/community/list method.""" + + @property + def method(self) -> str: + return "smp/community/list" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, CommunityListParams) + detector = context["community_detector"] + return cast(dict[str, Any], await detector.list_communities(level=p.level)) + + +class CommunityGetHandler(MethodHandler): + """Handles smp/community/get method.""" + + @property + def method(self) -> str: + return "smp/community/get" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + p = msgspec.convert(params, CommunityGetParams) + detector = context["community_detector"] + result = await detector.get_community( + community_id=p.community_id, + node_types=p.node_types or None, + include_bridges=p.include_bridges, + ) + return cast(dict[str, Any] | None, result) + + +class CommunityBoundariesHandler(MethodHandler): + """Handles smp/community/boundaries method.""" + + @property + def method(self) -> str: + return "smp/community/boundaries" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, CommunityBoundariesParams) + detector = context["community_detector"] + return cast(dict[str, Any], await detector.get_boundaries(level=p.level, min_coupling=p.min_coupling)) diff --git a/smp/protocol/handlers/enrichment.py b/smp/protocol/handlers/enrichment.py new file mode 100644 index 0000000..b70cc51 --- /dev/null +++ b/smp/protocol/handlers/enrichment.py @@ -0,0 +1,185 @@ +"""Handler for enrichment methods (smp/enrich, smp/enrich/batch, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + EnrichBatchParams, + EnrichParams, + EnrichStaleParams, + EnrichStatusParams, +) +from smp.engine.enricher import _compute_source_hash +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class EnrichHandler(MethodHandler): + """Handles smp/enrich method.""" + + @property + def method(self) -> str: + return "smp/enrich" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ep = msgspec.convert(params, EnrichParams) + engine = context["engine"] + enricher = context["enricher"] + + node = await engine._graph.get_node(ep.node_id) + if not node: + raise ValueError(f"Node not found: {ep.node_id}") + + enriched = await enricher.enrich_node(node, force=ep.force) + if enriched.semantic.source_hash and enriched.semantic.status == "enriched": + await engine._graph.upsert_node(enriched) + + return { + "node_id": enriched.id, + "status": enriched.semantic.status, + "docstring": enriched.semantic.docstring, + "inline_comments": [{"line": c.line, "text": c.text} for c in enriched.semantic.inline_comments], + "decorators": enriched.semantic.decorators, + "annotations": { + "params": (enriched.semantic.annotations.params if enriched.semantic.annotations else {}), + "returns": (enriched.semantic.annotations.returns if enriched.semantic.annotations else None), + "throws": (enriched.semantic.annotations.throws if enriched.semantic.annotations else []), + } + if enriched.semantic.annotations + else {}, + "tags": enriched.semantic.tags, + "source_hash": enriched.semantic.source_hash, + "enriched_at": enriched.semantic.enriched_at, + } + + +class EnrichBatchHandler(MethodHandler): + """Handles smp/enrich/batch method.""" + + @property + def method(self) -> str: + return "smp/enrich/batch" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ebp = msgspec.convert(params, EnrichBatchParams) + engine = context["engine"] + enricher = context["enricher"] + + nodes = await engine._graph.find_nodes_by_scope(ebp.scope) + enriched_count = 0 + skipped_count = 0 + no_metadata_count = 0 + no_metadata_nodes: list[str] = [] + + for node in nodes: + enriched = await enricher.enrich_node(node, force=ebp.force) + if enriched.semantic.status == "enriched": + enriched_count += 1 + await engine._graph.upsert_node(enriched) + elif enriched.semantic.status == "skipped": + skipped_count += 1 + elif enriched.semantic.status == "no_metadata": + no_metadata_count += 1 + no_metadata_nodes.append(enriched.id) + + return { + "enriched": enriched_count, + "skipped": skipped_count, + "no_metadata": no_metadata_count, + "failed": 0, + "no_metadata_nodes": no_metadata_nodes, + } + + +class EnrichStaleHandler(MethodHandler): + """Handles smp/enrich/stale method.""" + + @property + def method(self) -> str: + return "smp/enrich/stale" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + esp = msgspec.convert(params, EnrichStaleParams) + engine = context["engine"] + + nodes = await engine._graph.find_nodes_by_scope(esp.scope) + stale_nodes = [] + + for node in nodes: + if node.semantic.source_hash: + current = _compute_source_hash( + node.structural.name, + node.file_path, + node.structural.start_line, + node.structural.end_line, + node.structural.signature, + ) + if current != node.semantic.source_hash: + stale_nodes.append( + { + "node_id": node.id, + "file": node.file_path, + "last_enriched": node.semantic.enriched_at, + "current_hash": current, + "enriched_hash": node.semantic.source_hash, + } + ) + + return {"stale_count": len(stale_nodes), "stale_nodes": stale_nodes} + + +class EnrichStatusHandler(MethodHandler): + """Handles smp/enrich/status method.""" + + @property + def method(self) -> str: + return "smp/enrich/status" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + estp = msgspec.convert(params, EnrichStatusParams) + engine = context["engine"] + + nodes = await engine._graph.find_nodes_by_scope(estp.scope) + total = len(nodes) + has_docstring = sum(1 for n in nodes if n.semantic.docstring) + has_annotations = sum( + 1 + for n in nodes + if n.semantic.annotations and (n.semantic.annotations.params or n.semantic.annotations.returns) + ) + has_tags = sum(1 for n in nodes if n.semantic.tags) + manually_annotated = sum(1 for n in nodes if n.semantic.manually_set) + no_metadata = sum(1 for n in nodes if n.semantic.status == "no_metadata") + coverage = round((total - no_metadata) / total * 100, 1) if total > 0 else 0 + + return { + "total_nodes": total, + "has_docstring": has_docstring, + "has_annotations": has_annotations, + "has_tags": has_tags, + "manually_annotated": manually_annotated, + "no_metadata": no_metadata, + "stale": 0, + "coverage_pct": coverage, + } diff --git a/smp/protocol/handlers/handoff.py b/smp/protocol/handlers/handoff.py new file mode 100644 index 0000000..e770799 --- /dev/null +++ b/smp/protocol/handlers/handoff.py @@ -0,0 +1,68 @@ +"""Handler for handoff and review methods.""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import PRCreateParams, ReviewCreateParams +from smp.protocol.handlers.base import MethodHandler + + +class HandoffReviewHandler(MethodHandler): + """Handles smp/handoff/review method.""" + + @property + def method(self) -> str: + return "smp/handoff/review" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, ReviewCreateParams) + manager = context["handoff_manager"] + review = manager.create_review( + session_id=p.session_id, + files_changed=p.files_changed, + diff_summary=p.diff_summary, + reviewers=p.reviewers, + ) + return { + "review_id": review.review_id, + "status": review.status, + "created_at": review.created_at, + } + + +class HandoffPRHandler(MethodHandler): + """Handles smp/handoff/pr method.""" + + @property + def method(self) -> str: + return "smp/handoff/pr" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + p = msgspec.convert(params, PRCreateParams) + manager = context["handoff_manager"] + pr = manager.create_pr( + review_id=p.review_id, + title=p.title, + body=p.body, + branch=p.branch, + base_branch=p.base_branch, + ) + if pr is None: + return None + return { + "pr_id": pr.pr_id, + "status": pr.status, + "url": pr.url, + "created_at": pr.created_at, + } diff --git a/smp/protocol/handlers/memory.py b/smp/protocol/handlers/memory.py new file mode 100644 index 0000000..6658907 --- /dev/null +++ b/smp/protocol/handlers/memory.py @@ -0,0 +1,115 @@ +"""Handler for memory management methods (smp/update, smp/batch_update, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import BatchUpdateParams, ReindexParams, UpdateParams +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class UpdateHandler(MethodHandler): + """Handles smp/update method.""" + + @property + def method(self) -> str: + return "smp/update" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, UpdateParams) + enricher = context["enricher"] + builder = context["builder"] + registry = context["registry"] + vector = context.get("vector") + + file_path = p.file_path + + if p.content: + parser_obj = registry.get(p.language) + if not parser_obj: + from smp.core.models import Language + + parser_obj = registry.get(Language.PYTHON) + if not parser_obj: + return {"error": "No parser available"} + doc = parser_obj.parse(p.content, file_path) + else: + doc = registry.parse_file(file_path) + + if not doc.nodes and not doc.edges: + return { + "file_path": file_path, + "nodes": 0, + "edges": 0, + "errors": len(doc.errors), + "message": "No nodes extracted", + } + + enriched_nodes = await enricher.enrich_batch(doc.nodes) + doc = type(doc)( + file_path=doc.file_path, + language=doc.language, + nodes=enriched_nodes, + edges=doc.edges, + errors=doc.errors, + ) + + if vector: + await vector.delete_by_file(file_path) + await builder.remove_document(file_path) + await builder.ingest_document(doc) + + return { + "file_path": file_path, + "nodes": len(doc.nodes), + "edges": len(doc.edges), + "errors": len(doc.errors), + } + + +class BatchUpdateHandler(MethodHandler): + """Handles smp/batch_update method.""" + + @property + def method(self) -> str: + return "smp/batch_update" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + bp = msgspec.convert(params, BatchUpdateParams) + update_handler = UpdateHandler() + + results = [] + for change in bp.changes: + r = await update_handler.handle(change, context) + results.append(r) + + return {"updates": len(results), "results": results} + + +class ReindexHandler(MethodHandler): + """Handles smp/reindex method.""" + + @property + def method(self) -> str: + return "smp/reindex" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + rp = msgspec.convert(params, ReindexParams) + return {"status": "reindex_requested", "scope": rp.scope} diff --git a/smp/protocol/handlers/merkle.py b/smp/protocol/handlers/merkle.py new file mode 100644 index 0000000..de3cf5a --- /dev/null +++ b/smp/protocol/handlers/merkle.py @@ -0,0 +1,81 @@ +"""Handler for Merkle index and sync methods.""" + +from __future__ import annotations + +from typing import Any, cast + +from smp.protocol.handlers.base import MethodHandler + + +class SyncHandler(MethodHandler): + """Handles smp/sync method.""" + + @property + def method(self) -> str: + return "smp/sync" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + remote_hash = params.get("remote_hash", "") + index = context["merkle_index"] + # MerkleIndex.sync returns dict[str, set[str]] | None + result = index.sync(remote_hash) + if result is None: + return {"status": "in_sync"} + return {"status": "out_of_sync", "diff": result} + + +class MerkleTreeHandler(MethodHandler): + """Handles smp/merkle/tree method.""" + + @property + def method(self) -> str: + return "smp/merkle/tree" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + index = context["merkle_index"] + tree = index._tree + return {"hash": tree.hash()} + + +class IndexExportHandler(MethodHandler): + """Handles smp/index/export method.""" + + @property + def method(self) -> str: + return "smp/index/export" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + index = context["merkle_index"] + tree = index._tree + return cast(dict[str, Any], tree.export()) + + +class IndexImportHandler(MethodHandler): + """Handles smp/index/import method.""" + + @property + def method(self) -> str: + return "smp/index/import" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + data = params.get("data", {}) + index = context["merkle_index"] + tree = index._tree + tree.import_data(data) + return {"success": True} diff --git a/smp/protocol/handlers/query.py b/smp/protocol/handlers/query.py new file mode 100644 index 0000000..aee6582 --- /dev/null +++ b/smp/protocol/handlers/query.py @@ -0,0 +1,142 @@ +"""Handler for query methods (smp/navigate, smp/trace, smp/context, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + ContextParams, + FlowParams, + ImpactParams, + LocateParams, + NavigateParams, + SearchParams, + TraceParams, +) +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class NavigateHandler(MethodHandler): + """Handles smp/navigate method.""" + + @property + def method(self) -> str: + return "smp/navigate" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + np_ = msgspec.convert(params, NavigateParams) + engine = context["engine"] + return await engine.navigate(np_.query, np_.include_relationships) + + +class TraceHandler(MethodHandler): + """Handles smp/trace method.""" + + @property + def method(self) -> str: + return "smp/trace" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + trp = msgspec.convert(params, TraceParams) + engine = context["engine"] + result = await engine.trace(trp.start, trp.relationship, trp.depth, trp.direction) + return {"nodes": result} + + +class ContextHandler(MethodHandler): + """Handles smp/context method.""" + + @property + def method(self) -> str: + return "smp/context" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ctp = msgspec.convert(params, ContextParams) + engine = context["engine"] + return await engine.get_context(ctp.file_path, ctp.scope, ctp.depth) + + +class ImpactHandler(MethodHandler): + """Handles smp/impact method.""" + + @property + def method(self) -> str: + return "smp/impact" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + imp = msgspec.convert(params, ImpactParams) + engine = context["engine"] + return await engine.assess_impact(imp.entity, imp.change_type) + + +class LocateHandler(MethodHandler): + """Handles smp/locate method.""" + + @property + def method(self) -> str: + return "smp/locate" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + loc = msgspec.convert(params, LocateParams) + engine = context["engine"] + result = await engine.locate(loc.query, loc.fields, loc.node_types, loc.top_k) + return {"matches": result} + + +class SearchHandler(MethodHandler): + """Handles smp/search method.""" + + @property + def method(self) -> str: + return "smp/search" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sp = msgspec.convert(params, SearchParams) + engine = context["engine"] + return await engine.search(sp.query, sp.match, sp.filter, sp.top_k) + + +class FlowHandler(MethodHandler): + """Handles smp/flow method.""" + + @property + def method(self) -> str: + return "smp/flow" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + fp = msgspec.convert(params, FlowParams) + engine = context["engine"] + return await engine.find_flow(fp.start, fp.end, fp.flow_type) diff --git a/smp/protocol/handlers/query_ext.py b/smp/protocol/handlers/query_ext.py new file mode 100644 index 0000000..76002b7 --- /dev/null +++ b/smp/protocol/handlers/query_ext.py @@ -0,0 +1,115 @@ +"""Handler for diff, plan, conflict, why, and telemetry methods.""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + ConflictParams, + DiffParams, + PlanParams, + TelemetryParams, + WhyParams, +) +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class DiffHandler(MethodHandler): + """Handles smp/diff method.""" + + @property + def method(self) -> str: + return "smp/diff" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + dp = msgspec.convert(params, DiffParams) + engine = context["engine"] + return await engine.diff(dp.from_snapshot, dp.to_snapshot, dp.scope) + + +class PlanHandler(MethodHandler): + """Handles smp/plan method.""" + + @property + def method(self) -> str: + return "smp/plan" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + pp = msgspec.convert(params, PlanParams) + engine = context["engine"] + return await engine.plan(pp.change_description, pp.target_file, pp.change_type, pp.scope) + + +class ConflictHandler(MethodHandler): + """Handles smp/conflict method.""" + + @property + def method(self) -> str: + return "smp/conflict" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + cp = msgspec.convert(params, ConflictParams) + engine = context["engine"] + return await engine.conflict(cp.entity, cp.proposed_change, cp.context) + + +class WhyHandler(MethodHandler): + """Handles smp/why method.""" + + @property + def method(self) -> str: + return "smp/graph/why" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + wp = msgspec.convert(params, WhyParams) + engine = context["engine"] + return await engine.why(wp.entity, wp.relationship, wp.depth) + + +class TelemetryHandler(MethodHandler): + """Handles smp/telemetry method.""" + + @property + def method(self) -> str: + return "smp/telemetry" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + tp = msgspec.convert(params, TelemetryParams) + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + # Return basic stats if telemetry not configured + return {"action": tp.action, "status": "not_configured"} + + if tp.action == "get_stats": + return telemetry_engine.get_summary() + elif tp.action == "get_hot" and tp.node_id: + return telemetry_engine.get_stats(tp.node_id) + elif tp.action == "decay": + return {"decayed": telemetry_engine.decay()} + else: + return {"error": "Unknown telemetry action"} diff --git a/smp/protocol/handlers/safety.py b/smp/protocol/handlers/safety.py new file mode 100644 index 0000000..079a21c --- /dev/null +++ b/smp/protocol/handlers/safety.py @@ -0,0 +1,338 @@ +"""Handler for safety protocol methods (session, guard, lock, checkpoint, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + AuditGetParams, + CheckpointParams, + DryRunParams, + GuardCheckParams, + LockParams, + RollbackParams, + SessionCloseParams, + SessionOpenParams, + SessionRecoverParams, +) +from smp.engine.integrity import IntegrityCheckResult, IntegrityVerifier +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class SessionOpenHandler(MethodHandler): + """Handles smp/session/open method.""" + + @property + def method(self) -> str: + return "smp/session/open" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sop = msgspec.convert(params, SessionOpenParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return await safety["session_manager"].open_session(sop.agent_id, sop.task, sop.scope, sop.mode) + + +class SessionCloseHandler(MethodHandler): + """Handles smp/session/close method.""" + + @property + def method(self) -> str: + return "smp/session/close" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + scp = msgspec.convert(params, SessionCloseParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + close_result = await safety["session_manager"].close_session(scp.session_id, scp.status) + if not close_result: + raise ValueError(f"Session not found: {scp.session_id}") + + await safety["lock_manager"].release_all(scp.session_id) + if "audit_logger" in safety: + safety["audit_logger"].close_log(close_result.get("audit_log_id", ""), scp.status) + + return close_result + + +class SessionRecoverHandler(MethodHandler): + """Handles smp/session/recover method.""" + + @property + def method(self) -> str: + return "smp/session/recover" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + srp = msgspec.convert(params, SessionRecoverParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + session_manager = safety.get("session_manager") + if not session_manager: + raise ValueError("Session manager not configured") + + result = await session_manager.recover_session(srp.session_id) + if not result: + raise ValueError(f"Session not found: {srp.session_id}") + + return result + + +class GuardCheckHandler(MethodHandler): + """Handles smp/guard/check method.""" + + @property + def method(self) -> str: + return "smp/guard/check" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + gcp = msgspec.convert(params, GuardCheckParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return await safety["guard_engine"].check(gcp.session_id, gcp.target, gcp.intended_change) + + +class DryRunHandler(MethodHandler): + """Handles smp/dryrun method.""" + + @property + def method(self) -> str: + return "smp/dryrun" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + drp = msgspec.convert(params, DryRunParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return safety["dryrun_simulator"].simulate( + drp.session_id, drp.file_path, drp.proposed_content, drp.change_summary + ) + + +class CheckpointHandler(MethodHandler): + """Handles smp/checkpoint method.""" + + @property + def method(self) -> str: + return "smp/checkpoint" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + cp = msgspec.convert(params, CheckpointParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return safety["checkpoint_manager"].create(cp.session_id, cp.files) + + +class RollbackHandler(MethodHandler): + """Handles smp/rollback method.""" + + @property + def method(self) -> str: + return "smp/rollback" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + rbp = msgspec.convert(params, RollbackParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return safety["checkpoint_manager"].rollback(rbp.checkpoint_id) + + +class LockHandler(MethodHandler): + """Handles smp/lock method.""" + + @property + def method(self) -> str: + return "smp/lock" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + lp = msgspec.convert(params, LockParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return await safety["lock_manager"].acquire(lp.session_id, lp.files) + + +class UnlockHandler(MethodHandler): + """Handles smp/unlock method.""" + + @property + def method(self) -> str: + return "smp/unlock" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ulp = msgspec.convert(params, LockParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + await safety["lock_manager"].release(ulp.session_id, ulp.files) + return {"released": ulp.files} + + +class AuditGetHandler(MethodHandler): + """Handles smp/audit/get method.""" + + @property + def method(self) -> str: + return "smp/audit/get" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + agp = msgspec.convert(params, AuditGetParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + audit_logger = safety.get("audit_logger") + if not audit_logger: + raise ValueError("Audit logger not configured") + + # Prefer explicit audit_log_id, fall back to session_id param for convenience + audit = None + if agp.audit_log_id: + audit = audit_logger.get_log(agp.audit_log_id) + if not audit and "session_id" in params: + audit = audit_logger.get_log_by_session(params.get("session_id")) + + if not audit: + raise ValueError(f"Audit log not found: {agp.audit_log_id or params.get('session_id')}") + + return audit + + +class IntegrityVerifyHandler(MethodHandler): + """Handles smp/verify/integrity method.""" + + @property + def method(self) -> str: + return "smp/verify/integrity" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + session_id: str = params["session_id"] + node_ids: list[str] = params.get("node_ids") or [] + mode: str = params.get("mode", "ast") + if mode not in ("ast", "mutation", "both"): + raise ValueError(f"Invalid mode: {mode}. Must be 'ast', 'mutation', or 'both'") + + integrity_verifier: IntegrityVerifier | None = context.get("integrity_verifier") + if not integrity_verifier: + integrity_verifier = IntegrityVerifier() + + graph_store = context.get("engine") + all_mutations: list[dict[str, Any]] = [] + all_warnings: list[str] = [] + total_checks = 0 + all_passed = True + + target_ids = node_ids if node_ids else list(integrity_verifier._baselines.keys()) + + for nid in target_ids: + results: list[IntegrityCheckResult] = [] + + if mode in ("ast", "both"): + baseline = integrity_verifier._baselines.get(nid) + current_state = baseline["state"] if baseline else {} + ast_result = await integrity_verifier.verify(nid, current_state) + results.append(ast_result) + + if mode in ("mutation", "both"): + if not graph_store: + all_warnings.append(f"Graph store unavailable for mutation test on {nid}") + continue + mutation_result = await integrity_verifier.run_mutation_test(nid, graph_store) + results.append(mutation_result) + + for r in results: + if not r.passed: + all_passed = False + total_checks += r.checks_run + all_mutations.extend( + [ + { + "node_id": m.node_id, + "mutation_type": m.mutation_type, + "field_name": m.field_name, + "old_value": m.old_value, + "new_value": m.new_value, + "detected_at": m.detected_at, + } + for m in r.mutations_detected + ] + ) + all_warnings.extend(r.warnings) + + log.info( + "integrity_verify", + session_id=session_id, + mode=mode, + passed=all_passed, + checks_run=total_checks, + ) + + return { + "passed": all_passed, + "mutations_detected": all_mutations, + "warnings": all_warnings, + "checks_run": total_checks, + } diff --git a/smp/protocol/handlers/sandbox.py b/smp/protocol/handlers/sandbox.py new file mode 100644 index 0000000..9622b7f --- /dev/null +++ b/smp/protocol/handlers/sandbox.py @@ -0,0 +1,110 @@ +"""Handler for sandbox methods (smp/sandbox/spawn, etc.).""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +import msgspec + +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler +from smp.sandbox.executor import SandboxExecutor +from smp.sandbox.spawner import SandboxSpawner + +log = get_logger(__name__) + + +class SandboxSpawnHandler(MethodHandler): + """Handles smp/sandbox/spawn method.""" + + @property + def method(self) -> str: + return "smp/sandbox/spawn" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # In a real implementation, these would come from context/session + # For now, we'll use defaults and extract from params if provided + sp = msgspec.convert(params, dict) # Use raw params since no model exists yet + + spawner = SandboxSpawner() + + sandbox_info = spawner.spawn(name=sp.get("name"), template=sp.get("template"), files=sp.get("files")) + + return { + "sandbox_id": sandbox_info.sandbox_id, + "root_path": sandbox_info.root_path, + "created_at": sandbox_info.created_at, + "status": sandbox_info.status, + } + + +class SandboxExecuteHandler(MethodHandler): + """Handles smp/sandbox/execute method.""" + + @property + def method(self) -> str: + return "smp/sandbox/execute" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sep = msgspec.convert(params, dict) # Use raw params + + # Create executor with default config + executor = SandboxExecutor() + + # Execute the command + result = await executor.execute( + command=sep.get("command", []), stdin=sep.get("stdin"), cwd=sep.get("working_directory") + ) + + return { + "execution_id": result.execution_id, + "exit_code": result.exit_code, + "stdout": result.stdout, + "stderr": result.stderr, + "duration_ms": result.duration_ms, + "memory_used_mb": result.memory_used_mb, + "timed_out": result.timed_out, + "killed": result.killed, + "metadata": result.metadata, + } + + +class SandboxDestroyHandler(MethodHandler): + """Handles smp/sandbox/destroy method.""" + + @property + def method(self) -> str: + return "smp/sandbox/destroy" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sdp = msgspec.convert(params, dict) # Use raw params + + spawner = SandboxSpawner() + sandbox_id = sdp.get("sandbox_id") + + if not sandbox_id: + return {"error": "sandbox_id is required"} + + destroyed = spawner.destroy(sandbox_id) + + if destroyed: + return { + "sandbox_id": sandbox_id, + "status": "destroyed", + "destroyed_at": datetime.now(UTC).isoformat(), + } + else: + return {"error": f"Sandbox not found: {sandbox_id}"} diff --git a/smp/protocol/handlers/telemetry.py b/smp/protocol/handlers/telemetry.py new file mode 100644 index 0000000..d12f2cf --- /dev/null +++ b/smp/protocol/handlers/telemetry.py @@ -0,0 +1,122 @@ +"""Telemetry handlers for SMP(3).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + TelemetryParams, +) +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class TelemetryHandler(MethodHandler): + """Handles smp/telemetry method.""" + + @property + def method(self) -> str: + return "smp/telemetry" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + tp = msgspec.convert(params, TelemetryParams) + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"action": tp.action, "status": "not_configured"} + elif tp.action == "get_stats": + return telemetry_engine.get_summary() + elif tp.action == "get_hot" and tp.node_id: + return telemetry_engine.get_stats(tp.node_id) + elif tp.action == "decay": + return {"decayed": telemetry_engine.decay()} + else: + return {"error": "Unknown telemetry action"} + + +class TelemetryHotHandler(MethodHandler): + """Handles smp/telemetry/hot method.""" + + @property + def method(self) -> str: + return "smp/telemetry/hot" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # Extract node_id from params + node_id = params.get("node_id") + if not node_id: + return {"error": "node_id is required"} + + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"status": "not_configured"} + + return telemetry_engine.get_stats(node_id) + + +class TelemetryNodeHandler(MethodHandler): + """Handles smp/telemetry/node method.""" + + @property + def method(self) -> str: + return "smp/telemetry/node" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # Extract node_id from params + node_id = params.get("node_id") + if not node_id: + return {"error": "node_id is required"} + + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"status": "not_configured"} + + return telemetry_engine.get_stats(node_id) + + +class TelemetryRecordHandler(MethodHandler): + """Handles smp/telemetry/record method.""" + + @property + def method(self) -> str: + return "smp/telemetry/record" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # Extract parameters + node_id = params.get("node_id") + action = params.get("action", "access") + session_id = params.get("session_id") + agent_id = params.get("agent_id") + + if not node_id: + return {"error": "node_id is required"} + + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"status": "not_configured"} + + return telemetry_engine.record_access( + node_id=node_id, + action=action, + session_id=session_id or "", + agent_id=agent_id or "", + ) diff --git a/smp/protocol/router.py b/smp/protocol/router.py new file mode 100644 index 0000000..01855c5 --- /dev/null +++ b/smp/protocol/router.py @@ -0,0 +1,653 @@ +"""JSON-RPC 2.0 dispatcher for the Structural Memory Protocol (SMP(3)). + +All SMP protocol methods are routed through a single ``POST /rpc`` endpoint. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +import msgspec +from fastapi import Request +from fastapi.responses import Response + +from smp.core.models import ( + AnnotateBulkParams, + AnnotateParams, + AuditGetParams, + BatchUpdateParams, + CheckpointParams, + ContextParams, + DryRunParams, + EnrichBatchParams, + EnrichParams, + EnrichStaleParams, + EnrichStatusParams, + FlowParams, + GuardCheckParams, + ImpactParams, + IntegrityCheckParams, + JsonRpcError, + JsonRpcRequest, + JsonRpcResponse, + LocateParams, + LockParams, + NavigateParams, + PRCreateParams, + ReindexParams, + ReviewApproveParams, + ReviewCommentParams, + ReviewCreateParams, + ReviewRejectParams, + RollbackParams, + SearchParams, + SessionCloseParams, + SessionOpenParams, + TagParams, + TelemetryParams, + TraceParams, + UpdateParams, +) +from smp.logging import get_logger +from smp.sandbox.executor import SandboxExecutor + +log = get_logger(__name__) + + +def _error_response(req_id: int | str | None, code: int, message: str, data: Any = None) -> Response: + body = msgspec.json.encode( + JsonRpcResponse( + error=JsonRpcError(code=code, message=message, data=data), + id=req_id, + ) + ) + return Response(content=body, media_type="application/json", status_code=200) + + +def _success_response(req_id: int | str | None, result: Any) -> Response: + body = msgspec.json.encode(JsonRpcResponse(result=result, id=req_id)) + return Response(content=body, media_type="application/json", status_code=200) + + +async def _handle_update( + params: dict[str, Any], + engine: Any, + enricher: Any, + builder: Any, + registry: Any, + vector: Any, +) -> dict[str, Any]: + p = msgspec.convert(params, UpdateParams) + file_path = p.file_path + + if p.content: + parser_obj = registry.get(p.language) + if not parser_obj: + from smp.core.models import Language + + parser_obj = registry.get(Language.PYTHON) + if not parser_obj: + return {"error": "No parser available"} + doc = parser_obj.parse(p.content, file_path) + else: + doc = registry.parse_file(file_path) + + if not doc.nodes and not doc.edges: + return { + "file_path": file_path, + "nodes": 0, + "edges": 0, + "errors": len(doc.errors), + "message": "No nodes extracted", + } + + enriched_nodes = await enricher.enrich_batch(doc.nodes) + doc = type(doc)( + file_path=doc.file_path, + language=doc.language, + nodes=enriched_nodes, + edges=doc.edges, + errors=doc.errors, + ) + + if vector: + await vector.delete_by_file(file_path) + await builder.remove_document(file_path) + await builder.ingest_document(doc) + + return { + "file_path": file_path, + "nodes": len(doc.nodes), + "edges": len(doc.edges), + "errors": len(doc.errors), + } + + +async def handle_rpc( + request: Request, + *, + engine: Any, + enricher: Any, + builder: Any, + registry: Any, + vector: Any, + safety: dict[str, Any] | None = None, + telemetry_engine: Any = None, + handoff_manager: Any = None, + integrity_verifier: Any = None, + runtime_linker: Any = None, +) -> Response: + """Dispatch a single JSON-RPC 2.0 request.""" + + # Build context for handlers + context: dict[str, Any] = { + "engine": engine, + "enricher": enricher, + "builder": builder, + "registry": registry, + "vector": vector, + "safety": safety, + "telemetry_engine": telemetry_engine, + "handoff_manager": handoff_manager, + "integrity_verifier": integrity_verifier, + "runtime_linker": runtime_linker, + } + try: + body = await request.body() + except Exception: + return _error_response(None, -32700, "Parse error") + + if not body: + return _error_response(None, -32700, "Empty request body") + + try: + req = msgspec.json.decode(body, type=JsonRpcRequest) + except (msgspec.DecodeError, Exception) as exc: + return _error_response(None, -32700, f"Parse error: {exc}") + + if req.jsonrpc != "2.0": + return _error_response(req.id, -32600, "Invalid Request: jsonrpc must be '2.0'") + + if not req.method: + return _error_response(req.id, -32600, "Invalid Request: method is required") + + method = req.method + params = req.params + + log.debug("rpc_request", method=method, id=req.id) + + try: + # --- Memory Management --- + if method == "smp/update": + result = await _handle_update(params, engine, enricher, builder, registry, vector) + + elif method == "smp/batch_update": + bp = msgspec.convert(params, BatchUpdateParams) + results = [] + for change in bp.changes: + r = await _handle_update(change, engine, enricher, builder, registry, vector) + results.append(r) + result = {"updates": len(results), "results": results} + + elif method == "smp/reindex": + rp = msgspec.convert(params, ReindexParams) + result = {"status": "reindex_requested", "scope": rp.scope} + + # --- Enrichment --- + elif method == "smp/enrich": + ep = msgspec.convert(params, EnrichParams) + node = await engine._graph.get_node(ep.node_id) + if not node: + return _error_response(req.id, -32001, "Node not found", data={"node_id": ep.node_id}) + enriched = await enricher.enrich_node(node, force=ep.force) + if enriched.semantic.source_hash and enriched.semantic.status == "enriched": + await engine._graph.upsert_node(enriched) + result = { + "node_id": enriched.id, + "status": enriched.semantic.status, + "docstring": enriched.semantic.docstring, + "inline_comments": [{"line": c.line, "text": c.text} for c in enriched.semantic.inline_comments], + "decorators": enriched.semantic.decorators, + "annotations": { + "params": enriched.semantic.annotations.params if enriched.semantic.annotations else {}, + "returns": enriched.semantic.annotations.returns if enriched.semantic.annotations else None, + "throws": enriched.semantic.annotations.throws if enriched.semantic.annotations else [], + } + if enriched.semantic.annotations + else {}, + "tags": enriched.semantic.tags, + "source_hash": enriched.semantic.source_hash, + "enriched_at": enriched.semantic.enriched_at, + } + + elif method == "smp/enrich/batch": + ebp = msgspec.convert(params, EnrichBatchParams) + nodes = await engine._graph.find_nodes_by_scope(ebp.scope) + enriched_count = 0 + skipped_count = 0 + no_metadata_count = 0 + no_metadata_nodes: list[str] = [] + for node in nodes: + enriched = await enricher.enrich_node(node, force=ebp.force) + if enriched.semantic.status == "enriched": + enriched_count += 1 + await engine._graph.upsert_node(enriched) + elif enriched.semantic.status == "skipped": + skipped_count += 1 + elif enriched.semantic.status == "no_metadata": + no_metadata_count += 1 + no_metadata_nodes.append(enriched.id) + result = { + "enriched": enriched_count, + "skipped": skipped_count, + "no_metadata": no_metadata_count, + "failed": 0, + "no_metadata_nodes": no_metadata_nodes, + } + + elif method == "smp/enrich/stale": + esp = msgspec.convert(params, EnrichStaleParams) + nodes = await engine._graph.find_nodes_by_scope(esp.scope) + stale_nodes = [] + for node in nodes: + if node.semantic.source_hash: + from smp.engine.enricher import _compute_source_hash + + current = _compute_source_hash( + node.structural.name, + node.file_path, + node.structural.start_line, + node.structural.end_line, + node.structural.signature, + ) + if current != node.semantic.source_hash: + stale_nodes.append( + { + "node_id": node.id, + "file": node.file_path, + "last_enriched": node.semantic.enriched_at, + "current_hash": current, + "enriched_hash": node.semantic.source_hash, + } + ) + result = {"stale_count": len(stale_nodes), "stale_nodes": stale_nodes} + + elif method == "smp/enrich/status": + estp = msgspec.convert(params, EnrichStatusParams) + nodes = await engine._graph.find_nodes_by_scope(estp.scope) + total = len(nodes) + has_docstring = sum(1 for n in nodes if n.semantic.docstring) + has_annotations = sum( + 1 + for n in nodes + if n.semantic.annotations and (n.semantic.annotations.params or n.semantic.annotations.returns) + ) + has_tags = sum(1 for n in nodes if n.semantic.tags) + manually_annotated = sum(1 for n in nodes if n.semantic.manually_set) + no_metadata = sum(1 for n in nodes if n.semantic.status == "no_metadata") + coverage = round((total - no_metadata) / total * 100, 1) if total > 0 else 0 + result = { + "total_nodes": total, + "has_docstring": has_docstring, + "has_annotations": has_annotations, + "has_tags": has_tags, + "manually_annotated": manually_annotated, + "no_metadata": no_metadata, + "stale": 0, + "coverage_pct": coverage, + } + + # --- Annotation --- + elif method == "smp/annotate": + ap = msgspec.convert(params, AnnotateParams) + node = await engine._graph.get_node(ap.node_id) + if not node: + return _error_response(req.id, -32001, "Node not found", data={"node_id": ap.node_id}) + if node.semantic.docstring and not ap.force: + return _error_response( + req.id, + -32002, + "Node already has extracted docstring. Set force: true to override.", + data={"node_id": ap.node_id}, + ) + node.semantic.description = ap.description + node.semantic.tags = list(set(node.semantic.tags + ap.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + result = { + "node_id": ap.node_id, + "status": "annotated", + "manually_set": True, + "annotated_at": node.semantic.enriched_at, + } + + elif method == "smp/annotate/bulk": + abp = msgspec.convert(params, AnnotateBulkParams) + annotated = 0 + failed = 0 + for ann in abp.annotations: + node = await engine._graph.get_node(ann.node_id) + if not node: + failed += 1 + continue + node.semantic.description = ann.description + node.semantic.tags = list(set(node.semantic.tags + ann.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + annotated += 1 + result = {"annotated": annotated, "failed": failed} + + elif method == "smp/tag": + tp = msgspec.convert(params, TagParams) + nodes = await engine._graph.find_nodes_by_scope(tp.scope) + affected = 0 + for node in nodes: + if tp.action == "add": + node.semantic.tags = list(set(node.semantic.tags + tp.tags)) + elif tp.action == "remove": + node.semantic.tags = [t for t in node.semantic.tags if t not in tp.tags] + elif tp.action == "replace": + node.semantic.tags = list(tp.tags) + await engine._graph.upsert_node(node) + affected += 1 + result = {"nodes_affected": affected, "action": tp.action, "scope": tp.scope} + + # --- Safety --- + elif method == "smp/session/open": + sop = msgspec.convert(params, SessionOpenParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = await safety["session_manager"].open_session(sop.agent_id, sop.task, sop.scope, sop.mode) + + elif method == "smp/session/close": + scp = msgspec.convert(params, SessionCloseParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + close_result = await safety["session_manager"].close_session(scp.session_id, scp.status) + if close_result: + safety["lock_manager"].release_all(scp.session_id) + if "audit_logger" in safety: + safety["audit_logger"].close_log(close_result.get("audit_log_id", ""), scp.status) + result = close_result + else: + return _error_response(req.id, -32001, "Session not found", data={"session_id": scp.session_id}) + + elif method == "smp/guard/check": + gcp = msgspec.convert(params, GuardCheckParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = await safety["guard_engine"].check(gcp.session_id, gcp.target, gcp.intended_change) + + elif method == "smp/dryrun": + drp = msgspec.convert(params, DryRunParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = safety["dryrun_simulator"].simulate( + drp.session_id, drp.file_path, drp.proposed_content, drp.change_summary + ) + + elif method == "smp/checkpoint": + cp = msgspec.convert(params, CheckpointParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = safety["checkpoint_manager"].create(cp.session_id, cp.files) + + elif method == "smp/rollback": + rbp = msgspec.convert(params, RollbackParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = safety["checkpoint_manager"].rollback(rbp.checkpoint_id) + + elif method == "smp/lock": + lp = msgspec.convert(params, LockParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = await safety["lock_manager"].acquire(lp.session_id, lp.files) + + elif method == "smp/unlock": + ulp = msgspec.convert(params, LockParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + await safety["lock_manager"].release(ulp.session_id, ulp.files) + result = {"released": ulp.files} + + elif method == "smp/audit/get": + agp = msgspec.convert(params, AuditGetParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + audit = safety["audit_logger"].get_log(agp.audit_log_id) + if not audit: + return _error_response(req.id, -32001, "Audit log not found", data={"audit_log_id": agp.audit_log_id}) + result = audit + + # --- Query --- + elif method == "smp/navigate": + np_ = msgspec.convert(params, NavigateParams) + result = await engine.navigate(np_.query, np_.include_relationships) + + elif method == "smp/trace": + trp = msgspec.convert(params, TraceParams) + result = await engine.trace(trp.start, trp.relationship, trp.depth, trp.direction) + + elif method == "smp/context": + ctp = msgspec.convert(params, ContextParams) + result = await engine.get_context(ctp.file_path, ctp.scope, ctp.depth) + + elif method == "smp/impact": + imp = msgspec.convert(params, ImpactParams) + result = await engine.assess_impact(imp.entity, imp.change_type) + + elif method == "smp/locate": + loc = msgspec.convert(params, LocateParams) + result = await engine.locate(loc.query, loc.fields, loc.node_types, loc.top_k) + + elif method == "smp/search": + sp = msgspec.convert(params, SearchParams) + result = await engine.search(sp.query, sp.match, sp.filter, sp.top_k) + + elif method == "smp/flow": + fp = msgspec.convert(params, FlowParams) + result = await engine.find_flow(fp.start, fp.end, fp.flow_type) + + elif method == "smp/graph/why": + wp = msgspec.convert(params, dict) + result = await engine.why( + entity=wp.get("entity", ""), + relationship=wp.get("relationship", ""), + depth=wp.get("depth", 3), + ) + + elif method == "smp/diff": + dp = msgspec.convert(params, dict) + result = await engine.diff_file( + file_path=dp.get("file_path", ""), + proposed_content=dp.get("proposed_content"), + ) + + elif method == "smp/plan": + pp = msgspec.convert(params, dict) + result = await engine.plan_multi_file( + session_id=pp.get("session_id", ""), + task=pp.get("task", ""), + intended_writes=pp.get("intended_writes", []), + ) + + elif method == "smp/conflict": + cp = msgspec.convert(params, dict) + result = await engine.detect_conflict( + session_a=cp.get("session_a", ""), + session_b=cp.get("session_b", ""), + ) + + # --- Sandbox --- + elif method == "smp/sandbox/spawn": + if not safety: + return _error_response(req.id, -32601, "Sandbox functionality requires safety protocol") + result = safety["sandbox_spawner"].spawn( + name=params.get("name"), template=params.get("template"), files=params.get("files") + ) + result = { + "sandbox_id": result.sandbox_id, + "root_path": result.root_path, + "created_at": result.created_at, + "status": result.status, + } + elif method == "smp/sandbox/execute": + if not safety: + return _error_response(req.id, -32601, "Sandbox functionality requires safety protocol") + sep = msgspec.convert(params, dict) + executor = safety.get("sandbox_executor") + if not executor: + # Create a default executor if not in context + executor = SandboxExecutor() + result = await executor.execute( + command=sep.get("command", []), stdin=sep.get("stdin"), cwd=sep.get("working_directory") + ) + result = { + "execution_id": result.execution_id, + "exit_code": result.exit_code, + "stdout": result.stdout, + "stderr": result.stderr, + "duration_ms": result.duration_ms, + "memory_used_mb": result.memory_used_mb, + "timed_out": result.timed_out, + "killed": result.killed, + "metadata": result.metadata, + } + elif method == "smp/sandbox/destroy": + if not safety: + return _error_response(req.id, -32601, "Sandbox functionality requires safety protocol") + sdp = msgspec.convert(params, dict) + sandbox_id = sdp.get("sandbox_id") + if not sandbox_id: + return _error_response(req.id, -32602, "sandbox_id is required") + destroyed = safety["sandbox_spawner"].destroy(sandbox_id) + if destroyed: + result = { + "sandbox_id": sandbox_id, + "status": "destroyed", + "destroyed_at": datetime.now(UTC).isoformat(), + } + else: + result = {"error": f"Sandbox not found: {sandbox_id}"} + + # --- Telemetry --- + elif method == "smp/telemetry": + tp = msgspec.convert(params, TelemetryParams) + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + result = {"action": tp.action, "status": "not_configured"} + elif tp.action == "get_stats": + result = telemetry_engine.get_summary() + elif tp.action == "get_hot" and tp.node_id: + result = telemetry_engine.get_stats(tp.node_id) + elif tp.action == "decay": + result = {"decayed": telemetry_engine.decay()} + else: + result = {"error": "Unknown telemetry action"} + + # --- Runtime Linker --- + elif method == "smp/linker/report": + linker = context.get("runtime_linker") + if not linker: + result = {"unresolved_edges": [], "status": "not_configured"} + else: + pending_count = linker.get_pending_count() + result = {"unresolved_edges": [], "pending_count": pending_count, "status": "ok"} + elif method == "smp/linker/runtime": + linker = context.get("runtime_linker") + if not linker: + result = {"hot_paths": [], "status": "not_configured"} + else: + threshold = params.get("threshold", 10) + result = {"hot_paths": linker.get_hot_paths(threshold), "stats": linker.get_stats()} + + # --- Handoff --- + elif method == "smp/handoff/review": + rcp = msgspec.convert(params, ReviewCreateParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + review = handoff_manager.create_review( + session_id=rcp.session_id, + files_changed=rcp.files_changed, + diff_summary=rcp.diff_summary, + reviewers=rcp.reviewers, + ) + result = {"review_id": review.review_id, "status": review.status, "created_at": review.created_at} + elif method == "smp/handoff/review/comment": + rcm = msgspec.convert(params, ReviewCommentParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + success = handoff_manager.add_comment( + review_id=rcm.review_id, + author=rcm.author, + comment=rcm.comment, + file_path=rcm.file_path, + line=rcm.line, + ) + result = {"success": success, "review_id": rcm.review_id} + elif method == "smp/handoff/review/approve": + rap = msgspec.convert(params, ReviewApproveParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + success = handoff_manager.approve(rap.review_id, rap.reviewer) + result = {"success": success, "review_id": rap.review_id, "status": "approved" if success else "failed"} + elif method == "smp/handoff/review/reject": + rrj = msgspec.convert(params, ReviewRejectParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + success = handoff_manager.reject(rrj.review_id, rrj.reviewer, rrj.reason) + result = {"success": success, "review_id": rrj.review_id, "status": "rejected" if success else "failed"} + elif method == "smp/handoff/pr": + pcp = msgspec.convert(params, PRCreateParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + pr = handoff_manager.create_pr( + review_id=pcp.review_id, + title=pcp.title, + body=pcp.body, + branch=pcp.branch, + base_branch=pcp.base_branch, + ) + if pr: + result = {"pr_id": pr.pr_id, "status": pr.status, "created_at": pr.created_at, "url": pr.url} + else: + result = {"error": "Review not found or not approved"} + + # --- Integrity --- + elif method == "smp/verify/integrity": + icp = msgspec.convert(params, IntegrityCheckParams) + verifier = context.get("integrity_verifier") + if not verifier: + result = {"status": "not_configured", "error": "Integrity verifier not available"} + else: + result = await verifier.verify(icp.node_id, icp.current_state) + + else: + return _error_response(req.id, -32601, f"Method not found: {method}") + + except msgspec.ValidationError as exc: + return _error_response(req.id, -32602, f"Invalid params: {exc}") + except Exception as exc: + log.error("rpc_internal_error", method=method, error=str(exc)) + return _error_response(req.id, -32603, f"Internal error: {exc}") + + if req.id is None: + return Response(content=b"", status_code=204) + + return _success_response(req.id, result) diff --git a/smp/protocol/server.py b/smp/protocol/server.py new file mode 100644 index 0000000..40dd426 --- /dev/null +++ b/smp/protocol/server.py @@ -0,0 +1,154 @@ +"""FastAPI application with JSON-RPC 2.0 endpoint. + +Start with: ``python3.11 -m smp.cli serve`` +""" + +from __future__ import annotations + +try: + import pysqlite3 + import sys + sys.modules["sqlite3"] = pysqlite3 +except ImportError: + pass + +import os +from contextlib import asynccontextmanager +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.responses import Response + +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.engine.seed_walk import SeedWalkEngine +from smp.engine.community import CommunityDetector +from smp.core.merkle import MerkleIndex +from smp.logging import get_logger +from smp.parser.registry import ParserRegistry +from smp.protocol.dispatcher import handle_rpc +from smp.store.graph.neo4j_store import Neo4jGraphStore +from smp.store.chroma_store import ChromaVectorStore + +log = get_logger(__name__) + + +def create_app( + neo4j_uri: str | None = None, + neo4j_user: str | None = None, + neo4j_password: str | None = None, + safety_enabled: bool = False, +) -> FastAPI: + """Create and configure the SMP FastAPI application.""" + + uri = neo4j_uri or os.environ.get("SMP_NEO4J_URI", "bolt://localhost:7687") + user = neo4j_user or os.environ.get("SMP_NEO4J_USER", "neo4j") + password = neo4j_password or os.environ.get("SMP_NEO4J_PASSWORD", "") + + @asynccontextmanager + async def lifespan(app: FastAPI): # type: ignore[no-untyped-def] # noqa: ANN202 + graph = Neo4jGraphStore(uri=uri, user=user, password=password) + await graph.connect() + + vector = ChromaVectorStore() + await vector.connect() + + enricher = StaticSemanticEnricher() + community_detector = CommunityDetector(graph_store=graph, vector_store=vector) + engine = SeedWalkEngine(graph_store=graph, vector_store=vector, enricher=enricher) + builder = DefaultGraphBuilder(graph) + registry = ParserRegistry() + merkle_index = MerkleIndex() + + safety: dict[str, Any] | None = None + if safety_enabled: + from smp.engine.handoff import HandoffManager + from smp.engine.integrity import IntegrityVerifier + from smp.engine.safety import ( + AuditLogger, + CheckpointManager, + DryRunSimulator, + GuardEngine, + LockManager, + SessionManager, + ) + from smp.engine.telemetry import TelemetryEngine + from smp.sandbox.executor import SandboxExecutor + from smp.sandbox.spawner import SandboxSpawner + + session_manager = SessionManager(graph_store=graph) + lock_manager = LockManager(graph_store=graph) + session_manager.set_graph_store(graph) + lock_manager.set_graph_store(graph) + sandbox_spawner = SandboxSpawner() + sandbox_executor = SandboxExecutor() + telemetry_engine = TelemetryEngine() + handoff_manager = HandoffManager() + integrity_verifier = IntegrityVerifier() + + # Runtime linker and linker are already available in the graph + # We'll add them to app.state for access via context + app.state.telemetry_engine = telemetry_engine + app.state.handoff_manager = handoff_manager + app.state.integrity_verifier = integrity_verifier + + safety = { + "session_manager": session_manager, + "lock_manager": lock_manager, + "guard_engine": GuardEngine(session_manager, lock_manager), + "dryrun_simulator": DryRunSimulator(), + "checkpoint_manager": CheckpointManager(), + "audit_logger": AuditLogger(), + "sandbox_spawner": sandbox_spawner, + "sandbox_executor": sandbox_executor, + } + + app.state.graph = graph + app.state.vector = vector + app.state.engine = engine + app.state.community_detector = community_detector + app.state.merkle_index = merkle_index + app.state.builder = builder + app.state.enricher = enricher + app.state.registry = registry + app.state.safety = safety + + log.info("server_started", neo4j=neo4j_uri, safety=safety_enabled) + yield + + await graph.close() + log.info("server_stopped") + + app = FastAPI(title="SMP — Structural Memory Protocol", version="3.0.0", lifespan=lifespan) + + @app.post("/rpc") + async def rpc_endpoint(request: Request) -> Response: + return await handle_rpc( + request, + engine=app.state.engine, + enricher=app.state.enricher, + builder=app.state.builder, + registry=app.state.registry, + vector=app.state.vector, + safety=app.state.safety, + telemetry_engine=getattr(app.state, "telemetry_engine", None), + handoff_manager=getattr(app.state, "handoff_manager", None), + integrity_verifier=getattr(app.state, "integrity_verifier", None), + ) + + @app.get("/health") + async def health() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/stats") + async def stats() -> dict[str, int]: + graph: Neo4jGraphStore = app.state.graph + return { + "nodes": await graph.count_nodes(), + "edges": await graph.count_edges(), + } + + return app + + +app = create_app() diff --git a/smp/sandbox/__init__.py b/smp/sandbox/__init__.py new file mode 100644 index 0000000..a73e2af --- /dev/null +++ b/smp/sandbox/__init__.py @@ -0,0 +1 @@ +"""SMP sandbox runtime module.""" diff --git a/smp/sandbox/docker_sandbox.py b/smp/sandbox/docker_sandbox.py new file mode 100644 index 0000000..80ba1dc --- /dev/null +++ b/smp/sandbox/docker_sandbox.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import docker + +from smp.logging import get_logger + +log = get_logger(__name__) + + +class DockerSandbox: + def __init__(self) -> None: + self._client = docker.from_env() + self._container: docker.models.containers.Container | None = None + self._network: docker.models.networks.Network | None = None + + def spawn(self, name: str, image: str, services: list[str]) -> str: + self._network = self._client.networks.create( + name=f"{name}_net", + internal=True, + ) + + self._container = self._client.containers.run( + image=image, + name=name, + detach=True, + network=self._network.name, + volumes={ + f"{name}_cow": {"bind": "/data", "mode": "rw"}, + }, + labels={"smp_sandbox": "true"}, + ) + + log.info("docker_sandbox_spawned", container_id=str(self._container.id), name=name) + return str(self._container.id) + + def execute(self, command: str, timeout: int) -> str: + if not self._container: + log.error("docker_sandbox_execute_failed", reason="no_container") + raise RuntimeError("No container spawned") + + exit_code, output = self._container.exec_run(command, timeout=timeout) + + if exit_code != 0: + log.warn("docker_sandbox_exec_nonzero", exit_code=exit_code, command=command) + + return str(output.decode("utf-8")) + + def destroy(self) -> None: + if self._container: + self._container.remove(force=True) + log.info("docker_sandbox_container_removed", container_id=self._container.id) + self._container = None + + if self._network: + self._network.remove() + log.info("docker_sandbox_network_removed", network_id=self._network.id) + self._network = None diff --git a/smp/sandbox/ebpf_collector.py b/smp/sandbox/ebpf_collector.py new file mode 100644 index 0000000..a36cb44 --- /dev/null +++ b/smp/sandbox/ebpf_collector.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import uuid + +from smp.logging import get_logger + +log = get_logger(__name__) + + +class EBPFCollector: + def __init__(self) -> None: + self._active_traces: dict[str, str] = {} + self._data: list[dict[str, str | int]] = [] + + def start_trace(self, session_id: str) -> str: + trace_id = str(uuid.uuid4()) + self._active_traces[trace_id] = session_id + log.info("ebpf_trace_started", trace_id=trace_id, session_id=session_id) + return trace_id + + def stop_trace(self, trace_id: str) -> None: + if trace_id in self._active_traces: + session_id = self._active_traces.pop(trace_id) + log.info("ebpf_trace_stopped", trace_id=trace_id, session_id=session_id) + else: + log.error("ebpf_trace_stop_failed", trace_id=trace_id) + + def get_traces(self) -> list[dict[str, str | int]]: + return self._data diff --git a/smp/sandbox/executor.py b/smp/sandbox/executor.py new file mode 100644 index 0000000..5018e95 --- /dev/null +++ b/smp/sandbox/executor.py @@ -0,0 +1,169 @@ +"""SMP(3) sandbox executor for isolated runtime execution. + +Provides isolated execution environments for running agent code safely. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import os +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + +_SANDBOX_DEFAULT_TIMEOUT = 30 +_SANDBOX_DEFAULT_MEMORY_MB = 512 + + +@dataclass +class ExecutionResult: + """Result of a sandbox execution.""" + + execution_id: str + exit_code: int + stdout: str + stderr: str + duration_ms: int + memory_used_mb: float = 0.0 + timed_out: bool = False + killed: bool = False + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class SandboxConfig: + """Configuration for sandbox execution.""" + + timeout_seconds: int = _SANDBOX_DEFAULT_TIMEOUT + memory_limit_mb: int = _SANDBOX_DEFAULT_MEMORY_MB + allow_network: bool = False + allow_file_write: bool = False + working_directory: str = "" + environment: dict[str, str] = field(default_factory=dict) + + +class SandboxExecutor: + """Executes code in an isolated sandbox environment.""" + + def __init__(self, config: SandboxConfig | None = None) -> None: + self._config = config or SandboxConfig() + self._active_processes: dict[str, asyncio.subprocess.Process] = {} + + async def execute( + self, + command: list[str], + stdin: str | None = None, + cwd: str | None = None, + ) -> ExecutionResult: + """Execute a command in the sandbox.""" + execution_id = f"exec_{uuid.uuid4().hex[:8]}" + start_time = asyncio.get_event_loop().time() + + work_dir = cwd or self._config.working_directory or str(Path.cwd()) + + env = os.environ.copy() + env.update(self._config.environment) + if not self._config.allow_network: + env["NO_NETWORK"] = "1" + + try: + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE if stdin else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=work_dir, + env=env, + ) + self._active_processes[execution_id] = process + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + process.communicate(stdin.encode() if stdin else None), + timeout=self._config.timeout_seconds, + ) + timed_out = False + except TimeoutError: + process.kill() + await process.wait() + stdout_bytes, stderr_bytes = b"", b"Timeout exceeded" + timed_out = True + + duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000) + + result = ExecutionResult( + execution_id=execution_id, + exit_code=process.returncode or -1, + stdout=stdout_bytes.decode("utf-8", errors="replace"), + stderr=stderr_bytes.decode("utf-8", errors="replace"), + duration_ms=duration_ms, + timed_out=timed_out, + killed=timed_out, + ) + + log.info( + "sandbox_execution_complete", + execution_id=execution_id, + exit_code=result.exit_code, + duration_ms=duration_ms, + timed_out=timed_out, + ) + return result + + except Exception as exc: + log.error("sandbox_execution_error", execution_id=execution_id, error=str(exc)) + return ExecutionResult( + execution_id=execution_id, + exit_code=-1, + stdout="", + stderr=str(exc), + duration_ms=int((asyncio.get_event_loop().time() - start_time) * 1000), + ) + finally: + self._active_processes.pop(execution_id, None) + + async def execute_python( + self, + code: str, + timeout: int | None = None, + ) -> ExecutionResult: + """Execute Python code in the sandbox.""" + import tempfile + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_path = f.name + + try: + config = SandboxConfig( + timeout_seconds=timeout or self._config.timeout_seconds, + **{k: v for k, v in self._config.__dict__.items() if k != "timeout_seconds"}, + ) + executor = SandboxExecutor(config) + return await executor.execute(["python3.11", temp_path]) + finally: + Path(temp_path).unlink(missing_ok=True) + + def kill(self, execution_id: str) -> bool: + """Kill an active execution.""" + process = self._active_processes.get(execution_id) + if process: + process.kill() + log.info("sandbox_killed", execution_id=execution_id) + return True + return False + + async def cleanup(self) -> None: + """Kill all active executions.""" + for exec_id, process in list(self._active_processes.items()): + process.kill() + with contextlib.suppress(Exception): + await process.wait() + log.info("sandbox_cleanup", execution_id=exec_id) + self._active_processes.clear() diff --git a/smp/sandbox/spawner.py b/smp/sandbox/spawner.py new file mode 100644 index 0000000..39bcdb4 --- /dev/null +++ b/smp/sandbox/spawner.py @@ -0,0 +1,113 @@ +"""Sandbox spawner for creating isolated execution environments. + +Manages the lifecycle of sandboxed processes and containers. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + +_DEFAULT_SANDBOX_ROOT = Path.home() / ".smp" / "sandboxes" + + +@dataclass +class SandboxInfo: + """Information about a spawned sandbox.""" + + sandbox_id: str + root_path: str + created_at: str + status: str = "created" + pid: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +class SandboxSpawner: + """Spawns and manages isolated sandbox directories.""" + + def __init__(self, sandbox_root: Path | None = None) -> None: + self._root = sandbox_root or _DEFAULT_SANDBOX_ROOT + self._sandboxes: dict[str, SandboxInfo] = {} + + def spawn( + self, + name: str | None = None, + template: str | None = None, + files: dict[str, str] | None = None, + ) -> SandboxInfo: + """Create a new sandbox directory.""" + sandbox_id = f"sandbox_{uuid.uuid4().hex[:8]}" + sandbox_name = name or sandbox_id + sandbox_path = self._root / sandbox_name + + sandbox_path.mkdir(parents=True, exist_ok=True) + + if template: + template_path = self._root / template + if template_path.exists(): + import shutil + + shutil.copytree(template_path, sandbox_path, dirs_exist_ok=True) + + if files: + for rel_path, content in files.items(): + file_path = sandbox_path / rel_path + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content, encoding="utf-8") + + info = SandboxInfo( + sandbox_id=sandbox_id, + root_path=str(sandbox_path), + created_at=datetime.now(UTC).isoformat(), + ) + self._sandboxes[sandbox_id] = info + + log.info("sandbox_spawned", sandbox_id=sandbox_id, path=str(sandbox_path)) + return info + + def get(self, sandbox_id: str) -> SandboxInfo | None: + """Get sandbox info by ID.""" + return self._sandboxes.get(sandbox_id) + + def list_active(self) -> list[SandboxInfo]: + """List all active sandboxes.""" + return list(self._sandboxes.values()) + + def destroy(self, sandbox_id: str) -> bool: + """Remove a sandbox directory.""" + info = self._sandboxes.get(sandbox_id) + if not info: + return False + + import shutil + + path = Path(info.root_path) + if path.exists(): + shutil.rmtree(path) + + del self._sandboxes[sandbox_id] + log.info("sandbox_destroyed", sandbox_id=sandbox_id) + return True + + async def cleanup_all(self) -> int: + """Remove all sandbox directories.""" + import shutil + + count = 0 + for sandbox_id, info in list(self._sandboxes.items()): + path = Path(info.root_path) + if path.exists(): + shutil.rmtree(path) + count += 1 + del self._sandboxes[sandbox_id] + + log.info("sandboxes_cleaned", count=count) + return count diff --git a/smp/store/__init__.py b/smp/store/__init__.py new file mode 100644 index 0000000..ecccfaa --- /dev/null +++ b/smp/store/__init__.py @@ -0,0 +1 @@ +"""Store layer — graph and vector backends.""" diff --git a/smp/store/chroma_store.py b/smp/store/chroma_store.py new file mode 100644 index 0000000..f51df7d --- /dev/null +++ b/smp/store/chroma_store.py @@ -0,0 +1,155 @@ +"""ChromaDB-backed vector store implementation.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import chromadb + +from smp.logging import get_logger +from smp.store.interfaces import VectorStore + +log = get_logger(__name__) + + +class ChromaVectorStore(VectorStore): + """Persist embeddings in ChromaDB with metadata filtering.""" + + def __init__( + self, + collection_name: str = "smp_code_embeddings", + persist_dir: str | None = None, + ) -> None: + self._collection_name = collection_name + self._persist_dir = persist_dir + self._client: Any = None + self._collection: Any = None + + async def connect(self) -> None: + if self._persist_dir is not None: + self._client = chromadb.PersistentClient(path=self._persist_dir) + else: + self._client = chromadb.Client() + self._collection = self._client.get_or_create_collection(name=self._collection_name) + log.info("chroma_connected", collection=self._collection_name, persist_dir=self._persist_dir) + + async def close(self) -> None: + self._client = None + self._collection = None + log.info("chroma_closed", collection=self._collection_name) + + async def clear(self) -> None: + if self._client is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._client.delete_collection(name=self._collection_name) + self._collection = self._client.get_or_create_collection(name=self._collection_name) + log.info("chroma_cleared", collection=self._collection_name) + + async def upsert( + self, + ids: Sequence[str], + embeddings: Sequence[Sequence[float]], + metadatas: Sequence[dict[str, Any]], + documents: Sequence[str] | None = None, + ) -> None: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._collection.upsert( + ids=list(ids), + embeddings=[list(e) for e in embeddings], + metadatas=list(metadatas), + documents=list(documents) if documents is not None else None, + ) + log.info("chroma_upserted", count=len(ids)) + + async def query( + self, + embedding: Sequence[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + result = self._collection.query( + query_embeddings=[list(embedding)], + n_results=top_k, + where=where, + ) + return _normalise_query_result(result) + + async def get(self, ids: Sequence[str]) -> list[dict[str, Any] | None]: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + result = self._collection.get(ids=list(ids)) + return _normalise_get_result(result) + + async def delete(self, ids: Sequence[str]) -> int: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._collection.delete(ids=list(ids)) + log.info("chroma_deleted", count=len(ids)) + return len(ids) + + async def delete_by_file(self, file_path: str) -> int: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._collection.delete(where={"file_path": file_path}) + log.info("chroma_deleted_by_file", file_path=file_path) + return -1 + + async def add_code_embedding( + self, + node_id: str, + embedding: list[float], + metadata: dict[str, Any], + document: str = "", + ) -> None: + await self.upsert( + ids=[node_id], + embeddings=[embedding], + metadatas=[metadata], + documents=[document], + ) + + async def query_similar( + self, + embedding: list[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + return await self.query(embedding=embedding, top_k=top_k, where=where) + + +def _normalise_query_result(result: dict[str, Any]) -> list[dict[str, Any]]: + ids_batch = result.get("ids", [[]]) + distances_batch = result.get("distances", [[]]) + metadatas_batch = result.get("metadatas", [[]]) + documents_batch = result.get("documents", [[]]) + out: list[dict[str, Any]] = [] + for i, entry_id in enumerate(ids_batch[0]): + out.append( + { + "id": entry_id, + "score": distances_batch[0][i] if distances_batch and i < len(distances_batch[0]) else None, + "metadata": metadatas_batch[0][i] if metadatas_batch and i < len(metadatas_batch[0]) else {}, + "document": documents_batch[0][i] if documents_batch and i < len(documents_batch[0]) else "", + } + ) + return out + + +def _normalise_get_result(result: dict[str, Any]) -> list[dict[str, Any] | None]: + ids = result.get("ids", []) + metadatas = result.get("metadatas", []) + documents = result.get("documents", []) + out: list[dict[str, Any] | None] = [] + for i, entry_id in enumerate(ids): + out.append( + { + "id": entry_id, + "metadata": metadatas[i] if metadatas and i < len(metadatas) else {}, + "document": documents[i] if documents and i < len(documents) else "", + } + ) + return out diff --git a/smp/store/graph/__init__.py b/smp/store/graph/__init__.py new file mode 100644 index 0000000..74b6ccf --- /dev/null +++ b/smp/store/graph/__init__.py @@ -0,0 +1 @@ +"""Graph store implementations.""" diff --git a/smp/store/graph/neo4j_store.py b/smp/store/graph/neo4j_store.py new file mode 100644 index 0000000..d926102 --- /dev/null +++ b/smp/store/graph/neo4j_store.py @@ -0,0 +1,558 @@ +"""Neo4j-backed graph store implementation. + +Uses the official ``neo4j`` Python driver with async support. +Updated for SMP(3) partitioned schema (structural + semantic). +""" + +from __future__ import annotations + +import os +from collections.abc import Sequence +from datetime import UTC, datetime +from typing import Any + +from neo4j import AsyncDriver, AsyncGraphDatabase + +from smp.core.models import ( + Annotations, + EdgeType, + GraphEdge, + GraphNode, + InlineComment, + NodeType, + SemanticProperties, + StructuralProperties, +) +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + +_ALL_LABEL = "SMPNode" +_SESSION_LABEL = "SMPSession" +_LOCK_LABEL = "SMPLck" + + +def _node_to_props(node: GraphNode) -> dict[str, Any]: + """Convert a GraphNode to flat Neo4j properties.""" + return { + "id": node.id, + "type": node.type.value, + "file_path": node.file_path, + "structural_name": node.structural.name, + "structural_file": node.structural.file, + "structural_signature": node.structural.signature, + "structural_start_line": node.structural.start_line, + "structural_end_line": node.structural.end_line, + "structural_complexity": node.structural.complexity, + "structural_lines": node.structural.lines, + "structural_parameters": node.structural.parameters, + "semantic_status": node.semantic.status, + "semantic_docstring": node.semantic.docstring, + "semantic_description": node.semantic.description or "", + "semantic_decorators": str(node.semantic.decorators), + "semantic_tags": str(node.semantic.tags), + "semantic_manually_set": node.semantic.manually_set, + "semantic_source_hash": node.semantic.source_hash, + "semantic_enriched_at": node.semantic.enriched_at, + "semantic_annotations": str(node.semantic.annotations) if node.semantic.annotations else "", + "semantic_inline_comments": str(node.semantic.inline_comments), + } + + +def _record_to_node(record: dict[str, Any]) -> GraphNode: + """Reconstruct a GraphNode from a Neo4j record.""" + structural = StructuralProperties( + name=record.get("structural_name", ""), + file=record.get("structural_file", ""), + signature=record.get("structural_signature", ""), + start_line=record.get("structural_start_line", 0), + end_line=record.get("structural_end_line", 0), + complexity=record.get("structural_complexity", 0), + lines=record.get("structural_lines", 0), + parameters=record.get("structural_parameters", 0), + ) + + annotations_raw = record.get("semantic_annotations", "") + annotations: Annotations | None = None + if annotations_raw and annotations_raw != "": + try: + import ast + + parsed = ast.literal_eval(annotations_raw) + if isinstance(parsed, dict): + annotations = Annotations( + params=parsed.get("params", {}), + returns=parsed.get("returns"), + throws=parsed.get("throws", []), + ) + except (ValueError, SyntaxError): + pass + + decorators_raw = record.get("semantic_decorators", "[]") + try: + import ast + + decorators = ast.literal_eval(decorators_raw) if decorators_raw else [] + if not isinstance(decorators, list): + decorators = [] + except (ValueError, SyntaxError): + decorators = [] + + tags_raw = record.get("semantic_tags", "[]") + try: + import ast + + tags = ast.literal_eval(tags_raw) if tags_raw else [] + if not isinstance(tags, list): + tags = [] + except (ValueError, SyntaxError): + tags = [] + + comments_raw = record.get("semantic_inline_comments", "[]") + inline_comments: list[InlineComment] = [] + try: + import ast + + parsed_comments = ast.literal_eval(comments_raw) if comments_raw else [] + if isinstance(parsed_comments, list): + for c in parsed_comments: + if isinstance(c, dict): + inline_comments.append(InlineComment(line=c.get("line", 0), text=c.get("text", ""))) + elif isinstance(c, InlineComment): + inline_comments.append(c) + except (ValueError, SyntaxError): + pass + + semantic = SemanticProperties( + status=record.get("semantic_status", "no_metadata"), + docstring=record.get("semantic_docstring", ""), + description=record.get("semantic_description") or None, + inline_comments=inline_comments, + decorators=decorators, + annotations=annotations, + tags=tags, + manually_set=record.get("semantic_manually_set", False), + source_hash=record.get("semantic_source_hash", ""), + enriched_at=record.get("semantic_enriched_at", ""), + ) + + return GraphNode( + id=record["id"], + type=NodeType(record["type"]), + file_path=record["file_path"], + structural=structural, + semantic=semantic, + ) + + +class Neo4jGraphStore(GraphStore): + """Graph store backed by a Neo4j instance.""" + + def __init__( + self, + uri: str = "", + user: str = "", + password: str = "", + database: str = "neo4j", + ) -> None: + self._uri = uri or os.environ.get("SMP_NEO4J_URI", "bolt://localhost:7687") + self._user = user or os.environ.get("SMP_NEO4J_USER", "neo4j") + self._password = password or os.environ.get("SMP_NEO4J_PASSWORD", "") + self._database = database + self._driver: AsyncDriver | None = None + + async def connect(self) -> None: + self._driver = AsyncGraphDatabase.driver(self._uri, auth=(self._user, self._password)) + await self._driver.verify_connectivity() + log.info("neo4j_connected", uri=self._uri) + await self._execute(f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:{_ALL_LABEL}) REQUIRE n.id IS UNIQUE") + + # Create full-text index for search + await self._execute( + f"CREATE FULLTEXT INDEX node_search_index IF NOT EXISTS FOR (n:{_ALL_LABEL}) " + "ON EACH [n.semantic_docstring, n.semantic_description, n.structural_name, n.file_path]" + ) + + async def close(self) -> None: + if self._driver: + await self._driver.close() + self._driver = None + log.info("neo4j_closed") + + async def clear(self) -> None: + await self._execute("MATCH (n) DETACH DELETE n") + log.warning("neo4j_cleared") + + async def upsert_node(self, node: GraphNode) -> None: + props = _node_to_props(node) + cypher = f""" + MERGE (n:{_ALL_LABEL} {{id: $id}}) + SET n += $props + """ + await self._execute(cypher, {"id": node.id, "props": props}) + log.debug("node_upserted", node_id=node.id) + + async def upsert_session(self, session: Any) -> None: + """Store or update a session in the graph.""" + props = { + "session_id": session.session_id, + "agent_id": session.agent_id, + "task": session.task, + "mode": session.mode, + "opened_at": session.opened_at, + "expires_at": session.expires_at, + "status": session.status, + "files_written": session.files_written, + "files_read": session.files_read, + } + cypher = f""" + MERGE (n:{_SESSION_LABEL} {{session_id: $session_id}}) + SET n += $props + """ + await self._execute(cypher, {"session_id": session.session_id, "props": props}) + log.debug("session_upserted", session_id=session.session_id) + + async def get_session(self, session_id: str) -> dict[str, Any] | None: + """Retrieve a session by ID.""" + cypher = f"MATCH (n:{_SESSION_LABEL} {{session_id: $session_id}}) RETURN n" + records = await self._execute(cypher, {"session_id": session_id}) + if not records: + return None + return dict(records[0]["n"]) + + async def delete_session(self, session_id: str) -> bool: + """Delete a session from the graph.""" + cypher = f"MATCH (n:{_SESSION_LABEL} {{session_id: $session_id}}) DETACH DELETE n RETURN count(n) AS deleted" + records = await self._execute(cypher, {"session_id": session_id}) + deleted = records[0]["deleted"] if records else 0 + return deleted > 0 + + async def upsert_lock(self, file_path: str, session_id: str) -> None: + """Store a file lock.""" + props = { + "file_path": file_path, + "session_id": session_id, + "acquired_at": datetime.now(UTC).isoformat(), + } + cypher = f""" + MERGE (n:{_LOCK_LABEL} {{file_path: $file_path, session_id: $session_id}}) + SET n += $props + """ + await self._execute(cypher, {"file_path": file_path, "session_id": session_id, "props": props}) + log.debug("lock_upserted", file_path=file_path, session_id=session_id) + + async def get_lock(self, file_path: str) -> dict[str, Any] | None: + """Get lock info for a file.""" + cypher = f"MATCH (n:{_LOCK_LABEL} {{file_path: $file_path}}) RETURN n LIMIT 1" + records = await self._execute(cypher, {"file_path": file_path}) + if not records: + return None + return dict(records[0]["n"]) + + async def release_lock(self, file_path: str, session_id: str) -> bool: + """Release a file lock.""" + cypher = f""" + MATCH (n:{_LOCK_LABEL} {{file_path: $file_path, session_id: $session_id}}) + DETACH DELETE n + RETURN count(n) AS deleted + """ + records = await self._execute(cypher, {"file_path": file_path, "session_id": session_id}) + deleted = records[0]["deleted"] if records else 0 + if deleted > 0: + log.debug("lock_released", file_path=file_path, session_id=session_id) + return deleted > 0 + + async def release_all_locks(self, session_id: str) -> int: + """Release all locks held by a session.""" + cypher = f""" + MATCH (n:{_LOCK_LABEL} {{session_id: $session_id}}) + DETACH DELETE n + RETURN count(n) AS deleted + """ + records = await self._execute(cypher, {"session_id": session_id}) + deleted = records[0]["deleted"] if records else 0 + log.info("locks_released_by_session", session_id=session_id, count=deleted) + return deleted + + async def upsert_nodes(self, nodes: Sequence[GraphNode]) -> None: + if not nodes: + return + batch = [_node_to_props(n) for n in nodes] + cypher = f""" + UNWIND $batch AS row + MERGE (n:{_ALL_LABEL} {{id: row.id}}) + SET n += row + """ + await self._execute(cypher, {"batch": batch}) + log.info("nodes_upserted_batch", count=len(nodes)) + + async def get_node(self, node_id: str) -> GraphNode | None: + cypher = f"MATCH (n:{_ALL_LABEL} {{id: $id}}) RETURN n" + records = await self._execute(cypher, {"id": node_id}) + if not records: + return None + return _record_to_node(dict(records[0]["n"])) + + async def delete_node(self, node_id: str) -> bool: + cypher = f"MATCH (n:{_ALL_LABEL} {{id: $id}}) DETACH DELETE n RETURN count(n) AS deleted" + records = await self._execute(cypher, {"id": node_id}) + deleted = records[0]["deleted"] if records else 0 + return deleted > 0 + + async def delete_nodes_by_file(self, file_path: str) -> int: + stem = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path + cypher = f""" + MATCH (n:{_ALL_LABEL}) + WHERE n.file_path = $file_path OR n.file_path = $stem + DETACH DELETE n + RETURN count(n) AS deleted + """ + records = await self._execute(cypher, {"file_path": file_path, "stem": stem}) + deleted = records[0]["deleted"] if records else 0 + log.info("nodes_deleted_by_file", file_path=file_path, deleted=deleted) + return deleted + + async def upsert_edge(self, edge: GraphEdge) -> None: + rel_type = edge.type.value + cypher = f""" + MATCH (a:{_ALL_LABEL} {{id: $source_id}}) + MATCH (b:{_ALL_LABEL} {{id: $target_id}}) + MERGE (a)-[r:{rel_type}]->(b) + SET r += $metadata + """ + await self._execute( + cypher, + { + "source_id": edge.source_id, + "target_id": edge.target_id, + "metadata": edge.metadata, + }, + ) + log.debug("edge_upserted", src=edge.source_id, tgt=edge.target_id, type=rel_type) + + async def upsert_edges(self, edges: Sequence[GraphEdge]) -> None: + if not edges: + return + grouped: dict[str, list[dict[str, Any]]] = {} + for e in edges: + grouped.setdefault(e.type.value, []).append( + {"source_id": e.source_id, "target_id": e.target_id, "metadata": e.metadata} + ) + for rel_type, batch in grouped.items(): + cypher = f""" + UNWIND $batch AS row + MATCH (a:{_ALL_LABEL} {{id: row.source_id}}) + MATCH (b:{_ALL_LABEL} {{id: row.target_id}}) + MERGE (a)-[r:{rel_type}]->(b) + SET r += row.metadata + """ + await self._execute(cypher, {"batch": batch}) + log.info("edges_upserted_batch", count=len(edges)) + + async def get_edges( + self, + node_id: str, + edge_type: EdgeType | None = None, + direction: str = "both", + ) -> list[GraphEdge]: + type_filter = f":{edge_type.value}" if edge_type else "" + if direction == "outgoing": + pattern = f"(a:{_ALL_LABEL} {{id: $id}})-[r{type_filter}]->(b)" + elif direction == "incoming": + pattern = f"(a)-[r{type_filter}]->(b:{_ALL_LABEL} {{id: $id}})" + else: + pattern = f"(a:{_ALL_LABEL} {{id: $id}})-[r{type_filter}]-(b)" + + cypher = f"MATCH {pattern} RETURN a.id AS source, b.id AS target, type(r) AS rel_type" + records = await self._execute(cypher, {"id": node_id}) + return [ + GraphEdge( + source_id=rec["source"], + target_id=rec["target"], + type=EdgeType(rec["rel_type"]), + ) + for rec in records + ] + + async def get_neighbors( + self, + node_id: str, + edge_type: EdgeType | None = None, + depth: int = 1, + ) -> list[GraphNode]: + type_filter = f":{edge_type.value}" if edge_type else "" + depth_str = f"1..{depth}" + cypher = f""" + MATCH (start:{_ALL_LABEL} {{id: $id}})-[r{type_filter}*{depth_str}]->(neighbor:{_ALL_LABEL}) + RETURN DISTINCT neighbor + """ + records = await self._execute(cypher, {"id": node_id}) + return [_record_to_node(dict(rec["neighbor"])) for rec in records] + + async def traverse( + self, + start_id: str, + edge_type: EdgeType, + depth: int, + max_nodes: int = 100, + direction: str = "outgoing", + ) -> list[GraphNode]: + rel_type = edge_type.value + if direction == "incoming": + cypher = f""" + MATCH path = (start:{_ALL_LABEL} {{id: $id}})<-[r:{rel_type}*1..{depth}]-(node:{_ALL_LABEL}) + RETURN DISTINCT node + LIMIT $max_nodes + """ + else: + cypher = f""" + MATCH path = (start:{_ALL_LABEL} {{id: $id}})-[r:{rel_type}*1..{depth}]->(node:{_ALL_LABEL}) + RETURN DISTINCT node + LIMIT $max_nodes + """ + records = await self._execute(cypher, {"id": start_id, "max_nodes": max_nodes}) + return [_record_to_node(dict(rec["node"])) for rec in records] + + async def find_nodes( + self, + *, + type: NodeType | None = None, + file_path: str | None = None, + name: str | None = None, + ) -> list[GraphNode]: + conditions: list[str] = [] + params: dict[str, Any] = {} + if type: + conditions.append("n.type = $type") + params["type"] = type.value + if file_path: + stem = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path + conditions.append("(n.file_path = $file_path OR n.file_path = $stem)") + params["file_path"] = file_path + params["stem"] = stem + if name: + conditions.append("n.structural_name = $name") + params["name"] = name + + where = " AND ".join(conditions) + where_clause = f"WHERE {where}" if where else "" + cypher = f"MATCH (n:{_ALL_LABEL}) {where_clause} RETURN n" + records = await self._execute(cypher, params) + return [_record_to_node(dict(rec["n"])) for rec in records] + + async def find_nodes_by_scope(self, scope: str) -> list[GraphNode]: + """Find nodes matching a scope prefix (package:path or file:path).""" + if scope == "full": + cypher = f"MATCH (n:{_ALL_LABEL}) RETURN n" + records = await self._execute(cypher) + return [_record_to_node(dict(rec["n"])) for rec in records] + + if scope.startswith("package:"): + prefix = scope[len("package:") :] + cypher = f"MATCH (n:{_ALL_LABEL}) WHERE n.file_path STARTS WITH $prefix RETURN n" + records = await self._execute(cypher, {"prefix": prefix}) + return [_record_to_node(dict(rec["n"])) for rec in records] + + if scope.startswith("file:"): + fp = scope[len("file:") :] + cypher = f"MATCH (n:{_ALL_LABEL}) WHERE n.file_path = $fp RETURN n" + records = await self._execute(cypher, {"fp": fp}) + return [_record_to_node(dict(rec["n"])) for rec in records] + + return [] + + async def get_node_degree(self, node_id: str) -> tuple[int, int]: + """Return (in_degree, out_degree) for a node.""" + cypher = f""" + MATCH (n:{_ALL_LABEL} {{id: $id}}) + OPTIONAL MATCH (n)-[out]->() + OPTIONAL MATCH ()-[inp]->(n) + RETURN count(DISTINCT out) AS out_degree, count(DISTINCT inp) AS in_degree + """ + records = await self._execute(cypher, {"id": node_id}) + if records: + return records[0]["in_degree"], records[0]["out_degree"] + return 0, 0 + + async def count_nodes(self) -> int: + cypher = f"MATCH (n:{_ALL_LABEL}) RETURN count(n) AS cnt" + records = await self._execute(cypher) + return records[0]["cnt"] if records else 0 + + async def count_edges(self) -> int: + cypher = "MATCH ()-[r]->() RETURN count(r) AS cnt" + records = await self._execute(cypher) + return records[0]["cnt"] if records else 0 + + async def search_nodes( + self, + query_terms: list[str], + match: str = "any", + node_types: list[str] | None = None, + tags: list[str] | None = None, + scope: str | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + """Keyword search using Neo4j full-text index (BM25).""" + search_query = " OR ".join(query_terms) if match == "any" else " AND ".join(query_terms) + + # If search_query is empty, return empty list + if not search_query: + return [] + + conditions: list[str] = [] + params: dict[str, Any] = {"search_query": search_query, "limit": top_k} + + if scope and scope != "full": + if scope.startswith("package:"): + prefix = scope[len("package:") :] + conditions.append("node.file_path STARTS WITH $scope_prefix") + params["scope_prefix"] = prefix + elif scope.startswith("file:"): + fp = scope[len("file:") :] + conditions.append("node.file_path = $scope_file") + params["scope_file"] = fp + + if node_types: + placeholders = ", ".join(f"$nt{i}" for i in range(len(node_types))) + conditions.append(f"node.type IN [{placeholders}]") + for i, nt in enumerate(node_types): + params[f"nt{i}"] = nt + + where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" + + cypher = f""" + CALL db.index.fulltext.queryNodes('node_search_index', $search_query) + YIELD node, score + {where_clause} + RETURN node, score + LIMIT $limit + """ + + records = await self._execute(cypher, params) + + results: list[dict[str, Any]] = [] + for rec in records: + node_data = dict(rec["node"]) + node = _record_to_node(node_data) + results.append( + { + "node_id": node.id, + "node_type": node.type.value, + "file": node.file_path, + "name": node.structural.name, + "docstring": node.semantic.docstring, + "tags": node.semantic.tags, + "score": rec["score"], + } + ) + return results + + async def _execute(self, cypher: str, params: dict[str, Any] | None = None) -> list[Any]: + """Execute a Cypher query and return records.""" + if not self._driver: + raise RuntimeError("Neo4jGraphStore is not connected. Call connect() first.") + async with self._driver.session(database=self._database) as session: + result = await session.run(cypher, params or {}) + return [rec.data() async for rec in result] diff --git a/smp/store/interfaces.py b/smp/store/interfaces.py new file mode 100644 index 0000000..c3b68ed --- /dev/null +++ b/smp/store/interfaces.py @@ -0,0 +1,249 @@ +"""Abstract base classes for store backends. + +All concrete implementations must subclass these to ensure +interchangeability across graph and vector stores. +""" + +from __future__ import annotations + +import abc +from collections.abc import Sequence +from typing import Any + +from smp.core.models import EdgeType, GraphEdge, GraphNode, NodeType + + +class GraphStore(abc.ABC): + """Abstract graph store — manages nodes and directed edges.""" + + # -- Lifecycle ----------------------------------------------------------- + + @abc.abstractmethod + async def connect(self) -> None: + """Open connection / initialise the underlying store.""" + + @abc.abstractmethod + async def close(self) -> None: + """Release resources.""" + + @abc.abstractmethod + async def clear(self) -> None: + """Drop all data (useful for tests).""" + + # -- Node CRUD ----------------------------------------------------------- + + @abc.abstractmethod + async def upsert_node(self, node: GraphNode) -> None: + """Insert or update a single node.""" + + @abc.abstractmethod + async def upsert_nodes(self, nodes: Sequence[GraphNode]) -> None: + """Batch upsert nodes.""" + + @abc.abstractmethod + async def get_node(self, node_id: str) -> GraphNode | None: + """Retrieve a node by its *id*, or ``None``.""" + + @abc.abstractmethod + async def delete_node(self, node_id: str) -> bool: + """Delete a node and all its edges. Returns True if it existed.""" + + @abc.abstractmethod + async def delete_nodes_by_file(self, file_path: str) -> int: + """Remove all nodes (and edges) belonging to *file_path*. + + Returns the number of nodes removed. + """ + + # -- Edge CRUD ----------------------------------------------------------- + + @abc.abstractmethod + async def upsert_edge(self, edge: GraphEdge) -> None: + """Insert or update a single edge.""" + + @abc.abstractmethod + async def upsert_edges(self, edges: Sequence[GraphEdge]) -> None: + """Batch upsert edges.""" + + @abc.abstractmethod + async def get_edges( + self, + node_id: str, + edge_type: EdgeType | None = None, + direction: str = "both", + ) -> list[GraphEdge]: + """Return edges connected to *node_id*. + + *direction*: ``"outgoing"`` | ``"incoming"`` | ``"both"``. + """ + + # -- Traversal ----------------------------------------------------------- + + @abc.abstractmethod + async def get_neighbors( + self, + node_id: str, + edge_type: EdgeType | None = None, + depth: int = 1, + ) -> list[GraphNode]: + """Return neighbours up to *depth* hops from *node_id*.""" + + @abc.abstractmethod + async def traverse( + self, + start_id: str, + edge_type: EdgeType, + depth: int, + max_nodes: int = 100, + direction: str = "outgoing", + ) -> list[GraphNode]: + """BFS traversal from *start_id* following *edge_type* edges.""" + + # -- Search -------------------------------------------------------------- + + @abc.abstractmethod + async def find_nodes( + self, + *, + type: NodeType | None = None, + file_path: str | None = None, + name: str | None = None, + ) -> list[GraphNode]: + """Find nodes matching the given filters.""" + + # -- Aggregation --------------------------------------------------------- + + @abc.abstractmethod + async def count_nodes(self) -> int: + """Return total number of nodes.""" + + @abc.abstractmethod + async def count_edges(self) -> int: + """Return total number of edges.""" + + # -- SMP(3) Extensions --------------------------------------------------- + + async def find_nodes_by_scope(self, scope: str) -> list[GraphNode]: + """Find nodes matching a scope prefix.""" + return [] + + async def get_node_degree(self, node_id: str) -> tuple[int, int]: + """Return (in_degree, out_degree) for a node.""" + return 0, 0 + + async def search_nodes( + self, + query_terms: list[str], + match: str = "any", + node_types: list[str] | None = None, + tags: list[str] | None = None, + scope: str | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + """Keyword search across docstrings, descriptions, and tags.""" + return [] + + # -- Session Persistence --------------------------------------------------- + + async def upsert_session(self, session: Any) -> None: + """Store or update a session in the graph.""" + raise NotImplementedError + + async def get_session(self, session_id: str) -> dict[str, Any] | None: + """Retrieve a session by ID.""" + return None + + async def delete_session(self, session_id: str) -> bool: + """Delete a session from the graph.""" + return False + + # -- Lock Persistence ------------------------------------------------------ + + async def upsert_lock(self, file_path: str, session_id: str) -> None: + """Store a file lock.""" + raise NotImplementedError + + async def get_lock(self, file_path: str) -> dict[str, Any] | None: + """Get lock info for a file.""" + return None + + async def release_lock(self, file_path: str, session_id: str) -> bool: + """Release a file lock.""" + return False + + async def release_all_locks(self, session_id: str) -> int: + """Release all locks held by a session.""" + return 0 + + # -- Context manager convenience ----------------------------------------- + + async def __aenter__(self) -> GraphStore: + await self.connect() + return self + + async def __aexit__(self, *_: Any) -> None: + await self.close() + + +class VectorStore(abc.ABC): + """Abstract vector store — manages embeddings with metadata.""" + + # -- Lifecycle ----------------------------------------------------------- + + @abc.abstractmethod + async def connect(self) -> None: + """Open connection / initialise.""" + + @abc.abstractmethod + async def close(self) -> None: + """Release resources.""" + + @abc.abstractmethod + async def clear(self) -> None: + """Drop all data.""" + + # -- CRUD ---------------------------------------------------------------- + + @abc.abstractmethod + async def upsert( + self, + ids: Sequence[str], + embeddings: Sequence[Sequence[float]], + metadatas: Sequence[dict[str, Any]], + documents: Sequence[str] | None = None, + ) -> None: + """Insert or update vectors with associated metadata.""" + + @abc.abstractmethod + async def query( + self, + embedding: Sequence[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """Return the *top_k* nearest neighbours. + + Each result is a dict with keys: ``id``, ``score``, ``metadata``, + ``document``. + """ + + @abc.abstractmethod + async def get(self, ids: Sequence[str]) -> list[dict[str, Any] | None]: + """Retrieve vectors by ID.""" + + @abc.abstractmethod + async def delete(self, ids: Sequence[str]) -> int: + """Delete vectors by ID. Returns count of deleted items.""" + + @abc.abstractmethod + async def delete_by_file(self, file_path: str) -> int: + """Delete all vectors whose metadata ``file_path`` matches.""" + + # -- Context manager convenience ----------------------------------------- + + async def __aenter__(self) -> VectorStore: + await self.connect() + return self + + async def __aexit__(self, *_: Any) -> None: + await self.close() diff --git a/test_codebase/src/auth/manager.py b/test_codebase/src/auth/manager.py index 22c32ec..8f6d45c 100644 --- a/test_codebase/src/auth/manager.py +++ b/test_codebase/src/auth/manager.py @@ -1,7 +1,6 @@ # src/auth/manager.py from src.db.user_store import save_user, get_user - def authenticate_user(email, password): """Validates user credentials and returns a session token.""" user = get_user(email) @@ -9,7 +8,6 @@ def authenticate_user(email, password): return "token_123" return None - def register_user(email, password): """Creates a new user account.""" data = {"email": email, "password": password} diff --git a/test_codebase/src/db/user_store.py b/test_codebase/src/db/user_store.py index a4886ce..9d2d6f8 100644 --- a/test_codebase/src/db/user_store.py +++ b/test_codebase/src/db/user_store.py @@ -4,7 +4,6 @@ def save_user(user_data: dict): print(f"Saving user {user_data.get('email')}") return True - def get_user(email: str): """Retrieves user by email.""" return {"email": email, "name": "Test User"} diff --git a/test_codebase/tests/test_auth.py b/test_codebase/tests/test_auth.py index 622609e..7c8d819 100644 --- a/test_codebase/tests/test_auth.py +++ b/test_codebase/tests/test_auth.py @@ -1,10 +1,8 @@ # tests/test_auth.py from src.auth.manager import authenticate_user - def test_auth_success(): assert authenticate_user("test@example.com", "secret") == "token_123" - def test_auth_fail(): assert authenticate_user("test@example.com", "wrong") is None diff --git a/tests/practical_verification.py b/tests/practical_verification.py index a558607..7921d45 100644 --- a/tests/practical_verification.py +++ b/tests/practical_verification.py @@ -5,7 +5,6 @@ log = get_logger("verification") - async def main(): async with SMPClient("http://localhost:8420") as client: # 1. Ingest Test Codebase @@ -17,7 +16,7 @@ async def main(): with open(path, "r") as file: content = file.read() await client.update(path, content=content) - + stats = await client.stats() log.info("Graph stats after ingestion", stats=stats) @@ -65,10 +64,10 @@ async def main(): session = await client._rpc("smp/session/open", {"agent_id": "test_agent", "task": "verify"}) sid = session["session_id"] log.info("Session opened", sid=sid) - + lock_res = await client._rpc("smp/lock", {"file_path": "tests/test_codebase/math_utils.py", "session_id": sid}) log.info("Lock acquired", res=lock_res) - + await client._rpc("smp/session/close", {"session_id": sid}) log.info("Session closed") @@ -77,16 +76,15 @@ async def main(): sandbox = await client._rpc("smp/sandbox/spawn", {"name": "test_sb"}) sb_id = sandbox["sandbox_id"] log.info("Sandbox spawned", sb_id=sb_id) - + exec_res = await client._rpc("smp/sandbox/execute", {"sandbox_id": sb_id, "command": ["ls", "-la"]}) log.info("Sandbox execution", res=exec_res) assert exec_res["exit_code"] == 0, "Sandbox command should succeed" - + await client._rpc("smp/sandbox/destroy", {"sandbox_id": sb_id}) log.info("Sandbox destroyed") log.info("ALL PRACTICAL TESTS PASSED") - if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/test_codebase/api/routes.py b/tests/test_codebase/api/routes.py index 50b600c..ea6a97b 100644 --- a/tests/test_codebase/api/routes.py +++ b/tests/test_codebase/api/routes.py @@ -12,7 +12,12 @@ class APIRoutes: This class acts as a 'Hot Node' as it coordinates multiple services. """ - def __init__(self, session_handler: SessionHandler, user_manager: UserManager, order_repo: OrderRepository): + def __init__( + self, + session_handler: SessionHandler, + user_manager: UserManager, + order_repo: OrderRepository + ): """ Initializes the APIRoutes with necessary services. """ @@ -44,7 +49,7 @@ async def handle_get_profile(self, token: str) -> str: User profile details or an error message. """ if await self._session_handler.validate_session(token): - user_id = "user_123" # Mocked from token + user_id = "user_123" # Mocked from token user = await self._user_manager.get_user_profile(user_id) return f"User: {user.username if user else 'Unknown'}" return "Invalid Session" @@ -62,7 +67,7 @@ async def handle_create_user(self, username: str, email: str) -> str: """ if not validate_email(email): return "Invalid Email" - + user = await self._user_manager.create_user(username, email) return f"User {user.username} created" @@ -77,7 +82,7 @@ async def handle_get_orders(self, token: str) -> str: List of orders or an error message. """ if await self._session_handler.validate_session(token): - user_id = "user_123" # Mocked from token + user_id = "user_123" # Mocked from token orders = await self._order_repo.get_orders_by_user(user_id) return f"Orders: {len(orders)}" return "Invalid Session" diff --git a/tests/test_codebase/calculator.py b/tests/test_codebase/calculator.py index 5656850..bf526a2 100644 --- a/tests/test_codebase/calculator.py +++ b/tests/test_codebase/calculator.py @@ -1,11 +1,9 @@ from math_utils import add, multiply - def compute_sum(x: int, y: int) -> int: """Computes sum using utils.""" return add(x, y) - def compute_product(x: int, y: int) -> int: """Computes product using utils.""" return multiply(x, y) diff --git a/tests/test_codebase/db/order_repository.py b/tests/test_codebase/db/order_repository.py index d725466..84d4a0a 100644 --- a/tests/test_codebase/db/order_repository.py +++ b/tests/test_codebase/db/order_repository.py @@ -5,7 +5,6 @@ class Order: """Represents an order entity.""" - def __init__(self, order_id: str, user_id: str, amount: float): self.order_id = order_id self.user_id = user_id diff --git a/tests/test_codebase/db/user_repository.py b/tests/test_codebase/db/user_repository.py index f400e4e..81c4e1d 100644 --- a/tests/test_codebase/db/user_repository.py +++ b/tests/test_codebase/db/user_repository.py @@ -5,7 +5,6 @@ class User: """Represents a user entity.""" - def __init__(self, user_id: str, username: str, email: str): self.user_id = user_id self.username = username diff --git a/tests/test_codebase/math_utils.py b/tests/test_codebase/math_utils.py index 7e48dfb..60af0d9 100644 --- a/tests/test_codebase/math_utils.py +++ b/tests/test_codebase/math_utils.py @@ -2,7 +2,6 @@ def add(a: int, b: int) -> int: """Adds two integers.""" return a + b - def multiply(a: int, b: int) -> int: """Multiplies two integers.""" return a * b diff --git a/tests/test_integration_community.py b/tests/test_integration_community.py index a1d32a5..9df88b7 100644 --- a/tests/test_integration_community.py +++ b/tests/test_integration_community.py @@ -698,4 +698,4 @@ async def test_get_boundaries_with_min_coupling(self) -> None: result = await detector.get_boundaries(level=0, min_coupling=0.5) - assert "boundaries" in result + assert "boundaries" in result \ No newline at end of file diff --git a/tests/test_integration_parser_graph.py b/tests/test_integration_parser_graph.py index 4a244a0..13ad132 100644 --- a/tests/test_integration_parser_graph.py +++ b/tests/test_integration_parser_graph.py @@ -43,9 +43,7 @@ async def test_parse_auth_service(self, registry: ParserRegistry) -> None: classes = [n for n in doc.nodes if n.type == NodeType.CLASS] imports = [n for n in doc.nodes if n.type == NodeType.FILE and "import" in n.structural.signature] - assert len(functions) >= 4, ( - f"Expected 4+ functions, got {len(functions)}: {[f.structural.name for f in functions]}" - ) + assert len(functions) >= 4, f"Expected 4+ functions, got {len(functions)}: {[f.structural.name for f in functions]}" assert len(classes) >= 1, f"Expected 1+ classes, got {len(classes)}" assert len(imports) >= 3, f"Expected 3+ imports, got {len(imports)}" @@ -74,9 +72,7 @@ async def test_parse_api_routes(self, registry: ParserRegistry) -> None: functions = [n for n in doc.nodes if n.type == NodeType.FUNCTION] imports = [n for n in doc.nodes if n.type == NodeType.FILE and "import" in n.structural.signature] - assert len(functions) >= 3, ( - f"Expected 3+ functions, got {len(functions)}: {[f.structural.name for f in functions]}" - ) + assert len(functions) >= 3, f"Expected 3+ functions, got {len(functions)}: {[f.structural.name for f in functions]}" assert len(imports) >= 2, f"Expected 2+ imports, got {len(imports)}" @pytest.mark.asyncio @@ -129,15 +125,7 @@ async def test_ingest_auth_service( stored_nodes = await store.find_nodes() node_names = {n.structural.name for n in stored_nodes} - expected_functions = [ - "hash_password", - "verify_password", - "generate_token", - "login", - "logout", - "verify_token", - "get_current_user", - ] + expected_functions = ["hash_password", "verify_password", "generate_token", "login", "logout", "verify_token", "get_current_user"] expected_classes = ["AuthService"] for func_name in expected_functions: @@ -150,9 +138,7 @@ async def test_ingest_auth_service( edge_types = {e.type for e in edges} assert EdgeType.DEFINES in edge_types, f"Missing DEFINES edges. Types found: {edge_types}" - assert EdgeType.CALLS in edge_types or EdgeType.DEFINES in edge_types, ( - f"Missing relationship edges. Types found: {edge_types}" - ) + assert EdgeType.CALLS in edge_types or EdgeType.DEFINES in edge_types, f"Missing relationship edges. Types found: {edge_types}" @pytest.mark.asyncio async def test_ingest_db_models( @@ -221,9 +207,7 @@ async def test_ingest_all_files( assert NodeType.CLASS in node_types, f"No classes found. Types: {node_types}" assert EdgeType.DEFINES in edge_types, f"No DEFINES edges. Types: {edge_types}" - assert EdgeType.CALLS in edge_types or EdgeType.IMPORTS in edge_types, ( - f"No CALLS or IMPORTS edges. Types: {edge_types}" - ) + assert EdgeType.CALLS in edge_types or EdgeType.IMPORTS in edge_types, f"No CALLS or IMPORTS edges. Types: {edge_types}" class TestEdgeCreation: @@ -305,6 +289,4 @@ async def test_class_method_defines( defines_edges = [e for e in edges if e.type == EdgeType.DEFINES] auth_service_defines = [e for e in defines_edges if "AuthService" in e.source_id] - assert len(auth_service_defines) >= 3, ( - f"Expected 3+ AuthService method defines, got {len(auth_service_defines)}" - ) + assert len(auth_service_defines) >= 3, f"Expected 3+ AuthService method defines, got {len(auth_service_defines)}" diff --git a/tests/test_integration_sandbox.py b/tests/test_integration_sandbox.py index 1b1c6a1..d0935df 100644 --- a/tests/test_integration_sandbox.py +++ b/tests/test_integration_sandbox.py @@ -10,7 +10,6 @@ try: from smp.sandbox.docker_sandbox import DockerSandbox - DOCKER_AVAILABLE = True except ImportError: DOCKER_AVAILABLE = False