From f3559c109dee2c7fc476f3999058ebf76e129038 Mon Sep 17 00:00:00 2001 From: offx-zinth Date: Sun, 19 Apr 2026 11:59:52 +0530 Subject: [PATCH 1/2] Save my local changes before linking --- .env.example | 5 + .github/workflows/ci.yml | 37 + .gitignore | 50 + AGENTS.md | 156 + Dockerfile | 17 + ERROR | 0 LICENSE | 21 + README.md | 241 ++ docker-compose.yml | 52 + pyproject.toml | 58 + smp (3).md | 3623 +++++++++++++++++ smp.md | 722 ++++ smp/__init__.py | 3 + smp/agent.py | 431 ++ smp/cli.py | 301 ++ smp/client.py | 201 + smp/core/__init__.py | 1 + smp/core/background.py | 182 + smp/core/merkle.py | 91 + smp/core/models.py | 644 +++ smp/engine/__init__.py | 11 + smp/engine/community.py | 499 +++ smp/engine/embedding.py | 124 + smp/engine/enricher.py | 113 + smp/engine/graph_builder.py | 159 + smp/engine/handoff.py | 203 + smp/engine/integrity.py | 242 ++ smp/engine/interfaces.py | 157 + smp/engine/linker.py | 205 + smp/engine/notification.py | 90 + smp/engine/pagerank.py | 119 + smp/engine/query.py | 817 ++++ smp/engine/runtime_linker.py | 212 + smp/engine/safety.py | 590 +++ smp/engine/seed_walk.py | 465 +++ smp/engine/telemetry.py | 161 + smp/logging.py | 68 + smp/parser/__init__.py | 11 + smp/parser/base.py | 153 + smp/parser/python_parser.py | 553 +++ smp/parser/registry.py | 72 + smp/parser/typescript_parser.py | 525 +++ smp/protocol/__init__.py | 9 + smp/protocol/dispatcher.py | 267 ++ smp/protocol/handlers/__init__.py | 1 + smp/protocol/handlers/annotation.py | 117 + smp/protocol/handlers/base.py | 34 + smp/protocol/handlers/community.py | 94 + smp/protocol/handlers/enrichment.py | 185 + smp/protocol/handlers/handoff.py | 68 + smp/protocol/handlers/memory.py | 115 + smp/protocol/handlers/merkle.py | 81 + smp/protocol/handlers/query.py | 142 + smp/protocol/handlers/query_ext.py | 115 + smp/protocol/handlers/safety.py | 338 ++ smp/protocol/handlers/sandbox.py | 110 + smp/protocol/handlers/telemetry.py | 122 + smp/protocol/router.py | 653 +++ smp/protocol/server.py | 162 + smp/sandbox/__init__.py | 1 + smp/sandbox/docker_sandbox.py | 57 + smp/sandbox/ebpf_collector.py | 29 + smp/sandbox/executor.py | 169 + smp/sandbox/spawner.py | 113 + smp/store/__init__.py | 1 + smp/store/chroma_store.py | 155 + smp/store/graph/__init__.py | 1 + smp/store/graph/neo4j_store.py | 558 +++ smp/store/interfaces.py | 249 ++ test_codebase/src/auth/manager.py | 14 + test_codebase/src/db/user_store.py | 9 + test_codebase/tests/test_auth.py | 8 + tests/__init__.py | 0 tests/conftest.py | 94 + .../sample_project/src/api/__init__.py | 7 + .../fixtures/sample_project/src/api/routes.py | 37 + .../sample_project/src/auth/__init__.py | 7 + .../sample_project/src/auth/auth_service.py | 65 + .../sample_project/src/db/__init__.py | 9 + .../fixtures/sample_project/src/db/models.py | 28 + .../fixtures/sample_project/src/db/orders.py | 28 + .../sample_project/tests/test_auth.py | 67 + .../fixtures/sample_project/tests/test_db.py | 52 + tests/practical_verification.py | 90 + tests/results/practical_phase10_handoff.json | 56 + tests/results/practical_phase1_service.json | 32 + tests/results/practical_phase2_ingestion.json | 32 + tests/results/practical_phase2_query.json | 64 + .../results/practical_phase3_enrichment.json | 42 + tests/results/practical_phase3_linker.json | 22 + .../results/practical_phase4_annotation.json | 21 + tests/results/practical_phase4_query.json | 73 + .../results/practical_phase5_enrichment.json | 42 + tests/results/practical_phase5_memory.json | 21 + .../results/practical_phase6_annotation.json | 30 + tests/results/practical_phase6_safety.json | 69 + tests/results/practical_phase7_query_ext.json | 60 + tests/results/practical_phase8_safety.json | 87 + tests/results/practical_phase9_sandbox.json | 30 + tests/results/practical_summary.json | 6 + tests/results/summary.json | 19 + tests/test_client.py | 210 + tests/test_codebase/__init__.py | 0 tests/test_codebase/api/middleware.py | 25 + tests/test_codebase/api/routes.py | 88 + tests/test_codebase/auth/jwt_utils.py | 42 + tests/test_codebase/auth/session_handler.py | 52 + tests/test_codebase/auth/user_manager.py | 45 + tests/test_codebase/calculator.py | 9 + tests/test_codebase/db/base_repository.py | 58 + tests/test_codebase/db/order_repository.py | 62 + tests/test_codebase/db/user_repository.py | 65 + tests/test_codebase/main.py | 1 + tests/test_codebase/math_utils.py | 7 + tests/test_codebase/utils/crypto.py | 32 + tests/test_codebase/utils/validators.py | 30 + tests/test_enricher.py | 153 + tests/test_integration_community.py | 701 ++++ tests/test_integration_merkle.py | 219 + tests/test_integration_parser_graph.py | 292 ++ tests/test_integration_protocol_handlers.py | 343 ++ tests/test_integration_query_engine.py | 799 ++++ tests/test_integration_safety.py | 515 +++ tests/test_integration_sandbox.py | 218 + tests/test_integration_vector_store.py | 312 ++ tests/test_models.py | 301 ++ tests/test_parser.py | 226 + tests/test_protocol.py | 195 + tests/test_query.py | 277 ++ tests/test_store.py | 213 + tests/test_update.py | 223 + 131 files changed, 23265 insertions(+) create mode 100644 .env.example create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 AGENTS.md create mode 100644 Dockerfile create mode 100644 ERROR create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docker-compose.yml create mode 100644 pyproject.toml create mode 100644 smp (3).md create mode 100644 smp.md create mode 100644 smp/__init__.py create mode 100644 smp/agent.py create mode 100644 smp/cli.py create mode 100644 smp/client.py create mode 100644 smp/core/__init__.py create mode 100644 smp/core/background.py create mode 100644 smp/core/merkle.py create mode 100644 smp/core/models.py create mode 100644 smp/engine/__init__.py create mode 100644 smp/engine/community.py create mode 100644 smp/engine/embedding.py create mode 100644 smp/engine/enricher.py create mode 100644 smp/engine/graph_builder.py create mode 100644 smp/engine/handoff.py create mode 100644 smp/engine/integrity.py create mode 100644 smp/engine/interfaces.py create mode 100644 smp/engine/linker.py create mode 100644 smp/engine/notification.py create mode 100644 smp/engine/pagerank.py create mode 100644 smp/engine/query.py create mode 100644 smp/engine/runtime_linker.py create mode 100644 smp/engine/safety.py create mode 100644 smp/engine/seed_walk.py create mode 100644 smp/engine/telemetry.py create mode 100644 smp/logging.py create mode 100644 smp/parser/__init__.py create mode 100644 smp/parser/base.py create mode 100644 smp/parser/python_parser.py create mode 100644 smp/parser/registry.py create mode 100644 smp/parser/typescript_parser.py create mode 100644 smp/protocol/__init__.py create mode 100644 smp/protocol/dispatcher.py create mode 100644 smp/protocol/handlers/__init__.py create mode 100644 smp/protocol/handlers/annotation.py create mode 100644 smp/protocol/handlers/base.py create mode 100644 smp/protocol/handlers/community.py create mode 100644 smp/protocol/handlers/enrichment.py create mode 100644 smp/protocol/handlers/handoff.py create mode 100644 smp/protocol/handlers/memory.py create mode 100644 smp/protocol/handlers/merkle.py create mode 100644 smp/protocol/handlers/query.py create mode 100644 smp/protocol/handlers/query_ext.py create mode 100644 smp/protocol/handlers/safety.py create mode 100644 smp/protocol/handlers/sandbox.py create mode 100644 smp/protocol/handlers/telemetry.py create mode 100644 smp/protocol/router.py create mode 100644 smp/protocol/server.py create mode 100644 smp/sandbox/__init__.py create mode 100644 smp/sandbox/docker_sandbox.py create mode 100644 smp/sandbox/ebpf_collector.py create mode 100644 smp/sandbox/executor.py create mode 100644 smp/sandbox/spawner.py create mode 100644 smp/store/__init__.py create mode 100644 smp/store/chroma_store.py create mode 100644 smp/store/graph/__init__.py create mode 100644 smp/store/graph/neo4j_store.py create mode 100644 smp/store/interfaces.py create mode 100644 test_codebase/src/auth/manager.py create mode 100644 test_codebase/src/db/user_store.py create mode 100644 test_codebase/tests/test_auth.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/fixtures/sample_project/src/api/__init__.py create mode 100644 tests/fixtures/sample_project/src/api/routes.py create mode 100644 tests/fixtures/sample_project/src/auth/__init__.py create mode 100644 tests/fixtures/sample_project/src/auth/auth_service.py create mode 100644 tests/fixtures/sample_project/src/db/__init__.py create mode 100644 tests/fixtures/sample_project/src/db/models.py create mode 100644 tests/fixtures/sample_project/src/db/orders.py create mode 100644 tests/fixtures/sample_project/tests/test_auth.py create mode 100644 tests/fixtures/sample_project/tests/test_db.py create mode 100644 tests/practical_verification.py create mode 100644 tests/results/practical_phase10_handoff.json create mode 100644 tests/results/practical_phase1_service.json create mode 100644 tests/results/practical_phase2_ingestion.json create mode 100644 tests/results/practical_phase2_query.json create mode 100644 tests/results/practical_phase3_enrichment.json create mode 100644 tests/results/practical_phase3_linker.json create mode 100644 tests/results/practical_phase4_annotation.json create mode 100644 tests/results/practical_phase4_query.json create mode 100644 tests/results/practical_phase5_enrichment.json create mode 100644 tests/results/practical_phase5_memory.json create mode 100644 tests/results/practical_phase6_annotation.json create mode 100644 tests/results/practical_phase6_safety.json create mode 100644 tests/results/practical_phase7_query_ext.json create mode 100644 tests/results/practical_phase8_safety.json create mode 100644 tests/results/practical_phase9_sandbox.json create mode 100644 tests/results/practical_summary.json create mode 100644 tests/results/summary.json create mode 100644 tests/test_client.py create mode 100644 tests/test_codebase/__init__.py create mode 100644 tests/test_codebase/api/middleware.py create mode 100644 tests/test_codebase/api/routes.py create mode 100644 tests/test_codebase/auth/jwt_utils.py create mode 100644 tests/test_codebase/auth/session_handler.py create mode 100644 tests/test_codebase/auth/user_manager.py create mode 100644 tests/test_codebase/calculator.py create mode 100644 tests/test_codebase/db/base_repository.py create mode 100644 tests/test_codebase/db/order_repository.py create mode 100644 tests/test_codebase/db/user_repository.py create mode 100644 tests/test_codebase/main.py create mode 100644 tests/test_codebase/math_utils.py create mode 100644 tests/test_codebase/utils/crypto.py create mode 100644 tests/test_codebase/utils/validators.py create mode 100644 tests/test_enricher.py create mode 100644 tests/test_integration_community.py create mode 100644 tests/test_integration_merkle.py create mode 100644 tests/test_integration_parser_graph.py create mode 100644 tests/test_integration_protocol_handlers.py create mode 100644 tests/test_integration_query_engine.py create mode 100644 tests/test_integration_safety.py create mode 100644 tests/test_integration_sandbox.py create mode 100644 tests/test_integration_vector_store.py create mode 100644 tests/test_models.py create mode 100644 tests/test_parser.py create mode 100644 tests/test_protocol.py create mode 100644 tests/test_query.py create mode 100644 tests/test_store.py create mode 100644 tests/test_update.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ce3b91a --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +# Database Configuration +SMP_NEO4J_URI=bolt://localhost:7687 +SMP_NEO4J_USER=neo4j +SMP_NEO4J_PASSWORD=your_secure_password_here + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..fa68c43 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,37 @@ +name: CI + +on: + push: + branches: [main, develop] + pull_request: + branches: [main, develop] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install uv + run: pip install uv + + - name: Install dependencies + run: uv pip install -e ".[dev]" + + - name: Check formatting with ruff + run: ruff format --check . + + - name: Lint with ruff + run: ruff check . + + - name: Type check with mypy + run: mypy smp/ + + - name: Run tests + run: pytest tests/ -x --tb=short \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..850343d --- /dev/null +++ b/.gitignore @@ -0,0 +1,50 @@ +# Byte-compiled / optimized +__pycache__/ +*.py[cod] +*$py.class + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# Packaging +*.egg-info/ +dist/ +build/ +.eggs/ +*.egg + +# Environment & secrets — NEVER commit +.env +.env.* +!.env.example + +# Virtual environments +.venv/ +venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Neo4j data +neo4j-data/ + +# ChromaDB data (removed — V2 uses Neo4j full-text only) +chroma-data/ + +# Type stubs +*.pyi + +# Logs +*.log + +# MyPy cache +.mypy_cache/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..6428ed5 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,156 @@ +# AGENTS.md — Coding Agent Instructions for SMP + +## Project Overview + +SMP (Structural Memory Protocol) is a graph-based codebase intelligence system for AI agents. +It parses source code into a knowledge graph (Neo4j) with vector embeddings (ChromaDB), +exposing a JSON-RPC API via FastAPI. + +**Stack:** Python 3.11+, FastAPI, msgspec, tree-sitter, Neo4j, ChromaDB, pytest. + +**IMPORTANT:** This project requires **Python 3.11** explicitly. Always use `python3.11` or ensure your virtual environment is created with Python 3.11. The project uses 3.11+ features (`X | Y` unions, `tomllib`, etc.) and ruff/mypy are configured with `target-version = "py311"`. + +--- + +## Build / Lint / Test Commands + +```bash +# Create venv with Python 3.11 +python3.11 -m venv .venv && source .venv/bin/activate + +# Install in dev mode +pip install -e ".[dev]" + +# Lint (check + format) +ruff check . +ruff format . + +# Type check +mypy smp/ + +# Run all tests +pytest + +# Run a single test file +pytest tests/test_models.py + +# Run a single test class +pytest tests/test_models.py::TestGraphNode + +# Run a single test method +pytest tests/test_models.py::TestGraphNode::test_defaults + +# Run the server +python3.11 -m smp.cli serve + +# Ingest a directory +python3.11 -m smp.cli ingest + +# Run a command in background (returns immediately, agent continues working) +# IMPORTANT: Use full path to venv python3.11 so it can find the smp module +python3.11 -m smp.cli run -- .venv/bin/python -m [args...] +# Example: start server in background +python3.11 -m smp.cli run myserver -- .venv/bin/python -m smp.cli serve --port 8420 + +# List running background processes +python3.11 -m smp.cli ps + +# View logs for a background process +python3.11 -m smp.cli logs + +# Stop a background process +python3.11 -m smp.cli stop + +# Restart a background process (use --restart flag when running) +python3.11 -m smp.cli run -- .venv/bin/python -m --restart +python3.11 -m smp.cli run -- --restart +``` + +Always run `ruff check .`, `ruff format .`, and `mypy smp/` after making changes. + +--- + +## Code Style + +### Imports + +- Every file starts with `from __future__ import annotations`. +- Group imports: stdlib → third-party → local, separated by blank lines. +- Use absolute imports for local modules: `from smp.core.models import GraphNode`. +- Prefer `from X import Y` over `import X` for specific names. +- Use `list[...]`, `dict[...]`, `set[...]` (builtin generics), never `List`, `Dict`, `Set`. +- Use `X | Y` for unions (python3.11 3.11+), not `Optional[X]` or `Union[X, Y]`. + +### Formatting + +- Line length: 120 characters max. +- Formatting enforced by `ruff format` — no manual alignment or custom spacing. + +### Naming + +- **Modules:** `snake_case.py` +- **Classes:** `PascalCase` (e.g., `GraphNode`, `Neo4jGraphStore`) +- **Functions/methods:** `snake_case` +- **Constants:** `_UPPER_SNAKE_CASE` (leading underscore for module-private) +- **Private members:** `_leading_underscore` +- **Enums:** `PascalCase` class, `UPPER_SNAKE_CASE` members + +### Type Annotations + +- All function signatures must have full type annotations (enforced by mypy strict mode). +- `self` and `cls` annotations are not required (ANN101/ANN102 ignored). +- Avoid `Any` in return types (ANN401 ignored only when unavoidable). + +### Docstrings + +- Use triple double-quotes, imperative mood, Google style. +- Omit docstrings for trivial/private methods; prefer clear naming instead. + +### Logging + +- Use structured logging: `log = get_logger(__name__)`. +- Log with keyword context: `log.info("event_name", key=value)`. +- Never use f-strings in log calls; pass structured fields instead. + +--- + +## Architecture & Patterns + +- **Layered:** `core` (models) → `engine` (logic) → `protocol` (API) → `store` (persistence). +- **Interfaces:** Abstract base classes in `interfaces.py` files using `abc.ABC` + `@abc.abstractmethod`. +- **Models:** Use `msgspec.Struct`; prefer `frozen=True` for immutable data. +- **Dependency injection:** Pass dependencies via constructor parameters. +- **Factories:** Use `create_app()` style factory functions for complex setup. +- **Registry:** Use `ParserRegistry` pattern for lazy/dispatched initialization. +- **Async:** Use `async`/`await` throughout; async context managers via `__aenter__`/`__aexit__`. + +--- + +## Testing + +- Framework: **pytest** with **pytest-asyncio** (`asyncio_mode = "auto"`). +- Tests live in `tests/`, mirroring the source structure. +- Use fixtures from `conftest.py`: `clean_graph`, `vector_store`, `make_node()`, `make_edge()`. +- Group related tests in classes (e.g., `class TestGraphNode:`). +- Async tests: just define as `async def test_...` — the `@pytest.mark.asyncio` decorator is auto-applied. +- Prefer factory helpers over inline test data construction. + +--- + +## Lint Rules (ruff) + +Enabled rule sets: `E` (pycodestyle), `F` (pyflakes), `I` (isort), `N` (naming), +`UP` (pyupgrade), `B` (bugbear), `SIM` (simplify), `ANN` (annotations). + +Ignored: `ANN101`, `ANN102`, `ANN401`. + +--- + +## Pre-Commit Checklist + +Before submitting changes: + +1. `ruff check .` — no lint errors +2. `ruff format .` — code is formatted +3. `mypy smp/` — no type errors +4. `pytest` — all tests pass diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..5f1f1d7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.11-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +COPY pyproject.toml ./ +COPY smp/ ./smp/ + +RUN pip install --no-cache-dir -e ".[dev]" + +EXPOSE 8420 + +CMD ["python3.11", "-m", "smp.cli", "serve", "--host", "0.0.0.0", "--port", "8420"] \ No newline at end of file diff --git a/ERROR b/ERROR new file mode 100644 index 0000000..e69de29 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d81e0b7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 SMP Team + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..f025172 --- /dev/null +++ b/README.md @@ -0,0 +1,241 @@ +# Structural Memory Protocol (SMP) + +**High-Fidelity Codebase Intelligence for AI Agents** + +Structural Memory Protocol (SMP) is a graph-based memory system that provides AI agents with a deep, structured understanding of complex codebases. Unlike RAG which treats code as flat text, SMP models code as a multi-dimensional graph of entities, relationships, and semantic meanings. + +Built with **Python 3.11**, **FastAPI**, and **Neo4j**, SMP enables agents to perform precise code navigation, impact analysis, and safe refactoring — using static analysis (no LLM required). + +--- + +## Quickstart (Docker Compose) + +The fastest way to get SMP running: + +```bash +# Clone the repository +git clone https://github.com/your-org/smp.git +cd smp + +# Copy and configure environment +cp .env.example .env +# Edit .env with your Neo4j password + +# Start all services +docker compose up -d + +# Verify health +curl http://localhost:8420/health +# Returns: {"status":"ok"} +``` + +--- + +## Quickstart (Manual) + +### 1. Requirements +- **Python 3.11+** +- **Neo4j 5.x** (Local or AuraDB) + +### 2. Environment +```bash +# Copy the example and configure +cp .env.example .env + +# Edit .env with your credentials: +# SMP_NEO4J_PASSWORD=your_neo4j_password +``` + +### 3. Install & Run +```bash +# Clone and enter the repo +git clone https://github.com/offx-zinth/smp.git +cd smp + +# Create venv with Python 3.11 +python3.11 -m venv .venv +source .venv/bin/activate +pip install -e ".[dev]" + +# Start the server +smp serve +``` + +--- + +## Architecture: Manual Efficient Method (SMP V2) + +SMP V2 is designed for production-grade efficiency. It relies on **static AST extraction** and **Neo4j full-text indexing** — no LLM or vector embeddings required. + +- **Parser**: Tree-sitter extracts functions, classes, imports, and docstrings directly from AST. +- **Enricher**: Extracts docstrings, decorators, and type annotations statically. +- **Linker**: Namespaced cross-file resolution for CALLS edges. +- **Query Engine**: Neo4j full-text index (BM25) for keyword search. +- **Safety Protocol**: Session management, dry-runs, and isolated sandbox execution. + +--- + +## Demo: JSON-RPC Query + +Ingest a codebase and query it: + +```bash +# Ingest a project +smp ingest /path/to/your/project + +# Query via JSON-RPC +curl -X POST http://localhost:8420/rpc \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "method": "smp/context", + "params": { + "file_path": "smp/core/models.py", + "scope": "edit", + "depth": 2 + }, + "id": 1 + }' +``` + +**Response:** +```json +{ + "jsonrpc": "2.0", + "result": { + "self": { + "id": "smp/core/models.py::GraphNode", + "type": "Class", + "name": "GraphNode", + "signature": "class GraphNode", + "start_line": 130, + "end_line": 220 + }, + "neighbors": [ + { + "id": "smp/core/models.py::StructuralProperties", + "type": "Class", + "relationship": "CONTAINS" + }, + { + "id": "smp/core/models.py::SemanticProperties", + "type": "Class", + "relationship": "CONTAINS" + } + ], + "context": { + "file": "smp/core/models.py", + "imports": ["msgspec", "typing"], + "defines": ["GraphNode", "GraphEdge", "NodeType", "EdgeType"] + } + }, + "id": 1 +} +``` + +--- + +## Key Capabilities + +* **Graph-Augmented Retrieval:** Navigate via `CALLS`, `INHERITS`, `IMPORTS` relationships +* **Semantic Search:** Neo4j full-text index (BM25) for keyword search across docstrings/tags +* **Static Enrichment:** Docstrings, decorators, and type annotations extracted from AST +* **Impact Assessment:** Determine the "blast radius" before changes +* **Safety & Sandboxing:** Session management, dry-runs, isolated execution +* **Multi-Language:** Python and TypeScript/JavaScript via Tree-sitter + +--- + +## Architecture + +``` +smp/ +├── smp/ +│ ├── core/ # Models, logging +│ ├── engine/ # Query, enricher, linker, safety +│ ├── protocol/ # JSON-RPC 2.0 API +│ │ └── handlers/ # Modular method handlers +│ ├── store/ # Neo4j (graph + full-text) +│ ├── parser/ # Tree-sitter parsing +│ ├── sandbox/ # Isolated execution +│ ├── cli.py # CLI +│ └── client.py # Python SDK +├── tests/ # Test suite +└── .github/workflows/# CI/CD +``` + +--- + +## Usage + +### Ingest a Project +```bash +smp ingest /path/to/project --clear +``` + +### Run Server +```bash +smp serve --port 8420 --safety +``` + +### Python SDK +```python +import asyncio +from smp.client import SMPClient + +async def main(): + async with SMPClient("http://localhost:8420") as client: + # Semantic search + results = await client.locate("authentication logic") + + # Trace call graph + graph = await client.trace("src/auth.py::login", depth=5) + + # Impact assessment + impact = await client.assess_impact("src/models/user.py::User") + print(f"Affects {impact['total_affected_nodes']} nodes") + +asyncio.run(main()) +``` + +--- + +## Development + +```bash +# Format +ruff format . + +# Lint +ruff check . + +# Type check +mypy smp/ + +# Test +pytest +``` + +--- + +## Troubleshooting + +| Issue | Solution | +|:---|:---| +| `sqlite3` ImportError | Install `pysqlite3-binary` | +| Neo4j Connection | Check `SMP_NEO4J_URI` and credentials in `.env` | +| SyntaxError | Use Python 3.11 | +| Enrichment Timeout | Set `SMP_ENRICHMENT=none` in `.env` | + +--- + +## Contributing + +1. Use `feature/` or `fix/` branches +2. Follow patterns in `AGENTS.md` +3. Add tests for new features +4. Run `ruff check . && ruff format . && mypy smp/ && pytest` + +--- + +*SMP — Empowering agents with structural memory.* \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..b25e0b1 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,52 @@ +services: + neo4j: + image: neo4j:5.23-community + environment: + NEO4J_AUTH: neo4j/${SMP_NEO4J_PASSWORD:-neo4j_secure_password} + NEO4J_dbms_security_procedures_unrestricted: apoc.* + NEO4J_dbms_security_procedures_allowlist: apoc.* + ports: + - "7475:7474" # Host 7475 maps to Container 7474 + - "7688:7687" # Host 7688 maps to Container 7687 + volumes: + - neo4j_data:/data + - neo4j_logs:/logs + healthcheck: + test: ["CMD-SHELL", "cypher-shell -u neo4j -p ${SMP_NEO4J_PASSWORD:-neo4j_secure_password} 'RETURN 1;' || exit 1"] + interval: 10s + timeout: 30s + retries: 5 + + chromadb: + image: chromadb/chroma:latest + ports: + - "8000:8000" + volumes: + - chroma_data:/chroma/chroma + + smp: + build: . + ports: + - "8420:8420" + environment: + SMP_NEO4J_URI: bolt://neo4j:7687 + SMP_NEO4J_USER: neo4j + SMP_NEO4J_PASSWORD: ${SMP_NEO4J_PASSWORD:-neo4j_secure_password} + SMP_ENRICHMENT: ${SMP_ENRICHMENT:-none} + SMP_CHROMA_PERSIST_DIR: /chroma/chroma + depends_on: + neo4j: + condition: service_healthy + chromadb: + condition: service_started + env_file: + - .env.example + volumes: + - .:/app + - smp_data:/root/.smp + +volumes: + neo4j_data: + neo4j_logs: + chroma_data: + smp_data: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f8dd13a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "smp" +version = "0.1.0" +description = "Structural Memory Protocol — graph-based codebase intelligence for AI agents" +requires-python = ">=3.11" +license = "MIT" +authors = [{ name = "SMP Team" }] + +dependencies = [ + "msgspec>=0.19", + "fastapi>=0.115", + "uvicorn[standard]>=0.34", + "neo4j>=5.0", + "httpx>=0.27", + "tree-sitter>=0.24", + "tree-sitter-python>=0.23", + "tree-sitter-typescript>=0.23", + "python-dotenv>=1.0", + "structlog>=24.0", + "chromadb", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.24", + "httpx>=0.27", + "ruff>=0.8", + "mypy>=1.13", +] + +[tool.hatch.build.targets.wheel] +packages = ["smp"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +[tool.ruff] +target-version = "py311" +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "UP", "B", "SIM", "ANN"] +ignore = ["ANN101", "ANN102", "ANN401"] + +[tool.mypy] +python_version = "3.11" +strict = true +warn_return_any = true +warn_unused_configs = true +exclude = ["tests/results"] + +[project.scripts] +smp = "smp.cli:main" diff --git a/smp (3).md b/smp (3).md new file mode 100644 index 0000000..190b6c7 --- /dev/null +++ b/smp (3).md @@ -0,0 +1,3623 @@ +# The Structural Memory Protocol (SMP) + +A framework for giving AI agents a "programmer's brain" — not text retrieval, but structural understanding. + +--- + +## Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ CODEBASE (Files + Git) │ +└──────────────────────────┬──────────────────────────────────────┘ + │ Updates (Watch / Agent Push / commit_sha) + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ MEMORY SERVER (SMP Core) │ +│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │ +│ │ PARSER │──▶│ GRAPH BUILDER│──▶│ ENRICHER │ │ +│ │ (AST/Tree- │ │ + LINKER │ │ (Static │ │ +│ │ sitter) │ │ (Static + │ │ Metadata) │ │ +│ │ │ │ eBPF Runtime│ │ │ │ +│ └─────────────┘ └──────────────┘ └──────┬──────┘ │ +│ │ │ +│ ┌───────────────────────────────────────────▼──────────────┐ │ +│ │ MEMORY STORE │ │ +│ │ │ │ +│ │ ┌─────────────────────────────────────┐ │ │ +│ │ │ GRAPH DB (Neo4j) │ │ │ +│ │ │ Structure · CALLS_STATIC │ │ │ +│ │ │ CALLS_RUNTIME · PageRank │ │ │ +│ │ │ Sessions · Audit · Telemetry │ │ │ +│ │ │ Full-Text Index (BM25) │ │ │ +│ │ └─────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ ┌─────────────────────────────────────┐ │ │ +│ │ │ VECTOR INDEX (ChromaDB) │ │ │ +│ │ │ code_embedding per node │ │ │ +│ │ │ (signature + docstring, at │ │ │ +│ │ │ index time — no LLM at query time)│ │ │ +│ │ └─────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ ┌─────────────────────────────────────┐ │ │ +│ │ │ MERKLE INDEX │ │ │ +│ │ │ SHA-256 leaf per file node │ │ │ +│ │ │ Package subtree hashes │ │ │ +│ │ │ Root hash = full codebase state │ │ │ +│ │ └─────────────────────────────────────┘ │ │ +│ └──────────────────────────────┬───────────────────────────┘ │ +└─────────────────────────────────┼───────────────────────────────┘ + │ + ┌───────────────────────┼───────────────────────┐ + │ │ │ + ▼ ▼ ▼ +┌─────────────────┐ ┌──────────────────────┐ ┌───────────────┐ +│ QUERY ENGINE │ │ SANDBOX RUNTIME │ │ SWARM LAYER │ +│ Navigator │ │ Ephemeral microVM/ │ │ Peer Review │ +│ Reasoner │ │ Docker + CoW fork │ │ PR Handoff │ +│ SeedWalkEngine │ │ eBPF trace capture │ │ │ +│ Telemetry │ │ Egress-firewalled │ └───────┬───────┘ +└────────┬────────┘ └──────────┬───────────┘ │ + └──────────────┬────────┘ ────────┘ + │ SMP Protocol (Dispatcher) + ▼ + ┌─────────────────────────────────────────────┐ + │ AGENT LAYER │ + │ Agent A Agent B Agent C │ + │ (Coder) (Reviewer) (Architect) │ + └─────────────────────────────────────────────┘ +``` + +--- + +## Part 1: The Memory Server + +### A. Parser (AST Extraction) + +**Technology:** Tree-sitter (multi-language, fast, incremental) + +**Input:** File path + content + +**Output:** Abstract Syntax Tree with typed nodes + +```python +# What gets extracted per file +{ + "file_path": "src/auth/login.ts", + "language": "typescript", + "nodes": [ + { + "id": "func_authenticate_user", + "type": "function_declaration", + "name": "authenticateUser", + "start_line": 15, + "end_line": 42, + "signature": "authenticateUser(email: string, password: string): Promise", + "docstring": "Validates user credentials and returns JWT...", + "modifiers": ["async", "export"] + }, + { + "id": "class_AuthService", + "type": "class_declaration", + "name": "AuthService", + "methods": ["login", "logout", "refresh"], + "properties": ["tokenExpiry", "secretKey"] + } + ], + "imports": [ + {"from": "./utils/crypto", "items": ["hashPassword", "compareHash"]}, + {"from": "../db/user", "items": ["UserModel"]} + ], + "exports": ["authenticateUser", "AuthService"] +} +``` + +--- + +### B. Graph Builder (Structural Analysis) + +**Graph Schema:** + +``` +┌─────────────────────────────────────────────────────────────┐ +│ NODE TYPES │ +├─────────────────────────────────────────────────────────────┤ +│ Repository │ Root node │ +│ Package │ Directory/module │ +│ File │ Source file │ +│ Class │ Class definition │ +│ Function │ Function/method │ +│ Variable │ Variable/constant │ +│ Interface │ Type definition/interface │ +│ Test │ Test file/function │ +│ Config │ Configuration file │ +│ Community │ Louvain-detected structural cluster │ +└─────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────┐ +│ RELATIONSHIP TYPES │ +├─────────────────────────────────────────────────────────────┤ +│ CONTAINS │ Parent-child (Package → File) │ +│ IMPORTS │ File imports File/Module │ +│ DEFINES │ File defines Class/Function │ +│ CALLS │ Function calls Function (namespaced) │ +│ INHERITS │ Class inherits Class │ +│ IMPLEMENTS │ Class implements Interface │ +│ DEPENDS_ON │ General dependency │ +│ TESTS │ Test tests Function/Class │ +│ USES │ Function uses Variable/Type │ +│ REFERENCES │ Variable references Variable │ +│ MEMBER_OF │ Node belongs to Community │ +│ BRIDGES │ Community connects to Community │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Example Graph Node:** + +```json +{ + "id": "func_authenticate_user", + "type": "Function", + "name": "authenticateUser", + "file": "src/auth/login.ts", + "community_id": "comm_auth_core", + "signature": "authenticateUser(email: string, password: string): Promise", + "metrics": { + "complexity": 4, + "lines": 28, + "parameters": 2 + }, + "relationships": { + "CALLS": ["func_hashPassword", "func_compareHash", "func_generateToken"], + "DEPENDS_ON": ["class_UserModel"], + "DEFINED_IN": "file_auth_login_ts", + "MEMBER_OF": "comm_auth_core" + } +} +``` + +--- + +### B1. The Linker (Namespaced Cross-File Resolution) + +The Linker runs after the Graph Builder and resolves every `CALLS` edge using the file's `imports` list as a namespace map. This prevents ambiguous links when the same function name exists across multiple files. + +**The Problem:** + +``` +File A calls: save() +File B has: save() (src/db/user.ts) +File C has: save() (src/cache/session.ts) +``` + +Without namespacing, the linker guesses. With it, it traces the import to the exact origin file first. + +**Resolution Algorithm:** + +``` +For each CALLS(caller → "save") edge: + 1. Look up caller's IMPORTS list + 2. Find the import entry that exposes "save" + → e.g. import { save } from "../db/user" + 3. Resolve "../db/user" to absolute path → src/db/user.ts + 4. Find node with name="save" AND file="src/db/user.ts" + 5. Draw CALLS edge to that exact node + + If step 2 finds no import for "save": + → Mark edge as CALLS_UNRESOLVED (name="save", reason="not in imports") + → Flag for smp/linker/report +``` + +**Linker State in the Graph DB:** + +Every `CALLS` edge carries a `resolved` flag so agents always know if a dependency is confirmed or ambiguous. + +```json +{ + "edge": "CALLS", + "from": "func_authenticate_user", + "to": "func_hashPassword", + "resolved": true, + "import_source": "src/utils/crypto.ts" +} +``` + +```json +{ + "edge": "CALLS_UNRESOLVED", + "from": "func_process_data", + "to_name": "save", + "resolved": false, + "reason": "ambiguous — save exists in 3 files, none imported directly" +} +``` + +**Protocol:** + +```json +// smp/linker/report — list all unresolved edges in the graph +{ + "jsonrpc": "2.0", + "method": "smp/linker/report", + "params": { + "scope": "full" // "full" | "package:" | "file:" + }, + "id": 24 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "unresolved_count": 4, + "unresolved": [ + { + "caller": "func_process_data", + "file": "src/pipeline/runner.ts", + "to_name": "save", + "candidates": [ + "src/db/user.ts:save", + "src/cache/session.ts:save", + "src/storage/blob.ts:save" + ], + "action": "add_import", + "action_target": "src/pipeline/runner.ts" + } + ] + }, + "id": 24 +} +``` + +--- + +### B2. The Runtime Linker (eBPF Execution Traces) + +Static linking resolves what the *source says* will be called. The Runtime Linker resolves what *actually runs* — capturing real call chains from inside a sandbox via eBPF, then injecting `CALLS_RUNTIME` edges into the graph. + +**Why static linking alone isn't enough:** + +``` +// Dependency Injection — static linker sees no CALLS edge here at all +container.bind("AuthService").to(JwtAuthService); + +// Metaprogramming — target function name is a runtime variable +const method = config.get("handler"); +this[method](payload); +``` + +The static linker marks these as `CALLS_UNRESOLVED`. The runtime linker resolves them by actually executing the code path inside a sandboxed environment and capturing the kernel-level syscall trace via eBPF. + +**How it works:** + +``` +Agent spawns sandbox (smp/sandbox/spawn) + │ + ▼ +Agent runs test suite inside sandbox (smp/sandbox/execute, inject_ebpf: true) + │ + ▼ +eBPF daemon intercepts: every function entry/exit at kernel level + │ + ▼ +SMP Runtime Linker processes trace → resolves targets → injects CALLS_RUNTIME edges + │ + ▼ +Graph DB now has full hybrid call graph: + CALLS_STATIC = "source says this will be called" (resolved at index time) + CALLS_RUNTIME = "kernel confirmed this was called" (resolved at execution time) +``` + +**CALLS_RUNTIME edge schema:** + +```json +{ + "edge": "CALLS_RUNTIME", + "from": "func_process_payment", + "to": "func_handle_stripe_webhook", + "resolved_via": "ebpf_trace", + "sandbox_id": "box_99x", + "commit_sha": "a1b2c3d4", + "call_count": 3, + "first_seen": "2025-02-15T10:44:09Z" +} +``` + +**Protocol — query runtime edges specifically:** + +```json +// smp/linker/runtime — get all CALLS_RUNTIME edges for a node +{ + "jsonrpc": "2.0", + "method": "smp/linker/runtime", + "params": { + "node_id": "func_process_payment", + "commit_sha": "a1b2c3d4" + }, + "id": 25 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "node_id": "func_process_payment", + "runtime_callees": [ + { + "node_id": "func_handle_stripe_webhook", + "file": "src/payments/webhook.ts", + "call_count": 3, + "was_static_unresolved": true + } + ], + "static_only_callees": ["func_validate_amount"], + "unresolved_remaining": 0 + }, + "id": 25 +} +``` + +**Purpose:** Attach human-readable metadata to structural nodes using only what already exists in the code — docstrings, comments, annotations, and decorators. No LLM. No embeddings. Pure static extraction. + +--- + +#### smp/enrich — Extract static metadata from a node + +Reads docstrings, inline comments, decorators, and type annotations directly off the AST. Skips silently if `source_hash` is unchanged since last enrichment. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/enrich", + "params": { + "node_id": "func_authenticate_user", + "force": false // true = re-enrich even if source_hash unchanged + }, + "id": 10 +} + +// Response — enriched +{ + "jsonrpc": "2.0", + "result": { + "node_id": "func_authenticate_user", + "status": "enriched", // "enriched" | "skipped" | "no_metadata" + "docstring": "Validates user credentials and returns a signed JWT for the session.", + "inline_comments": [ + {"line": 18, "text": "compare against bcrypt hash, not plaintext"}, + {"line": 31, "text": "token expiry pulled from env config"} + ], + "decorators": ["@requires_db", "@rate_limited"], + "annotations": { + "params": {"email": "string", "password": "string"}, + "returns": "Promise", + "throws": ["AuthenticationError", "DatabaseError"] + }, + "tags": [], + "source_hash": "a3f9c12d", + "enriched_at": "2025-02-15T10:30:00Z" + }, + "id": 10 +} + +// Response — already fresh, nothing to do +{ + "jsonrpc": "2.0", + "result": { + "node_id": "func_authenticate_user", + "status": "skipped", + "reason": "source_hash unchanged" + }, + "id": 10 +} + +// Response — node has no extractable metadata +{ + "jsonrpc": "2.0", + "result": { + "node_id": "func_xT9_handler", + "status": "no_metadata", + "reason": "no docstring, decorators, or type annotations found" + }, + "id": 10 +} + +// Error — node not found +{ + "jsonrpc": "2.0", + "error": { + "code": -32001, + "message": "Node not found", + "data": {"node_id": "func_authenticate_user"} + }, + "id": 10 +} +``` + +--- + +#### smp/enrich/batch — Enrich a scope at once + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/enrich/batch", + "params": { + "scope": "package:src/auth", // "full" | "package:" | "file:" + "force": false + }, + "id": 11 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "enriched": 24, + "skipped": 6, // source_hash unchanged + "no_metadata": 3, // nothing extractable — see node_ids for smp/annotate targets + "failed": 0, + "no_metadata_nodes": [ + "func_xT9_handler", + "func_a1_proc", + "class_TmpHelper" + ] + }, + "id": 11 +} +``` + +--- + +#### smp/enrich/stale — List nodes whose source changed since last enrichment + +Useful before a batch re-enrich — shows exactly what's out of date without running the full enrichment pass. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/enrich/stale", + "params": { + "scope": "full" // "full" | "package:" | "file:" + }, + "id": 12 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "stale_count": 4, + "stale_nodes": [ + { + "node_id": "func_authenticate_user", + "file": "src/auth/login.ts", + "last_enriched": "2025-02-10T08:00:00Z", + "current_hash": "b7d2e91f", + "enriched_hash": "a3f9c12d" + }, + { + "node_id": "class_AuthService", + "file": "src/auth/login.ts", + "last_enriched": "2025-02-10T08:00:00Z", + "current_hash": "c3a1f004", + "enriched_hash": "99de12ab" + } + ] + }, + "id": 12 +} +``` + +--- + +#### smp/annotate — Manually set metadata on a node + +For `no_metadata` nodes that have nothing extractable. Stored and queried identically to auto-enriched fields. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/annotate", + "params": { + "node_id": "func_xT9_handler", + "description": "Processes Stripe webhook payload and updates subscription status in DB.", + "tags": ["billing", "webhook", "stripe"] + }, + "id": 13 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "node_id": "func_xT9_handler", + "status": "annotated", + "manually_set": true, + "annotated_at": "2025-02-15T11:00:00Z" + }, + "id": 13 +} + +// Error — annotation would overwrite a docstring without force flag +{ + "jsonrpc": "2.0", + "error": { + "code": -32002, + "message": "Node already has extracted docstring. Set force: true to override.", + "data": {"node_id": "func_xT9_handler"} + }, + "id": 13 +} +``` + +--- + +#### smp/annotate/bulk — Annotate multiple nodes in one call + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/annotate/bulk", + "params": { + "annotations": [ + { + "node_id": "func_xT9_handler", + "description": "Processes Stripe webhook, updates subscription status.", + "tags": ["billing", "webhook"] + }, + { + "node_id": "func_a1_proc", + "description": "Runs nightly aggregation job for analytics pipeline.", + "tags": ["analytics", "cron"] + } + ] + }, + "id": 14 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "annotated": 2, + "failed": 0 + }, + "id": 14 +} +``` + +--- + +#### smp/tag — Bulk-tag nodes by scope + +```json +// Request — add tags +{ + "jsonrpc": "2.0", + "method": "smp/tag", + "params": { + "scope": "package:src/payments", + "tags": ["billing", "stripe", "pci-sensitive"], + "action": "add" // "add" | "remove" | "replace" + }, + "id": 15 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "nodes_affected": 31, + "action": "add", + "scope": "package:src/payments" + }, + "id": 15 +} + +// Request — remove a tag that was applied by mistake +{ + "jsonrpc": "2.0", + "method": "smp/tag", + "params": { + "scope": "package:src/payments", + "tags": ["pci-sensitive"], + "action": "remove" + }, + "id": 16 +} +``` + +--- + +#### smp/search — Full-text search across enriched metadata + +BM25-ranked full-text search against docstrings, descriptions, and tags. Backed by a Neo4j Full-Text Index — no table scans, no `CONTAINS` on raw strings. Scales to 100k+ nodes. + +**Index configuration (one-time setup at server start):** + +```cypher +-- Create the full-text index covering all enrichable node types +CALL db.index.fulltext.createNodeIndex( + "smp_fulltext", + ["Function", "Class", "Interface", "Variable"], + ["semantic_docstring", "semantic_description", "semantic_tags"] +) +``` + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/search", + "params": { + "query": "stripe webhook", + "match": "all", // "all" = AND, "any" = OR across query terms + "filter": { + "node_types": ["Function", "Class"], + "tags": ["billing"], + "scope": "package:src/payments" + }, + "top_k": 5 + }, + "id": 17 +} + +// Response — results ranked by BM25 score (term frequency + inverse doc frequency) +{ + "jsonrpc": "2.0", + "result": { + "matches": [ + { + "node_id": "func_xT9_handler", + "node_type": "Function", + "file": "src/payments/webhook.ts", + "docstring": "Processes Stripe webhook payload and updates subscription status in DB.", + "tags": ["billing", "webhook", "stripe"], + "matched_on": ["docstring", "tags"], + "bm25_score": 4.72 + }, + { + "node_id": "class_StripeClient", + "node_type": "Class", + "file": "src/payments/stripe.ts", + "docstring": "Thin wrapper around the Stripe SDK for payment operations.", + "tags": ["billing", "stripe"], + "matched_on": ["docstring", "tags"], + "bm25_score": 3.18 + } + ], + "total": 2 + }, + "id": 17 +} + +// Response — no matches +{ + "jsonrpc": "2.0", + "result": { + "matches": [], + "total": 0, + "searched_fields": ["docstring", "tags"], + "scope_node_count": 312 + }, + "id": 17 +} +``` + +--- + +#### smp/enrich/status — Enrichment coverage report + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/enrich/status", + "params": { + "scope": "full" + }, + "id": 18 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "total_nodes": 1240, + "has_docstring": 834, + "has_annotations": 910, + "has_tags": 412, + "manually_annotated": 17, + "no_metadata": 89, + "stale": 4, + "coverage_pct": 92.8 + }, + "id": 18 +} +``` + +--- + +**Enriched Node (final schema):** + +```json +{ + "id": "func_authenticate_user", + "structural": { "...": "..." }, + "semantic": { + "status": "enriched", + "docstring": "Validates user credentials and returns a signed JWT for the session.", + "description": null, + "inline_comments": [ + {"line": 18, "text": "compare against bcrypt hash, not plaintext"} + ], + "decorators": ["@requires_db", "@rate_limited"], + "annotations": { + "params": {"email": "string", "password": "string"}, + "returns": "Promise", + "throws": ["AuthenticationError", "DatabaseError"] + }, + "tags": ["auth", "jwt", "session"], + "manually_set": false, + "source_hash": "a3f9c12d", + "enriched_at": "2025-02-15T10:30:00Z" + }, + "vector": { + "code_embedding": [0.021, -0.134, 0.087, "..."], + "embedding_input": "func authenticateUser(email: string, password: string): Promise — Validates user credentials and returns a signed JWT for the session.", + "model": "text-embedding-3-small", + "indexed_at": "2025-02-15T10:30:01Z" + } +} +``` + +> **Embedding policy:** `code_embedding` is generated **once at index time** from `signature + docstring`. It is stored in ChromaDB keyed by `node_id`. At query time (`smp/locate`), ChromaDB is called for **seed discovery only** — the actual retrieval, ranking, and response assembly are pure graph + arithmetic. No generative LLM is involved at any point. + +--- + +### D. Community Detection + +**Purpose:** Automatically partition the codebase graph into structural clusters at **two levels** — coarse (architecture overview) and fine (search routing) — so agents can reason about domain boundaries and `smp/locate` Phase 0 narrows seed search to ~200 nodes instead of all 100k. + +**Two-level hierarchy (mirrors GraphRAG):** + +``` +Level 0 — COARSE (global architecture view) + e.g. "backend_core", "api_gateway", "data_layer" + → Used by architecture agents to understand module ownership. + → smp/community/boundaries shows coupling strength between these. + +Level 1 — FINE (search routing) + e.g. "auth_core", "auth_oauth", "payments_stripe", "payments_refunds" + → Subdivisions of coarse communities. + → Used by smp/locate Phase 0 to scope seed search to ~200 nodes. + → Every node carries both community_id_l0 and community_id_l1. +``` + +**How it works — purely topological, no LLM:** + +``` +1. Run Louvain at two resolutions via Neo4j GDS: + resolution=0.5 → fewer, larger communities (Level 0 / coarse) + resolution=1.5 → more, smaller communities (Level 1 / fine) + +2. For each community at each level, derive label from topology: + → majority_path_prefix: most common src/ subdirectory among members + → top_tags: most frequent semantic tags across enriched members + → centroid_embedding: mean of all member code_embeddings (ChromaDB) + — used for community-level vector routing in smp/locate Phase 0 + +3. Write community_id_l0 + community_id_l1 onto every node as properties. + Create Community nodes at both levels, link fine → coarse via CHILD_OF. + Detect cross-community edges → write BRIDGES with coupling_weight. +``` + +**Community Node schema:** + +```json +{ + "id": "comm_auth_core", + "type": "Community", + "level": 1, + "parent_community": "comm_backend_core", + "label": "auth", + "majority_path_prefix": "src/auth", + "top_tags": ["auth", "jwt", "session", "credentials"], + "member_count": 47, + "file_count": 6, + "internal_edge_count": 183, + "external_edge_count": 12, + "modularity_score": 0.74, + "centroid_embedding_id": "centroid_comm_auth_core", + "detected_at": "2025-02-15T10:00:00Z" +} +``` + +**Protocol:** + +```json +// smp/community/detect — Run Louvain at two resolutions, write community_id_l0 +// and community_id_l1 to all nodes. Triggered at index time and when smp/sync +// detects structural changes affecting >10% of nodes. +{ + "jsonrpc": "2.0", + "method": "smp/community/detect", + "params": { + "algorithm": "louvain", + "relationship_types": ["CALLS_STATIC", "CALLS_RUNTIME", "IMPORTS"], + "levels": [ + {"level": 0, "resolution": 0.5, "label": "coarse"}, + {"level": 1, "resolution": 1.5, "label": "fine"} + ], + "min_community_size": 5 + }, + "id": 19 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "nodes_assigned": 1240, + "bridge_edges": 38, + "levels": { + "0": {"communities_found": 5, "modularity": 0.61}, + "1": {"communities_found": 14, "modularity": 0.74} + }, + "coarse_communities": [ + {"id": "comm_backend_core", "label": "backend_core", "member_count": 320, "fine_children": 4}, + {"id": "comm_data_layer", "label": "data_layer", "member_count": 280, "fine_children": 3}, + {"id": "comm_api_gateway", "label": "api_gateway", "member_count": 410, "fine_children": 5} + ], + "fine_communities": [ + {"id": "comm_auth_core", "parent": "comm_backend_core", "label": "auth", "member_count": 47}, + {"id": "comm_payments", "parent": "comm_backend_core", "label": "payments", "member_count": 83}, + {"id": "comm_db_models", "parent": "comm_data_layer", "label": "db", "member_count": 61}, + {"id": "comm_api_layer", "parent": "comm_api_gateway", "label": "api", "member_count": 112}, + {"id": "comm_notifications", "parent": "comm_backend_core", "label": "notifications","member_count": 29} + ] + }, + "id": 19 +} +``` + +```json +// smp/community/list — List all communities at a given level +{ + "jsonrpc": "2.0", + "method": "smp/community/list", + "params": { + "level": 1 // 0 = coarse, 1 = fine, omit = both levels + } +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "total": 14, + "communities": [ + { + "id": "comm_auth_core", + "level": 1, + "parent_community": "comm_backend_core", + "label": "auth", + "majority_path_prefix": "src/auth", + "top_tags": ["auth", "jwt", "session"], + "member_count": 47, + "file_count": 6, + "internal_edge_count": 183, + "external_edge_count": 12, + "modularity_score": 0.74, + "bridge_communities": ["comm_db_models", "comm_api_layer"] + } + ] + } +} +``` + +```json +// smp/community/get — Get all nodes in a specific community +{ + "jsonrpc": "2.0", + "method": "smp/community/get", + "params": { + "community_id": "comm_auth_core", + "node_types": ["Function", "Class"], + "include_bridges": true + }, + "id": 20 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "community_id": "comm_auth_core", + "level": 1, + "parent_community": "comm_backend_core", + "label": "auth", + "member_count": 47, + "members": [ + { + "id": "func_authenticate_user", + "type": "Function", + "name": "authenticateUser", + "file": "src/auth/login.ts", + "pagerank": 0.042, + "heat_score": 96 + } + ], + "bridge_edges": [ + { + "from": "func_authenticate_user", + "to": "class_UserModel", + "edge_type": "CALLS_STATIC", + "to_community": "comm_db_models", + "coupling_weight": 0.31 + } + ] + }, + "id": 20 +} +``` + +```json +// smp/community/boundaries — Coupling strength between all community pairs. +// Architecture agents use this to understand which domains are tightly coupled +// and identify the exact bridge nodes responsible for cross-domain dependencies. +{ + "jsonrpc": "2.0", + "method": "smp/community/boundaries", + "params": { + "level": 0, // 0 = coarse module boundaries, 1 = fine boundaries + "min_coupling": 0.05 // omit pairs below this weight + }, + "id": 21 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "level": 0, + "boundaries": [ + { + "from_community": "comm_backend_core", + "to_community": "comm_data_layer", + "edge_count": 83, + "coupling_weight": 0.61, + "bridge_nodes": [ + {"id": "class_UserModel", "type": "Class", "side": "data_layer", "in_degree_from_peer": 12}, + {"id": "class_OrderModel", "type": "Class", "side": "data_layer", "in_degree_from_peer": 9}, + {"id": "func_authenticate_user", "type": "Function", "side": "backend_core", "out_degree_to_peer": 7} + ] + }, + { + "from_community": "comm_backend_core", + "to_community": "comm_api_gateway", + "edge_count": 47, + "coupling_weight": 0.38, + "bridge_nodes": [ + {"id": "class_AuthService", "type": "Class", "side": "backend_core", "out_degree_to_peer": 14} + ] + }, + { + "from_community": "comm_data_layer", + "to_community": "comm_api_gateway", + "edge_count": 11, + "coupling_weight": 0.09, + "bridge_nodes": [ + {"id": "func_serialize_response", "type": "Function", "side": "api_gateway", "in_degree_from_peer": 11} + ] + } + ] + }, + "id": 21 +} +``` + +--- + +## Part 2: The Query Engine + +### Query Types + +| Type | Purpose | Example | +|------|---------|---------| +| **Navigate** | Find specific entities | "Where is `login` defined?" | +| **Trace** | Follow relationships | "What calls `authenticateUser`?" | +| **Context** | Get relevant context | "I'm editing `auth.ts`, what do I need to know?" | +| **Impact** | Assess change impact | "If I delete this, what breaks?" | +| **Locate** | Find by description | "Where is user registration handled?" | +| **Flow** | Trace data/logic path | "How does a request become a DB entry?" | + +--- + +### Query Engine Implementation + +```python +# smp/engine/query.py +import msgspec +from typing import Sequence +from neo4j import AsyncSession +import chromadb + + +# ── Data Models (msgspec.Struct — zero-copy, schema-validated) ────────────── + +class SeedNode(msgspec.Struct, frozen=True): + node_id: str + node_type: str + name: str + file: str + signature: str + docstring: str | None + tags: list[str] + community_id: str | None # which community this node belongs to + vector_score: float + pagerank: float + heat_score: int + +class WalkNode(msgspec.Struct, frozen=True): + node_id: str + node_type: str + name: str + file: str + signature: str + docstring: str | None + community_id: str | None + edge_type: str + edge_direction: str + hop: int + is_bridge: bool # True if this edge crosses community boundaries + pagerank: float + heat_score: int + +class RankedResult(msgspec.Struct, frozen=True): + node_id: str + node_type: str + name: str + file: str + signature: str + docstring: str | None + tags: list[str] + community_id: str | None + final_score: float + vector_score: float + pagerank: float + heat_score: int + is_seed: bool + reachable_from: list[str] + +class LocateResponse(msgspec.Struct, frozen=True): + query: str + routed_community: str | None # community routing hit — None if cross-community query + seed_count: int + total_walked: int + results: list[RankedResult] + structural_map: list[dict] + + +# ── Seed & Walk Engine ─────────────────────────────────────────────────────── + +class SeedWalkEngine: + """ + Implements the Community-Routed Graph RAG pipeline for smp/locate. + + Phase 0 — ROUTE: Compare query embedding against Level-1 (fine) community centroid + embeddings stored in ChromaDB. Routes to the best-matching fine + community (scoped to ~200 nodes). Low confidence → global search. + Phase 1 — SEED: ChromaDB vector search scoped to routed fine community (or global). + → Top-K nodes whose code_embedding is closest to query. + Phase 2 — WALK: Single Cypher N-hop traversal from seeds. + Follows CALLS_STATIC | CALLS_RUNTIME | IMPORTS | DEFINES. + Crosses community boundaries via BRIDGES edges. + Phase 3 — RANK: Composite score = α·vector + β·pagerank + γ·heat. + Phase 4 — ASSEMBLE: Deduplicated RankedResult list + structural_map with community labels. + + No LLM calls at any phase. + """ + + ALPHA = 0.50 + BETA = 0.30 + GAMMA = 0.20 + ROUTE_CONFIDENCE_THRESHOLD = 0.65 # below this → skip routing, search globally + + def __init__(self, neo4j_session: AsyncSession, chroma_collection: chromadb.Collection): + self._graph = neo4j_session + self._chroma = chroma_collection + + # ── Phase 0: Community Routing ──────────────────────────────────────────── + + async def _route_to_community(self, query: str) -> tuple[str | None, float]: + """ + Compare the query embedding against stored community centroid embeddings. + Returns (community_id, confidence) if a strong match is found. + Returns (None, 0.0) if no community clears the threshold — fallback to global search. + + Centroid embeddings are stored in ChromaDB under the 'centroids' collection, + keyed by community_id. Computed at smp/community/detect time, not per-query. + """ + centroids = self._chroma.query( + collection="centroids", + query_texts=[query], + n_results=1, + include=["metadatas", "distances"] + ) + if not centroids["metadatas"][0]: + return None, 0.0 + + confidence = 1.0 - centroids["distances"][0][0] + if confidence < self.ROUTE_CONFIDENCE_THRESHOLD: + return None, confidence # query spans multiple communities — search globally + + community_id = centroids["metadatas"][0][0]["community_id"] + return community_id, confidence + + # ── Phase 1: Seed ──────────────────────────────────────────────────────── + + async def _seed(self, query: str, seed_k: int, community_id: str | None) -> list[SeedNode]: + """ + Vector search scoped to community_id when routing hit. + Falls back to global search when community_id is None. + """ + where_filter = {"community_id": community_id} if community_id else None + results = self._chroma.query( + query_texts=[query], + n_results=seed_k, + where=where_filter, + include=["metadatas", "distances"] + ) + seeds = [] + for meta, dist in zip(results["metadatas"][0], results["distances"][0]): + seeds.append(SeedNode( + node_id = meta["node_id"], + node_type = meta["node_type"], + name = meta["name"], + file = meta["file"], + signature = meta["signature"], + docstring = meta.get("docstring"), + tags = meta.get("tags", []), + vector_score = 1.0 - dist, # ChromaDB returns L2 distance; convert to similarity + pagerank = meta["pagerank"], + heat_score = meta["heat_score"], + )) + return seeds + + # ── Phase 2: Walk ───────────────────────────────────────────────────────── + + async def _walk(self, seed_ids: list[str], hops: int) -> list[WalkNode]: + """ + Single Cypher query — no N+1. + Traverses CALLS_STATIC, CALLS_RUNTIME, IMPORTS, and DEFINES edges + (Senthil Global Linker edges) up to `hops` depth from each seed. + """ + cypher = """ + UNWIND $seed_ids AS seed_id + MATCH (seed {id: seed_id}) + CALL apoc.path.subgraphNodes(seed, { + relationshipFilter: "CALLS_STATIC>|CALLS_RUNTIME>|IMPORTS>|DEFINES>", + minLevel: 1, + maxLevel: $hops + }) YIELD node + MATCH (seed)-[r*1..$hops]-(node) + WITH seed, node, + [rel IN r | type(rel)] AS edge_types, + [rel IN r | startNode(rel).id] AS edge_starts, + size(r) AS hop_count + RETURN + node.id AS node_id, + node.type AS node_type, + node.name AS name, + node.file AS file, + node.signature AS signature, + node.docstring AS docstring, + edge_types[-1] AS edge_type, + CASE WHEN edge_starts[-1] = node.id THEN 'in' ELSE 'out' END AS edge_direction, + hop_count, + node.pagerank AS pagerank, + node.heat_score AS heat_score, + seed.id AS seed_id + """ + records = await self._graph.run(cypher, seed_ids=seed_ids, hops=hops) + walked: dict[str, WalkNode] = {} + for r in records: + if r["node_id"] not in walked: + walked[r["node_id"]] = WalkNode( + node_id = r["node_id"], + node_type = r["node_type"], + name = r["name"], + file = r["file"], + signature = r["signature"], + docstring = r["docstring"], + edge_type = r["edge_type"], + edge_direction = r["edge_direction"], + hop = r["hop_count"], + pagerank = r["pagerank"] or 0.0, + heat_score = r["heat_score"] or 0, + ) + return list(walked.values()) + + # ── Phase 3: Rank ───────────────────────────────────────────────────────── + + def _rank( + self, + seeds: list[SeedNode], + walked: list[WalkNode], + top_k: int, + ) -> list[RankedResult]: + """ + Composite score: α·vector_score + β·pagerank_norm + γ·heat_norm + vector_score already 0–1 from ChromaDB. + pagerank_norm = pagerank / max_pagerank in result set. + heat_norm = heat_score / 100. + """ + seed_map = {s.node_id: s for s in seeds} + max_pr = max((w.pagerank for w in walked), default=1.0) or 1.0 + + # Seeds are also results — build from seed list first + results: dict[str, RankedResult] = {} + + for s in seeds: + score = ( + self.ALPHA * s.vector_score + + self.BETA * (s.pagerank / max_pr) + + self.GAMMA * (s.heat_score / 100) + ) + results[s.node_id] = RankedResult( + node_id = s.node_id, + node_type = s.node_type, + name = s.name, + file = s.file, + signature = s.signature, + docstring = s.docstring, + tags = s.tags, + final_score = round(score, 4), + vector_score = s.vector_score, + pagerank = s.pagerank, + heat_score = s.heat_score, + is_seed = True, + reachable_from= [s.node_id], + ) + + for w in walked: + if w.node_id in results: + continue + seed_pr = seed_map.get(w.node_id) + v_score = seed_pr.vector_score if seed_pr else 0.0 + score = ( + self.ALPHA * v_score + + self.BETA * (w.pagerank / max_pr) + + self.GAMMA * (w.heat_score / 100) + ) + results[w.node_id] = RankedResult( + node_id = w.node_id, + node_type = w.node_type, + name = w.name, + file = w.file, + signature = w.signature, + docstring = w.docstring, + tags = [], + final_score = round(score, 4), + vector_score = v_score, + pagerank = w.pagerank, + heat_score = w.heat_score, + is_seed = False, + reachable_from= [], + ) + + ranked = sorted(results.values(), key=lambda r: r.final_score, reverse=True) + return ranked[:top_k] + + # ── Phase 4: Structural Map ─────────────────────────────────────────────── + + def _build_structural_map( + self, + results: list[RankedResult], + walked: list[WalkNode], + ) -> list[dict]: + """ + Build an adjacency list of edges between result nodes only. + Used by Accept: text/markdown responses to render the call chain section. + """ + result_ids = {r.node_id for r in results} + edges = [] + for w in walked: + if w.node_id in result_ids: + edges.append({ + "from": w.node_id, + "to": w.node_id, + "edge_type": w.edge_type, + "hop": w.hop, + }) + return edges + + # ── Public entrypoint ───────────────────────────────────────────────────── + + async def locate( + self, + query: str, + seed_k: int = 3, + hops: int = 2, + top_k: int = 10, + ) -> LocateResponse: + seeds = await self._seed(query, seed_k) + walked = await self._walk([s.node_id for s in seeds], hops) + ranked = self._rank(seeds, walked, top_k) + smap = self._build_structural_map(ranked, walked) + + return LocateResponse( + query = query, + seed_count = len(seeds), + total_walked = len(walked), + results = ranked, + structural_map = smap, + ) +``` + +--- + +### The `get_context()` Method (Most Important for Agents) + +```python +def get_context(self, file_path: str, scope: str = "edit"): + """ + Returns the "programmer's mental model" for a file, + plus a pre-computed summary so agents don't drown in raw data. + """ + file_node = self.graph.get_node_by_path(file_path) + imported_by = self.graph.get_relationships(file_node, "IMPORTS", direction="incoming") + + context = { + "self": file_node, + + "imports": self.graph.get_relationships( + file_node, "IMPORTS", direction="outgoing" + ), + + "imported_by": imported_by, + + "defines": self.graph.get_relationships( + file_node, "DEFINES", direction="outgoing" + ), + + "related_patterns": self.graph.find_structurally_similar(file_node), + + "entry_points": self.graph.find_entry_points(file_node), + + "data_flow_in": self.trace_data_flow(file_node, direction="in"), + "data_flow_out": self.trace_data_flow(file_node, direction="out"), + + # Pre-computed summary — agent reads this first, drills into raw fields only if needed + "summary": self._build_summary(file_node, imported_by), + } + + return context + +def _build_summary(self, file_node, imported_by: list) -> dict: + """ + Purely structural computation — no LLM, no generated sentences. + Every field is a count, enum, or list derived directly from the graph. + """ + api_layer_files = self.graph.count_in_package("src/api") + imported_by_api = [f for f in imported_by if "src/api" in f.path] + complexity_scores = [n.metrics["complexity"] for n in file_node.defines] + avg_complexity = round(sum(complexity_scores) / max(len(complexity_scores), 1), 1) + max_complexity = max(complexity_scores, default=0) + + return { + "role": self._classify_role(file_node), + "blast_radius": len(imported_by), + "api_layer_callers": len(imported_by_api), + "api_layer_total": api_layer_files, + "avg_complexity": avg_complexity, + "max_complexity": max_complexity, + "exported_symbols": len(file_node.exports), + "has_tests": self.graph.has_tests(file_node), + "test_files": self.graph.get_test_files(file_node), + "is_hot_node": self.telemetry.is_hot(file_node.id), + "heat_score": self.telemetry.heat_score(file_node.id), + "risk_level": "high" if len(imported_by) > 10 or avg_complexity > 8 else + "medium" if len(imported_by) > 3 or avg_complexity > 4 else + "low", + } + +def _classify_role(self, file_node) -> str: + """ + Derive file role purely from graph topology — no LLM. + + Rules applied in order (first match wins): + test → file path contains /test/ or /spec/, or all defines are Test nodes + config → file is a Config node type + endpoint → file defines a Function with an HTTP-verb decorator (@get/@post/etc) + or path contains /routes/ or /controllers/ + service → file path contains /services/ and has both incoming + outgoing IMPORTS + core_utility → blast_radius > 5 (imported by many) and path contains /utils/ or /lib/ or /shared/ + isolated → blast_radius == 0 and not exported + module → default fallback + """ + path = file_node.path + if "/test" in path or "/spec" in path: + return "test" + if file_node.node_type == "Config": + return "config" + decorators = [d for n in file_node.defines for d in n.decorators] + if any(d in decorators for d in ["@get", "@post", "@put", "@delete", "@patch"]): + return "endpoint" + if "/routes" in path or "/controllers" in path: + return "endpoint" + if "/services" in path: + return "service" + if len(file_node.imported_by) > 5 and any(p in path for p in ["/utils", "/lib", "/shared", "/helpers"]): + return "core_utility" + if len(file_node.imported_by) == 0 and not file_node.exports: + return "isolated" + return "module" + +--- + +## Part 3: The Protocol (SMP) + +### Protocol Specification + +**Name:** Structural Memory Protocol (SMP) +**Version:** 1.0 +**Transport:** JSON-RPC 2.0 over stdio / HTTP / WebSocket +**Inspired by:** MCP (Model Context Protocol), A2A (Agent-to-Agent) + +--- + +### Protocol Methods + +#### 1. Memory Management + +```json +// smp/update - Sync codebase state +{ + "jsonrpc": "2.0", + "method": "smp/update", + "params": { + "type": "file_change", + "file_path": "src/auth/login.ts", + "content": "...", + "change_type": "modified" | "created" | "deleted" + }, + "id": 1 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "status": "success", + "nodes_added": 3, + "nodes_updated": 12, + "nodes_removed": 1, + "relationships_updated": 8 + }, + "id": 1 +} +``` + +```json +// smp/batch_update - Multiple files at once +{ + "jsonrpc": "2.0", + "method": "smp/batch_update", + "params": { + "changes": [ + {"file_path": "src/auth/login.ts", "content": "...", "change_type": "modified"}, + {"file_path": "src/auth/middleware.ts", "content": "...", "change_type": "created"} + ] + }, + "id": 2 +} +``` + +```json +// smp/sync — Merkle-diff sync. Client sends its current root hash and a +// flat map of { file_path → sha256(content) }. Server compares against its +// own Merkle tree and returns exactly which files need to be pushed. +// O(log n) — only walks subtrees where hashes diverge. +{ + "jsonrpc": "2.0", + "method": "smp/sync", + "params": { + "client_root_hash": "e3b0c44298fc", + "file_hashes": { + "src/auth/login.ts": "a3f9c12d", + "src/auth/register.ts": "99de12ab", + "src/utils/crypto.ts": "c3a1f004", + "src/db/models/user.ts": "7f3b9e21" + } + }, + "id": 3 +} + +// Response — server returns the minimal diff, not a full file list +{ + "jsonrpc": "2.0", + "result": { + "server_root_hash": "f7c2a19b3d84", + "in_sync": false, + "diff": { + "stale_on_server": [ + { + "file": "src/auth/login.ts", + "client_hash": "a3f9c12d", + "server_hash": "b7d2e91f", + "action": "push" // client is newer — push to server + } + ], + "missing_on_client": [ + { + "file": "src/auth/oauth.ts", + "server_hash": "44f1c8d9", + "action": "pull" // server has file client doesn't know about + } + ], + "deleted_on_server": [ + { + "file": "src/auth/legacy.ts", + "action": "remove_from_graph" + } + ], + "unchanged": 2 // count only — no need to list them + } + }, + "id": 3 +} + +// Response — already in sync, nothing to do +{ + "jsonrpc": "2.0", + "result": { + "server_root_hash": "e3b0c44298fc", + "in_sync": true, + "diff": { + "stale_on_server": [], + "missing_on_client": [], + "deleted_on_server": [], + "unchanged": 4 + } + }, + "id": 3 +} +``` + +```json +// smp/merkle/tree — Return the server's full Merkle tree. +// Agents use this to build a local copy for offline diff before connecting. +{ + "jsonrpc": "2.0", + "method": "smp/merkle/tree", + "params": { + "scope": "full" // "full" | "package:src/auth" + }, + "id": 4 +} + +// Response — hierarchical hash tree, mirrors the package/file structure +{ + "jsonrpc": "2.0", + "result": { + "root_hash": "f7c2a19b3d84", + "tree": { + "src": { + "hash": "9c4f2a1b", + "children": { + "auth": { + "hash": "3d8e7f12", + "children": { + "login.ts": {"hash": "b7d2e91f", "node_count": 4}, + "register.ts": {"hash": "99de12ab", "node_count": 3}, + "oauth.ts": {"hash": "44f1c8d9", "node_count": 6} + } + }, + "utils": { + "hash": "1a3c9f00", + "children": { + "crypto.ts": {"hash": "c3a1f004", "node_count": 5} + } + } + } + } + } + }, + "id": 4 +} +``` + +--- + +#### 1b. Secure Index Distribution + +A new agent or a new SMP server instance does not need to re-index the entire codebase from scratch. The current server exports a cryptographically signed snapshot of its index. The recipient compares Merkle root hashes — if they match, it imports directly. If they differ, it only re-indexes the diverging subtrees. + +``` +┌─────────────────┐ ┌─────────────────┐ +│ SMP Server A │ export │ SMP Server B │ +│ (source) │──────────▶│ (new instance) │ +│ │ signed │ │ +│ root: f7c2a19b │ snapshot │ 1. verify sig │ +│ │ │ 2. compare root │ +└─────────────────┘ │ 3a. match → │ + │ import all │ + │ 3b. differ → │ + │ sync diff │ + └─────────────────┘ +``` + +```json +// smp/index/export — Package the current index as a signed, portable snapshot. +// Used for fast agent onboarding and multi-instance distribution. +{ + "jsonrpc": "2.0", + "method": "smp/index/export", + "params": { + "scope": "full", // "full" | "package:" + "signing_key_id": "key_prod_01" + }, + "id": 5 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "snapshot_id": "snap_4f8a2c", + "root_hash": "f7c2a19b3d84", + "scope": "full", + "node_count": 1240, + "edge_count": 8430, + "signed_at": "2025-02-15T10:00:00Z", + "signature": "sha256:a1b2c3...", + "export_url": "smp://snapshots/snap_4f8a2c.tar.zst" + }, + "id": 5 +} +``` + +```json +// smp/index/import — Load a signed snapshot into this server instance. +// Verifies signature and root hash before touching the graph. +{ + "jsonrpc": "2.0", + "method": "smp/index/import", + "params": { + "snapshot_id": "snap_4f8a2c", + "source_url": "smp://snapshots/snap_4f8a2c.tar.zst", + "expected_root_hash": "f7c2a19b3d84", + "verify_signature": true + }, + "id": 6 +} + +// Response: hashes match → full import, no re-indexing needed +{ + "jsonrpc": "2.0", + "result": { + "status": "imported", + "root_hash_verified": true, + "signature_verified": true, + "nodes_imported": 1240, + "edges_imported": 8430, + "re_indexed_files": 0, + "duration_ms": 840 + }, + "id": 6 +} + +// Response: hashes differ → partial re-index of diverging subtrees only +{ + "jsonrpc": "2.0", + "result": { + "status": "partial_import", + "root_hash_verified": false, + "signature_verified": true, + "nodes_imported": 1218, + "edges_imported": 8390, + "diverging_packages": ["src/auth", "src/api"], + "re_indexed_files": 7, + "duration_ms": 2310 + }, + "id": 6 +} + +// Response: signature invalid → rejected entirely +{ + "jsonrpc": "2.0", + "error": { + "code": -32010, + "message": "signature_invalid", + "data": {"snapshot_id": "snap_4f8a2c", "key_id": "key_prod_01"} + }, + "id": 6 +} +``` + +--- + +#### 2. Structural Queries + +```json +// smp/navigate - Find entity and basic info +{ + "jsonrpc": "2.0", + "method": "smp/navigate", + "params": { + "query": "authenticateUser", + "include_relationships": true + }, + "id": 4 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "entity": { + "id": "func_authenticate_user", + "type": "Function", + "file": "src/auth/login.ts", + "signature": "authenticateUser(email: string, password: string): Promise", + "purpose": "Handles user authentication..." + }, + "relationships": { + "calls": ["hashPassword", "compareHash", "generateToken"], + "called_by": ["loginRoute", "test_auth"], + "depends_on": ["UserModel", "TokenService"] + } + }, + "id": 4 +} +``` + +```json +// smp/trace - Follow relationship chain +{ + "jsonrpc": "2.0", + "method": "smp/trace", + "params": { + "start": "func_authenticate_user", + "relationship": "CALLS", + "depth": 3, + "direction": "outgoing" + }, + "id": 5 +} + +// Response: Returns the call graph as a tree +{ + "jsonrpc": "2.0", + "result": { + "root": "authenticateUser", + "tree": { + "authenticateUser": { + "calls": { + "hashPassword": {"calls": {"bcrypt.hash": {}}}, + "compareHash": {"calls": {"bcrypt.compare": {}}}, + "generateToken": {"calls": {"jwt.sign": {}}} + } + } + } + }, + "id": 5 +} +``` + +--- + +#### 3. Context Queries (Proactive) + +```json +// smp/context - Get editing context +{ + "jsonrpc": "2.0", + "method": "smp/context", + "params": { + "file_path": "src/auth/login.ts", + "scope": "edit", // "edit" | "create" | "debug" | "review" + "depth": 2 + }, + "id": 6 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "summary": { + "role": "core_utility", + "blast_radius": 12, + "api_layer_callers": 7, + "api_layer_total": 12, + "avg_complexity": 5.2, + "max_complexity": 9, + "exported_symbols": 3, + "has_tests": true, + "test_files": ["tests/auth.test.ts"], + "is_hot_node": true, + "heat_score": 96, + "risk_level": "high" + }, + "self": { + "id": "file_auth_login_ts", + "path": "src/auth/login.ts", + "language": "typescript", + "lines": 120, + "source_hash": "a3f9c12d" + }, + "imports": [ + {"file": "src/utils/crypto.ts", "items": ["hashPassword", "compareHash"]}, + {"file": "src/db/models/user.ts", "items": ["UserModel"]} + ], + "imported_by": [ + {"file": "src/api/routes.ts"}, + {"file": "src/middleware/auth.ts"}, + {"file": "src/api/admin.ts"} + ], + "defines": { + "functions": [ + {"id": "func_authenticate_user", "name": "authenticateUser", "complexity": 9, "exported": true}, + {"id": "func_refresh_token", "name": "refreshToken", "complexity": 4, "exported": true} + ], + "classes": [ + {"id": "class_AuthService", "name": "AuthService", "method_count": 3, "exported": true} + ] + }, + "structurally_similar": [ + {"file": "src/api/users.ts", "shared_imports": 3, "shared_node_types": ["Function", "Class"]}, + {"file": "src/api/session.ts", "shared_imports": 2, "shared_node_types": ["Function"]} + ], + "entry_points": ["func_authenticate_user", "func_refresh_token"], + "data_flow_in": [ + {"from": "src/api/routes.ts", "via": "loginRoute", "carries": "Request"} + ], + "data_flow_out": [ + {"to": "src/db/models/user.ts", "via": "UserModel.findByEmail", "carries": "UserRecord"} + ] + }, + "id": 6 +} +``` + +```json +// smp/impact - Assess change impact +{ + "jsonrpc": "2.0", + "method": "smp/impact", + "params": { + "entity": "func_authenticate_user", + "change_type": "signature_change" | "delete" | "move" + }, + "id": 7 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "affected_files": [ + "src/api/routes.ts", + "tests/auth.test.ts", + "src/middleware/auth.ts" + ], + "affected_functions": [ + {"id": "func_login_route", "file": "src/api/routes.ts", "relationship": "CALLS"}, + {"id": "func_test_authenticate_user", "file": "tests/auth.test.ts", "relationship": "TESTS"}, + {"id": "func_auth_middleware", "file": "src/middleware/auth.ts", "relationship": "CALLS"} + ], + "severity": "high", + "required_updates": [ + { + "file": "src/api/routes.ts", + "function": "loginRoute", + "reason": "CALLS", + "change_type": "signature_change" + }, + { + "file": "tests/auth.test.ts", + "function": "test_authenticate_user", + "reason": "TESTS", + "change_type": "signature_change" + } + ] + }, + "id": 7 +} +``` + +--- + +#### 4. Community-Routed Graph RAG (`smp/locate`) + +`smp/locate` is the primary code discovery method. It runs a five-phase Graph RAG pipeline — no LLM at any stage: + +``` +Phase 0 — ROUTE: Compare query against Level-1 (fine) community centroid embeddings. + → Best-match fine community returned with confidence score. + → If confidence ≥ 0.65: scope seed search to that fine community (~200 nodes). + → If confidence < 0.65: query spans multiple communities — search globally. + Key Graph RAG insight: narrow the search space BEFORE seeding. + Architecture agents can also force Level-0 routing to get module-level results. + +Phase 1 — SEED: ChromaDB vector search, scoped to community or global. + → Top-K nodes whose code_embedding is closest to the query. + → No generative model; embedding of the query string only. + +Phase 2 — WALK: Single Cypher N-hop traversal from each seed. + → Follows CALLS_STATIC | CALLS_RUNTIME | IMPORTS | DEFINES edges. + → Crosses community boundaries via BRIDGES edges when relevant. + → One query, zero N+1 overhead. + +Phase 3 — RANK: Composite score per node: + final_score = 0.50 × vector_score + + 0.30 × (pagerank / max_pagerank) + + 0.20 × (heat_score / 100) + +Phase 4 — ASSEMBLE: Deduplicated ranked list + structural_map adjacency list. + Results include community_id so the agent knows which domain each node lives in. +``` + +**PageRank** is pre-computed by Neo4j GDS at index time and stored as a property on every node. **Community centroids** are computed at `smp/community/detect` time. Neither is computed per-query. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/locate", + "params": { + "query": "user registration", + "seed_k": 3, + "hops": 2, + "top_k": 10, + "node_types": ["Function", "Class"], + "community_id": null // null = auto-route via Phase 0; set explicitly to force a community + }, + "id": 8 +} + +// Response (Accept: application/json) +{ + "jsonrpc": "2.0", + "result": { + "query": "user registration", + "routed_community": { + "id": "comm_auth_core", + "label": "auth", + "confidence": 0.83, + "searched_nodes": 47 // searched 47 nodes instead of 1240 — 96% reduction + }, + "seed_count": 3, + "total_walked": 18, + "results": [ + { + "node_id": "func_register_user", + "node_type": "Function", + "name": "registerUser", + "file": "src/auth/register.ts", + "community_id": "comm_auth_core", + "signature": "registerUser(email: string, password: string): Promise", + "docstring": "Creates a new user account and sends a verification email.", + "tags": ["auth", "registration"], + "final_score": 0.8821, + "vector_score": 0.94, + "pagerank": 0.031, + "heat_score": 42, + "is_seed": true, + "reachable_from": ["func_register_user"] + }, + { + "node_id": "class_UserService", + "node_type": "Class", + "name": "UserService", + "file": "src/services/user.ts", + "community_id": "comm_db_models", + "signature": "class UserService", + "docstring": "Manages user CRUD operations including registration.", + "tags": ["user", "service"], + "final_score": 0.7340, + "vector_score": 0.81, + "pagerank": 0.058, + "heat_score": 61, + "is_seed": false, + "reachable_from": ["func_register_user"] + }, + { + "node_id": "func_send_verification_email", + "node_type": "Function", + "name": "sendVerificationEmail", + "file": "src/notifications/email.ts", + "community_id": "comm_notifications", + "signature": "sendVerificationEmail(userId: string): Promise", + "docstring": "Sends account verification link to new user.", + "tags": ["email", "notifications"], + "final_score": 0.6180, + "vector_score": 0.71, + "pagerank": 0.019, + "heat_score": 18, + "is_seed": false, + "reachable_from": ["func_register_user"] + } + ], + "structural_map": [ + {"from": "func_register_user", "to": "class_UserService", "edge_type": "CALLS_STATIC", "hop": 1, "is_bridge": true, "bridge": "auth → db"}, + {"from": "func_register_user", "to": "func_send_verification_email","edge_type": "CALLS_STATIC", "hop": 1, "is_bridge": true, "bridge": "auth → notifications"}, + {"from": "class_UserService", "to": "func_validate_email_format", "edge_type": "DEFINES", "hop": 2, "is_bridge": false} + ] + }, + "id": 8 +} +``` + +**`Accept: text/markdown` response** — when the client sends `Accept: text/markdown`, the server assembles `LocateResponse` into a structured Markdown document for direct agent consumption: + +```` +// smp/locate response — Accept: text/markdown + +## Results for: "user registration" +_3 seeds · 24 nodes walked · top 3 shown_ + +--- + +### 1. `registerUser` · Function · score 0.8821 ★ seed +**File:** `src/auth/register.ts` +**Signature:** `registerUser(email: string, password: string): Promise` +**Docstring:** Creates a new user account and sends a verification email. +**Tags:** `auth` `registration` +| vector | pagerank | heat | +|--------|----------|------| +| 0.94 | 0.031 | 42 | + +--- + +### 2. `UserService` · Class · score 0.7340 +**File:** `src/services/user.ts` +**Docstring:** Manages user CRUD operations including registration. +**Reachable from:** `registerUser` + +--- + +### 3. `sendVerificationEmail` · Function · score 0.6180 +**File:** `src/notifications/email.ts` +**Signature:** `sendVerificationEmail(userId: string): Promise` +**Reachable from:** `registerUser` + +--- + +## Structural Map + +``` +registerUser + ├─[CALLS_STATIC]──▶ UserService + └─[CALLS_STATIC]──▶ sendVerificationEmail + └─[DEFINES]──▶ validateEmailFormat +``` +```` + +--- + +#### 5. Flow Analysis + +```json +// smp/flow - Trace execution/data flow +{ + "jsonrpc": "2.0", + "method": "smp/flow", + "params": { + "start": "api_route_login", + "end": "database_write_user", + "flow_type": "data" | "execution" + }, + "id": 9 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "path": [ + {"node": "api_route_login", "type": "endpoint", "file": "src/api/routes.ts"}, + {"node": "auth_middleware", "type": "middleware", "file": "src/middleware/auth.ts"}, + {"node": "authenticateUser", "type": "function", "file": "src/auth/login.ts"}, + {"node": "UserModel.findByEmail", "type": "method", "file": "src/db/models/user.ts"}, + {"node": "generateToken", "type": "function", "file": "src/auth/login.ts"}, + {"node": "response_json", "type": "output", "file": "src/api/routes.ts"} + ], + "type_transitions": [ + {"from_node": "api_route_login", "to_node": "authenticateUser", "param_types": ["Request"], "return_type": "Promise"}, + {"from_node": "authenticateUser", "to_node": "UserModel.findByEmail", "param_types": ["string"], "return_type": "Promise"}, + {"from_node": "authenticateUser", "to_node": "generateToken", "param_types": ["UserRecord"], "return_type": "Promise"} + ] + }, + "id": 9 +} +``` + +--- + +#### 6. Structural Diff + +Before writing, an agent needs to know *exactly* what changed between the current version and its proposed version — at the node level, not the line level. + +```json +// smp/diff - Compare current graph state of a file against proposed new content +{ + "jsonrpc": "2.0", + "method": "smp/diff", + "params": { + "file_path": "src/auth/login.ts", + "proposed_content": "..." + }, + "id": 10 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "nodes_added": [ + {"id": "func_check_rate_limit", "type": "Function", "name": "checkRateLimit"} + ], + "nodes_removed": [], + "nodes_modified": [ + { + "id": "func_authenticate_user", + "changes": { + "signature_changed": false, + "body_changed": true, + "complexity_delta": +2, + "calls_added": ["func_check_rate_limit"], + "calls_removed": [] + } + } + ], + "relationships_added": [ + {"edge": "CALLS", "from": "func_authenticate_user", "to": "func_check_rate_limit"} + ], + "relationships_removed": [] + }, + "id": 10 +} +``` + +--- + +#### 7. Multi-File Plan + +Before a complex multi-file task, the agent declares its full plan upfront. SMP validates scope, detects inter-file conflicts, and returns a risk-ranked execution order. + +```json +// smp/plan - Validate and rank a multi-file task before execution +{ + "jsonrpc": "2.0", + "method": "smp/plan", + "params": { + "session_id": "ses_4f8a2c", + "task": "Refactor AuthService to support OAuth in addition to password auth", + "intended_writes": [ + "src/auth/login.ts", + "src/auth/oauth.ts", + "src/middleware/auth.ts", + "src/types/token.ts" + ] + }, + "id": 11 +} + +// Response — execution order sorted by dependency: write leaves first, roots last +{ + "jsonrpc": "2.0", + "result": { + "execution_order": [ + { + "step": 1, + "file": "src/types/token.ts", + "dependants_in_plan": 0, + "dependencies_in_plan": 0, + "blast_radius": 2, + "risk_level": "low" + }, + { + "step": 2, + "file": "src/auth/oauth.ts", + "dependants_in_plan": 1, + "dependencies_in_plan": 0, + "blast_radius": 0, + "risk_level": "low", + "is_new_file": true + }, + { + "step": 3, + "file": "src/auth/login.ts", + "dependants_in_plan": 1, + "dependencies_in_plan": 1, + "depends_on_steps": [1], + "blast_radius": 12, + "is_hot_node": true, + "heat_score": 96, + "risk_level": "high" + }, + { + "step": 4, + "file": "src/middleware/auth.ts", + "dependants_in_plan": 0, + "dependencies_in_plan": 2, + "depends_on_steps": [2, 3], + "blast_radius": 5, + "risk_level": "medium" + } + ], + "inter_file_conflicts": [], + "external_files_at_risk": [ + "src/api/routes.ts", + "tests/auth.test.ts" + ] + }, + "id": 11 +} +``` + +--- + +#### 8. Conflict Detection + +Check if two agents' scopes overlap before either starts writing. + +```json +// smp/conflict - Detect scope overlap between two planned sessions +{ + "jsonrpc": "2.0", + "method": "smp/conflict", + "params": { + "session_a": "ses_4f8a2c", + "session_b": "ses_7c1d9f" + }, + "id": 12 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "conflict": true, + "overlapping_files": ["src/auth/login.ts"], + "overlapping_nodes": ["func_authenticate_user"], + "session_modes": { + "ses_4f8a2c": "write", + "ses_7c1d9f": "read" + }, + "write_session": "ses_4f8a2c" + }, + "id": 12 +} + +// No conflict +{ + "jsonrpc": "2.0", + "result": { + "conflict": false + }, + "id": 12 +} +``` + +--- + +#### 9. Graph Explanation + +Agents often need to understand *why* a dependency exists — not just that it does. `smp/graph/why` traces the shortest structural path between two nodes and returns it as a human-readable chain. + +```json +// smp/graph/why - Explain the dependency path between two nodes +{ + "jsonrpc": "2.0", + "method": "smp/graph/why", + "params": { + "from": "src/api/routes.ts", + "to": "src/utils/crypto.ts" + }, + "id": 13 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "path_length": 3, + "chain": [ + {"node": "src/api/routes.ts", "edge": "IMPORTS", "target": "src/auth/login.ts"}, + {"node": "src/auth/login.ts", "edge": "IMPORTS", "target": "src/utils/crypto.ts"} + ], + "readable": "routes.ts → imports → login.ts → imports → crypto.ts" + }, + "id": 13 +} + +// No path found +{ + "jsonrpc": "2.0", + "result": { + "path_length": null, + "chain": [], + "readable": "No dependency path exists between these two nodes" + }, + "id": 13 +} +``` + +--- + +### Event Notifications (Server → Agent) + +```json +// Notification: Memory updated +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "memory_updated", + "changes": { + "files_affected": ["src/auth/login.ts"], + "structural_changes": ["func_authenticate_user modified"], + "semantic_changes": ["purpose re-enriched"] + } + } +} +``` + +```json +// Notification: Conflict detected +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "conflict_detected", + "severity": "warning", + "message": "File modified by external process, memory may be stale", + "file": "src/auth/login.ts" + } +} +``` + +--- + +## Part 4: Agent Safety Protocol + +> The core idea: **the agent must talk to SMP before it touches anything.** SMP acts as the guardrail layer between the agent and the codebase — enforcing scope, surfacing danger, and keeping a full audit trail of every write. + +--- + +### The Agent Write Lifecycle (MVCC + Sandbox) + +File-level locking (`smp/lock`) is the sequential model — one agent, one file at a time. For swarms of parallel agents, SMP uses **MVCC**: each agent works against a specific `commit_sha` snapshot in its own isolated sandbox. No agent blocks another. Merge conflicts are resolved at PR time, not at lock-acquisition time. + +``` +Agent receives task + │ + ▼ +┌──────────────────────┐ +│ smp/session/open │ ← declare intent + commit_sha snapshot +└─────────┬────────────┘ + │ + ▼ +┌──────────────────────┐ +│ smp/sandbox/spawn │ ← get isolated microVM, CoW filesystem, firewalled network +└─────────┬────────────┘ + │ + ▼ +┌──────────────────────┐ +│ smp/guard/check │ ← pre-flight: coverage gaps, hot nodes, blast radius +└─────────┬────────────┘ + │ + ┌────┴────┐ + CLEAR RED_ALERT ──► fix blocking condition, re-check + │ BLOCKED ──► abort + ▼ +┌──────────────────────┐ +│ smp/dryrun │ ← structural diff against snapshot — what would break? +└─────────┬────────────┘ + │ + ┌────┴──────┐ + SAFE BREAKING ──► fix callers first, re-run + │ + ▼ + WRITE FILE (inside sandbox) + │ + ▼ +┌──────────────────────┐ +│ smp/sandbox/execute │ ← run tests, capture eBPF runtime trace +└─────────┬────────────┘ + │ + ┌────┴──────────┐ + PASS FAIL ──► read stderr/trace, self-correct, re-execute + │ + ▼ +┌──────────────────────┐ +│ smp/verify/integrity│ ← AST data-flow check + mutation test gate +└─────────┬────────────┘ + │ + ┌────┴───────────┐ + PASSED SURVIVING_MUTANT ──► tighten assertions, re-verify + │ + ▼ +┌──────────────────────┐ +│ smp/update │ ← sync graph memory with new file state +└─────────┬────────────┘ + │ + ▼ +┌──────────────────────┐ +│ smp/handoff/pr │ ← pass to reviewer agent or file PR directly +└─────────┬────────────┘ + │ + ▼ +┌──────────────────────┐ +│ smp/session/close │ ← commit audit log, destroy sandbox +└──────────────────────┘ +``` + +--- + +#### 1. Session Management + +Sessions and locks are persisted directly in the Graph DB — not in memory. If the SMP server restarts, all active sessions, locks, and checkpoints survive. Agents reconnect and continue without losing their write guards. + +**Persistence schema in Graph DB:** + +``` +(:Session {id, agent_id, task, scope, mode, status, opened_at, expires_at}) +(:Lock {file, held_by_session, acquired_at, expires_at}) +(:Checkpoint {id, session_id, files_snapshotted, snapshot_at, content_hash}) + +(:Session)-[:HOLDS]->(:Lock) +(:Session)-[:HAS_CHECKPOINT]->(:Checkpoint) +``` + +```json +// smp/session/open — declare intent before touching the codebase +{ + "jsonrpc": "2.0", + "method": "smp/session/open", + "params": { + "agent_id": "coder_agent_01", + "task": "Add rate limiting to the login endpoint", + "scope": [ + "src/auth/login.ts", + "src/middleware/rateLimit.ts" + ], + "mode": "write", // "read" | "write" + "commit_sha": "a1b2c3d4", // graph snapshot this session operates against + "concurrency": "mvcc" // "mvcc" (parallel, sandbox-isolated) | "exclusive" (file-locked, sequential) + }, + "id": 15 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "session_id": "ses_4f8a2c", + "commit_sha": "a1b2c3d4", + "concurrency": "mvcc", + "granted_scope": [ + "src/auth/login.ts", + "src/middleware/rateLimit.ts" + ], + "denied_scope": [], + "scope_analysis": { + "src/auth/login.ts": {"blast_radius": 12, "is_hot_node": true, "heat_score": 96, "risk_level": "high"}, + "src/middleware/rateLimit.ts": {"blast_radius": 3, "is_hot_node": false, "heat_score": 12, "risk_level": "medium"} + }, + "safety_level": "elevated", + "expires_at": "2025-02-15T11:30:00Z" + }, + "id": 15 +} +``` + +```json +// smp/session/recover — reconnect to a persisted session after a server restart or crash +{ + "jsonrpc": "2.0", + "method": "smp/session/recover", + "params": { + "session_id": "ses_4f8a2c", + "agent_id": "coder_agent_01" + }, + "id": 16 +} + +// Response — session is intact, locks re-confirmed +{ + "jsonrpc": "2.0", + "result": { + "session_id": "ses_4f8a2c", + "status": "recovered", + "scope": ["src/auth/login.ts", "src/middleware/rateLimit.ts"], + "locks_held": ["src/auth/login.ts"], + "checkpoints": ["chk_3a7f91"], + "events_so_far": 4, + "expires_at": "2025-02-15T11:30:00Z" + }, + "id": 16 +} + +// Response — session expired during downtime, must re-open +{ + "jsonrpc": "2.0", + "result": { + "status": "expired", + "reason": "ttl_elapsed", + "last_checkpoint": "chk_3a7f91", + "last_checkpoint_at": "2025-02-15T10:45:00Z", + "files_snapshotted": ["src/auth/login.ts"] + }, + "id": 16 +} +``` + +```json +// smp/session/close — commit the session, release locks, write audit log +{ + "jsonrpc": "2.0", + "method": "smp/session/close", + "params": { + "session_id": "ses_4f8a2c", + "status": "completed" // "completed" | "aborted" | "rolled_back" + }, + "id": 17 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "session_id": "ses_4f8a2c", + "files_written": ["src/auth/login.ts"], + "files_read": ["src/middleware/rateLimit.ts"], + "duration_ms": 4200, + "audit_log_id": "aud_9b3e1a" + }, + "id": 17 +} +``` + +--- + +#### 2. Pre-Flight Guard Check + +Before writing, the agent asks SMP: *is it safe to touch this?* SMP checks scope, locks, concurrent agents, and — critically — whether the specific function being changed has test coverage. If a high-complexity function has zero test coverage, the guard returns `red_alert` and blocks the write. + +```json +// smp/guard/check — pre-flight safety check before writing a file +{ + "jsonrpc": "2.0", + "method": "smp/guard/check", + "params": { + "session_id": "ses_4f8a2c", + "target": "src/auth/login.ts", + "intended_change": "modify_function:authenticateUser", + "coverage_report": "coverage/lcov.info" // optional: path to lcov/cobertura report + }, + "id": 19 +} + +// Response: CLEAR +{ + "jsonrpc": "2.0", + "result": { + "verdict": "clear", + "target": "src/auth/login.ts", + "checks": { + "in_declared_scope": true, + "locked_by_other_agent": false, + "is_hot_node": false, + "heat_score": 18, + "has_tests": true, + "test_files": ["tests/auth.test.ts"], + "function_coverage": { + "authenticateUser": {"covered": true, "coverage_pct": 87} + }, + "caller_count": 3, + "blast_radius": 3, + "is_public_api": true + }, + "safety_level": "standard" + }, + "id": 19 +} + +// Response: RED ALERT — high-complexity function, zero test coverage +{ + "jsonrpc": "2.0", + "result": { + "verdict": "red_alert", + "target": "src/auth/login.ts", + "checks": { + "in_declared_scope": true, + "locked_by_other_agent": false, + "is_hot_node": true, + "heat_score": 96, + "has_tests": true, + "test_files": ["tests/auth.test.ts"], + "function_coverage": { + "authenticateUser": {"covered": false, "coverage_pct": 0} + }, + "caller_count": 12, + "blast_radius": 12, + "is_public_api": true + }, + "safety_level": "elevated", + "blocking": [ + {"code": "ZERO_COVERAGE", "node_id": "func_authenticate_user", "complexity": 9, "coverage_pct": 0}, + {"code": "HOT_NODE", "node_id": "func_authenticate_user", "heat_score": 96, "caller_count": 12} + ], + "unblock_conditions": [ + {"code": "ZERO_COVERAGE", "action": "add_tests", "target_node": "func_authenticate_user", "min_coverage_pct": 60} + ] + }, + "id": 19 +} + +// Response: BLOCKED — hard stop, no conditions +{ + "jsonrpc": "2.0", + "result": { + "verdict": "blocked", + "target": "src/auth/login.ts", + "reasons": [ + "File is outside declared session scope", + "Locked by session ses_7c1d9f (agent: reviewer_agent_02)" + ] + }, + "id": 19 +} +``` + +**Verdict levels:** + +| Verdict | Meaning | Agent action | +|---|---|---| +| `clear` | Safe to proceed | Continue to dryrun | +| `red_alert` | High risk, remediable | Fix the blocking reason, re-check | +| `blocked` | Hard stop | Abort — do not proceed | + +--- + +#### 3. Dry Run + +Simulate the write. SMP resolves the structural impact of the proposed change without writing anything to disk — returning exactly which nodes, files, and callers would be affected. + +```json +// smp/dryrun — simulate a write and see what breaks +{ + "jsonrpc": "2.0", + "method": "smp/dryrun", + "params": { + "session_id": "ses_4f8a2c", + "file_path": "src/auth/login.ts", + "proposed_content": "...", + "change_summary": "Added rate limit check before credential validation" + }, + "id": 18 +} + +// Response: SAFE +{ + "jsonrpc": "2.0", + "result": { + "structural_delta": { + "nodes_added": 1, + "nodes_modified": 1, + "nodes_removed": 0, + "signature_changed": false + }, + "impact": { + "affected_files": [], + "broken_callers": [], + "test_coverage_delta": "unchanged" + }, + "verdict": "safe", + "risks": [] + }, + "id": 18 +} + +// Response: BREAKING change detected +{ + "jsonrpc": "2.0", + "result": { + "structural_delta": { + "nodes_modified": 1, + "signature_changed": true + }, + "impact": { + "affected_files": ["src/api/routes.ts", "tests/auth.test.ts"], + "broken_callers": [ + { + "function": "loginRoute", + "file": "src/api/routes.ts", + "expected_return_type": "Promise", + "actual_return_type": "Promise<{token, retryAfter}>" + } + ], + "broken_tests": [ + { + "function": "test_authenticate_user", + "file": "tests/auth.test.ts", + "expected_return_type": "Promise", + "actual_return_type": "Promise<{token, retryAfter}>" + } + ] + }, + "verdict": "breaking" + }, + "id": 18 +} +``` + +--- + +#### 4. Checkpoint & Rollback + +Snapshot the structural state of any file before writing. If the agent's edit produces bad output, it can roll back to the snapshot in one call. + +```json +// smp/checkpoint — snapshot state before a risky write +{ + "jsonrpc": "2.0", + "method": "smp/checkpoint", + "params": { + "session_id": "ses_4f8a2c", + "files": ["src/auth/login.ts"] + }, + "id": 19 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "checkpoint_id": "chk_3a7f91", + "files_snapshotted": ["src/auth/login.ts"], + "snapshot_at": "2025-02-15T10:45:00Z" + }, + "id": 19 +} +``` + +```json +// smp/rollback — revert to a checkpoint +{ + "jsonrpc": "2.0", + "method": "smp/rollback", + "params": { + "session_id": "ses_4f8a2c", + "checkpoint_id": "chk_3a7f91" + }, + "id": 20 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "status": "rolled_back", + "files_restored": ["src/auth/login.ts"], + "memory_resynced": true + }, + "id": 20 +} +``` + +--- + +#### 5. Concurrency: MVCC vs File Locks + +Two concurrency modes. Choose at `session/open` time. + +**MVCC (default for swarms):** Each agent operates against its own `commit_sha` snapshot. No agent can block another. Multiple agents work in parallel on the same file simultaneously — conflicts surface at `smp/handoff/pr` as standard merge conflicts, resolved by the reviewer agent or a human. This is the model for autonomous agent swarms. + +**Exclusive locks (sequential writes):** The original file-lock model. Use when an operation *must* be the only writer and ordering matters — e.g. a schema migration that must complete before any other agent reads the new shape. + +```json +// smp/lock — claim exclusive write access (sequential mode only) +// Not needed in MVCC mode — sandbox isolation replaces this +{ + "jsonrpc": "2.0", + "method": "smp/lock", + "params": { + "session_id": "ses_4f8a2c", + "files": ["src/db/migrations/0012_add_user_role.ts"] + }, + "id": 21 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "granted": ["src/db/migrations/0012_add_user_role.ts"], + "denied": [] + }, + "id": 21 +} +``` + +```json +// smp/unlock — release locks (also released automatically on session/close) +{ + "jsonrpc": "2.0", + "method": "smp/unlock", + "params": { + "session_id": "ses_4f8a2c", + "files": ["src/db/migrations/0012_add_user_role.ts"] + }, + "id": 22 +} +``` + +--- + +#### 6. Audit Log + +Full record of every agent session — what was intended, what was read, what was written, what was rolled back. + +```json +// smp/audit/get — retrieve the log for a session +{ + "jsonrpc": "2.0", + "method": "smp/audit/get", + "params": { + "audit_log_id": "aud_9b3e1a" + }, + "id": 23 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "audit_log_id": "aud_9b3e1a", + "agent_id": "coder_agent_01", + "task": "Add rate limiting to the login endpoint", + "session_id": "ses_4f8a2c", + "opened_at": "2025-02-15T10:44:00Z", + "closed_at": "2025-02-15T10:45:10Z", + "status": "completed", + "events": [ + {"t": "10:44:01", "method": "smp/guard/check", "target": "src/auth/login.ts", "result": "clear"}, + {"t": "10:44:02", "method": "smp/dryrun", "target": "src/auth/login.ts", "result": "safe"}, + {"t": "10:44:03", "method": "smp/checkpoint", "files": ["src/auth/login.ts"], "checkpoint_id": "chk_3a7f91"}, + {"t": "10:44:05", "method": "smp/lock", "files": ["src/auth/login.ts"], "result": "granted"}, + {"t": "10:44:08", "method": "FILE_WRITE", "file": "src/auth/login.ts"}, + {"t": "10:44:09", "method": "smp/update", "file": "src/auth/login.ts", "result": "success"}, + {"t": "10:44:10", "method": "smp/unlock", "files": ["src/auth/login.ts"]} + ] + }, + "id": 23 +} +``` + +--- + +#### Agent Safety Notifications (Server → Agent) + +```json +// lock collision — another agent wants the same file +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "lock_conflict", + "severity": "warning", + "file": "src/auth/login.ts", + "held_by_session": "ses_7c1d9f", + "held_by_agent": "reviewer_agent_02" + } +} +``` + +```json +// scope violation — agent tried to write outside its declared scope +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "scope_violation", + "severity": "error", + "session_id": "ses_4f8a2c", + "attempted_file": "src/db/models/user.ts", + "declared_scope": ["src/auth/login.ts", "src/middleware/rateLimit.ts"] + } +} +``` + +```json +// session expired — agent took too long, locks auto-released +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "session_expired", + "severity": "error", + "session_id": "ses_4f8a2c", + "expired_at": "2025-02-15T11:30:00Z", + "locks_released": ["src/auth/login.ts"] + } +} +``` + +--- + +## Part 5: Dependency Telemetry + +Telemetry tracks *how nodes change over time*, not just their current state. The key insight: a function that changes frequently AND has many callers is a **Hot Node** — high blast radius, high churn. Any agent touching a Hot Node automatically gets an elevated `safety_level` on its session. + +--- + +#### smp/telemetry/record — Record a node change event + +Called automatically by `smp/update` on every file write. No manual agent call needed. + +```json +// Internal — fired by smp/update on every successful write +{ + "jsonrpc": "2.0", + "method": "smp/telemetry/record", + "params": { + "node_id": "func_authenticate_user", + "event": "modified", + "session_id": "ses_4f8a2c", + "agent_id": "coder_agent_01", + "timestamp": "2025-02-15T10:44:08Z" + }, + "id": 30 +} +``` + +--- + +#### smp/telemetry/hot — Get hot nodes in the graph + +```json +// smp/telemetry/hot — list nodes with high churn AND high dependency count +{ + "jsonrpc": "2.0", + "method": "smp/telemetry/hot", + "params": { + "scope": "full", + "window_days": 30, + "min_changes": 5, + "min_callers": 5 + }, + "id": 31 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "hot_nodes": [ + { + "node_id": "func_authenticate_user", + "file": "src/auth/login.ts", + "changes_in_window": 8, + "caller_count": 12, + "heat_score": 96, // (changes × callers), normalized 0–100 + "last_changed_by": "coder_agent_01", + "last_changed_at": "2025-02-15T10:44:08Z" + }, + { + "node_id": "class_UserModel", + "file": "src/db/models/user.ts", + "changes_in_window": 6, + "caller_count": 21, + "heat_score": 126, + "last_changed_by": "coder_agent_03", + "last_changed_at": "2025-02-14T09:11:00Z" + } + ] + }, + "id": 31 +} +``` + +--- + +#### smp/telemetry/node — Full change history for a specific node + +```json +// smp/telemetry/node — change history for a single node +{ + "jsonrpc": "2.0", + "method": "smp/telemetry/node", + "params": { + "node_id": "func_authenticate_user", + "window_days": 90 + }, + "id": 32 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "node_id": "func_authenticate_user", + "total_changes": 14, + "unique_agents": ["coder_agent_01", "coder_agent_03"], + "history": [ + {"timestamp": "2025-02-15T10:44:08Z", "agent": "coder_agent_01", "session": "ses_4f8a2c", "event": "modified"}, + {"timestamp": "2025-02-10T14:22:00Z", "agent": "coder_agent_03", "session": "ses_1a2b3c", "event": "modified"} + ], + "heat_score": 96, + "stability": "unstable" // "stable" | "moderate" | "unstable" + }, + "id": 32 +} +``` + +--- + +#### Automatic Safety Escalation + +When a session opens and any file in its declared scope contains a Hot Node, `smp/session/open` automatically sets `safety_level: elevated`. Elevated sessions must complete `smp/guard/check` → `smp/dryrun` → `smp/checkpoint` in sequence — no shortcuts allowed. + +```json +// smp/session/open response when scope contains hot nodes +{ + "jsonrpc": "2.0", + "result": { + "session_id": "ses_9x7y6z", + "granted_scope": ["src/auth/login.ts"], + "safety_level": "elevated", // auto-escalated — hot node in scope + "hot_nodes_in_scope": [ + { + "node_id": "func_authenticate_user", + "heat_score": 96, + "caller_count": 12, + "changes_in_window": 8 + } + ], + "expires_at": "2025-02-15T11:30:00Z" + }, + "id": 15 +} +``` + +--- + +## Part 6: Sandbox Runtime + +Every agent write session runs inside an ephemeral, network-isolated container. The sandbox is the physical boundary that makes autonomy safe — the agent can run, fail, self-correct, and iterate without ever touching live infrastructure, live APIs, or other agents' work. + +--- + +#### smp/sandbox/spawn — Request an isolated execution environment + +Spawns a microVM or Docker container from a specific `commit_sha`. The container gets a Copy-on-Write clone of the filesystem state at that SHA, so multiple agents can each have their own independent snapshot without duplicating storage. Network egress is hard-firewalled — only package registries are reachable. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/sandbox/spawn", + "params": { + "session_id": "ses_4f8a2c", + "commit_sha": "a1b2c3d4", + "image": "node:20-alpine", + "services": ["postgres:15", "redis:7"], + "cow_fs_clone": true, + "inject_ebpf": true + }, + "id": 101 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "sandbox_id": "box_99x", + "status": "ready", + "commit_sha": "a1b2c3d4", + "services_started": ["postgres:15", "redis:7"], + "network": { + "egress_policy": "registry_only", + "allowed_registries": ["registry.npmjs.org", "pypi.org"] + }, + "ebpf_injected": true, + "spawned_at": "2025-02-15T10:44:00Z" + }, + "id": 101 +} +``` + +--- + +#### smp/sandbox/execute — Run a command, capture output and eBPF trace + +Runs any shell command inside the sandbox. If `inject_ebpf` was set on spawn, the response includes new `CALLS_RUNTIME` edges discovered during the execution, which SMP automatically injects into the graph. + +If a live external API call is made (e.g. Stripe, SendGrid), the network firewall returns `ECONNREFUSED`. That appears in `stderr`, and the agent reads it to understand it needs to write a local mock. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/sandbox/execute", + "params": { + "sandbox_id": "box_99x", + "command": "npm run test:local", + "timeout_ms": 30000 + }, + "id": 102 +} + +// Response — tests pass, eBPF discovered new runtime edges +{ + "jsonrpc": "2.0", + "result": { + "exit_code": 0, + "stdout": "12 tests passed", + "stderr": "", + "duration_ms": 4200, + "calls_runtime_injected": [ + { + "from": "func_process_payment", + "to": "func_handle_stripe_webhook", + "call_count": 3 + } + ] + }, + "id": 102 +} + +// Response — live API hit, network blocked, stderr shows it +{ + "jsonrpc": "2.0", + "result": { + "exit_code": 1, + "stdout": "", + "stderr": "Error: connect ECONNREFUSED api.stripe.com:443", + "duration_ms": 312, + "calls_runtime_injected": [], + "network_blocks": [ + {"host": "api.stripe.com", "port": 443, "reason": "egress_blocked"} + ] + }, + "id": 102 +} +``` + +Agent reads `network_blocks`, writes a local Stripe mock, re-executes. No human needed. + +--- + +#### smp/sandbox/destroy — Tear down sandbox and release resources + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/sandbox/destroy", + "params": { + "sandbox_id": "box_99x", + "session_id": "ses_4f8a2c" + }, + "id": 103 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "sandbox_id": "box_99x", + "status": "destroyed", + "destroyed_at": "2025-02-15T10:50:00Z", + "resources_freed": { + "filesystem_mb": 240, + "services_stopped": ["postgres:15", "redis:7"] + } + }, + "id": 103 +} +``` + +--- + +#### smp/verify/integrity — AST data-flow + mutation testing gate + +The final gate before handoff. Two checks run in sequence: + +**1. AST Data-Flow Check** — parses the test file's AST and confirms there is a data-flow edge from the tested function's *output* into a formal `assert()` / `expect()` call. Catches vacuous tests that call the function but assert nothing about its result. + +**2. Mutation Testing** — deterministically flips operators in the source file (`<` → `>`, `===` → `!==`, `+1` → `-1`) and re-runs the test suite. If any mutant survives (tests still pass), the assertions are too loose. Gate rejects and returns the surviving mutant so the agent can tighten the test. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/verify/integrity", + "params": { + "sandbox_id": "box_99x", + "target_file": "src/auth/login.ts", + "test_file": "tests/auth/login.test.ts" + }, + "id": 104 +} + +// Response — passed both gates +{ + "jsonrpc": "2.0", + "result": { + "status": "passed", + "coverage_delta_pct": +14, + "ast_assert_check": "passed", + "mutation_score": 1.0, + "mutants_total": 8, + "mutants_killed": 8, + "mutants_survived": 0 + }, + "id": 104 +} + +// Response — surviving mutant detected +{ + "jsonrpc": "2.0", + "result": { + "status": "failed", + "failure_code": "SURVIVING_MUTANT", + "ast_assert_check": "passed", + "mutation_score": 0.75, + "mutants_total": 8, + "mutants_killed": 6, + "mutants_survived": 2, + "survivors": [ + { + "node_id": "func_authenticate_user", + "file": "src/auth/login.ts", + "line": 31, + "original_op": "===", + "mutated_op": "!==", + "surviving_test": "tests/auth/login.test.ts:44" + } + ] + }, + "id": 104 +} + +// Response — no assert connected to function output +{ + "jsonrpc": "2.0", + "result": { + "status": "failed", + "failure_code": "MISSING_AST_ASSERT", + "ast_assert_check": "failed", + "missing_assert_for": [ + {"node_id": "func_authenticate_user", "output_type": "Promise"} + ] + }, + "id": 104 +} +``` + +--- + +## Part 7: Swarm Handoff + +Once a coder agent passes `smp/verify/integrity`, it hands off to a reviewer agent or files a PR directly. The PR carries the full structural diff and execution log — not just code diffs. + +--- + +#### smp/handoff/review — Pass sandbox to a reviewer agent + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/handoff/review", + "params": { + "sandbox_id": "box_99x", + "session_id": "ses_4f8a2c", + "reviewer_agent": "reviewer_agent_02", + "verify_result_id": "ver_8b2d1e" + }, + "id": 105 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "handoff_id": "hnd_7f3a9c", + "reviewer_agent": "reviewer_agent_02", + "sandbox_id": "box_99x", + "status": "pending_review", + "reviewer_session": "ses_rv_5d2f1a" + }, + "id": 105 +} +``` + +--- + +#### smp/handoff/pr — Package verified work as a Pull Request + +Called after peer review passes, or directly if no reviewer agent is configured. Compiles the structural diff, runtime edges discovered, test results, and mutation score into a standard GitHub/GitLab PR payload. + +```json +// Request +{ + "jsonrpc": "2.0", + "method": "smp/handoff/pr", + "params": { + "sandbox_id": "box_99x", + "session_id": "ses_4f8a2c", + "base_sha": "a1b2c3d4", + "title": "fix: rate limiting logic in auth module", + "issue_refs": ["#42"], + "include": { + "structural_diff": true, + "runtime_edges": true, + "mutation_score": true, + "execution_log": true + } + }, + "id": 106 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "pr_id": "pr_gh_1847", + "status": "open", + "base_sha": "a1b2c3d4", + "head_sha": "f9e3c2b1", + "files_changed": ["src/auth/login.ts", "src/middleware/rateLimit.ts"], + "structural_diff": { + "nodes_added": 1, + "nodes_modified": 1, + "signature_changed": false, + "calls_runtime_added": 1 + }, + "test_summary": { + "coverage_delta_pct": +14, + "mutation_score": 1.0 + } + }, + "id": 106 +} +``` + +--- + +#### Swarm Notifications (Server → Agent) + +```json +// Sandbox network block — live API call attempted and blocked +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "network_blocked", + "severity": "info", + "sandbox_id": "box_99x", + "blocked_host": "api.stripe.com", + "blocked_port": 443 + } +} +``` + +```json +// Handoff ready — reviewer agent has accepted the sandbox +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "handoff_accepted", + "severity": "info", + "handoff_id": "hnd_7f3a9c", + "reviewer_agent": "reviewer_agent_02", + "sandbox_id": "box_99x" + } +} +``` + +--- + +## Part 8: Implementation Stack + +### Recommended Technologies + +| Component | Technology | Why | +|-----------|------------|-----| +| **Parser** | Tree-sitter | Multi-language, incremental, fast | +| **Graph DB** | Neo4j / Memgraph | Native graph queries, BM25 full-text index, GDS PageRank, persists sessions + telemetry + CALLS_RUNTIME | +| **Graph DB (lightweight)** | SQLite + recursive CTEs | Single-machine or embedded use | +| **Vector Index** | ChromaDB | code_embedding per node — seed discovery for smp/locate only | +| **Merkle Index** | SHA-256 tree (built in-process) | O(log n) incremental sync — no full re-index; enables secure index distribution | +| **Sandbox Runtime** | Docker / Firecracker microVMs | Ephemeral, CoW filesystem, hard egress firewall | +| **Container Topology** | Testcontainers | Spin up local Postgres, Redis, etc. per sandbox | +| **Runtime Tracing** | eBPF daemon (BCC / libbpf) | Kernel-level call capture — zero app instrumentation needed | +| **Mutation Testing** | Stryker (JS/TS) / mutmut (Python) | Deterministic, no LLM, kills tautological tests | +| **Data Models** | msgspec | Zero-copy, schema-validated structs for internal data flow | +| **Protocol** | JSON-RPC 2.0 | Standard, simple, MCP-compatible | +| **Language** | Python (prototype) → Rust (production) | Start fast, optimize later | + +--- + +### File Structure + +The protocol router uses a **Dispatcher Pattern** — each method group lives in its own handler module with a `@rpc_method` decorator. No god-file `if/elif` chain. + +``` +structural-memory/ +├── server/ +│ ├── core/ +│ │ ├── parser.py # AST extraction (Tree-sitter) +│ │ ├── graph_builder.py # Build structural graph +│ │ ├── linker.py # Static namespaced CALLS resolution +│ │ ├── linker_runtime.py # eBPF trace ingestion → CALLS_RUNTIME edges +│ │ ├── enricher.py # Static metadata extraction +│ │ ├── merkle.py # Merkle tree builder + hash comparator + smp/sync logic +│ │ ├── index_distributor.py # smp/index/export + import + signature verification +│ │ ├── community.py # Louvain detection + centroid computation + MEMBER_OF writes +│ │ ├── telemetry.py # Hot node tracking + heat scores +│ │ ├── store.py # Graph DB interface + full-text index + PageRank setup +│ │ └── chroma_index.py # ChromaDB collection management + code_embedding writes +│ ├── engine/ +│ │ ├── navigator.py # Graph traversal (navigate, trace, flow, why) +│ │ ├── reasoner.py # Proactive context + summary computation +│ │ ├── seed_walk.py # SeedWalkEngine: Seed & Walk pipeline for smp/locate +│ │ └── guard.py # Guard checks, dry run, test-gap analysis +│ ├── sandbox/ +│ │ ├── spawner.py # Docker / Firecracker microVM lifecycle +│ │ ├── executor.py # Command runner + stdout/stderr capture +│ │ ├── ebpf_collector.py # eBPF daemon interface + trace → graph edges +│ │ ├── network_policy.py # Egress firewall rules + block notifications +│ │ └── verifier.py # AST data-flow check + mutation test runner +│ ├── protocol/ +│ │ ├── dispatcher.py # @rpc_method decorator + method registry +│ │ └── handlers/ +│ │ ├── memory.py # smp/update, batch_update, sync, merkle/tree +│ │ ├── index.py # smp/index/export, import +│ │ ├── community.py # smp/community/detect, list, get +│ │ ├── query.py # smp/navigate, trace, context, impact, locate, flow, diff, why +│ │ ├── enrichment.py # smp/enrich, annotate, tag, search +│ │ ├── safety.py # smp/session/*, guard/check, dryrun, checkpoint, lock, audit +│ │ ├── planning.py # smp/plan, conflict +│ │ ├── sandbox.py # smp/sandbox/spawn, execute, destroy +│ │ ├── verify.py # smp/verify/integrity +│ │ ├── handoff.py # smp/handoff/review, pr +│ │ └── telemetry.py # smp/telemetry/* +│ └── main.py # Server entry point + full-text index init +├── clients/ +│ ├── python_client.py # Python SDK for agents +│ ├── typescript_client.ts # TS SDK for agents +│ └── cli.py # Manual interaction +├── watchers/ +│ ├── file_watcher.py # Watch for file changes +│ └── git_hook.py # Git-based updates +└── tests/ + └── ... +``` + +**Dispatcher pattern:** + +```python +# protocol/dispatcher.py +_registry: dict[str, Callable] = {} + +def rpc_method(name: str): + def decorator(fn): + _registry[name] = fn + return fn + return decorator + +def dispatch(method: str, params: dict, context: ServerContext): + handler = _registry.get(method) + if not handler: + raise MethodNotFound(method) + return handler(params, context) +``` + +```python +# protocol/handlers/query.py +from protocol.dispatcher import rpc_method + +@rpc_method("smp/navigate") +def handle_navigate(params, ctx): + return ctx.engine.navigator.navigate(params["query"], params.get("include_relationships", False)) + +@rpc_method("smp/trace") +def handle_trace(params, ctx): + return ctx.engine.navigator.trace(params["start"], params["relationship"], params.get("depth", 3)) +``` + +--- + +## Part 9: Agent Integration Example + +### Agent Workflow with SMP + +```python +class CodingAgent: + def __init__(self, smp_client): + self.smp = smp_client + + def edit_file(self, file_path, instruction): + # 1. Open a session — declare scope upfront + session = self.smp.call("smp/session/open", { + "agent_id": self.agent_id, + "task": instruction, + "scope": [file_path], + "mode": "write" + }) + + # 2. Pre-flight guard check + guard = self.smp.call("smp/guard/check", { + "session_id": session["session_id"], + "target": file_path + }) + if guard["verdict"] == "blocked": + raise AbortError(guard["reasons"]) + + # 3. Get full structural context + context = self.smp.call("smp/context", { + "file_path": file_path, + "scope": "edit" + }) + + # 4. Dry run the proposed change + dryrun = self.smp.call("smp/dryrun", { + "session_id": session["session_id"], + "file_path": file_path, + "proposed_content": new_code, + }) + if dryrun["verdict"] == "breaking": + raise AbortError(dryrun["risks"]) + + # 5. Checkpoint, write, sync memory + self.smp.call("smp/checkpoint", {"session_id": session["session_id"], "files": [file_path]}) + write_to_disk(file_path, new_code) + self.smp.call("smp/update", {"file_path": file_path, "content": new_code, "change_type": "modified"}) + + # 6. Close session + self.smp.call("smp/session/close", {"session_id": session["session_id"], "status": "completed"}) +``` + +--- + +## Summary + +| Component | Purpose | +|-----------|---------| +| **Parser** | Extract AST from code (Tree-sitter) | +| **Graph Builder** | Create structural relationships | +| **Static Linker** | Namespace-aware cross-file CALLS resolution — no ambiguous edges | +| **Runtime Linker** | eBPF execution traces → `CALLS_RUNTIME` edges — resolves DI and metaprogramming | +| **Enricher** | Attach static metadata — docstrings, annotations, tags, code_embedding | +| **Graph DB** | Neo4j — structure, `CALLS_STATIC`, `CALLS_RUNTIME`, PageRank, sessions, telemetry, BM25 index | +| **Vector Index** | ChromaDB — `code_embedding` per node for Seed phase of `smp/locate` | +| **Merkle Index** | SHA-256 tree over all file nodes — O(log n) incremental sync, powers `smp/sync` + secure index distribution | +| **SeedWalkEngine** | `smp/locate` pipeline: Vector seed → Cypher N-hop walk → composite rank → structural_map | +| **Query Engine** | navigate, trace, context (+summary), impact, locate, flow, diff, plan, conflict, why | +| **SMP Protocol** | JSON-RPC 2.0 via Dispatcher — handlers split by domain, no god file | +| **Agent Safety** | Sessions (persisted, MVCC or exclusive), guard checks, dry runs, checkpoints, audit log | +| **Telemetry** | Hot node tracking, heat scores, automatic safety escalation | +| **Community Detection** | Two-level Louvain (coarse + fine) — powers Graph RAG routing, `smp/community/boundaries` for architecture agents | +| **Sandbox Runtime** | Ephemeral microVM/Docker, CoW filesystem, hard egress firewall, eBPF trace capture | +| **Integrity Gate** | AST data-flow check + deterministic mutation testing — anti-gamification, no LLM | +| **Swarm Handoff** | Peer review pass-off + structured PR with structural diff, runtime edges, mutation score | + +--- + diff --git a/smp.md b/smp.md new file mode 100644 index 0000000..31d614a --- /dev/null +++ b/smp.md @@ -0,0 +1,722 @@ +# The Structural Memory Protocol (SMP) + +A framework for giving AI agents a "programmer's brain" — not text retrieval, but structural understanding. + +--- + +## Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ CODEBASE (Files) │ +└──────────────────────────┬──────────────────────────────────────┘ + │ Updates (Watch / Agent Push) + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ MEMORY SERVER (SMP Core) │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ PARSER │─▶│ GRAPH BUILDER│──▶│ ENRICHER │ │ +│ │ (AST/Tree- │ │ (Structural │ │ (Semantic │ │ +│ │ sitter) │ │ Analysis) │ │ Layer) │ │ +│ └─────────────┘ └─────────────┘ └──────┬──────┘ │ +│ │ │ +│ ┌───────────────────────▼──────────────┐ │ +│ │ MEMORY STORE │ │ +│ │ ┌─────────────┐ ┌──────────────┐ │ │ +│ │ │ GRAPH DB │ │ VECTOR STORE │ │ │ +│ │ │ (Structure) │ │ (Purpose) │ │ │ +│ │ └─────────────┘ └──────────────┘ │ │ +│ └───────────────────────┬──────────────┘ │ +└─────────────────────────────────────────────┼──────────────────-┘ + │ + ┌─────────────────────────▼──────────────────┐ + │ QUERY ENGINE (SMP Interface) │ + │ ┌────────────┐ ┌────────────┐ │ + │ │ Navigator │ │ Reasoner │ │ + │ │ (Graph │ │ (Proactive │ │ + │ │ Traversal)│ │ Context) │ │ + │ └────────────┘ └────────────┘ │ + └───────────────────────┬────────────────────┘ + │ SMP Protocol + ▼ + ┌─────────────────────────────────────────────┐ + │ AGENT LAYER │ + │ Agent A Agent B Agent C │ + │ (Coder) (Reviewer) (Architect) │ + └─────────────────────────────────────────────┘ +``` + +--- + +## Part 1: The Memory Server + +### A. Parser (AST Extraction) + +**Technology:** Tree-sitter (multi-language, fast, incremental) + +**Input:** File path + content + +**Output:** Abstract Syntax Tree with typed nodes + +```python +# What gets extracted per file +{ + "file_path": "src/auth/login.ts", + "language": "typescript", + "nodes": [ + { + "id": "func_authenticate_user", + "type": "function_declaration", + "name": "authenticateUser", + "start_line": 15, + "end_line": 42, + "signature": "authenticateUser(email: string, password: string): Promise", + "docstring": "Validates user credentials and returns JWT...", + "modifiers": ["async", "export"] + }, + { + "id": "class_AuthService", + "type": "class_declaration", + "name": "AuthService", + "methods": ["login", "logout", "refresh"], + "properties": ["tokenExpiry", "secretKey"] + } + ], + "imports": [ + {"from": "./utils/crypto", "items": ["hashPassword", "compareHash"]}, + {"from": "../db/user", "items": ["UserModel"]} + ], + "exports": ["authenticateUser", "AuthService"] +} +``` + +--- + +### B. Graph Builder (Structural Analysis) + +**Graph Schema:** + +``` +┌─────────────────────────────────────────────────────────────┐ +│ NODE TYPES │ +├─────────────────────────────────────────────────────────────┤ +│ Repository │ Root node │ +│ Package │ Directory/module │ +│ File │ Source file │ +│ Class │ Class definition │ +│ Function │ Function/method │ +│ Variable │ Variable/constant │ +│ Interface │ Type definition/interface │ +│ Test │ Test file/function │ +│ Config │ Configuration file │ +└─────────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────────┐ +│ RELATIONSHIP TYPES │ +├─────────────────────────────────────────────────────────────┤ +│ CONTAINS │ Parent-child (Package → File) │ +│ IMPORTS │ File imports File/Module │ +│ DEFINES │ File defines Class/Function │ +│ CALLS │ Function calls Function │ +│ INHERITS │ Class inherits Class │ +│ IMPLEMENTS │ Class implements Interface │ +│ DEPENDS_ON │ General dependency │ +│ TESTS │ Test tests Function/Class │ +│ USES │ Function uses Variable/Type │ +│ REFERENCES │ Variable references Variable │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Example Graph Node:** + +```json +{ + "id": "func_authenticate_user", + "type": "Function", + "name": "authenticateUser", + "file": "src/auth/login.ts", + "signature": "authenticateUser(email: string, password: string): Promise", + "metrics": { + "complexity": 4, + "lines": 28, + "parameters": 2 + }, + "relationships": { + "CALLS": ["func_hashPassword", "func_compareHash", "func_generateToken"], + "DEPENDS_ON": ["class_UserModel"], + "DEFINED_IN": "file_auth_login_ts" + } +} +``` + +--- + +### C. Semantic Enricher + +**Purpose:** Add meaning to structural nodes. + +**Process:** + +1. **Static Analysis (No LLM needed):** + - Extract docstrings + - Parse comments + - Infer from naming conventions (`getUserById` → "retrieves user by identifier") + - Extract type information + +2. **LLM Enrichment (One-time per node):** + ``` + Prompt: "In 1 sentence, what is the PURPOSE of this code in the system?" + + Input: + - Function signature + - Docstring + - Dependencies + - Called-by relationships + + Output: + "Handles user authentication by validating credentials against the database + and issuing JWT tokens for session management." + ``` + +3. **Embedding Generation:** + - Embed the purpose + signature + key context + - Store in vector database for similarity search + +**Enriched Node:** + +```json +{ + "id": "func_authenticate_user", + "structural": { ... }, + "semantic": { + "purpose": "Handles user authentication by validating credentials against the database and issuing JWT tokens for session management", + "keywords": ["auth", "login", "jwt", "credentials", "validation"], + "embedding": [0.123, -0.456, ...], + "last_enriched": "2025-02-15T10:30:00Z", + "confidence": 0.92 + } +} +``` + +--- + +## Part 2: The Query Engine + +### Query Types + +| Type | Purpose | Example | +|------|---------|---------| +| **Navigate** | Find specific entities | "Where is `login` defined?" | +| **Trace** | Follow relationships | "What calls `authenticateUser`?" | +| **Context** | Get relevant context | "I'm editing `auth.ts`, what do I need to know?" | +| **Impact** | Assess change impact | "If I delete this, what breaks?" | +| **Locate** | Find by description | "Where is user registration handled?" | +| **Flow** | Trace data/logic path | "How does a request become a DB entry?" | + +--- + +### Query Engine Implementation + +```python +class StructuralQueryEngine: + def __init__(self, graph_db, vector_store): + self.graph = graph_db + self.vectors = vector_store + + def navigate(self, entity_name: str, direction: str = "to"): + """Find entity and its relationships""" + pass + + def trace(self, start_id: str, relationship_type: str, depth: int = 3): + """Follow relationship chain""" + pass + + def get_context(self, file_path: str, scope: str = "edit"): + """ + Proactive context gathering. + + scope options: + - "edit": What do I need to edit this file safely? + - "create": What pattern should I follow for new file? + - "debug": What's the data flow through this file? + """ + pass + + def assess_impact(self, entity_id: str, change_type: str): + """What would break if I change/delete this?""" + pass + + def locate_by_intent(self, description: str): + """Find code by what it does, not its name""" + # Vector search on semantic embeddings + # Return ranked structural matches + pass + + def trace_flow(self, start: str, end: str = None): + """Trace execution/data flow""" + pass +``` + +--- + +### The `get_context()` Method (Most Important for Agents) + +```python +def get_context(self, file_path: str, scope: str = "edit"): + """ + Returns the "programmer's mental model" for a file. + """ + file_node = self.graph.get_node_by_path(file_path) + + context = { + "self": file_node, # What is this file? + + "imports": self.graph.get_relationships( + file_node, "IMPORTS", direction="outgoing" + ), # What does it depend on? + + "imported_by": self.graph.get_relationships( + file_node, "IMPORTS", direction="incoming" + ), # Who depends on it? + + "defines": self.graph.get_relationships( + file_node, "DEFINES", direction="outgoing" + ), # What's inside? + + "related_patterns": self.vectors.find_similar( + file_node.semantic.embedding, top_k=5 + ), # Similar files (pattern reference) + + "entry_points": self.graph.find_entry_points(file_node), + + "data_flow_in": self.trace_data_flow(file_node, direction="in"), + + "data_flow_out": self.trace_data_flow(file_node, direction="out"), + } + + return context +``` + +--- + +## Part 3: The Protocol (SMP) + +### Protocol Specification + +**Name:** Structural Memory Protocol (SMP) +**Version:** 1.0 +**Transport:** JSON-RPC 2.0 over stdio / HTTP / WebSocket +**Inspired by:** MCP (Model Context Protocol), A2A (Agent-to-Agent) + +--- + +### Protocol Methods + +#### 1. Memory Management + +```json +// smp/update - Sync codebase state +{ + "jsonrpc": "2.0", + "method": "smp/update", + "params": { + "type": "file_change", + "file_path": "src/auth/login.ts", + "content": "...", + "change_type": "modified" | "created" | "deleted" + }, + "id": 1 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "status": "success", + "nodes_added": 3, + "nodes_updated": 12, + "nodes_removed": 1, + "relationships_updated": 8 + }, + "id": 1 +} +``` + +```json +// smp/batch_update - Multiple files at once +{ + "jsonrpc": "2.0", + "method": "smp/batch_update", + "params": { + "changes": [ + {"file_path": "src/auth/login.ts", "content": "...", "change_type": "modified"}, + {"file_path": "src/auth/middleware.ts", "content": "...", "change_type": "created"} + ] + }, + "id": 2 +} +``` + +```json +// smp/reindex - Full re-index (for major refactors) +{ + "jsonrpc": "2.0", + "method": "smp/reindex", + "params": { + "scope": "full" | "package:src/auth" + }, + "id": 3 +} +``` + +--- + +#### 2. Structural Queries + +```json +// smp/navigate - Find entity and basic info +{ + "jsonrpc": "2.0", + "method": "smp/navigate", + "params": { + "query": "authenticateUser", + "include_relationships": true + }, + "id": 4 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "entity": { + "id": "func_authenticate_user", + "type": "Function", + "file": "src/auth/login.ts", + "signature": "authenticateUser(email: string, password: string): Promise", + "purpose": "Handles user authentication..." + }, + "relationships": { + "calls": ["hashPassword", "compareHash", "generateToken"], + "called_by": ["loginRoute", "test_auth"], + "depends_on": ["UserModel", "TokenService"] + } + }, + "id": 4 +} +``` + +```json +// smp/trace - Follow relationship chain +{ + "jsonrpc": "2.0", + "method": "smp/trace", + "params": { + "start": "func_authenticate_user", + "relationship": "CALLS", + "depth": 3, + "direction": "outgoing" + }, + "id": 5 +} + +// Response: Returns the call graph as a tree +{ + "jsonrpc": "2.0", + "result": { + "root": "authenticateUser", + "tree": { + "authenticateUser": { + "calls": { + "hashPassword": {"calls": {"bcrypt.hash": {}}}, + "compareHash": {"calls": {"bcrypt.compare": {}}}, + "generateToken": {"calls": {"jwt.sign": {}}} + } + } + } + }, + "id": 5 +} +``` + +--- + +#### 3. Context Queries (Proactive) + +```json +// smp/context - Get editing context +{ + "jsonrpc": "2.0", + "method": "smp/context", + "params": { + "file_path": "src/auth/login.ts", + "scope": "edit", // "edit" | "create" | "debug" | "review" + "depth": 2 + }, + "id": 6 +} + +// Response: Full context needed to edit this file safely +{ + "jsonrpc": "2.0", + "result": { + "self": {...}, + "imports": [...], + "imported_by": [...], + "functions_defined": [...], + "classes_defined": [...], + "tests": ["tests/auth.test.ts"], + "patterns": ["src/api/users.ts (similar structure)"], + "warnings": ["This file is imported by 12 other files"] + }, + "id": 6 +} +``` + +```json +// smp/impact - Assess change impact +{ + "jsonrpc": "2.0", + "method": "smp/impact", + "params": { + "entity": "func_authenticate_user", + "change_type": "signature_change" | "delete" | "move" + }, + "id": 7 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "affected_files": [ + "src/api/routes.ts", + "tests/auth.test.ts", + "src/middleware/auth.ts" + ], + "affected_functions": ["loginRoute", "test_authenticate_user"], + "severity": "high", + "recommendations": [ + "Update loginRoute in routes.ts to match new signature", + "Update test cases in auth.test.ts" + ] + }, + "id": 7 +} +``` + +--- + +#### 4. Semantic Search + +```json +// smp/locate - Find by description +{ + "jsonrpc": "2.0", + "method": "smp/locate", + "params": { + "description": "where is user registration handled", + "top_k": 5 + }, + "id": 8 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "matches": [ + { + "entity": "func_register_user", + "file": "src/auth/register.ts", + "purpose": "Handles new user registration...", + "relevance": 0.94 + }, + { + "entity": "class_UserService", + "file": "src/services/user.ts", + "purpose": "Manages user CRUD operations...", + "relevance": 0.87 + } + ] + }, + "id": 8 +} +``` + +--- + +#### 5. Flow Analysis + +```json +// smp/flow - Trace execution/data flow +{ + "jsonrpc": "2.0", + "method": "smp/flow", + "params": { + "start": "api_route_login", + "end": "database_write_user", + "flow_type": "data" | "execution" + }, + "id": 9 +} + +// Response +{ + "jsonrpc": "2.0", + "result": { + "path": [ + {"node": "api_route_login", "type": "endpoint"}, + {"node": "auth_middleware", "type": "middleware"}, + {"node": "authenticateUser", "type": "function"}, + {"node": "UserModel.findByEmail", "type": "method"}, + {"node": "generateToken", "type": "function"}, + {"node": "response_json", "type": "output"} + ], + "data_transformations": [ + "Request body → credentials object", + "Credentials → validated user", + "User → JWT token" + ] + }, + "id": 9 +} +``` + +--- + +### Event Notifications (Server → Agent) + +```json +// Notification: Memory updated +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "memory_updated", + "changes": { + "files_affected": ["src/auth/login.ts"], + "structural_changes": ["func_authenticate_user modified"], + "semantic_changes": ["purpose re-enriched"] + } + } +} +``` + +```json +// Notification: Conflict detected +{ + "jsonrpc": "2.0", + "method": "smp/notification", + "params": { + "type": "conflict_detected", + "severity": "warning", + "message": "File modified by external process, memory may be stale", + "file": "src/auth/login.ts" + } +} +``` + +--- + +## Part 4: Implementation Stack + +### Recommended Technologies + +| Component | Technology | Why | +|-----------|------------|-----| +| **Parser** | Tree-sitter | Multi-language, incremental, fast | +| **Graph DB** | Neo4j / Memgraph / SQLite (if lightweight) | Native graph queries | +| **Vector Store** | Chroma / Qdrant / LanceDB | Semantic search | +| **Embedding** | OpenAI text-embedding-3-small | Good balance of speed/quality | +| **Protocol** | JSON-RPC 2.0 | Standard, simple, MCP-compatible | +| **Language** | Python (prototype) → Rust (production) | Start fast, optimize later | + +--- + +### File Structure + +``` +structural-memory/ +├── server/ +│ ├── core/ +│ │ ├── parser.py # AST extraction (Tree-sitter) +│ │ ├── graph_builder.py # Build structural graph +│ │ ├── enricher.py # Semantic enrichment +│ │ └── store.py # Graph + Vector store interface +│ ├── engine/ +│ │ ├── query.py # Query processing +│ │ ├── navigator.py # Graph traversal +│ │ └── reasoner.py # Proactive context +│ ├── protocol/ +│ │ ├── smp_handler.py # JSON-RPC handler +│ │ └── methods.py # SMP method implementations +│ └── main.py # Server entry point +├── clients/ +│ ├── python_client.py # Python SDK for agents +│ ├── typescript_client.ts # TS SDK for agents +│ └── cli.py # Manual interaction +├── watchers/ +│ ├── file_watcher.py # Watch for file changes +│ └── git_hook.py # Git-based updates +└── tests/ + └── ... +``` + +--- + +## Part 5: Agent Integration Example + +### Agent Workflow with SMP + +```python +class CodingAgent: + def __init__(self, smp_client): + self.smp = smp_client + + def edit_file(self, file_path, instruction): + # 1. Get structural context + context = self.smp.call("smp/context", { + "file_path": file_path, + "scope": "edit" + }) + + # 2. Understand impact + impact = self.smp.call("smp/impact", { + "entity": context["self"]["id"], + "change_type": "signature_change" + }) + + # 3. Make the edit (with context-aware prompt) + new_code = self.llm.edit( + current_code=context["self"]["content"], + instruction=instruction, + context=context, + warnings=impact + ) + + # 4. Update memory + self.smp.call("smp/update", { + "file_path": file_path, + "content": new_code, + "change_type": "modified" + }) + + return new_code +``` + +--- + +## Summary + +| Component | Purpose | +|-----------|---------| +| **Parser** | Extract AST from code (Tree-sitter) | +| **Graph Builder** | Create structural relationships | +| **Enricher** | Add semantic meaning to nodes | +| **Memory Store** | Graph DB + Vector Store | +| **Query Engine** | Navigate, trace, context, impact, locate, flow | +| **SMP Protocol** | JSON-RPC interface for agent communication | + +--- + diff --git a/smp/__init__.py b/smp/__init__.py new file mode 100644 index 0000000..7fe7acb --- /dev/null +++ b/smp/__init__.py @@ -0,0 +1,3 @@ +"""SMP — Structural Memory Protocol.""" + +__version__ = "0.1.0" diff --git a/smp/agent.py b/smp/agent.py new file mode 100644 index 0000000..4d4b6ac --- /dev/null +++ b/smp/agent.py @@ -0,0 +1,431 @@ +"""CodingAgent — AI coding agent powered by Structural Memory Protocol. + +Wraps :class:`SMPClient` into a six-step workflow that gathers structural +context, assesses change impact, asks an LLM to generate an edit, writes +the result to disk, and syncs the graph back. + +Usage:: + + from smp.agent import CodingAgent + from smp.client import SMPClient + + async with SMPClient("http://localhost:8420") as client: + agent = CodingAgent(client, zen_api_key="...") + result = await agent.run( + file_path="src/auth.py", + instruction="Add rate limiting to the login endpoint", + ) + print(result["summary"]) +""" + +from __future__ import annotations + +import os +import re +import time +from pathlib import Path +from typing import Any + +import msgspec + +from smp.client import SMPClient +from smp.logging import get_logger + +log = get_logger(__name__) + + +class AgentError(Exception): + """Raised when the agent cannot complete its workflow.""" + + +class AgentResult(msgspec.Struct): + """Outcome of a single :meth:`CodingAgent.run` invocation.""" + + file_path: str + instruction: str + original_content: str + edited_content: str + context: dict[str, Any] = msgspec.field(default_factory=dict) + impact: dict[str, Any] = msgspec.field(default_factory=dict) + summary: str = "" + nodes_synced: int = 0 + edges_synced: int = 0 + + +# --------------------------------------------------------------------------- +# Gemini LLM backend (lazy import, mirrors enricher pattern) +# --------------------------------------------------------------------------- + + +class _GeminiBackend: + """Wraps Google Gemini API for code-edit generation using Gemma 3.""" + + def __init__(self, api_key: str, model: str = "gemma-3-27b-it") -> None: + from google import genai + + self._client = genai.Client(api_key=api_key) + self._model = model + + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate a response from the model.""" + response = self._client.models.generate_content( + model=self._model, + contents=f"{system_prompt}\n\n{user_prompt}", + ) + return str(response.text or "") + + +# --------------------------------------------------------------------------- +# CodingAgent +# --------------------------------------------------------------------------- + + +class CodingAgent: + """AI coding agent that uses SMP for structural awareness. + + The agent follows a six-step workflow: + + 1. **Context** — query ``smp/context`` for the file's mental model. + 2. **Impact** — query ``smp/impact`` for blast-radius analysis. + 3. **Generate** — send context + instruction to the LLM for an edit. + 4. **Write** — persist the edited file to disk. + 5. **Sync** — call ``smp/update`` so SMP re-parses the changed file. + + Args: + client: Connected :class:`SMPClient` instance. + gemini_api_key: Google Gemini API key. Falls back to GEMINI_API_KEY or GOOGLE_API_KEY env var. + model: Gemini model name (default: gemma-3-27b-it). + """ + + def __init__( + self, + client: SMPClient, + *, + gemini_api_key: str | None = None, + model: str = "gemma-3-27b-it", + ) -> None: + self._client = client + self._llm: _GeminiBackend | None = None + + key = gemini_api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") + if key: + try: + self._llm = _GeminiBackend(api_key=key, model=model) + log.info("agent_llm_ready", model=model) + except Exception as exc: + log.warning("agent_llm_init_failed", error=str(exc)) + else: + log.warning("agent_no_llm", reason="no_api_key") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def run(self, file_path: str, instruction: str) -> AgentResult: + """Execute the full agent workflow and return the result. + + Args: + file_path: Path to the source file to edit. + instruction: Natural-language description of the desired change. + + Returns: + An :class:`AgentResult` with before/after content and metadata. + + Raises: + AgentError: On unrecoverable failures (missing file, no LLM, etc.). + """ + workflow_id = f"wf_{int(time.monotonic() * 1000)}" + log.info( + "agent_workflow_start", + workflow_id=workflow_id, + file_path=file_path, + instruction=instruction[:120], + ) + + t_start = time.monotonic() + + # Step 1 — read the current file + original = await self._read_file(file_path) + log.info("agent_step_complete", step=1, label="read_file", workflow_id=workflow_id) + + # Step 2 — structural context + context = await self._step_context(file_path, workflow_id) + + # Step 3 — impact assessment + impact = await self._step_impact(file_path, context, workflow_id) + + # Step 4 — LLM edit generation + edited = await self._step_generate(file_path, instruction, original, context, impact, workflow_id) + + # Step 5 — write to disk + await self._step_write(file_path, edited, workflow_id) + + # Step 6 — sync back into structural memory + sync_result = await self._step_sync(file_path, edited, workflow_id) + + elapsed = round(time.monotonic() - t_start, 2) + nodes = sync_result.get("nodes", 0) + edges = sync_result.get("edges", 0) + + summary = f"Edited {file_path}: {instruction}. Synced {nodes} nodes, {edges} edges in {elapsed}s." + + log.info( + "agent_workflow_complete", + workflow_id=workflow_id, + file_path=file_path, + elapsed_s=elapsed, + nodes=nodes, + edges=edges, + ) + + return AgentResult( + file_path=file_path, + instruction=instruction, + original_content=original, + edited_content=edited, + context=context, + impact=impact, + summary=summary, + nodes_synced=nodes, + edges_synced=edges, + ) + + # ------------------------------------------------------------------ + # Step implementations + # ------------------------------------------------------------------ + + async def _read_file(self, file_path: str) -> str: + """Read file content from disk.""" + log.info("agent_read_file", file_path=file_path) + path = Path(file_path) + if not path.exists(): + raise AgentError(f"File not found: {file_path}") + content = path.read_text(encoding="utf-8") + log.info("agent_file_read", file_path=file_path, size_bytes=len(content)) + return content + + async def _step_context(self, file_path: str, workflow_id: str) -> dict[str, Any]: + """Step 2 — query SMP for the file's structural context.""" + log.info("agent_step_start", step=2, label="context", workflow_id=workflow_id) + + ctx = await self._client.get_context(file_path, scope="edit", depth=2) + + node_count = len(ctx.get("nodes", [])) + edge_count = len(ctx.get("edges", [])) + types = self._summarise_node_types(ctx.get("nodes", [])) + + log.info( + "agent_context_ready", + workflow_id=workflow_id, + nodes=node_count, + edges=edge_count, + **types, + ) + return ctx + + async def _step_impact( + self, + file_path: str, + context: dict[str, Any], + workflow_id: str, + ) -> dict[str, Any]: + """Step 3 — assess the blast radius of modifying *file_path*.""" + log.info("agent_step_start", step=3, label="impact", workflow_id=workflow_id) + + nodes = context.get("nodes", []) + target_id = self._pick_impact_target(nodes, file_path) + + if not target_id: + log.info("agent_impact_skip", workflow_id=workflow_id, reason="no_entity_found") + return {"entity": None, "affected_nodes": [], "total_affected": 0} + + impact = await self._client.assess_impact(target_id, change_type="modify") + affected = impact.get("affected_nodes", []) + + # Build a concise summary of downstream effects + downstream = self._format_downstream(affected) + + log.info( + "agent_impact_assessed", + workflow_id=workflow_id, + entity=target_id, + affected_count=len(affected), + downstream=downstream[:8], + ) + return impact + + async def _step_generate( + self, + file_path: str, + instruction: str, + original: str, + context: dict[str, Any], + impact: dict[str, Any], + workflow_id: str, + ) -> str: + """Step 4 — ask the LLM to produce an edited version of the file.""" + log.info("agent_step_start", step=4, label="generate", workflow_id=workflow_id) + + if not self._llm: + raise AgentError("No LLM backend. Set GEMINI_API_KEY or GOOGLE_API_KEY to enable edit generation.") + + system_prompt = self._build_system_prompt() + user_prompt = self._build_user_prompt( + file_path=file_path, + instruction=instruction, + original=original, + context=context, + impact=impact, + ) + + log.info("agent_llm_call", workflow_id=workflow_id, model=self._llm._model) + raw = self._llm.generate(system_prompt, user_prompt) + edited = self._extract_code(raw) + + log.info( + "agent_llm_response", + workflow_id=workflow_id, + raw_chars=len(raw), + edited_chars=len(edited), + ) + return edited + + async def _step_write(self, file_path: str, content: str, workflow_id: str) -> None: + """Step 5 — write the edited content to disk.""" + log.info("agent_step_start", step=5, label="write", workflow_id=workflow_id) + + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8") + + log.info("agent_file_written", file_path=file_path, size_bytes=len(content)) + + async def _step_sync(self, file_path: str, content: str, workflow_id: str) -> dict[str, Any]: + """Step 6 — push the changed file back into the structural memory.""" + log.info("agent_step_start", step=6, label="sync", workflow_id=workflow_id) + + result = await self._client.update(file_path, content=content) + + log.info( + "agent_sync_complete", + workflow_id=workflow_id, + file_path=file_path, + nodes=result.get("nodes", 0), + edges=result.get("edges", 0), + enriched=result.get("enriched", 0), + errors=result.get("errors", 0), + ) + return result + + # ------------------------------------------------------------------ + # Prompt construction + # ------------------------------------------------------------------ + + @staticmethod + def _build_system_prompt() -> str: + return ( + "You are an expert software engineer. You will receive a source file, " + "its structural context (classes, functions, imports, relationships), " + "and an instruction for how to modify it.\n\n" + "Rules:\n" + "- Return ONLY the complete modified file content.\n" + "- Do NOT wrap in markdown code fences.\n" + "- Do NOT add explanations before or after the code.\n" + "- Preserve existing style, conventions, and imports.\n" + "- Only change what the instruction requires.\n" + "- Ensure the result is syntactically valid." + ) + + @staticmethod + def _build_user_prompt( + *, + file_path: str, + instruction: str, + original: str, + context: dict[str, Any], + impact: dict[str, Any], + ) -> str: + parts: list[str] = [] + + parts.append(f"## File: {file_path}") + parts.append(f"## Instruction\n{instruction}") + + # Context block + nodes = context.get("nodes", []) + edges = context.get("edges", []) + if nodes: + parts.append("## Structural Context") + for n in nodes[:30]: + sem = n.get("semantic") + purpose = f" — {sem['purpose']}" if sem and sem.get("purpose") else "" + parts.append(f" - {n['type']} {n['name']} (L{n['start_line']}-{n['end_line']}){purpose}") + if edges: + parts.append(f" ({len(edges)} relationships)") + + # Impact block + affected = impact.get("affected_nodes", []) + if affected: + parts.append(f"## Impact Analysis — {len(affected)} downstream entities affected") + for a in affected[:10]: + parts.append(f" - {a['type']} {a['name']} in {a['file_path']}") + if len(affected) > 10: + parts.append(f" ... and {len(affected) - 10} more") + + # Original source + parts.append(f"## Current Source\n```\n{original}\n```") + parts.append("## Modified Source") + + return "\n\n".join(parts) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_code(llm_response: str) -> str: + """Extract source code from an LLM response. + + Handles responses wrapped in markdown code fences as well as raw code. + """ + fenced: list[str] = re.findall(r"```(?:\w*)\n(.*?)```", llm_response, re.DOTALL) + if fenced: + return str(fenced[0].strip()) + # No fences — strip common LLM preamble lines + lines = llm_response.split("\n") + start = 0 + for i, line in enumerate(lines): + stripped = line.strip() + if stripped and not stripped.startswith("#") and not stripped.startswith("//"): + start = i + break + return "\n".join(lines[start:]).strip() + + @staticmethod + def _summarise_node_types(nodes: list[dict[str, Any]]) -> dict[str, int]: + """Count nodes by type for structured log output.""" + counts: dict[str, int] = {} + for n in nodes: + t = n.get("type", "UNKNOWN") + counts[t] = counts.get(t, 0) + 1 + return counts + + @staticmethod + def _pick_impact_target(nodes: list[dict[str, Any]], file_path: str) -> str | None: + """Choose the best entity for impact analysis. + + Prefers the first FUNCTION or CLASS node; falls back to the FILE node. + """ + file_node_id: str | None = None + for n in nodes: + ntype = str(n.get("type", "")) + nid = str(n.get("id", "")) + if ntype in ("FUNCTION", "CLASS"): + return nid + if ntype == "FILE" and not file_node_id: + file_node_id = nid + return file_node_id + + @staticmethod + def _format_downstream(affected: list[dict[str, Any]]) -> list[str]: + """Format affected nodes into compact summary strings.""" + return [f"{a.get('type', '?')} {a.get('name', '?')} @ {a.get('file_path', '?')}" for a in affected] diff --git a/smp/cli.py b/smp/cli.py new file mode 100644 index 0000000..46af79c --- /dev/null +++ b/smp/cli.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import argparse +import asyncio +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +from smp.logging import configure_logging, get_logger + +load_dotenv(Path(__file__).parent.parent / ".env") + +log = get_logger(__name__) + +DEFAULT_EXTENSIONS = (".py", ".ts", ".tsx", ".js", ".jsx") +DEFAULT_MAX_FILE_SIZE = 1_000_000 + + +async def ingest_directory( + directory: str, + *, + neo4j_uri: str | None = None, + neo4j_user: str | None = None, + neo4j_password: str | None = None, + extensions: tuple[str, ...] = DEFAULT_EXTENSIONS, + max_file_size: int = DEFAULT_MAX_FILE_SIZE, + clear: bool = False, +) -> dict[str, int]: + """Walk *directory*, parse all matching files, and build the graph.""" + from smp.engine.enricher import StaticSemanticEnricher + from smp.engine.graph_builder import DefaultGraphBuilder + from smp.parser.registry import ParserRegistry + from smp.store.graph.neo4j_store import Neo4jGraphStore + + registry = ParserRegistry() + graph_store = Neo4jGraphStore( + uri=neo4j_uri or os.environ.get("SMP_NEO4J_URI", "bolt://localhost:7687"), + user=neo4j_user or os.environ.get("SMP_NEO4J_USER", "neo4j"), + password=neo4j_password or os.environ.get("SMP_NEO4J_PASSWORD", ""), + ) + builder = DefaultGraphBuilder(graph_store) + enricher = StaticSemanticEnricher() + + await graph_store.connect() + if clear: + await graph_store.clear() + log.warning("graph_cleared") + + root = Path(directory).resolve() + if not root.is_dir(): + raise ValueError(f"Not a directory: {root}") + + stats = {"files": 0, "nodes": 0, "edges": 0, "errors": 0, "skipped": 0} + t0 = time.monotonic() + + for file_path in sorted(root.rglob("*")): + if not file_path.is_file(): + continue + if file_path.suffix.lower() not in extensions: + continue + + try: + size = file_path.stat().st_size + except OSError: + continue + if size > max_file_size: + log.warning("file_too_large", file=str(file_path), size=size) + stats["skipped"] += 1 + continue + + parts = file_path.relative_to(root).parts + if any( + p.startswith(".") or p in ("node_modules", "__pycache__", "venv", ".venv", "dist", "build") for p in parts + ): + continue + + rel_path = str(file_path.relative_to(root)) + doc = registry.parse_file(str(file_path)) + doc = type(doc)( + file_path=rel_path, + language=doc.language, + nodes=[ + type(n)( + id=n.id.replace(str(file_path), rel_path), + type=n.type, + file_path=rel_path, + structural=n.structural, + semantic=n.semantic, + ) + for n in doc.nodes + ], + edges=[ + type(e)( + source_id=e.source_id.replace(str(file_path), rel_path), + target_id=e.target_id.replace(str(file_path), rel_path), + type=e.type, + metadata=e.metadata, + ) + for e in doc.edges + ], + errors=doc.errors, + ) + + if doc.nodes or doc.edges: + await builder.ingest_document(doc) + + if doc.nodes: + enriched = await enricher.enrich_batch(doc.nodes) + for en in enriched: + if en.semantic.status == "enriched": + await graph_store.upsert_node(en) + + stats["files"] += 1 + stats["nodes"] += len(doc.nodes) + stats["edges"] += len(doc.edges) + stats["errors"] += len(doc.errors) + + resolved = await builder.resolve_pending_edges() + if resolved: + log.info("post_ingest_edges_resolved", count=resolved) + + elapsed = time.monotonic() - t0 + log.info( + "ingest_complete", + directory=str(root), + files=stats["files"], + nodes=stats["nodes"], + edges=stats["edges"], + errors=stats["errors"], + skipped=stats["skipped"], + elapsed_s=round(elapsed, 2), + ) + + await graph_store.close() + return stats + + +def main() -> None: + parser = argparse.ArgumentParser(prog="smp", description="Structural Memory Protocol CLI") + sub = parser.add_subparsers(dest="command") + + ingest_cmd = sub.add_parser("ingest", help="Parse a directory and build the graph") + ingest_cmd.add_argument("directory", help="Root directory to ingest") + ingest_cmd.add_argument( + "--neo4j-uri", type=str, help="Neo4j URI (defaults to SMP_NEO4J_URI env var or bolt://localhost:7687)" + ) + ingest_cmd.add_argument("--neo4j-user", type=str, help="Neo4j user (defaults to SMP_NEO4J_USER env var or neo4j)") + ingest_cmd.add_argument( + "--neo4j-password", type=str, help="Neo4j password (defaults to SMP_NEO4J_PASSWORD env var)" + ) + ingest_cmd.add_argument("--clear", action="store_true", help="Clear graph before ingesting") + ingest_cmd.add_argument("--json-log", action="store_true", help="JSON structured logging") + ingest_cmd.add_argument("--max-size", type=int, default=DEFAULT_MAX_FILE_SIZE, help="Max file size in bytes") + + serve_cmd = sub.add_parser("serve", help="Start the SMP JSON-RPC server") + serve_cmd.add_argument("--host", default="0.0.0.0", help="Bind host") + serve_cmd.add_argument("--port", type=int, default=8420, help="Bind port") + serve_cmd.add_argument( + "--neo4j-uri", type=str, help="Neo4j URI (defaults to SMP_NEO4J_URI env var or bolt://localhost:7687)" + ) + serve_cmd.add_argument("--neo4j-user", type=str, help="Neo4j user (defaults to SMP_NEO4J_USER env var or neo4j)") + serve_cmd.add_argument("--neo4j-password", type=str, help="Neo4j password (defaults to SMP_NEO4J_PASSWORD env var)") + serve_cmd.add_argument("--safety", action="store_true", help="Enable agent safety protocol") + serve_cmd.add_argument("--json-log", action="store_true", help="JSON structured logging") + + run_cmd = sub.add_parser("run", help="Run a command in the background") + run_cmd.add_argument("name", help="Name for this background process") + run_cmd.add_argument("command", nargs="+", help="Command and arguments to run") + run_cmd.add_argument("--cwd", type=str, help="Working directory") + run_cmd.add_argument("--env", nargs="+", help="Environment variables as KEY=VALUE") + run_cmd.add_argument("--restart", action="store_true", help="Restart if already running") + + list_cmd = sub.add_parser("ps", help="List running background processes") + list_cmd.add_argument("--name", help="Show specific process details") + + stop_cmd = sub.add_parser("stop", help="Stop a background process") + stop_cmd.add_argument("name", help="Name of the process to stop") + + logs_cmd = sub.add_parser("logs", help="View logs for a background process") + logs_cmd.add_argument("name", help="Name of the process") + logs_cmd.add_argument("--stream", action="store_true", help="Stream new output") + + args = parser.parse_args() + if not args.command: + parser.print_help() + sys.exit(1) + + configure_logging(json=getattr(args, "json_log", False)) + + if args.command == "ingest": + stats = asyncio.run( + ingest_directory( + args.directory, + neo4j_uri=args.neo4j_uri, + neo4j_user=args.neo4j_user, + neo4j_password=args.neo4j_password, + clear=args.clear, + max_file_size=args.max_size, + ) + ) + print( + f"\nIngested {stats['files']} files: {stats['nodes']} nodes, " + f"{stats['edges']} edges, {stats['errors']} errors" + ) + + elif args.command == "serve": + import os + + import uvicorn + + # Only set env vars if explicitly provided (to allow env var fallbacks) + if args.neo4j_uri: + os.environ["SMP_NEO4J_URI"] = args.neo4j_uri + if args.neo4j_user: + os.environ["SMP_NEO4J_USER"] = args.neo4j_user + if args.neo4j_password: + os.environ["SMP_NEO4J_PASSWORD"] = args.neo4j_password + + from smp.protocol.server import create_app + + application = create_app( + neo4j_uri=args.neo4j_uri, + neo4j_user=args.neo4j_user, + neo4j_password=args.neo4j_password, + safety_enabled=getattr(args, "safety", False), + ) + uvicorn.run(application, host=args.host, port=args.port) + + elif args.command == "run": + from smp.core.background import BackgroundRunner + + env = {} + if args.env: + for e in args.env: + if "=" in e: + key, val = e.split("=", 1) + env[key] = val + + runner = BackgroundRunner() + cwd = Path(args.cwd) if args.cwd else None + + try: + bg_proc = runner.start(args.name, args.command, cwd=cwd, env=env or None) + print(f"Started {args.name}: pid={bg_proc.pid}") + except ValueError as e: + if args.restart: + bg_proc = runner.restart(args.name) + print(f"Restarted {args.name}: pid={bg_proc.pid}") + else: + print(f"Error: {e}") + sys.exit(1) + + elif args.command == "ps": + from smp.core.background import BackgroundRunner + + runner = BackgroundRunner() + if args.name: + proc = runner.get(args.name) + if proc: + print(f"{args.name}: pid={proc['pid']}, running={proc['running']}") + print(f" command: {' '.join(proc['command'])}") + else: + print(f"Process not found: {args.name}") + else: + all_procs = runner.list() + if all_procs: + for name, info in all_procs.items(): + print(f"{name}: pid={info['pid']}, running={info['running']}") + else: + print("No background processes running") + + elif args.command == "stop": + from smp.core.background import BackgroundRunner + + runner = BackgroundRunner() + if runner.stop(args.name): + print(f"Stopped {args.name}") + else: + print(f"Process not found: {args.name}") + sys.exit(1) + + elif args.command == "logs": + from smp.core.background import BackgroundRunner + + runner = BackgroundRunner() + try: + logs = runner.logs(args.name) + if logs["stdout"]: + print(f"=== stdout ===\n{logs['stdout']}") + if logs["stderr"]: + print(f"=== stderr ===\n{logs['stderr']}") + except ValueError as e: + print(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/smp/client.py b/smp/client.py new file mode 100644 index 0000000..0566545 --- /dev/null +++ b/smp/client.py @@ -0,0 +1,201 @@ +"""SMP Client — Python SDK for the Structural Memory Protocol. + +Provides an async client for interacting with the SMP JSON-RPC server. + +Usage:: + + from smp.client import SMPClient + + async with SMPClient("http://localhost:8420") as client: + ctx = await client.get_context("src/auth.py") + results = await client.locate("authentication logic") + await client.update("src/auth.py", content=new_source) +""" + +from __future__ import annotations + +from typing import Any + +import httpx +import msgspec + +from smp.core.models import ( + ContextParams, + FlowParams, + ImpactParams, + JsonRpcRequest, + JsonRpcResponse, + Language, + LocateParams, + NavigateParams, + TraceParams, + UpdateParams, +) + + +class SMPClientError(Exception): + """Raised when the SMP server returns an error.""" + + def __init__(self, code: int, message: str, data: Any = None) -> None: + self.code = code + self.data = data + super().__init__(f"JSON-RPC error {code}: {message}") + + +class SMPClient: + """Async client for the Structural Memory Protocol server. + + Args: + base_url: Server base URL (e.g. ``"http://localhost:8420"``). + timeout: Request timeout in seconds. + """ + + def __init__(self, base_url: str = "http://localhost:8420", timeout: float = 30.0) -> None: + self._base_url = base_url.rstrip("/") + self._client: httpx.AsyncClient | None = None + self._timeout = timeout + self._req_id = 0 + + async def connect(self) -> None: + self._client = httpx.AsyncClient(base_url=self._base_url, timeout=self._timeout) + + async def close(self) -> None: + if self._client: + await self._client.aclose() + self._client = None + + async def __aenter__(self) -> SMPClient: + await self.connect() + return self + + async def __aexit__(self, *_: Any) -> None: + await self.close() + + def _ensure_connected(self) -> httpx.AsyncClient: + if not self._client: + raise RuntimeError("Client not connected. Use 'async with SMPClient(...)' or call connect().") + return self._client + + async def _rpc(self, method: str, params: dict[str, Any]) -> Any: + """Send a JSON-RPC request and return the result.""" + self._req_id += 1 + req = JsonRpcRequest(method=method, params=params, id=self._req_id) + body = msgspec.json.encode(req) + + client = self._ensure_connected() + resp = await client.post("/rpc", content=body, headers={"Content-Type": "application/json"}) + + if resp.status_code == 204: + return None + + rpc_resp = msgspec.json.decode(resp.content, type=JsonRpcResponse) + if rpc_resp.error: + raise SMPClientError(rpc_resp.error.code, rpc_resp.error.message, rpc_resp.error.data) + return rpc_resp.result + + # ----------------------------------------------------------------------- + # Protocol methods + # ----------------------------------------------------------------------- + + async def navigate(self, entity_id: str) -> dict[str, Any]: + """Get a node and its immediate neighbours.""" + return await self._rpc("smp/navigate", msgspec.to_builtins(NavigateParams(query=entity_id))) + + async def trace( + self, + start_id: str, + edge_type: str = "CALLS", + depth: int = 3, + direction: str = "outgoing", + ) -> list[dict[str, Any]]: + """Recursive traversal (e.g. full call graph).""" + return await self._rpc( + "smp/trace", + msgspec.to_builtins( + TraceParams( + start=start_id, + relationship=edge_type, + depth=depth, + direction=direction, + ) + ), + ) + + async def get_context( + self, + file_path: str, + scope: str = "edit", + depth: int = 2, + ) -> dict[str, Any]: + """Aggregate structural context for safe editing.""" + return await self._rpc( + "smp/context", + msgspec.to_builtins( + ContextParams( + file_path=file_path, + scope=scope, + depth=depth, + ) + ), + ) + + async def assess_impact(self, entity_id: str, change_type: str = "delete") -> dict[str, Any]: + """Find blast radius of a change.""" + return await self._rpc( + "smp/impact", msgspec.to_builtins(ImpactParams(entity=entity_id, change_type=change_type)) + ) + + async def locate(self, query: str, top_k: int = 5) -> list[dict[str, Any]]: + """Search by semantic intent — vector search mapping back to graph nodes.""" + return await self._rpc("smp/locate", msgspec.to_builtins(LocateParams(query=query, top_k=top_k))) + + async def find_flow(self, start: str, end: str, max_depth: int = 20) -> list[list[dict[str, Any]]]: + """Find paths between two nodes.""" + return await self._rpc( + "smp/flow", + msgspec.to_builtins( + FlowParams( + start=start, + end=end, + ) + ), + ) + + async def update( + self, + file_path: str, + content: str = "", + language: str = "python", + ) -> dict[str, Any]: + """Notify the server of a file change — incremental graph update. + + If *content* is provided it is parsed directly; otherwise the server + reads the file from disk. + """ + lang = Language(language) if language else Language.PYTHON + return await self._rpc( + "smp/update", + msgspec.to_builtins( + UpdateParams( + file_path=file_path, + content=content, + language=lang, + ) + ), + ) + + # ----------------------------------------------------------------------- + # Convenience endpoints + # ----------------------------------------------------------------------- + + async def health(self) -> dict[str, str]: + """Check server health.""" + client = self._ensure_connected() + resp = await client.get("/health") + return resp.json() + + async def stats(self) -> dict[str, int]: + """Get graph statistics (node/edge counts).""" + client = self._ensure_connected() + resp = await client.get("/stats") + return resp.json() diff --git a/smp/core/__init__.py b/smp/core/__init__.py new file mode 100644 index 0000000..b0685c0 --- /dev/null +++ b/smp/core/__init__.py @@ -0,0 +1 @@ +"""Core data models and types.""" diff --git a/smp/core/background.py b/smp/core/background.py new file mode 100644 index 0000000..b87eb4b --- /dev/null +++ b/smp/core/background.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import contextlib +import json +import os +import signal +import subprocess +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class BackgroundProcess: + name: str + command: list[str] + pid: int + cwd: Path | None = None + env: dict[str, str] = field(default_factory=dict) + started_at: float = field(default_factory=time.time) + + +class BackgroundRunner: + """Manages long-running background processes without blocking the agent.""" + + def __init__(self) -> None: + self._base_dir = Path.home() / ".smp" / "runs" + self._processes: dict[str, BackgroundProcess] = {} + self._open_files: dict[str, tuple[Any, Any]] = {} + self._load() + + def _state_file(self) -> Path: + return self._base_dir / "state.json" + + def _load(self) -> None: + f = self._state_file() + if f.exists(): + with open(f) as fp: + data = json.load(fp) + for name, item in data.items(): + proc = BackgroundProcess( + name=name, + command=item["command"], + pid=item["pid"], + cwd=Path(item["cwd"]) if item.get("cwd") else None, + env=item.get("env", {}), + started_at=item.get("started_at", 0), + ) + if self._is_running(proc.pid): + self._processes[name] = proc + + def _save(self) -> None: + self._base_dir.mkdir(parents=True, exist_ok=True) + data = { + name: { + "command": proc.command, + "pid": proc.pid, + "cwd": str(proc.cwd) if proc.cwd else None, + "env": proc.env, + "started_at": proc.started_at, + } + for name, proc in self._processes.items() + } + with open(self._state_file(), "w") as fp: + json.dump(data, fp) + + def start( + self, + name: str, + command: list[str], + cwd: Path | None = None, + env: dict[str, str] | None = None, + ) -> BackgroundProcess: + """Start a command in background and return immediately.""" + if name in self._processes: + raise ValueError(f"Process already running: {name}") + + self._base_dir.mkdir(parents=True, exist_ok=True) + run_dir = self._base_dir / name + run_dir.mkdir(parents=True, exist_ok=True) + + full_env = os.environ.copy() + if env: + full_env.update(env) + + with open(run_dir / "stdout.log", "wb") as stdout_file, open(run_dir / "stderr.log", "wb") as stderr_file: + proc = subprocess.Popen( + command, + stdout=stdout_file, + stderr=stderr_file, + cwd=cwd or run_dir, + env=full_env, + start_new_session=True, + text=True, + ) + + self._open_files[name] = (None, None) # Files closed after Popen + + bg_proc = BackgroundProcess( + name=name, + command=command, + pid=proc.pid, + cwd=cwd, + env=env, + ) + self._processes[name] = bg_proc + self._save() + return bg_proc + + def stop(self, name: str) -> bool: + """Stop a running process by name.""" + if name not in self._processes: + return False + + bg_proc = self._processes[name] + with contextlib.suppress(ProcessLookupError): + os.kill(bg_proc.pid, signal.SIGTERM) + + self._open_files.pop(name, None) + + del self._processes[name] + self._save() + return True + + def restart(self, name: str) -> BackgroundProcess: + """Restart a stopped or existing process.""" + if name not in self._processes: + raise ValueError(f"Unknown process: {name}") + + bg_proc = self._processes[name] + self.stop(name) + return self.start(name, bg_proc.command, bg_proc.cwd, bg_proc.env) + + def list(self) -> dict[str, dict[str, Any]]: + """List all managed processes.""" + result = {} + for name, proc in self._processes.items(): + result[name] = { + "pid": proc.pid, + "command": proc.command, + "cwd": str(proc.cwd) if proc.cwd else None, + "running": self._is_running(proc.pid), + } + return result + + def get(self, name: str) -> dict[str, Any] | None: + """Get details of a specific process.""" + if name not in self._processes: + return None + + proc = self._processes[name] + return { + "pid": proc.pid, + "command": proc.command, + "cwd": str(proc.cwd) if proc.cwd else None, + "running": self._is_running(proc.pid), + } + + def logs(self, name: str) -> dict[str, str]: + """Get stdout/stderr log contents for a process.""" + if name not in self._processes and not (self._base_dir / name).exists(): + raise ValueError(f"Unknown process: {name}") + + run_dir = self._base_dir / name + stdout = "" + stderr = "" + if (run_dir / "stdout.log").exists(): + with open(run_dir / "stdout.log") as fp: + stdout = fp.read() + if (run_dir / "stderr.log").exists(): + with open(run_dir / "stderr.log") as fp: + stderr = fp.read() + return {"stdout": stdout, "stderr": stderr} + + def _is_running(self, pid: int) -> bool: + """Check if a process is still running.""" + try: + os.kill(pid, 0) + return True + except ProcessLookupError: + return False diff --git a/smp/core/merkle.py b/smp/core/merkle.py new file mode 100644 index 0000000..0f2fe42 --- /dev/null +++ b/smp/core/merkle.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import hashlib +from typing import Any + +from smp.core.models import GraphNode, NodeType +from smp.logging import get_logger + +log = get_logger(__name__) + + +class MerkleTree: + """SHA-256 Merkle Tree for structural consistency checks.""" + + def __init__(self) -> None: + self._leaf_hashes: list[tuple[str, str]] = [] + self._levels: list[list[str]] = [] + + def _hash_single(self, data: str) -> str: + return hashlib.sha256(data.encode()).hexdigest() + + def _hash_pair(self, left: str, right: str) -> str: + return hashlib.sha256(f"{left}{right}".encode()).hexdigest() + + def build(self, nodes: list[GraphNode]) -> None: + """Build a SHA-256 tree where leaves are file nodes.""" + file_nodes = sorted([n for n in nodes if n.type == NodeType.FILE], key=lambda n: n.id) + + self._leaf_hashes = [(n.id, self._hash_single(f"{n.id}{n.semantic.source_hash}")) for n in file_nodes] + + current_level = [h for _, h in self._leaf_hashes] + self._levels = [current_level] + + while len(current_level) > 1: + next_level = [] + for i in range(0, len(current_level), 2): + left = current_level[i] + right = current_level[i + 1] if i + 1 < len(current_level) else left + next_level.append(self._hash_pair(left, right)) + current_level = next_level + self._levels.append(current_level) + + def hash(self) -> str: + """Return the root hash.""" + if not self._levels: + return "" + return self._levels[-1][0] + + def diff(self, other: MerkleTree) -> dict[str, set[str]]: + """Perform an O(log n) comparison to return {added, removed, modified} node IDs.""" + local_map = dict(self._leaf_hashes) + remote_map = dict(other._leaf_hashes) + + local_ids = set(local_map.keys()) + remote_ids = set(remote_map.keys()) + + added = remote_ids - local_ids + removed = local_ids - remote_ids + + common_ids = local_ids & remote_ids + modified = {nid for nid in common_ids if local_map[nid] != remote_map[nid]} + + return {"added": added, "removed": removed, "modified": modified} + + def export(self) -> dict[str, Any]: + """Return a serializable format of the tree for distribution.""" + return {"root": self.hash(), "levels": self._levels, "leaf_hashes": self._leaf_hashes} + + def import_data(self, data: dict[str, Any]) -> None: + """Reconstruct the tree from exported data.""" + self._levels = data["levels"] + self._leaf_hashes = [tuple(x) for x in data["leaf_hashes"]] + + +class MerkleIndex: + """Sync management using Merkle Trees.""" + + def __init__(self, tree: MerkleTree) -> None: + self._tree = tree + + def sync(self, remote_hash: str) -> dict[str, set[str]] | None: + """Compare local root hash with remote, if different, trigger diff.""" + if self._tree.hash() == remote_hash: + return None + + log.info("merkle_sync_diff_triggered", local=self._tree.hash(), remote=remote_hash) + return None + + def apply_patch(self, patch: dict[str, Any]) -> None: + """Update local state based on a patch.""" + log.info("merkle_apply_patch", patch_keys=list(patch.keys())) diff --git a/smp/core/models.py b/smp/core/models.py new file mode 100644 index 0000000..54615f1 --- /dev/null +++ b/smp/core/models.py @@ -0,0 +1,644 @@ +"""Core data models for SMP(3). + +Partitioned schema: structural vs semantic properties. +All models use msgspec.Struct for zero-cost serialization and validation. +""" + +from __future__ import annotations + +import enum +from typing import Any + +import msgspec + +# --------------------------------------------------------------------------- +# Enumerations (SMP(3) schema) +# --------------------------------------------------------------------------- + + +class NodeType(enum.StrEnum): + """Node types per SMP(3) specification.""" + + REPOSITORY = "Repository" + PACKAGE = "Package" + FILE = "File" + CLASS = "Class" + FUNCTION = "Function" + VARIABLE = "Variable" + INTERFACE = "Interface" + TEST = "Test" + CONFIG = "Config" + + +class EdgeType(enum.StrEnum): + """Relationship types per SMP(3) specification.""" + + CONTAINS = "CONTAINS" + IMPORTS = "IMPORTS" + DEFINES = "DEFINES" + CALLS = "CALLS" + CALLS_RUNTIME = "CALLS_RUNTIME" + INHERITS = "INHERITS" + IMPLEMENTS = "IMPLEMENTS" + DEPENDS_ON = "DEPENDS_ON" + TESTS = "TESTS" + USES = "USES" + REFERENCES = "REFERENCES" + + +class Language(enum.StrEnum): + """Supported source languages.""" + + PYTHON = "python" + TYPESCRIPT = "typescript" + UNKNOWN = "unknown" + + +# --------------------------------------------------------------------------- +# Structural properties (coordinates, signature, complexity) +# --------------------------------------------------------------------------- + + +class StructuralProperties(msgspec.Struct, frozen=True): + """Immutable structural coordinates of a code entity.""" + + name: str = "" + file: str = "" + signature: str = "" + start_line: int = 0 + end_line: int = 0 + complexity: int = 0 + lines: int = 0 + parameters: int = 0 + + +# --------------------------------------------------------------------------- +# Semantic properties (docstrings, comments, decorators, annotations, tags) +# --------------------------------------------------------------------------- + + +class InlineComment(msgspec.Struct, frozen=True): + """A single inline comment extracted from source.""" + + line: int = 0 + text: str = "" + + +class Annotations(msgspec.Struct, frozen=True): + """Structured type annotations extracted from a function/method.""" + + params: dict[str, str] = msgspec.field(default_factory=dict) + returns: str | None = None + throws: list[str] = msgspec.field(default_factory=list) + + +class SemanticProperties(msgspec.Struct): + """Mutable semantic metadata extracted via static AST analysis.""" + + status: str = "no_metadata" + docstring: str = "" + description: str | None = None + inline_comments: list[InlineComment] = msgspec.field(default_factory=list) + decorators: list[str] = msgspec.field(default_factory=list) + annotations: Annotations | None = None + tags: list[str] = msgspec.field(default_factory=list) + score: float = 0.0 + manually_set: bool = False + source_hash: str = "" + enriched_at: str = "" + + +# --------------------------------------------------------------------------- +# Graph primitives +# --------------------------------------------------------------------------- + + +class GraphNode(msgspec.Struct): + """A single node in the structural graph with partitioned properties.""" + + id: str + type: NodeType + file_path: str + structural: StructuralProperties = msgspec.field(default_factory=StructuralProperties) + semantic: SemanticProperties = msgspec.field(default_factory=SemanticProperties) + + def fingerprint(self) -> str: + """Deterministic identity key for deduplication.""" + return f"{self.file_path}::{self.type.value}::{self.structural.name}::{self.structural.start_line}" + + +class GraphEdge(msgspec.Struct): + """A directed edge between two nodes.""" + + source_id: str + target_id: str + type: EdgeType + metadata: dict[str, str] = msgspec.field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Document — the unit of parsing +# --------------------------------------------------------------------------- + + +class ParseError(msgspec.Struct): + """Non-fatal error encountered during parsing.""" + + message: str + line: int = 0 + column: int = 0 + severity: str = "error" + + +class Document(msgspec.Struct): + """A parsed source file with its extracted graph elements.""" + + file_path: str + language: Language = Language.UNKNOWN + content_hash: str = "" + nodes: list[GraphNode] = msgspec.field(default_factory=list) + edges: list[GraphEdge] = msgspec.field(default_factory=list) + errors: list[ParseError] = msgspec.field(default_factory=list) + + +# --------------------------------------------------------------------------- +# JSON-RPC 2.0 protocol models +# --------------------------------------------------------------------------- + + +class JsonRpcRequest(msgspec.Struct): + """JSON-RPC 2.0 request envelope.""" + + jsonrpc: str = "2.0" + method: str = "" + params: dict[str, Any] = msgspec.field(default_factory=dict) + id: int | str | None = None + + +class JsonRpcError(msgspec.Struct): + """JSON-RPC 2.0 error object.""" + + code: int + message: str + data: Any = None + + +class JsonRpcResponse(msgspec.Struct): + """JSON-RPC 2.0 response envelope.""" + + jsonrpc: str = "2.0" + result: Any = None + error: JsonRpcError | None = None + id: int | str | None = None + + +# --------------------------------------------------------------------------- +# Memory Management params +# --------------------------------------------------------------------------- + + +class UpdateParams(msgspec.Struct): + """Parameters for smp/update.""" + + file_path: str + content: str = "" + change_type: str = "modified" + language: Language = Language.PYTHON + + +class BatchUpdateParams(msgspec.Struct): + """Parameters for smp/batch_update.""" + + changes: list[dict[str, str]] = msgspec.field(default_factory=list) + + +class ReindexParams(msgspec.Struct): + """Parameters for smp/reindex.""" + + scope: str = "full" + + +# --------------------------------------------------------------------------- +# Enrichment params +# --------------------------------------------------------------------------- + + +class EnrichParams(msgspec.Struct): + """Parameters for smp/enrich.""" + + node_id: str + force: bool = False + + +class EnrichBatchParams(msgspec.Struct): + """Parameters for smp/enrich/batch.""" + + scope: str = "full" + force: bool = False + + +class EnrichStaleParams(msgspec.Struct): + """Parameters for smp/enrich/stale.""" + + scope: str = "full" + + +class EnrichStatusParams(msgspec.Struct): + """Parameters for smp/enrich/status.""" + + scope: str = "full" + + +# --------------------------------------------------------------------------- +# Annotation params +# --------------------------------------------------------------------------- + + +class AnnotateParams(msgspec.Struct): + """Parameters for smp/annotate.""" + + node_id: str + description: str = "" + tags: list[str] = msgspec.field(default_factory=list) + force: bool = False + + +class AnnotateBulkItem(msgspec.Struct): + """Single annotation in a bulk request.""" + + node_id: str + description: str = "" + tags: list[str] = msgspec.field(default_factory=list) + + +class AnnotateBulkParams(msgspec.Struct): + """Parameters for smp/annotate/bulk.""" + + annotations: list[AnnotateBulkItem] = msgspec.field(default_factory=list) + + +class TagParams(msgspec.Struct): + """Parameters for smp/tag.""" + + scope: str = "" + tags: list[str] = msgspec.field(default_factory=list) + action: str = "add" + + +# --------------------------------------------------------------------------- +# Session / Safety params +# --------------------------------------------------------------------------- + + +class SessionOpenParams(msgspec.Struct): + """Parameters for smp/session/open.""" + + agent_id: str = "" + task: str = "" + scope: list[str] = msgspec.field(default_factory=list) + mode: str = "read" + + +class SessionCloseParams(msgspec.Struct): + """Parameters for smp/session/close.""" + + session_id: str = "" + status: str = "completed" + + +class SessionRecoverParams(msgspec.Struct): + """Parameters for smp/session/recover.""" + + session_id: str = "" + + +class GuardCheckParams(msgspec.Struct): + target: str = "" + intended_change: str = "" + + +class DryRunParams(msgspec.Struct): + """Parameters for smp/dryrun.""" + + session_id: str = "" + file_path: str = "" + proposed_content: str = "" + change_summary: str = "" + + +class CheckpointParams(msgspec.Struct): + """Parameters for smp/checkpoint.""" + + session_id: str = "" + files: list[str] = msgspec.field(default_factory=list) + + +class RollbackParams(msgspec.Struct): + """Parameters for smp/rollback.""" + + session_id: str = "" + checkpoint_id: str = "" + + +class LockParams(msgspec.Struct): + """Parameters for smp/lock and smp/unlock.""" + + session_id: str = "" + files: list[str] = msgspec.field(default_factory=list) + + +class AuditGetParams(msgspec.Struct): + """Parameters for smp/audit/get.""" + + audit_log_id: str = "" + + +# --------------------------------------------------------------------------- +# Query params +# --------------------------------------------------------------------------- + + +class NavigateParams(msgspec.Struct): + """Parameters for smp/navigate.""" + + query: str = "" + include_relationships: bool = True + + +class TraceParams(msgspec.Struct): + """Parameters for smp/trace.""" + + start: str = "" + relationship: str = "CALLS" + depth: int = 3 + direction: str = "outgoing" + + +class ContextParams(msgspec.Struct): + """Parameters for smp/context.""" + + file_path: str = "" + scope: str = "edit" + depth: int = 2 + + +class ImpactParams(msgspec.Struct): + """Parameters for smp/impact.""" + + entity: str = "" + change_type: str = "delete" + + +class LocateParams(msgspec.Struct): + """Parameters for smp/locate.""" + + query: str = "" + fields: list[str] = msgspec.field(default_factory=lambda: ["name", "docstring", "tags"]) + node_types: list[str] = msgspec.field(default_factory=list) + top_k: int = 5 + + +class SearchParams(msgspec.Struct): + """Parameters for smp/search.""" + + query: str = "" + match: str = "any" + filter: dict[str, Any] = msgspec.field(default_factory=dict) + top_k: int = 5 + + +class FlowParams(msgspec.Struct): + """Parameters for smp/flow.""" + + start: str = "" + end: str = "" + flow_type: str = "data" + + +# --------------------------------------------------------------------------- +# SMP(3) Runtime Models +# --------------------------------------------------------------------------- + + +class RuntimeEdge(msgspec.Struct): + """Runtime edge tracking actual execution paths.""" + + source_id: str = "" + target_id: str = "" + edge_type: str = "CALLS_RUNTIME" + timestamp: str = "" + session_id: str = "" + trace_id: str = "" + duration_ms: int = 0 + metadata: dict[str, Any] = msgspec.field(default_factory=dict) + + +class RuntimeTrace(msgspec.Struct): + """Complete runtime trace for a session.""" + + trace_id: str = "" + session_id: str = "" + agent_id: str = "" + started_at: str = "" + ended_at: str = "" + edges: list[RuntimeEdge] = msgspec.field(default_factory=list) + nodes_visited: list[str] = msgspec.field(default_factory=list) + + +# --------------------------------------------------------------------------- +# SMP(3) Additional Query Params +# --------------------------------------------------------------------------- + + +class DiffParams(msgspec.Struct): + """Parameters for smp/diff.""" + + from_snapshot: str = "" + to_snapshot: str = "" + scope: str = "full" + + +class PlanParams(msgspec.Struct): + """Parameters for smp/plan.""" + + change_description: str = "" + target_file: str = "" + change_type: str = "refactor" + scope: str = "full" + + +class ConflictParams(msgspec.Struct): + """Parameters for smp/conflict.""" + + entity: str = "" + proposed_change: str = "" + context: dict[str, Any] = msgspec.field(default_factory=dict) + + +class WhyParams(msgspec.Struct): + """Parameters for smp/why.""" + + entity: str = "" + relationship: str = "" + depth: int = 3 + + +class TelemetryParams(msgspec.Struct): + """Parameters for smp/telemetry.""" + + action: str = "get_stats" + node_id: str | None = None + threshold: int | None = None + + +class TelemetryHotParams(msgspec.Struct): + """Parameters for smp/telemetry/hot.""" + + node_id: str + + +class TelemetryNodeParams(msgspec.Struct): + """Parameters for smp/telemetry/node.""" + + node_id: str + + +# --------------------------------------------------------------------------- +# SMP(3) Handoff Models +# --------------------------------------------------------------------------- + + +class ReviewCreateParams(msgspec.Struct): + """Parameters for smp/review/create.""" + + session_id: str = "" + files_changed: list[str] = msgspec.field(default_factory=list) + diff_summary: str = "" + reviewers: list[str] = msgspec.field(default_factory=list) + + +class ReviewApproveParams(msgspec.Struct): + """Parameters for smp/review/approve.""" + + review_id: str = "" + reviewer: str = "" + + +class ReviewRejectParams(msgspec.Struct): + """Parameters for smp/review/reject.""" + + review_id: str = "" + reviewer: str = "" + reason: str = "" + + +class ReviewCommentParams(msgspec.Struct): + """Parameters for smp/review/comment.""" + + review_id: str = "" + author: str = "" + comment: str = "" + file_path: str | None = None + line: int | None = None + + +class PRCreateParams(msgspec.Struct): + """Parameters for smp/pr/create.""" + + review_id: str = "" + title: str = "" + body: str = "" + branch: str = "" + base_branch: str = "main" + + +# --------------------------------------------------------------------------- +# SMP(3) Sandbox Models +# --------------------------------------------------------------------------- + + +class SandboxSpawnParams(msgspec.Struct): + """Parameters for smp/sandbox/spawn.""" + + name: str | None = None + template: str | None = None + files: dict[str, str] = msgspec.field(default_factory=dict) + + +class SandboxExecuteParams(msgspec.Struct): + """Parameters for smp/sandbox/execute.""" + + sandbox_id: str = "" + command: list[str] = msgspec.field(default_factory=list) + stdin: str | None = None + timeout: int | None = None + + +class SandboxKillParams(msgspec.Struct): + """Parameters for smp/sandbox/kill.""" + + execution_id: str = "" + + +# --------------------------------------------------------------------------- +# Community params +# --------------------------------------------------------------------------- + + +class CommunityDetectParams(msgspec.Struct): + """Parameters for smp/community/detect.""" + + resolutions: list[dict[str, Any]] = msgspec.field(default_factory=list) + relationship_types: list[str] = msgspec.field(default_factory=list) + + +class CommunityListParams(msgspec.Struct): + """Parameters for smp/community/list.""" + + level: int | None = None + + +class CommunityGetParams(msgspec.Struct): + """Parameters for smp/community/get.""" + + community_id: str + node_types: list[str] = msgspec.field(default_factory=list) + include_bridges: bool = False + + +class CommunityBoundariesParams(msgspec.Struct): + """Parameters for smp/community/boundaries.""" + + level: int = 0 + min_coupling: float = 0.05 + + +# --------------------------------------------------------------------------- +# Merkle params +# --------------------------------------------------------------------------- + + +class MerkleSyncParams(msgspec.Struct): + """Parameters for smp/sync.""" + + remote_data: dict[str, Any] = msgspec.field(default_factory=dict) + + +class MerkleImportParams(msgspec.Struct): + """Parameters for smp/index/import.""" + + data: dict[str, Any] = msgspec.field(default_factory=dict) + + +class IntegrityCheckParams(msgspec.Struct): + """Parameters for smp/integrity/check.""" + + node_id: str = "" + current_state: dict[str, Any] = msgspec.field(default_factory=dict) + + +class IntegrityBaselineParams(msgspec.Struct): + """Parameters for smp/integrity/baseline.""" + + node_id: str = "" + state: dict[str, Any] = msgspec.field(default_factory=dict) diff --git a/smp/engine/__init__.py b/smp/engine/__init__.py new file mode 100644 index 0000000..c697d8e --- /dev/null +++ b/smp/engine/__init__.py @@ -0,0 +1,11 @@ +"""Engine layer — graph building, enrichment, querying.""" + +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.engine.query import DefaultQueryEngine + +__all__ = [ + "DefaultGraphBuilder", + "DefaultQueryEngine", + "StaticSemanticEnricher", +] diff --git a/smp/engine/community.py b/smp/engine/community.py new file mode 100644 index 0000000..69d75aa --- /dev/null +++ b/smp/engine/community.py @@ -0,0 +1,499 @@ +"""Community detection using Louvain algorithm at two resolution levels. + +Implements two-level community detection (coarse L0, fine L1) per the SMP(3) +specification. Creates Community nodes, MEMBER_OF edges, BRIDGES edges, +and centroid embeddings stored in ChromaDB for smp/locate Phase 0 routing. +""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.core.models import EdgeType, GraphEdge, GraphNode, NodeType +from smp.logging import get_logger +from smp.store.interfaces import GraphStore, VectorStore + +log = get_logger(__name__) + + +@dataclass +class Community: + id: str = "" + level: int = 0 + label: str = "" + parent_community: str = "" + majority_path_prefix: str = "" + top_tags: list[str] = field(default_factory=list) + member_count: int = 0 + file_count: int = 0 + internal_edge_count: int = 0 + external_edge_count: int = 0 + modularity_score: float = 0.0 + centroid_embedding_id: str = "" + detected_at: str = "" + + +class CommunityDetector: + """Two-level Louvain community detection over the structural graph.""" + + def __init__( + self, + graph_store: GraphStore, + vector_store: VectorStore | None = None, + min_community_size: int = 5, + ) -> None: + self._graph = graph_store + self._vector = vector_store + self._min_size = min_community_size + self._communities: dict[str, Community] = {} + self._node_communities_l0: dict[str, str] = {} + self._node_communities_l1: dict[str, str] = {} + self._bridges: list[dict[str, Any]] = [] + + async def detect( + self, + resolutions: list[dict[str, Any]] | None = None, + relationship_types: list[str] | None = None, + ) -> dict[str, Any]: + if resolutions is None: + resolutions = [ + {"level": 0, "resolution": 0.5, "label": "coarse"}, + {"level": 1, "resolution": 1.5, "label": "fine"}, + ] + if relationship_types is None: + relationship_types = ["CALLS", "IMPORTS", "DEFINES"] + + all_nodes = await self._graph.find_nodes() + if not all_nodes: + return { + "nodes_assigned": 0, + "bridge_edges": 0, + "levels": {}, + "coarse_communities": [], + "fine_communities": [], + } + + edge_types = [EdgeType(rt) for rt in relationship_types if rt in EdgeType._value2member_map_] + adjacency = await self._build_adjacency(all_nodes, edge_types) + + all_results: dict[str, dict[str, Any]] = {} + for res_config in resolutions: + level = res_config.get("level", 0) + resolution = res_config.get("resolution", 1.0) + label = res_config.get("label", "coarse" if level == 0 else "fine") + + assignments = self._louvain(all_nodes, adjacency, resolution) + communities = self._build_communities(assignments, all_nodes, adjacency, level, label) + + if level == 0: + self._node_communities_l0 = assignments + else: + self._node_communities_l1 = assignments + + for comm in communities.values(): + self._communities[comm.id] = comm + await self._store_community_node(comm) + await self._write_member_of_edges(comm, all_nodes, level) + + all_results[str(level)] = { + "communities_found": len(communities), + "modularity": self._compute_modularity(assignments, adjacency), + } + + self._bridges = await self._detect_bridges(all_nodes, adjacency) + await self._write_bridges_edges() + + if self._vector is not None: + await self._compute_centroids(all_nodes) + + coarse = [ + { + "id": c.id, + "label": c.label, + "member_count": c.member_count, + "fine_children": sum(1 for fc in self._communities.values() if fc.parent_community == c.id), + } + for c in self._communities.values() + if c.level == 0 + ] + fine = [ + {"id": c.id, "parent": c.parent_community, "label": c.label, "member_count": c.member_count} + for c in self._communities.values() + if c.level == 1 + ] + + total_assigned = len(self._node_communities_l0) + return { + "nodes_assigned": total_assigned, + "bridge_edges": len(self._bridges), + "levels": all_results, + "coarse_communities": coarse, + "fine_communities": fine, + } + + async def list_communities(self, level: int | None = None) -> dict[str, Any]: + communities = list(self._communities.values()) + if level is not None: + communities = [c for c in communities if c.level == level] + return { + "total": len(communities), + "communities": [ + { + "id": c.id, + "level": c.level, + "parent_community": c.parent_community, + "label": c.label, + "majority_path_prefix": c.majority_path_prefix, + "top_tags": c.top_tags, + "member_count": c.member_count, + "file_count": c.file_count, + "internal_edge_count": c.internal_edge_count, + "external_edge_count": c.external_edge_count, + "modularity_score": c.modularity_score, + "bridge_communities": [b["to_community"] for b in self._bridges if b["from_community"] == c.id], + } + for c in communities + ], + } + + async def get_community( + self, + community_id: str, + node_types: list[str] | None = None, + include_bridges: bool = False, + ) -> dict[str, Any] | None: + comm = self._communities.get(community_id) + if not comm: + return None + + assignments = self._node_communities_l1 if comm.level == 1 else self._node_communities_l0 + member_ids = [nid for nid, cid in assignments.items() if cid == community_id] + + members: list[dict[str, Any]] = [] + for mid in member_ids: + node = await self._graph.get_node(mid) + if node is None: + continue + if node_types and node.type.value not in node_types: + continue + members.append( + { + "id": node.id, + "type": node.type.value, + "name": node.structural.name, + "file": node.file_path, + "pagerank": 0.0, + "heat_score": 0, + } + ) + + bridge_edges = [] + if include_bridges: + bridge_edges = [ + b for b in self._bridges if b["from_community"] == community_id or b["to_community"] == community_id + ] + + return { + "community_id": comm.id, + "level": comm.level, + "parent_community": comm.parent_community, + "label": comm.label, + "member_count": comm.member_count, + "members": members, + "bridge_edges": bridge_edges, + } + + async def get_boundaries(self, level: int = 0, min_coupling: float = 0.05) -> dict[str, Any]: + level_bridges = [ + b + for b in self._bridges + if any( + self._communities.get(b["from_community"], Community()).level == level, + self._communities.get(b["to_community"], Community()).level == level, + ) + ] + filtered = [b for b in level_bridges if b.get("coupling_weight", 0) >= min_coupling] + return { + "level": level, + "boundaries": filtered, + } + + async def _build_adjacency( + self, + nodes: list[GraphNode], + edge_types: list[EdgeType], + ) -> dict[str, set[str]]: + adj: dict[str, set[str]] = defaultdict(set) + for node in nodes: + adj[node.id] = set() + for node in nodes: + for et in edge_types: + edges = await self._graph.get_edges(node.id, et, direction="outgoing") + for edge in edges: + adj[node.id].add(edge.target_id) + if edge.target_id in adj: + adj[edge.target_id].add(node.id) + return adj + + def _louvain( + self, + nodes: list[GraphNode], + adjacency: dict[str, set[str]], + resolution: float, + ) -> dict[str, str]: + community: dict[str, int] = {} + for i, node in enumerate(nodes): + community[node.id] = i + + improved = True + iterations = 0 + max_iterations = 50 + + while improved and iterations < max_iterations: + improved = False + iterations += 1 + for node in nodes: + nid = node.id + current_comm = community[nid] + neighbor_comms: dict[int, int] = defaultdict(int) + for neighbor_id in adjacency.get(nid, set()): + neighbor_comms[community[neighbor_id]] += 1 + + if not neighbor_comms: + continue + + best_comm = current_comm + best_gain = 0.0 + total_edges = sum(neighbor_comms.values()) + ki = len(adjacency.get(nid, set())) + + for comm, ki_comm in neighbor_comms.items(): + sigma_tot = sum(1 for n, c in community.items() if c == comm and n in adjacency) + sigma_tot = max(sigma_tot, 1) + gain = resolution * ki_comm - ki * sigma_tot / (2 * total_edges) if total_edges > 0 else 0 + if gain > best_gain: + best_gain = gain + best_comm = comm + + if best_comm != current_comm: + community[nid] = best_comm + improved = True + + comm_map: dict[str, str] = {} + for nid, comm_id in community.items(): + comm_map[nid] = f"comm_{comm_id}" + return comm_map + + def _compute_modularity( + self, + assignments: dict[str, str], + adjacency: dict[str, set[str]], + ) -> float: + total_edges = sum(len(neighbors) for neighbors in adjacency.values()) + if total_edges == 0: + return 0.0 + total_edges //= 2 + + e_cc: dict[str, float] = defaultdict(float) + a_c: dict[str, float] = defaultdict(float) + + for nid, neighbors in adjacency.items(): + c_i = assignments.get(nid, "") + a_c[c_i] += len(neighbors) + for neighbor_id in neighbors: + c_j = assignments.get(neighbor_id, "") + if c_i == c_j: + e_cc[c_i] += 1 + + modularity = 0.0 + for c in e_cc: + modularity += (e_cc[c] / (2.0 * total_edges if total_edges > 0 else 1)) - ( + a_c[c] / (2.0 * total_edges if total_edges > 0 else 1) + ) ** 2 + return round(modularity, 4) + + def _build_communities( + self, + assignments: dict[str, str], + nodes: list[GraphNode], + adjacency: dict[str, set[str]], + level: int, + label: str, + ) -> dict[str, Community]: + comm_members: dict[str, list[GraphNode]] = defaultdict(list) + for node in nodes: + cid = assignments.get(node.id, "") + if cid: + comm_members[cid].append(node) + + communities: dict[str, Community] = {} + for cid, members in comm_members.items(): + if len(members) < self._min_size: + smallest_comm = min(communities, key=lambda k: len(comm_members[k])) if communities else None + if smallest_comm: + for m in members: + assignments[m.id] = smallest_comm + communities[smallest_comm].member_count += 1 + continue + + path_counts: dict[str, int] = defaultdict(int) + tag_counts: dict[str, int] = defaultdict(int) + file_set: set[str] = set() + internal_edges = 0 + external_edges = 0 + + for m in members: + path_prefix = "/".join(m.file_path.split("/")[:2]) if "/" in m.file_path else m.file_path + path_counts[path_prefix] += 1 + for tag in m.semantic.tags: + tag_counts[tag] += 1 + file_set.add(m.file_path) + for neighbor_id in adjacency.get(m.id, set()): + if assignments.get(neighbor_id) == cid: + internal_edges += 1 + else: + external_edges += 1 + + majority_path = max(path_counts, key=path_counts.get) if path_counts else "" + top_tags_sorted = sorted(tag_counts, key=tag_counts.get, reverse=True)[:5] + + parent = "" + if level == 1: + for m in members: + parent = self._node_communities_l0.get(m.id, "") + break + + communities[cid] = Community( + id=cid, + level=level, + label=label + "_" + majority_path.split("/")[-1] if majority_path else label, + parent_community=parent, + majority_path_prefix=majority_path, + top_tags=top_tags_sorted, + member_count=len(members), + file_count=len(file_set), + internal_edge_count=internal_edges // 2, + external_edge_count=external_edges, + modularity_score=0.0, + detected_at=datetime.now(UTC).isoformat(), + ) + + return communities + + async def _store_community_node(self, comm: Community) -> None: + comm_node = GraphNode( + id=comm.id, + type=NodeType("Community") if "Community" in NodeType._value2member_map_ else NodeType.FILE, + file_path=comm.majority_path_prefix, + structural=__import__("smp.core.models", fromlist=["StructuralProperties"]).StructuralProperties( + name=comm.label, + file=comm.majority_path_prefix, + ), + semantic=__import__("smp.core.models", fromlist=["SemanticProperties"]).SemanticProperties( + tags=comm.top_tags, + enriched_at=comm.detected_at, + ), + ) + await self._graph.upsert_node(comm_node) + + async def _write_member_of_edges(self, comm: Community, nodes: list[GraphNode], level: int) -> None: + assignments = self._node_communities_l1 if level == 1 else self._node_communities_l0 + for node in nodes: + if assignments.get(node.id) == comm.id: + edge = GraphEdge( + source_id=node.id, + target_id=comm.id, + type=EdgeType.MEMBER_OF if "MEMBER_OF" in EdgeType._value2member_map_ else EdgeType.REFERENCES, + metadata={"community_level": str(level)}, + ) + await self._graph.upsert_edge(edge) + + async def _detect_bridges( + self, + nodes: list[GraphNode], + adjacency: dict[str, set[str]], + ) -> list[dict[str, Any]]: + bridges: list[dict[str, Any]] = [] + comm_pairs: dict[tuple[str, str], list[str]] = defaultdict(list) + + for node in nodes: + cid = self._node_communities_l1.get(node.id, "") + if not cid: + continue + for neighbor_id in adjacency.get(node.id, set()): + neighbor_cid = self._node_communities_l1.get(neighbor_id, "") + if neighbor_cid and neighbor_cid != cid: + pair = tuple(sorted([cid, neighbor_cid])) + comm_pairs[pair].append(node.id) + + for (c1, c2), bridge_nodes in comm_pairs.items(): + coupling = len(bridge_nodes) / max(self._communities.get(c1, Community()).member_count, 1) + bridges.append( + { + "from_community": c1, + "to_community": c2, + "edge_count": len(bridge_nodes), + "coupling_weight": round(coupling, 4), + "bridge_nodes": bridge_nodes, + } + ) + return bridges + + async def _write_bridges_edges(self) -> None: + for bridge in self._bridges: + edge_type = EdgeType.BRIDGES if "BRIDGES" in EdgeType._value2member_map_ else EdgeType.REFERENCES + edge = GraphEdge( + source_id=bridge["from_community"], + target_id=bridge["to_community"], + type=edge_type, + metadata={"coupling_weight": str(bridge.get("coupling_weight", ""))}, + ) + await self._graph.upsert_edge(edge) + + async def _compute_centroids(self, nodes: list[GraphNode]) -> None: + if self._vector is None: + return + from smp.engine.seed_walk import _simple_hash_embedding + + comm_nodes: dict[str, list[GraphNode]] = defaultdict(list) + for node in nodes: + cid = self._node_communities_l1.get(node.id, "") + if cid: + comm_nodes[cid].append(node) + + for cid, members in comm_nodes.items(): + if not members: + continue + all_vecs: list[list[float]] = [] + for m in members: + text = m.structural.name + " " + (m.semantic.docstring or "") + vec = _simple_hash_embedding(text) + all_vecs.append(vec) + + dim = len(all_vecs[0]) if all_vecs else 128 + centroid = [0.0] * dim + for vec in all_vecs: + for i in range(dim): + centroid[i] += vec[i] + n = len(all_vecs) if all_vecs else 1 + centroid = [c / n for c in centroid] + + comm = self._communities.get(cid) + label = comm.label if comm else cid + majority_path = comm.majority_path_prefix if comm else "" + + await self._vector.add_code_embedding( + node_id=f"centroid_{cid}", + embedding=centroid, + metadata={ + "collection_type": "centroid", + "community_id": cid, + "label": label, + "majority_path_prefix": majority_path, + "member_count": str(len(members)), + }, + document=label, + ) diff --git a/smp/engine/embedding.py b/smp/engine/embedding.py new file mode 100644 index 0000000..4d10940 --- /dev/null +++ b/smp/engine/embedding.py @@ -0,0 +1,124 @@ +"""Embedding service using NVIDIA NIM or OpenAI.""" + +from __future__ import annotations + +import os +from typing import Any + +import httpx + +from smp.logging import get_logger + +log = get_logger(__name__) + + +class EmbeddingService: + """Generate embeddings via NVIDIA NIM or OpenAI.""" + + def __init__( + self, + provider: str = "nvidia", + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + dimension: int = 768, + ) -> None: + self._provider = provider + self._api_key = api_key or os.environ.get("NVIDIA_NIM_API_KEY") or os.environ.get("OPENAI_API_KEY", "") + self._model = model or os.environ.get("EMBEDDING_MODEL", "nvidia/nv-embed-qa-4") + self._base_url = base_url or os.environ.get( + "EMBEDDING_BASE_URL", "https://integrate.api.nvidia.com/v1" + ) + self._dimension = dimension + self._client: httpx.AsyncClient | None = None + + async def connect(self) -> None: + self._client = httpx.AsyncClient( + base_url=self._base_url, + headers={"Authorization": f"Bearer {self._api_key}"}, + timeout=60.0, + ) + log.info("embedding_service_connected", provider=self._provider, model=self._model) + + async def close(self) -> None: + if self._client: + await self._client.aclose() + self._client = None + + @property + def dimension(self) -> int: + return self._dimension + + async def embed(self, text: str) -> list[float]: + """Generate embedding for a single text.""" + if self._client is None: + raise RuntimeError("EmbeddingService not connected") + + if self._provider == "nvidia": + return await self._embed_nvidia(text) + elif self._provider == "openai": + return await self._embed_openai(text) + else: + raise ValueError(f"Unknown provider: {self._provider}") + + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + """Generate embeddings for multiple texts.""" + if self._client is None: + raise RuntimeError("EmbeddingService not connected") + + if self._provider == "nvidia": + return await self._embed_batch_nvidia(texts) + elif self._provider == "openai": + return await self._embed_batch_openai(texts) + else: + raise ValueError(f"Unknown provider: {self._provider}") + + async def _embed_nvidia(self, text: str) -> list[float]: + payload = { + "input": text, + "model": self._model, + } + response = await self._client.post("/embeddings", json=payload) + response.raise_for_status() + data = response.json() + return data["data"][0]["embedding"] + + async def _embed_batch_nvidia(self, texts: list[str]) -> list[list[float]]: + payload = { + "input": texts, + "model": self._model, + } + response = await self._client.post("/embeddings", json=payload) + response.raise_for_status() + data = response.json() + return [item["embedding"] for item in data["data"]] + + async def _embed_openai(self, text: str) -> list[float]: + payload = { + "input": text, + "model": self._model, + } + response = await self._client.post("/embeddings", json=payload) + response.raise_for_status() + data = response.json() + return data["data"][0]["embedding"] + + async def _embed_batch_openai(self, texts: list[str]) -> list[list[float]]: + payload = { + "input": texts, + "model": self._model, + } + response = await self._client.post("/embeddings", json=payload) + response.raise_for_status() + data = response.json() + return [item["embedding"] for item in data["data"]] + + +def create_embedding_service() -> EmbeddingService: + """Create embedding service from environment variables.""" + provider = os.getenv("EMBEDDING_PROVIDER", "nvidia") + api_key = os.getenv("NVIDIA_NIM_API_KEY") or os.getenv("OPENAI_API_KEY") + model = os.getenv("EMBEDDING_MODEL") + base_url = os.getenv("EMBEDDING_BASE_URL") + dimension = int(os.getenv("EMBEDDING_DIMENSION", "768")) + return EmbeddingService(provider=provider, api_key=api_key, model=model, base_url=base_url, dimension=dimension) \ No newline at end of file diff --git a/smp/engine/enricher.py b/smp/engine/enricher.py new file mode 100644 index 0000000..ab53941 --- /dev/null +++ b/smp/engine/enricher.py @@ -0,0 +1,113 @@ +"""Static semantic enricher with optional LLM-based embedding.""" + +from __future__ import annotations + +import hashlib +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from smp.core.models import GraphNode +from smp.engine.interfaces import SemanticEnricher as SemanticEnricherInterface +from smp.logging import get_logger + +if TYPE_CHECKING: + from smp.engine.embedding import EmbeddingService + +log = get_logger(__name__) + + +def _compute_source_hash(name: str, file_path: str, start: int, end: int, signature: str) -> str: + """Compute deterministic source hash for a node.""" + raw = f"{file_path}:{name}:{start}:{end}:{signature}" + return hashlib.sha256(raw.encode()).hexdigest()[:8] + + +class StaticSemanticEnricher(SemanticEnricherInterface): + """Static AST-based semantic enricher with optional embedding support.""" + + def __init__(self, embedding_service: EmbeddingService | None = None) -> None: + self._enrichment_counts: dict[str, int] = { + "enriched": 0, + "skipped": 0, + "no_metadata": 0, + "failed": 0, + } + self._embedding_service = embedding_service + + def set_embedding_service(self, service: EmbeddingService) -> None: + self._embedding_service = service + + async def enrich_node( + self, + node: GraphNode, + force: bool = False, + ) -> GraphNode: + """Enrich a single node with static metadata.""" + sem = node.semantic + current_hash = _compute_source_hash( + node.structural.name, + node.file_path, + node.structural.start_line, + node.structural.end_line, + node.structural.signature, + ) + + if not force and sem.source_hash and sem.source_hash == current_hash and sem.status != "no_metadata": + self._enrichment_counts["skipped"] += 1 + return node + + sem.source_hash = current_hash + + has_docstring = bool(sem.docstring and sem.docstring.strip()) + has_decorators = bool(sem.decorators) + has_annotations = bool(sem.annotations and (sem.annotations.params or sem.annotations.returns)) + + if not has_docstring and not has_decorators and not has_annotations: + sem.status = "no_metadata" + self._enrichment_counts["no_metadata"] += 1 + sem.enriched_at = datetime.now(UTC).isoformat() + return node + + sem.status = "enriched" + sem.enriched_at = datetime.now(UTC).isoformat() + + self._enrichment_counts["enriched"] += 1 + return node + + async def enrich_batch( + self, + nodes: list[GraphNode], + force: bool = False, + ) -> list[GraphNode]: + """Enrich multiple nodes.""" + enriched = [] + for node in nodes: + result = await self.enrich_node(node, force=force) + enriched.append(result) + return enriched + + @property + def has_llm(self) -> bool: + """Check if LLM-based embedding is available.""" + return self._embedding_service is not None + + async def embed(self, text: str) -> list[float]: + """Generate embedding using the embedding service if available.""" + if self._embedding_service is None: + return [] + return await self._embedding_service.embed(text) + + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + """Generate embeddings for multiple texts.""" + if self._embedding_service is None: + return [[] for _ in texts] + return await self._embedding_service.embed_batch(texts) + + def get_counts(self) -> dict[str, int]: + """Return enrichment statistics.""" + return dict(self._enrichment_counts) + + def reset_counts(self) -> None: + """Reset enrichment counters.""" + for key in self._enrichment_counts: + self._enrichment_counts[key] = 0 \ No newline at end of file diff --git a/smp/engine/graph_builder.py b/smp/engine/graph_builder.py new file mode 100644 index 0000000..cf28008 --- /dev/null +++ b/smp/engine/graph_builder.py @@ -0,0 +1,159 @@ +"""Graph builder — maps parsed Documents into the graph store with Global Linking. + +Updated for SMP(3) partitioned data model. +""" + +from __future__ import annotations + +from smp.core.models import Document, GraphEdge, NodeType +from smp.engine.interfaces import GraphBuilder as GraphBuilderInterface +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +class DefaultGraphBuilder(GraphBuilderInterface): + def __init__(self, graph_store: GraphStore) -> None: + self._store = graph_store + self._pending_edges: list[tuple[GraphEdge, str, str]] = [] + + async def ingest_document(self, document: Document) -> None: + name_to_id = {n.structural.name: n.id for n in document.nodes} + + import_map: dict[str, tuple[str, str]] = {} + for node in document.nodes: + if node.type != NodeType.FILE: + continue + sig = node.structural.signature + if "import" not in sig: + continue + module_path = node.structural.name.replace(".", "/") + ".py" + if sig.strip().startswith("from"): + after_import = sig.split("import", 1)[1] + for raw_name in after_import.split(","): + stripped = raw_name.strip() + if not stripped: + continue + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + else: + parts = sig.replace("import", "").strip().split(",") + for p in parts: + stripped = p.strip() + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + + if document.nodes: + await self._store.upsert_nodes(document.nodes) + + resolved_edges: list[GraphEdge] = [] + for edge in document.edges: + parts = edge.target_id.split("::") + if len(parts) >= 4 and parts[-1] == "0": + entity_name = parts[2] + + if entity_name in name_to_id: + edge.target_id = name_to_id[entity_name] + resolved_edges.append(edge) + continue + + if entity_name in import_map: + module_path, original_name = import_map[entity_name] + target_id = await self._resolve_cross_file( + original_name, + module_path, + ) + if target_id: + edge.target_id = target_id + log.info( + "linker_resolved_cross_file", + name=entity_name, + original=original_name, + target=target_id, + ) + resolved_edges.append(edge) + else: + fallback = f"{module_path}::Function::{original_name}::1" + edge.target_id = fallback + self._pending_edges.append((edge, original_name, module_path)) + log.info( + "linker_cross_file_pending", + name=entity_name, + original=original_name, + target=fallback, + ) + else: + resolved_edges.append(edge) + else: + resolved_edges.append(edge) + + if resolved_edges: + await self._store.upsert_edges(resolved_edges) + + log.info("ingest_complete", file=document.file_path, resolved=len(resolved_edges)) + + async def _resolve_cross_file( + self, + entity_name: str, + module_path: str, + ) -> str | None: + """Look up the actual node ID for a cross-file reference.""" + candidates = await self._store.find_nodes(name=entity_name) + if not candidates: + return None + + if not module_path: + return candidates[0].id + + stem = module_path.rsplit("/", 1)[-1] + + for n in candidates: + if n.file_path == module_path: + return n.id + for n in candidates: + if n.file_path.endswith(stem): + return n.id + + return candidates[0].id + + async def resolve_pending_edges(self) -> int: + """Re-attempt cross-file edges that were deferred.""" + if not self._pending_edges: + return 0 + + fixed = 0 + still_pending: list[tuple[GraphEdge, str, str]] = [] + resolved: list[GraphEdge] = [] + for edge, original_name, module_path in self._pending_edges: + real_id = await self._resolve_cross_file(original_name, module_path) + if real_id: + edge.target_id = real_id + log.info( + "linker_pending_resolved", + original=original_name, + target=real_id, + ) + resolved.append(edge) + fixed += 1 + else: + still_pending.append((edge, original_name, module_path)) + + if resolved: + await self._store.upsert_edges(resolved) + + self._pending_edges = still_pending + if fixed: + log.info("resolve_pending_complete", fixed=fixed, remaining=len(still_pending)) + return fixed + + async def remove_document(self, file_path: str) -> None: + await self._store.delete_nodes_by_file(file_path) diff --git a/smp/engine/handoff.py b/smp/engine/handoff.py new file mode 100644 index 0000000..5a1492d --- /dev/null +++ b/smp/engine/handoff.py @@ -0,0 +1,203 @@ +"""Handoff layer for code review and PR creation. + +Manages the transition from AI-generated changes to human review, +including PR creation, review workflows, and approval tracking. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + + +@dataclass +class ReviewRequest: + """A request for human review.""" + + review_id: str + session_id: str + files_changed: list[str] + diff_summary: str + created_at: str + status: str = "pending" + reviewers: list[str] = field(default_factory=list) + approvals: list[str] = field(default_factory=list) + rejections: list[str] = field(default_factory=list) + comments: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class PRInfo: + """Information about a created PR.""" + + pr_id: str + review_id: str + title: str + body: str + branch: str + base_branch: str + url: str | None = None + created_at: str = "" + status: str = "open" + + +class HandoffManager: + """Manages code review and PR workflows.""" + + def __init__(self) -> None: + self._reviews: dict[str, ReviewRequest] = {} + self._prs: dict[str, PRInfo] = {} + + def create_review( + self, + session_id: str, + files_changed: list[str], + diff_summary: str, + reviewers: list[str] | None = None, + ) -> ReviewRequest: + """Create a new review request.""" + review_id = f"rev_{uuid.uuid4().hex[:8]}" + + review = ReviewRequest( + review_id=review_id, + session_id=session_id, + files_changed=files_changed, + diff_summary=diff_summary, + created_at=datetime.now(UTC).isoformat(), + reviewers=reviewers or [], + ) + self._reviews[review_id] = review + + log.info("review_created", review_id=review_id, files=len(files_changed)) + return review + + def add_comment( + self, + review_id: str, + author: str, + comment: str, + file_path: str | None = None, + line: int | None = None, + ) -> bool: + """Add a comment to a review.""" + review = self._reviews.get(review_id) + if not review: + return False + + comment_data: dict[str, Any] = { + "author": author, + "comment": comment, + "timestamp": datetime.now(UTC).isoformat(), + } + if file_path: + comment_data["file_path"] = file_path + if line: + comment_data["line"] = line + + review.comments.append(comment_data) + log.info("review_comment_added", review_id=review_id, author=author) + return True + + def approve(self, review_id: str, reviewer: str) -> bool: + """Record an approval for a review.""" + review = self._reviews.get(review_id) + if not review: + return False + + if reviewer not in review.approvals: + review.approvals.append(reviewer) + + if reviewer in review.rejections: + review.rejections.remove(reviewer) + + self._update_review_status(review) + log.info("review_approved", review_id=review_id, reviewer=reviewer) + return True + + def reject(self, review_id: str, reviewer: str, reason: str = "") -> bool: + """Record a rejection for a review.""" + review = self._reviews.get(review_id) + if not review: + return False + + if reviewer not in review.rejections: + review.rejections.append(reviewer) + + if reviewer in review.approvals: + review.approvals.remove(reviewer) + + self._update_review_status(review) + log.info("review_rejected", review_id=review_id, reviewer=reviewer, reason=reason) + return True + + def _update_review_status(self, review: ReviewRequest) -> None: + """Update review status based on approvals/rejections.""" + if len(review.rejections) > 0: + review.status = "rejected" + elif len(review.approvals) >= len(review.reviewers) and review.reviewers: + review.status = "approved" + + def create_pr( + self, + review_id: str, + title: str, + body: str, + branch: str, + base_branch: str = "main", + ) -> PRInfo | None: + """Create a PR for an approved review.""" + review = self._reviews.get(review_id) + if not review: + return None + + pr_id = f"pr_{uuid.uuid4().hex[:8]}" + + pr = PRInfo( + pr_id=pr_id, + review_id=review_id, + title=title, + body=body, + branch=branch, + base_branch=base_branch, + created_at=datetime.now(UTC).isoformat(), + ) + self._prs[pr_id] = pr + + review.status = "pr_created" + log.info("pr_created", pr_id=pr_id, review_id=review_id) + return pr + + def get_review(self, review_id: str) -> ReviewRequest | None: + """Get review by ID.""" + return self._reviews.get(review_id) + + def get_pr(self, pr_id: str) -> PRInfo | None: + """Get PR by ID.""" + return self._prs.get(pr_id) + + def list_pending_reviews(self) -> list[ReviewRequest]: + """List all pending reviews.""" + return [r for r in self._reviews.values() if r.status == "pending"] + + def get_review_summary(self, review_id: str) -> dict[str, Any] | None: + """Get summary of a review.""" + review = self._reviews.get(review_id) + if not review: + return None + + return { + "review_id": review.review_id, + "session_id": review.session_id, + "status": review.status, + "files_count": len(review.files_changed), + "reviewers": len(review.reviewers), + "approvals": len(review.approvals), + "rejections": len(review.rejections), + "comments_count": len(review.comments), + } diff --git a/smp/engine/integrity.py b/smp/engine/integrity.py new file mode 100644 index 0000000..3e44bcd --- /dev/null +++ b/smp/engine/integrity.py @@ -0,0 +1,242 @@ +"""Integrity verification module for AST-based data-flow analysis. + +Verifies that runtime behavior matches structural expectations by +analyzing data flow through the AST and detecting mutations. +""" + +from __future__ import annotations + +import subprocess +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +@dataclass +class MutationRecord: + """Record of a detected mutation.""" + + node_id: str + mutation_type: str + field_name: str + old_value: str + new_value: str + detected_at: str + + +@dataclass +class DataFlowPath: + """Represents a data flow path through the code.""" + + source_node: str + target_node: str + path: list[str] + flow_type: str + transformations: list[str] = field(default_factory=list) + + +@dataclass +class IntegrityCheckResult: + """Result of an integrity verification.""" + + passed: bool + node_id: str + checks_run: int + mutations_detected: list[MutationRecord] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + +class IntegrityVerifier: + """Verifies structural integrity of graph nodes.""" + + def __init__(self) -> None: + self._mutations: list[MutationRecord] = [] + self._baselines: dict[str, dict[str, Any]] = {} + + async def capture_baseline(self, node_id: str, state: dict[str, Any]) -> None: + """Capture baseline state for a node.""" + self._baselines[node_id] = { + "state": state.copy(), + "captured_at": datetime.now(UTC).isoformat(), + } + log.debug("baseline_captured", node_id=node_id) + + async def verify( + self, + node_id: str, + current_state: dict[str, Any], + ) -> IntegrityCheckResult: + """Verify node state against baseline.""" + baseline = self._baselines.get(node_id) + mutations: list[MutationRecord] = [] + warnings: list[str] = [] + + checks_run = 1 + + if baseline: + for field_name, baseline_value in baseline["state"].items(): + current_value = current_state.get(field_name) + + if baseline_value != current_value: + mutation = MutationRecord( + node_id=node_id, + mutation_type="field_change", + field_name=field_name, + old_value=str(baseline_value), + new_value=str(current_value), + detected_at=datetime.now(UTC).isoformat(), + ) + mutations.append(mutation) + self._mutations.append(mutation) + + warnings.append(f"{field_name} changed from {baseline_value} to {current_value}") + + passed = len(mutations) == 0 + + log.info( + "integrity_check", + node_id=node_id, + passed=passed, + mutations=len(mutations), + ) + + return IntegrityCheckResult( + passed=passed, + node_id=node_id, + checks_run=checks_run, + mutations_detected=mutations, + warnings=warnings, + ) + + def analyze_data_flow( + self, + source: str, + sink: str, + path_nodes: list[str], + ) -> DataFlowPath: + """Analyze data flow from source to sink.""" + transformations = [] + + for i in range(len(path_nodes) - 1): + transformations.append(f"{path_nodes[i]} → {path_nodes[i + 1]}") + + return DataFlowPath( + source_node=source, + target_node=sink, + path=path_nodes, + flow_type="data", + transformations=transformations, + ) + + def get_mutations(self, node_id: str | None = None) -> list[MutationRecord]: + """Get all detected mutations, optionally filtered by node.""" + if node_id: + return [m for m in self._mutations if m.node_id == node_id] + return list(self._mutations) + + def clear_mutations(self) -> None: + """Clear mutation history.""" + self._mutations.clear() + log.info("mutations_cleared") + + def get_mutation_summary(self) -> dict[str, Any]: + """Return summary of detected mutations.""" + by_node: dict[str, int] = {} + for m in self._mutations: + by_node[m.node_id] = by_node.get(m.node_id, 0) + 1 + + return { + "total_mutations": len(self._mutations), + "affected_nodes": len(by_node), + "by_node": by_node, + } + + async def run_mutation_test( + self, + node_id: str, + graph_store: GraphStore, + ) -> IntegrityCheckResult: + """Run mutation testing on a specific node. + + Mutates operators in the source code and checks if tests still pass. + """ + node = await graph_store.get_node(node_id) + if not node: + log.error("mutation_test_failed", reason="node_not_found", node_id=node_id) + return IntegrityCheckResult(passed=False, node_id=node_id, checks_run=0) + + file_path = node.file_path + try: + with open(file_path) as f: + lines = f.readlines() + except OSError as e: + log.error("mutation_test_failed", reason="file_read_error", error=str(e)) + return IntegrityCheckResult(passed=False, node_id=node_id, checks_run=0) + + mutants_survived = 0 + checks_run = 0 + detected_mutations: list[MutationRecord] = [] + + # Simple operator flips + operators = {"==": "!=", "!=": "==", ">": "<=", "<": ">=", ">=": "<=", "<=": ">"} + + start = max(0, node.structural.start_line - 1) + end = min(len(lines), node.structural.end_line) + + for i in range(start, end): + line = lines[i] + for op, replacement in operators.items(): + if op in line: + checks_run += 1 + original_line = line + lines[i] = line.replace(op, replacement, 1) + + try: + with open(file_path, "w") as f: + f.writelines(lines) + + # Run tests + result = subprocess.run(["pytest"], capture_output=True, text=True, timeout=30) + + if result.returncode == 0: + mutants_survived += 1 + mutation = MutationRecord( + node_id=node_id, + mutation_type="operator_flip", + field_name=f"line_{i + 1}", + old_value=op, + new_value=replacement, + detected_at=datetime.now(UTC).isoformat(), + ) + detected_mutations.append(mutation) + self._mutations.append(mutation) + + except (subprocess.TimeoutExpired, OSError) as e: + log.warning("mutation_test_warning", error=str(e)) + finally: + lines[i] = original_line + with open(file_path, "w") as f: + f.writelines(lines) + + passed = mutants_survived == 0 + + log.info( + "mutation_test_completed", + node_id=node_id, + passed=passed, + survived=mutants_survived, + total=checks_run, + ) + + return IntegrityCheckResult( + passed=passed, + node_id=node_id, + checks_run=checks_run, + mutations_detected=detected_mutations, + warnings=[f"{mutants_survived} mutants survived"] if mutants_survived > 0 else [], + ) diff --git a/smp/engine/interfaces.py b/smp/engine/interfaces.py new file mode 100644 index 0000000..eaf3b91 --- /dev/null +++ b/smp/engine/interfaces.py @@ -0,0 +1,157 @@ +"""Abstract base classes for the engine layer. + +Defines the contracts for parsing, graph building, semantic enrichment, +and querying for SMP(3). +""" + +from __future__ import annotations + +import abc +from typing import Any + +from smp.core.models import ( + Document, + GraphNode, +) + + +class Parser(abc.ABC): + """Extract typed AST nodes and edges from source code.""" + + @abc.abstractmethod + def parse(self, source: str, file_path: str) -> Document: + """Parse *source* and return a :class:`Document`.""" + + @property + @abc.abstractmethod + def supported_languages(self) -> list[str]: + """Return language names this parser handles.""" + + +class GraphBuilder(abc.ABC): + """Map parsed :class:`Document` elements into a graph store.""" + + @abc.abstractmethod + async def ingest_document(self, document: Document) -> None: + """Write the document's nodes and edges into the graph store.""" + + @abc.abstractmethod + async def remove_document(self, file_path: str) -> None: + """Remove all graph data for *file_path*.""" + + +class SemanticEnricher(abc.ABC): + """Generate static semantic summaries from AST metadata.""" + + @abc.abstractmethod + async def enrich_node(self, node: GraphNode, force: bool = False) -> GraphNode: + """Return a copy of *node* with :class:`SemanticProperties` populated.""" + + @abc.abstractmethod + async def enrich_batch(self, nodes: list[GraphNode], force: bool = False) -> list[GraphNode]: + """Enrich multiple nodes.""" + + @abc.abstractmethod + async def embed(self, text: str) -> list[float]: + """No-op for static enricher.""" + + @abc.abstractmethod + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + """Generate embeddings for multiple texts.""" + + +class QueryEngine(abc.ABC): + """High-level query interface over the memory store.""" + + @abc.abstractmethod + async def navigate(self, query: str, include_relationships: bool = True) -> dict[str, Any]: + """Find entity and return basic info with relationships.""" + + @abc.abstractmethod + async def trace( + self, + start: str, + relationship: str = "CALLS", + depth: int = 3, + direction: str = "outgoing", + ) -> list[dict[str, Any]]: + """Follow relationship chain from start node.""" + + @abc.abstractmethod + async def get_context( + self, + file_path: str, + scope: str = "edit", + depth: int = 2, + ) -> dict[str, Any]: + """Aggregate structural context for safe editing — the programmer's mental model.""" + + @abc.abstractmethod + async def assess_impact(self, entity: str, change_type: str = "delete") -> dict[str, Any]: + """Find blast radius of a change.""" + + @abc.abstractmethod + async def locate( + self, + query: str, + fields: list[str] | None = None, + node_types: list[str] | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + """Keyword search ranked by match quality.""" + + @abc.abstractmethod + async def search( + self, + query: str, + match: str = "any", + filters: dict[str, Any] | None = None, + top_k: int = 5, + ) -> dict[str, Any]: + """Pure keyword/token search across docstrings and tags.""" + + @abc.abstractmethod + async def find_flow( + self, + start: str, + end: str, + flow_type: str = "data", + ) -> dict[str, Any]: + """Trace execution/data flow between two nodes.""" + + @abc.abstractmethod + async def diff( + self, + from_snapshot: str, + to_snapshot: str, + scope: str = "full", + ) -> dict[str, Any]: + """Compare two snapshots and return differences.""" + + @abc.abstractmethod + async def plan( + self, + change_description: str, + target_file: str, + change_type: str = "refactor", + scope: str = "full", + ) -> dict[str, Any]: + """Generate a change plan for proposed modifications.""" + + @abc.abstractmethod + async def conflict( + self, + entity: str, + proposed_change: str, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Check for conflicts in proposed changes.""" + + @abc.abstractmethod + async def why( + self, + entity: str, + relationship: str = "", + depth: int = 3, + ) -> dict[str, Any]: + """Explain why a relationship exists.""" diff --git a/smp/engine/linker.py b/smp/engine/linker.py new file mode 100644 index 0000000..66a10bc --- /dev/null +++ b/smp/engine/linker.py @@ -0,0 +1,205 @@ +"""Graph linker module for resolving cross-file references. + +Implements the SMP(3) linker spec: +- Resolves namespaced CALLS edges (file::function) +- Supports global linking across the graph +- Handles pending edges for forward references +""" + +from __future__ import annotations + +from typing import Any + +from smp.core.models import Document, EdgeType, GraphEdge, GraphNode, NodeType +from smp.logging import get_logger + +log = get_logger(__name__) + + +class Linker: + """Resolves cross-file references and creates CALLS edges.""" + + def __init__(self) -> None: + self._pending_edges: list[tuple[GraphEdge, str, str]] = [] + self._import_maps: dict[str, dict[str, tuple[str, str]]] = {} + + def build_import_map( + self, + document: Document, + nodes: list[GraphNode], + ) -> dict[str, tuple[str, str]]: + """Build import map from document nodes. + + Returns dict mapping imported names to (module_path, original_name). + """ + import_map: dict[str, tuple[str, str]] = {} + + for node in nodes: + if node.type != NodeType.FILE: + continue + + sig = node.structural.signature + if "import" not in sig: + continue + + module_path = node.structural.name.replace(".", "/") + ".py" + + if sig.strip().startswith("from"): + after_import = sig.split("import", 1)[1] + for raw_name in after_import.split(","): + stripped = raw_name.strip() + if not stripped: + continue + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + else: + parts = sig.replace("import", "").strip().split(",") + for p in parts: + stripped = p.strip() + if " as " in stripped: + original, alias = stripped.split(" as ", 1) + import_map[alias.strip()] = (module_path, original.strip()) + else: + name = stripped.split()[0] + import_map[name] = (module_path, name) + + self._import_maps[document.file_path] = import_map + return import_map + + async def resolve_calls( + self, + edges: list[GraphEdge], + nodes: list[GraphNode], + graph_store: Any, + ) -> tuple[list[GraphEdge], list[tuple[GraphEdge, str, str]]]: + """Resolve CALLS edges to target node IDs. + + Returns (resolved_edges, pending_edges). + """ + name_to_id = {n.structural.name: n.id for n in nodes} + file_path = nodes[0].file_path if nodes else "" + import_map = self._import_maps.get(file_path, {}) + + resolved: list[GraphEdge] = [] + pending: list[tuple[GraphEdge, str, str]] = [] + + for edge in edges: + if edge.type != EdgeType.CALLS: + resolved.append(edge) + continue + + target_id = edge.target_id + parts = target_id.split("::") + + if len(parts) >= 4 and parts[-1] == "0": + entity_name = parts[2] + + if entity_name in name_to_id: + edge.target_id = name_to_id[entity_name] + resolved.append(edge) + log.debug("linker_resolved_local", name=entity_name, target=edge.target_id) + continue + + if entity_name in import_map: + module_path, original_name = import_map[entity_name] + resolved_target = await self._resolve_cross_file( + original_name, + module_path, + graph_store, + ) + + if resolved_target: + edge.target_id = resolved_target + resolved.append(edge) + log.info( + "linker_resolved_cross_file", + name=entity_name, + original=original_name, + target=resolved_target, + ) + else: + fallback = f"{module_path}::Function::{original_name}::1" + edge.target_id = fallback + pending.append((edge, original_name, module_path)) + log.info( + "linker_cross_file_pending", + name=entity_name, + original=original_name, + target=fallback, + ) + resolved.append(edge) + else: + resolved.append(edge) + + return resolved, pending + + async def _resolve_cross_file( + self, + entity_name: str, + module_path: str, + graph_store: Any, + ) -> str | None: + """Look up the actual node ID for a cross-file reference.""" + candidates = await graph_store.find_nodes(name=entity_name) + if not candidates: + return None + + if not module_path: + return candidates[0].id + + stem = module_path.rsplit("/", 1)[-1] + + for n in candidates: + if n.file_path == module_path: + return n.id + + for n in candidates: + if n.file_path.endswith(stem): + return n.id + + return candidates[0].id + + async def resolve_pending(self, graph_store: Any) -> int: + """Re-attempt pending edge resolutions.""" + if not self._pending_edges: + return 0 + + fixed = 0 + still_pending: list[tuple[GraphEdge, str, str]] = [] + resolved: list[GraphEdge] = [] + + for edge, original_name, module_path in self._pending_edges: + real_id = await self._resolve_cross_file(original_name, module_path, graph_store) + if real_id: + edge.target_id = real_id + log.info( + "linker_pending_resolved", + original=original_name, + target=real_id, + ) + resolved.append(edge) + fixed += 1 + else: + still_pending.append((edge, original_name, module_path)) + + if resolved: + await graph_store.upsert_edges(resolved) + + self._pending_edges = still_pending + if fixed: + log.info("resolve_pending_complete", fixed=fixed, remaining=len(still_pending)) + + return fixed + + def get_pending_count(self) -> int: + """Return count of pending edges.""" + return len(self._pending_edges) + + def clear_pending(self) -> None: + """Clear all pending edges.""" + self._pending_edges.clear() + log.info("linker_pending_cleared") diff --git a/smp/engine/notification.py b/smp/engine/notification.py new file mode 100644 index 0000000..0943911 --- /dev/null +++ b/smp/engine/notification.py @@ -0,0 +1,90 @@ +"""Notification manager for server-push events. + +Provides a polling-based notification system for clients to receive +real-time updates about graph changes, session events, etc. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + + +@dataclass +class Notification: + """A single notification event.""" + + notification_id: str + event_type: str + payload: dict[str, Any] + timestamp: str + session_id: str = "" + + +class NotificationManager: + """Manages notifications with in-memory storage.""" + + def __init__(self, max_events: int = 1000) -> None: + self._events: list[Notification] = [] + self._max_events = max_events + self._subscribers: dict[str, asyncio.Queue[Notification]] = {} + + def emit( + self, + event_type: str, + payload: dict[str, Any], + session_id: str = "", + ) -> None: + """Emit a new notification event.""" + notification = Notification( + notification_id=f"notif_{len(self._events)}", + event_type=event_type, + payload=payload, + timestamp=datetime.now(UTC).isoformat(), + session_id=session_id, + ) + self._events.append(notification) + + # Trim if exceeding max + if len(self._events) > self._max_events: + self._events = self._events[-self._max_events :] + + log.debug("notification_emitted", event_type=event_type) + + def poll(self, last_seen: int = 0) -> list[dict[str, Any]]: + """Poll for new notifications since last_seen index.""" + if last_seen >= len(self._events): + return [] + + recent = self._events[last_seen:] + return [ + { + "index": last_seen + i, + "notification_id": n.notification_id, + "event_type": n.event_type, + "payload": n.payload, + "timestamp": n.timestamp, + "session_id": n.session_id, + } + for i, n in enumerate(recent) + ] + + def get_recent(self, limit: int = 50) -> list[dict[str, Any]]: + """Get the most recent notifications.""" + recent = self._events[-limit:] + return [ + { + "notification_id": n.notification_id, + "event_type": n.event_type, + "payload": n.payload, + "timestamp": n.timestamp, + "session_id": n.session_id, + } + for n in recent + ] diff --git a/smp/engine/pagerank.py b/smp/engine/pagerank.py new file mode 100644 index 0000000..16672b2 --- /dev/null +++ b/smp/engine/pagerank.py @@ -0,0 +1,119 @@ +"""PageRank engine for calculating node importance in the structural graph. + +Implements an iterative PageRank algorithm to identify central entities based on +graph connectivity (in-degree and relationship importance). +""" + +from __future__ import annotations + +from collections import defaultdict + +from smp.core.models import GraphEdge, GraphNode +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +class PageRankEngine: + """Calculates importance scores for graph nodes using the PageRank algorithm.""" + + def __init__(self, damping: float = 0.85, max_iterations: int = 100, tol: float = 1e-6) -> None: + """Initialize PageRank engine. + + Args: + damping: Damping factor (probability of following a link). + max_iterations: Maximum number of iterations to run. + tol: Convergence threshold. + """ + self.damping = damping + self.max_iterations = max_iterations + self.tol = tol + + def compute(self, nodes: list[GraphNode], edges: list[GraphEdge]) -> dict[str, float]: + """Compute PageRank scores for the given nodes and edges. + + Args: + nodes: List of nodes in the graph. + edges: List of directed edges in the graph. + + Returns: + A dictionary mapping node IDs to their calculated PageRank scores. + """ + if not nodes: + return {} + + n = len(nodes) + node_ids = [node.id for node in nodes] + id_to_idx = {node_id: i for i, node_id in enumerate(node_ids)} + + # Adjacency list and out-degrees + adj = defaultdict(list) + out_degree = defaultdict(int) + for edge in edges: + if edge.source_id in id_to_idx and edge.target_id in id_to_idx: + adj[edge.target_id].append(edge.source_id) + out_degree[edge.source_id] += 1 + + # Initial scores + scores = [1.0 / n] * n + + for iteration in range(self.max_iterations): + new_scores = [0.0] * n + total_dangling_weight = 0.0 + + # Handle dangling nodes (nodes with no outgoing edges) + for i in range(n): + if out_degree[node_ids[i]] == 0: + total_dangling_weight += scores[i] + + for i in range(n): + target_id = node_ids[i] + # Sum of PageRank from neighbors + rank_sum = sum(scores[id_to_idx[src]] / out_degree[src] for src in adj[target_id]) + + # Calculate new score + new_scores[i] = (1.0 - self.damping) / n + self.damping * (rank_sum + total_dangling_weight / n) + + # Check convergence + diff = sum(abs(new_scores[i] - scores[i]) for i in range(n)) + if diff < self.tol: + log.debug("pagerank_converged", iteration=iteration, diff=diff) + scores = new_scores + break + + scores = new_scores + + return {node_ids[i]: scores[i] for i in range(n)} + + async def update_node_scores(self, graph_store: GraphStore) -> int: + """Update nodes in the graph store with their computed PageRank scores. + + Args: + graph_store: The graph store to update. + + Returns: + Number of nodes updated. + """ + # Use a broad search to get all nodes. In a real scenario, + # we might want to filter by type or scope. + nodes = await graph_store.find_nodes() + + # We need all edges to compute PageRank. + # This is expensive for large graphs; a real implementation would use GDS. + all_edges: list[GraphEdge] = [] + for node in nodes: + edges = await graph_store.get_edges(node.id, direction="outgoing") + all_edges.extend(edges) + + scores = self.compute(nodes, all_edges) + + updated_count = 0 + for node in nodes: + score = scores.get(node.id, 0.0) + node.semantic.score = score + await graph_store.upsert_node(node) + updated_count += 1 + + log.info("pagerank_scores_updated", count=updated_count) + return updated_count diff --git a/smp/engine/query.py b/smp/engine/query.py new file mode 100644 index 0000000..cbbafa7 --- /dev/null +++ b/smp/engine/query.py @@ -0,0 +1,817 @@ +"""Query engine — high-level structural queries over the memory store. + +Provides navigate, trace, get_context, assess_impact, locate, search, +and find_flow queries backed by the graph store. +""" + +from __future__ import annotations + +from collections import deque +from typing import Any + +from smp.core.models import EdgeType, GraphNode, NodeType +from smp.engine.interfaces import QueryEngine as QueryEngineInterface +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + +_HTTP_VERB_DECORATORS = {"get", "post", "put", "delete", "patch", "head", "options"} +_UTILITY_PATH_SEGMENTS = {"/utils", "/lib", "/shared", "/helpers"} + + +class DefaultQueryEngine(QueryEngineInterface): + """Query engine backed by a graph store.""" + + def __init__( + self, + graph_store: GraphStore, + enricher: Any | None = None, + ) -> None: + self._graph = graph_store + self._enricher = enricher + + def _node_to_dict(self, node: GraphNode) -> dict[str, Any]: + return { + "id": node.id, + "type": node.type.value, + "file_path": node.file_path, + "name": node.structural.name, + "signature": node.structural.signature, + "start_line": node.structural.start_line, + "end_line": node.structural.end_line, + "complexity": node.structural.complexity, + "lines": node.structural.lines, + "semantic": { + "status": node.semantic.status, + "docstring": node.semantic.docstring, + "description": node.semantic.description, + "decorators": node.semantic.decorators, + "tags": node.semantic.tags, + }, + } + + async def navigate(self, query: str, include_relationships: bool = True) -> dict[str, Any]: + node = await self._graph.get_node(query) + + # If exact match fails, try to find by file path or name + if not node: + # Check if query looks like a file path + if "/" in query or query.endswith(".py"): + candidates = await self._graph.find_nodes(file_path=query) + if candidates: + node = candidates[0] + else: + # Try finding by name + candidates = await self._graph.find_nodes(name=query) + if candidates: + node = candidates[0] + + # If still not found, try partial match on node ID prefix + if not node: + all_nodes = await self._graph.find_nodes() + for n in all_nodes: + if n.id.startswith(query) or query in n.id: + node = n + break + + if not node: + return {"error": f"Node {query} not found"} + + result: dict[str, Any] = {"entity": self._node_to_dict(node)} + + if include_relationships: + outgoing = await self._graph.get_edges(node.id, direction="outgoing") + incoming = await self._graph.get_edges(node.id, direction="incoming") + + calls = [e.target_id for e in outgoing if e.type == EdgeType.CALLS] + called_by = [e.source_id for e in incoming if e.type == EdgeType.CALLS] + depends_on = [e.target_id for e in outgoing if e.type == EdgeType.DEPENDS_ON] + imported_by = [e.source_id for e in incoming if e.type == EdgeType.IMPORTS] + + result["relationships"] = { + "calls": calls, + "called_by": called_by, + "depends_on": depends_on, + "imported_by": imported_by, + } + + return result + + async def trace( + self, + start: str, + relationship: str = "CALLS", + depth: int = 3, + direction: str = "outgoing", + ) -> list[dict[str, Any]]: + try: + et = EdgeType(relationship) + except ValueError: + et = EdgeType.CALLS + nodes = await self._graph.traverse(start, et, depth, max_nodes=100, direction=direction) + return [self._node_to_dict(n) for n in nodes] + + async def get_context( + self, + file_path: str, + scope: str = "edit", + depth: int = 2, + ) -> dict[str, Any]: + file_nodes = await self._graph.find_nodes(file_path=file_path) + if not file_nodes: + return {"error": f"No nodes found for {file_path}"} + + file_node = file_nodes[0] + file_id = file_node.id + + imports = await self._graph.get_edges(file_id, EdgeType.IMPORTS, direction="outgoing") + imported_by = await self._graph.get_edges(file_id, EdgeType.IMPORTS, direction="incoming") + defines = await self._graph.get_edges(file_id, EdgeType.DEFINES, direction="outgoing") + tests_edges = await self._graph.get_edges(file_id, EdgeType.TESTS, direction="incoming") + + defines_nodes: list[dict[str, Any]] = [] + complexities: list[int] = [] + exported_symbols: list[str] = [] + http_decorators: list[str] = [] + test_file_paths: list[str] = [] + + for edge in defines: + target = await self._graph.get_node(edge.target_id) + if target: + defines_nodes.append(self._node_to_dict(target)) + complexities.append(target.structural.complexity) + exported_symbols.append(target.structural.name) + for dec in target.semantic.decorators: + dec_lower = dec.lstrip("@").lower() + if dec_lower in _HTTP_VERB_DECORATORS: + http_decorators.append(dec) + + for te in tests_edges: + source = await self._graph.get_node(te.source_id) + if source and source.file_path not in test_file_paths: + test_file_paths.append(source.file_path) + + has_tests = len(test_file_paths) > 0 + + related_patterns: list[dict[str, Any]] = [] + all_nodes = await self._graph.find_nodes() + for candidate in all_nodes: + if candidate.id == file_id or candidate.file_path == file_path: + continue + if candidate.type == file_node.type: + name_sim = self._name_similarity(file_node.structural.name, candidate.structural.name) + if name_sim > 0.5: + related_patterns.append( + { + "file_path": candidate.file_path, + "name": candidate.structural.name, + "similarity": round(name_sim, 2), + } + ) + related_patterns.sort(key=lambda x: -x["similarity"]) + related_patterns = related_patterns[:5] + + entry_points: list[dict[str, Any]] = [] + if http_decorators: + for edge in defines: + target = await self._graph.get_node(edge.target_id) + if target: + target_http = [ + d for d in target.semantic.decorators if d.lstrip("@").lower() in _HTTP_VERB_DECORATORS + ] + if target_http: + entry_points.append( + { + "name": target.structural.name, + "decorators": target_http, + "file_path": target.file_path, + } + ) + + data_flow_in: list[dict[str, Any]] = [] + data_flow_out: list[dict[str, Any]] = [] + + callers_in = await self._graph.traverse( + file_id, EdgeType.CALLS, depth=depth, max_nodes=50, direction="incoming" + ) + for caller in callers_in: + data_flow_in.append( + { + "node_id": caller.id, + "name": caller.structural.name, + "file_path": caller.file_path, + } + ) + + callers_out = await self._graph.traverse( + file_id, EdgeType.CALLS, depth=depth, max_nodes=50, direction="outgoing" + ) + for callee in callers_out: + data_flow_out.append( + { + "node_id": callee.id, + "name": callee.structural.name, + "file_path": callee.file_path, + } + ) + + role = self._classify_role(file_node, imported_by, defines_nodes, http_decorators) + avg_complexity = round(sum(complexities) / max(len(complexities), 1), 1) + max_complexity = max(complexities, default=0) + blast_radius = len(imported_by) + + imported_by_api = 0 + for edge in imported_by: + source = await self._graph.get_node(edge.source_id) + if source and "/api" in source.file_path: + imported_by_api += 1 + + is_hot_node = blast_radius > 10 or max_complexity > 8 + heat_score = blast_radius + max_complexity + + if blast_radius > 10 or avg_complexity > 8: + risk_level = "high" + elif blast_radius > 3 or avg_complexity > 4: + risk_level = "medium" + else: + risk_level = "low" + + summary = { + "role": role, + "blast_radius": blast_radius, + "api_layer_callers": imported_by_api, + "avg_complexity": avg_complexity, + "max_complexity": max_complexity, + "exported_symbols": exported_symbols, + "has_tests": has_tests, + "test_files": test_file_paths, + "is_hot_node": is_hot_node, + "heat_score": heat_score, + "risk_level": risk_level, + } + + return { + "self": self._node_to_dict(file_node), + "imports": [{"source": e.source_id, "target": e.target_id} for e in imports], + "imported_by": [{"source": e.source_id, "target": e.target_id} for e in imported_by], + "defines": defines_nodes, + "related_patterns": related_patterns, + "entry_points": entry_points, + "data_flow_in": data_flow_in, + "data_flow_out": data_flow_out, + "summary": summary, + } + + @staticmethod + def _name_similarity(name_a: str, name_b: str) -> float: + if not name_a or not name_b: + return 0.0 + set_a = set(name_a.lower()) + set_b = set(name_b.lower()) + if not set_a or not set_b: + return 0.0 + intersection = set_a & set_b + union = set_a | set_b + return len(intersection) / len(union) + + def _classify_role( + self, + file_node: GraphNode, + imported_by: list[Any], + defines_nodes: list[dict[str, Any]], + http_decorators: list[str], + ) -> str: + path = file_node.file_path + if "/test" in path or "/spec" in path: + return "test" + if file_node.type == NodeType.CONFIG: + return "config" + if http_decorators: + return "endpoint" + if "/routes" in path or "/controllers" in path: + return "endpoint" + incoming_imports = len(imported_by) + if "/services" in path and incoming_imports > 0: + return "service" + if incoming_imports > 5 and any(seg in path for seg in _UTILITY_PATH_SEGMENTS): + return "core_utility" + if incoming_imports == 0 and not defines_nodes: + return "isolated" + return "module" + + async def assess_impact(self, entity: str, change_type: str = "delete") -> dict[str, Any]: + node = await self._graph.get_node(entity) + + # If exact match fails, try to find by file path or name + if not node: + if "/" in entity or entity.endswith(".py"): + candidates = await self._graph.find_nodes(file_path=entity) + if candidates: + node = candidates[0] + else: + candidates = await self._graph.find_nodes(name=entity) + if candidates: + node = candidates[0] + + # Try partial match if still not found + if not node: + all_nodes = await self._graph.find_nodes() + for n in all_nodes: + if entity in n.id: + node = n + break + + if not node: + return {"error": f"Node {entity} not found"} + + dependents = await self._graph.traverse(node.id, EdgeType.CALLS, depth=10, max_nodes=200, direction="incoming") + + affected_files: list[str] = [] + affected_functions: list[str] = [] + for dep in dependents: + if dep.file_path not in affected_files: + affected_files.append(dep.file_path) + affected_functions.append(dep.structural.name) + + severity = "low" + if len(dependents) > 10: + severity = "high" + elif len(dependents) > 3: + severity = "medium" + + recommendations: list[str] = [] + if change_type == "signature_change": + recommendations.append(f"Update {len(dependents)} callers to match new signature") + elif change_type == "delete": + recommendations.append(f"Remove or stub {len(dependents)} dependent references") + + return { + "affected_files": affected_files, + "affected_functions": affected_functions, + "severity": severity, + "recommendations": recommendations, + } + + async def locate( + self, + query: str, + fields: list[str] | None = None, + node_types: list[str] | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + if not fields: + fields = ["name", "docstring", "tags"] + + terms = query.lower().split() + all_nodes = await self._graph.find_nodes() + + scored: list[tuple[int, dict[str, Any]]] = [] + for node in all_nodes: + if node_types and node.type.value not in node_types: + continue + + score = 0 + matched_on = "" + + name_lower = node.structural.name.lower() + if all(t in name_lower for t in terms): + score = 100 + matched_on = "name" + elif any(t in name_lower for t in terms): + score = 50 + matched_on = "name" + elif node.semantic.docstring: + doc_lower = node.semantic.docstring.lower() + if all(t in doc_lower for t in terms): + score = 30 + matched_on = "docstring" + elif any(t in doc_lower for t in terms): + score = 15 + matched_on = "docstring" + + if score > 0: + for tag in node.semantic.tags: + if any(t in tag.lower() for t in terms): + score += 10 + if matched_on: + matched_on += ", tags" + else: + matched_on = "tags" + break + + scored.append( + ( + score, + { + "entity": node.structural.name, + "file": node.file_path, + "matched_on": matched_on, + "docstring": node.semantic.docstring, + "tags": node.semantic.tags, + }, + ) + ) + + scored.sort(key=lambda x: -x[0]) + return [item[1] for item in scored[:top_k]] + + async def search( + self, + query: str, + match: str = "any", + filters: dict[str, Any] | None = None, + top_k: int = 5, + ) -> dict[str, Any]: + filters = filters or {} + terms = query.split() + node_types = filters.get("node_types") + tags = filters.get("tags") + scope = filters.get("scope") + + results = await self._graph.search_nodes( + query_terms=terms, + match=match, + node_types=node_types, + tags=tags, + scope=scope, + top_k=top_k, + ) + + if not results: + return { + "matches": [], + "total": 0, + "hint": "Try broadening scope or using match: any", + } + + return {"matches": results, "total": len(results)} + + async def find_flow( + self, + start: str, + end: str, + flow_type: str = "data", + ) -> dict[str, Any]: + # Resolve start and end nodes + start_node = await self._graph.get_node(start) + if not start_node: + candidates = await self._graph.find_nodes(name=start) + if candidates: + start_node = candidates[0] + + end_node = await self._graph.get_node(end) + if not end_node: + candidates = await self._graph.find_nodes(name=end) + if candidates: + end_node = candidates[0] + + if start == end: + if start_node: + return { + "path": [{"node": start_node.structural.name, "type": start_node.type.value}], + "data_transformations": [], + } + return {"path": [], "data_transformations": []} + + if not start_node or not end_node: + return {"path": [], "data_transformations": []} + + paths = await self._bfs_paths(start, end) + if not paths: + return {"path": [], "data_transformations": []} + + best_path = paths[0] + path_nodes = [] + for nid in best_path: + node = await self._graph.get_node(nid) + if node: + path_nodes.append({"node": node.structural.name, "type": node.type.value}) + + transformations: list[str] = [] + for i in range(len(path_nodes) - 1): + transformations.append(f"{path_nodes[i]['node']} → {path_nodes[i + 1]['node']}") + + return { + "path": path_nodes, + "data_transformations": transformations, + } + + async def _bfs_paths(self, start_id: str, end_id: str) -> list[list[str]]: + """BFS to find shortest paths.""" + found_paths: list[list[str]] = [] + queue: deque[tuple[str, list[str]]] = deque([(start_id, [start_id])]) + visited: set[str] = set() + + while queue and len(found_paths) < 3: + current, path = queue.popleft() + if len(path) > 20: + continue + + edges = await self._graph.get_edges(current, direction="outgoing") + edges += await self._graph.get_edges(current, direction="incoming") + + neighbors: set[str] = set() + for e in edges: + neighbors.add(e.target_id if e.source_id == current else e.source_id) + + for neighbor in neighbors: + if neighbor == end_id: + found_paths.append(path + [neighbor]) + continue + if neighbor not in visited and neighbor not in path: + visited.add(neighbor) + queue.append((neighbor, path + [neighbor])) + + return found_paths + + async def diff( + self, + from_snapshot: str, + to_snapshot: str, + scope: str = "full", + ) -> dict[str, Any]: + """Compare two snapshots and return the differences.""" + from_nodes = await self._graph.find_nodes_by_scope(from_snapshot) + to_nodes = await self._graph.find_nodes_by_scope(to_snapshot) + + from_ids = {n.id for n in from_nodes} + to_ids = {n.id for n in to_nodes} + + added = to_ids - from_ids + removed = from_ids - to_ids + common = from_ids & to_ids + + changed: list[str] = [] + for node_id in common: + from_node = next((n for n in from_nodes if n.id == node_id), None) + to_node = next((n for n in to_nodes if n.id == node_id), None) + if from_node and to_node and from_node.semantic.source_hash != to_node.semantic.source_hash: + changed.append(node_id) + + return { + "from_snapshot": from_snapshot, + "to_snapshot": to_snapshot, + "added": list(added), + "removed": list(removed), + "changed": changed, + "stats": { + "added_count": len(added), + "removed_count": len(removed), + "changed_count": len(changed), + }, + } + + async def plan( + self, + change_description: str, + target_file: str, + change_type: str = "refactor", + scope: str = "full", + ) -> dict[str, Any]: + """Generate a change plan for proposed modifications.""" + file_nodes = await self._graph.find_nodes(file_path=target_file) + + affected_nodes: list[str] = [] + for node in file_nodes: + callers = await self._graph.traverse(node.id, EdgeType.CALLS, depth=10, max_nodes=200, direction="incoming") + if callers: + affected_nodes.append(node.id) + + steps: list[dict[str, str]] = [ + {"step": "1", "action": "Backup current state", "details": f"Snapshot {target_file}"}, + {"step": "2", "action": "Apply changes", "details": change_description}, + {"step": "3", "action": "Run tests", "details": f"Test affected nodes: {len(affected_nodes)}"}, + ] + + if change_type == "signature_change": + steps.append( + { + "step": "4", + "action": "Update callers", + "details": f"Update {len(affected_nodes)} dependent functions", + } + ) + + return { + "change_description": change_description, + "target_file": target_file, + "change_type": change_type, + "affected_nodes": affected_nodes, + "steps": steps, + "risk_level": "high" if len(affected_nodes) > 10 else "medium" if affected_nodes else "low", + } + + async def conflict( + self, + entity: str, + proposed_change: str, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Check for conflicts in proposed changes.""" + node = await self._graph.get_node(entity) + if not node: + candidates = await self._graph.find_nodes(name=entity) + if candidates: + node = candidates[0] + + if not node: + return {"conflict": False, "reason": f"Entity {entity} not found"} + + edges = await self._graph.get_edges(node.id, direction="incoming") + callers = [e.source_id for e in edges if e.type == EdgeType.CALLS] + + conflicts: list[str] = [] + warnings: list[str] = [] + + if len(callers) > 5: + conflicts.append(f"Entity has {len(callers)} callers - high blast radius") + + if node.semantic.manually_set: + warnings.append("Entity has manually set annotations - may need re-annotation") + + if context and context.get("session_id"): + locked_files = context.get("locked_files", []) + if node.file_path in locked_files: + conflicts.append(f"File {node.file_path} is locked by another session") + + return { + "entity": entity, + "proposed_change": proposed_change, + "conflict": len(conflicts) > 0, + "conflicts": conflicts, + "warnings": warnings, + "caller_count": len(callers), + } + + async def why( + self, + entity: str, + relationship: str = "", + depth: int = 3, + ) -> dict[str, Any]: + """Explain why a relationship exists between entities.""" + node = await self._graph.get_node(entity) + + # If exact match fails, try to find by file path or name + if not node: + if "/" in entity or entity.endswith(".py"): + candidates = await self._graph.find_nodes(file_path=entity) + if candidates: + node = candidates[0] + else: + candidates = await self._graph.find_nodes(name=entity) + if candidates: + node = candidates[0] + + if not node: + return {"error": f"Entity {entity} not found"} + + reasons: list[dict[str, Any]] = [] + + incoming = await self._graph.get_edges(node.id, direction="incoming") + outgoing = await self._graph.get_edges(node.id, direction="outgoing") + + for edge in incoming[:depth]: + source = await self._graph.get_node(edge.source_id) + if source: + reasons.append( + { + "type": "incoming", + "edge_type": edge.type.value, + "from": source.structural.name, + "file": source.file_path, + "reason": f"{source.structural.name} {edge.type.value} {node.structural.name}", + } + ) + + for edge in outgoing[:depth]: + target = await self._graph.get_node(edge.target_id) + if target: + reasons.append( + { + "type": "outgoing", + "edge_type": edge.type.value, + "to": target.structural.name, + "file": target.file_path, + "reason": f"{node.structural.name} {edge.type.value} {target.structural.name}", + } + ) + + return { + "entity": entity, + "name": node.structural.name, + "file": node.file_path, + "reasons": reasons, + "total_relationships": len(incoming) + len(outgoing), + } + + async def diff_file( + self, + file_path: str, + proposed_content: str | None = None, + ) -> dict[str, Any]: + """Compare current graph state of a file against proposed new content.""" + current_nodes = await self._graph.find_nodes(file_path=file_path) + current_node_ids = {n.id for n in current_nodes} + current_calls: dict[str, set[str]] = {n.id: set() for n in current_nodes} + + for node in current_nodes: + edges = await self._graph.get_edges(node.id, direction="outgoing") + for e in edges: + if e.type == EdgeType.CALLS: + current_calls[node.id].add(e.target_id) + + if proposed_content: + from smp.parser.base import detect_language + from smp.parser.registry import ParserRegistry + + registry = ParserRegistry() + lang = detect_language(file_path) + parser = registry.get(lang) + if not parser: + from smp.core.models import Language + + parser = registry.get(Language.PYTHON) + if parser: + proposed_data = parser.parse(proposed_content, file_path) + proposed_node_ids = {n.id for n in proposed_data.nodes} + else: + proposed_node_ids = current_node_ids + else: + proposed_node_ids = current_node_ids + + nodes_added = list(proposed_node_ids - current_node_ids) + nodes_removed = list(current_node_ids - proposed_node_ids) + nodes_modified: list[str] = [] + + return { + "nodes_added": nodes_added, + "nodes_removed": nodes_removed, + "nodes_modified": nodes_modified, + "relationships_added": [], + "relationships_removed": [], + } + + async def plan_multi_file( + self, + session_id: str, + task: str, + intended_writes: list[str], + ) -> dict[str, Any]: + """Validate and rank a multi-file task before execution.""" + file_dependencies: dict[str, set[str]] = {} + + for file_path in intended_writes: + nodes = await self._graph.find_nodes(file_path=file_path) + deps = set() + for node in nodes: + edges = await self._graph.get_edges(node.id, direction="outgoing") + for e in edges: + if e.type == EdgeType.CALLS: + deps.add(e.target_id) + file_dependencies[file_path] = deps + + execution_order = [] + for i, file_path in enumerate(intended_writes, 1): + current_nodes = await self._graph.find_nodes(file_path=file_path) + dependants = 0 + for fp in intended_writes: + if fp != file_path: + for n in current_nodes: + if n.id in file_dependencies.get(fp, set()): + dependants += 1 + + outgoing = [] + for n in current_nodes: + edges = await self._graph.get_edges(n.id, direction="outgoing") + outgoing.extend([e.target_id for e in edges]) + + execution_order.append( + { + "step": i, + "file": file_path, + "dependants_in_plan": dependants, + "dependencies_in_plan": len(outgoing), + "blast_radius": dependants, + "risk_level": "high" if dependants > 3 else "medium" if dependants > 0 else "low", + } + ) + + return { + "execution_order": execution_order, + "inter_file_conflicts": [], + "external_files_at_risk": [], + } + + async def detect_conflict( + self, + session_a: str, + session_b: str, + ) -> dict[str, Any]: + """Detect scope overlap between two planned sessions.""" + return { + "has_conflict": False, + "overlapping_files": [], + "conflicting_nodes": [], + } diff --git a/smp/engine/runtime_linker.py b/smp/engine/runtime_linker.py new file mode 100644 index 0000000..83643a8 --- /dev/null +++ b/smp/engine/runtime_linker.py @@ -0,0 +1,212 @@ +"""Runtime linker for tracking actual execution paths. + +Records CALLS_RUNTIME edges based on telemetry data to build +a runtime call graph that complements the static analysis. +""" + +from __future__ import annotations + +import uuid +from collections import defaultdict +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any + +from smp.core.models import EdgeType, GraphEdge, RuntimeEdge, RuntimeTrace +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + + +@dataclass +class RuntimeCall: + """A single runtime call observation.""" + + source_id: str + target_id: str + timestamp: str + session_id: str + duration_ms: int = 0 + + +class RuntimeLinker: + """Tracks and records runtime execution paths.""" + + def __init__(self) -> None: + self._calls: list[RuntimeCall] = [] + self._traces: dict[str, RuntimeTrace] = {} + self._session_traces: dict[str, list[str]] = defaultdict(list) + self._call_counts: dict[tuple[str, str], int] = defaultdict(int) + + def record_call( + self, + source_id: str, + target_id: str, + session_id: str, + duration_ms: int = 0, + ) -> RuntimeEdge: + """Record a runtime call observation.""" + trace_id = f"trace_{uuid.uuid4().hex[:8]}" + timestamp = datetime.now(UTC).isoformat() + + call = RuntimeCall( + source_id=source_id, + target_id=target_id, + timestamp=timestamp, + session_id=session_id, + duration_ms=duration_ms, + ) + self._calls.append(call) + + key = (source_id, target_id) + self._call_counts[key] += 1 + + edge = RuntimeEdge( + source_id=source_id, + target_id=target_id, + edge_type="CALLS_RUNTIME", + timestamp=timestamp, + session_id=session_id, + trace_id=trace_id, + duration_ms=duration_ms, + ) + + self._session_traces[session_id].append(trace_id) + + log.debug( + "runtime_call_recorded", + source=source_id, + target=target_id, + session=session_id, + ) + return edge + + def start_trace( + self, + session_id: str, + agent_id: str, + ) -> str: + """Start a new runtime trace.""" + trace_id = f"trc_{uuid.uuid4().hex[:8]}" + timestamp = datetime.now(UTC).isoformat() + + trace = RuntimeTrace( + trace_id=trace_id, + session_id=session_id, + agent_id=agent_id, + started_at=timestamp, + ) + self._traces[trace_id] = trace + + log.info("trace_started", trace_id=trace_id, session=session_id) + return trace_id + + def end_trace(self, trace_id: str) -> RuntimeTrace | None: + """End a runtime trace.""" + trace = self._traces.get(trace_id) + if not trace: + return None + + trace.ended_at = datetime.now(UTC).isoformat() + + related_calls = [c for c in self._calls if c.session_id == trace.session_id] + trace.edges = [ + RuntimeEdge( + source_id=c.source_id, + target_id=c.target_id, + edge_type="CALLS_RUNTIME", + timestamp=c.timestamp, + session_id=c.session_id, + trace_id=trace_id, + duration_ms=c.duration_ms, + ) + for c in related_calls + ] + + visited: set[str] = set() + for edge in trace.edges: + visited.add(edge.source_id) + visited.add(edge.target_id) + trace.nodes_visited = list(visited) + + log.info( + "trace_ended", + trace_id=trace_id, + edges=len(trace.edges), + nodes=len(trace.nodes_visited), + ) + return trace + + def get_trace(self, trace_id: str) -> RuntimeTrace | None: + """Get trace by ID.""" + return self._traces.get(trace_id) + + def get_session_traces(self, session_id: str) -> list[RuntimeTrace]: + """Get all traces for a session.""" + trace_ids = self._session_traces.get(session_id, []) + return [self._traces[tid] for tid in trace_ids if tid in self._traces] + + def get_hot_paths(self, threshold: int = 10) -> list[dict[str, Any]]: + """Return frequently executed paths.""" + hot = [] + + for (source, target), count in self._call_counts.items(): + if count >= threshold: + hot.append( + { + "source_id": source, + "target_id": target, + "call_count": count, + } + ) + + hot.sort(key=lambda x: -int(x["call_count"])) + return hot + + def get_stats(self) -> dict[str, Any]: + """Return runtime linker statistics.""" + return { + "total_calls": len(self._calls), + "unique_paths": len(self._call_counts), + "active_traces": len(self._traces), + "sessions_with_traces": len(self._session_traces), + } + + def clear(self) -> None: + """Clear all runtime data.""" + self._calls.clear() + self._traces.clear() + self._session_traces.clear() + self._call_counts.clear() + log.info("runtime_linker_cleared") + + async def inject_runtime_edges(self, graph_store: GraphStore) -> int: + """Inject recorded runtime calls as edges into the graph store. + + Args: + graph_store: The graph store to update. + + Returns: + Number of edges injected. + """ + edges_to_inject: list[GraphEdge] = [] + + for call in self._calls: + edge = GraphEdge( + source_id=call.source_id, + target_id=call.target_id, + type=EdgeType.CALLS_RUNTIME, + metadata={ + "timestamp": call.timestamp, + "session_id": call.session_id, + "duration_ms": str(call.duration_ms), + }, + ) + edges_to_inject.append(edge) + + if edges_to_inject: + await graph_store.upsert_edges(edges_to_inject) + + log.info("runtime_edges_injected", count=len(edges_to_inject)) + return len(edges_to_inject) diff --git a/smp/engine/safety.py b/smp/engine/safety.py new file mode 100644 index 0000000..ccfbf28 --- /dev/null +++ b/smp/engine/safety.py @@ -0,0 +1,590 @@ +"""Agent Safety Protocol — sessions, guards, dry-runs, locks, checkpoints, audit. + +Implements the full SMP(3) agent write lifecycle: + session/open → guard/check → dryrun → checkpoint → write → update → session/close +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from smp.logging import get_logger + +if TYPE_CHECKING: + from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + +_SESSION_TTL_SECONDS = 3600 + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass +class Session: + """Represents an active agent session.""" + + session_id: str + agent_id: str + task: str + scope: list[str] + mode: str + granted_scope: list[str] + denied_scope: list[str] + opened_at: str + expires_at: str + status: str = "open" + files_written: list[str] = field(default_factory=list) + files_read: list[str] = field(default_factory=list) + + +@dataclass +class AuditEvent: + """A single event in the audit log.""" + + timestamp: str + method: str + target: str = "" + result: str = "" + checkpoint_id: str = "" + files: list[str] = field(default_factory=list) + + +@dataclass +class AuditLog: + """Full audit record for a session.""" + + audit_log_id: str + agent_id: str + task: str + session_id: str + opened_at: str + closed_at: str = "" + status: str = "open" + events: list[AuditEvent] = field(default_factory=list) + + +@dataclass +class Checkpoint: + """Snapshot of files before a write.""" + + checkpoint_id: str + session_id: str + files: dict[str, str] + snapshot_at: str + + +# --------------------------------------------------------------------------- +# Session Manager +# --------------------------------------------------------------------------- + + +class SessionManager: + """Manages agent session lifecycle with scope enforcement and auto-expiry.""" + + def __init__( + self, + ttl_seconds: int = _SESSION_TTL_SECONDS, + graph_store: GraphStore | None = None, + ) -> None: + self._sessions: dict[str, Session] = {} + self._ttl = ttl_seconds + self._graph = graph_store + + def set_graph_store(self, graph_store: GraphStore) -> None: + """Set the graph store for session persistence.""" + self._graph = graph_store + + async def _persist_session(self, session: Session) -> None: + """Persist session to graph if available.""" + if self._graph: + await self._graph.upsert_session(session) + + async def _load_session(self, session_id: str) -> Session | None: + """Load session from graph if available.""" + if self._graph: + data = await self._graph.get_session(session_id) + if data: + return Session( + session_id=data["session_id"], + agent_id=data["agent_id"], + task=data["task"], + scope=data.get("scope", []), + mode=data.get("mode", "read"), + granted_scope=data.get("granted_scope", []), + denied_scope=data.get("denied_scope", []), + opened_at=data["opened_at"], + expires_at=data["expires_at"], + status=data.get("status", "open"), + files_written=data.get("files_written", []), + files_read=data.get("files_read", []), + ) + return None + + async def open_session( + self, + agent_id: str, + task: str, + scope: list[str], + mode: str = "read", + ) -> dict[str, Any]: + """Open a new session and return the result dict.""" + session_id = f"ses_{uuid.uuid4().hex[:6]}" + now = datetime.now(UTC) + expires = now.timestamp() + self._ttl + + granted = [] + denied = [] + warnings = [] + + for path in scope: + p = Path(path) + if p.exists() or not p.suffix: + granted.append(path) + else: + denied.append(path) + + for path in granted: + caller_count = 0 + if caller_count > 10: + warnings.append(f"{path} is imported by {caller_count} files — changes have wide blast radius") + + session = Session( + session_id=session_id, + agent_id=agent_id, + task=task, + scope=scope, + mode=mode, + granted_scope=granted, + denied_scope=denied, + opened_at=now.isoformat(), + expires_at=datetime.fromtimestamp(expires, tz=UTC).isoformat(), + ) + self._sessions[session_id] = session + await self._persist_session(session) + + log.info("session_opened", session_id=session_id, agent_id=agent_id, mode=mode) + return { + "session_id": session_id, + "granted_scope": granted, + "denied_scope": denied, + "active_locks": [], + "warnings": warnings, + "expires_at": session.expires_at, + } + + async def close_session(self, session_id: str, status: str = "completed") -> dict[str, Any] | None: + """Close a session and return summary.""" + session = self._sessions.get(session_id) + if not session: + return None + + session.status = status + now = datetime.now(UTC) + opened = datetime.fromisoformat(session.opened_at) + duration_ms = int((now - opened).total_seconds() * 1000) + + audit_log_id = f"aud_{uuid.uuid4().hex[:6]}" + + log.info("session_closed", session_id=session_id, status=status, duration_ms=duration_ms) + + if self._graph: + await self._graph.delete_session(session_id) + + return { + "session_id": session_id, + "files_written": session.files_written, + "files_read": session.files_read, + "duration_ms": duration_ms, + "audit_log_id": audit_log_id, + } + + async def get_session(self, session_id: str) -> Session | None: + """Get a session by ID, checking expiry.""" + session = self._sessions.get(session_id) + if not session: + return await self._load_session(session_id) + if session.status != "open": + return None + expires = datetime.fromisoformat(session.expires_at) + if datetime.now(UTC) > expires: + session.status = "expired" + return None + return session + + async def is_in_scope(self, session_id: str, file_path: str) -> bool: + """Check if file_path is within the session's granted scope.""" + session = await self.get_session(session_id) + if not session: + return False + return any(file_path == granted or file_path.startswith(granted) for granted in session.granted_scope) + + def record_file_access(self, session_id: str, file_path: str, access_type: str = "read") -> None: + """Record that a file was read or written in this session.""" + session = self._sessions.get(session_id) + if not session: + return + if access_type == "write" and file_path not in session.files_written: + session.files_written.append(file_path) + elif access_type == "read" and file_path not in session.files_read: + session.files_read.append(file_path) + + async def recover_session(self, session_id: str) -> dict[str, Any] | None: + """Recover a session from persistent storage.""" + session = await self._load_session(session_id) + if not session: + return None + self._sessions[session_id] = session + log.info("session_recovered", session_id=session_id) + return { + "session_id": session.session_id, + "agent_id": session.agent_id, + "task": session.task, + "scope": session.scope, + "mode": session.mode, + "opened_at": session.opened_at, + "expires_at": session.expires_at, + "status": session.status, + } + + +# --------------------------------------------------------------------------- +# Lock Manager +# --------------------------------------------------------------------------- + + +class LockManager: + """File-level locking to prevent concurrent writes.""" + + def __init__(self, graph_store: GraphStore | None = None) -> None: + self._locks: dict[str, str] = {} + self._graph = graph_store + + def set_graph_store(self, graph_store: GraphStore) -> None: + """Set the graph store for lock persistence.""" + self._graph = graph_store + + async def acquire(self, session_id: str, files: list[str]) -> dict[str, Any]: + """Acquire locks on files for a session.""" + granted = [] + denied = [] + for f in files: + if f in self._locks: + holder = self._locks[f] + if holder == session_id: + granted.append(f) + else: + denied.append(f) + else: + self._locks[f] = session_id + granted.append(f) + if self._graph: + await self._graph.upsert_lock(f, session_id) + + log.info("locks_acquired", session_id=session_id, granted=len(granted), denied=len(denied)) + return {"granted": granted, "denied": denied} + + async def release(self, session_id: str, files: list[str]) -> None: + """Release locks held by a session.""" + for f in files: + if self._locks.get(f) == session_id: + del self._locks[f] + if self._graph: + await self._graph.release_lock(f, session_id) + log.info("locks_released", session_id=session_id, files=len(files)) + + async def release_all(self, session_id: str) -> None: + """Release all locks held by a session.""" + to_release = [f for f, sid in self._locks.items() if sid == session_id] + for f in to_release: + del self._locks[f] + if self._graph: + await self._graph.release_all_locks(session_id) + + def is_locked(self, file_path: str) -> str | None: + """Return session_id that holds the lock, or None.""" + return self._locks.get(file_path) + + +# --------------------------------------------------------------------------- +# Guard Engine +# --------------------------------------------------------------------------- + + +class GuardEngine: + """Pre-flight safety checks before writing a file.""" + + def __init__(self, session_manager: SessionManager, lock_manager: LockManager) -> None: + self._sessions = session_manager + self._locks = lock_manager + + async def check( + self, + session_id: str, + target: str, + intended_change: str = "", + caller_count: int = 0, + has_tests: bool = False, + test_files: list[str] | None = None, + is_public_api: bool = False, + has_downstream: bool = False, + ) -> dict[str, Any]: + """Run pre-flight checks and return verdict.""" + reasons: list[str] = [] + warnings: list[str] = [] + checks: dict[str, Any] = {} + + session = await self._sessions.get_session(session_id) + if not session: + return {"verdict": "blocked", "reasons": ["Session not found or expired"]} + + in_scope = await self._sessions.is_in_scope(session_id, target) + locked_by = self._locks.is_locked(target) + locked_by_other = locked_by is not None and locked_by != session_id + + checks["in_declared_scope"] = in_scope + checks["locked_by_other_agent"] = locked_by_other + checks["has_tests"] = has_tests + checks["test_files"] = test_files or [] + checks["caller_count"] = caller_count + checks["is_public_api"] = is_public_api + checks["has_downstream_services"] = has_downstream + + if not in_scope: + reasons.append("File is outside declared session scope") + if locked_by_other: + reasons.append(f"Locked by session {locked_by}") + + if caller_count > 5: + warnings.append(f"Target has {caller_count} callers — changes will cascade") + if is_public_api: + warnings.append("Target is part of public API — signature changes are breaking") + if not has_tests and caller_count > 0: + warnings.append("No test coverage found — manual verification recommended") + + verdict = "blocked" if reasons else "clear" + + result: dict[str, Any] = { + "verdict": verdict, + "target": target, + "checks": checks, + "warnings": warnings, + } + if reasons: + result["reasons"] = reasons + + log.info("guard_check", target=target, verdict=verdict, session_id=session_id) + return result + + +# --------------------------------------------------------------------------- +# Dry Run Simulator +# --------------------------------------------------------------------------- + + +class DryRunSimulator: + """Simulate structural impact of proposed changes without disk writes.""" + + def __init__(self) -> None: + pass + + def simulate( + self, + session_id: str, + file_path: str, + proposed_content: str, + change_summary: str = "", + current_signature: str = "", + proposed_signature: str = "", + affected_files: list[str] | None = None, + broken_callers: list[dict[str, str]] | None = None, + ) -> dict[str, Any]: + """Simulate the write and return structural delta + verdict.""" + signature_changed = bool(current_signature and proposed_signature and current_signature != proposed_signature) + + nodes_added = 0 + nodes_modified = 1 + nodes_removed = 0 + + risks: list[str] = [] + if signature_changed: + risks.append("Signature change detected — may break callers") + if affected_files: + risks.append(f"{len(affected_files)} files may need updates") + if broken_callers: + for bc in broken_callers: + risks.append( + f"{bc.get('function', '?')} in {bc.get('file', '?')}: {bc.get('reason', 'incompatible change')}" + ) + + verdict = "breaking" if (signature_changed and (broken_callers or affected_files)) else "safe" + + result: dict[str, Any] = { + "structural_delta": { + "nodes_added": nodes_added, + "nodes_modified": nodes_modified, + "nodes_removed": nodes_removed, + "signature_changed": signature_changed, + }, + "impact": { + "affected_files": affected_files or [], + "broken_callers": broken_callers or [], + "test_coverage_delta": "unchanged", + }, + "verdict": verdict, + "risks": risks, + } + + log.info("dryrun_complete", file_path=file_path, verdict=verdict, session_id=session_id) + return result + + +# --------------------------------------------------------------------------- +# Checkpoint Manager +# --------------------------------------------------------------------------- + + +class CheckpointManager: + """Snapshot and restore file state.""" + + def __init__(self) -> None: + self._checkpoints: dict[str, Checkpoint] = {} + + def create(self, session_id: str, files: list[str]) -> dict[str, Any]: + """Create a checkpoint by snapshotting file contents.""" + checkpoint_id = f"chk_{uuid.uuid4().hex[:6]}" + now = datetime.now(UTC).isoformat() + + snapshots: dict[str, str] = {} + snapshotted: list[str] = [] + for f in files: + try: + content = Path(f).read_text(encoding="utf-8") + snapshots[f] = content + snapshotted.append(f) + except OSError: + log.warning("checkpoint_file_unreadable", file=f) + + checkpoint = Checkpoint( + checkpoint_id=checkpoint_id, + session_id=session_id, + files=snapshots, + snapshot_at=now, + ) + self._checkpoints[checkpoint_id] = checkpoint + + log.info("checkpoint_created", checkpoint_id=checkpoint_id, files=len(snapshotted)) + return { + "checkpoint_id": checkpoint_id, + "files_snapshotted": snapshotted, + "snapshot_at": now, + } + + def rollback(self, checkpoint_id: str) -> dict[str, Any]: + """Restore files from a checkpoint.""" + checkpoint = self._checkpoints.get(checkpoint_id) + if not checkpoint: + return {"status": "error", "reason": "Checkpoint not found"} + + restored: list[str] = [] + for f, content in checkpoint.files.items(): + try: + Path(f).write_text(content, encoding="utf-8") + restored.append(f) + except OSError as exc: + log.error("rollback_write_failed", file=f, error=str(exc)) + + log.info("rollback_complete", checkpoint_id=checkpoint_id, restored=len(restored)) + return { + "status": "rolled_back", + "files_restored": restored, + "memory_resynced": True, + } + + +# --------------------------------------------------------------------------- +# Audit Logger +# --------------------------------------------------------------------------- + + +class AuditLogger: + """Persistent append-only audit log for session events.""" + + def __init__(self) -> None: + self._logs: dict[str, AuditLog] = {} + + def create_log(self, agent_id: str, task: str, session_id: str) -> str: + """Create a new audit log for a session.""" + audit_log_id = f"aud_{uuid.uuid4().hex[:6]}" + now = datetime.now(UTC).isoformat() + self._logs[audit_log_id] = AuditLog( + audit_log_id=audit_log_id, + agent_id=agent_id, + task=task, + session_id=session_id, + opened_at=now, + ) + return audit_log_id + + def append_event( + self, + audit_log_id: str, + method: str, + target: str = "", + result: str = "", + checkpoint_id: str = "", + files: list[str] | None = None, + ) -> None: + """Append an event to an audit log.""" + log_entry = self._logs.get(audit_log_id) + if not log_entry: + return + event = AuditEvent( + timestamp=datetime.now(UTC).strftime("%H:%M:%S"), + method=method, + target=target, + result=result, + checkpoint_id=checkpoint_id, + files=files or [], + ) + log_entry.events.append(event) + + def close_log(self, audit_log_id: str, status: str = "completed") -> None: + """Mark an audit log as closed.""" + log_entry = self._logs.get(audit_log_id) + if log_entry: + log_entry.closed_at = datetime.now(UTC).isoformat() + log_entry.status = status + + def get_log(self, audit_log_id: str) -> dict[str, Any] | None: + """Retrieve an audit log.""" + log_entry = self._logs.get(audit_log_id) + if not log_entry: + return None + return { + "audit_log_id": log_entry.audit_log_id, + "agent_id": log_entry.agent_id, + "task": log_entry.task, + "session_id": log_entry.session_id, + "opened_at": log_entry.opened_at, + "closed_at": log_entry.closed_at, + "status": log_entry.status, + "events": [ + { + "t": e.timestamp, + "method": e.method, + "target": e.target, + "result": e.result, + "checkpoint_id": e.checkpoint_id, + "files": e.files, + } + for e in log_entry.events + ], + } diff --git a/smp/engine/seed_walk.py b/smp/engine/seed_walk.py new file mode 100644 index 0000000..84154da --- /dev/null +++ b/smp/engine/seed_walk.py @@ -0,0 +1,465 @@ +"""SeedWalkEngine — community-routed graph RAG pipeline for smp/locate. + +Phase 0 — ROUTE: Compare query embedding against community centroids. +Phase 1 — SEED: ChromaDB vector search scoped to community or global. +Phase 2 — WALK: Graph traversal from seeds via CALLS/IMPORTS/DEFINES edges. +Phase 3 — RANK: Composite score = alpha*vector + beta*pagerank + gamma*heat. +Phase 4 — ASSEMBLE: Deduplicated results + structural map. + +No LLM calls at any phase. +""" + +from __future__ import annotations + +from collections import deque +from typing import Any + +import msgspec + +from smp.engine.interfaces import QueryEngine as QueryEngineInterface +from smp.logging import get_logger +from smp.store.interfaces import GraphStore, VectorStore + +log = get_logger(__name__) + +ALPHA = 0.50 +BETA = 0.30 +GAMMA = 0.20 +ROUTE_CONFIDENCE_THRESHOLD = 0.65 +DEFAULT_SEED_K = 3 +DEFAULT_HOPS = 2 +DEFAULT_TOP_K = 10 + + +class SeedNode(msgspec.Struct, frozen=True): + node_id: str = "" + node_type: str = "" + name: str = "" + file: str = "" + signature: str = "" + docstring: str | None = None + tags: list[str] = msgspec.field(default_factory=list) + community_id: str | None = None + vector_score: float = 0.0 + pagerank: float = 0.0 + heat_score: int = 0 + + +class WalkNode(msgspec.Struct, frozen=True): + node_id: str = "" + node_type: str = "" + name: str = "" + file: str = "" + signature: str = "" + docstring: str | None = None + community_id: str | None = None + edge_type: str = "" + edge_direction: str = "" + hop: int = 0 + is_bridge: bool = False + pagerank: float = 0.0 + heat_score: int = 0 + + +class RankedResult(msgspec.Struct, frozen=True): + node_id: str = "" + node_type: str = "" + name: str = "" + file: str = "" + signature: str = "" + docstring: str | None = None + tags: list[str] = msgspec.field(default_factory=list) + community_id: str | None = None + final_score: float = 0.0 + vector_score: float = 0.0 + pagerank: float = 0.0 + heat_score: int = 0 + is_seed: bool = False + reachable_from: list[str] = msgspec.field(default_factory=list) + + +class LocateResponse(msgspec.Struct, frozen=True): + query: str = "" + routed_community: str | None = None + seed_count: int = 0 + total_walked: int = 0 + results: list[RankedResult] = msgspec.field(default_factory=list) + structural_map: list[dict[str, Any]] = msgspec.field(default_factory=list) + + +class SeedWalkEngine(QueryEngineInterface): + """Community-routed graph RAG pipeline for smp/locate.""" + + def __init__( + self, + graph_store: GraphStore, + vector_store: VectorStore | None = None, + enricher: Any | None = None, + alpha: float = ALPHA, + beta: float = BETA, + gamma: float = GAMMA, + route_threshold: float = ROUTE_CONFIDENCE_THRESHOLD, + delegate: QueryEngineInterface | None = None, + ) -> None: + self._graph = graph_store + self._vector = vector_store + self._enricher = enricher + self._alpha = alpha + self._beta = beta + self._gamma = gamma + self._route_threshold = route_threshold + self._delegate = delegate + + async def _route_to_community(self, query: str) -> tuple[str | None, float]: + if self._vector is None: + return None, 0.0 + try: + results = await self._vector.query( + embedding=_simple_hash_embedding(query), + top_k=1, + where={"collection_type": "centroid"}, + ) + if not results: + return None, 0.0 + best = results[0] + community_id = best.get("metadata", {}).get("community_id") + score = best.get("score", 1.0) + if isinstance(score, (int, float)): + confidence = 1.0 - float(score) + else: + confidence = 0.0 + if confidence < self._route_threshold: + return None, confidence + return community_id, confidence + except Exception: + log.warning("route_community_failed", query=query) + return None, 0.0 + + async def _seed( + self, + query: str, + seed_k: int, + community_id: str | None = None, + ) -> list[SeedNode]: + all_nodes = await self._graph.find_nodes() + terms = query.lower().split() + scored: list[tuple[float, dict[str, Any]]] = [] + for node in all_nodes: + s = 0.0 + name_lower = node.structural.name.lower() + if all(t in name_lower for t in terms): + s += 100.0 + elif any(t in name_lower for t in terms): + s += 50.0 + if node.semantic.docstring: + doc_lower = node.semantic.docstring.lower() + if all(t in doc_lower for t in terms): + s += 30.0 + elif any(t in doc_lower for t in terms): + s += 15.0 + for tag in node.semantic.tags: + if any(t in tag.lower() for t in terms): + s += 10.0 + break + if community_id and hasattr(node.semantic, "tags"): + pass + if s > 0: + scored.append((s, {"node": node, "score": s})) + + if self._vector is not None: + try: + v_results = await self._vector.query( + embedding=_simple_hash_embedding(query), + top_k=seed_k, + ) + for vr in v_results: + node_id = vr.get("id", "") + v_score = vr.get("score", 0.0) + if isinstance(v_score, (int, float)): + v_sim = 1.0 - float(v_score) + else: + v_sim = 0.0 + found = False + for s_item in scored: + if s_item[1].get("node", None) and s_item[1]["node"].id == node_id: + found = True + break + if not found and v_sim > 0.3: + gnode = await self._graph.get_node(node_id) + if gnode: + scored.append((v_sim * 80.0, {"node": gnode, "score": v_sim * 80.0})) + except Exception: + log.warning("vector_seed_failed", query=query) + + scored.sort(key=lambda x: -x[0]) + seeds: list[SeedNode] = [] + for score_val, data in scored[:seed_k]: + node = data["node"] + seeds.append( + SeedNode( + node_id=node.id, + node_type=node.type.value, + name=node.structural.name, + file=node.file_path, + signature=node.structural.signature, + docstring=node.semantic.docstring or None, + tags=node.semantic.tags, + community_id=None, + vector_score=min(score_val / 100.0, 1.0), + pagerank=0.0, + heat_score=0, + ) + ) + return seeds + + async def _walk(self, seed_ids: list[str], hops: int) -> list[WalkNode]: + from smp.core.models import EdgeType + + walked: dict[str, WalkNode] = {} + queue: deque[tuple[str, int]] = deque() + for sid in seed_ids: + queue.append((sid, 0)) + visited: set[str] = set(seed_ids) + + while queue: + current_id, depth = queue.popleft() + if depth >= hops: + continue + node = await self._graph.get_node(current_id) + if not node: + continue + try: + edges_out = await self._graph.get_edges(current_id, direction="outgoing") + except Exception: + edges_out = [] + try: + edges_in = await self._graph.get_edges(current_id, direction="incoming") + except Exception: + edges_in = [] + all_edges = edges_out + edges_in + for edge in all_edges: + if edge.type not in (EdgeType.CALLS, EdgeType.CALLS_RUNTIME, EdgeType.IMPORTS, EdgeType.DEFINES): + continue + neighbor_id = edge.target_id if edge.source_id == current_id else edge.source_id + direction = "out" if edge.source_id == current_id else "in" + if neighbor_id in visited: + continue + visited.add(neighbor_id) + neighbor = await self._graph.get_node(neighbor_id) + if not neighbor: + continue + walked[neighbor_id] = WalkNode( + node_id=neighbor_id, + node_type=neighbor.type.value, + name=neighbor.structural.name, + file=neighbor.file_path, + signature=neighbor.structural.signature, + docstring=neighbor.semantic.docstring or None, + community_id=None, + edge_type=edge.type.value, + edge_direction=direction, + hop=depth + 1, + is_bridge=False, + pagerank=0.0, + heat_score=0, + ) + queue.append((neighbor_id, depth + 1)) + return list(walked.values()) + + def _rank( + self, + seeds: list[SeedNode], + walked: list[WalkNode], + top_k: int, + ) -> list[RankedResult]: + seed_map = {s.node_id: s for s in seeds} + max_pr = max((s.pagerank for s in seeds), default=1.0) or 1.0 + walked_max_pr = max((w.pagerank for w in walked), default=1.0) or 1.0 + max_pr = max(max_pr, walked_max_pr) + + results: dict[str, RankedResult] = {} + for s in seeds: + score = ( + self._alpha * s.vector_score + self._beta * (s.pagerank / max_pr) + self._gamma * (s.heat_score / 100.0) + ) + results[s.node_id] = RankedResult( + node_id=s.node_id, + node_type=s.node_type, + name=s.name, + file=s.file, + signature=s.signature, + docstring=s.docstring, + tags=s.tags, + community_id=s.community_id, + final_score=round(score, 4), + vector_score=s.vector_score, + pagerank=s.pagerank, + heat_score=s.heat_score, + is_seed=True, + reachable_from=[s.node_id], + ) + + for w in walked: + if w.node_id in results: + continue + seed_pr = seed_map.get(w.node_id) + v_score = seed_pr.vector_score if seed_pr else 0.0 + score = self._alpha * v_score + self._beta * (w.pagerank / max_pr) + self._gamma * (w.heat_score / 100.0) + results[w.node_id] = RankedResult( + node_id=w.node_id, + node_type=w.node_type, + name=w.name, + file=w.file, + signature=w.signature, + docstring=w.docstring, + tags=[], + community_id=w.community_id, + final_score=round(score, 4), + vector_score=v_score, + pagerank=w.pagerank, + heat_score=w.heat_score, + is_seed=False, + reachable_from=[], + ) + + ranked = sorted(results.values(), key=lambda r: r.final_score, reverse=True) + return ranked[:top_k] + + def _build_structural_map( + self, + results: list[RankedResult], + walked: list[WalkNode], + ) -> list[dict[str, Any]]: + result_ids = {r.node_id for r in results} + edges: list[dict[str, Any]] = [] + for w in walked: + if w.node_id in result_ids: + edges.append( + { + "from": w.node_id, + "to": w.node_id, + "edge_type": w.edge_type, + "hop": w.hop, + } + ) + return edges + + async def locate( + self, + query: str, + fields: list[str] | None = None, + node_types: list[str] | None = None, + top_k: int = DEFAULT_TOP_K, + ) -> list[dict[str, Any]]: + routed_community, route_confidence = await self._route_to_community(query) + seed_k = min(top_k, DEFAULT_SEED_K) + hops = DEFAULT_HOPS + seeds = await self._seed(query, seed_k, community_id=routed_community) + if node_types: + seeds = [s for s in seeds if s.node_type in node_types] + walked = await self._walk([s.node_id for s in seeds], hops) + if node_types: + walked = [w for w in walked if w.node_type in node_types] + ranked = self._rank(seeds, walked, top_k) + smap = self._build_structural_map(ranked, walked) + + result = LocateResponse( + query=query, + routed_community=routed_community, + seed_count=len(seeds), + total_walked=len(walked), + results=ranked, + structural_map=smap, + ) + + return [msgspec.structs.asdict(result)] + + async def navigate(self, query: str, include_relationships: bool = True) -> dict[str, Any]: + if self._delegate: + return await self._delegate.navigate(query, include_relationships) + return {} + + async def trace( + self, start: str, relationship: str = "CALLS", depth: int = 3, direction: str = "outgoing" + ) -> list[dict[str, Any]]: + if self._delegate: + return await self._delegate.trace(start, relationship, depth, direction) + return [] + + async def get_context(self, file_path: str, scope: str = "edit", depth: int = 2) -> dict[str, Any]: + if self._delegate: + return await self._delegate.get_context(file_path, scope, depth) + return {} + + async def assess_impact(self, entity: str, change_type: str = "delete") -> dict[str, Any]: + if self._delegate: + return await self._delegate.assess_impact(entity, change_type) + return {} + + async def search( + self, query: str, match: str = "any", filters: dict[str, Any] | None = None, top_k: int = 5 + ) -> dict[str, Any]: + if self._delegate: + return await self._delegate.search(query, match, filters, top_k) + return {} + + async def conflict( + self, + entity: str, + proposed_change: str = "", + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if self._delegate: + return await self._delegate.conflict(entity, proposed_change, context) + return {"conflicts": []} + + async def diff( + self, + from_snapshot: str, + to_snapshot: str, + scope: str = "full", + ) -> dict[str, Any]: + if self._delegate: + return await self._delegate.diff(from_snapshot, to_snapshot, scope) + return {"diff": {}} + + async def plan( + self, + change_description: str, + target_file: str = "", + change_type: str = "refactor", + scope: str = "full", + ) -> dict[str, Any]: + if self._delegate: + return await self._delegate.plan(change_description, target_file, change_type, scope) + return {"steps": []} + + async def why( + self, + entity: str, + relationship: str = "", + depth: int = 3, + ) -> dict[str, Any]: + if self._delegate: + return await self._delegate.why(entity, relationship, depth) + return {"reasoning": []} + + async def find_flow(self, start: str, end: str, flow_type: str = "data") -> dict[str, Any]: + if self._delegate: + return await self._delegate.find_flow(start, end, flow_type) + return {} + + +def _simple_hash_embedding(text: str, dim: int = 128) -> list[float]: + """Deterministic hash-based embedding for prototyping. + + Maps text to a fixed-dimension float vector using character + frequency hashing. Production should use a real embedding model. + """ + vec = [0.0] * dim + for i, ch in enumerate(text): + vec[i % dim] += float(ord(ch)) + norm = sum(v * v for v in vec) ** 0.5 + if norm == 0: + return vec + return [v / norm for v in vec] diff --git a/smp/engine/telemetry.py b/smp/engine/telemetry.py new file mode 100644 index 0000000..ba55c8f --- /dev/null +++ b/smp/engine/telemetry.py @@ -0,0 +1,161 @@ +"""Telemetry engine for tracking node hotness and usage patterns. + +Collects runtime statistics to identify hot code paths and frequently +accessed nodes for optimization and safety decisions. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + +_HOT_THRESHOLD = 10 +_HOT_DECAY_SECONDS = 3600 + + +@dataclass +class NodeStats: + """Statistics for a single node.""" + + node_id: str + hit_count: int = 0 + last_hit_at: str = "" + avg_response_time_ms: float = 0.0 + error_count: int = 0 + callers: set[str] = field(default_factory=set) + + def touch(self) -> None: + """Record a hit on this node.""" + self.hit_count += 1 + self.last_hit_at = datetime.now(UTC).isoformat() + + +@dataclass +class TelemetryConfig: + """Configuration for telemetry collection.""" + + hot_threshold: int = _HOT_THRESHOLD + decay_seconds: int = _HOT_DECAY_SECONDS + max_tracked_nodes: int = 10000 + + +class TelemetryEngine: + """Tracks node access patterns and identifies hot nodes.""" + + def __init__(self, config: TelemetryConfig | None = None) -> None: + self._config = config or TelemetryConfig() + self._stats: dict[str, NodeStats] = {} + self._start_time = time.time() + + def record_access( + self, + node_id: str, + caller_id: str | None = None, + response_time_ms: float = 0.0, + error: bool = False, + ) -> None: + """Record an access to a node.""" + stats = self._stats.get(node_id) + if not stats: + if len(self._stats) >= self._config.max_tracked_nodes: + self._evict_cold() + stats = NodeStats(node_id=node_id) + self._stats[node_id] = stats + + stats.touch() + if caller_id: + stats.callers.add(caller_id) + if response_time_ms > 0: + total = stats.avg_response_time_ms * (stats.hit_count - 1) + response_time_ms + stats.avg_response_time_ms = total / stats.hit_count + if error: + stats.error_count += 1 + + log.debug("telemetry_access", node_id=node_id, hit_count=stats.hit_count) + + def get_hot_nodes(self, threshold: int | None = None) -> list[dict[str, Any]]: + """Return nodes exceeding the hot threshold.""" + hot_threshold = threshold or self._config.hot_threshold + hot = [] + + for node_id, stats in self._stats.items(): + if stats.hit_count >= hot_threshold: + hot.append( + { + "node_id": node_id, + "hit_count": stats.hit_count, + "last_hit_at": stats.last_hit_at, + "avg_response_time_ms": stats.avg_response_time_ms, + "error_count": stats.error_count, + "caller_count": len(stats.callers), + } + ) + + hot.sort(key=lambda x: -int(x["hit_count"])) + return hot + + def get_stats(self, node_id: str) -> dict[str, Any] | None: + """Get statistics for a specific node.""" + stats = self._stats.get(node_id) + if not stats: + return None + return { + "node_id": stats.node_id, + "hit_count": stats.hit_count, + "last_hit_at": stats.last_hit_at, + "avg_response_time_ms": stats.avg_response_time_ms, + "error_count": stats.error_count, + "caller_count": len(stats.callers), + } + + def get_summary(self) -> dict[str, Any]: + """Return overall telemetry summary.""" + total_hits = sum(s.hit_count for s in self._stats.values()) + total_errors = sum(s.error_count for s in self._stats.values()) + + return { + "uptime_seconds": int(time.time() - self._start_time), + "total_nodes_tracked": len(self._stats), + "total_hits": total_hits, + "total_errors": total_errors, + "hot_node_count": len(self.get_hot_nodes()), + } + + def decay(self) -> int: + """Decay old statistics to prevent unbounded growth.""" + cutoff = datetime.now(UTC).timestamp() - self._config.decay_seconds + cutoff_str = datetime.fromtimestamp(cutoff, tz=UTC).isoformat() + + to_remove = [ + node_id for node_id, stats in self._stats.items() if stats.last_hit_at and stats.last_hit_at < cutoff_str + ] + + for node_id in to_remove: + del self._stats[node_id] + + if to_remove: + log.info("telemetry_decayed", removed=len(to_remove)) + return len(to_remove) + + def _evict_cold(self) -> None: + """Evict the coldest nodes when at capacity.""" + if not self._stats: + return + + sorted_nodes = sorted(self._stats.items(), key=lambda x: x[1].hit_count) + for node_id, _ in sorted_nodes[: len(sorted_nodes) // 10]: + del self._stats[node_id] + + log.debug("telemetry_evicted", count=len(sorted_nodes) // 10) + + def reset(self) -> None: + """Clear all telemetry data.""" + self._stats.clear() + self._start_time = time.time() + log.info("telemetry_reset") diff --git a/smp/logging.py b/smp/logging.py new file mode 100644 index 0000000..bb899f5 --- /dev/null +++ b/smp/logging.py @@ -0,0 +1,68 @@ +"""Structured logging configuration for SMP. + +Usage: + from smp.logging import get_logger + log = get_logger(__name__) + log.info("graph_updated", nodes=42, edges=97) +""" + +from __future__ import annotations + +import logging +import sys + +import structlog + + +def configure_logging(*, json: bool = False, level: str = "INFO") -> None: + """Initialise structlog + stdlib logging. + + Args: + json: When True, render as newline-delimited JSON (production). + When False, render with colours (development). + level: Minimum log level for the root SMP logger. + """ + shared_processors: list[structlog.types.Processor] = [ + structlog.contextvars.merge_contextvars, + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.UnicodeDecoder(), + ] + + if json: + renderer: structlog.types.Processor = structlog.processors.JSONRenderer() + else: + renderer = structlog.dev.ConsoleRenderer(colors=True) + + structlog.configure( + processors=[*shared_processors, structlog.stdlib.ProcessorFormatter.wrap_for_formatter], + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, + ) + + formatter = structlog.stdlib.ProcessorFormatter( + processors=[ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + renderer, + ], + ) + + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(formatter) + + root = logging.getLogger("smp") + root.handlers.clear() + root.addHandler(handler) + root.setLevel(level.upper()) + + +def get_logger(name: str) -> structlog.stdlib.BoundLogger: + """Return a bound structlog logger scoped to *name*.""" + return structlog.get_logger(name) + + +# Auto-configure with dev defaults on first import. +configure_logging() diff --git a/smp/parser/__init__.py b/smp/parser/__init__.py new file mode 100644 index 0000000..e45fa64 --- /dev/null +++ b/smp/parser/__init__.py @@ -0,0 +1,11 @@ +"""Parser layer — AST extraction via tree-sitter.""" + +from smp.parser.base import TreeSitterParser, detect_language, make_node_id +from smp.parser.registry import ParserRegistry + +__all__ = [ + "TreeSitterParser", + "detect_language", + "make_node_id", + "ParserRegistry", +] diff --git a/smp/parser/base.py b/smp/parser/base.py new file mode 100644 index 0000000..cfddd40 --- /dev/null +++ b/smp/parser/base.py @@ -0,0 +1,153 @@ +"""Abstract tree-sitter parser and language detection utilities.""" + +from __future__ import annotations + +import abc +from pathlib import Path + +import tree_sitter as ts + +from smp.core.models import Document, GraphEdge, GraphNode, Language, NodeType, ParseError +from smp.engine.interfaces import Parser +from smp.logging import get_logger + +log = get_logger(__name__) + +_EXT_TO_LANG: dict[str, Language] = { + ".py": Language.PYTHON, + ".ts": Language.TYPESCRIPT, + ".tsx": Language.TYPESCRIPT, + ".js": Language.TYPESCRIPT, + ".jsx": Language.TYPESCRIPT, +} + +# Extensions that use the TSX grammar variant +_TSX_EXTS = {".tsx", ".jsx"} + + +def detect_language(file_path: str) -> Language: + """Guess language from file extension.""" + suffix = Path(file_path).suffix.lower() + return _EXT_TO_LANG.get(suffix, Language.UNKNOWN) + + +def is_tsx(file_path: str) -> bool: + """Return True if the file uses JSX/TSX syntax.""" + return Path(file_path).suffix.lower() in _TSX_EXTS + + +def make_node_id(file_path: str, type: NodeType, name: str, start_line: int) -> str: + """Deterministic node ID from structural coordinates.""" + return f"{file_path}::{type.value}::{name}::{start_line}" + + +def node_text(node: ts.Node) -> str: + """Safely extract text from a tree-sitter node.""" + if node.text: + return node.text.decode("utf-8", errors="replace") + return "" + + +def line_range(node: ts.Node) -> tuple[int, int]: + """Return (start_line, end_line) as 1-indexed line numbers.""" + return node.start_point[0] + 1, node.end_point[0] + 1 + + +class TreeSitterParser(Parser, abc.ABC): + """Abstract base for tree-sitter language parsers. + + Subclasses provide the grammar language object and extraction logic. + The base class handles parsing, error recovery, and Document assembly. + """ + + @abc.abstractmethod + def _language(self, file_path: str) -> ts.Language: + """Return the tree-sitter Language object for *file_path*.""" + + @abc.abstractmethod + def _extract( + self, + root_node: ts.Node, + source_bytes: bytes, + file_path: str, + ) -> tuple[list[GraphNode], list[GraphEdge], list[ParseError]]: + """Extract nodes, edges, and errors from a parsed AST. + + Returns a tuple of (nodes, edges, errors). + """ + + @property + @abc.abstractmethod + def supported_languages(self) -> list[str]: ... + + def parse(self, source: str, file_path: str) -> Document: + lang = detect_language(file_path) + if lang == Language.UNKNOWN: + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"Unsupported language for {file_path}")], + ) + + source_bytes = source.encode("utf-8") + + try: + ts_lang = self._language(file_path) + parser = ts.Parser(ts_lang) + tree = parser.parse(source_bytes) + except Exception as exc: + log.error("parse_crash", file_path=file_path, error=str(exc)) + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"Parser crash: {exc}")], + ) + + errors: list[ParseError] = [] + nodes: list[GraphNode] = [] + edges: list[GraphEdge] = [] + + try: + nodes, edges, errors = self._extract(tree.root_node, source_bytes, file_path) + except Exception as exc: + log.error("extract_error", file_path=file_path, error=str(exc)) + errors.append(ParseError(message=f"Extraction error: {exc}")) + + # Detect tree-sitter error nodes + self._collect_syntax_errors(tree.root_node, source_bytes, errors) + + log.debug( + "file_parsed", + file_path=file_path, + lang=lang.value, + nodes=len(nodes), + edges=len(edges), + errors=len(errors), + ) + return Document( + file_path=file_path, + language=lang, + nodes=nodes, + edges=edges, + errors=errors, + ) + + @staticmethod + def _collect_syntax_errors( + node: ts.Node, + source: bytes, + errors: list[ParseError], + ) -> None: + """Walk the tree and collect ERROR / MISSING nodes.""" + if node.is_error or node.is_missing: + row, col = node.start_point + text = node.text.decode("utf-8", errors="replace")[:80] if node.text else "" + errors.append( + ParseError( + message=f"Syntax {'missing' if node.is_missing else 'error'}: {text}", + line=row + 1, + column=col, + ) + ) + for child in node.children: + TreeSitterParser._collect_syntax_errors(child, source, errors) diff --git a/smp/parser/python_parser.py b/smp/parser/python_parser.py new file mode 100644 index 0000000..16a72f2 --- /dev/null +++ b/smp/parser/python_parser.py @@ -0,0 +1,553 @@ +"""Python-specific tree-sitter parser. + +Extracts functions, classes, methods, imports, decorators, inline comments, +and type annotations from Python source using the ``tree-sitter-python`` grammar. +""" + +from __future__ import annotations + +import re + +import tree_sitter as ts +import tree_sitter_python as tsp + +from smp.core.models import ( + Annotations, + EdgeType, + GraphEdge, + GraphNode, + NodeType, + ParseError, + SemanticProperties, + StructuralProperties, +) +from smp.logging import get_logger +from smp.parser.base import TreeSitterParser, line_range, make_node_id, node_text + +log = get_logger(__name__) + +_LANGUAGE = ts.Language(tsp.language()) + +_CALL_QUERY = ts.Query( + _LANGUAGE, + """ +(call function: (identifier) @callee) @call +(call function: (attribute) @callee) @call +""", +) + +_COMMENT_QUERY = ts.Query( + _LANGUAGE, + """ +(comment) @comment +""", +) + + +def _compute_complexity(body: ts.Node) -> int: + """Estimate cyclomatic complexity from AST body node.""" + complexity = 1 + cursor = body.walk() + stack: list[ts.Node] = [cursor.node] if cursor.node else [] + while stack: + node = stack.pop() + if node.type in ( + "if_statement", + "elif_clause", + "for_statement", + "while_statement", + "conditional_expression", + "boolean_operator", + ): + complexity += 1 + for child in node.children: + stack.append(child) + return complexity + + +class PythonParser(TreeSitterParser): + """Extract structural elements from Python source.""" + + @property + def supported_languages(self) -> list[str]: + return ["python"] + + def _language(self, file_path: str) -> ts.Language: + return _LANGUAGE + + def _extract( + self, + root_node: ts.Node, + source_bytes: bytes, + file_path: str, + ) -> tuple[list[GraphNode], list[GraphEdge], list[ParseError]]: + nodes: list[GraphNode] = [] + edges: list[GraphEdge] = [] + errors: list[ParseError] = [] + seen_ids: set[str] = set() + + file_node = GraphNode( + id=make_node_id(file_path, NodeType.FILE, file_path, 1), + type=NodeType.FILE, + file_path=file_path, + structural=StructuralProperties( + name=file_path, + file=file_path, + start_line=1, + end_line=root_node.end_point[0] + 1, + lines=root_node.end_point[0] + 1, + ), + ) + self._add_node(file_node, nodes, seen_ids) + + self._walk_block( + root_node, + source_bytes, + file_path, + parent_id=file_node.id, + class_name=None, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + log.debug("python_parsed", file=file_path, nodes=len(nodes), edges=len(edges), errors=len(errors)) + return nodes, edges, errors + + def _add_node(self, node: GraphNode, nodes: list[GraphNode], seen: set[str]) -> bool: + """Add node if not already seen. Returns True if added.""" + if node.id in seen: + return False + seen.add(node.id) + nodes.append(node) + return True + + def _walk_block( + self, + block: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + """Walk children of a block extracting definitions.""" + self._walk_direct_children( + block, + source, + file_path, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + ) + + def _walk_direct_children( + self, + block: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + """Walk direct children of a block, processing definitions.""" + for child in block.children: + if child.type == "function_definition": + self._process_function( + child, + source, + file_path, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + [], + ) + elif child.type == "class_definition": + self._process_class( + child, + source, + file_path, + parent_id, + nodes, + edges, + errors, + seen_ids, + [], + ) + elif child.type == "decorated_definition": + decorator_names = self._extract_decorators(child, source) + for sub in child.children: + if sub.type == "function_definition": + self._process_function( + sub, + source, + file_path, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + decorator_names, + ) + break + elif sub.type == "class_definition": + self._process_class( + sub, + source, + file_path, + parent_id, + nodes, + edges, + errors, + seen_ids, + decorator_names, + ) + break + elif child.type in ("import_statement", "import_from_statement"): + self._process_import(child, source, file_path, parent_id, nodes, edges) + elif child.type == "expression_statement": + self._process_assignment(child, source, file_path, parent_id, nodes, edges) + + def _process_function( + self, + func: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + decorator_names: list[str], + ) -> None: + name_node = func.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(func) + node_type = NodeType.FUNCTION if class_name is None else NodeType.FUNCTION + sig = self._extract_signature(func, source, name) + docstring = self._extract_docstring(func, source) + annotations = self._extract_annotations(func, source) + node_id = make_node_id(file_path, node_type, name, start) + + body = func.child_by_field_name("body") + complexity = _compute_complexity(body) if body else 1 + lines = end - start + 1 + param_count = len(annotations.params) if annotations else 0 + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + complexity=complexity, + lines=lines, + parameters=param_count, + ) + + semantic = SemanticProperties( + docstring=docstring, + decorators=decorator_names, + annotations=annotations, + ) + + metadata: dict[str, str] = {} + if class_name: + metadata["class"] = class_name + + node = GraphNode( + id=node_id, + type=node_type, + file_path=file_path, + structural=structural, + semantic=semantic, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_class( + self, + cls: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + decorator_names: list[str], + ) -> None: + name_node = cls.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(cls) + docstring = self._extract_docstring(cls, source) + bases = self._extract_bases(cls, source) + sig = f"class {name}" + if bases: + sig += f"({', '.join(bases)})" + node_id = make_node_id(file_path, NodeType.CLASS, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + semantic = SemanticProperties( + docstring=docstring, + decorators=decorator_names, + ) + + node = GraphNode( + id=node_id, + type=NodeType.CLASS, + file_path=file_path, + structural=structural, + semantic=semantic, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + for base in bases: + base_id = make_node_id(file_path, NodeType.INTERFACE, base, 0) + edges.append(GraphEdge(source_id=node_id, target_id=base_id, type=EdgeType.IMPLEMENTS)) + + body = cls.child_by_field_name("body") + if body: + self._walk_block( + body, + source, + file_path, + parent_id=node_id, + class_name=name, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + + def _process_assignment( + self, + expr: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + """Process top-level variable assignments.""" + for child in expr.children: + if child.type in ("assignment", "type_alias_statement"): + start, end = line_range(child) + left = child.child_by_field_name("left") or child.child_by_field_name("name") + if not left: + continue + name = node_text(left) + if not name or name.startswith("_"): + continue + node_id = make_node_id(file_path, NodeType.VARIABLE, name, start) + structural = StructuralProperties( + name=name, + file=file_path, + signature=node_text(child), + start_line=start, + end_line=end, + lines=end - start + 1, + ) + node = GraphNode( + id=node_id, + type=NodeType.VARIABLE, + file_path=file_path, + structural=structural, + ) + nodes.append(node) + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + def _process_import( + self, + imp: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + start, end = line_range(imp) + text = node_text(imp).strip() + if imp.type == "import_from_statement": + module_name_node = imp.child_by_field_name("module_name") + module = node_text(module_name_node) if module_name_node else text + else: + module = text.replace("import ", "").split(",")[0].strip() + + node_id = make_node_id(file_path, NodeType.FILE, module, start) + structural = StructuralProperties( + name=module, + file=file_path, + signature=text, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + node = GraphNode( + id=node_id, + type=NodeType.FILE, + file_path=file_path, + structural=structural, + ) + nodes.append(node) + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.IMPORTS)) + + def _extract_calls( + self, + body: ts.Node, + source: bytes, + file_path: str, + caller_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + cursor = ts.QueryCursor(_CALL_QUERY) + seen_edges: set[tuple[str, str]] = set() + for _, caps in cursor.matches(body): + call_nodes = caps.get("call") + callee_nodes = caps.get("callee") + if not callee_nodes or not call_nodes: + continue + callee_name = node_text(callee_nodes[0]) + call_node = call_nodes[0] + start, _ = line_range(call_node) + target_id = make_node_id(file_path, NodeType.FUNCTION, callee_name, 0) + edge_key = (caller_id, target_id) + if edge_key in seen_edges: + continue + seen_edges.add(edge_key) + edges.append( + GraphEdge( + source_id=caller_id, + target_id=target_id, + type=EdgeType.CALLS, + metadata={"line": str(start)}, + ) + ) + + def _extract_decorators(self, decorated: ts.Node, source: bytes) -> list[str]: + names: list[str] = [] + for child in decorated.children: + if child.type == "decorator": + text = node_text(child).lstrip("@").strip() + if "(" in text: + text = text[: text.index("(")] + names.append(text) + return names + + def _extract_bases(self, cls: ts.Node, source: bytes) -> list[str]: + bases: list[str] = [] + arg_list = cls.child_by_field_name("superclasses") + if not arg_list: + for child in cls.children: + if child.type == "argument_list": + arg_list = child + break + if arg_list: + for child in arg_list.children: + if child.type == "identifier": + bases.append(node_text(child)) + return bases + + def _extract_signature(self, func: ts.Node, source: bytes, name: str) -> str: + params = func.child_by_field_name("parameters") + param_text = node_text(params) if params else "()" + return_type = "" + for child in func.children: + if child.type == "type": + return_type = f" -> {node_text(child)}" + break + return f"def {name}{param_text}{return_type}" + + def _extract_annotations(self, func: ts.Node, source: bytes) -> Annotations: + """Extract structured type annotations from a function.""" + params_dict: dict[str, str] = {} + returns: str | None = None + throws: list[str] = [] + + params_node = func.child_by_field_name("parameters") + if params_node: + for child in params_node.children: + if child.type == "identifier": + pname = node_text(child) + if pname in ("self", "cls"): + continue + params_dict[pname] = "Any" + elif child.type == "typed_parameter": + # In tree-sitter-python, typed_parameter has 'identifier' and 'type' as direct children + ident = None + type_node = None + for sub in child.children: + if sub.type == "identifier": + ident = sub + elif sub.type == "type": + type_node = sub + pname = node_text(ident) if ident else "" + ptype = node_text(type_node) if type_node else "Any" + if pname and pname not in ("self", "cls"): + params_dict[pname] = ptype + + for child in func.children: + if child.type == "type": + returns = node_text(child) + break + + body = func.child_by_field_name("body") + if body: + body_text = node_text(body) + raise_matches = re.findall(r"raise\s+(\w+)", body_text) + throws = list(dict.fromkeys(raise_matches)) + + return Annotations(params=params_dict, returns=returns, throws=throws) + + def _extract_docstring(self, func_or_class: ts.Node, source: bytes) -> str: + body = func_or_class.child_by_field_name("body") + if not body: + return "" + for child in body.children: + if child.type == "expression_statement": + for sub in child.children: + if sub.type == "string": + text = node_text(sub) + for quote in ('"""', "'''", '"', "'"): + if text.startswith(quote) and text.endswith(quote): + text = text[len(quote) : -len(quote)] + break + return text.strip() + else: + break + return "" diff --git a/smp/parser/registry.py b/smp/parser/registry.py new file mode 100644 index 0000000..9c4e59f --- /dev/null +++ b/smp/parser/registry.py @@ -0,0 +1,72 @@ +"""Parser registry — dispatches to the correct language parser.""" + +from __future__ import annotations + +from pathlib import Path + +from smp.core.models import Document, Language +from smp.logging import get_logger +from smp.parser.base import TreeSitterParser, detect_language + +log = get_logger(__name__) + + +class ParserRegistry: + """Lazy-initialised registry of language-specific parsers.""" + + def __init__(self) -> None: + self._parsers: dict[Language, TreeSitterParser] = {} + + def _ensure_parser(self, language: Language) -> TreeSitterParser | None: + if language in self._parsers: + return self._parsers[language] + + parser: TreeSitterParser | None = None + + if language == Language.PYTHON: + from smp.parser.python_parser import PythonParser + + parser = PythonParser() + elif language == Language.TYPESCRIPT: + from smp.parser.typescript_parser import TypeScriptParser + + parser = TypeScriptParser() + + if parser: + self._parsers[language] = parser + log.debug("parser_registered", language=language.value) + return parser + + def get(self, language: Language) -> TreeSitterParser | None: + """Return the parser for *language*, or ``None`` if unsupported.""" + return self._ensure_parser(language) + + def parse_file(self, file_path: str) -> Document: + """Detect language, read file, and parse. + + Returns a Document with nodes, edges, and errors. + """ + lang = detect_language(file_path) + parser = self.get(lang) + if not parser: + from smp.core.models import ParseError + + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"No parser available for {lang.value}")], + ) + + try: + source = Path(file_path).read_text(encoding="utf-8", errors="replace") + except OSError as exc: + from smp.core.models import ParseError + + log.error("file_read_error", file_path=file_path, error=str(exc)) + return Document( + file_path=file_path, + language=lang, + errors=[ParseError(message=f"Cannot read file: {exc}")], + ) + + return parser.parse(source, file_path) diff --git a/smp/parser/typescript_parser.py b/smp/parser/typescript_parser.py new file mode 100644 index 0000000..a339f05 --- /dev/null +++ b/smp/parser/typescript_parser.py @@ -0,0 +1,525 @@ +"""TypeScript-specific tree-sitter parser. + +Extracts functions, classes, interfaces, methods, imports, arrow functions, +and call edges from TypeScript / TSX source using ``tree-sitter-typescript``. +Updated for SMP(3) partitioned model. +""" + +from __future__ import annotations + +import tree_sitter as ts +import tree_sitter_typescript as tst + +from smp.core.models import ( + EdgeType, + GraphEdge, + GraphNode, + NodeType, + ParseError, + StructuralProperties, +) +from smp.logging import get_logger +from smp.parser.base import TreeSitterParser, is_tsx, line_range, make_node_id, node_text + +log = get_logger(__name__) + +_TS_LANG = ts.Language(tst.language_typescript()) +_TSX_LANG = ts.Language(tst.language_tsx()) + +_QUERY_STRINGS = { + "top": """ +(function_declaration name: (identifier) @name) @func +(class_declaration name: (type_identifier) @name) @class +(interface_declaration name: (type_identifier) @name) @interface +(import_statement) @import +(export_statement) @export +""", + "arrow": """ +(lexical_declaration (variable_declarator name: (identifier) @name value: (arrow_function) @arrow)) @var +""", + "method": """ +(method_definition name: (property_identifier) @name) @method +""", + "call": """ +(call_expression function: (identifier) @callee) @call +(call_expression function: (member_expression property: (property_identifier) @callee)) @call +""", +} + +_query_cache: dict[str, dict[str, ts.Query]] = {"ts": {}, "tsx": {}} + + +def _get_queries(lang: ts.Language) -> dict[str, ts.Query]: + key = "tsx" if lang is _TSX_LANG else "ts" + if not _query_cache[key]: + for name, qstr in _QUERY_STRINGS.items(): + _query_cache[key][name] = ts.Query(lang, qstr) + return _query_cache[key] + + +class TypeScriptParser(TreeSitterParser): + """Extract structural elements from TypeScript / TSX source.""" + + @property + def supported_languages(self) -> list[str]: + return ["typescript"] + + def _language(self, file_path: str) -> ts.Language: + return _TSX_LANG if is_tsx(file_path) else _TS_LANG + + def _extract( + self, + root_node: ts.Node, + source_bytes: bytes, + file_path: str, + ) -> tuple[list[GraphNode], list[GraphEdge], list[ParseError]]: + nodes: list[GraphNode] = [] + edges: list[GraphEdge] = [] + errors: list[ParseError] = [] + seen_ids: set[str] = set() + + file_node = GraphNode( + id=make_node_id(file_path, NodeType.FILE, file_path, 1), + type=NodeType.FILE, + file_path=file_path, + structural=StructuralProperties( + name=file_path, + file=file_path, + start_line=1, + end_line=root_node.end_point[0] + 1, + lines=root_node.end_point[0] + 1, + ), + ) + self._add_node(file_node, nodes, seen_ids) + + self._walk_block( + root_node, + source_bytes, + file_path, + self._language(file_path), + parent_id=file_node.id, + class_name=None, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + log.debug("typescript_parsed", file=file_path, nodes=len(nodes), edges=len(edges), errors=len(errors)) + return nodes, edges, errors + + def _add_node(self, node: GraphNode, nodes: list[GraphNode], seen: set[str]) -> bool: + if node.id in seen: + return False + seen.add(node.id) + nodes.append(node) + return True + + def _walk_block( + self, + block: ts.Node, + source: bytes, + file_path: str, + lang: ts.Language, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + queries = _get_queries(lang) + cursor = ts.QueryCursor(queries["top"]) + for _idx, caps in cursor.matches(block): + func_nodes = caps.get("func") + class_nodes = caps.get("class") + iface_nodes = caps.get("interface") + import_nodes = caps.get("import") + export_nodes = caps.get("export") + + if func_nodes: + self._process_function( + func_nodes[0], + source, + file_path, + parent_id, + class_name, + nodes, + edges, + seen_ids, + ) + continue + + if class_nodes: + self._process_class( + class_nodes[0], + source, + file_path, + lang, + parent_id, + nodes, + edges, + errors, + seen_ids, + ) + continue + + if iface_nodes: + self._process_interface(iface_nodes[0], source, file_path, parent_id, nodes, edges, seen_ids) + continue + + if import_nodes: + self._process_import(import_nodes[0], source, file_path, parent_id, nodes, edges) + continue + + if export_nodes: + for child in export_nodes[0].children: + self._walk_block( + child, + source, + file_path, + lang, + parent_id, + class_name, + nodes, + edges, + errors, + seen_ids, + ) + continue + + arrow_cursor = ts.QueryCursor(queries["arrow"]) + for _idx, caps in arrow_cursor.matches(block): + name_nodes = caps.get("name") + arrow_nodes = caps.get("arrow") + if name_nodes and arrow_nodes: + self._process_arrow_function( + name_nodes[0], + arrow_nodes[0], + source, + file_path, + parent_id, + class_name, + nodes, + edges, + seen_ids, + ) + + method_cursor = ts.QueryCursor(queries["method"]) + for _idx, caps in method_cursor.matches(block): + method_nodes = caps.get("method") + name_nodes = caps.get("name") + if method_nodes and name_nodes: + self._process_method( + method_nodes[0], + name_nodes[0], + source, + file_path, + parent_id, + class_name, + nodes, + edges, + seen_ids, + ) + + def _process_function( + self, + func: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name_node = func.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(func) + sig = self._extract_ts_signature(func, source, name) + node_id = make_node_id(file_path, NodeType.FUNCTION, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FUNCTION, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = func.child_by_field_name("body") + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_arrow_function( + self, + name_node: ts.Node, + arrow: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name = node_text(name_node) + start, end = line_range(arrow) + sig = f"const {name} = {self._extract_ts_signature(arrow, source, name)}" + node_id = make_node_id(file_path, NodeType.FUNCTION, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FUNCTION, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = arrow.child_by_field_name("body") + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_method( + self, + method: ts.Node, + name_node: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + class_name: str | None, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name = node_text(name_node) + start, end = line_range(method) + sig = self._extract_ts_signature(method, source, name) + node_id = make_node_id(file_path, NodeType.FUNCTION, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FUNCTION, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = method.child_by_field_name("body") + if body: + self._extract_calls(body, source, file_path, node_id, nodes, edges) + + def _process_class( + self, + cls: ts.Node, + source: bytes, + file_path: str, + lang: ts.Language, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + errors: list[ParseError], + seen_ids: set[str], + ) -> None: + name_node = cls.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(cls) + sig = f"class {name}" + node_id = make_node_id(file_path, NodeType.CLASS, name, start) + + for child in cls.children: + if child.type == "class_heritage": + for heritage_child in child.children: + if heritage_child.type == "extends_clause": + for sub in heritage_child.children: + if sub.type in ("type_identifier", "identifier"): + base_name = node_text(sub) + sig += f" extends {base_name}" + base_id = make_node_id(file_path, NodeType.INTERFACE, base_name, 0) + edges.append(GraphEdge(source_id=node_id, target_id=base_id, type=EdgeType.IMPLEMENTS)) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=sig, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.CLASS, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + body = cls.child_by_field_name("body") + if body: + self._walk_block( + body, + source, + file_path, + lang, + parent_id=node_id, + class_name=name, + nodes=nodes, + edges=edges, + errors=errors, + seen_ids=seen_ids, + ) + + def _process_interface( + self, + iface: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + seen_ids: set[str], + ) -> None: + name_node = iface.child_by_field_name("name") + if not name_node: + return + name = node_text(name_node) + start, end = line_range(iface) + node_id = make_node_id(file_path, NodeType.INTERFACE, name, start) + + structural = StructuralProperties( + name=name, + file=file_path, + signature=f"interface {name}", + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.INTERFACE, + file_path=file_path, + structural=structural, + ) + if not self._add_node(node, nodes, seen_ids): + return + + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.DEFINES)) + + def _process_import( + self, + imp: ts.Node, + source: bytes, + file_path: str, + parent_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + start, end = line_range(imp) + text = node_text(imp).strip() + source_node = imp.child_by_field_name("source") + module = node_text(source_node) if source_node else text + + node_id = make_node_id(file_path, NodeType.FILE, module, start) + structural = StructuralProperties( + name=module, + file=file_path, + signature=text, + start_line=start, + end_line=end, + lines=end - start + 1, + ) + + node = GraphNode( + id=node_id, + type=NodeType.FILE, + file_path=file_path, + structural=structural, + ) + nodes.append(node) + edges.append(GraphEdge(source_id=parent_id, target_id=node_id, type=EdgeType.IMPORTS)) + + def _extract_calls( + self, + body: ts.Node, + source: bytes, + file_path: str, + caller_id: str, + nodes: list[GraphNode], + edges: list[GraphEdge], + ) -> None: + queries = _get_queries(self._language(file_path)) + cursor = ts.QueryCursor(queries["call"]) + seen_edges: set[tuple[str, str]] = set() + for _, caps in cursor.matches(body): + callee_nodes = caps.get("callee") + call_nodes = caps.get("call") + if not callee_nodes or not call_nodes: + continue + callee_name = node_text(callee_nodes[0]) + call_node = call_nodes[0] + start, _ = line_range(call_node) + target_id = make_node_id(file_path, NodeType.FUNCTION, callee_name, 0) + edge_key = (caller_id, target_id) + if edge_key in seen_edges: + continue + seen_edges.add(edge_key) + edges.append( + GraphEdge( + source_id=caller_id, + target_id=target_id, + type=EdgeType.CALLS, + metadata={"line": str(start)}, + ) + ) + + def _extract_ts_signature(self, node: ts.Node, source: bytes, name: str) -> str: + params_node = node.child_by_field_name("parameters") + params_text = node_text(params_node) if params_node else "()" + return_type = "" + for child in node.children: + if child.type == "type_annotation": + return_type = node_text(child) + break + if node.type == "arrow_function": + return f"({params_text}) => {return_type or '...'}" + return f"{name}{params_text}{return_type}" diff --git a/smp/protocol/__init__.py b/smp/protocol/__init__.py new file mode 100644 index 0000000..d7fb9e5 --- /dev/null +++ b/smp/protocol/__init__.py @@ -0,0 +1,9 @@ +"""Protocol layer — JSON-RPC 2.0 over FastAPI.""" + +from smp.protocol.router import handle_rpc +from smp.protocol.server import create_app + +__all__ = [ + "create_app", + "handle_rpc", +] diff --git a/smp/protocol/dispatcher.py b/smp/protocol/dispatcher.py new file mode 100644 index 0000000..4f49285 --- /dev/null +++ b/smp/protocol/dispatcher.py @@ -0,0 +1,267 @@ +"""JSON-RPC 2.0 dispatcher using handler pattern. + +Routes JSON-RPC method calls to registered handler instances. +""" + +from __future__ import annotations + +from typing import Any + +import msgspec +from fastapi import Request +from fastapi.responses import Response + +from smp.core.models import ( + JsonRpcError, + JsonRpcRequest, + JsonRpcResponse, +) +from smp.logging import get_logger +from smp.protocol.handlers.annotation import ( + AnnotateBulkHandler, + AnnotateHandler, + TagHandler, +) +from smp.protocol.handlers.base import MethodHandler +from smp.protocol.handlers.community import ( + CommunityBoundariesHandler, + CommunityDetectHandler, + CommunityGetHandler, + CommunityListHandler, +) +from smp.protocol.handlers.enrichment import ( + EnrichBatchHandler, + EnrichHandler, + EnrichStaleHandler, + EnrichStatusHandler, +) +from smp.protocol.handlers.handoff import ( + HandoffPRHandler, + HandoffReviewHandler, +) +from smp.protocol.handlers.memory import ( + BatchUpdateHandler, + ReindexHandler, + UpdateHandler, +) +from smp.protocol.handlers.merkle import ( + IndexExportHandler, + IndexImportHandler, + MerkleTreeHandler, + SyncHandler, +) +from smp.protocol.handlers.query import ( + ContextHandler, + FlowHandler, + ImpactHandler, + LocateHandler, + NavigateHandler, + SearchHandler, + TraceHandler, +) +from smp.protocol.handlers.query_ext import ( + ConflictHandler, + DiffHandler, + PlanHandler, + WhyHandler, +) +from smp.protocol.handlers.safety import ( + AuditGetHandler, + CheckpointHandler, + DryRunHandler, + GuardCheckHandler, + IntegrityVerifyHandler, + LockHandler, + RollbackHandler, + SessionCloseHandler, + SessionOpenHandler, + SessionRecoverHandler, + UnlockHandler, +) +from smp.protocol.handlers.sandbox import ( + SandboxDestroyHandler, + SandboxExecuteHandler, + SandboxSpawnHandler, +) +from smp.protocol.handlers.telemetry import ( + TelemetryHandler, + TelemetryHotHandler, + TelemetryNodeHandler, + TelemetryRecordHandler, +) + +log = get_logger(__name__) + + +def _error_response(req_id: int | str | None, code: int, message: str, data: Any = None) -> Response: + body = msgspec.json.encode( + JsonRpcResponse( + error=JsonRpcError(code=code, message=message, data=data), + id=req_id, + ) + ) + return Response(content=body, media_type="application/json", status_code=200) + + +def _success_response(req_id: int | str | None, result: Any) -> Response: + body = msgspec.json.encode(JsonRpcResponse(result=result, id=req_id)) + return Response(content=body, media_type="application/json", status_code=200) + + +class RpcDispatcher: + """Dispatches JSON-RPC requests to registered handlers.""" + + def __init__(self) -> None: + self._handlers: dict[str, MethodHandler] = {} + + for handler_cls in [ + UpdateHandler, + BatchUpdateHandler, + ReindexHandler, + EnrichHandler, + EnrichBatchHandler, + EnrichStaleHandler, + EnrichStatusHandler, + AnnotateHandler, + AnnotateBulkHandler, + TagHandler, + SessionOpenHandler, + SessionCloseHandler, + SessionRecoverHandler, + GuardCheckHandler, + DryRunHandler, + CheckpointHandler, + RollbackHandler, + LockHandler, + UnlockHandler, + AuditGetHandler, + IntegrityVerifyHandler, + NavigateHandler, + TraceHandler, + ContextHandler, + ImpactHandler, + LocateHandler, + SearchHandler, + FlowHandler, + DiffHandler, + PlanHandler, + ConflictHandler, + WhyHandler, + TelemetryHandler, + TelemetryHotHandler, + TelemetryNodeHandler, + TelemetryRecordHandler, + SandboxSpawnHandler, + SandboxExecuteHandler, + SandboxDestroyHandler, + CommunityDetectHandler, + CommunityListHandler, + CommunityGetHandler, + CommunityBoundariesHandler, + SyncHandler, + MerkleTreeHandler, + IndexExportHandler, + IndexImportHandler, + HandoffReviewHandler, + HandoffPRHandler, + ]: + handler = handler_cls() + self._handlers[handler.method] = handler + + def register(self, handler: MethodHandler) -> None: + """Register a handler for a method.""" + self._handlers[handler.method] = handler + log.debug("handler_registered", method=handler.method) + + def get_handler(self, method: str) -> MethodHandler | None: + """Get handler for a method.""" + return self._handlers.get(method) + + async def dispatch( + self, + request: Request, + context: dict[str, Any], + ) -> Response: + """Dispatch a JSON-RPC request to the appropriate handler.""" + try: + body = await request.body() + except Exception: + return _error_response(None, -32700, "Parse error") + + if not body: + return _error_response(None, -32700, "Empty request body") + + try: + req = msgspec.json.decode(body, type=JsonRpcRequest) + except (msgspec.DecodeError, Exception) as exc: + return _error_response(None, -32700, f"Parse error: {exc}") + + if req.jsonrpc != "2.0": + return _error_response(req.id, -32600, "Invalid Request: jsonrpc must be '2.0'") + + if not req.method: + return _error_response(req.id, -32600, "Invalid Request: method is required") + + method = req.method + params = req.params or {} + + log.debug("rpc_request", method=method, id=req.id) + + handler = self._handlers.get(method) + if not handler: + return _error_response(req.id, -32601, f"Method not found: {method}") + + try: + result = await handler.handle(params, context) + except msgspec.ValidationError as exc: + return _error_response(req.id, -32602, f"Invalid params: {exc}") + except ValueError as exc: + return _error_response(req.id, -32001, str(exc)) + except Exception as exc: + log.error("rpc_internal_error", method=method, error=str(exc)) + return _error_response(req.id, -32603, f"Internal error: {exc}") + + if req.id is None: + return Response(content=b"", status_code=204) + + return _success_response(req.id, result) + + +_dispatcher: RpcDispatcher | None = None + + +def get_dispatcher() -> RpcDispatcher: + """Get or create the global dispatcher instance.""" + global _dispatcher + if _dispatcher is None: + _dispatcher = RpcDispatcher() + return _dispatcher + + +async def handle_rpc( + request: Request, + *, + engine: Any, + enricher: Any, + builder: Any, + registry: Any, + vector: Any, + safety: dict[str, Any] | None = None, + telemetry_engine: Any = None, + handoff_manager: Any = None, + integrity_verifier: Any = None, +) -> Response: + """Dispatch a single JSON-RPC 2.0 request.""" + dispatcher = get_dispatcher() + context = { + "engine": engine, + "enricher": enricher, + "builder": builder, + "registry": registry, + "vector": vector, + "safety": safety, + "telemetry_engine": telemetry_engine, + "handoff_manager": handoff_manager, + "integrity_verifier": integrity_verifier, + } + return await dispatcher.dispatch(request, context) diff --git a/smp/protocol/handlers/__init__.py b/smp/protocol/handlers/__init__.py new file mode 100644 index 0000000..9f20d06 --- /dev/null +++ b/smp/protocol/handlers/__init__.py @@ -0,0 +1 @@ +"""Protocol handler modules.""" diff --git a/smp/protocol/handlers/annotation.py b/smp/protocol/handlers/annotation.py new file mode 100644 index 0000000..9ce7243 --- /dev/null +++ b/smp/protocol/handlers/annotation.py @@ -0,0 +1,117 @@ +"""Handler for annotation methods (smp/annotate, smp/annotate/bulk, smp/tag).""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +import msgspec + +from smp.core.models import AnnotateBulkParams, AnnotateParams, TagParams +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class AnnotateHandler(MethodHandler): + """Handles smp/annotate method.""" + + @property + def method(self) -> str: + return "smp/annotate" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ap = msgspec.convert(params, AnnotateParams) + engine = context["engine"] + + node = await engine._graph.get_node(ap.node_id) + if not node: + raise ValueError(f"Node not found: {ap.node_id}") + + if node.semantic.docstring and not ap.force: + raise ValueError(f"Node already has extracted docstring. Set force: true to override. Node: {ap.node_id}") + + node.semantic.description = ap.description + node.semantic.tags = list(set(node.semantic.tags + ap.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + + return { + "node_id": ap.node_id, + "status": "annotated", + "manually_set": True, + "annotated_at": node.semantic.enriched_at, + } + + +class AnnotateBulkHandler(MethodHandler): + """Handles smp/annotate/bulk method.""" + + @property + def method(self) -> str: + return "smp/annotate/bulk" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + abp = msgspec.convert(params, AnnotateBulkParams) + engine = context["engine"] + + annotated = 0 + failed = 0 + + for ann in abp.annotations: + node = await engine._graph.get_node(ann.node_id) + if not node: + failed += 1 + continue + + node.semantic.description = ann.description + node.semantic.tags = list(set(node.semantic.tags + ann.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + annotated += 1 + + return {"annotated": annotated, "failed": failed} + + +class TagHandler(MethodHandler): + """Handles smp/tag method.""" + + @property + def method(self) -> str: + return "smp/tag" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + tp = msgspec.convert(params, TagParams) + engine = context["engine"] + + nodes = await engine._graph.find_nodes_by_scope(tp.scope) + affected = 0 + + for node in nodes: + if tp.action == "add": + node.semantic.tags = list(set(node.semantic.tags + tp.tags)) + elif tp.action == "remove": + node.semantic.tags = [t for t in node.semantic.tags if t not in tp.tags] + elif tp.action == "replace": + node.semantic.tags = list(tp.tags) + await engine._graph.upsert_node(node) + affected += 1 + + return {"nodes_affected": affected, "action": tp.action, "scope": tp.scope} diff --git a/smp/protocol/handlers/base.py b/smp/protocol/handlers/base.py new file mode 100644 index 0000000..a1c5103 --- /dev/null +++ b/smp/protocol/handlers/base.py @@ -0,0 +1,34 @@ +"""Base handler interface for JSON-RPC method handlers.""" + +from __future__ import annotations + +import abc +from typing import Any + + +class MethodHandler(abc.ABC): + """Abstract base class for JSON-RPC method handlers.""" + + @property + @abc.abstractmethod + def method(self) -> str: + """Return the JSON-RPC method name this handler processes.""" + + @abc.abstractmethod + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + """Handle the method call. + + Args: + params: The method parameters + context: Request context (engine, enricher, etc.) + + Returns: + Result dict or None for notifications + + Raises: + JsonRpcError: For method-specific errors + """ diff --git a/smp/protocol/handlers/community.py b/smp/protocol/handlers/community.py new file mode 100644 index 0000000..306090d --- /dev/null +++ b/smp/protocol/handlers/community.py @@ -0,0 +1,94 @@ +"""Handler for community detection methods.""" + +from __future__ import annotations + +from typing import Any, cast + +import msgspec + +from smp.core.models import ( + CommunityBoundariesParams, + CommunityDetectParams, + CommunityGetParams, + CommunityListParams, +) +from smp.protocol.handlers.base import MethodHandler + + +class CommunityDetectHandler(MethodHandler): + """Handles smp/community/detect method.""" + + @property + def method(self) -> str: + return "smp/community/detect" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, CommunityDetectParams) + detector = context["community_detector"] + return cast( + dict[str, Any], + await detector.detect( + resolutions=p.resolutions or None, + relationship_types=p.relationship_types or None, + ), + ) + + +class CommunityListHandler(MethodHandler): + """Handles smp/community/list method.""" + + @property + def method(self) -> str: + return "smp/community/list" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, CommunityListParams) + detector = context["community_detector"] + return cast(dict[str, Any], await detector.list_communities(level=p.level)) + + +class CommunityGetHandler(MethodHandler): + """Handles smp/community/get method.""" + + @property + def method(self) -> str: + return "smp/community/get" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + p = msgspec.convert(params, CommunityGetParams) + detector = context["community_detector"] + result = await detector.get_community( + community_id=p.community_id, + node_types=p.node_types or None, + include_bridges=p.include_bridges, + ) + return cast(dict[str, Any] | None, result) + + +class CommunityBoundariesHandler(MethodHandler): + """Handles smp/community/boundaries method.""" + + @property + def method(self) -> str: + return "smp/community/boundaries" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, CommunityBoundariesParams) + detector = context["community_detector"] + return cast(dict[str, Any], await detector.get_boundaries(level=p.level, min_coupling=p.min_coupling)) diff --git a/smp/protocol/handlers/enrichment.py b/smp/protocol/handlers/enrichment.py new file mode 100644 index 0000000..b70cc51 --- /dev/null +++ b/smp/protocol/handlers/enrichment.py @@ -0,0 +1,185 @@ +"""Handler for enrichment methods (smp/enrich, smp/enrich/batch, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + EnrichBatchParams, + EnrichParams, + EnrichStaleParams, + EnrichStatusParams, +) +from smp.engine.enricher import _compute_source_hash +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class EnrichHandler(MethodHandler): + """Handles smp/enrich method.""" + + @property + def method(self) -> str: + return "smp/enrich" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ep = msgspec.convert(params, EnrichParams) + engine = context["engine"] + enricher = context["enricher"] + + node = await engine._graph.get_node(ep.node_id) + if not node: + raise ValueError(f"Node not found: {ep.node_id}") + + enriched = await enricher.enrich_node(node, force=ep.force) + if enriched.semantic.source_hash and enriched.semantic.status == "enriched": + await engine._graph.upsert_node(enriched) + + return { + "node_id": enriched.id, + "status": enriched.semantic.status, + "docstring": enriched.semantic.docstring, + "inline_comments": [{"line": c.line, "text": c.text} for c in enriched.semantic.inline_comments], + "decorators": enriched.semantic.decorators, + "annotations": { + "params": (enriched.semantic.annotations.params if enriched.semantic.annotations else {}), + "returns": (enriched.semantic.annotations.returns if enriched.semantic.annotations else None), + "throws": (enriched.semantic.annotations.throws if enriched.semantic.annotations else []), + } + if enriched.semantic.annotations + else {}, + "tags": enriched.semantic.tags, + "source_hash": enriched.semantic.source_hash, + "enriched_at": enriched.semantic.enriched_at, + } + + +class EnrichBatchHandler(MethodHandler): + """Handles smp/enrich/batch method.""" + + @property + def method(self) -> str: + return "smp/enrich/batch" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ebp = msgspec.convert(params, EnrichBatchParams) + engine = context["engine"] + enricher = context["enricher"] + + nodes = await engine._graph.find_nodes_by_scope(ebp.scope) + enriched_count = 0 + skipped_count = 0 + no_metadata_count = 0 + no_metadata_nodes: list[str] = [] + + for node in nodes: + enriched = await enricher.enrich_node(node, force=ebp.force) + if enriched.semantic.status == "enriched": + enriched_count += 1 + await engine._graph.upsert_node(enriched) + elif enriched.semantic.status == "skipped": + skipped_count += 1 + elif enriched.semantic.status == "no_metadata": + no_metadata_count += 1 + no_metadata_nodes.append(enriched.id) + + return { + "enriched": enriched_count, + "skipped": skipped_count, + "no_metadata": no_metadata_count, + "failed": 0, + "no_metadata_nodes": no_metadata_nodes, + } + + +class EnrichStaleHandler(MethodHandler): + """Handles smp/enrich/stale method.""" + + @property + def method(self) -> str: + return "smp/enrich/stale" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + esp = msgspec.convert(params, EnrichStaleParams) + engine = context["engine"] + + nodes = await engine._graph.find_nodes_by_scope(esp.scope) + stale_nodes = [] + + for node in nodes: + if node.semantic.source_hash: + current = _compute_source_hash( + node.structural.name, + node.file_path, + node.structural.start_line, + node.structural.end_line, + node.structural.signature, + ) + if current != node.semantic.source_hash: + stale_nodes.append( + { + "node_id": node.id, + "file": node.file_path, + "last_enriched": node.semantic.enriched_at, + "current_hash": current, + "enriched_hash": node.semantic.source_hash, + } + ) + + return {"stale_count": len(stale_nodes), "stale_nodes": stale_nodes} + + +class EnrichStatusHandler(MethodHandler): + """Handles smp/enrich/status method.""" + + @property + def method(self) -> str: + return "smp/enrich/status" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + estp = msgspec.convert(params, EnrichStatusParams) + engine = context["engine"] + + nodes = await engine._graph.find_nodes_by_scope(estp.scope) + total = len(nodes) + has_docstring = sum(1 for n in nodes if n.semantic.docstring) + has_annotations = sum( + 1 + for n in nodes + if n.semantic.annotations and (n.semantic.annotations.params or n.semantic.annotations.returns) + ) + has_tags = sum(1 for n in nodes if n.semantic.tags) + manually_annotated = sum(1 for n in nodes if n.semantic.manually_set) + no_metadata = sum(1 for n in nodes if n.semantic.status == "no_metadata") + coverage = round((total - no_metadata) / total * 100, 1) if total > 0 else 0 + + return { + "total_nodes": total, + "has_docstring": has_docstring, + "has_annotations": has_annotations, + "has_tags": has_tags, + "manually_annotated": manually_annotated, + "no_metadata": no_metadata, + "stale": 0, + "coverage_pct": coverage, + } diff --git a/smp/protocol/handlers/handoff.py b/smp/protocol/handlers/handoff.py new file mode 100644 index 0000000..e770799 --- /dev/null +++ b/smp/protocol/handlers/handoff.py @@ -0,0 +1,68 @@ +"""Handler for handoff and review methods.""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import PRCreateParams, ReviewCreateParams +from smp.protocol.handlers.base import MethodHandler + + +class HandoffReviewHandler(MethodHandler): + """Handles smp/handoff/review method.""" + + @property + def method(self) -> str: + return "smp/handoff/review" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, ReviewCreateParams) + manager = context["handoff_manager"] + review = manager.create_review( + session_id=p.session_id, + files_changed=p.files_changed, + diff_summary=p.diff_summary, + reviewers=p.reviewers, + ) + return { + "review_id": review.review_id, + "status": review.status, + "created_at": review.created_at, + } + + +class HandoffPRHandler(MethodHandler): + """Handles smp/handoff/pr method.""" + + @property + def method(self) -> str: + return "smp/handoff/pr" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + p = msgspec.convert(params, PRCreateParams) + manager = context["handoff_manager"] + pr = manager.create_pr( + review_id=p.review_id, + title=p.title, + body=p.body, + branch=p.branch, + base_branch=p.base_branch, + ) + if pr is None: + return None + return { + "pr_id": pr.pr_id, + "status": pr.status, + "url": pr.url, + "created_at": pr.created_at, + } diff --git a/smp/protocol/handlers/memory.py b/smp/protocol/handlers/memory.py new file mode 100644 index 0000000..6658907 --- /dev/null +++ b/smp/protocol/handlers/memory.py @@ -0,0 +1,115 @@ +"""Handler for memory management methods (smp/update, smp/batch_update, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import BatchUpdateParams, ReindexParams, UpdateParams +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class UpdateHandler(MethodHandler): + """Handles smp/update method.""" + + @property + def method(self) -> str: + return "smp/update" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + p = msgspec.convert(params, UpdateParams) + enricher = context["enricher"] + builder = context["builder"] + registry = context["registry"] + vector = context.get("vector") + + file_path = p.file_path + + if p.content: + parser_obj = registry.get(p.language) + if not parser_obj: + from smp.core.models import Language + + parser_obj = registry.get(Language.PYTHON) + if not parser_obj: + return {"error": "No parser available"} + doc = parser_obj.parse(p.content, file_path) + else: + doc = registry.parse_file(file_path) + + if not doc.nodes and not doc.edges: + return { + "file_path": file_path, + "nodes": 0, + "edges": 0, + "errors": len(doc.errors), + "message": "No nodes extracted", + } + + enriched_nodes = await enricher.enrich_batch(doc.nodes) + doc = type(doc)( + file_path=doc.file_path, + language=doc.language, + nodes=enriched_nodes, + edges=doc.edges, + errors=doc.errors, + ) + + if vector: + await vector.delete_by_file(file_path) + await builder.remove_document(file_path) + await builder.ingest_document(doc) + + return { + "file_path": file_path, + "nodes": len(doc.nodes), + "edges": len(doc.edges), + "errors": len(doc.errors), + } + + +class BatchUpdateHandler(MethodHandler): + """Handles smp/batch_update method.""" + + @property + def method(self) -> str: + return "smp/batch_update" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + bp = msgspec.convert(params, BatchUpdateParams) + update_handler = UpdateHandler() + + results = [] + for change in bp.changes: + r = await update_handler.handle(change, context) + results.append(r) + + return {"updates": len(results), "results": results} + + +class ReindexHandler(MethodHandler): + """Handles smp/reindex method.""" + + @property + def method(self) -> str: + return "smp/reindex" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + rp = msgspec.convert(params, ReindexParams) + return {"status": "reindex_requested", "scope": rp.scope} diff --git a/smp/protocol/handlers/merkle.py b/smp/protocol/handlers/merkle.py new file mode 100644 index 0000000..de3cf5a --- /dev/null +++ b/smp/protocol/handlers/merkle.py @@ -0,0 +1,81 @@ +"""Handler for Merkle index and sync methods.""" + +from __future__ import annotations + +from typing import Any, cast + +from smp.protocol.handlers.base import MethodHandler + + +class SyncHandler(MethodHandler): + """Handles smp/sync method.""" + + @property + def method(self) -> str: + return "smp/sync" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any] | None: + remote_hash = params.get("remote_hash", "") + index = context["merkle_index"] + # MerkleIndex.sync returns dict[str, set[str]] | None + result = index.sync(remote_hash) + if result is None: + return {"status": "in_sync"} + return {"status": "out_of_sync", "diff": result} + + +class MerkleTreeHandler(MethodHandler): + """Handles smp/merkle/tree method.""" + + @property + def method(self) -> str: + return "smp/merkle/tree" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + index = context["merkle_index"] + tree = index._tree + return {"hash": tree.hash()} + + +class IndexExportHandler(MethodHandler): + """Handles smp/index/export method.""" + + @property + def method(self) -> str: + return "smp/index/export" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + index = context["merkle_index"] + tree = index._tree + return cast(dict[str, Any], tree.export()) + + +class IndexImportHandler(MethodHandler): + """Handles smp/index/import method.""" + + @property + def method(self) -> str: + return "smp/index/import" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + data = params.get("data", {}) + index = context["merkle_index"] + tree = index._tree + tree.import_data(data) + return {"success": True} diff --git a/smp/protocol/handlers/query.py b/smp/protocol/handlers/query.py new file mode 100644 index 0000000..aee6582 --- /dev/null +++ b/smp/protocol/handlers/query.py @@ -0,0 +1,142 @@ +"""Handler for query methods (smp/navigate, smp/trace, smp/context, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + ContextParams, + FlowParams, + ImpactParams, + LocateParams, + NavigateParams, + SearchParams, + TraceParams, +) +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class NavigateHandler(MethodHandler): + """Handles smp/navigate method.""" + + @property + def method(self) -> str: + return "smp/navigate" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + np_ = msgspec.convert(params, NavigateParams) + engine = context["engine"] + return await engine.navigate(np_.query, np_.include_relationships) + + +class TraceHandler(MethodHandler): + """Handles smp/trace method.""" + + @property + def method(self) -> str: + return "smp/trace" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + trp = msgspec.convert(params, TraceParams) + engine = context["engine"] + result = await engine.trace(trp.start, trp.relationship, trp.depth, trp.direction) + return {"nodes": result} + + +class ContextHandler(MethodHandler): + """Handles smp/context method.""" + + @property + def method(self) -> str: + return "smp/context" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ctp = msgspec.convert(params, ContextParams) + engine = context["engine"] + return await engine.get_context(ctp.file_path, ctp.scope, ctp.depth) + + +class ImpactHandler(MethodHandler): + """Handles smp/impact method.""" + + @property + def method(self) -> str: + return "smp/impact" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + imp = msgspec.convert(params, ImpactParams) + engine = context["engine"] + return await engine.assess_impact(imp.entity, imp.change_type) + + +class LocateHandler(MethodHandler): + """Handles smp/locate method.""" + + @property + def method(self) -> str: + return "smp/locate" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + loc = msgspec.convert(params, LocateParams) + engine = context["engine"] + result = await engine.locate(loc.query, loc.fields, loc.node_types, loc.top_k) + return {"matches": result} + + +class SearchHandler(MethodHandler): + """Handles smp/search method.""" + + @property + def method(self) -> str: + return "smp/search" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sp = msgspec.convert(params, SearchParams) + engine = context["engine"] + return await engine.search(sp.query, sp.match, sp.filter, sp.top_k) + + +class FlowHandler(MethodHandler): + """Handles smp/flow method.""" + + @property + def method(self) -> str: + return "smp/flow" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + fp = msgspec.convert(params, FlowParams) + engine = context["engine"] + return await engine.find_flow(fp.start, fp.end, fp.flow_type) diff --git a/smp/protocol/handlers/query_ext.py b/smp/protocol/handlers/query_ext.py new file mode 100644 index 0000000..76002b7 --- /dev/null +++ b/smp/protocol/handlers/query_ext.py @@ -0,0 +1,115 @@ +"""Handler for diff, plan, conflict, why, and telemetry methods.""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + ConflictParams, + DiffParams, + PlanParams, + TelemetryParams, + WhyParams, +) +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class DiffHandler(MethodHandler): + """Handles smp/diff method.""" + + @property + def method(self) -> str: + return "smp/diff" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + dp = msgspec.convert(params, DiffParams) + engine = context["engine"] + return await engine.diff(dp.from_snapshot, dp.to_snapshot, dp.scope) + + +class PlanHandler(MethodHandler): + """Handles smp/plan method.""" + + @property + def method(self) -> str: + return "smp/plan" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + pp = msgspec.convert(params, PlanParams) + engine = context["engine"] + return await engine.plan(pp.change_description, pp.target_file, pp.change_type, pp.scope) + + +class ConflictHandler(MethodHandler): + """Handles smp/conflict method.""" + + @property + def method(self) -> str: + return "smp/conflict" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + cp = msgspec.convert(params, ConflictParams) + engine = context["engine"] + return await engine.conflict(cp.entity, cp.proposed_change, cp.context) + + +class WhyHandler(MethodHandler): + """Handles smp/why method.""" + + @property + def method(self) -> str: + return "smp/graph/why" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + wp = msgspec.convert(params, WhyParams) + engine = context["engine"] + return await engine.why(wp.entity, wp.relationship, wp.depth) + + +class TelemetryHandler(MethodHandler): + """Handles smp/telemetry method.""" + + @property + def method(self) -> str: + return "smp/telemetry" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + tp = msgspec.convert(params, TelemetryParams) + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + # Return basic stats if telemetry not configured + return {"action": tp.action, "status": "not_configured"} + + if tp.action == "get_stats": + return telemetry_engine.get_summary() + elif tp.action == "get_hot" and tp.node_id: + return telemetry_engine.get_stats(tp.node_id) + elif tp.action == "decay": + return {"decayed": telemetry_engine.decay()} + else: + return {"error": "Unknown telemetry action"} diff --git a/smp/protocol/handlers/safety.py b/smp/protocol/handlers/safety.py new file mode 100644 index 0000000..079a21c --- /dev/null +++ b/smp/protocol/handlers/safety.py @@ -0,0 +1,338 @@ +"""Handler for safety protocol methods (session, guard, lock, checkpoint, etc.).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + AuditGetParams, + CheckpointParams, + DryRunParams, + GuardCheckParams, + LockParams, + RollbackParams, + SessionCloseParams, + SessionOpenParams, + SessionRecoverParams, +) +from smp.engine.integrity import IntegrityCheckResult, IntegrityVerifier +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class SessionOpenHandler(MethodHandler): + """Handles smp/session/open method.""" + + @property + def method(self) -> str: + return "smp/session/open" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sop = msgspec.convert(params, SessionOpenParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return await safety["session_manager"].open_session(sop.agent_id, sop.task, sop.scope, sop.mode) + + +class SessionCloseHandler(MethodHandler): + """Handles smp/session/close method.""" + + @property + def method(self) -> str: + return "smp/session/close" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + scp = msgspec.convert(params, SessionCloseParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + close_result = await safety["session_manager"].close_session(scp.session_id, scp.status) + if not close_result: + raise ValueError(f"Session not found: {scp.session_id}") + + await safety["lock_manager"].release_all(scp.session_id) + if "audit_logger" in safety: + safety["audit_logger"].close_log(close_result.get("audit_log_id", ""), scp.status) + + return close_result + + +class SessionRecoverHandler(MethodHandler): + """Handles smp/session/recover method.""" + + @property + def method(self) -> str: + return "smp/session/recover" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + srp = msgspec.convert(params, SessionRecoverParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + session_manager = safety.get("session_manager") + if not session_manager: + raise ValueError("Session manager not configured") + + result = await session_manager.recover_session(srp.session_id) + if not result: + raise ValueError(f"Session not found: {srp.session_id}") + + return result + + +class GuardCheckHandler(MethodHandler): + """Handles smp/guard/check method.""" + + @property + def method(self) -> str: + return "smp/guard/check" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + gcp = msgspec.convert(params, GuardCheckParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return await safety["guard_engine"].check(gcp.session_id, gcp.target, gcp.intended_change) + + +class DryRunHandler(MethodHandler): + """Handles smp/dryrun method.""" + + @property + def method(self) -> str: + return "smp/dryrun" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + drp = msgspec.convert(params, DryRunParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return safety["dryrun_simulator"].simulate( + drp.session_id, drp.file_path, drp.proposed_content, drp.change_summary + ) + + +class CheckpointHandler(MethodHandler): + """Handles smp/checkpoint method.""" + + @property + def method(self) -> str: + return "smp/checkpoint" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + cp = msgspec.convert(params, CheckpointParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return safety["checkpoint_manager"].create(cp.session_id, cp.files) + + +class RollbackHandler(MethodHandler): + """Handles smp/rollback method.""" + + @property + def method(self) -> str: + return "smp/rollback" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + rbp = msgspec.convert(params, RollbackParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return safety["checkpoint_manager"].rollback(rbp.checkpoint_id) + + +class LockHandler(MethodHandler): + """Handles smp/lock method.""" + + @property + def method(self) -> str: + return "smp/lock" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + lp = msgspec.convert(params, LockParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + return await safety["lock_manager"].acquire(lp.session_id, lp.files) + + +class UnlockHandler(MethodHandler): + """Handles smp/unlock method.""" + + @property + def method(self) -> str: + return "smp/unlock" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + ulp = msgspec.convert(params, LockParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + await safety["lock_manager"].release(ulp.session_id, ulp.files) + return {"released": ulp.files} + + +class AuditGetHandler(MethodHandler): + """Handles smp/audit/get method.""" + + @property + def method(self) -> str: + return "smp/audit/get" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + agp = msgspec.convert(params, AuditGetParams) + safety = context.get("safety") + if not safety: + raise ValueError("Safety protocol not enabled") + + audit_logger = safety.get("audit_logger") + if not audit_logger: + raise ValueError("Audit logger not configured") + + # Prefer explicit audit_log_id, fall back to session_id param for convenience + audit = None + if agp.audit_log_id: + audit = audit_logger.get_log(agp.audit_log_id) + if not audit and "session_id" in params: + audit = audit_logger.get_log_by_session(params.get("session_id")) + + if not audit: + raise ValueError(f"Audit log not found: {agp.audit_log_id or params.get('session_id')}") + + return audit + + +class IntegrityVerifyHandler(MethodHandler): + """Handles smp/verify/integrity method.""" + + @property + def method(self) -> str: + return "smp/verify/integrity" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + session_id: str = params["session_id"] + node_ids: list[str] = params.get("node_ids") or [] + mode: str = params.get("mode", "ast") + if mode not in ("ast", "mutation", "both"): + raise ValueError(f"Invalid mode: {mode}. Must be 'ast', 'mutation', or 'both'") + + integrity_verifier: IntegrityVerifier | None = context.get("integrity_verifier") + if not integrity_verifier: + integrity_verifier = IntegrityVerifier() + + graph_store = context.get("engine") + all_mutations: list[dict[str, Any]] = [] + all_warnings: list[str] = [] + total_checks = 0 + all_passed = True + + target_ids = node_ids if node_ids else list(integrity_verifier._baselines.keys()) + + for nid in target_ids: + results: list[IntegrityCheckResult] = [] + + if mode in ("ast", "both"): + baseline = integrity_verifier._baselines.get(nid) + current_state = baseline["state"] if baseline else {} + ast_result = await integrity_verifier.verify(nid, current_state) + results.append(ast_result) + + if mode in ("mutation", "both"): + if not graph_store: + all_warnings.append(f"Graph store unavailable for mutation test on {nid}") + continue + mutation_result = await integrity_verifier.run_mutation_test(nid, graph_store) + results.append(mutation_result) + + for r in results: + if not r.passed: + all_passed = False + total_checks += r.checks_run + all_mutations.extend( + [ + { + "node_id": m.node_id, + "mutation_type": m.mutation_type, + "field_name": m.field_name, + "old_value": m.old_value, + "new_value": m.new_value, + "detected_at": m.detected_at, + } + for m in r.mutations_detected + ] + ) + all_warnings.extend(r.warnings) + + log.info( + "integrity_verify", + session_id=session_id, + mode=mode, + passed=all_passed, + checks_run=total_checks, + ) + + return { + "passed": all_passed, + "mutations_detected": all_mutations, + "warnings": all_warnings, + "checks_run": total_checks, + } diff --git a/smp/protocol/handlers/sandbox.py b/smp/protocol/handlers/sandbox.py new file mode 100644 index 0000000..9622b7f --- /dev/null +++ b/smp/protocol/handlers/sandbox.py @@ -0,0 +1,110 @@ +"""Handler for sandbox methods (smp/sandbox/spawn, etc.).""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +import msgspec + +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler +from smp.sandbox.executor import SandboxExecutor +from smp.sandbox.spawner import SandboxSpawner + +log = get_logger(__name__) + + +class SandboxSpawnHandler(MethodHandler): + """Handles smp/sandbox/spawn method.""" + + @property + def method(self) -> str: + return "smp/sandbox/spawn" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # In a real implementation, these would come from context/session + # For now, we'll use defaults and extract from params if provided + sp = msgspec.convert(params, dict) # Use raw params since no model exists yet + + spawner = SandboxSpawner() + + sandbox_info = spawner.spawn(name=sp.get("name"), template=sp.get("template"), files=sp.get("files")) + + return { + "sandbox_id": sandbox_info.sandbox_id, + "root_path": sandbox_info.root_path, + "created_at": sandbox_info.created_at, + "status": sandbox_info.status, + } + + +class SandboxExecuteHandler(MethodHandler): + """Handles smp/sandbox/execute method.""" + + @property + def method(self) -> str: + return "smp/sandbox/execute" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sep = msgspec.convert(params, dict) # Use raw params + + # Create executor with default config + executor = SandboxExecutor() + + # Execute the command + result = await executor.execute( + command=sep.get("command", []), stdin=sep.get("stdin"), cwd=sep.get("working_directory") + ) + + return { + "execution_id": result.execution_id, + "exit_code": result.exit_code, + "stdout": result.stdout, + "stderr": result.stderr, + "duration_ms": result.duration_ms, + "memory_used_mb": result.memory_used_mb, + "timed_out": result.timed_out, + "killed": result.killed, + "metadata": result.metadata, + } + + +class SandboxDestroyHandler(MethodHandler): + """Handles smp/sandbox/destroy method.""" + + @property + def method(self) -> str: + return "smp/sandbox/destroy" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + sdp = msgspec.convert(params, dict) # Use raw params + + spawner = SandboxSpawner() + sandbox_id = sdp.get("sandbox_id") + + if not sandbox_id: + return {"error": "sandbox_id is required"} + + destroyed = spawner.destroy(sandbox_id) + + if destroyed: + return { + "sandbox_id": sandbox_id, + "status": "destroyed", + "destroyed_at": datetime.now(UTC).isoformat(), + } + else: + return {"error": f"Sandbox not found: {sandbox_id}"} diff --git a/smp/protocol/handlers/telemetry.py b/smp/protocol/handlers/telemetry.py new file mode 100644 index 0000000..d12f2cf --- /dev/null +++ b/smp/protocol/handlers/telemetry.py @@ -0,0 +1,122 @@ +"""Telemetry handlers for SMP(3).""" + +from __future__ import annotations + +from typing import Any + +import msgspec + +from smp.core.models import ( + TelemetryParams, +) +from smp.logging import get_logger +from smp.protocol.handlers.base import MethodHandler + +log = get_logger(__name__) + + +class TelemetryHandler(MethodHandler): + """Handles smp/telemetry method.""" + + @property + def method(self) -> str: + return "smp/telemetry" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + tp = msgspec.convert(params, TelemetryParams) + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"action": tp.action, "status": "not_configured"} + elif tp.action == "get_stats": + return telemetry_engine.get_summary() + elif tp.action == "get_hot" and tp.node_id: + return telemetry_engine.get_stats(tp.node_id) + elif tp.action == "decay": + return {"decayed": telemetry_engine.decay()} + else: + return {"error": "Unknown telemetry action"} + + +class TelemetryHotHandler(MethodHandler): + """Handles smp/telemetry/hot method.""" + + @property + def method(self) -> str: + return "smp/telemetry/hot" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # Extract node_id from params + node_id = params.get("node_id") + if not node_id: + return {"error": "node_id is required"} + + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"status": "not_configured"} + + return telemetry_engine.get_stats(node_id) + + +class TelemetryNodeHandler(MethodHandler): + """Handles smp/telemetry/node method.""" + + @property + def method(self) -> str: + return "smp/telemetry/node" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # Extract node_id from params + node_id = params.get("node_id") + if not node_id: + return {"error": "node_id is required"} + + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"status": "not_configured"} + + return telemetry_engine.get_stats(node_id) + + +class TelemetryRecordHandler(MethodHandler): + """Handles smp/telemetry/record method.""" + + @property + def method(self) -> str: + return "smp/telemetry/record" + + async def handle( + self, + params: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + # Extract parameters + node_id = params.get("node_id") + action = params.get("action", "access") + session_id = params.get("session_id") + agent_id = params.get("agent_id") + + if not node_id: + return {"error": "node_id is required"} + + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + return {"status": "not_configured"} + + return telemetry_engine.record_access( + node_id=node_id, + action=action, + session_id=session_id or "", + agent_id=agent_id or "", + ) diff --git a/smp/protocol/router.py b/smp/protocol/router.py new file mode 100644 index 0000000..01855c5 --- /dev/null +++ b/smp/protocol/router.py @@ -0,0 +1,653 @@ +"""JSON-RPC 2.0 dispatcher for the Structural Memory Protocol (SMP(3)). + +All SMP protocol methods are routed through a single ``POST /rpc`` endpoint. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +import msgspec +from fastapi import Request +from fastapi.responses import Response + +from smp.core.models import ( + AnnotateBulkParams, + AnnotateParams, + AuditGetParams, + BatchUpdateParams, + CheckpointParams, + ContextParams, + DryRunParams, + EnrichBatchParams, + EnrichParams, + EnrichStaleParams, + EnrichStatusParams, + FlowParams, + GuardCheckParams, + ImpactParams, + IntegrityCheckParams, + JsonRpcError, + JsonRpcRequest, + JsonRpcResponse, + LocateParams, + LockParams, + NavigateParams, + PRCreateParams, + ReindexParams, + ReviewApproveParams, + ReviewCommentParams, + ReviewCreateParams, + ReviewRejectParams, + RollbackParams, + SearchParams, + SessionCloseParams, + SessionOpenParams, + TagParams, + TelemetryParams, + TraceParams, + UpdateParams, +) +from smp.logging import get_logger +from smp.sandbox.executor import SandboxExecutor + +log = get_logger(__name__) + + +def _error_response(req_id: int | str | None, code: int, message: str, data: Any = None) -> Response: + body = msgspec.json.encode( + JsonRpcResponse( + error=JsonRpcError(code=code, message=message, data=data), + id=req_id, + ) + ) + return Response(content=body, media_type="application/json", status_code=200) + + +def _success_response(req_id: int | str | None, result: Any) -> Response: + body = msgspec.json.encode(JsonRpcResponse(result=result, id=req_id)) + return Response(content=body, media_type="application/json", status_code=200) + + +async def _handle_update( + params: dict[str, Any], + engine: Any, + enricher: Any, + builder: Any, + registry: Any, + vector: Any, +) -> dict[str, Any]: + p = msgspec.convert(params, UpdateParams) + file_path = p.file_path + + if p.content: + parser_obj = registry.get(p.language) + if not parser_obj: + from smp.core.models import Language + + parser_obj = registry.get(Language.PYTHON) + if not parser_obj: + return {"error": "No parser available"} + doc = parser_obj.parse(p.content, file_path) + else: + doc = registry.parse_file(file_path) + + if not doc.nodes and not doc.edges: + return { + "file_path": file_path, + "nodes": 0, + "edges": 0, + "errors": len(doc.errors), + "message": "No nodes extracted", + } + + enriched_nodes = await enricher.enrich_batch(doc.nodes) + doc = type(doc)( + file_path=doc.file_path, + language=doc.language, + nodes=enriched_nodes, + edges=doc.edges, + errors=doc.errors, + ) + + if vector: + await vector.delete_by_file(file_path) + await builder.remove_document(file_path) + await builder.ingest_document(doc) + + return { + "file_path": file_path, + "nodes": len(doc.nodes), + "edges": len(doc.edges), + "errors": len(doc.errors), + } + + +async def handle_rpc( + request: Request, + *, + engine: Any, + enricher: Any, + builder: Any, + registry: Any, + vector: Any, + safety: dict[str, Any] | None = None, + telemetry_engine: Any = None, + handoff_manager: Any = None, + integrity_verifier: Any = None, + runtime_linker: Any = None, +) -> Response: + """Dispatch a single JSON-RPC 2.0 request.""" + + # Build context for handlers + context: dict[str, Any] = { + "engine": engine, + "enricher": enricher, + "builder": builder, + "registry": registry, + "vector": vector, + "safety": safety, + "telemetry_engine": telemetry_engine, + "handoff_manager": handoff_manager, + "integrity_verifier": integrity_verifier, + "runtime_linker": runtime_linker, + } + try: + body = await request.body() + except Exception: + return _error_response(None, -32700, "Parse error") + + if not body: + return _error_response(None, -32700, "Empty request body") + + try: + req = msgspec.json.decode(body, type=JsonRpcRequest) + except (msgspec.DecodeError, Exception) as exc: + return _error_response(None, -32700, f"Parse error: {exc}") + + if req.jsonrpc != "2.0": + return _error_response(req.id, -32600, "Invalid Request: jsonrpc must be '2.0'") + + if not req.method: + return _error_response(req.id, -32600, "Invalid Request: method is required") + + method = req.method + params = req.params + + log.debug("rpc_request", method=method, id=req.id) + + try: + # --- Memory Management --- + if method == "smp/update": + result = await _handle_update(params, engine, enricher, builder, registry, vector) + + elif method == "smp/batch_update": + bp = msgspec.convert(params, BatchUpdateParams) + results = [] + for change in bp.changes: + r = await _handle_update(change, engine, enricher, builder, registry, vector) + results.append(r) + result = {"updates": len(results), "results": results} + + elif method == "smp/reindex": + rp = msgspec.convert(params, ReindexParams) + result = {"status": "reindex_requested", "scope": rp.scope} + + # --- Enrichment --- + elif method == "smp/enrich": + ep = msgspec.convert(params, EnrichParams) + node = await engine._graph.get_node(ep.node_id) + if not node: + return _error_response(req.id, -32001, "Node not found", data={"node_id": ep.node_id}) + enriched = await enricher.enrich_node(node, force=ep.force) + if enriched.semantic.source_hash and enriched.semantic.status == "enriched": + await engine._graph.upsert_node(enriched) + result = { + "node_id": enriched.id, + "status": enriched.semantic.status, + "docstring": enriched.semantic.docstring, + "inline_comments": [{"line": c.line, "text": c.text} for c in enriched.semantic.inline_comments], + "decorators": enriched.semantic.decorators, + "annotations": { + "params": enriched.semantic.annotations.params if enriched.semantic.annotations else {}, + "returns": enriched.semantic.annotations.returns if enriched.semantic.annotations else None, + "throws": enriched.semantic.annotations.throws if enriched.semantic.annotations else [], + } + if enriched.semantic.annotations + else {}, + "tags": enriched.semantic.tags, + "source_hash": enriched.semantic.source_hash, + "enriched_at": enriched.semantic.enriched_at, + } + + elif method == "smp/enrich/batch": + ebp = msgspec.convert(params, EnrichBatchParams) + nodes = await engine._graph.find_nodes_by_scope(ebp.scope) + enriched_count = 0 + skipped_count = 0 + no_metadata_count = 0 + no_metadata_nodes: list[str] = [] + for node in nodes: + enriched = await enricher.enrich_node(node, force=ebp.force) + if enriched.semantic.status == "enriched": + enriched_count += 1 + await engine._graph.upsert_node(enriched) + elif enriched.semantic.status == "skipped": + skipped_count += 1 + elif enriched.semantic.status == "no_metadata": + no_metadata_count += 1 + no_metadata_nodes.append(enriched.id) + result = { + "enriched": enriched_count, + "skipped": skipped_count, + "no_metadata": no_metadata_count, + "failed": 0, + "no_metadata_nodes": no_metadata_nodes, + } + + elif method == "smp/enrich/stale": + esp = msgspec.convert(params, EnrichStaleParams) + nodes = await engine._graph.find_nodes_by_scope(esp.scope) + stale_nodes = [] + for node in nodes: + if node.semantic.source_hash: + from smp.engine.enricher import _compute_source_hash + + current = _compute_source_hash( + node.structural.name, + node.file_path, + node.structural.start_line, + node.structural.end_line, + node.structural.signature, + ) + if current != node.semantic.source_hash: + stale_nodes.append( + { + "node_id": node.id, + "file": node.file_path, + "last_enriched": node.semantic.enriched_at, + "current_hash": current, + "enriched_hash": node.semantic.source_hash, + } + ) + result = {"stale_count": len(stale_nodes), "stale_nodes": stale_nodes} + + elif method == "smp/enrich/status": + estp = msgspec.convert(params, EnrichStatusParams) + nodes = await engine._graph.find_nodes_by_scope(estp.scope) + total = len(nodes) + has_docstring = sum(1 for n in nodes if n.semantic.docstring) + has_annotations = sum( + 1 + for n in nodes + if n.semantic.annotations and (n.semantic.annotations.params or n.semantic.annotations.returns) + ) + has_tags = sum(1 for n in nodes if n.semantic.tags) + manually_annotated = sum(1 for n in nodes if n.semantic.manually_set) + no_metadata = sum(1 for n in nodes if n.semantic.status == "no_metadata") + coverage = round((total - no_metadata) / total * 100, 1) if total > 0 else 0 + result = { + "total_nodes": total, + "has_docstring": has_docstring, + "has_annotations": has_annotations, + "has_tags": has_tags, + "manually_annotated": manually_annotated, + "no_metadata": no_metadata, + "stale": 0, + "coverage_pct": coverage, + } + + # --- Annotation --- + elif method == "smp/annotate": + ap = msgspec.convert(params, AnnotateParams) + node = await engine._graph.get_node(ap.node_id) + if not node: + return _error_response(req.id, -32001, "Node not found", data={"node_id": ap.node_id}) + if node.semantic.docstring and not ap.force: + return _error_response( + req.id, + -32002, + "Node already has extracted docstring. Set force: true to override.", + data={"node_id": ap.node_id}, + ) + node.semantic.description = ap.description + node.semantic.tags = list(set(node.semantic.tags + ap.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + result = { + "node_id": ap.node_id, + "status": "annotated", + "manually_set": True, + "annotated_at": node.semantic.enriched_at, + } + + elif method == "smp/annotate/bulk": + abp = msgspec.convert(params, AnnotateBulkParams) + annotated = 0 + failed = 0 + for ann in abp.annotations: + node = await engine._graph.get_node(ann.node_id) + if not node: + failed += 1 + continue + node.semantic.description = ann.description + node.semantic.tags = list(set(node.semantic.tags + ann.tags)) + node.semantic.manually_set = True + node.semantic.status = "manually_annotated" + node.semantic.enriched_at = datetime.now(UTC).isoformat() + await engine._graph.upsert_node(node) + annotated += 1 + result = {"annotated": annotated, "failed": failed} + + elif method == "smp/tag": + tp = msgspec.convert(params, TagParams) + nodes = await engine._graph.find_nodes_by_scope(tp.scope) + affected = 0 + for node in nodes: + if tp.action == "add": + node.semantic.tags = list(set(node.semantic.tags + tp.tags)) + elif tp.action == "remove": + node.semantic.tags = [t for t in node.semantic.tags if t not in tp.tags] + elif tp.action == "replace": + node.semantic.tags = list(tp.tags) + await engine._graph.upsert_node(node) + affected += 1 + result = {"nodes_affected": affected, "action": tp.action, "scope": tp.scope} + + # --- Safety --- + elif method == "smp/session/open": + sop = msgspec.convert(params, SessionOpenParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = await safety["session_manager"].open_session(sop.agent_id, sop.task, sop.scope, sop.mode) + + elif method == "smp/session/close": + scp = msgspec.convert(params, SessionCloseParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + close_result = await safety["session_manager"].close_session(scp.session_id, scp.status) + if close_result: + safety["lock_manager"].release_all(scp.session_id) + if "audit_logger" in safety: + safety["audit_logger"].close_log(close_result.get("audit_log_id", ""), scp.status) + result = close_result + else: + return _error_response(req.id, -32001, "Session not found", data={"session_id": scp.session_id}) + + elif method == "smp/guard/check": + gcp = msgspec.convert(params, GuardCheckParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = await safety["guard_engine"].check(gcp.session_id, gcp.target, gcp.intended_change) + + elif method == "smp/dryrun": + drp = msgspec.convert(params, DryRunParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = safety["dryrun_simulator"].simulate( + drp.session_id, drp.file_path, drp.proposed_content, drp.change_summary + ) + + elif method == "smp/checkpoint": + cp = msgspec.convert(params, CheckpointParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = safety["checkpoint_manager"].create(cp.session_id, cp.files) + + elif method == "smp/rollback": + rbp = msgspec.convert(params, RollbackParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = safety["checkpoint_manager"].rollback(rbp.checkpoint_id) + + elif method == "smp/lock": + lp = msgspec.convert(params, LockParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + result = await safety["lock_manager"].acquire(lp.session_id, lp.files) + + elif method == "smp/unlock": + ulp = msgspec.convert(params, LockParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + await safety["lock_manager"].release(ulp.session_id, ulp.files) + result = {"released": ulp.files} + + elif method == "smp/audit/get": + agp = msgspec.convert(params, AuditGetParams) + if not safety: + return _error_response(req.id, -32601, "Safety protocol not enabled") + audit = safety["audit_logger"].get_log(agp.audit_log_id) + if not audit: + return _error_response(req.id, -32001, "Audit log not found", data={"audit_log_id": agp.audit_log_id}) + result = audit + + # --- Query --- + elif method == "smp/navigate": + np_ = msgspec.convert(params, NavigateParams) + result = await engine.navigate(np_.query, np_.include_relationships) + + elif method == "smp/trace": + trp = msgspec.convert(params, TraceParams) + result = await engine.trace(trp.start, trp.relationship, trp.depth, trp.direction) + + elif method == "smp/context": + ctp = msgspec.convert(params, ContextParams) + result = await engine.get_context(ctp.file_path, ctp.scope, ctp.depth) + + elif method == "smp/impact": + imp = msgspec.convert(params, ImpactParams) + result = await engine.assess_impact(imp.entity, imp.change_type) + + elif method == "smp/locate": + loc = msgspec.convert(params, LocateParams) + result = await engine.locate(loc.query, loc.fields, loc.node_types, loc.top_k) + + elif method == "smp/search": + sp = msgspec.convert(params, SearchParams) + result = await engine.search(sp.query, sp.match, sp.filter, sp.top_k) + + elif method == "smp/flow": + fp = msgspec.convert(params, FlowParams) + result = await engine.find_flow(fp.start, fp.end, fp.flow_type) + + elif method == "smp/graph/why": + wp = msgspec.convert(params, dict) + result = await engine.why( + entity=wp.get("entity", ""), + relationship=wp.get("relationship", ""), + depth=wp.get("depth", 3), + ) + + elif method == "smp/diff": + dp = msgspec.convert(params, dict) + result = await engine.diff_file( + file_path=dp.get("file_path", ""), + proposed_content=dp.get("proposed_content"), + ) + + elif method == "smp/plan": + pp = msgspec.convert(params, dict) + result = await engine.plan_multi_file( + session_id=pp.get("session_id", ""), + task=pp.get("task", ""), + intended_writes=pp.get("intended_writes", []), + ) + + elif method == "smp/conflict": + cp = msgspec.convert(params, dict) + result = await engine.detect_conflict( + session_a=cp.get("session_a", ""), + session_b=cp.get("session_b", ""), + ) + + # --- Sandbox --- + elif method == "smp/sandbox/spawn": + if not safety: + return _error_response(req.id, -32601, "Sandbox functionality requires safety protocol") + result = safety["sandbox_spawner"].spawn( + name=params.get("name"), template=params.get("template"), files=params.get("files") + ) + result = { + "sandbox_id": result.sandbox_id, + "root_path": result.root_path, + "created_at": result.created_at, + "status": result.status, + } + elif method == "smp/sandbox/execute": + if not safety: + return _error_response(req.id, -32601, "Sandbox functionality requires safety protocol") + sep = msgspec.convert(params, dict) + executor = safety.get("sandbox_executor") + if not executor: + # Create a default executor if not in context + executor = SandboxExecutor() + result = await executor.execute( + command=sep.get("command", []), stdin=sep.get("stdin"), cwd=sep.get("working_directory") + ) + result = { + "execution_id": result.execution_id, + "exit_code": result.exit_code, + "stdout": result.stdout, + "stderr": result.stderr, + "duration_ms": result.duration_ms, + "memory_used_mb": result.memory_used_mb, + "timed_out": result.timed_out, + "killed": result.killed, + "metadata": result.metadata, + } + elif method == "smp/sandbox/destroy": + if not safety: + return _error_response(req.id, -32601, "Sandbox functionality requires safety protocol") + sdp = msgspec.convert(params, dict) + sandbox_id = sdp.get("sandbox_id") + if not sandbox_id: + return _error_response(req.id, -32602, "sandbox_id is required") + destroyed = safety["sandbox_spawner"].destroy(sandbox_id) + if destroyed: + result = { + "sandbox_id": sandbox_id, + "status": "destroyed", + "destroyed_at": datetime.now(UTC).isoformat(), + } + else: + result = {"error": f"Sandbox not found: {sandbox_id}"} + + # --- Telemetry --- + elif method == "smp/telemetry": + tp = msgspec.convert(params, TelemetryParams) + telemetry_engine = context.get("telemetry_engine") + if not telemetry_engine: + result = {"action": tp.action, "status": "not_configured"} + elif tp.action == "get_stats": + result = telemetry_engine.get_summary() + elif tp.action == "get_hot" and tp.node_id: + result = telemetry_engine.get_stats(tp.node_id) + elif tp.action == "decay": + result = {"decayed": telemetry_engine.decay()} + else: + result = {"error": "Unknown telemetry action"} + + # --- Runtime Linker --- + elif method == "smp/linker/report": + linker = context.get("runtime_linker") + if not linker: + result = {"unresolved_edges": [], "status": "not_configured"} + else: + pending_count = linker.get_pending_count() + result = {"unresolved_edges": [], "pending_count": pending_count, "status": "ok"} + elif method == "smp/linker/runtime": + linker = context.get("runtime_linker") + if not linker: + result = {"hot_paths": [], "status": "not_configured"} + else: + threshold = params.get("threshold", 10) + result = {"hot_paths": linker.get_hot_paths(threshold), "stats": linker.get_stats()} + + # --- Handoff --- + elif method == "smp/handoff/review": + rcp = msgspec.convert(params, ReviewCreateParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + review = handoff_manager.create_review( + session_id=rcp.session_id, + files_changed=rcp.files_changed, + diff_summary=rcp.diff_summary, + reviewers=rcp.reviewers, + ) + result = {"review_id": review.review_id, "status": review.status, "created_at": review.created_at} + elif method == "smp/handoff/review/comment": + rcm = msgspec.convert(params, ReviewCommentParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + success = handoff_manager.add_comment( + review_id=rcm.review_id, + author=rcm.author, + comment=rcm.comment, + file_path=rcm.file_path, + line=rcm.line, + ) + result = {"success": success, "review_id": rcm.review_id} + elif method == "smp/handoff/review/approve": + rap = msgspec.convert(params, ReviewApproveParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + success = handoff_manager.approve(rap.review_id, rap.reviewer) + result = {"success": success, "review_id": rap.review_id, "status": "approved" if success else "failed"} + elif method == "smp/handoff/review/reject": + rrj = msgspec.convert(params, ReviewRejectParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + success = handoff_manager.reject(rrj.review_id, rrj.reviewer, rrj.reason) + result = {"success": success, "review_id": rrj.review_id, "status": "rejected" if success else "failed"} + elif method == "smp/handoff/pr": + pcp = msgspec.convert(params, PRCreateParams) + handoff_manager = context.get("handoff_manager") + if not handoff_manager: + result = {"error": "Handoff manager not configured"} + else: + pr = handoff_manager.create_pr( + review_id=pcp.review_id, + title=pcp.title, + body=pcp.body, + branch=pcp.branch, + base_branch=pcp.base_branch, + ) + if pr: + result = {"pr_id": pr.pr_id, "status": pr.status, "created_at": pr.created_at, "url": pr.url} + else: + result = {"error": "Review not found or not approved"} + + # --- Integrity --- + elif method == "smp/verify/integrity": + icp = msgspec.convert(params, IntegrityCheckParams) + verifier = context.get("integrity_verifier") + if not verifier: + result = {"status": "not_configured", "error": "Integrity verifier not available"} + else: + result = await verifier.verify(icp.node_id, icp.current_state) + + else: + return _error_response(req.id, -32601, f"Method not found: {method}") + + except msgspec.ValidationError as exc: + return _error_response(req.id, -32602, f"Invalid params: {exc}") + except Exception as exc: + log.error("rpc_internal_error", method=method, error=str(exc)) + return _error_response(req.id, -32603, f"Internal error: {exc}") + + if req.id is None: + return Response(content=b"", status_code=204) + + return _success_response(req.id, result) diff --git a/smp/protocol/server.py b/smp/protocol/server.py new file mode 100644 index 0000000..2e382a8 --- /dev/null +++ b/smp/protocol/server.py @@ -0,0 +1,162 @@ +"""FastAPI application with JSON-RPC 2.0 endpoint. + +Start with: ``python3.11 -m smp.cli serve`` +""" + +from __future__ import annotations + +try: + import sys + + import pysqlite3 + + sys.modules["sqlite3"] = pysqlite3 +except ImportError: + pass + +import os +from contextlib import asynccontextmanager +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.responses import Response + +from smp.core.merkle import MerkleIndex, MerkleTree +from smp.engine.community import CommunityDetector +from smp.engine.embedding import create_embedding_service +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.engine.seed_walk import SeedWalkEngine +from smp.engine.query import DefaultQueryEngine +from smp.logging import get_logger +from smp.parser.registry import ParserRegistry +from smp.protocol.dispatcher import handle_rpc +from smp.store.chroma_store import ChromaVectorStore +from smp.store.graph.neo4j_store import Neo4jGraphStore + +log = get_logger(__name__) + + +def create_app( + neo4j_uri: str | None = None, + neo4j_user: str | None = None, + neo4j_password: str | None = None, + safety_enabled: bool = False, +) -> FastAPI: + """Create and configure the SMP FastAPI application.""" + + uri = neo4j_uri or os.environ.get("SMP_NEO4J_URI", "bolt://localhost:7687") + user = neo4j_user or os.environ.get("SMP_NEO4J_USER", "neo4j") + password = neo4j_password or os.environ.get("SMP_NEO4J_PASSWORD", "") + + @asynccontextmanager + async def lifespan(app: FastAPI): # type: ignore[no-untyped-def] # noqa: ANN202 + graph = Neo4jGraphStore(uri=uri, user=user, password=password) + await graph.connect() + + vector = ChromaVectorStore() + await vector.connect() + + embedding_service = create_embedding_service() + await embedding_service.connect() + + enricher = StaticSemanticEnricher(embedding_service=embedding_service) + community_detector = CommunityDetector(graph_store=graph, vector_store=vector) + default_engine = DefaultQueryEngine(graph_store=graph, enricher=enricher) + engine = SeedWalkEngine(graph_store=graph, vector_store=vector, enricher=enricher, delegate=default_engine) + builder = DefaultGraphBuilder(graph) + registry = ParserRegistry() + merkle_index = MerkleIndex(MerkleTree()) + + safety: dict[str, Any] | None = None + if safety_enabled: + from smp.engine.handoff import HandoffManager + from smp.engine.integrity import IntegrityVerifier + from smp.engine.safety import ( + AuditLogger, + CheckpointManager, + DryRunSimulator, + GuardEngine, + LockManager, + SessionManager, + ) + from smp.engine.telemetry import TelemetryEngine + from smp.sandbox.executor import SandboxExecutor + from smp.sandbox.spawner import SandboxSpawner + + session_manager = SessionManager(graph_store=graph) + lock_manager = LockManager(graph_store=graph) + session_manager.set_graph_store(graph) + lock_manager.set_graph_store(graph) + sandbox_spawner = SandboxSpawner() + sandbox_executor = SandboxExecutor() + telemetry_engine = TelemetryEngine() + handoff_manager = HandoffManager() + integrity_verifier = IntegrityVerifier() + + # Runtime linker and linker are already available in the graph + # We'll add them to app.state for access via context + app.state.telemetry_engine = telemetry_engine + app.state.handoff_manager = handoff_manager + app.state.integrity_verifier = integrity_verifier + + safety = { + "session_manager": session_manager, + "lock_manager": lock_manager, + "guard_engine": GuardEngine(session_manager, lock_manager), + "dryrun_simulator": DryRunSimulator(), + "checkpoint_manager": CheckpointManager(), + "audit_logger": AuditLogger(), + "sandbox_spawner": sandbox_spawner, + "sandbox_executor": sandbox_executor, + } + + app.state.graph = graph + app.state.vector = vector + app.state.engine = engine + app.state.community_detector = community_detector + app.state.merkle_index = merkle_index + app.state.builder = builder + app.state.enricher = enricher + app.state.registry = registry + app.state.safety = safety + + log.info("server_started", neo4j=neo4j_uri, safety=safety_enabled) + yield + + await graph.close() + log.info("server_stopped") + + app = FastAPI(title="SMP — Structural Memory Protocol", version="3.0.0", lifespan=lifespan) + + @app.post("/rpc") + async def rpc_endpoint(request: Request) -> Response: + return await handle_rpc( + request, + engine=app.state.engine, + enricher=app.state.enricher, + builder=app.state.builder, + registry=app.state.registry, + vector=app.state.vector, + safety=app.state.safety, + telemetry_engine=getattr(app.state, "telemetry_engine", None), + handoff_manager=getattr(app.state, "handoff_manager", None), + integrity_verifier=getattr(app.state, "integrity_verifier", None), + ) + + @app.get("/health") + async def health() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/stats") + async def stats() -> dict[str, int]: + graph: Neo4jGraphStore = app.state.graph + return { + "nodes": await graph.count_nodes(), + "edges": await graph.count_edges(), + } + + return app + + +app = create_app() diff --git a/smp/sandbox/__init__.py b/smp/sandbox/__init__.py new file mode 100644 index 0000000..a73e2af --- /dev/null +++ b/smp/sandbox/__init__.py @@ -0,0 +1 @@ +"""SMP sandbox runtime module.""" diff --git a/smp/sandbox/docker_sandbox.py b/smp/sandbox/docker_sandbox.py new file mode 100644 index 0000000..80ba1dc --- /dev/null +++ b/smp/sandbox/docker_sandbox.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import docker + +from smp.logging import get_logger + +log = get_logger(__name__) + + +class DockerSandbox: + def __init__(self) -> None: + self._client = docker.from_env() + self._container: docker.models.containers.Container | None = None + self._network: docker.models.networks.Network | None = None + + def spawn(self, name: str, image: str, services: list[str]) -> str: + self._network = self._client.networks.create( + name=f"{name}_net", + internal=True, + ) + + self._container = self._client.containers.run( + image=image, + name=name, + detach=True, + network=self._network.name, + volumes={ + f"{name}_cow": {"bind": "/data", "mode": "rw"}, + }, + labels={"smp_sandbox": "true"}, + ) + + log.info("docker_sandbox_spawned", container_id=str(self._container.id), name=name) + return str(self._container.id) + + def execute(self, command: str, timeout: int) -> str: + if not self._container: + log.error("docker_sandbox_execute_failed", reason="no_container") + raise RuntimeError("No container spawned") + + exit_code, output = self._container.exec_run(command, timeout=timeout) + + if exit_code != 0: + log.warn("docker_sandbox_exec_nonzero", exit_code=exit_code, command=command) + + return str(output.decode("utf-8")) + + def destroy(self) -> None: + if self._container: + self._container.remove(force=True) + log.info("docker_sandbox_container_removed", container_id=self._container.id) + self._container = None + + if self._network: + self._network.remove() + log.info("docker_sandbox_network_removed", network_id=self._network.id) + self._network = None diff --git a/smp/sandbox/ebpf_collector.py b/smp/sandbox/ebpf_collector.py new file mode 100644 index 0000000..a36cb44 --- /dev/null +++ b/smp/sandbox/ebpf_collector.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import uuid + +from smp.logging import get_logger + +log = get_logger(__name__) + + +class EBPFCollector: + def __init__(self) -> None: + self._active_traces: dict[str, str] = {} + self._data: list[dict[str, str | int]] = [] + + def start_trace(self, session_id: str) -> str: + trace_id = str(uuid.uuid4()) + self._active_traces[trace_id] = session_id + log.info("ebpf_trace_started", trace_id=trace_id, session_id=session_id) + return trace_id + + def stop_trace(self, trace_id: str) -> None: + if trace_id in self._active_traces: + session_id = self._active_traces.pop(trace_id) + log.info("ebpf_trace_stopped", trace_id=trace_id, session_id=session_id) + else: + log.error("ebpf_trace_stop_failed", trace_id=trace_id) + + def get_traces(self) -> list[dict[str, str | int]]: + return self._data diff --git a/smp/sandbox/executor.py b/smp/sandbox/executor.py new file mode 100644 index 0000000..5018e95 --- /dev/null +++ b/smp/sandbox/executor.py @@ -0,0 +1,169 @@ +"""SMP(3) sandbox executor for isolated runtime execution. + +Provides isolated execution environments for running agent code safely. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import os +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + +_SANDBOX_DEFAULT_TIMEOUT = 30 +_SANDBOX_DEFAULT_MEMORY_MB = 512 + + +@dataclass +class ExecutionResult: + """Result of a sandbox execution.""" + + execution_id: str + exit_code: int + stdout: str + stderr: str + duration_ms: int + memory_used_mb: float = 0.0 + timed_out: bool = False + killed: bool = False + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class SandboxConfig: + """Configuration for sandbox execution.""" + + timeout_seconds: int = _SANDBOX_DEFAULT_TIMEOUT + memory_limit_mb: int = _SANDBOX_DEFAULT_MEMORY_MB + allow_network: bool = False + allow_file_write: bool = False + working_directory: str = "" + environment: dict[str, str] = field(default_factory=dict) + + +class SandboxExecutor: + """Executes code in an isolated sandbox environment.""" + + def __init__(self, config: SandboxConfig | None = None) -> None: + self._config = config or SandboxConfig() + self._active_processes: dict[str, asyncio.subprocess.Process] = {} + + async def execute( + self, + command: list[str], + stdin: str | None = None, + cwd: str | None = None, + ) -> ExecutionResult: + """Execute a command in the sandbox.""" + execution_id = f"exec_{uuid.uuid4().hex[:8]}" + start_time = asyncio.get_event_loop().time() + + work_dir = cwd or self._config.working_directory or str(Path.cwd()) + + env = os.environ.copy() + env.update(self._config.environment) + if not self._config.allow_network: + env["NO_NETWORK"] = "1" + + try: + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE if stdin else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=work_dir, + env=env, + ) + self._active_processes[execution_id] = process + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + process.communicate(stdin.encode() if stdin else None), + timeout=self._config.timeout_seconds, + ) + timed_out = False + except TimeoutError: + process.kill() + await process.wait() + stdout_bytes, stderr_bytes = b"", b"Timeout exceeded" + timed_out = True + + duration_ms = int((asyncio.get_event_loop().time() - start_time) * 1000) + + result = ExecutionResult( + execution_id=execution_id, + exit_code=process.returncode or -1, + stdout=stdout_bytes.decode("utf-8", errors="replace"), + stderr=stderr_bytes.decode("utf-8", errors="replace"), + duration_ms=duration_ms, + timed_out=timed_out, + killed=timed_out, + ) + + log.info( + "sandbox_execution_complete", + execution_id=execution_id, + exit_code=result.exit_code, + duration_ms=duration_ms, + timed_out=timed_out, + ) + return result + + except Exception as exc: + log.error("sandbox_execution_error", execution_id=execution_id, error=str(exc)) + return ExecutionResult( + execution_id=execution_id, + exit_code=-1, + stdout="", + stderr=str(exc), + duration_ms=int((asyncio.get_event_loop().time() - start_time) * 1000), + ) + finally: + self._active_processes.pop(execution_id, None) + + async def execute_python( + self, + code: str, + timeout: int | None = None, + ) -> ExecutionResult: + """Execute Python code in the sandbox.""" + import tempfile + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_path = f.name + + try: + config = SandboxConfig( + timeout_seconds=timeout or self._config.timeout_seconds, + **{k: v for k, v in self._config.__dict__.items() if k != "timeout_seconds"}, + ) + executor = SandboxExecutor(config) + return await executor.execute(["python3.11", temp_path]) + finally: + Path(temp_path).unlink(missing_ok=True) + + def kill(self, execution_id: str) -> bool: + """Kill an active execution.""" + process = self._active_processes.get(execution_id) + if process: + process.kill() + log.info("sandbox_killed", execution_id=execution_id) + return True + return False + + async def cleanup(self) -> None: + """Kill all active executions.""" + for exec_id, process in list(self._active_processes.items()): + process.kill() + with contextlib.suppress(Exception): + await process.wait() + log.info("sandbox_cleanup", execution_id=exec_id) + self._active_processes.clear() diff --git a/smp/sandbox/spawner.py b/smp/sandbox/spawner.py new file mode 100644 index 0000000..39bcdb4 --- /dev/null +++ b/smp/sandbox/spawner.py @@ -0,0 +1,113 @@ +"""Sandbox spawner for creating isolated execution environments. + +Manages the lifecycle of sandboxed processes and containers. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from smp.logging import get_logger + +log = get_logger(__name__) + +_DEFAULT_SANDBOX_ROOT = Path.home() / ".smp" / "sandboxes" + + +@dataclass +class SandboxInfo: + """Information about a spawned sandbox.""" + + sandbox_id: str + root_path: str + created_at: str + status: str = "created" + pid: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +class SandboxSpawner: + """Spawns and manages isolated sandbox directories.""" + + def __init__(self, sandbox_root: Path | None = None) -> None: + self._root = sandbox_root or _DEFAULT_SANDBOX_ROOT + self._sandboxes: dict[str, SandboxInfo] = {} + + def spawn( + self, + name: str | None = None, + template: str | None = None, + files: dict[str, str] | None = None, + ) -> SandboxInfo: + """Create a new sandbox directory.""" + sandbox_id = f"sandbox_{uuid.uuid4().hex[:8]}" + sandbox_name = name or sandbox_id + sandbox_path = self._root / sandbox_name + + sandbox_path.mkdir(parents=True, exist_ok=True) + + if template: + template_path = self._root / template + if template_path.exists(): + import shutil + + shutil.copytree(template_path, sandbox_path, dirs_exist_ok=True) + + if files: + for rel_path, content in files.items(): + file_path = sandbox_path / rel_path + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content, encoding="utf-8") + + info = SandboxInfo( + sandbox_id=sandbox_id, + root_path=str(sandbox_path), + created_at=datetime.now(UTC).isoformat(), + ) + self._sandboxes[sandbox_id] = info + + log.info("sandbox_spawned", sandbox_id=sandbox_id, path=str(sandbox_path)) + return info + + def get(self, sandbox_id: str) -> SandboxInfo | None: + """Get sandbox info by ID.""" + return self._sandboxes.get(sandbox_id) + + def list_active(self) -> list[SandboxInfo]: + """List all active sandboxes.""" + return list(self._sandboxes.values()) + + def destroy(self, sandbox_id: str) -> bool: + """Remove a sandbox directory.""" + info = self._sandboxes.get(sandbox_id) + if not info: + return False + + import shutil + + path = Path(info.root_path) + if path.exists(): + shutil.rmtree(path) + + del self._sandboxes[sandbox_id] + log.info("sandbox_destroyed", sandbox_id=sandbox_id) + return True + + async def cleanup_all(self) -> int: + """Remove all sandbox directories.""" + import shutil + + count = 0 + for sandbox_id, info in list(self._sandboxes.items()): + path = Path(info.root_path) + if path.exists(): + shutil.rmtree(path) + count += 1 + del self._sandboxes[sandbox_id] + + log.info("sandboxes_cleaned", count=count) + return count diff --git a/smp/store/__init__.py b/smp/store/__init__.py new file mode 100644 index 0000000..ecccfaa --- /dev/null +++ b/smp/store/__init__.py @@ -0,0 +1 @@ +"""Store layer — graph and vector backends.""" diff --git a/smp/store/chroma_store.py b/smp/store/chroma_store.py new file mode 100644 index 0000000..f51df7d --- /dev/null +++ b/smp/store/chroma_store.py @@ -0,0 +1,155 @@ +"""ChromaDB-backed vector store implementation.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import chromadb + +from smp.logging import get_logger +from smp.store.interfaces import VectorStore + +log = get_logger(__name__) + + +class ChromaVectorStore(VectorStore): + """Persist embeddings in ChromaDB with metadata filtering.""" + + def __init__( + self, + collection_name: str = "smp_code_embeddings", + persist_dir: str | None = None, + ) -> None: + self._collection_name = collection_name + self._persist_dir = persist_dir + self._client: Any = None + self._collection: Any = None + + async def connect(self) -> None: + if self._persist_dir is not None: + self._client = chromadb.PersistentClient(path=self._persist_dir) + else: + self._client = chromadb.Client() + self._collection = self._client.get_or_create_collection(name=self._collection_name) + log.info("chroma_connected", collection=self._collection_name, persist_dir=self._persist_dir) + + async def close(self) -> None: + self._client = None + self._collection = None + log.info("chroma_closed", collection=self._collection_name) + + async def clear(self) -> None: + if self._client is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._client.delete_collection(name=self._collection_name) + self._collection = self._client.get_or_create_collection(name=self._collection_name) + log.info("chroma_cleared", collection=self._collection_name) + + async def upsert( + self, + ids: Sequence[str], + embeddings: Sequence[Sequence[float]], + metadatas: Sequence[dict[str, Any]], + documents: Sequence[str] | None = None, + ) -> None: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._collection.upsert( + ids=list(ids), + embeddings=[list(e) for e in embeddings], + metadatas=list(metadatas), + documents=list(documents) if documents is not None else None, + ) + log.info("chroma_upserted", count=len(ids)) + + async def query( + self, + embedding: Sequence[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + result = self._collection.query( + query_embeddings=[list(embedding)], + n_results=top_k, + where=where, + ) + return _normalise_query_result(result) + + async def get(self, ids: Sequence[str]) -> list[dict[str, Any] | None]: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + result = self._collection.get(ids=list(ids)) + return _normalise_get_result(result) + + async def delete(self, ids: Sequence[str]) -> int: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._collection.delete(ids=list(ids)) + log.info("chroma_deleted", count=len(ids)) + return len(ids) + + async def delete_by_file(self, file_path: str) -> int: + if self._collection is None: + raise RuntimeError("ChromaVectorStore is not connected") + self._collection.delete(where={"file_path": file_path}) + log.info("chroma_deleted_by_file", file_path=file_path) + return -1 + + async def add_code_embedding( + self, + node_id: str, + embedding: list[float], + metadata: dict[str, Any], + document: str = "", + ) -> None: + await self.upsert( + ids=[node_id], + embeddings=[embedding], + metadatas=[metadata], + documents=[document], + ) + + async def query_similar( + self, + embedding: list[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + return await self.query(embedding=embedding, top_k=top_k, where=where) + + +def _normalise_query_result(result: dict[str, Any]) -> list[dict[str, Any]]: + ids_batch = result.get("ids", [[]]) + distances_batch = result.get("distances", [[]]) + metadatas_batch = result.get("metadatas", [[]]) + documents_batch = result.get("documents", [[]]) + out: list[dict[str, Any]] = [] + for i, entry_id in enumerate(ids_batch[0]): + out.append( + { + "id": entry_id, + "score": distances_batch[0][i] if distances_batch and i < len(distances_batch[0]) else None, + "metadata": metadatas_batch[0][i] if metadatas_batch and i < len(metadatas_batch[0]) else {}, + "document": documents_batch[0][i] if documents_batch and i < len(documents_batch[0]) else "", + } + ) + return out + + +def _normalise_get_result(result: dict[str, Any]) -> list[dict[str, Any] | None]: + ids = result.get("ids", []) + metadatas = result.get("metadatas", []) + documents = result.get("documents", []) + out: list[dict[str, Any] | None] = [] + for i, entry_id in enumerate(ids): + out.append( + { + "id": entry_id, + "metadata": metadatas[i] if metadatas and i < len(metadatas) else {}, + "document": documents[i] if documents and i < len(documents) else "", + } + ) + return out diff --git a/smp/store/graph/__init__.py b/smp/store/graph/__init__.py new file mode 100644 index 0000000..74b6ccf --- /dev/null +++ b/smp/store/graph/__init__.py @@ -0,0 +1 @@ +"""Graph store implementations.""" diff --git a/smp/store/graph/neo4j_store.py b/smp/store/graph/neo4j_store.py new file mode 100644 index 0000000..d926102 --- /dev/null +++ b/smp/store/graph/neo4j_store.py @@ -0,0 +1,558 @@ +"""Neo4j-backed graph store implementation. + +Uses the official ``neo4j`` Python driver with async support. +Updated for SMP(3) partitioned schema (structural + semantic). +""" + +from __future__ import annotations + +import os +from collections.abc import Sequence +from datetime import UTC, datetime +from typing import Any + +from neo4j import AsyncDriver, AsyncGraphDatabase + +from smp.core.models import ( + Annotations, + EdgeType, + GraphEdge, + GraphNode, + InlineComment, + NodeType, + SemanticProperties, + StructuralProperties, +) +from smp.logging import get_logger +from smp.store.interfaces import GraphStore + +log = get_logger(__name__) + +_ALL_LABEL = "SMPNode" +_SESSION_LABEL = "SMPSession" +_LOCK_LABEL = "SMPLck" + + +def _node_to_props(node: GraphNode) -> dict[str, Any]: + """Convert a GraphNode to flat Neo4j properties.""" + return { + "id": node.id, + "type": node.type.value, + "file_path": node.file_path, + "structural_name": node.structural.name, + "structural_file": node.structural.file, + "structural_signature": node.structural.signature, + "structural_start_line": node.structural.start_line, + "structural_end_line": node.structural.end_line, + "structural_complexity": node.structural.complexity, + "structural_lines": node.structural.lines, + "structural_parameters": node.structural.parameters, + "semantic_status": node.semantic.status, + "semantic_docstring": node.semantic.docstring, + "semantic_description": node.semantic.description or "", + "semantic_decorators": str(node.semantic.decorators), + "semantic_tags": str(node.semantic.tags), + "semantic_manually_set": node.semantic.manually_set, + "semantic_source_hash": node.semantic.source_hash, + "semantic_enriched_at": node.semantic.enriched_at, + "semantic_annotations": str(node.semantic.annotations) if node.semantic.annotations else "", + "semantic_inline_comments": str(node.semantic.inline_comments), + } + + +def _record_to_node(record: dict[str, Any]) -> GraphNode: + """Reconstruct a GraphNode from a Neo4j record.""" + structural = StructuralProperties( + name=record.get("structural_name", ""), + file=record.get("structural_file", ""), + signature=record.get("structural_signature", ""), + start_line=record.get("structural_start_line", 0), + end_line=record.get("structural_end_line", 0), + complexity=record.get("structural_complexity", 0), + lines=record.get("structural_lines", 0), + parameters=record.get("structural_parameters", 0), + ) + + annotations_raw = record.get("semantic_annotations", "") + annotations: Annotations | None = None + if annotations_raw and annotations_raw != "": + try: + import ast + + parsed = ast.literal_eval(annotations_raw) + if isinstance(parsed, dict): + annotations = Annotations( + params=parsed.get("params", {}), + returns=parsed.get("returns"), + throws=parsed.get("throws", []), + ) + except (ValueError, SyntaxError): + pass + + decorators_raw = record.get("semantic_decorators", "[]") + try: + import ast + + decorators = ast.literal_eval(decorators_raw) if decorators_raw else [] + if not isinstance(decorators, list): + decorators = [] + except (ValueError, SyntaxError): + decorators = [] + + tags_raw = record.get("semantic_tags", "[]") + try: + import ast + + tags = ast.literal_eval(tags_raw) if tags_raw else [] + if not isinstance(tags, list): + tags = [] + except (ValueError, SyntaxError): + tags = [] + + comments_raw = record.get("semantic_inline_comments", "[]") + inline_comments: list[InlineComment] = [] + try: + import ast + + parsed_comments = ast.literal_eval(comments_raw) if comments_raw else [] + if isinstance(parsed_comments, list): + for c in parsed_comments: + if isinstance(c, dict): + inline_comments.append(InlineComment(line=c.get("line", 0), text=c.get("text", ""))) + elif isinstance(c, InlineComment): + inline_comments.append(c) + except (ValueError, SyntaxError): + pass + + semantic = SemanticProperties( + status=record.get("semantic_status", "no_metadata"), + docstring=record.get("semantic_docstring", ""), + description=record.get("semantic_description") or None, + inline_comments=inline_comments, + decorators=decorators, + annotations=annotations, + tags=tags, + manually_set=record.get("semantic_manually_set", False), + source_hash=record.get("semantic_source_hash", ""), + enriched_at=record.get("semantic_enriched_at", ""), + ) + + return GraphNode( + id=record["id"], + type=NodeType(record["type"]), + file_path=record["file_path"], + structural=structural, + semantic=semantic, + ) + + +class Neo4jGraphStore(GraphStore): + """Graph store backed by a Neo4j instance.""" + + def __init__( + self, + uri: str = "", + user: str = "", + password: str = "", + database: str = "neo4j", + ) -> None: + self._uri = uri or os.environ.get("SMP_NEO4J_URI", "bolt://localhost:7687") + self._user = user or os.environ.get("SMP_NEO4J_USER", "neo4j") + self._password = password or os.environ.get("SMP_NEO4J_PASSWORD", "") + self._database = database + self._driver: AsyncDriver | None = None + + async def connect(self) -> None: + self._driver = AsyncGraphDatabase.driver(self._uri, auth=(self._user, self._password)) + await self._driver.verify_connectivity() + log.info("neo4j_connected", uri=self._uri) + await self._execute(f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:{_ALL_LABEL}) REQUIRE n.id IS UNIQUE") + + # Create full-text index for search + await self._execute( + f"CREATE FULLTEXT INDEX node_search_index IF NOT EXISTS FOR (n:{_ALL_LABEL}) " + "ON EACH [n.semantic_docstring, n.semantic_description, n.structural_name, n.file_path]" + ) + + async def close(self) -> None: + if self._driver: + await self._driver.close() + self._driver = None + log.info("neo4j_closed") + + async def clear(self) -> None: + await self._execute("MATCH (n) DETACH DELETE n") + log.warning("neo4j_cleared") + + async def upsert_node(self, node: GraphNode) -> None: + props = _node_to_props(node) + cypher = f""" + MERGE (n:{_ALL_LABEL} {{id: $id}}) + SET n += $props + """ + await self._execute(cypher, {"id": node.id, "props": props}) + log.debug("node_upserted", node_id=node.id) + + async def upsert_session(self, session: Any) -> None: + """Store or update a session in the graph.""" + props = { + "session_id": session.session_id, + "agent_id": session.agent_id, + "task": session.task, + "mode": session.mode, + "opened_at": session.opened_at, + "expires_at": session.expires_at, + "status": session.status, + "files_written": session.files_written, + "files_read": session.files_read, + } + cypher = f""" + MERGE (n:{_SESSION_LABEL} {{session_id: $session_id}}) + SET n += $props + """ + await self._execute(cypher, {"session_id": session.session_id, "props": props}) + log.debug("session_upserted", session_id=session.session_id) + + async def get_session(self, session_id: str) -> dict[str, Any] | None: + """Retrieve a session by ID.""" + cypher = f"MATCH (n:{_SESSION_LABEL} {{session_id: $session_id}}) RETURN n" + records = await self._execute(cypher, {"session_id": session_id}) + if not records: + return None + return dict(records[0]["n"]) + + async def delete_session(self, session_id: str) -> bool: + """Delete a session from the graph.""" + cypher = f"MATCH (n:{_SESSION_LABEL} {{session_id: $session_id}}) DETACH DELETE n RETURN count(n) AS deleted" + records = await self._execute(cypher, {"session_id": session_id}) + deleted = records[0]["deleted"] if records else 0 + return deleted > 0 + + async def upsert_lock(self, file_path: str, session_id: str) -> None: + """Store a file lock.""" + props = { + "file_path": file_path, + "session_id": session_id, + "acquired_at": datetime.now(UTC).isoformat(), + } + cypher = f""" + MERGE (n:{_LOCK_LABEL} {{file_path: $file_path, session_id: $session_id}}) + SET n += $props + """ + await self._execute(cypher, {"file_path": file_path, "session_id": session_id, "props": props}) + log.debug("lock_upserted", file_path=file_path, session_id=session_id) + + async def get_lock(self, file_path: str) -> dict[str, Any] | None: + """Get lock info for a file.""" + cypher = f"MATCH (n:{_LOCK_LABEL} {{file_path: $file_path}}) RETURN n LIMIT 1" + records = await self._execute(cypher, {"file_path": file_path}) + if not records: + return None + return dict(records[0]["n"]) + + async def release_lock(self, file_path: str, session_id: str) -> bool: + """Release a file lock.""" + cypher = f""" + MATCH (n:{_LOCK_LABEL} {{file_path: $file_path, session_id: $session_id}}) + DETACH DELETE n + RETURN count(n) AS deleted + """ + records = await self._execute(cypher, {"file_path": file_path, "session_id": session_id}) + deleted = records[0]["deleted"] if records else 0 + if deleted > 0: + log.debug("lock_released", file_path=file_path, session_id=session_id) + return deleted > 0 + + async def release_all_locks(self, session_id: str) -> int: + """Release all locks held by a session.""" + cypher = f""" + MATCH (n:{_LOCK_LABEL} {{session_id: $session_id}}) + DETACH DELETE n + RETURN count(n) AS deleted + """ + records = await self._execute(cypher, {"session_id": session_id}) + deleted = records[0]["deleted"] if records else 0 + log.info("locks_released_by_session", session_id=session_id, count=deleted) + return deleted + + async def upsert_nodes(self, nodes: Sequence[GraphNode]) -> None: + if not nodes: + return + batch = [_node_to_props(n) for n in nodes] + cypher = f""" + UNWIND $batch AS row + MERGE (n:{_ALL_LABEL} {{id: row.id}}) + SET n += row + """ + await self._execute(cypher, {"batch": batch}) + log.info("nodes_upserted_batch", count=len(nodes)) + + async def get_node(self, node_id: str) -> GraphNode | None: + cypher = f"MATCH (n:{_ALL_LABEL} {{id: $id}}) RETURN n" + records = await self._execute(cypher, {"id": node_id}) + if not records: + return None + return _record_to_node(dict(records[0]["n"])) + + async def delete_node(self, node_id: str) -> bool: + cypher = f"MATCH (n:{_ALL_LABEL} {{id: $id}}) DETACH DELETE n RETURN count(n) AS deleted" + records = await self._execute(cypher, {"id": node_id}) + deleted = records[0]["deleted"] if records else 0 + return deleted > 0 + + async def delete_nodes_by_file(self, file_path: str) -> int: + stem = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path + cypher = f""" + MATCH (n:{_ALL_LABEL}) + WHERE n.file_path = $file_path OR n.file_path = $stem + DETACH DELETE n + RETURN count(n) AS deleted + """ + records = await self._execute(cypher, {"file_path": file_path, "stem": stem}) + deleted = records[0]["deleted"] if records else 0 + log.info("nodes_deleted_by_file", file_path=file_path, deleted=deleted) + return deleted + + async def upsert_edge(self, edge: GraphEdge) -> None: + rel_type = edge.type.value + cypher = f""" + MATCH (a:{_ALL_LABEL} {{id: $source_id}}) + MATCH (b:{_ALL_LABEL} {{id: $target_id}}) + MERGE (a)-[r:{rel_type}]->(b) + SET r += $metadata + """ + await self._execute( + cypher, + { + "source_id": edge.source_id, + "target_id": edge.target_id, + "metadata": edge.metadata, + }, + ) + log.debug("edge_upserted", src=edge.source_id, tgt=edge.target_id, type=rel_type) + + async def upsert_edges(self, edges: Sequence[GraphEdge]) -> None: + if not edges: + return + grouped: dict[str, list[dict[str, Any]]] = {} + for e in edges: + grouped.setdefault(e.type.value, []).append( + {"source_id": e.source_id, "target_id": e.target_id, "metadata": e.metadata} + ) + for rel_type, batch in grouped.items(): + cypher = f""" + UNWIND $batch AS row + MATCH (a:{_ALL_LABEL} {{id: row.source_id}}) + MATCH (b:{_ALL_LABEL} {{id: row.target_id}}) + MERGE (a)-[r:{rel_type}]->(b) + SET r += row.metadata + """ + await self._execute(cypher, {"batch": batch}) + log.info("edges_upserted_batch", count=len(edges)) + + async def get_edges( + self, + node_id: str, + edge_type: EdgeType | None = None, + direction: str = "both", + ) -> list[GraphEdge]: + type_filter = f":{edge_type.value}" if edge_type else "" + if direction == "outgoing": + pattern = f"(a:{_ALL_LABEL} {{id: $id}})-[r{type_filter}]->(b)" + elif direction == "incoming": + pattern = f"(a)-[r{type_filter}]->(b:{_ALL_LABEL} {{id: $id}})" + else: + pattern = f"(a:{_ALL_LABEL} {{id: $id}})-[r{type_filter}]-(b)" + + cypher = f"MATCH {pattern} RETURN a.id AS source, b.id AS target, type(r) AS rel_type" + records = await self._execute(cypher, {"id": node_id}) + return [ + GraphEdge( + source_id=rec["source"], + target_id=rec["target"], + type=EdgeType(rec["rel_type"]), + ) + for rec in records + ] + + async def get_neighbors( + self, + node_id: str, + edge_type: EdgeType | None = None, + depth: int = 1, + ) -> list[GraphNode]: + type_filter = f":{edge_type.value}" if edge_type else "" + depth_str = f"1..{depth}" + cypher = f""" + MATCH (start:{_ALL_LABEL} {{id: $id}})-[r{type_filter}*{depth_str}]->(neighbor:{_ALL_LABEL}) + RETURN DISTINCT neighbor + """ + records = await self._execute(cypher, {"id": node_id}) + return [_record_to_node(dict(rec["neighbor"])) for rec in records] + + async def traverse( + self, + start_id: str, + edge_type: EdgeType, + depth: int, + max_nodes: int = 100, + direction: str = "outgoing", + ) -> list[GraphNode]: + rel_type = edge_type.value + if direction == "incoming": + cypher = f""" + MATCH path = (start:{_ALL_LABEL} {{id: $id}})<-[r:{rel_type}*1..{depth}]-(node:{_ALL_LABEL}) + RETURN DISTINCT node + LIMIT $max_nodes + """ + else: + cypher = f""" + MATCH path = (start:{_ALL_LABEL} {{id: $id}})-[r:{rel_type}*1..{depth}]->(node:{_ALL_LABEL}) + RETURN DISTINCT node + LIMIT $max_nodes + """ + records = await self._execute(cypher, {"id": start_id, "max_nodes": max_nodes}) + return [_record_to_node(dict(rec["node"])) for rec in records] + + async def find_nodes( + self, + *, + type: NodeType | None = None, + file_path: str | None = None, + name: str | None = None, + ) -> list[GraphNode]: + conditions: list[str] = [] + params: dict[str, Any] = {} + if type: + conditions.append("n.type = $type") + params["type"] = type.value + if file_path: + stem = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path + conditions.append("(n.file_path = $file_path OR n.file_path = $stem)") + params["file_path"] = file_path + params["stem"] = stem + if name: + conditions.append("n.structural_name = $name") + params["name"] = name + + where = " AND ".join(conditions) + where_clause = f"WHERE {where}" if where else "" + cypher = f"MATCH (n:{_ALL_LABEL}) {where_clause} RETURN n" + records = await self._execute(cypher, params) + return [_record_to_node(dict(rec["n"])) for rec in records] + + async def find_nodes_by_scope(self, scope: str) -> list[GraphNode]: + """Find nodes matching a scope prefix (package:path or file:path).""" + if scope == "full": + cypher = f"MATCH (n:{_ALL_LABEL}) RETURN n" + records = await self._execute(cypher) + return [_record_to_node(dict(rec["n"])) for rec in records] + + if scope.startswith("package:"): + prefix = scope[len("package:") :] + cypher = f"MATCH (n:{_ALL_LABEL}) WHERE n.file_path STARTS WITH $prefix RETURN n" + records = await self._execute(cypher, {"prefix": prefix}) + return [_record_to_node(dict(rec["n"])) for rec in records] + + if scope.startswith("file:"): + fp = scope[len("file:") :] + cypher = f"MATCH (n:{_ALL_LABEL}) WHERE n.file_path = $fp RETURN n" + records = await self._execute(cypher, {"fp": fp}) + return [_record_to_node(dict(rec["n"])) for rec in records] + + return [] + + async def get_node_degree(self, node_id: str) -> tuple[int, int]: + """Return (in_degree, out_degree) for a node.""" + cypher = f""" + MATCH (n:{_ALL_LABEL} {{id: $id}}) + OPTIONAL MATCH (n)-[out]->() + OPTIONAL MATCH ()-[inp]->(n) + RETURN count(DISTINCT out) AS out_degree, count(DISTINCT inp) AS in_degree + """ + records = await self._execute(cypher, {"id": node_id}) + if records: + return records[0]["in_degree"], records[0]["out_degree"] + return 0, 0 + + async def count_nodes(self) -> int: + cypher = f"MATCH (n:{_ALL_LABEL}) RETURN count(n) AS cnt" + records = await self._execute(cypher) + return records[0]["cnt"] if records else 0 + + async def count_edges(self) -> int: + cypher = "MATCH ()-[r]->() RETURN count(r) AS cnt" + records = await self._execute(cypher) + return records[0]["cnt"] if records else 0 + + async def search_nodes( + self, + query_terms: list[str], + match: str = "any", + node_types: list[str] | None = None, + tags: list[str] | None = None, + scope: str | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + """Keyword search using Neo4j full-text index (BM25).""" + search_query = " OR ".join(query_terms) if match == "any" else " AND ".join(query_terms) + + # If search_query is empty, return empty list + if not search_query: + return [] + + conditions: list[str] = [] + params: dict[str, Any] = {"search_query": search_query, "limit": top_k} + + if scope and scope != "full": + if scope.startswith("package:"): + prefix = scope[len("package:") :] + conditions.append("node.file_path STARTS WITH $scope_prefix") + params["scope_prefix"] = prefix + elif scope.startswith("file:"): + fp = scope[len("file:") :] + conditions.append("node.file_path = $scope_file") + params["scope_file"] = fp + + if node_types: + placeholders = ", ".join(f"$nt{i}" for i in range(len(node_types))) + conditions.append(f"node.type IN [{placeholders}]") + for i, nt in enumerate(node_types): + params[f"nt{i}"] = nt + + where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" + + cypher = f""" + CALL db.index.fulltext.queryNodes('node_search_index', $search_query) + YIELD node, score + {where_clause} + RETURN node, score + LIMIT $limit + """ + + records = await self._execute(cypher, params) + + results: list[dict[str, Any]] = [] + for rec in records: + node_data = dict(rec["node"]) + node = _record_to_node(node_data) + results.append( + { + "node_id": node.id, + "node_type": node.type.value, + "file": node.file_path, + "name": node.structural.name, + "docstring": node.semantic.docstring, + "tags": node.semantic.tags, + "score": rec["score"], + } + ) + return results + + async def _execute(self, cypher: str, params: dict[str, Any] | None = None) -> list[Any]: + """Execute a Cypher query and return records.""" + if not self._driver: + raise RuntimeError("Neo4jGraphStore is not connected. Call connect() first.") + async with self._driver.session(database=self._database) as session: + result = await session.run(cypher, params or {}) + return [rec.data() async for rec in result] diff --git a/smp/store/interfaces.py b/smp/store/interfaces.py new file mode 100644 index 0000000..c3b68ed --- /dev/null +++ b/smp/store/interfaces.py @@ -0,0 +1,249 @@ +"""Abstract base classes for store backends. + +All concrete implementations must subclass these to ensure +interchangeability across graph and vector stores. +""" + +from __future__ import annotations + +import abc +from collections.abc import Sequence +from typing import Any + +from smp.core.models import EdgeType, GraphEdge, GraphNode, NodeType + + +class GraphStore(abc.ABC): + """Abstract graph store — manages nodes and directed edges.""" + + # -- Lifecycle ----------------------------------------------------------- + + @abc.abstractmethod + async def connect(self) -> None: + """Open connection / initialise the underlying store.""" + + @abc.abstractmethod + async def close(self) -> None: + """Release resources.""" + + @abc.abstractmethod + async def clear(self) -> None: + """Drop all data (useful for tests).""" + + # -- Node CRUD ----------------------------------------------------------- + + @abc.abstractmethod + async def upsert_node(self, node: GraphNode) -> None: + """Insert or update a single node.""" + + @abc.abstractmethod + async def upsert_nodes(self, nodes: Sequence[GraphNode]) -> None: + """Batch upsert nodes.""" + + @abc.abstractmethod + async def get_node(self, node_id: str) -> GraphNode | None: + """Retrieve a node by its *id*, or ``None``.""" + + @abc.abstractmethod + async def delete_node(self, node_id: str) -> bool: + """Delete a node and all its edges. Returns True if it existed.""" + + @abc.abstractmethod + async def delete_nodes_by_file(self, file_path: str) -> int: + """Remove all nodes (and edges) belonging to *file_path*. + + Returns the number of nodes removed. + """ + + # -- Edge CRUD ----------------------------------------------------------- + + @abc.abstractmethod + async def upsert_edge(self, edge: GraphEdge) -> None: + """Insert or update a single edge.""" + + @abc.abstractmethod + async def upsert_edges(self, edges: Sequence[GraphEdge]) -> None: + """Batch upsert edges.""" + + @abc.abstractmethod + async def get_edges( + self, + node_id: str, + edge_type: EdgeType | None = None, + direction: str = "both", + ) -> list[GraphEdge]: + """Return edges connected to *node_id*. + + *direction*: ``"outgoing"`` | ``"incoming"`` | ``"both"``. + """ + + # -- Traversal ----------------------------------------------------------- + + @abc.abstractmethod + async def get_neighbors( + self, + node_id: str, + edge_type: EdgeType | None = None, + depth: int = 1, + ) -> list[GraphNode]: + """Return neighbours up to *depth* hops from *node_id*.""" + + @abc.abstractmethod + async def traverse( + self, + start_id: str, + edge_type: EdgeType, + depth: int, + max_nodes: int = 100, + direction: str = "outgoing", + ) -> list[GraphNode]: + """BFS traversal from *start_id* following *edge_type* edges.""" + + # -- Search -------------------------------------------------------------- + + @abc.abstractmethod + async def find_nodes( + self, + *, + type: NodeType | None = None, + file_path: str | None = None, + name: str | None = None, + ) -> list[GraphNode]: + """Find nodes matching the given filters.""" + + # -- Aggregation --------------------------------------------------------- + + @abc.abstractmethod + async def count_nodes(self) -> int: + """Return total number of nodes.""" + + @abc.abstractmethod + async def count_edges(self) -> int: + """Return total number of edges.""" + + # -- SMP(3) Extensions --------------------------------------------------- + + async def find_nodes_by_scope(self, scope: str) -> list[GraphNode]: + """Find nodes matching a scope prefix.""" + return [] + + async def get_node_degree(self, node_id: str) -> tuple[int, int]: + """Return (in_degree, out_degree) for a node.""" + return 0, 0 + + async def search_nodes( + self, + query_terms: list[str], + match: str = "any", + node_types: list[str] | None = None, + tags: list[str] | None = None, + scope: str | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + """Keyword search across docstrings, descriptions, and tags.""" + return [] + + # -- Session Persistence --------------------------------------------------- + + async def upsert_session(self, session: Any) -> None: + """Store or update a session in the graph.""" + raise NotImplementedError + + async def get_session(self, session_id: str) -> dict[str, Any] | None: + """Retrieve a session by ID.""" + return None + + async def delete_session(self, session_id: str) -> bool: + """Delete a session from the graph.""" + return False + + # -- Lock Persistence ------------------------------------------------------ + + async def upsert_lock(self, file_path: str, session_id: str) -> None: + """Store a file lock.""" + raise NotImplementedError + + async def get_lock(self, file_path: str) -> dict[str, Any] | None: + """Get lock info for a file.""" + return None + + async def release_lock(self, file_path: str, session_id: str) -> bool: + """Release a file lock.""" + return False + + async def release_all_locks(self, session_id: str) -> int: + """Release all locks held by a session.""" + return 0 + + # -- Context manager convenience ----------------------------------------- + + async def __aenter__(self) -> GraphStore: + await self.connect() + return self + + async def __aexit__(self, *_: Any) -> None: + await self.close() + + +class VectorStore(abc.ABC): + """Abstract vector store — manages embeddings with metadata.""" + + # -- Lifecycle ----------------------------------------------------------- + + @abc.abstractmethod + async def connect(self) -> None: + """Open connection / initialise.""" + + @abc.abstractmethod + async def close(self) -> None: + """Release resources.""" + + @abc.abstractmethod + async def clear(self) -> None: + """Drop all data.""" + + # -- CRUD ---------------------------------------------------------------- + + @abc.abstractmethod + async def upsert( + self, + ids: Sequence[str], + embeddings: Sequence[Sequence[float]], + metadatas: Sequence[dict[str, Any]], + documents: Sequence[str] | None = None, + ) -> None: + """Insert or update vectors with associated metadata.""" + + @abc.abstractmethod + async def query( + self, + embedding: Sequence[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """Return the *top_k* nearest neighbours. + + Each result is a dict with keys: ``id``, ``score``, ``metadata``, + ``document``. + """ + + @abc.abstractmethod + async def get(self, ids: Sequence[str]) -> list[dict[str, Any] | None]: + """Retrieve vectors by ID.""" + + @abc.abstractmethod + async def delete(self, ids: Sequence[str]) -> int: + """Delete vectors by ID. Returns count of deleted items.""" + + @abc.abstractmethod + async def delete_by_file(self, file_path: str) -> int: + """Delete all vectors whose metadata ``file_path`` matches.""" + + # -- Context manager convenience ----------------------------------------- + + async def __aenter__(self) -> VectorStore: + await self.connect() + return self + + async def __aexit__(self, *_: Any) -> None: + await self.close() diff --git a/test_codebase/src/auth/manager.py b/test_codebase/src/auth/manager.py new file mode 100644 index 0000000..8f6d45c --- /dev/null +++ b/test_codebase/src/auth/manager.py @@ -0,0 +1,14 @@ +# src/auth/manager.py +from src.db.user_store import save_user, get_user + +def authenticate_user(email, password): + """Validates user credentials and returns a session token.""" + user = get_user(email) + if user and password == "secret": + return "token_123" + return None + +def register_user(email, password): + """Creates a new user account.""" + data = {"email": email, "password": password} + return save_user(data) diff --git a/test_codebase/src/db/user_store.py b/test_codebase/src/db/user_store.py new file mode 100644 index 0000000..9d2d6f8 --- /dev/null +++ b/test_codebase/src/db/user_store.py @@ -0,0 +1,9 @@ +# src/db/user_store.py +def save_user(user_data: dict): + """Saves user data to the database.""" + print(f"Saving user {user_data.get('email')}") + return True + +def get_user(email: str): + """Retrieves user by email.""" + return {"email": email, "name": "Test User"} diff --git a/test_codebase/tests/test_auth.py b/test_codebase/tests/test_auth.py new file mode 100644 index 0000000..7c8d819 --- /dev/null +++ b/test_codebase/tests/test_auth.py @@ -0,0 +1,8 @@ +# tests/test_auth.py +from src.auth.manager import authenticate_user + +def test_auth_success(): + assert authenticate_user("test@example.com", "secret") == "token_123" + +def test_auth_fail(): + assert authenticate_user("test@example.com", "wrong") is None diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8cb985f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,94 @@ +"""Shared test fixtures for SMP tests.""" + +from __future__ import annotations + +import os + +import pytest + +from smp.core.models import ( + EdgeType, + GraphEdge, + GraphNode, + NodeType, + SemanticProperties, + StructuralProperties, +) +from smp.store.graph.neo4j_store import Neo4jGraphStore + +# Load environment from .env if not already set +if "SMP_NEO4J_URI" not in os.environ: + try: + from dotenv import load_dotenv + + load_dotenv() + except ImportError: + pass + + +# --------------------------------------------------------------------------- +# Neo4j fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="session") +def neo4j_store() -> Neo4jGraphStore: + """Provide a connected Neo4j graph store (session-scoped).""" + uri = os.environ.get("SMP_NEO4J_URI", "bolt://localhost:7687") + user = os.environ.get("SMP_NEO4J_USER", "neo4j") + password = os.environ.get("SMP_NEO4J_PASSWORD", "") + store = Neo4jGraphStore(uri=uri, user=user, password=password) + return store + + +@pytest.fixture() +async def clean_graph(neo4j_store: Neo4jGraphStore): + """Provide a clean Neo4j store, clearing data before and after each test.""" + await neo4j_store.connect() + await neo4j_store.clear() + yield neo4j_store + await neo4j_store.clear() + await neo4j_store.close() + + +# --------------------------------------------------------------------------- +# Sample data factories +# --------------------------------------------------------------------------- + + +def make_node( + id: str = "func_login", + type: NodeType = NodeType.FUNCTION, + file_path: str = "src/auth/login.py", + structural: StructuralProperties | None = None, + semantic: SemanticProperties | None = None, +) -> GraphNode: + if structural is None: + structural = StructuralProperties( + name="login", + file=file_path, + signature="def login(user: User) -> Token:", + start_line=10, + end_line=25, + lines=16, + ) + if semantic is None: + semantic = SemanticProperties( + docstring="Authenticate user and return token.", + status="enriched", + ) + return GraphNode( + id=id, + type=type, + file_path=file_path, + structural=structural, + semantic=semantic, + ) + + +def make_edge( + source: str = "func_login", + target: str = "func_validate", + edge_type: EdgeType = EdgeType.CALLS, +) -> GraphEdge: + return GraphEdge(source_id=source, target_id=target, type=edge_type) diff --git a/tests/fixtures/sample_project/src/api/__init__.py b/tests/fixtures/sample_project/src/api/__init__.py new file mode 100644 index 0000000..6aa203c --- /dev/null +++ b/tests/fixtures/sample_project/src/api/__init__.py @@ -0,0 +1,7 @@ +"""API module.""" + +from __future__ import annotations + +from .routes import create_app, health_check, get_user + +__all__ = ["create_app", "health_check", "get_user"] diff --git a/tests/fixtures/sample_project/src/api/routes.py b/tests/fixtures/sample_project/src/api/routes.py new file mode 100644 index 0000000..d54ccd2 --- /dev/null +++ b/tests/fixtures/sample_project/src/api/routes.py @@ -0,0 +1,37 @@ +"""API routes for the sample project.""" + +from __future__ import annotations + +from typing import Any + +from src.auth import AuthService +from src.db import UserModel, DatabaseConnection + + +def create_app() -> dict[str, Any]: + """Create and return the API app configuration.""" + return { + "name": "sample_api", + "version": "1.0.0", + "endpoints": ["/health", "/users/{id}", "/login"], + } + + +def health_check() -> dict[str, str]: + """Health check endpoint.""" + return {"status": "ok", "service": "sample_api"} + + +def get_user(user_id: str) -> dict[str, Any] | None: + """Get a user by ID.""" + db = DatabaseConnection() + db.connect() + model = UserModel(db) + result = model.find_by_username(user_id) + return result + + +def login_user(username: str, password: str) -> dict[str, Any]: + """Login a user.""" + auth = AuthService() + return auth.login(username, password) diff --git a/tests/fixtures/sample_project/src/auth/__init__.py b/tests/fixtures/sample_project/src/auth/__init__.py new file mode 100644 index 0000000..d092dde --- /dev/null +++ b/tests/fixtures/sample_project/src/auth/__init__.py @@ -0,0 +1,7 @@ +"""Auth module.""" + +from __future__ import annotations + +from .auth_service import AuthService, generate_token, get_current_user, hash_password, verify_password + +__all__ = ["AuthService", "generate_token", "get_current_user", "hash_password", "verify_password"] diff --git a/tests/fixtures/sample_project/src/auth/auth_service.py b/tests/fixtures/sample_project/src/auth/auth_service.py new file mode 100644 index 0000000..0cf142c --- /dev/null +++ b/tests/fixtures/sample_project/src/auth/auth_service.py @@ -0,0 +1,65 @@ +"""Authentication service for the sample project.""" + +from __future__ import annotations + +import hashlib +import uuid +from datetime import datetime, timezone + + +def hash_password(password: str, salt: str | None = None) -> tuple[str, str]: + """Hash a password with a salt using SHA-256.""" + if salt is None: + salt = uuid.uuid4().hex[:16] + combined = f"{password}{salt}".encode() + digest = hashlib.sha256(combined).hexdigest() + return digest, salt + + +def verify_password(password: str, hashed: str, salt: str) -> bool: + """Verify a password against a hash.""" + computed, _ = hash_password(password, salt) + return computed == hashed + + +def generate_token(user_id: str, secret: str = "default_secret") -> str: + """Generate a simple auth token.""" + payload = f"{user_id}:{datetime.now(timezone.utc).isoformat()}" + combined = f"{payload}{secret}".encode() + return hashlib.sha256(combined).hexdigest() + + +class AuthService: + """Main authentication service.""" + + def __init__(self) -> None: + self._sessions: dict[str, str] = {} + self._secret = "default_secret" + + def login(self, username: str, password: str) -> dict[str, str]: + """Authenticate a user and return a session token.""" + if not username or not password: + return {"error": "missing credentials"} + + token = generate_token(username, self._secret) + self._sessions[token] = username + return {"token": token, "user": username} + + def logout(self, token: str) -> bool: + """End a user session.""" + if token in self._sessions: + del self._sessions[token] + return True + return False + + def verify_token(self, token: str) -> str | None: + """Check if a token is valid and return the username.""" + return self._sessions.get(token) + + +def get_current_user(token: str | None) -> str | None: + """Helper to get current user from token.""" + if not token: + return None + service = AuthService() + return service.verify_token(token) diff --git a/tests/fixtures/sample_project/src/db/__init__.py b/tests/fixtures/sample_project/src/db/__init__.py new file mode 100644 index 0000000..1a4da69 --- /dev/null +++ b/tests/fixtures/sample_project/src/db/__init__.py @@ -0,0 +1,9 @@ +"""DB module.""" + +from __future__ import annotations + +from .models import UserModel +from .orders import OrderModel +from . import models + +__all__ = ["UserModel", "OrderModel", "models"] diff --git a/tests/fixtures/sample_project/src/db/models.py b/tests/fixtures/sample_project/src/db/models.py new file mode 100644 index 0000000..4071c7b --- /dev/null +++ b/tests/fixtures/sample_project/src/db/models.py @@ -0,0 +1,28 @@ +"""User model for database operations.""" + +from __future__ import annotations + +from typing import Any + + +class UserModel: + """User data access object.""" + + def __init__(self, db: Any) -> None: + self._db = db + + def find_by_username(self, username: str) -> dict[str, Any] | None: + """Find a user by username.""" + result = self._db.execute( + "SELECT * FROM users WHERE username = $1", + {"username": username}, + ) + return result[0] if result else None + + def create(self, username: str, email: str) -> dict[str, Any]: + """Create a new user.""" + id = self._db.execute( + "INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id", + {"username": username, "email": email}, + ) + return {"id": id, "username": username, "email": email} diff --git a/tests/fixtures/sample_project/src/db/orders.py b/tests/fixtures/sample_project/src/db/orders.py new file mode 100644 index 0000000..0a32e5b --- /dev/null +++ b/tests/fixtures/sample_project/src/db/orders.py @@ -0,0 +1,28 @@ +"""Order model for database operations.""" + +from __future__ import annotations + +from typing import Any + + +class OrderModel: + """Order data access object.""" + + def __init__(self, db: Any) -> None: + self._db = db + + def find_by_user(self, user_id: str) -> list[dict[str, Any]]: + """Find all orders for a user.""" + result = self._db.execute( + "SELECT * FROM orders WHERE user_id = $1", + {"user_id": user_id}, + ) + return result + + def create(self, user_id: str, product: str, quantity: int) -> dict[str, Any]: + """Create a new order.""" + id = self._db.execute( + "INSERT INTO orders (user_id, product, quantity) VALUES ($1, $2, $3) RETURNING id", + {"user_id": user_id, "product": product, "quantity": quantity}, + ) + return {"id": id, "user_id": user_id, "product": product, "quantity": quantity} diff --git a/tests/fixtures/sample_project/tests/test_auth.py b/tests/fixtures/sample_project/tests/test_auth.py new file mode 100644 index 0000000..8388229 --- /dev/null +++ b/tests/fixtures/sample_project/tests/test_auth.py @@ -0,0 +1,67 @@ +"""Tests for auth service.""" + +from __future__ import annotations + +import pytest +from src.auth import AuthService, hash_password, verify_password, generate_token + + +def test_hash_password_returns_tuple() -> None: + """Test that hash_password returns (digest, salt).""" + digest, salt = hash_password("secret123") + assert isinstance(digest, str) + assert isinstance(salt, str) + assert len(digest) == 64 + + +def test_verify_password_correct() -> None: + """Test verify_password with correct credentials.""" + digest, salt = hash_password("secret123") + result = verify_password("secret123", digest, salt) + assert result is True + + +def test_verify_password_incorrect() -> None: + """Test verify_password with wrong password.""" + digest, salt = hash_password("secret123") + result = verify_password("wrongpassword", digest, salt) + assert result is False + + +def test_generate_token_returns_hex() -> None: + """Test that generate_token returns a hex string.""" + token = generate_token("user_123") + assert isinstance(token, str) + assert len(token) == 64 + + +def test_auth_service_login_success() -> None: + """Test AuthService.login with valid credentials.""" + service = AuthService() + result = service.login("alice", "password123") + assert "token" in result + assert result["user"] == "alice" + + +def test_auth_service_login_missing_username() -> None: + """Test AuthService.login with missing username.""" + service = AuthService() + result = service.login("", "password") + assert "error" in result + + +def test_auth_service_logout() -> None: + """Test AuthService.logout.""" + service = AuthService() + login_result = service.login("alice", "password123") + token = login_result["token"] + assert service.logout(token) is True + + +def test_auth_service_verify_token() -> None: + """Test AuthService.verify_token.""" + service = AuthService() + login_result = service.login("alice", "password123") + token = login_result["token"] + assert service.verify_token(token) == "alice" + assert service.verify_token("invalid_token") is None diff --git a/tests/fixtures/sample_project/tests/test_db.py b/tests/fixtures/sample_project/tests/test_db.py new file mode 100644 index 0000000..4067c17 --- /dev/null +++ b/tests/fixtures/sample_project/tests/test_db.py @@ -0,0 +1,52 @@ +"""Tests for database models.""" + +from __future__ import annotations + +import pytest +from src.db.models import User, Order, DatabaseConnection, UserModel, OrderModel + + +def test_user_dataclass() -> None: + """Test User dataclass creation.""" + user = User(id="1", username="alice", email="alice@example.com", created_at="2024-01-01") + assert user.id == "1" + assert user.username == "alice" + + +def test_order_dataclass() -> None: + """Test Order dataclass creation.""" + order = Order(id="1", user_id="1", product="Widget", quantity=5, status="pending") + assert order.product == "Widget" + assert order.quantity == 5 + + +def test_database_connection_connect() -> None: + """Test DatabaseConnection.connect().""" + db = DatabaseConnection() + assert db.connect() is True + assert db._connected is True + + +def test_database_connection_disconnect() -> None: + """Test DatabaseConnection.disconnect().""" + db = DatabaseConnection() + db.connect() + db.disconnect() + assert db._connected is False + + +def test_database_connection_execute_returns_list() -> None: + """Test DatabaseConnection.execute() returns a list.""" + db = DatabaseConnection() + db.connect() + result = db.execute("SELECT * FROM users", {}) + assert isinstance(result, list) + + +def test_user_model_find_by_username_no_results() -> None: + """Test UserModel.find_by_username returns None when not found.""" + db = DatabaseConnection() + db.connect() + model = UserModel(db) + result = model.find_by_username("nonexistent") + assert result is None diff --git a/tests/practical_verification.py b/tests/practical_verification.py new file mode 100644 index 0000000..7921d45 --- /dev/null +++ b/tests/practical_verification.py @@ -0,0 +1,90 @@ +import asyncio +import os +from smp.client import SMPClient +from smp.logging import get_logger + +log = get_logger("verification") + +async def main(): + async with SMPClient("http://localhost:8420") as client: + # 1. Ingest Test Codebase + log.info("Step 1: Ingesting test codebase") + test_dir = "tests/test_codebase" + files = ["math_utils.py", "calculator.py"] + for f in files: + path = os.path.join(test_dir, f) + with open(path, "r") as file: + content = file.read() + await client.update(path, content=content) + + stats = await client.stats() + log.info("Graph stats after ingestion", stats=stats) + + # 2. Test Locate (Graph RAG) + log.info("Step 2: Testing locate") + results = await client.locate("adds two integers") + log.info("Locate results", results=results) + assert len(results) > 0, "Should have found the 'add' function" + assert "add" in results[0]["name"], "First result should be 'add'" + + # 3. Test Navigate + log.info("Step 3: Testing navigate") + # Find the ID for compute_sum first + res = await client.locate("compute sum") + sum_id = res[0]["node_id"] + nav = await client.navigate(sum_id) + log.info("Navigate results", nav=nav) + # Check if it mentions 'add' + found_add = any("add" in str(v) for v in nav.values()) + assert found_add, "Navigate for compute_sum should reveal connection to add" + + # 4. Test Trace + log.info("Step 4: Testing trace") + # Trace who calls 'add' + add_res = await client.locate("adds two integers") + add_id = add_res[0]["node_id"] + trace = await client.trace(add_id, direction="incoming") + log.info("Trace results", trace=trace) + assert len(trace) > 0, "Should find that compute_sum calls add" + + # 5. Test Community Detection + log.info("Step 5: Testing community detection") + comm_res = await client._rpc("smp/community/detect", {"levels": [{"level": 0, "resolution": 0.5}]}) + log.info("Community detect result", res=comm_res) + assert "coarse_communities" in comm_res, "Should have detected communities" + + # 6. Test Merkle Sync + log.info("Step 6: Testing Merkle Sync") + root_hash = await client._rpc("smp/merkle/tree", {}) + log.info("Merkle root hash", hash=root_hash) + assert root_hash is not None, "Should have a root hash" + + # 7. Test Safety Protocol + log.info("Step 7: Testing Safety") + session = await client._rpc("smp/session/open", {"agent_id": "test_agent", "task": "verify"}) + sid = session["session_id"] + log.info("Session opened", sid=sid) + + lock_res = await client._rpc("smp/lock", {"file_path": "tests/test_codebase/math_utils.py", "session_id": sid}) + log.info("Lock acquired", res=lock_res) + + await client._rpc("smp/session/close", {"session_id": sid}) + log.info("Session closed") + + # 8. Test Sandbox + log.info("Step 8: Testing Sandbox") + sandbox = await client._rpc("smp/sandbox/spawn", {"name": "test_sb"}) + sb_id = sandbox["sandbox_id"] + log.info("Sandbox spawned", sb_id=sb_id) + + exec_res = await client._rpc("smp/sandbox/execute", {"sandbox_id": sb_id, "command": ["ls", "-la"]}) + log.info("Sandbox execution", res=exec_res) + assert exec_res["exit_code"] == 0, "Sandbox command should succeed" + + await client._rpc("smp/sandbox/destroy", {"sandbox_id": sb_id}) + log.info("Sandbox destroyed") + + log.info("ALL PRACTICAL TESTS PASSED") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/results/practical_phase10_handoff.json b/tests/results/practical_phase10_handoff.json new file mode 100644 index 0000000..f318866 --- /dev/null +++ b/tests/results/practical_phase10_handoff.json @@ -0,0 +1,56 @@ +[ + { + "name": "smp_handoff_review", + "passed": true, + "elapsed_s": 0.008, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_handoff_review_comment", + "passed": true, + "elapsed_s": 0.006, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_handoff_review_approve", + "passed": true, + "elapsed_s": 0.009, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_handoff_review_reject", + "passed": true, + "elapsed_s": 0.039, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_handoff_pr", + "passed": true, + "elapsed_s": 0.009, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_verify_integrity", + "passed": true, + "elapsed_s": 0.006, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase1_service.json b/tests/results/practical_phase1_service.json new file mode 100644 index 0000000..30ca173 --- /dev/null +++ b/tests/results/practical_phase1_service.json @@ -0,0 +1,32 @@ +[ + { + "name": "health", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true, + "status": "ok" + }, + "error": null + }, + { + "name": "stats", + "passed": true, + "elapsed_s": 0.009, + "result": { + "passed": true, + "nodes": 246, + "edges": 243 + }, + "error": null + }, + { + "name": "rpc_endpoint", + "passed": true, + "elapsed_s": 0.067, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase2_ingestion.json b/tests/results/practical_phase2_ingestion.json new file mode 100644 index 0000000..6cc2196 --- /dev/null +++ b/tests/results/practical_phase2_ingestion.json @@ -0,0 +1,32 @@ +[ + { + "name": "smp_update", + "passed": true, + "elapsed_s": 0.021, + "result": { + "passed": true, + "nodes": 2 + }, + "error": null + }, + { + "name": "smp_batch_update", + "passed": true, + "elapsed_s": 0.035, + "result": { + "passed": true, + "updates": 2 + }, + "error": null + }, + { + "name": "smp_reindex", + "passed": true, + "elapsed_s": 0.008, + "result": { + "passed": true, + "status": "reindex_requested" + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase2_query.json b/tests/results/practical_phase2_query.json new file mode 100644 index 0000000..87d1b22 --- /dev/null +++ b/tests/results/practical_phase2_query.json @@ -0,0 +1,64 @@ +[ + { + "name": "navigate", + "passed": true, + "elapsed_s": 0.362, + "result": { + "passed": true, + "node_id": "chaos_tests/beta.py::File::chaos_tests/beta.py::1" + }, + "error": null + }, + { + "name": "navigate_by_name", + "passed": true, + "elapsed_s": 0.233, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "context", + "passed": true, + "elapsed_s": 0.187, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "impact", + "passed": true, + "elapsed_s": 0.063, + "result": { + "passed": true, + "result_keys": [ + "affected_files", + "affected_functions", + "severity", + "recommendations" + ] + }, + "error": null + }, + { + "name": "locate", + "passed": true, + "elapsed_s": 0.1, + "result": { + "passed": true, + "match_count": 3 + }, + "error": null + }, + { + "name": "search", + "passed": true, + "elapsed_s": 0.032, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase3_enrichment.json b/tests/results/practical_phase3_enrichment.json new file mode 100644 index 0000000..d5cb6d3 --- /dev/null +++ b/tests/results/practical_phase3_enrichment.json @@ -0,0 +1,42 @@ +[ + { + "name": "enrich", + "passed": true, + "elapsed_s": 0.017, + "result": { + "passed": true, + "status": "no_metadata" + }, + "error": null + }, + { + "name": "enrich_batch", + "passed": true, + "elapsed_s": 0.287, + "result": { + "passed": true, + "enriched": 92 + }, + "error": null + }, + { + "name": "enrich_status", + "passed": true, + "elapsed_s": 0.059, + "result": { + "passed": true, + "total_nodes": 246 + }, + "error": null + }, + { + "name": "enrich_stale", + "passed": true, + "elapsed_s": 0.062, + "result": { + "passed": true, + "stale_count": 0 + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase3_linker.json b/tests/results/practical_phase3_linker.json new file mode 100644 index 0000000..74311f6 --- /dev/null +++ b/tests/results/practical_phase3_linker.json @@ -0,0 +1,22 @@ +[ + { + "name": "smp_linker_report", + "passed": true, + "elapsed_s": 0.009, + "result": { + "passed": true, + "status": "not_configured" + }, + "error": null + }, + { + "name": "smp_linker_runtime", + "passed": true, + "elapsed_s": 0.008, + "result": { + "passed": true, + "status": "not_configured" + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase4_annotation.json b/tests/results/practical_phase4_annotation.json new file mode 100644 index 0000000..76cf486 --- /dev/null +++ b/tests/results/practical_phase4_annotation.json @@ -0,0 +1,21 @@ +[ + { + "name": "annotate", + "passed": true, + "elapsed_s": 0.023, + "result": { + "passed": true, + "status": "annotated" + }, + "error": null + }, + { + "name": "tag", + "passed": true, + "elapsed_s": 0.715, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase4_query.json b/tests/results/practical_phase4_query.json new file mode 100644 index 0000000..c3e3808 --- /dev/null +++ b/tests/results/practical_phase4_query.json @@ -0,0 +1,73 @@ +[ + { + "name": "smp_navigate", + "passed": true, + "elapsed_s": 0.307, + "result": { + "passed": true, + "node_id": "chaos_tests/beta.py::File::chaos_tests/beta.py::1" + }, + "error": null + }, + { + "name": "smp_navigate_by_name", + "passed": true, + "elapsed_s": 0.072, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_trace", + "passed": true, + "elapsed_s": 0.058, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_context", + "passed": true, + "elapsed_s": 0.025, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_impact", + "passed": true, + "elapsed_s": 0.031, + "result": { + "passed": true, + "result_keys": [ + "affected_files", + "affected_functions", + "severity", + "recommendations" + ] + }, + "error": null + }, + { + "name": "smp_locate", + "passed": true, + "elapsed_s": 0.063, + "result": { + "passed": true, + "match_count": 3 + }, + "error": null + }, + { + "name": "smp_search", + "passed": true, + "elapsed_s": 0.009, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase5_enrichment.json b/tests/results/practical_phase5_enrichment.json new file mode 100644 index 0000000..08b6a6a --- /dev/null +++ b/tests/results/practical_phase5_enrichment.json @@ -0,0 +1,42 @@ +[ + { + "name": "smp_enrich", + "passed": true, + "elapsed_s": 0.022, + "result": { + "passed": true, + "status": "no_metadata" + }, + "error": null + }, + { + "name": "smp_enrich_batch", + "passed": true, + "elapsed_s": 0.151, + "result": { + "passed": true, + "enriched": 92 + }, + "error": null + }, + { + "name": "smp_enrich_stale", + "passed": true, + "elapsed_s": 0.104, + "result": { + "passed": true, + "stale_count": 0 + }, + "error": null + }, + { + "name": "smp_enrich_status", + "passed": true, + "elapsed_s": 0.067, + "result": { + "passed": true, + "total_nodes": 246 + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase5_memory.json b/tests/results/practical_phase5_memory.json new file mode 100644 index 0000000..1b2782c --- /dev/null +++ b/tests/results/practical_phase5_memory.json @@ -0,0 +1,21 @@ +[ + { + "name": "update", + "passed": true, + "elapsed_s": 0.117, + "result": { + "passed": true, + "nodes": 2 + }, + "error": null + }, + { + "name": "reindex", + "passed": true, + "elapsed_s": 0.016, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase6_annotation.json b/tests/results/practical_phase6_annotation.json new file mode 100644 index 0000000..5da795a --- /dev/null +++ b/tests/results/practical_phase6_annotation.json @@ -0,0 +1,30 @@ +[ + { + "name": "smp_annotate", + "passed": true, + "elapsed_s": 0.025, + "result": { + "passed": true, + "status": "annotated" + }, + "error": null + }, + { + "name": "smp_annotate_bulk", + "passed": true, + "elapsed_s": 0.01, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_tag", + "passed": true, + "elapsed_s": 0.293, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase6_safety.json b/tests/results/practical_phase6_safety.json new file mode 100644 index 0000000..dd1a243 --- /dev/null +++ b/tests/results/practical_phase6_safety.json @@ -0,0 +1,69 @@ +[ + { + "name": "safety_not_enabled", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true, + "msg": "Safety protocol not enabled" + }, + "error": null + }, + { + "name": "session_open", + "passed": true, + "elapsed_s": 0.024, + "result": { + "passed": true, + "session_id": "ses_62be43" + }, + "error": null + }, + { + "name": "guard_check", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true, + "verdict": "blocked" + }, + "error": null + }, + { + "name": "lock", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "checkpoint", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true, + "checkpoint_id": "chk_3cbb2d" + }, + "error": null + }, + { + "name": "dryrun", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "session_close", + "passed": true, + "elapsed_s": 0.028, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase7_query_ext.json b/tests/results/practical_phase7_query_ext.json new file mode 100644 index 0000000..3088346 --- /dev/null +++ b/tests/results/practical_phase7_query_ext.json @@ -0,0 +1,60 @@ +[ + { + "name": "smp_diff", + "passed": true, + "elapsed_s": 0.308, + "result": { + "passed": true, + "result_keys": [ + "nodes_added", + "nodes_removed", + "nodes_modified", + "relationships_added", + "relationships_removed" + ] + }, + "error": null + }, + { + "name": "smp_plan", + "passed": true, + "elapsed_s": 0.016, + "result": { + "passed": true, + "result_keys": [ + "execution_order", + "inter_file_conflicts", + "external_files_at_risk" + ] + }, + "error": null + }, + { + "name": "smp_conflict", + "passed": true, + "elapsed_s": 0.014, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_why", + "passed": true, + "elapsed_s": 0.048, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_telemetry", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true, + "status": "not_configured" + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase8_safety.json b/tests/results/practical_phase8_safety.json new file mode 100644 index 0000000..bfd5002 --- /dev/null +++ b/tests/results/practical_phase8_safety.json @@ -0,0 +1,87 @@ +[ + { + "name": "smp_session_open", + "passed": true, + "elapsed_s": 0.014, + "result": { + "passed": true, + "session_id": "ses_bf5eb8" + }, + "error": null + }, + { + "name": "smp_guard_check", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true, + "verdict": "blocked" + }, + "error": null + }, + { + "name": "smp_lock", + "passed": true, + "elapsed_s": 0.031, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_checkpoint", + "passed": true, + "elapsed_s": 0.008, + "result": { + "passed": true, + "checkpoint_id": "chk_6ab409" + }, + "error": null + }, + { + "name": "smp_dryrun", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_rollback", + "passed": true, + "elapsed_s": 0.017, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_unlock", + "passed": true, + "elapsed_s": 0.03, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_audit_get", + "passed": true, + "elapsed_s": 0.021, + "result": { + "passed": true, + "note": "endpoint responded correctly (no matching audit log)" + }, + "error": null + }, + { + "name": "smp_session_close", + "passed": true, + "elapsed_s": 0.009, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_phase9_sandbox.json b/tests/results/practical_phase9_sandbox.json new file mode 100644 index 0000000..81dd755 --- /dev/null +++ b/tests/results/practical_phase9_sandbox.json @@ -0,0 +1,30 @@ +[ + { + "name": "smp_sandbox_spawn", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true, + "sandbox_id": "sandbox_933b8c3c" + }, + "error": null + }, + { + "name": "smp_sandbox_execute", + "passed": true, + "elapsed_s": 0.013, + "result": { + "passed": true + }, + "error": null + }, + { + "name": "smp_sandbox_destroy", + "passed": true, + "elapsed_s": 0.007, + "result": { + "passed": true + }, + "error": null + } +] \ No newline at end of file diff --git a/tests/results/practical_summary.json b/tests/results/practical_summary.json new file mode 100644 index 0000000..782f7ca --- /dev/null +++ b/tests/results/practical_summary.json @@ -0,0 +1,6 @@ +{ + "timestamp": "2026-04-11T11:13:23.216354+00:00", + "passed": 45, + "failed": 0, + "total": 45 +} \ No newline at end of file diff --git a/tests/results/summary.json b/tests/results/summary.json new file mode 100644 index 0000000..63c75f4 --- /dev/null +++ b/tests/results/summary.json @@ -0,0 +1,19 @@ +[ + { + "timestamp": "2026-04-11T08:26:39.600537+00:00", + "total_tests": 50, + "passed": 50, + "failed": 0, + "phases": { + "phase1": 5, + "phase2": 2, + "phase3": 4, + "phase4": 8, + "phase5": 5, + "phase6": 4, + "phase7": 10, + "phase8": 9, + "phase9": 3 + } + } +] \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..80643da --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,210 @@ +"""Tests for the SMP SDK client.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager + +import msgspec +import pytest +from fastapi import FastAPI, Request +from fastapi.responses import Response +from starlette.testclient import TestClient + +from smp.client import SMPClient, SMPClientError +from smp.core.models import EdgeType, GraphEdge, GraphNode, NodeType, SemanticProperties, StructuralProperties +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.engine.query import DefaultQueryEngine +from smp.parser.registry import ParserRegistry +from smp.protocol.router import handle_rpc +from smp.store.graph.neo4j_store import Neo4jGraphStore + + +@pytest.fixture(scope="module") +def server(): + """Create a FastAPI server with real stores (lifespan handles event loop).""" + graph = Neo4jGraphStore() + enricher = StaticSemanticEnricher() + registry = ParserRegistry() + + @asynccontextmanager + async def lifespan(app: FastAPI): + await graph.connect() + await graph.clear() + nodes = [ + GraphNode( + id="f.py::FILE::f.py::1", + type=NodeType.FILE, + file_path="f.py", + structural=StructuralProperties(name="f.py", file="f.py", start_line=1, end_line=20), + semantic=SemanticProperties(docstring=""), + ), + GraphNode( + id="f.py::FUNCTION::alpha::3", + type=NodeType.FUNCTION, + file_path="f.py", + structural=StructuralProperties(name="alpha", file="f.py", start_line=3, end_line=8), + semantic=SemanticProperties(docstring="Alpha function."), + ), + GraphNode( + id="f.py::FUNCTION::beta::10", + type=NodeType.FUNCTION, + file_path="f.py", + structural=StructuralProperties(name="beta", file="f.py", start_line=10, end_line=15), + semantic=SemanticProperties(docstring="Beta function."), + ), + ] + edges = [ + GraphEdge(source_id="f.py::FILE::f.py::1", target_id="f.py::FUNCTION::alpha::3", type=EdgeType.CONTAINS), + GraphEdge(source_id="f.py::FILE::f.py::1", target_id="f.py::FUNCTION::beta::10", type=EdgeType.CONTAINS), + GraphEdge(source_id="f.py::FUNCTION::alpha::3", target_id="f.py::FUNCTION::beta::10", type=EdgeType.CALLS), + ] + await graph.upsert_nodes(nodes) + await graph.upsert_edges(edges) + + app.state.engine = DefaultQueryEngine(graph, enricher) + app.state.builder = DefaultGraphBuilder(graph) + app.state.enricher = enricher + app.state.registry = registry + yield + await graph.clear() + await graph.close() + + app = FastAPI(lifespan=lifespan) + + @app.post("/rpc") + async def rpc(request: Request) -> Response: + return await handle_rpc( + request, + engine=request.app.state.engine, + enricher=request.app.state.enricher, + builder=request.app.state.builder, + registry=request.app.state.registry, + vector=request.app.state.vector, + ) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.get("/stats") + async def stats(): + g = request.app.state.graph + return {"nodes": await g.count_nodes(), "edges": await g.count_edges()} + + with TestClient(app) as c: + yield c + + +@pytest.fixture() +def smp_client(server): + """Provide an SMPClient connected to the test server.""" + + class _TestClient(SMPClient): + """SMPClient that routes through TestClient instead of real HTTP.""" + + def __init__(self, test_client: TestClient) -> None: + super().__init__("http://test") + self._tc = test_client + self._req_id = 0 + + async def connect(self) -> None: + pass # no real connection needed + + async def close(self) -> None: + pass + + async def _rpc(self, method: str, params: dict): + self._req_id += 1 + req = msgspec.json.encode({"jsonrpc": "2.0", "method": method, "params": params, "id": self._req_id}) + resp = self._tc.post("/rpc", content=req) + if resp.status_code == 204: + return None + body = msgspec.json.decode(resp.content) + if body.get("error"): + raise SMPClientError(body["error"]["code"], body["error"]["message"]) + return body["result"] + + async def health(self): + return self._tc.get("/health").json() + + async def stats(self): + return self._tc.get("/stats").json() + + client = _TestClient(server) + return client + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_health(smp_client): + result = await smp_client.health() + assert result["status"] == "ok" + + +@pytest.mark.asyncio +async def test_navigate(smp_client): + result = await smp_client.navigate("f.py::FUNCTION::alpha::3") + assert result["entity"]["name"] == "alpha" + assert len(result.get("relationships", {}).get("calls", [])) >= 0 + + +@pytest.mark.asyncio +async def test_navigate_missing(smp_client): + result = await smp_client.navigate("nonexistent") + assert "error" in result + + +@pytest.mark.asyncio +async def test_trace(smp_client): + result = await smp_client.trace("f.py::FUNCTION::alpha::3") + names = {n["name"] for n in result} + assert "beta" in names + + +@pytest.mark.asyncio +async def test_get_context(smp_client): + result = await smp_client.get_context("f.py") + assert "functions_defined" in result or "self" in result + + +@pytest.mark.asyncio +async def test_assess_impact(smp_client): + result = await smp_client.assess_impact("f.py::FUNCTION::beta::10") + assert "affected_files" in result or "severity" in result + + +@pytest.mark.asyncio +async def test_locate(smp_client): + result = await smp_client.locate("alpha function") + assert isinstance(result, (list, dict)) + + +@pytest.mark.asyncio +async def test_find_flow(smp_client): + result = await smp_client.find_flow("f.py::FUNCTION::alpha::3", "f.py::FUNCTION::beta::10") + if isinstance(result, dict): + assert "path" in result + elif isinstance(result, list): + assert len(result) >= 1 + + +@pytest.mark.asyncio +async def test_invalid_method(smp_client): + with pytest.raises(SMPClientError) as exc_info: + await smp_client._rpc("smp/nonexistent", {}) + assert exc_info.value.code == -32601 + + +@pytest.mark.asyncio +async def test_update(smp_client): + result = await smp_client.update( + "test_client_file.py", + content="def hello():\n pass\n", + ) + assert result["file_path"] == "test_client_file.py" + assert result["nodes"] > 0 diff --git a/tests/test_codebase/__init__.py b/tests/test_codebase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_codebase/api/middleware.py b/tests/test_codebase/api/middleware.py new file mode 100644 index 0000000..65067b4 --- /dev/null +++ b/tests/test_codebase/api/middleware.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Callable + + +class Middleware: + """ + Simple middleware for request processing. + """ + + async def process_request(self, request_id: str, handler: Callable) -> str: + """ + Processes a request by wrapping it with middleware logic. + + Args: + request_id: The unique ID of the request. + handler: The handler function to execute. + + Returns: + The result of the handler as a string. + """ + print(f"Processing request {request_id}") + result = await handler() + print(f"Finished request {request_id}") + return result diff --git a/tests/test_codebase/api/routes.py b/tests/test_codebase/api/routes.py new file mode 100644 index 0000000..ea6a97b --- /dev/null +++ b/tests/test_codebase/api/routes.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from ..auth.session_handler import SessionHandler +from ..auth.user_manager import UserManager +from ..db.order_repository import OrderRepository +from ..utils.validators import validate_email + + +class APIRoutes: + """ + Defines the API routes and their handlers. + This class acts as a 'Hot Node' as it coordinates multiple services. + """ + + def __init__( + self, + session_handler: SessionHandler, + user_manager: UserManager, + order_repo: OrderRepository + ): + """ + Initializes the APIRoutes with necessary services. + """ + self._session_handler = session_handler + self._user_manager = user_manager + self._order_repo = order_repo + + async def handle_login(self, username: str) -> str: + """ + Route handler for user login. + + Args: + username: The username to log in. + + Returns: + A session token or an error message. + """ + token = await self._session_handler.login(username) + return token if token else "Unauthorized" + + async def handle_get_profile(self, token: str) -> str: + """ + Route handler for retrieving user profile. + + Args: + token: The authentication token. + + Returns: + User profile details or an error message. + """ + if await self._session_handler.validate_session(token): + user_id = "user_123" # Mocked from token + user = await self._user_manager.get_user_profile(user_id) + return f"User: {user.username if user else 'Unknown'}" + return "Invalid Session" + + async def handle_create_user(self, username: str, email: str) -> str: + """ + Route handler for user registration. + + Args: + username: New username. + email: New email. + + Returns: + Success or failure message. + """ + if not validate_email(email): + return "Invalid Email" + + user = await self._user_manager.create_user(username, email) + return f"User {user.username} created" + + async def handle_get_orders(self, token: str) -> str: + """ + Route handler for retrieving user orders. + + Args: + token: The authentication token. + + Returns: + List of orders or an error message. + """ + if await self._session_handler.validate_session(token): + user_id = "user_123" # Mocked from token + orders = await self._order_repo.get_orders_by_user(user_id) + return f"Orders: {len(orders)}" + return "Invalid Session" diff --git a/tests/test_codebase/auth/jwt_utils.py b/tests/test_codebase/auth/jwt_utils.py new file mode 100644 index 0000000..8617ffe --- /dev/null +++ b/tests/test_codebase/auth/jwt_utils.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import time +from ..utils.crypto import generate_secure_token + + +class JWTUtils: + """ + Utility class for handling JSON Web Tokens. + """ + + SECRET_KEY = "super-secret-key" + + def encode_token(self, user_id: str) -> str: + """ + Encodes a user ID into a JWT token. + + Args: + user_id: The ID of the user. + + Returns: + A signed JWT string. + """ + return f"jwt.header.{user_id}.{self.SECRET_KEY}.{int(time.time())}" + + def decode_token(self, token: str) -> str | None: + """ + Decodes a JWT token and returns the user ID. + + Args: + token: The JWT string to decode. + + Returns: + The user ID if the token is valid, otherwise None. + """ + try: + parts = token.split(".") + if len(parts) == 4 and parts[3] == self.SECRET_KEY: + return parts[2] + except Exception: + pass + return None diff --git a/tests/test_codebase/auth/session_handler.py b/tests/test_codebase/auth/session_handler.py new file mode 100644 index 0000000..591798c --- /dev/null +++ b/tests/test_codebase/auth/session_handler.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from .jwt_utils import JWTUtils +from ..db.user_repository import UserRepository + + +class SessionHandler: + """ + Handles user sessions and authentication state. + """ + + def __init__(self, jwt_utils: JWTUtils, user_repo: UserRepository): + """ + Initializes the SessionHandler. + + Args: + jwt_utils: Utility for JWT operations. + user_repo: Repository for user data. + """ + self._jwt_utils = jwt_utils + self._user_repo = user_repo + + async def login(self, username: str) -> str | None: + """ + Authenticates a user and returns a session token. + + Args: + username: The username of the user. + + Returns: + A JWT token if authentication succeeds, otherwise None. + """ + user = await self._user_repo.find_by_username(username) + if user: + return self._jwt_utils.encode_token(user.user_id) + return None + + async def validate_session(self, token: str) -> bool: + """ + Validates if a session token is still active and valid. + + Args: + token: The JWT token to validate. + + Returns: + True if valid, False otherwise. + """ + user_id = self._jwt_utils.decode_token(token) + if user_id: + user = await self._user_repo.get_by_id(user_id) + return user is not None + return False diff --git a/tests/test_codebase/auth/user_manager.py b/tests/test_codebase/auth/user_manager.py new file mode 100644 index 0000000..b4cd159 --- /dev/null +++ b/tests/test_codebase/auth/user_manager.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from ..db.user_repository import UserRepository, User + + +class UserManager: + """ + Manages user-related business logic. + """ + + def __init__(self, user_repo: UserRepository): + """ + Initializes the UserManager. + + Args: + user_repo: An instance of UserRepository for data access. + """ + self._user_repo = user_repo + + async def create_user(self, username: str, email: str) -> User: + """ + Creates a new user in the system. + + Args: + username: Desired username. + email: User's email address. + + Returns: + The created User object. + """ + user = User("new_id", username, email) + await self._user_repo.save(user) + return user + + async def get_user_profile(self, user_id: str) -> User | None: + """ + Retrieves a user's profile. + + Args: + user_id: The ID of the user. + + Returns: + The User object if found. + """ + return await self._user_repo.get_by_id(user_id) diff --git a/tests/test_codebase/calculator.py b/tests/test_codebase/calculator.py new file mode 100644 index 0000000..bf526a2 --- /dev/null +++ b/tests/test_codebase/calculator.py @@ -0,0 +1,9 @@ +from math_utils import add, multiply + +def compute_sum(x: int, y: int) -> int: + """Computes sum using utils.""" + return add(x, y) + +def compute_product(x: int, y: int) -> int: + """Computes product using utils.""" + return multiply(x, y) diff --git a/tests/test_codebase/db/base_repository.py b/tests/test_codebase/db/base_repository.py new file mode 100644 index 0000000..8f54b0e --- /dev/null +++ b/tests/test_codebase/db/base_repository.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar, Sequence + +T = TypeVar("T") + + +class BaseRepository(ABC, Generic[T]): + """ + Abstract base class for all repositories. + Defines standard CRUD operations for entities. + """ + + def __init__(self, connection_string: str): + """ + Initializes the repository with a database connection string. + + Args: + connection_string: The URI for the database connection. + """ + self._connection_string = connection_string + + @abstractmethod + async def get_by_id(self, entity_id: str) -> T | None: + """ + Retrieves an entity by its unique identifier. + + Args: + entity_id: The ID of the entity to retrieve. + + Returns: + The entity if found, otherwise None. + """ + pass + + @abstractmethod + async def save(self, entity: T) -> None: + """ + Saves or updates an entity in the database. + + Args: + entity: The entity to persist. + """ + pass + + @abstractmethod + async def delete(self, entity_id: str) -> bool: + """ + Deletes an entity by its ID. + + Args: + entity_id: The ID of the entity to delete. + + Returns: + True if the entity was deleted, False otherwise. + """ + pass diff --git a/tests/test_codebase/db/order_repository.py b/tests/test_codebase/db/order_repository.py new file mode 100644 index 0000000..84d4a0a --- /dev/null +++ b/tests/test_codebase/db/order_repository.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from .base_repository import BaseRepository + + +class Order: + """Represents an order entity.""" + def __init__(self, order_id: str, user_id: str, amount: float): + self.order_id = order_id + self.user_id = user_id + self.amount = amount + + +class OrderRepository(BaseRepository[Order]): + """ + Repository implementation for managing Order entities. + """ + + async def get_by_id(self, entity_id: str) -> Order | None: + """ + Fetches an order by its order_id. + + Args: + entity_id: The order ID. + + Returns: + An Order instance if found. + """ + return Order(entity_id, "user_123", 99.99) + + async def save(self, entity: Order) -> None: + """ + Persists order data. + + Args: + entity: The Order object to save. + """ + pass + + async def delete(self, entity_id: str) -> bool: + """ + Removes an order from the store. + + Args: + entity_id: The order ID. + + Returns: + True if successful. + """ + return True + + async def get_orders_by_user(self, user_id: str) -> list[Order]: + """ + Retrieves all orders belonging to a specific user. + + Args: + user_id: The user ID. + + Returns: + A list of orders. + """ + return [Order("o1", user_id, 10.0), Order("o2", user_id, 20.0)] diff --git a/tests/test_codebase/db/user_repository.py b/tests/test_codebase/db/user_repository.py new file mode 100644 index 0000000..81c4e1d --- /dev/null +++ b/tests/test_codebase/db/user_repository.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from .base_repository import BaseRepository + + +class User: + """Represents a user entity.""" + def __init__(self, user_id: str, username: str, email: str): + self.user_id = user_id + self.username = username + self.email = email + + +class UserRepository(BaseRepository[User]): + """ + Repository implementation for managing User entities. + Inherits from BaseRepository. + """ + + async def get_by_id(self, entity_id: str) -> User | None: + """ + Fetches a user by their user_id. + + Args: + entity_id: The user ID. + + Returns: + A User instance if found. + """ + # Mock implementation + return User(entity_id, "test_user", "test@example.com") + + async def save(self, entity: User) -> None: + """ + Persists user data. + + Args: + entity: The User object to save. + """ + pass + + async def delete(self, entity_id: str) -> bool: + """ + Removes a user from the store. + + Args: + entity_id: The user ID. + + Returns: + True if successful. + """ + return True + + async def find_by_username(self, username: str) -> User | None: + """ + Finds a user by their username. + + Args: + username: The username to search for. + + Returns: + A User instance if found. + """ + # Mock implementation + return User("123", username, "user@example.com") diff --git a/tests/test_codebase/main.py b/tests/test_codebase/main.py new file mode 100644 index 0000000..6b8e18a --- /dev/null +++ b/tests/test_codebase/main.py @@ -0,0 +1 @@ +from .calculator import calculate_sum\n\ndef main():\n print(calculate_sum([1, 2, 3]))\n\nif __name__ == '__main__':\n main() diff --git a/tests/test_codebase/math_utils.py b/tests/test_codebase/math_utils.py new file mode 100644 index 0000000..60af0d9 --- /dev/null +++ b/tests/test_codebase/math_utils.py @@ -0,0 +1,7 @@ +def add(a: int, b: int) -> int: + """Adds two integers.""" + return a + b + +def multiply(a: int, b: int) -> int: + """Multiplies two integers.""" + return a * b diff --git a/tests/test_codebase/utils/crypto.py b/tests/test_codebase/utils/crypto.py new file mode 100644 index 0000000..b1d0cda --- /dev/null +++ b/tests/test_codebase/utils/crypto.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import hashlib +import secrets + + +def generate_secure_token(length: int = 32) -> str: + """ + Generates a cryptographically secure random token. + + Args: + length: The length of the token to generate. + + Returns: + A secure random hex string. + """ + return secrets.token_hex(length // 2) + + +def hash_password(password: str, salt: str) -> str: + """ + Hashes a password with a given salt using SHA-256. + + Args: + password: The plain-text password. + salt: The salt to be used for hashing. + + Returns: + The hex digest of the hashed password. + """ + combined = password + salt + return hashlib.sha256(combined.encode()).hexdigest() diff --git a/tests/test_codebase/utils/validators.py b/tests/test_codebase/utils/validators.py new file mode 100644 index 0000000..5e5d4d2 --- /dev/null +++ b/tests/test_codebase/utils/validators.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import re + + +def validate_email(email: str) -> bool: + """ + Validates an email address using a regular expression. + + Args: + email: The email string to validate. + + Returns: + True if the email is valid, False otherwise. + """ + pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" + return bool(re.match(pattern, email)) + + +def validate_username(username: str) -> bool: + """ + Validates a username. Must be alphanumeric and between 3-20 characters. + + Args: + username: The username to validate. + + Returns: + True if the username is valid, False otherwise. + """ + return username.isalnum() and 3 <= len(username) <= 20 diff --git a/tests/test_enricher.py b/tests/test_enricher.py new file mode 100644 index 0000000..82daea1 --- /dev/null +++ b/tests/test_enricher.py @@ -0,0 +1,153 @@ +"""Tests for the static semantic enricher — SMP(3).""" + +from __future__ import annotations + +import pytest + +from smp.core.models import ( + Annotations, + GraphNode, + NodeType, + SemanticProperties, + StructuralProperties, +) +from smp.engine.enricher import StaticSemanticEnricher, _compute_source_hash + +# --------------------------------------------------------------------------- +# Source hash +# --------------------------------------------------------------------------- + + +class TestSourceHash: + def test_deterministic(self) -> None: + h1 = _compute_source_hash("foo", "test.py", 1, 5, "def foo():") + h2 = _compute_source_hash("foo", "test.py", 1, 5, "def foo():") + assert h1 == h2 + + def test_different_names_differ(self) -> None: + h1 = _compute_source_hash("foo", "test.py", 1, 5, "def foo():") + h2 = _compute_source_hash("bar", "test.py", 1, 5, "def bar():") + assert h1 != h2 + + def test_different_files_differ(self) -> None: + h1 = _compute_source_hash("foo", "a.py", 1, 5, "def foo():") + h2 = _compute_source_hash("foo", "b.py", 1, 5, "def foo():") + assert h1 != h2 + + +# --------------------------------------------------------------------------- +# Enricher (static mode, no API key) +# --------------------------------------------------------------------------- + + +class TestStaticSemanticEnricher: + @pytest.fixture() + def enricher(self) -> StaticSemanticEnricher: + return StaticSemanticEnricher() + + def _make_node( + self, + id: str = "test::Function::foo::1", + name: str = "foo", + docstring: str = "", + decorators: list[str] | None = None, + annotations: Annotations | None = None, + ) -> GraphNode: + return GraphNode( + id=id, + type=NodeType.FUNCTION, + file_path="test.py", + structural=StructuralProperties( + name=name, + file="test.py", + signature=f"def {name}():", + start_line=1, + end_line=5, + lines=5, + ), + semantic=SemanticProperties( + docstring=docstring, + decorators=decorators or [], + annotations=annotations, + ), + ) + + @pytest.mark.asyncio + async def test_enrich_no_metadata(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node(docstring="") + enriched = await enricher.enrich_node(node) + assert enriched.semantic.status == "no_metadata" + assert enriched.semantic.source_hash != "" + + @pytest.mark.asyncio + async def test_enrich_with_docstring(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node(docstring="Validates credentials and issues JWT.") + enriched = await enricher.enrich_node(node) + assert enriched.semantic.status == "enriched" + assert enriched.semantic.source_hash != "" + assert enriched.semantic.enriched_at != "" + + @pytest.mark.asyncio + async def test_enrich_with_decorators(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node(decorators=["pytest.fixture"]) + enriched = await enricher.enrich_node(node) + assert enriched.semantic.status == "enriched" + + @pytest.mark.asyncio + async def test_enrich_with_annotations(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node( + annotations=Annotations(params={"x": "int"}, returns="str"), + ) + enriched = await enricher.enrich_node(node) + assert enriched.semantic.status == "enriched" + + @pytest.mark.asyncio + async def test_skip_unchanged(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node(docstring="Test.") + enriched1 = await enricher.enrich_node(node) + assert enriched1.semantic.status == "enriched" + hash1 = enriched1.semantic.source_hash + + enriched2 = await enricher.enrich_node(enriched1) + assert enriched2.semantic.source_hash == hash1 + + @pytest.mark.asyncio + async def test_force_re_enrich(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node(docstring="Test.") + enriched1 = await enricher.enrich_node(node) + enriched2 = await enricher.enrich_node(enriched1, force=True) + assert enriched2.semantic.status == "enriched" + + @pytest.mark.asyncio + async def test_enrich_batch(self, enricher: StaticSemanticEnricher) -> None: + nodes = [ + self._make_node(id=f"test::Function::f{i}::{i}", name=f"f{i}", docstring=f"Doc {i}.") for i in range(5) + ] + enriched = await enricher.enrich_batch(nodes) + assert len(enriched) == 5 + for n in enriched: + assert n.semantic.status == "enriched" + + @pytest.mark.asyncio + async def test_embed_noop(self, enricher: StaticSemanticEnricher) -> None: + emb = await enricher.embed("test text") + assert emb == [] + + @pytest.mark.asyncio + async def test_no_llm(self, enricher: StaticSemanticEnricher) -> None: + assert enricher.has_llm is False + + @pytest.mark.asyncio + async def test_counts(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node(docstring="Test.") + await enricher.enrich_node(node) + counts = enricher.get_counts() + assert counts["enriched"] == 1 + + @pytest.mark.asyncio + async def test_reset_counts(self, enricher: StaticSemanticEnricher) -> None: + node = self._make_node(docstring="Test.") + await enricher.enrich_node(node) + enricher.reset_counts() + counts = enricher.get_counts() + assert counts["enriched"] == 0 diff --git a/tests/test_integration_community.py b/tests/test_integration_community.py new file mode 100644 index 0000000..9df88b7 --- /dev/null +++ b/tests/test_integration_community.py @@ -0,0 +1,701 @@ +"""Integration tests for CommunityDetector.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from smp.core.models import EdgeType, GraphEdge, GraphNode, NodeType, SemanticProperties, StructuralProperties +from smp.engine.community import CommunityDetector + + +class MockGraphStore: + """Mock GraphStore for testing.""" + + def __init__(self) -> None: + self._nodes: dict[str, GraphNode] = {} + self._edges: list[GraphEdge] = [] + self._edge_index: dict[str, list[GraphEdge]] = defaultdict(list) + + async def connect(self) -> None: + pass + + async def close(self) -> None: + pass + + async def clear(self) -> None: + self._nodes.clear() + self._edges.clear() + self._edge_index.clear() + + async def upsert_node(self, node: GraphNode) -> None: + self._nodes[node.id] = node + + async def upsert_nodes(self, nodes: list[GraphNode]) -> None: + for node in nodes: + self._nodes[node.id] = node + + async def get_node(self, node_id: str) -> GraphNode | None: + return self._nodes.get(node_id) + + async def delete_node(self, node_id: str) -> bool: + if node_id in self._nodes: + del self._nodes[node_id] + return True + return False + + async def delete_nodes_by_file(self, file_path: str) -> int: + nodes_to_delete = [nid for nid, n in self._nodes.items() if n.file_path == file_path] + for nid in nodes_to_delete: + del self._nodes[nid] + return len(nodes_to_delete) + + async def upsert_edge(self, edge: GraphEdge) -> None: + self._edges.append(edge) + self._edge_index[edge.source_id].append(edge) + + async def upsert_edges(self, edges: list[GraphEdge]) -> None: + for edge in edges: + self._edges.append(edge) + self._edge_index[edge.source_id].append(edge) + + async def get_edges( + self, + node_id: str, + edge_type: EdgeType | None = None, + direction: str = "both", + ) -> list[GraphEdge]: + result = [] + for edge in self._edge_index.get(node_id, []): + if edge_type is None or edge.type == edge_type: + result.append(edge) + if direction == "incoming": + incoming = [e for e in self._edges if e.target_id == node_id] + result.extend(incoming) + return result + + async def get_neighbors( + self, + node_id: str, + edge_type: EdgeType | None = None, + depth: int = 1, + ) -> list[GraphNode]: + neighbor_ids = set() + current = {node_id} + for _ in range(depth): + next_current = set() + for nid in current: + for edge in self._edge_index.get(nid, []): + neighbor_ids.add(edge.target_id) + next_current.add(edge.target_id) + current = next_current + return [self._nodes[nid] for nid in neighbor_ids if nid in self._nodes] + + async def traverse( + self, + start_id: str, + edge_type: EdgeType, + depth: int, + max_nodes: int = 100, + direction: str = "outgoing", + ) -> list[GraphNode]: + return await self.get_neighbors(start_id, edge_type, depth) + + async def find_nodes( + self, + *, + type: NodeType | None = None, + file_path: str | None = None, + name: str | None = None, + ) -> list[GraphNode]: + result = list(self._nodes.values()) + if type is not None: + result = [n for n in result if n.type == type] + if file_path is not None: + result = [n for n in result if n.file_path == file_path] + if name is not None: + result = [n for n in result if n.structural.name == name] + return result + + async def count_nodes(self) -> int: + return len(self._nodes) + + async def count_edges(self) -> int: + return len(self._edges) + + +class MockVectorStore: + """Mock VectorStore for testing.""" + + def __init__(self) -> None: + self._embeddings: dict[str, list[float]] = {} + self._metadata: dict[str, dict[str, Any]] = {} + self._documents: dict[str, str] = {} + + async def connect(self) -> None: + pass + + async def close(self) -> None: + pass + + async def clear(self) -> None: + self._embeddings.clear() + self._metadata.clear() + self._documents.clear() + + async def upsert( + self, + ids: list[str], + embeddings: list[list[float]], + metadatas: list[dict[str, Any]], + documents: list[str] | None = None, + ) -> None: + for i, id_ in enumerate(ids): + self._embeddings[id_] = embeddings[i] + self._metadata[id_] = metadatas[i] + if documents: + self._documents[id_] = documents[i] + + async def add_code_embedding( + self, + node_id: str, + embedding: list[float], + metadata: dict[str, Any], + document: str, + ) -> None: + self._embeddings[node_id] = embedding + self._metadata[node_id] = metadata + self._documents[node_id] = document + + async def query( + self, + embedding: list[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + return [] + + async def get(self, ids: list[str]) -> list[dict[str, Any] | None]: + return [None] * len(ids) + + async def delete(self, ids: list[str]) -> int: + count = 0 + for id_ in ids: + if id_ in self._embeddings: + del self._embeddings[id_] + del self._metadata[id_] + del self._documents[id_] + count += 1 + return count + + async def delete_by_file(self, file_path: str) -> int: + return 0 + + +def make_community_node( + id: str, + type: NodeType = NodeType.FUNCTION, + file_path: str = "src/module/file.py", + name: str = "test_func", + tags: list[str] | None = None, +) -> GraphNode: + if tags is None: + tags = [] + return GraphNode( + id=id, + type=type, + file_path=file_path, + structural=StructuralProperties(name=name, file=file_path), + semantic=SemanticProperties(tags=tags), + ) + + +class TestCommunityDetectorInit: + """Test CommunityDetector.__init__().""" + + async def test_init_with_graph_store_only(self) -> None: + store = MockGraphStore() + detector = CommunityDetector(graph_store=store) + assert detector._graph is store + assert detector._vector is None + assert detector._min_size == 5 + assert detector._communities == {} + + async def test_init_with_graph_and_vector_store(self) -> None: + graph_store = MockGraphStore() + vector_store = MockVectorStore() + detector = CommunityDetector(graph_store=graph_store, vector_store=vector_store, min_community_size=10) + assert detector._graph is graph_store + assert detector._vector is vector_store + assert detector._min_size == 10 + + +class TestLouvainAlgorithm: + """Test CommunityDetector._louvain().""" + + async def test_louvain_three_cliques(self) -> None: + store = MockGraphStore() + await store.connect() + + clique_a_nodes = [ + make_community_node("a1", file_path="src/package_a/file1.py", name="func_a1", tags=["auth"]), + make_community_node("a2", file_path="src/package_a/file1.py", name="func_a2", tags=["auth"]), + make_community_node("a3", file_path="src/package_a/file2.py", name="func_a3", tags=["auth"]), + make_community_node("a4", file_path="src/package_a/file2.py", name="func_a4", tags=["auth"]), + make_community_node("a5", file_path="src/package_a/file3.py", name="func_a5", tags=["auth"]), + ] + clique_b_nodes = [ + make_community_node("b1", file_path="src/package_b/file1.py", name="func_b1", tags=["api"]), + make_community_node("b2", file_path="src/package_b/file1.py", name="func_b2", tags=["api"]), + make_community_node("b3", file_path="src/package_b/file2.py", name="func_b3", tags=["api"]), + make_community_node("b4", file_path="src/package_b/file2.py", name="func_b4", tags=["api"]), + make_community_node("b5", file_path="src/package_b/file3.py", name="func_b5", tags=["api"]), + ] + clique_c_nodes = [ + make_community_node("c1", file_path="src/package_c/file1.py", name="func_c1", tags=["core"]), + make_community_node("c2", file_path="src/package_c/file1.py", name="func_c2", tags=["core"]), + make_community_node("c3", file_path="src/package_c/file2.py", name="func_c3", tags=["core"]), + make_community_node("c4", file_path="src/package_c/file2.py", name="func_c4", tags=["core"]), + make_community_node("c5", file_path="src/package_c/file3.py", name="func_c5", tags=["core"]), + ] + + all_nodes = clique_a_nodes + clique_b_nodes + clique_c_nodes + await store.upsert_nodes(all_nodes) + + for n1, n2 in zip(clique_a_nodes, clique_a_nodes[1:], strict=False): + await store.upsert_edge(GraphEdge(source_id=n1.id, target_id=n2.id, type=EdgeType.CALLS)) + for n1, n2 in zip(clique_b_nodes, clique_b_nodes[1:], strict=False): + await store.upsert_edge(GraphEdge(source_id=n1.id, target_id=n2.id, type=EdgeType.CALLS)) + for n1, n2 in zip(clique_c_nodes, clique_c_nodes[1:], strict=False): + await store.upsert_edge(GraphEdge(source_id=n1.id, target_id=n2.id, type=EdgeType.CALLS)) + + await store.upsert_edge(GraphEdge(source_id="a1", target_id="b1", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="b3", target_id="c2", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store) + nodes = await store.find_nodes() + edge_types = [EdgeType.CALLS] + adjacency = await detector._build_adjacency(nodes, edge_types) + + assignments = detector._louvain(nodes, adjacency, resolution=1.0) + + comm_groups: dict[str, list[str]] = defaultdict(list) + for node_id, comm_id in assignments.items(): + comm_groups[comm_id].append(node_id) + + assert len(comm_groups) >= 3 + for group in comm_groups.values(): + assert len(group) >= 1 + + for node_id in ["a1", "a2", "a3", "a4", "a5"]: + assert node_id in assignments + + for node_id in ["b1", "b2", "b3", "b4", "b5"]: + assert node_id in assignments + + for node_id in ["c1", "c2", "c3", "c4", "c5"]: + assert node_id in assignments + + async def test_louvain_single_community(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/pkg/file.py", name="func1"), + make_community_node("n2", file_path="src/pkg/file.py", name="func2"), + make_community_node("n3", file_path="src/pkg/file.py", name="func3"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n2", target_id="n3", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n3", target_id="n1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store) + all_nodes = await store.find_nodes() + adjacency = await detector._build_adjacency(all_nodes, [EdgeType.CALLS]) + + assignments = detector._louvain(all_nodes, adjacency, resolution=1.0) + + comm_ids = set(assignments.values()) + assert len(comm_ids) == 1 + + async def test_louvain_empty_graph(self) -> None: + store = MockGraphStore() + detector = CommunityDetector(graph_store=store) + + assignments = detector._louvain([], {}, resolution=1.0) + assert assignments == {} + + +class TestBuildCommunities: + """Test CommunityDetector._build_communities().""" + + async def test_build_communities_labels_and_counts(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/auth/login.py", name="login", tags=["auth"]), + make_community_node("n2", file_path="src/auth/login.py", name="validate", tags=["auth"]), + make_community_node("n3", file_path="src/auth/logout.py", name="logout", tags=["auth"]), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n2", target_id="n3", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n3", target_id="n1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + all_nodes = await store.find_nodes() + adjacency = await detector._build_adjacency(all_nodes, [EdgeType.CALLS]) + + assignments = detector._louvain(all_nodes, adjacency, resolution=1.0) + communities = detector._build_communities(assignments, all_nodes, adjacency, level=0, label="coarse") + + assert len(communities) >= 1 + for comm in communities.values(): + assert comm.label.startswith("coarse") + assert comm.member_count >= 1 + assert comm.level == 0 + + async def test_build_communities_file_count(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/auth/login.py", name="login"), + make_community_node("n2", file_path="src/auth/login.py", name="validate"), + make_community_node("n3", file_path="src/auth/logout.py", name="logout"), + make_community_node("n4", file_path="src/core/util.py", name="helper"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n2", target_id="n3", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n4", target_id="n1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + all_nodes = await store.find_nodes() + adjacency = await detector._build_adjacency(all_nodes, [EdgeType.CALLS]) + + assignments = detector._louvain(all_nodes, adjacency, resolution=1.0) + communities = detector._build_communities(assignments, all_nodes, adjacency, level=0, label="test") + + for comm in communities.values(): + assert comm.file_count >= 1 + + +class TestComputeModularity: + """Test CommunityDetector._compute_modularity().""" + + async def test_good_partition_positive_modularity(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("a1", file_path="src/pkg/file.py", name="a1"), + make_community_node("a2", file_path="src/pkg/file.py", name="a2"), + make_community_node("b1", file_path="src/pkg/file.py", name="b1"), + make_community_node("b2", file_path="src/pkg/file.py", name="b2"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="a1", target_id="a2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="a2", target_id="a1", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="b1", target_id="b2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="b2", target_id="b1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store) + all_nodes = await store.find_nodes() + adjacency = await detector._build_adjacency(all_nodes, [EdgeType.CALLS]) + + louvain_assignments = detector._louvain(all_nodes, adjacency, resolution=1.0) + modularity = detector._compute_modularity(louvain_assignments, adjacency) + + assert modularity > 0 + + async def test_random_partition_near_zero_modularity(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/pkg/file.py", name="n1"), + make_community_node("n2", file_path="src/pkg/file.py", name="n2"), + make_community_node("n3", file_path="src/pkg/file.py", name="n3"), + make_community_node("n4", file_path="src/pkg/file.py", name="n4"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n2", target_id="n3", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n3", target_id="n1", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n4", target_id="n1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store) + all_nodes = await store.find_nodes() + adjacency = await detector._build_adjacency(all_nodes, [EdgeType.CALLS]) + + random_assignments = {node.id: f"comm_{i % 2}" for i, node in enumerate(all_nodes)} + modularity = detector._compute_modularity(random_assignments, adjacency) + + assert modularity <= 1.0 + + async def test_empty_adjacency_returns_zero(self) -> None: + detector = CommunityDetector(graph_store=MockGraphStore()) + modularity = detector._compute_modularity({}, {}) + assert modularity == 0.0 + + +class TestDetectBridges: + """Test CommunityDetector._detect_bridges().""" + + async def test_bridges_detected_between_communities(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("a1", file_path="src/pkg_a/file.py", name="func_a1"), + make_community_node("a2", file_path="src/pkg_a/file.py", name="func_a2"), + make_community_node("b1", file_path="src/pkg_b/file.py", name="func_b1"), + make_community_node("b2", file_path="src/pkg_b/file.py", name="func_b2"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="a1", target_id="a2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="a2", target_id="a1", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="b1", target_id="b2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="b2", target_id="b1", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="a1", target_id="b1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store) + all_nodes = await store.find_nodes() + adjacency = await detector._build_adjacency(all_nodes, [EdgeType.CALLS]) + + detector._node_communities_l0 = detector._louvain(all_nodes, adjacency, resolution=0.5) + detector._node_communities_l1 = detector._node_communities_l0.copy() + + communities = detector._build_communities( + detector._node_communities_l0, all_nodes, adjacency, level=0, label="coarse" + ) + for cid, comm in communities.items(): + detector._communities[cid] = comm + + bridges = await detector._detect_bridges(all_nodes, adjacency) + + assert len(bridges) >= 0 + for bridge in bridges: + assert "from_community" in bridge + assert "to_community" in bridge + assert bridge["from_community"] != bridge["to_community"] + + async def test_no_bridges_in_single_community(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/pkg/file.py", name="func1"), + make_community_node("n2", file_path="src/pkg/file.py", name="func2"), + make_community_node("n3", file_path="src/pkg/file.py", name="func3"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n2", target_id="n3", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="n3", target_id="n1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store) + all_nodes = await store.find_nodes() + adjacency = await detector._build_adjacency(all_nodes, [EdgeType.CALLS]) + + detector._node_communities_l0 = detector._louvain(all_nodes, adjacency, resolution=1.0) + detector._node_communities_l1 = detector._node_communities_l0.copy() + + communities = detector._build_communities( + detector._node_communities_l0, all_nodes, adjacency, level=0, label="test" + ) + for cid, comm in communities.items(): + detector._communities[cid] = comm + + bridges = await detector._detect_bridges(all_nodes, adjacency) + + assert len(bridges) == 0 + + +class TestListCommunities: + """Test CommunityDetector.list_communities().""" + + async def test_list_empty_before_detection(self) -> None: + store = MockGraphStore() + await store.connect() + detector = CommunityDetector(graph_store=store) + + result = await detector.list_communities() + + assert result["total"] == 0 + assert result["communities"] == [] + + async def test_list_after_detection(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("a1", file_path="src/auth/login.py", name="login", tags=["auth"]), + make_community_node("a2", file_path="src/auth/login.py", name="validate", tags=["auth"]), + make_community_node("a3", file_path="src/auth/logout.py", name="logout", tags=["auth"]), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="a1", target_id="a2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="a2", target_id="a3", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + await detector.detect() + + result = await detector.list_communities() + + assert result["total"] >= 1 + assert len(result["communities"]) >= 1 + + async def test_list_filtered_by_level(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/pkg/file.py", name="func1", tags=["test"]), + make_community_node("n2", file_path="src/pkg/file.py", name="func2", tags=["test"]), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + await detector.detect() + + result_l0 = await detector.list_communities(level=0) + result_l1 = await detector.list_communities(level=1) + + for comm in result_l0["communities"]: + assert comm["level"] == 0 + for comm in result_l1["communities"]: + assert comm["level"] == 1 + + +class TestGetCommunity: + """Test CommunityDetector.get_community().""" + + async def test_get_nonexistent_community(self) -> None: + store = MockGraphStore() + await store.connect() + detector = CommunityDetector(graph_store=store) + + result = await detector.get_community("nonexistent_id") + + assert result is None + + async def test_get_existing_community(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/auth/login.py", name="login", tags=["auth"]), + make_community_node("n2", file_path="src/auth/login.py", name="validate", tags=["auth"]), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + await detector.detect() + + comm_ids = list(detector._communities.keys()) + assert len(comm_ids) >= 1 + + result = await detector.get_community(comm_ids[0]) + + assert result is not None + assert "community_id" in result + assert "members" in result + + async def test_get_community_with_node_type_filter(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("n1", file_path="src/auth/login.py", name="login", tags=["auth"]), + make_community_node("n2", file_path="src/auth/login.py", name="validate", tags=["auth"]), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="n1", target_id="n2", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + await detector.detect() + + comm_ids = list(detector._communities.keys()) + result = await detector.get_community(comm_ids[0], node_types=["Function"]) + + assert result is not None + for member in result["members"]: + assert member["type"] == "Function" + + +class TestGetBoundaries: + """Test CommunityDetector.get_boundaries().""" + + async def test_get_boundaries_empty_before_detection(self) -> None: + store = MockGraphStore() + await store.connect() + detector = CommunityDetector(graph_store=store) + + result = await detector.get_boundaries() + + assert result["level"] == 0 + assert result["boundaries"] == [] + + async def test_get_boundaries_after_detection(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("a1", file_path="src/pkg_a/file.py", name="func_a1"), + make_community_node("a2", file_path="src/pkg_a/file.py", name="func_a2"), + make_community_node("b1", file_path="src/pkg_b/file.py", name="func_b1"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="a1", target_id="a2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="a2", target_id="a1", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="a1", target_id="b1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + await detector.detect() + + result = await detector.get_boundaries() + + assert "level" in result + assert "boundaries" in result + + async def test_get_boundaries_with_min_coupling(self) -> None: + store = MockGraphStore() + await store.connect() + + nodes = [ + make_community_node("a1", file_path="src/pkg_a/file.py", name="func_a1"), + make_community_node("a2", file_path="src/pkg_a/file.py", name="func_a2"), + make_community_node("b1", file_path="src/pkg_b/file.py", name="func_b1"), + ] + await store.upsert_nodes(nodes) + + await store.upsert_edge(GraphEdge(source_id="a1", target_id="a2", type=EdgeType.CALLS)) + await store.upsert_edge(GraphEdge(source_id="a1", target_id="b1", type=EdgeType.CALLS)) + + detector = CommunityDetector(graph_store=store, min_community_size=1) + await detector.detect() + + result = await detector.get_boundaries(level=0, min_coupling=0.5) + + assert "boundaries" in result \ No newline at end of file diff --git a/tests/test_integration_merkle.py b/tests/test_integration_merkle.py new file mode 100644 index 0000000..bc9b390 --- /dev/null +++ b/tests/test_integration_merkle.py @@ -0,0 +1,219 @@ +"""Integration tests for MerkleTree and MerkleIndex.""" + +from __future__ import annotations + +from smp.core.merkle import MerkleIndex, MerkleTree +from smp.core.models import GraphNode, NodeType, SemanticProperties, StructuralProperties + + +def make_file_node( + node_id: str = "file_1", + file_path: str = "src/app.py", + source_hash: str = "abc123", +) -> GraphNode: + """Create a FILE-type GraphNode for testing MerkleTree.""" + return GraphNode( + id=node_id, + type=NodeType.FILE, + file_path=file_path, + structural=StructuralProperties(name=file_path, file=file_path), + semantic=SemanticProperties(source_hash=source_hash), + ) + + +class TestMerkleTree: + """Tests for MerkleTree.build(), hash(), diff(), export(), import_data().""" + + def test_build_with_file_nodes(self) -> None: + """Build tree from FILE nodes - no exceptions.""" + nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + make_file_node("file_2", "src/b.py", "hash2"), + make_file_node("file_3", "src/c.py", "hash3"), + ] + tree = MerkleTree() + tree.build(nodes) + assert tree.hash() != "" + + def test_build_with_mixed_node_types(self) -> None: + """Build tree ignores non-FILE nodes.""" + nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + GraphNode( + id="func_1", + type=NodeType.FUNCTION, + file_path="src/a.py", + structural=StructuralProperties(name="test_func"), + semantic=SemanticProperties(), + ), + ] + tree = MerkleTree() + tree.build(nodes) + assert tree.hash() != "" + + def test_build_empty(self) -> None: + """Build with no FILE nodes results in empty _levels.""" + tree = MerkleTree() + tree.build([]) + assert tree._levels == [[]] + + def test_hash_deterministic(self) -> None: + """Same nodes produce same hash.""" + nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + make_file_node("file_2", "src/b.py", "hash2"), + ] + tree1 = MerkleTree() + tree1.build(nodes) + tree2 = MerkleTree() + tree2.build(nodes) + assert tree1.hash() == tree2.hash() + + def test_hash_different_nodes_different_hash(self) -> None: + """Different nodes produce different hashes.""" + tree1 = MerkleTree() + tree1.build([make_file_node("file_1", "src/a.py", "hash1")]) + tree2 = MerkleTree() + tree2.build([make_file_node("file_2", "src/b.py", "hash2")]) + assert tree1.hash() != tree2.hash() + + def test_diff_added(self) -> None: + """Node added to remote appears in added set.""" + local_nodes = [make_file_node("file_1", "src/a.py", "hash1")] + remote_nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + make_file_node("file_2", "src/b.py", "hash2"), + ] + local = MerkleTree() + local.build(local_nodes) + remote = MerkleTree() + remote.build(remote_nodes) + diff = local.diff(remote) + assert "file_2" in diff["added"] + assert diff["removed"] == set() + assert diff["modified"] == set() + + def test_diff_removed(self) -> None: + """Node removed from remote appears in removed set.""" + local_nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + make_file_node("file_2", "src/b.py", "hash2"), + ] + remote_nodes = [make_file_node("file_1", "src/a.py", "hash1")] + local = MerkleTree() + local.build(local_nodes) + remote = MerkleTree() + remote.build(remote_nodes) + diff = local.diff(remote) + assert diff["added"] == set() + assert "file_2" in diff["removed"] + assert diff["modified"] == set() + + def test_diff_modified(self) -> None: + """Node with changed hash appears in modified set.""" + local_nodes = [make_file_node("file_1", "src/a.py", "hash1")] + remote_nodes = [make_file_node("file_1", "src/a.py", "hash2")] + local = MerkleTree() + local.build(local_nodes) + remote = MerkleTree() + remote.build(remote_nodes) + diff = local.diff(remote) + assert diff["added"] == set() + assert diff["removed"] == set() + assert "file_1" in diff["modified"] + + def test_diff_no_changes(self) -> None: + """Identical trees have no added/removed/modified.""" + nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + make_file_node("file_2", "src/b.py", "hash2"), + ] + local = MerkleTree() + local.build(nodes) + remote = MerkleTree() + remote.build(nodes) + diff = local.diff(remote) + assert diff["added"] == set() + assert diff["removed"] == set() + assert diff["modified"] == set() + + def test_export_returns_dict(self) -> None: + """Export returns dict with expected keys.""" + tree = MerkleTree() + tree.build([make_file_node("file_1", "src/a.py", "hash1")]) + exported = tree.export() + assert "root" in exported + assert "levels" in exported + assert "leaf_hashes" in exported + assert exported["root"] == tree.hash() + + def test_export_deterministic(self) -> None: + """Same tree exports to same structure.""" + nodes = [make_file_node("file_1", "src/a.py", "hash1")] + tree1 = MerkleTree() + tree1.build(nodes) + exp1 = tree1.export() + tree2 = MerkleTree() + tree2.build(nodes) + exp2 = tree2.export() + assert exp1["root"] == exp2["root"] + + def test_import_recreates_hash(self) -> None: + """Import reconstructs tree with same hash.""" + nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + make_file_node("file_2", "src/b.py", "hash2"), + ] + original = MerkleTree() + original.build(nodes) + exported = original.export() + + restored = MerkleTree() + restored.import_data(exported) + assert restored.hash() == original.hash() + + def test_roundtrip_lossless(self) -> None: + """Round-trip export/import is lossless.""" + nodes = [ + make_file_node("file_1", "src/a.py", "hash1"), + make_file_node("file_2", "src/b.py", "hash2"), + ] + original = MerkleTree() + original.build(nodes) + exported = original.export() + + restored = MerkleTree() + restored.import_data(exported) + restored_exp = restored.export() + + assert exported["root"] == restored_exp["root"] + assert exported["leaf_hashes"] == restored_exp["leaf_hashes"] + + +class TestMerkleIndex: + """Tests for MerkleIndex.sync() and apply_patch().""" + + def test_sync_in_sync(self) -> None: + """When local_hash == remote_hash, sync returns None.""" + nodes = [make_file_node("file_1", "src/a.py", "hash1")] + tree = MerkleTree() + tree.build(nodes) + index = MerkleIndex(tree) + result = index.sync(tree.hash()) + assert result is None + + def test_sync_different_hash(self) -> None: + """When different, sync returns None (triggers diff).""" + tree = MerkleTree() + tree.build([make_file_node("file_1", "src/a.py", "hash1")]) + index = MerkleIndex(tree) + result = index.sync("different_hash") + assert result is None + + def test_apply_patch_logs(self) -> None: + """apply_patch executes without error.""" + tree = MerkleTree() + tree.build([make_file_node("file_1", "src/a.py", "hash1")]) + index = MerkleIndex(tree) + patch = {"added": ["file_2"], "removed": [], "modified": []} + index.apply_patch(patch) diff --git a/tests/test_integration_parser_graph.py b/tests/test_integration_parser_graph.py new file mode 100644 index 0000000..13ad132 --- /dev/null +++ b/tests/test_integration_parser_graph.py @@ -0,0 +1,292 @@ +"""Integration tests for ParserRegistry and DefaultGraphBuilder using sample_project.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from smp.core.models import Document, EdgeType, Language, NodeType +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.parser.registry import ParserRegistry +from smp.store.graph.neo4j_store import Neo4jGraphStore + + +FIXTURE_PATH = Path("/home/bhagyarekhab/SMP/tests/fixtures/sample_project/src") + + +class TestParserRegistryIntegration: + """Test ParserRegistry parsing all files in sample_project/src.""" + + @pytest.fixture() + def registry(self) -> ParserRegistry: + return ParserRegistry() + + def _get_python_files(self) -> list[Path]: + """Get all Python files from the fixture directory.""" + return list(FIXTURE_PATH.rglob("*.py")) + + def test_finds_python_files(self, registry: ParserRegistry) -> None: + """Verify fixture directory has Python files.""" + files = self._get_python_files() + assert len(files) > 0, "No Python files found in fixture directory" + + @pytest.mark.asyncio + async def test_parse_auth_service(self, registry: ParserRegistry) -> None: + """Parse auth_service.py and verify nodes extracted.""" + doc = registry.parse_file(str(FIXTURE_PATH / "auth/auth_service.py")) + + assert doc.language == Language.PYTHON + assert len(doc.errors) == 0, f"Parse errors: {doc.errors}" + + functions = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + classes = [n for n in doc.nodes if n.type == NodeType.CLASS] + imports = [n for n in doc.nodes if n.type == NodeType.FILE and "import" in n.structural.signature] + + assert len(functions) >= 4, f"Expected 4+ functions, got {len(functions)}: {[f.structural.name for f in functions]}" + assert len(classes) >= 1, f"Expected 1+ classes, got {len(classes)}" + assert len(imports) >= 3, f"Expected 3+ imports, got {len(imports)}" + + @pytest.mark.asyncio + async def test_parse_db_models(self, registry: ParserRegistry) -> None: + """Parse db/models.py and verify nodes extracted.""" + doc = registry.parse_file(str(FIXTURE_PATH / "db/models.py")) + + assert doc.language == Language.PYTHON + assert len(doc.errors) == 0, f"Parse errors: {doc.errors}" + + classes = [n for n in doc.nodes if n.type == NodeType.CLASS] + functions = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + + assert len(classes) >= 1, f"Expected 1+ classes, got {len(classes)}" + assert len(functions) >= 0, f"Expected 0+ functions, got {len(functions)}" + + @pytest.mark.asyncio + async def test_parse_api_routes(self, registry: ParserRegistry) -> None: + """Parse api/routes.py and verify nodes extracted.""" + doc = registry.parse_file(str(FIXTURE_PATH / "api/routes.py")) + + assert doc.language == Language.PYTHON + assert len(doc.errors) == 0, f"Parse errors: {doc.errors}" + + functions = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + imports = [n for n in doc.nodes if n.type == NodeType.FILE and "import" in n.structural.signature] + + assert len(functions) >= 3, f"Expected 3+ functions, got {len(functions)}: {[f.structural.name for f in functions]}" + assert len(imports) >= 2, f"Expected 2+ imports, got {len(imports)}" + + @pytest.mark.asyncio + async def test_parse_all_files(self, registry: ParserRegistry) -> None: + """Parse all Python files in fixture directory.""" + files = self._get_python_files() + results: list[tuple[str, Document]] = [] + + for f in files: + doc = registry.parse_file(str(f)) + results.append((str(f), doc)) + + total_nodes = sum(len(doc.nodes) for _, doc in results) + total_errors = sum(len(doc.errors) for _, doc in results) + + assert len(results) == len(files), f"Not all files were parsed" + assert total_nodes > 0, "No nodes extracted from any file" + assert total_errors == 0, f"Parse errors in files: {[(f, doc.errors) for f, doc in results if doc.errors]}" + + +class TestGraphBuilderIntegration: + """Test DefaultGraphBuilder ingesting parsed nodes into Neo4j.""" + + @pytest.fixture() + async def store(self, clean_graph: Neo4jGraphStore) -> Neo4jGraphStore: + await clean_graph.connect() + return clean_graph + + @pytest.fixture() + def builder(self, store: Neo4jGraphStore) -> DefaultGraphBuilder: + return DefaultGraphBuilder(store) + + @pytest.fixture() + def registry(self) -> ParserRegistry: + return ParserRegistry() + + @pytest.mark.asyncio + async def test_ingest_auth_service( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Parse auth_service.py and ingest into graph, verify nodes and edges.""" + doc = registry.parse_file(str(FIXTURE_PATH / "auth/auth_service.py")) + assert len(doc.errors) == 0, f"Parse errors: {doc.errors}" + + await builder.ingest_document(doc) + + stored_nodes = await store.find_nodes() + node_names = {n.structural.name for n in stored_nodes} + + expected_functions = ["hash_password", "verify_password", "generate_token", "login", "logout", "verify_token", "get_current_user"] + expected_classes = ["AuthService"] + + for func_name in expected_functions: + assert func_name in node_names, f"Function {func_name} not found in graph. Available: {node_names}" + + for class_name in expected_classes: + assert class_name in node_names, f"Class {class_name} not found in graph. Available: {node_names}" + + edges = await store.get_edges(limit=100) + edge_types = {e.type for e in edges} + + assert EdgeType.DEFINES in edge_types, f"Missing DEFINES edges. Types found: {edge_types}" + assert EdgeType.CALLS in edge_types or EdgeType.DEFINES in edge_types, f"Missing relationship edges. Types found: {edge_types}" + + @pytest.mark.asyncio + async def test_ingest_db_models( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Parse db/models.py and ingest into graph.""" + doc = registry.parse_file(str(FIXTURE_PATH / "db/models.py")) + assert len(doc.errors) == 0, f"Parse errors: {doc.errors}" + + await builder.ingest_document(doc) + + stored_nodes = await store.find_nodes() + classes = [n for n in stored_nodes if n.type == NodeType.CLASS] + class_names = {c.structural.name for c in classes} + + assert len(classes) >= 1, f"Expected 1+ classes, got {len(classes)}: {class_names}" + + @pytest.mark.asyncio + async def test_ingest_api_routes( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Parse api/routes.py and ingest into graph.""" + doc = registry.parse_file(str(FIXTURE_PATH / "api/routes.py")) + assert len(doc.errors) == 0, f"Parse errors: {doc.errors}" + + await builder.ingest_document(doc) + + stored_nodes = await store.find_nodes() + functions = [n for n in stored_nodes if n.type == NodeType.FUNCTION] + function_names = {f.structural.name for f in functions} + + assert len(functions) >= 2, f"Expected 2+ functions, got {len(functions)}: {function_names}" + + @pytest.mark.asyncio + async def test_ingest_all_files( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Parse and ingest all Python files, verify complete graph.""" + files = list(FIXTURE_PATH.rglob("*.py")) + + for f in files: + doc = registry.parse_file(str(f)) + if doc.errors: + pytest.fail(f"Parse errors in {f}: {doc.errors}") + await builder.ingest_document(doc) + + stored_nodes = await store.find_nodes(limit=500) + stored_edges = await store.get_edges(limit=500) + + assert len(stored_nodes) > 0, "No nodes stored in graph" + assert len(stored_edges) > 0, "No edges stored in graph" + + node_types = {n.type for n in stored_nodes} + edge_types = {e.type for e in stored_edges} + + assert NodeType.FUNCTION in node_types, f"No functions found. Types: {node_types}" + assert NodeType.CLASS in node_types, f"No classes found. Types: {node_types}" + + assert EdgeType.DEFINES in edge_types, f"No DEFINES edges. Types: {edge_types}" + assert EdgeType.CALLS in edge_types or EdgeType.IMPORTS in edge_types, f"No CALLS or IMPORTS edges. Types: {edge_types}" + + +class TestEdgeCreation: + """Test that specific edge types (CALLS, IMPORTS, DEFINES) are created correctly.""" + + @pytest.fixture() + async def store(self, clean_graph: Neo4jGraphStore) -> Neo4jGraphStore: + await clean_graph.connect() + return clean_graph + + @pytest.fixture() + def builder(self, store: Neo4jGraphStore) -> DefaultGraphBuilder: + return DefaultGraphBuilder(store) + + @pytest.fixture() + def registry(self) -> ParserRegistry: + return ParserRegistry() + + @pytest.mark.asyncio + async def test_defines_edges_exist( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Verify DEFINES edges connect File/Class nodes to their children.""" + doc = registry.parse_file(str(FIXTURE_PATH / "auth/auth_service.py")) + await builder.ingest_document(doc) + + edges = await store.get_edges(limit=200) + defines_edges = [e for e in edges if e.type == EdgeType.DEFINES] + + assert len(defines_edges) > 0, "No DEFINES edges created" + + @pytest.mark.asyncio + async def test_import_edges_exist( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Verify IMPORTS edges are created for import statements.""" + doc = registry.parse_file(str(FIXTURE_PATH / "auth/auth_service.py")) + await builder.ingest_document(doc) + + edges = await store.get_edges(limit=200) + import_edges = [e for e in edges if e.type == EdgeType.IMPORTS] + + assert len(import_edges) >= 3, f"Expected 3+ IMPORTS edges, got {len(import_edges)}" + + @pytest.mark.asyncio + async def test_calls_edges_exist( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Verify CALLS edges exist for function invocations.""" + doc = registry.parse_file(str(FIXTURE_PATH / "auth/auth_service.py")) + await builder.ingest_document(doc) + + edges = await store.get_edges(limit=200) + calls_edges = [e for e in edges if e.type == EdgeType.CALLS] + + assert len(calls_edges) > 0, "No CALLS edges created" + + @pytest.mark.asyncio + async def test_class_method_defines( + self, + store: Neo4jGraphStore, + builder: DefaultGraphBuilder, + registry: ParserRegistry, + ) -> None: + """Verify DEFINES edges from Class to its methods.""" + doc = registry.parse_file(str(FIXTURE_PATH / "auth/auth_service.py")) + await builder.ingest_document(doc) + + edges = await store.get_edges(limit=200) + defines_edges = [e for e in edges if e.type == EdgeType.DEFINES] + + auth_service_defines = [e for e in defines_edges if "AuthService" in e.source_id] + assert len(auth_service_defines) >= 3, f"Expected 3+ AuthService method defines, got {len(auth_service_defines)}" diff --git a/tests/test_integration_protocol_handlers.py b/tests/test_integration_protocol_handlers.py new file mode 100644 index 0000000..9086fb8 --- /dev/null +++ b/tests/test_integration_protocol_handlers.py @@ -0,0 +1,343 @@ +"""Integration tests for SMP Protocol Handlers via the dispatcher.""" + +from __future__ import annotations + +from smp.protocol.dispatcher import RpcDispatcher +from smp.protocol.handlers.annotation import ( + AnnotateBulkHandler, + AnnotateHandler, + TagHandler, +) +from smp.protocol.handlers.community import ( + CommunityBoundariesHandler, + CommunityDetectHandler, + CommunityGetHandler, + CommunityListHandler, +) +from smp.protocol.handlers.enrichment import ( + EnrichBatchHandler, + EnrichHandler, + EnrichStaleHandler, + EnrichStatusHandler, +) +from smp.protocol.handlers.handoff import ( + HandoffPRHandler, + HandoffReviewHandler, +) +from smp.protocol.handlers.memory import ( + BatchUpdateHandler, + ReindexHandler, + UpdateHandler, +) +from smp.protocol.handlers.merkle import ( + IndexExportHandler, + IndexImportHandler, + MerkleTreeHandler, + SyncHandler, +) +from smp.protocol.handlers.query import ( + ContextHandler, + FlowHandler, + ImpactHandler, + LocateHandler, + NavigateHandler, + SearchHandler, + TraceHandler, +) +from smp.protocol.handlers.safety import ( + AuditGetHandler, + CheckpointHandler, + DryRunHandler, + GuardCheckHandler, + IntegrityVerifyHandler, + LockHandler, + RollbackHandler, + SessionCloseHandler, + SessionOpenHandler, + SessionRecoverHandler, + UnlockHandler, +) +from smp.protocol.handlers.sandbox import ( + SandboxDestroyHandler, + SandboxExecuteHandler, + SandboxSpawnHandler, +) +from smp.protocol.handlers.telemetry import ( + TelemetryHandler, + TelemetryHotHandler, + TelemetryNodeHandler, + TelemetryRecordHandler, +) + + +class TestHandlerRegistration: + """Test that all registered handlers are reachable.""" + + def test_all_registered_handlers_have_valid_method(self): + """Each handler in dispatcher must have a valid non-empty method name.""" + dispatcher = RpcDispatcher() + for method, handler in dispatcher._handlers.items(): + assert method, f"Handler {handler.__class__.__name__} has empty method" + assert isinstance(method, str), f"Handler method must be str, got {type(method)}" + assert handler.method == method, ( + f"Handler method mismatch: expected '{method}', " + f"got '{handler.method}' from {handler.__class__.__name__}" + ) + + +class TestHandlerInstantiation: + """Test each handler class can be instantiated without errors.""" + + def test_safety_handlers(self): + """Safety handlers can be instantiated.""" + handlers = [ + SessionOpenHandler, + SessionCloseHandler, + SessionRecoverHandler, + GuardCheckHandler, + DryRunHandler, + CheckpointHandler, + RollbackHandler, + LockHandler, + UnlockHandler, + AuditGetHandler, + IntegrityVerifyHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_query_handlers(self): + """Query handlers can be instantiated.""" + handlers = [ + NavigateHandler, + TraceHandler, + ContextHandler, + ImpactHandler, + LocateHandler, + SearchHandler, + FlowHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_community_handlers(self): + """Community handlers can be instantiated.""" + handlers = [ + CommunityDetectHandler, + CommunityListHandler, + CommunityGetHandler, + CommunityBoundariesHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_merkle_handlers(self): + """Merkle handlers can be instantiated.""" + handlers = [ + SyncHandler, + MerkleTreeHandler, + IndexExportHandler, + IndexImportHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_handoff_handlers(self): + """Handoff handlers can be instantiated.""" + handlers = [ + HandoffReviewHandler, + HandoffPRHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_enrichment_handlers(self): + """Enrichment handlers can be instantiated.""" + handlers = [ + EnrichHandler, + EnrichBatchHandler, + EnrichStaleHandler, + EnrichStatusHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_annotation_handlers(self): + """Annotation handlers can be instantiated.""" + handlers = [ + AnnotateHandler, + AnnotateBulkHandler, + TagHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_memory_handlers(self): + """Memory handlers can be instantiated.""" + handlers = [ + UpdateHandler, + BatchUpdateHandler, + ReindexHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_sandbox_handlers(self): + """Sandbox handlers can be instantiated.""" + handlers = [ + SandboxSpawnHandler, + SandboxExecuteHandler, + SandboxDestroyHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + def test_telemetry_handlers(self): + """Telemetry handlers can be instantiated.""" + handlers = [ + TelemetryHandler, + TelemetryHotHandler, + TelemetryNodeHandler, + TelemetryRecordHandler, + ] + for handler_cls in handlers: + handler = handler_cls() + assert handler.method.startswith("smp/"), f"{handler_cls.__name__} has invalid method: {handler.method}" + + +class TestDispatcherHandlerDiscovery: + """Test that all expected handlers are registered in the dispatcher.""" + + def test_safety_handlers_registered(self): + """All safety handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/session/open", + "smp/session/close", + "smp/session/recover", + "smp/guard/check", + "smp/dryrun", + "smp/checkpoint", + "smp/rollback", + "smp/lock", + "smp/unlock", + "smp/audit/get", + "smp/verify/integrity", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_query_handlers_registered(self): + """All query handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/navigate", + "smp/trace", + "smp/context", + "smp/impact", + "smp/locate", + "smp/search", + "smp/flow", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_community_handlers_registered(self): + """All community handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/community/detect", + "smp/community/list", + "smp/community/get", + "smp/community/boundaries", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_merkle_handlers_registered(self): + """All merkle handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/sync", + "smp/merkle/tree", + "smp/index/export", + "smp/index/import", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_handoff_handlers_registered(self): + """All handoff handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/handoff/review", + "smp/handoff/pr", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_enrichment_handlers_registered(self): + """All enrichment handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/enrich", + "smp/enrich/batch", + "smp/enrich/stale", + "smp/enrich/status", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_annotation_handlers_registered(self): + """All annotation handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/annotate", + "smp/annotate/bulk", + "smp/tag", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_memory_handlers_registered(self): + """All memory handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/update", + "smp/batch_update", + "smp/reindex", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_sandbox_handlers_registered(self): + """All sandbox handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/sandbox/spawn", + "smp/sandbox/execute", + "smp/sandbox/destroy", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" + + def test_telemetry_handlers_registered(self): + """All telemetry handlers are registered in dispatcher.""" + dispatcher = RpcDispatcher() + expected_methods = [ + "smp/telemetry", + "smp/telemetry/hot", + "smp/telemetry/node", + "smp/telemetry/record", + ] + for method in expected_methods: + assert dispatcher.get_handler(method) is not None, f"Missing handler for {method}" diff --git a/tests/test_integration_query_engine.py b/tests/test_integration_query_engine.py new file mode 100644 index 0000000..cbf1d6c --- /dev/null +++ b/tests/test_integration_query_engine.py @@ -0,0 +1,799 @@ +"""Integration tests for Query Engine components — SMP(3). + +Tests DefaultQueryEngine, SeedWalkEngine, and ChromaVectorStore. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from smp.core.models import ( + EdgeType, + GraphEdge, + GraphNode, + NodeType, + SemanticProperties, + StructuralProperties, +) +from smp.engine.query import DefaultQueryEngine +from smp.engine.seed_walk import SeedWalkEngine, _simple_hash_embedding + +try: + from smp.store.chroma_store import ChromaVectorStore + + CHROMA_AVAILABLE = True +except Exception: + CHROMA_AVAILABLE = False + ChromaVectorStore = None + + +def make_node( + id: str = "func_login", + type: NodeType = NodeType.FUNCTION, + name: str = "login", + file_path: str = "src/auth/login.py", + start_line: int = 10, + end_line: int = 25, + docstring: str = "", + signature: str = "", + tags: list[str] | None = None, +) -> GraphNode: + if signature is None: + signature = f"def {name}():" + return GraphNode( + id=id, + type=type, + file_path=file_path, + structural=StructuralProperties( + name=name, + file=file_path, + signature=signature or f"def {name}():", + start_line=start_line, + end_line=end_line, + lines=end_line - start_line + 1, + ), + semantic=SemanticProperties( + docstring=docstring, + status="enriched" if docstring else "no_metadata", + tags=tags or [], + ), + ) + + +class MockGraphStore: + """In-memory graph store for testing without Neo4j.""" + + def __init__(self) -> None: + self._nodes: dict[str, GraphNode] = {} + self._edges: list[GraphEdge] = [] + + async def connect(self) -> None: + pass + + async def close(self) -> None: + pass + + async def clear(self) -> None: + self._nodes.clear() + self._edges.clear() + + async def upsert_node(self, node: GraphNode) -> None: + self._nodes[node.id] = node + + async def upsert_nodes(self, nodes: list[GraphNode]) -> None: + for node in nodes: + self._nodes[node.id] = node + + async def upsert_edge(self, edge: GraphEdge) -> None: + self._edges.append(edge) + + async def upsert_edges(self, edges: list[GraphEdge]) -> None: + self._edges.extend(edges) + + async def get_node(self, node_id: str) -> GraphNode | None: + return self._nodes.get(node_id) + + async def delete_node(self, node_id: str) -> bool: + if node_id in self._nodes: + del self._nodes[node_id] + self._edges = [e for e in self._edges if e.source_id != node_id and e.target_id != node_id] + return True + return False + + async def delete_nodes_by_file(self, file_path: str) -> int: + nodes_to_delete = [nid for nid, n in self._nodes.items() if n.file_path == file_path] + for nid in nodes_to_delete: + await self.delete_node(nid) + return len(nodes_to_delete) + + async def get_edges( + self, + node_id: str, + edge_type: EdgeType | None = None, + direction: str = "both", + ) -> list[GraphEdge]: + result: list[GraphEdge] = [] + for e in self._edges: + if e.source_id == node_id and direction in ("outgoing", "both"): + if edge_type is None or e.type == edge_type: + result.append(e) + if e.target_id == node_id and direction in ("incoming", "both"): + if edge_type is None or e.type == edge_type: + result.append(e) + return result + + async def get_neighbors( + self, + node_id: str, + edge_type: EdgeType | None = None, + depth: int = 1, + ) -> list[GraphNode]: + visited: set[str] = set() + queue: list[tuple[str, int]] = [(node_id, 0)] + while queue: + current, d = queue.pop(0) + if d >= depth or current in visited: + continue + visited.add(current) + edges = await self.get_edges(current, edge_type, "both") + for e in edges: + if e.source_id == current and e.target_id not in visited: + if e.target_id in self._nodes: + queue.append((e.target_id, d + 1)) + if e.target_id == current and e.source_id not in visited: + if e.source_id in self._nodes: + queue.append((e.source_id, d + 1)) + return [self._nodes[nid] for nid in visited if nid in self._nodes] + + async def traverse( + self, + start_id: str, + edge_type: EdgeType, + depth: int, + max_nodes: int = 100, + direction: str = "outgoing", + ) -> list[GraphNode]: + visited: dict[str, GraphNode] = {} + queue: list[tuple[str, int]] = [(start_id, 0)] + while queue and len(visited) < max_nodes: + current, d = queue.pop(0) + if d >= depth or current in visited: + continue + if current in self._nodes: + visited[current] = self._nodes[current] + edges = await self.get_edges(current, edge_type, direction) + for e in edges: + neighbor = e.target_id if e.source_id == current else e.source_id + if neighbor not in visited and d + 1 <= depth: + queue.append((neighbor, d + 1)) + return list(visited.values()) + + async def find_nodes( + self, + type: NodeType | None = None, + file_path: str | None = None, + name: str | None = None, + ) -> list[GraphNode]: + results = list(self._nodes.values()) + if type is not None: + results = [n for n in results if n.type == type] + if file_path is not None: + results = [n for n in results if n.file_path == file_path] + if name is not None: + results = [n for n in results if name in n.id or n.structural.name == name or name in n.structural.name] + return results + + async def count_nodes(self) -> int: + return len(self._nodes) + + async def count_edges(self) -> int: + return len(self._edges) + + async def search_nodes( + self, + query_terms: list[str], + match: str = "any", + node_types: list[str] | None = None, + tags: list[str] | None = None, + scope: str | None = None, + top_k: int = 5, + ) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for node in self._nodes.values(): + if node_types and node.type.value not in node_types: + continue + matched = False + doc_lower = (node.semantic.docstring or "").lower() + for term in query_terms: + if term.lower() in doc_lower: + matched = True + break + if matched: + results.append( + { + "id": node.id, + "name": node.structural.name, + "file_path": node.file_path, + "type": node.type.value, + "docstring": node.semantic.docstring, + } + ) + return results[:top_k] + + +class MockVectorStore: + """In-memory vector store for testing.""" + + def __init__(self) -> None: + self._data: dict[str, dict[str, Any]] = {} + + async def connect(self) -> None: + pass + + async def close(self) -> None: + pass + + async def clear(self) -> None: + self._data.clear() + + async def upsert( + self, + ids: list[str], + embeddings: list[list[float]], + metadatas: list[dict[str, Any]], + documents: list[str] | None = None, + ) -> None: + for i, id_val in enumerate(ids): + self._data[id_val] = { + "embedding": embeddings[i] if i < len(embeddings) else [], + "metadata": metadatas[i] if i < len(metadatas) else {}, + "document": documents[i] if documents and i < len(documents) else "", + } + + async def query( + self, + embedding: list[float], + top_k: int = 5, + where: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + results: list[tuple[float, dict[str, Any]]] = [] + for id_val, data in self._data.items(): + if where: + meta = data.get("metadata", {}) + match = all(meta.get(k) == v for k, v in where.items() if not isinstance(v, dict)) + if not match: + continue + dist = sum((a - b) ** 2 for a, b in zip(embedding, data.get("embedding", []))) ** 0.5 + results.append( + ( + dist, + { + "id": id_val, + "score": dist, + "metadata": data.get("metadata", {}), + "document": data.get("document", ""), + }, + ) + ) + results.sort(key=lambda x: x[0]) + return [r[1] for r in results[:top_k]] + + async def get(self, ids: list[str]) -> list[dict[str, Any] | None]: + return [self._data.get(id_val) for id_val in ids] + + async def delete(self, ids: list[str]) -> int: + count = 0 + for id_val in ids: + if id_val in self._data: + del self._data[id_val] + count += 1 + return count + + async def delete_by_file(self, file_path: str) -> int: + to_delete = [ + id_val for id_val, data in self._data.items() if data.get("metadata", {}).get("file_path") == file_path + ] + for id_val in to_delete: + del self._data[id_val] + return len(to_delete) + + +async def seed_mock_graph(graph: MockGraphStore) -> None: + """Seed a small graph for testing.""" + nodes = [ + make_node("file.py::File::file.py::1", NodeType.FILE, "file.py", "file.py", 1, 30), + make_node("file.py::File::os::2", NodeType.FILE, "os", "file.py", 2, 2, signature="import os"), + make_node("file.py::Function::func_a::4", NodeType.FUNCTION, "func_a", "file.py", 4, 8), + make_node( + "file.py::Function::func_b::10", NodeType.FUNCTION, "func_b", "file.py", 10, 14, docstring="Does B things." + ), + make_node("file.py::Function::func_c::16", NodeType.FUNCTION, "func_c", "file.py", 16, 20), + make_node("file.py::Class::Service::22", NodeType.CLASS, "Service", "file.py", 22, 28), + make_node("file.py::Function::method::23", NodeType.FUNCTION, "method", "file.py", 23, 25), + ] + edges = [ + GraphEdge(source_id="file.py::File::file.py::1", target_id="file.py::File::os::2", type=EdgeType.IMPORTS), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Function::func_a::4", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Function::func_b::10", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Function::func_c::16", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Class::Service::22", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::Function::func_a::4", target_id="file.py::Function::func_b::10", type=EdgeType.CALLS + ), + GraphEdge( + source_id="file.py::Function::func_b::10", target_id="file.py::Function::func_c::16", type=EdgeType.CALLS + ), + GraphEdge( + source_id="file.py::Class::Service::22", target_id="file.py::Function::method::23", type=EdgeType.DEFINES + ), + ] + await graph.upsert_nodes(nodes) + await graph.upsert_edges(edges) + + +# --------------------------------------------------------------------------- +# DefaultQueryEngine Tests +# --------------------------------------------------------------------------- + + +class TestDefaultQueryEngineNavigate: + """Tests for DefaultQueryEngine.navigate().""" + + @pytest.mark.asyncio + async def test_navigate_returns_dict(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.navigate("func_a") + assert isinstance(result, dict) + + @pytest.mark.asyncio + async def test_navigate_entity_structure(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.navigate("func_a") + assert "entity" in result + entity = result["entity"] + assert "id" in entity + assert "type" in entity + assert "name" in entity + assert entity["name"] == "func_a" + + @pytest.mark.asyncio + async def test_navigate_with_relationships(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.navigate("func_a", include_relationships=True) + assert "relationships" in result + rels = result["relationships"] + assert "calls" in rels + assert "called_by" in rels + assert "depends_on" in rels + assert "imported_by" in rels + + @pytest.mark.asyncio + async def test_navigate_missing_node(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.navigate("nonexistent_node") + assert "error" in result + + +class TestDefaultQueryEngineTrace: + """Tests for DefaultQueryEngine.trace().""" + + @pytest.mark.asyncio + async def test_trace_returns_list(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.trace("func_a", "CALLS", depth=2) + assert isinstance(result, list) + + @pytest.mark.asyncio + async def test_trace_nodes_have_dict_structure(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.trace("func_a", "CALLS", depth=2) + for node in result: + assert isinstance(node, dict) + assert "id" in node + assert "type" in node + assert "name" in node + + @pytest.mark.asyncio + async def test_trace_finds_call_chain(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.trace("file.py::Function::func_a::4", "CALLS", depth=3) + names = {n["name"] for n in result} + assert "func_b" in names + assert "func_c" in names + + +class TestDefaultQueryEngineGetContext: + """Tests for DefaultQueryEngine.get_context().""" + + @pytest.mark.asyncio + async def test_get_context_returns_rich_structure(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.get_context("file.py") + assert isinstance(result, dict) + assert "self" in result + assert "imports" in result + assert "imported_by" in result + assert "defines" in result + assert "related_patterns" in result + assert "entry_points" in result + assert "data_flow_in" in result + assert "data_flow_out" in result + assert "summary" in result + + @pytest.mark.asyncio + async def test_get_context_self_contains_node_info(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.get_context("file.py") + self_node = result["self"] + assert "name" in self_node + assert "file_path" in self_node + + @pytest.mark.asyncio + async def test_get_context_summary_has_expected_fields(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.get_context("file.py") + summary = result["summary"] + assert "role" in summary + assert "blast_radius" in summary + assert "avg_complexity" in summary + assert "max_complexity" in summary + assert "risk_level" in summary + + @pytest.mark.asyncio + async def test_get_context_missing_file(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.get_context("nonexistent.py") + assert "error" in result + + +class TestDefaultQueryEngineAssessImpact: + """Tests for DefaultQueryEngine.assess_impact().""" + + @pytest.mark.asyncio + async def test_assess_impact_returns_dict(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.assess_impact("func_b") + assert isinstance(result, dict) + + @pytest.mark.asyncio + async def test_assess_impact_has_expected_fields(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.assess_impact("func_b") + assert "affected_files" in result + assert "affected_functions" in result + assert "severity" in result + assert "recommendations" in result + + @pytest.mark.asyncio + async def test_assess_impact_severity_levels(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.assess_impact("func_c") + assert result["severity"] in ("low", "medium", "high") + + @pytest.mark.asyncio + async def test_assess_impact_missing_node(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.assess_impact("nonexistent_node") + assert "error" in result + + +class TestDefaultQueryEngineFindFlow: + """Tests for DefaultQueryEngine.find_flow().""" + + @pytest.mark.asyncio + async def test_find_flow_returns_dict(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.find_flow("func_a", "func_c") + assert isinstance(result, dict) + + @pytest.mark.asyncio + async def test_find_flow_has_expected_fields(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.find_flow("func_a", "func_c") + assert "path" in result + assert "data_transformations" in result + + @pytest.mark.asyncio + async def test_find_flow_same_node(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.find_flow("func_a", "func_a") + assert len(result["path"]) == 1 + assert result["path"][0]["node"] == "func_a" + + @pytest.mark.asyncio + async def test_find_flow_direct_path(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.find_flow("file.py::Function::func_a::4", "file.py::Function::func_b::10") + path_names = [n["node"] for n in result["path"]] + assert "func_a" in path_names + assert "func_b" in path_names + + @pytest.mark.asyncio + async def test_find_flow_no_path(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + result = await engine.find_flow("func_c", "func_a") + assert result["path"] == [] + + +# --------------------------------------------------------------------------- +# SeedWalkEngine Tests +# --------------------------------------------------------------------------- + + +class TestSeedWalkEngineInit: + """Tests for SeedWalkEngine initialization.""" + + @pytest.mark.asyncio + async def test_init_with_mock_stores(self) -> None: + graph = MockGraphStore() + vector = MockVectorStore() + engine = SeedWalkEngine(graph, vector) + assert engine._graph is graph + assert engine._vector is vector + assert engine._alpha == 0.50 + assert engine._beta == 0.30 + assert engine._gamma == 0.20 + + @pytest.mark.asyncio + async def test_init_with_custom_weights(self) -> None: + graph = MockGraphStore() + vector = MockVectorStore() + engine = SeedWalkEngine(graph, vector, alpha=0.6, beta=0.3, gamma=0.1) + assert engine._alpha == 0.6 + assert engine._beta == 0.3 + assert engine._gamma == 0.1 + + +class TestSimpleHashEmbedding: + """Tests for _simple_hash_embedding function.""" + + def test_hash_embedding_returns_list(self) -> None: + result = _simple_hash_embedding("test query") + assert isinstance(result, list) + + def test_hash_embedding_dimensions(self) -> None: + result = _simple_hash_embedding("test query", dim=256) + assert len(result) == 256 + + def test_hash_embedding_consistent(self) -> None: + result1 = _simple_hash_embedding("test query") + result2 = _simple_hash_embedding("test query") + assert result1 == result2 + + def test_hash_embedding_different_inputs_different_vectors(self) -> None: + result1 = _simple_hash_embedding("query a") + result2 = _simple_hash_embedding("query b") + assert result1 != result2 + + def test_hash_embedding_normalized(self) -> None: + result = _simple_hash_embedding("test query") + norm = sum(v * v for v in result) ** 0.5 + assert abs(norm - 1.0) < 0.0001 or norm == 0.0 + + def test_hash_embedding_empty_string(self) -> None: + result = _simple_hash_embedding("") + assert len(result) == 128 + assert all(v == 0.0 for v in result) + + +class TestSeedWalkEngineLocate: + """Tests for SeedWalkEngine.locate().""" + + @pytest.mark.asyncio + async def test_locate_returns_list(self) -> None: + graph = MockGraphStore() + vector = MockVectorStore() + await seed_mock_graph(graph) + engine = SeedWalkEngine(graph, vector) + result = await engine.locate("func_b") + assert isinstance(result, list) + + @pytest.mark.asyncio + async def test_locate_returns_correct_response_structure(self) -> None: + graph = MockGraphStore() + vector = MockVectorStore() + await seed_mock_graph(graph) + engine = SeedWalkEngine(graph, vector) + result = await engine.locate("func") + assert len(result) > 0 + response = result[0] + assert "query" in response + assert "routed_community" in response + assert "seed_count" in response + assert "total_walked" in response + assert "results" in response + assert "structural_map" in response + + @pytest.mark.asyncio + async def test_locate_results_have_expected_fields(self) -> None: + graph = MockGraphStore() + vector = MockVectorStore() + await seed_mock_graph(graph) + engine = SeedWalkEngine(graph, vector) + result = await engine.locate("func_b") + assert len(result) > 0 + response = result[0] + for item in response.get("results", []): + assert hasattr(item, "node_id") or "node_id" in item + assert hasattr(item, "node_type") or "node_type" in item + assert hasattr(item, "name") or "name" in item + assert hasattr(item, "file") or "file" in item + assert hasattr(item, "final_score") or "final_score" in item + + @pytest.mark.asyncio + async def test_locate_with_no_vector_store(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = SeedWalkEngine(graph, None) + result = await engine.locate("func_a") + assert isinstance(result, list) + assert len(result) > 0 + + +# --------------------------------------------------------------------------- +# ChromaVectorStore Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not CHROMA_AVAILABLE, reason="ChromaDB not available (sqlite3 version)") +class TestChromaVectorStore: + """Tests for ChromaVectorStore with in-memory ChromaDB.""" + + @pytest.fixture + async def chroma_store(self) -> ChromaVectorStore: + store = ChromaVectorStore(collection_name="test_collection") + await store.connect() + yield store + await store.clear() + await store.close() + + @pytest.mark.asyncio + async def test_upsert_and_query(self, chroma_store: ChromaVectorStore) -> None: + await chroma_store.upsert( + ids=["node1", "node2"], + embeddings=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + metadatas=[{"file_path": "a.py", "type": "function"}, {"file_path": "b.py", "type": "class"}], + documents=["doc1", "doc2"], + ) + results = await chroma_store.query(embedding=[0.1, 0.2, 0.3], top_k=2) + assert len(results) >= 1 + ids_found = {r["id"] for r in results} + assert "node1" in ids_found + + @pytest.mark.asyncio + async def test_query_with_filter(self, chroma_store: ChromaVectorStore) -> None: + await chroma_store.upsert( + ids=["node1", "node2"], + embeddings=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + metadatas=[{"file_path": "a.py", "type": "function"}, {"file_path": "b.py", "type": "class"}], + documents=["doc1", "doc2"], + ) + results = await chroma_store.query( + embedding=[0.1, 0.2, 0.3], + top_k=5, + where={"type": "function"}, + ) + assert len(results) >= 1 + assert results[0]["metadata"]["type"] == "function" + + @pytest.mark.asyncio + async def test_delete(self, chroma_store: ChromaVectorStore) -> None: + await chroma_store.upsert( + ids=["node1", "node2"], + embeddings=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + metadatas=[{"file_path": "a.py"}, {"file_path": "b.py"}], + ) + deleted = await chroma_store.delete(ids=["node1"]) + assert deleted == 1 + results = await chroma_store.query(embedding=[0.1, 0.2, 0.3], top_k=5) + ids_found = {r["id"] for r in results} + assert "node1" not in ids_found + + @pytest.mark.asyncio + async def test_get_by_ids(self, chroma_store: ChromaVectorStore) -> None: + await chroma_store.upsert( + ids=["node1", "node2"], + embeddings=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + metadatas=[{"file_path": "a.py"}, {"file_path": "b.py"}], + ) + results = await chroma_store.get(ids=["node1", "node2"]) + assert len(results) == 2 + assert results[0]["id"] == "node1" + assert results[1]["id"] == "node2" + + @pytest.mark.asyncio + async def test_clear(self, chroma_store: ChromaVectorStore) -> None: + await chroma_store.upsert( + ids=["node1"], + embeddings=[[0.1, 0.2, 0.3]], + metadatas=[{"file_path": "a.py"}], + ) + await chroma_store.clear() + results = await chroma_store.query(embedding=[0.1, 0.2, 0.3], top_k=5) + assert len(results) == 0 + + +# --------------------------------------------------------------------------- +# Integration Tests +# --------------------------------------------------------------------------- + + +class TestQueryEngineIntegration: + """Integration tests combining QueryEngine with mock stores.""" + + @pytest.mark.asyncio + async def test_navigate_and_trace_work_together(self) -> None: + graph = MockGraphStore() + await seed_mock_graph(graph) + engine = DefaultQueryEngine(graph) + + nav_result = await engine.navigate("file.py::Function::func_a::4") + assert "entity" in nav_result + + trace_result = await engine.trace("file.py::Function::func_a::4", depth=3) + assert len(trace_result) > 0 + + @pytest.mark.asyncio + async def test_locate_across_engines(self) -> None: + graph = MockGraphStore() + vector = MockVectorStore() + await seed_mock_graph(graph) + + default_engine = DefaultQueryEngine(graph) + seed_engine = SeedWalkEngine(graph, vector) + + default_result = await default_engine.locate("func_b") + assert len(default_result) > 0 + + seed_result = await seed_engine.locate("func_b") + assert isinstance(seed_result, list) diff --git a/tests/test_integration_safety.py b/tests/test_integration_safety.py new file mode 100644 index 0000000..680d414 --- /dev/null +++ b/tests/test_integration_safety.py @@ -0,0 +1,515 @@ +"""Integration tests for SMP Agent Safety components.""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest + +from smp.engine.runtime_linker import RuntimeLinker +from smp.engine.safety import ( + AuditLogger, + CheckpointManager, + DryRunSimulator, + GuardEngine, + LockManager, + SessionManager, +) + + +class TestSessionManager: + """Tests for SessionManager.""" + + @pytest.fixture + def mgr(self) -> SessionManager: + return SessionManager(ttl_seconds=3600) + + async def test_open_session_returns_session_id(self, mgr: SessionManager) -> None: + result = await mgr.open_session( + agent_id="agent_001", + task="refactor auth module", + scope=["smp/", "tests/"], + mode="read", + ) + assert "session_id" in result + assert result["session_id"].startswith("ses_") + assert "smp/" in result["granted_scope"] + assert "tests/" in result["granted_scope"] + assert "expires_at" in result + + async def test_open_session_denies_nonexistent_files(self, mgr: SessionManager) -> None: + result = await mgr.open_session( + agent_id="agent_001", + task="create new file", + scope=["nonexistent/file.py"], + mode="write", + ) + assert result["granted_scope"] == [] + assert result["denied_scope"] == ["nonexistent/file.py"] + + async def test_close_session_returns_summary(self, mgr: SessionManager) -> None: + opened = await mgr.open_session( + agent_id="agent_001", + task="test task", + scope=["src/"], + mode="read", + ) + session_id = opened["session_id"] + + closed = await mgr.close_session(session_id) + assert closed is not None + assert closed["session_id"] == session_id + assert "duration_ms" in closed + assert "audit_log_id" in closed + + async def test_close_session_unknown_id_returns_none(self, mgr: SessionManager) -> None: + result = await mgr.close_session("ses_unknown") + assert result is None + + async def test_get_session_returns_session(self, mgr: SessionManager) -> None: + opened = await mgr.open_session( + agent_id="agent_001", + task="test task", + scope=["src/"], + mode="read", + ) + session_id = opened["session_id"] + + session = await mgr.get_session(session_id) + assert session is not None + assert session.session_id == session_id + assert session.agent_id == "agent_001" + assert session.status == "open" + + async def test_get_session_expired_returns_none(self) -> None: + mgr = SessionManager(ttl_seconds=0) + opened = await mgr.open_session( + agent_id="agent_001", + task="test task", + scope=["src/"], + mode="read", + ) + session_id = opened["session_id"] + + import asyncio + + await asyncio.sleep(0.01) + + session = await mgr.get_session(session_id) + assert session is None + + async def test_record_file_access(self, mgr: SessionManager) -> None: + opened = await mgr.open_session( + agent_id="agent_001", + task="test task", + scope=["src/"], + mode="read", + ) + session_id = opened["session_id"] + + mgr.record_file_access(session_id, "src/main.py", "read") + mgr.record_file_access(session_id, "src/main.py", "write") + mgr.record_file_access(session_id, "src/utils.py", "write") + + session = await mgr.get_session(session_id) + assert session is not None + assert "src/main.py" in session.files_read + assert "src/main.py" in session.files_written + assert "src/utils.py" in session.files_written + + +class TestLockManager: + """Tests for LockManager.""" + + @pytest.fixture + def mgr(self) -> LockManager: + return LockManager() + + async def test_acquire_lock_grants_to_new_session(self, mgr: LockManager) -> None: + result = await mgr.acquire("ses_001", ["src/main.py", "src/utils.py"]) + assert result["granted"] == ["src/main.py", "src/utils.py"] + assert result["denied"] == [] + + async def test_acquire_lock_denies_to_other_session(self, mgr: LockManager) -> None: + await mgr.acquire("ses_001", ["src/main.py"]) + result = await mgr.acquire("ses_002", ["src/main.py"]) + assert result["granted"] == [] + assert result["denied"] == ["src/main.py"] + + async def test_acquire_lock_same_session_reacquires(self, mgr: LockManager) -> None: + await mgr.acquire("ses_001", ["src/main.py"]) + result = await mgr.acquire("ses_001", ["src/main.py"]) + assert result["granted"] == ["src/main.py"] + assert result["denied"] == [] + + async def test_is_locked_returns_holder(self, mgr: LockManager) -> None: + await mgr.acquire("ses_001", ["src/main.py"]) + assert mgr.is_locked("src/main.py") == "ses_001" + + async def test_is_locked_returns_none_when_unlocked(self, mgr: LockManager) -> None: + assert mgr.is_locked("src/main.py") is None + + async def test_release_lock_releases(self, mgr: LockManager) -> None: + await mgr.acquire("ses_001", ["src/main.py"]) + await mgr.release("ses_001", ["src/main.py"]) + assert mgr.is_locked("src/main.py") is None + + async def test_release_lock_ignores_wrong_session(self, mgr: LockManager) -> None: + await mgr.acquire("ses_001", ["src/main.py"]) + await mgr.release("ses_002", ["src/main.py"]) + assert mgr.is_locked("src/main.py") == "ses_001" + + async def test_release_all_locks(self, mgr: LockManager) -> None: + await mgr.acquire("ses_001", ["src/main.py", "src/utils.py"]) + await mgr.release_all("ses_001") + assert mgr.is_locked("src/main.py") is None + assert mgr.is_locked("src/utils.py") is None + + +class TestGuardEngine: + """Tests for GuardEngine.""" + + @pytest.fixture + def session_mgr(self) -> SessionManager: + return SessionManager() + + @pytest.fixture + def lock_mgr(self) -> LockManager: + return LockManager() + + @pytest.fixture + def guard(self, session_mgr: SessionManager, lock_mgr: LockManager) -> GuardEngine: + return GuardEngine(session_mgr, lock_mgr) + + async def test_guard_check_clear(self, guard: GuardEngine, session_mgr: SessionManager) -> None: + opened = await session_mgr.open_session( + agent_id="agent_001", + task="update auth", + scope=["smp/"], + mode="write", + ) + session_mgr.record_file_access(opened["session_id"], "smp/core/models.py", "write") + + result = await guard.check( + session_id=opened["session_id"], + target="smp/core/models.py", + caller_count=2, + has_tests=True, + ) + assert result["verdict"] == "clear" + assert result["target"] == "smp/core/models.py" + assert result["checks"]["in_declared_scope"] is True + assert result["checks"]["locked_by_other_agent"] is False + + async def test_guard_blocked_outside_scope(self, guard: GuardEngine, session_mgr: SessionManager) -> None: + opened = await session_mgr.open_session( + agent_id="agent_001", + task="update auth", + scope=["src/auth.py"], + mode="write", + ) + + result = await guard.check( + session_id=opened["session_id"], + target="src/other.py", + ) + assert result["verdict"] == "blocked" + assert "File is outside declared session scope" in result["reasons"] + + async def test_guard_blocked_by_lock( + self, + guard: GuardEngine, + session_mgr: SessionManager, + lock_mgr: LockManager, + ) -> None: + opened = await session_mgr.open_session( + agent_id="agent_001", + task="update auth", + scope=["src/auth.py"], + mode="write", + ) + await lock_mgr.acquire("ses_other", ["src/auth.py"]) + + result = await guard.check( + session_id=opened["session_id"], + target="src/auth.py", + ) + assert result["verdict"] == "blocked" + assert any("Locked by session" in r for r in result["reasons"]) + + async def test_guard_warnings_high_caller_count(self, guard: GuardEngine, session_mgr: SessionManager) -> None: + opened = await session_mgr.open_session( + agent_id="agent_001", + task="update core", + scope=["smp/"], + mode="write", + ) + + result = await guard.check( + session_id=opened["session_id"], + target="smp/core/models.py", + caller_count=10, + has_tests=False, + ) + assert result["verdict"] == "clear" + assert any("cascade" in w for w in result["warnings"]) + assert any("No test coverage" in w for w in result["warnings"]) + + async def test_guard_warnings_public_api(self, guard: GuardEngine, session_mgr: SessionManager) -> None: + opened = await session_mgr.open_session( + agent_id="agent_001", + task="update api", + scope=["src/api.py"], + mode="write", + ) + + result = await guard.check( + session_id=opened["session_id"], + target="src/api.py", + is_public_api=True, + ) + assert "public API" in result["warnings"][0] + + +class TestDryRunSimulator: + """Tests for DryRunSimulator.""" + + @pytest.fixture + def sim(self) -> DryRunSimulator: + return DryRunSimulator() + + def test_simulate_safe_no_changes(self, sim: DryRunSimulator) -> None: + result = sim.simulate( + session_id="ses_001", + file_path="src/utils.py", + proposed_content="def foo(): pass", + ) + assert result["verdict"] == "safe" + assert result["structural_delta"]["nodes_modified"] == 1 + assert result["structural_delta"]["signature_changed"] is False + + def test_simulate_breaking_with_signature_change(self, sim: DryRunSimulator) -> None: + result = sim.simulate( + session_id="ses_001", + file_path="src/api.py", + proposed_content="def new_api(): pass", + current_signature="def old_api():", + proposed_signature="def new_api():", + broken_callers=[{"function": "caller", "file": "src/main.py", "reason": "signature mismatch"}], + ) + assert result["verdict"] == "breaking" + assert result["structural_delta"]["signature_changed"] is True + assert len(result["risks"]) > 0 + + def test_simulate_affected_files(self, sim: DryRunSimulator) -> None: + result = sim.simulate( + session_id="ses_001", + file_path="src/base.py", + proposed_content="class NewBase: pass", + affected_files=["src/derived.py", "src/consumer.py"], + ) + assert result["verdict"] == "safe" + assert len(result["impact"]["affected_files"]) == 2 + + +class TestCheckpointManager: + """Tests for CheckpointManager.""" + + @pytest.fixture + def mgr(self) -> CheckpointManager: + return CheckpointManager() + + @pytest.fixture + def temp_file(self) -> Path: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write("original content\nline 2\nline 3\n") + return Path(f.name) + + def test_create_checkpoint(self, mgr: CheckpointManager, temp_file: Path) -> None: + result = mgr.create("ses_001", [str(temp_file)]) + assert "checkpoint_id" in result + assert result["checkpoint_id"].startswith("chk_") + assert result["files_snapshotted"] == [str(temp_file)] + + def test_rollback_restores_content(self, mgr: CheckpointManager, temp_file: Path) -> None: + create_result = mgr.create("ses_001", [str(temp_file)]) + + temp_file.write_text("modified content\n") + + rollback_result = mgr.rollback(create_result["checkpoint_id"]) + assert rollback_result["status"] == "rolled_back" + assert str(temp_file) in rollback_result["files_restored"] + assert temp_file.read_text() == "original content\nline 2\nline 3\n" + + def test_rollback_unknown_checkpoint(self, mgr: CheckpointManager) -> None: + result = mgr.rollback("chk_unknown") + assert result["status"] == "error" + assert "not found" in result["reason"] + + +class TestAuditLogger: + """Tests for AuditLogger.""" + + @pytest.fixture + def logger(self) -> AuditLogger: + return AuditLogger() + + def test_create_log_returns_id(self, logger: AuditLogger) -> None: + audit_log_id = logger.create_log( + agent_id="agent_001", + task="refactor auth", + session_id="ses_001", + ) + assert audit_log_id.startswith("aud_") + + def test_append_event(self, logger: AuditLogger) -> None: + audit_log_id = logger.create_log( + agent_id="agent_001", + task="refactor auth", + session_id="ses_001", + ) + logger.append_event( + audit_log_id=audit_log_id, + method="write", + target="src/auth.py", + result="success", + checkpoint_id="chk_001", + files=["src/auth.py"], + ) + + log = logger.get_log(audit_log_id) + assert log is not None + assert len(log["events"]) == 1 + assert log["events"][0]["method"] == "write" + assert log["events"][0]["target"] == "src/auth.py" + + def test_append_event_unknown_log_ignores(self, logger: AuditLogger) -> None: + logger.append_event( + audit_log_id="aud_unknown", + method="write", + target="src/auth.py", + ) + + def test_close_log(self, logger: AuditLogger) -> None: + audit_log_id = logger.create_log( + agent_id="agent_001", + task="refactor auth", + session_id="ses_001", + ) + logger.close_log(audit_log_id, status="completed") + + log = logger.get_log(audit_log_id) + assert log is not None + assert log["status"] == "completed" + assert log["closed_at"] != "" + + def test_get_log_unknown_returns_none(self, logger: AuditLogger) -> None: + result = logger.get_log("aud_unknown") + assert result is None + + +class TestRuntimeLinker: + """Tests for RuntimeLinker.""" + + @pytest.fixture + def linker(self) -> RuntimeLinker: + return RuntimeLinker() + + def test_record_call_returns_edge(self, linker: RuntimeLinker) -> None: + edge = linker.record_call( + source_id="node_001", + target_id="node_002", + session_id="ses_001", + duration_ms=50, + ) + assert edge.source_id == "node_001" + assert edge.target_id == "node_002" + assert edge.edge_type == "CALLS_RUNTIME" + assert edge.duration_ms == 50 + + def test_record_call_increments_counts(self, linker: RuntimeLinker) -> None: + linker.record_call("node_001", "node_002", "ses_001") + linker.record_call("node_001", "node_002", "ses_001") + linker.record_call("node_001", "node_002", "ses_001") + + stats = linker.get_stats() + assert stats["total_calls"] == 3 + assert stats["unique_paths"] == 1 + + def test_start_trace_returns_trace_id(self, linker: RuntimeLinker) -> None: + trace_id = linker.start_trace(session_id="ses_001", agent_id="agent_001") + assert trace_id.startswith("trc_") + + def test_end_trace_returns_trace(self, linker: RuntimeLinker) -> None: + trace_id = linker.start_trace(session_id="ses_001", agent_id="agent_001") + linker.record_call("node_001", "node_002", "ses_001") + linker.record_call("node_002", "node_003", "ses_001") + + trace = linker.end_trace(trace_id) + assert trace is not None + assert trace.trace_id == trace_id + assert len(trace.edges) == 2 + assert len(trace.nodes_visited) == 3 + + def test_end_trace_unknown_returns_none(self, linker: RuntimeLinker) -> None: + result = linker.end_trace("trc_unknown") + assert result is None + + def test_get_trace(self, linker: RuntimeLinker) -> None: + trace_id = linker.start_trace(session_id="ses_001", agent_id="agent_001") + trace = linker.get_trace(trace_id) + assert trace is not None + assert trace.trace_id == trace_id + + def test_get_session_traces(self, linker: RuntimeLinker) -> None: + linker.start_trace(session_id="ses_001", agent_id="agent_001") + trace = linker.get_trace(linker.start_trace(session_id="ses_001", agent_id="agent_001")) + assert trace is not None + assert trace.session_id == "ses_001" + + def test_get_hot_paths(self, linker: RuntimeLinker) -> None: + linker.record_call("A", "B", "ses_001") + linker.record_call("A", "B", "ses_001") + linker.record_call("A", "B", "ses_001") + linker.record_call("A", "B", "ses_001") + linker.record_call("A", "B", "ses_001") + + linker.record_call("A", "C", "ses_001") + linker.record_call("A", "C", "ses_001") + + linker.record_call("B", "C", "ses_001") + + hot = linker.get_hot_paths(threshold=5) + assert len(hot) == 1 + assert hot[0]["source_id"] == "A" + assert hot[0]["target_id"] == "B" + assert hot[0]["call_count"] == 5 + + def test_get_hot_paths_default_threshold(self, linker: RuntimeLinker) -> None: + linker.record_call("A", "B", "ses_001") + linker.record_call("A", "B", "ses_001") + + hot = linker.get_hot_paths() + assert len(hot) == 0 + + def test_clear(self, linker: RuntimeLinker) -> None: + linker.record_call("node_001", "node_002", "ses_001") + linker.start_trace(session_id="ses_001", agent_id="agent_001") + + linker.clear() + + stats = linker.get_stats() + assert stats["total_calls"] == 0 + assert stats["active_traces"] == 0 + + def test_get_stats(self, linker: RuntimeLinker) -> None: + linker.record_call("A", "B", "ses_001") + linker.record_call("B", "C", "ses_001") + linker.start_trace(session_id="ses_001", agent_id="agent_001") + + stats = linker.get_stats() + assert stats["total_calls"] == 2 + assert stats["unique_paths"] == 2 + assert stats["active_traces"] == 1 + assert stats["sessions_with_traces"] == 1 diff --git a/tests/test_integration_sandbox.py b/tests/test_integration_sandbox.py new file mode 100644 index 0000000..d0935df --- /dev/null +++ b/tests/test_integration_sandbox.py @@ -0,0 +1,218 @@ +"""Integration tests for SMP Sandbox Runtime components.""" + +from __future__ import annotations + +import asyncio +import tempfile +from pathlib import Path + +import pytest + +try: + from smp.sandbox.docker_sandbox import DockerSandbox + DOCKER_AVAILABLE = True +except ImportError: + DOCKER_AVAILABLE = False + +from smp.sandbox.ebpf_collector import EBPFCollector +from smp.sandbox.executor import ExecutionResult, SandboxConfig, SandboxExecutor +from smp.sandbox.spawner import SandboxInfo, SandboxSpawner + + +class TestSandboxSpawner: + """Tests for SandboxSpawner directory-based sandbox management.""" + + @pytest.fixture + def temp_root(self) -> Path: + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + @pytest.fixture + def spawner(self, temp_root: Path) -> SandboxSpawner: + return SandboxSpawner(sandbox_root=temp_root) + + def test_spawn_creates_directory(self, spawner: SandboxSpawner, temp_root: Path) -> None: + info = spawner.spawn() + assert info.sandbox_id.startswith("sandbox_") + assert Path(info.root_path).exists() + assert Path(info.root_path).is_dir() + + def test_spawn_with_name(self, spawner: SandboxSpawner, temp_root: Path) -> None: + info = spawner.spawn(name="my_sandbox") + assert Path(info.root_path).name == "my_sandbox" + + def test_spawn_with_files(self, spawner: SandboxSpawner, temp_root: Path) -> None: + files = { + "test.txt": "hello world", + "subdir/code.py": "print('hello')", + } + info = spawner.spawn(files=files) + assert (Path(info.root_path) / "test.txt").read_text() == "hello world" + assert (Path(info.root_path) / "subdir" / "code.py").read_text() == "print('hello')" + + def test_get_returns_sandbox(self, spawner: SandboxSpawner) -> None: + info = spawner.spawn() + retrieved = spawner.get(info.sandbox_id) + assert retrieved is not None + assert retrieved.sandbox_id == info.sandbox_id + + def test_get_nonexistent_returns_none(self, spawner: SandboxSpawner) -> None: + assert spawner.get("nonexistent_id") is None + + def test_list_active_returns_all(self, spawner: SandboxSpawner) -> None: + info1 = spawner.spawn() + info2 = spawner.spawn() + active = spawner.list_active() + assert len(active) == 2 + assert info1 in active + assert info2 in active + + def test_list_active_empty_after_destroy(self, spawner: SandboxSpawner) -> None: + info = spawner.spawn() + assert len(spawner.list_active()) == 1 + spawner.destroy(info.sandbox_id) + assert len(spawner.list_active()) == 0 + + def test_destroy_removes_directory(self, spawner: SandboxSpawner, temp_root: Path) -> None: + info = spawner.spawn() + root_path = Path(info.root_path) + assert root_path.exists() + result = spawner.destroy(info.sandbox_id) + assert result is True + assert not root_path.exists() + + def test_destroy_nonexistent_returns_false(self, spawner: SandboxSpawner) -> None: + assert spawner.destroy("nonexistent") is False + + def test_spawn_info_structure(self, spawner: SandboxSpawner) -> None: + info = spawner.spawn() + assert isinstance(info, SandboxInfo) + assert info.sandbox_id + assert info.root_path + assert info.created_at + assert info.status == "created" + + +class TestSandboxExecutor: + """Tests for SandboxExecutor async command execution.""" + + @pytest.fixture + def executor(self) -> SandboxExecutor: + return SandboxExecutor(config=SandboxConfig(timeout_seconds=10)) + + @pytest.mark.asyncio + async def test_execute_simple_command(self, executor: SandboxExecutor) -> None: + result = await executor.execute(["echo", "hello"]) + assert result.exit_code in (0, -1) + assert "hello" in result.stdout + + @pytest.mark.asyncio + async def test_execute_with_stdin(self, executor: SandboxExecutor) -> None: + result = await executor.execute( + command=["cat"], + stdin="test input", + ) + assert result.exit_code in (0, -1) + assert "test input" in result.stdout + + @pytest.mark.asyncio + async def test_execute_with_cwd(self, executor: SandboxExecutor) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + result = await executor.execute( + command=["pwd"], + cwd=tmpdir, + ) + assert result.exit_code in (0, -1) + assert tmpdir in result.stdout + + @pytest.mark.asyncio + async def test_execute_records_duration(self, executor: SandboxExecutor) -> None: + result = await executor.execute(["sleep", "0.1"]) + assert result.duration_ms >= 0 + + @pytest.mark.asyncio + async def test_execute_nonzero_exit_code(self, executor: SandboxExecutor) -> None: + result = await executor.execute(["ls", "/nonexistent_path_12345"]) + assert result.exit_code != 0 + assert "No such file" in result.stderr or result.exit_code > 0 + + @pytest.mark.asyncio + async def test_execute_python_code(self, executor: SandboxExecutor) -> None: + result = await executor.execute_python("print('hello from python')") + assert result.exit_code in (0, -1) + assert "hello from python" in result.stdout + + @pytest.mark.asyncio + async def test_execution_result_structure(self, executor: SandboxExecutor) -> None: + result = await executor.execute(["echo", "test"]) + assert isinstance(result, ExecutionResult) + assert result.execution_id.startswith("exec_") + assert result.exit_code in (0, -1) + assert result.stdout + assert result.duration_ms >= 0 + + +@pytest.mark.skipif(not DOCKER_AVAILABLE, reason="docker not available") +class TestDockerSandbox: + """Tests for DockerSandbox Docker container management.""" + + @pytest.fixture + def sandbox(self) -> DockerSandbox: + return DockerSandbox() + + def test_spawn_creates_container(self, sandbox: DockerSandbox) -> None: + container_id = sandbox.spawn( + name="test_sandbox", + image="alpine:latest", + services=[], + ) + assert container_id + sandbox.destroy() + + def test_execute_requires_container(self) -> None: + sandbox = DockerSandbox() + with pytest.raises(RuntimeError, match="No container spawned"): + sandbox.execute("echo hello", timeout=5) + + def test_execute_in_container(self, sandbox: DockerSandbox) -> None: + sandbox.spawn(name="test_exec", image="alpine:latest", services=[]) + output = sandbox.execute("echo hello from container", timeout=10) + assert "hello from container" in output + sandbox.destroy() + + def test_destroy_removes_container(self, sandbox: DockerSandbox) -> None: + sandbox.spawn(name="test_destroy", image="alpine:latest", services=[]) + sandbox.destroy() + + +class TestEBPFCollector: + """Tests for EBPFCollector eBPF tracing.""" + + @pytest.fixture + def collector(self) -> EBPFCollector: + return EBPFCollector() + + def test_start_trace_returns_trace_id(self, collector: EBPFCollector) -> None: + trace_id = collector.start_trace(session_id="session_1") + assert trace_id + assert len(trace_id) == 36 + + def test_stop_trace_removes_active(self, collector: EBPFCollector) -> None: + trace_id = collector.start_trace(session_id="session_1") + collector.stop_trace(trace_id) + assert trace_id not in collector._active_traces + + def test_stop_nonexistent_trace_logs_error(self, collector: EBPFCollector) -> None: + collector.stop_trace("nonexistent_trace_id") + + def test_get_traces_returns_list(self, collector: EBPFCollector) -> None: + traces = collector.get_traces() + assert isinstance(traces, list) + + def test_multiple_traces(self, collector: EBPFCollector) -> None: + trace_id1 = collector.start_trace(session_id="session_1") + trace_id2 = collector.start_trace(session_id="session_2") + assert trace_id1 != trace_id2 + assert len(collector._active_traces) == 2 + collector.stop_trace(trace_id1) + assert len(collector._active_traces) == 1 diff --git a/tests/test_integration_vector_store.py b/tests/test_integration_vector_store.py new file mode 100644 index 0000000..3fe8cd5 --- /dev/null +++ b/tests/test_integration_vector_store.py @@ -0,0 +1,312 @@ +"""Integration tests for ChromaVectorStore.""" + +from __future__ import annotations + +import sys + +import pysqlite3 + +sys.modules["sqlite3"] = pysqlite3 # type: ignore[assignment] + +import pytest + +pytestmark = pytest.mark.asyncio + +try: + from smp.store.chroma_store import ChromaVectorStore + + CHROMA_AVAILABLE = True +except Exception as e: # noqa: BLE001 + CHROMA_AVAILABLE = False + _CHROMA_IMPORT_ERROR = str(e) + +skip_if_no_chroma = pytest.mark.skipif(not CHROMA_AVAILABLE, reason="ChromaDB unavailable") + +_DIM = 8 +_VEC_A = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] +_VEC_B = [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] +_VEC_C = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] + + +@pytest.fixture() +async def store() -> ChromaVectorStore: + """Provide a connected in-memory ChromaVectorStore.""" + s = ChromaVectorStore(collection_name="test_collection") + await s.connect() + yield s + await s.close() + + +@skip_if_no_chroma +class TestInit: + def test_defaults(self) -> None: + s = ChromaVectorStore() + assert s._collection_name == "smp_code_embeddings" + assert s._persist_dir is None + assert s._client is None + assert s._collection is None + + def test_custom_params(self) -> None: + s = ChromaVectorStore(collection_name="my_col", persist_dir="/tmp/chroma") + assert s._collection_name == "my_col" + assert s._persist_dir == "/tmp/chroma" + + +@skip_if_no_chroma +class TestConnect: + async def test_connect_in_memory(self) -> None: + s = ChromaVectorStore(collection_name="conn_test") + await s.connect() + assert s._client is not None + assert s._collection is not None + await s.close() + + async def test_close_resets_state(self) -> None: + s = ChromaVectorStore(collection_name="close_test") + await s.connect() + await s.close() + assert s._client is None + assert s._collection is None + + +@skip_if_no_chroma +class TestUpsert: + async def test_upsert_multiple(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["id1", "id2"], + embeddings=[_VEC_A, _VEC_B], + metadatas=[{"file_path": "a.py"}, {"file_path": "b.py"}], + documents=["doc a", "doc b"], + ) + + async def test_upsert_without_documents(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["id3"], + embeddings=[_VEC_C], + metadatas=[{"file_path": "c.py"}], + ) + + async def test_upsert_raises_when_not_connected(self) -> None: + s = ChromaVectorStore() + with pytest.raises(RuntimeError, match="not connected"): + await s.upsert(ids=["x"], embeddings=[_VEC_A], metadatas=[{}]) + + async def test_upsert_overwrites_existing(self, store: ChromaVectorStore) -> None: + await store.upsert(ids=["dup"], embeddings=[_VEC_A], metadatas=[{"v": "1"}], documents=["first"]) + await store.upsert(ids=["dup"], embeddings=[_VEC_B], metadatas=[{"v": "2"}], documents=["second"]) + results = await store.get(["dup"]) + assert len(results) == 1 + assert results[0] is not None + assert results[0]["document"] == "second" + + +@skip_if_no_chroma +class TestQuery: + async def test_query_returns_top_k(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["q1", "q2", "q3"], + embeddings=[_VEC_A, _VEC_B, _VEC_C], + metadatas=[{"file_path": "f.py"}] * 3, + documents=["d1", "d2", "d3"], + ) + results = await store.query(embedding=_VEC_A, top_k=2) + assert len(results) == 2 + + async def test_query_result_structure(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["struct1"], + embeddings=[_VEC_A], + metadatas=[{"file_path": "s.py", "kind": "function"}], + documents=["source code"], + ) + results = await store.query(embedding=_VEC_A, top_k=1) + assert len(results) == 1 + r = results[0] + assert "id" in r + assert "score" in r + assert "metadata" in r + assert "document" in r + + async def test_query_raises_when_not_connected(self) -> None: + s = ChromaVectorStore() + with pytest.raises(RuntimeError, match="not connected"): + await s.query(embedding=_VEC_A) + + async def test_query_with_where_filter(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["f1", "f2"], + embeddings=[_VEC_A, _VEC_B], + metadatas=[{"file_path": "match.py"}, {"file_path": "other.py"}], + documents=["m", "o"], + ) + results = await store.query(embedding=_VEC_A, top_k=5, where={"file_path": "match.py"}) + assert all(r["metadata"]["file_path"] == "match.py" for r in results) + + +@skip_if_no_chroma +class TestGet: + async def test_get_by_ids(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["get1", "get2"], + embeddings=[_VEC_A, _VEC_B], + metadatas=[{"file_path": "g1.py"}, {"file_path": "g2.py"}], + documents=["doc1", "doc2"], + ) + results = await store.get(["get1", "get2"]) + assert len(results) == 2 + ids_returned = {r["id"] for r in results if r} + assert "get1" in ids_returned + assert "get2" in ids_returned + + async def test_get_result_structure(self, store: ChromaVectorStore) -> None: + await store.upsert(ids=["gs1"], embeddings=[_VEC_A], metadatas=[{"x": "y"}], documents=["hello"]) + results = await store.get(["gs1"]) + assert len(results) == 1 + r = results[0] + assert r is not None + assert r["id"] == "gs1" + assert r["metadata"] == {"x": "y"} + assert r["document"] == "hello" + + async def test_get_raises_when_not_connected(self) -> None: + s = ChromaVectorStore() + with pytest.raises(RuntimeError, match="not connected"): + await s.get(["x"]) + + +@skip_if_no_chroma +class TestDelete: + async def test_delete_by_ids(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["del1", "del2", "del3"], + embeddings=[_VEC_A, _VEC_B, _VEC_C], + metadatas=[{"file_path": "d.py"}] * 3, + documents=["a", "b", "c"], + ) + count = await store.delete(["del1", "del2"]) + assert count == 2 + + async def test_delete_removes_items(self, store: ChromaVectorStore) -> None: + await store.upsert(ids=["rm1"], embeddings=[_VEC_A], metadatas=[{"file_path": "rm.py"}], documents=["x"]) + await store.delete(["rm1"]) + results = await store.get(["rm1"]) + assert results == [] + + async def test_delete_raises_when_not_connected(self) -> None: + s = ChromaVectorStore() + with pytest.raises(RuntimeError, match="not connected"): + await s.delete(["x"]) + + +@skip_if_no_chroma +class TestDeleteByFile: + async def test_delete_by_file(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["dbf1", "dbf2", "dbf3"], + embeddings=[_VEC_A, _VEC_B, _VEC_C], + metadatas=[ + {"file_path": "target.py"}, + {"file_path": "target.py"}, + {"file_path": "keep.py"}, + ], + documents=["a", "b", "c"], + ) + await store.delete_by_file("target.py") + results = await store.query(embedding=_VEC_A, top_k=10, where={"file_path": "target.py"}) + assert results == [] + + async def test_delete_by_file_returns_minus_one(self, store: ChromaVectorStore) -> None: + await store.upsert(ids=["dbf4"], embeddings=[_VEC_A], metadatas=[{"file_path": "z.py"}], documents=["z"]) + ret = await store.delete_by_file("z.py") + assert ret == -1 + + async def test_delete_by_file_raises_when_not_connected(self) -> None: + s = ChromaVectorStore() + with pytest.raises(RuntimeError, match="not connected"): + await s.delete_by_file("x.py") + + +@skip_if_no_chroma +class TestAddCodeEmbedding: + async def test_add_code_embedding(self, store: ChromaVectorStore) -> None: + await store.add_code_embedding( + node_id="node_func_foo", + embedding=_VEC_A, + metadata={"file_path": "foo.py", "kind": "function", "name": "foo"}, + document="def foo(): pass", + ) + results = await store.get(["node_func_foo"]) + assert len(results) == 1 + r = results[0] + assert r is not None + assert r["id"] == "node_func_foo" + assert r["document"] == "def foo(): pass" + + async def test_add_code_embedding_default_document(self, store: ChromaVectorStore) -> None: + await store.add_code_embedding( + node_id="node_no_doc", + embedding=_VEC_B, + metadata={"file_path": "bar.py"}, + ) + results = await store.get(["node_no_doc"]) + assert results[0] is not None + assert results[0]["document"] == "" + + +@skip_if_no_chroma +class TestQuerySimilar: + async def test_query_similar_returns_list(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["qs1", "qs2"], + embeddings=[_VEC_A, _VEC_B], + metadatas=[{"file_path": "q.py"}] * 2, + documents=["d1", "d2"], + ) + results = await store.query_similar(embedding=_VEC_A, top_k=2) + assert isinstance(results, list) + assert len(results) == 2 + + async def test_query_similar_result_keys(self, store: ChromaVectorStore) -> None: + await store.upsert(ids=["qs3"], embeddings=[_VEC_C], metadatas=[{"file_path": "r.py"}], documents=["doc"]) + results = await store.query_similar(embedding=_VEC_C, top_k=1) + assert len(results) >= 1 + r = results[0] + for key in ("id", "score", "metadata", "document"): + assert key in r + + async def test_query_similar_with_where(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["wh1", "wh2"], + embeddings=[_VEC_A, _VEC_B], + metadatas=[{"file_path": "inc.py"}, {"file_path": "exc.py"}], + documents=["i", "e"], + ) + results = await store.query_similar(embedding=_VEC_A, top_k=5, where={"file_path": "inc.py"}) + assert all(r["metadata"]["file_path"] == "inc.py" for r in results) + + +@skip_if_no_chroma +class TestClear: + async def test_clear_empties_collection(self, store: ChromaVectorStore) -> None: + await store.upsert( + ids=["c1", "c2"], + embeddings=[_VEC_A, _VEC_B], + metadatas=[{"file_path": "x.py"}] * 2, + documents=["a", "b"], + ) + await store.clear() + results = await store.get(["c1", "c2"]) + assert results == [] + + async def test_clear_allows_new_inserts(self, store: ChromaVectorStore) -> None: + await store.upsert(ids=["old"], embeddings=[_VEC_A], metadatas=[{"file_path": "old.py"}], documents=["old"]) + await store.clear() + await store.upsert(ids=["new"], embeddings=[_VEC_B], metadatas=[{"file_path": "new.py"}], documents=["new"]) + results = await store.get(["new"]) + assert len(results) == 1 + assert results[0]["document"] == "new" + + async def test_clear_raises_when_not_connected(self) -> None: + s = ChromaVectorStore() + with pytest.raises(RuntimeError, match="not connected"): + await s.clear() diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..d290acb --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,301 @@ +"""Tests for core msgspec data models — SMP(3) partitioned schema.""" + +from __future__ import annotations + +import msgspec + +from smp.core.models import ( + AnnotateParams, + AuditGetParams, + CheckpointParams, + ContextParams, + Document, + DryRunParams, + EdgeType, + EnrichBatchParams, + EnrichParams, + FlowParams, + GraphEdge, + GraphNode, + GuardCheckParams, + ImpactParams, + JsonRpcError, + JsonRpcRequest, + JsonRpcResponse, + Language, + LocateParams, + LockParams, + NavigateParams, + NodeType, + ParseError, + RollbackParams, + SearchParams, + SemanticProperties, + SessionCloseParams, + SessionOpenParams, + StructuralProperties, + TagParams, + TraceParams, + UpdateParams, +) +from tests.conftest import make_edge, make_node + + +class TestGraphNode: + def test_defaults(self) -> None: + node = make_node() + assert node.id == "func_login" + assert node.type == NodeType.FUNCTION + assert node.structural.name == "login" + assert node.structural.start_line == 10 + + def test_fingerprint(self) -> None: + node = make_node() + assert node.fingerprint() == "src/auth/login.py::Function::login::10" + + def test_serialization_roundtrip(self) -> None: + node = make_node( + semantic=SemanticProperties( + docstring="Authenticate user.", + status="enriched", + source_hash="abc123", + ) + ) + data = msgspec.json.encode(node) + decoded = msgspec.json.decode(data, type=GraphNode) + assert decoded.id == node.id + assert decoded.semantic is not None + assert decoded.semantic.docstring == "Authenticate user." + assert decoded.semantic.status == "enriched" + + def test_structural_partition(self) -> None: + node = make_node() + assert node.structural.signature == "def login(user: User) -> Token:" + assert node.structural.lines == 16 + assert node.structural.complexity == 0 + + def test_semantic_partition(self) -> None: + node = make_node() + assert node.semantic.status == "enriched" + assert node.semantic.docstring == "Authenticate user and return token." + assert node.semantic.decorators == [] + assert node.semantic.tags == [] + + def test_all_node_types(self) -> None: + for nt in NodeType: + node = make_node(id=f"node_{nt.value}", type=nt) + assert node.type == nt + + +class TestGraphEdge: + def test_defaults(self) -> None: + edge = make_edge() + assert edge.source_id == "func_login" + assert edge.target_id == "func_validate" + assert edge.type == EdgeType.CALLS + + def test_serialization_roundtrip(self) -> None: + edge = make_edge() + data = msgspec.json.encode(edge) + decoded = msgspec.json.decode(data, type=GraphEdge) + assert decoded.source_id == edge.source_id + assert decoded.type == edge.type + + def test_all_edge_types(self) -> None: + for et in EdgeType: + edge = make_edge(edge_type=et) + assert edge.type == et + + +class TestStructuralProperties: + def test_defaults(self) -> None: + sp = StructuralProperties() + assert sp.name == "" + assert sp.complexity == 0 + assert sp.parameters == 0 + + def test_frozen(self) -> None: + sp = StructuralProperties(name="test", lines=10) + try: + sp.name = "other" # type: ignore[misc] + except AttributeError: + pass + else: + raise AssertionError("Should raise") + + +class TestSemanticProperties: + def test_defaults(self) -> None: + sp = SemanticProperties() + assert sp.status == "no_metadata" + assert sp.docstring == "" + assert sp.description is None + assert sp.manually_set is False + assert sp.source_hash == "" + + def test_with_annotations(self) -> None: + from smp.core.models import Annotations + + sp = SemanticProperties( + docstring="Test function.", + annotations=Annotations( + params={"x": "int"}, + returns="str", + throws=["ValueError"], + ), + ) + assert sp.annotations is not None + assert sp.annotations.params == {"x": "int"} + assert sp.annotations.returns == "str" + assert sp.annotations.throws == ["ValueError"] + + +class TestDocument: + def test_empty(self) -> None: + doc = Document(file_path="test.py") + assert doc.nodes == [] + assert doc.edges == [] + assert doc.errors == [] + + def test_with_content(self) -> None: + nodes = [ + make_node(), + make_node( + id="func_logout", + structural=StructuralProperties( + name="logout", + file="src/auth/login.py", + signature="def logout():", + start_line=30, + end_line=35, + lines=6, + ), + ), + ] + edges = [make_edge()] + doc = Document( + file_path="src/auth.py", + language=Language.PYTHON, + nodes=nodes, + edges=edges, + ) + assert len(doc.nodes) == 2 + assert len(doc.edges) == 1 + + def test_with_errors(self) -> None: + doc = Document( + file_path="bad.py", + errors=[ParseError(message="unexpected token", line=5, column=10)], + ) + assert len(doc.errors) == 1 + assert doc.errors[0].line == 5 + + +class TestJsonRpc: + def test_request(self) -> None: + req = JsonRpcRequest(method="smp/navigate", params={"id": "x"}, id=1) + assert req.jsonrpc == "2.0" + assert req.method == "smp/navigate" + + def test_request_serialization(self) -> None: + req = JsonRpcRequest(method="smp/context", params={"file_path": "test.py"}, id=42) + data = msgspec.json.encode(req) + decoded = msgspec.json.decode(data, type=JsonRpcRequest) + assert decoded.id == 42 + assert decoded.params["file_path"] == "test.py" + + def test_response_success(self) -> None: + resp = JsonRpcResponse(result={"nodes": 5}, id=1) + assert resp.error is None + + def test_response_error(self) -> None: + err = JsonRpcError(code=-32601, message="Method not found") + resp = JsonRpcResponse(error=err, id=1) + assert resp.result is None + assert resp.error is not None + assert resp.error.code == -32601 + + +class TestQueryParams: + def test_navigate_params(self) -> None: + p = NavigateParams(query="login") + assert p.include_relationships is True + + def test_trace_params_defaults(self) -> None: + p = TraceParams(start="x") + assert p.relationship == "CALLS" + assert p.depth == 3 + + def test_context_params(self) -> None: + p = ContextParams(file_path="test.py", scope="review") + assert p.scope == "review" + + def test_update_params(self) -> None: + p = UpdateParams(file_path="test.py", content="x = 1") + assert p.language == Language.PYTHON + + def test_impact_params(self) -> None: + p = ImpactParams(entity="x") + assert p.change_type == "delete" + + def test_locate_params(self) -> None: + p = LocateParams(query="find auth logic") + assert p.top_k == 5 + + def test_flow_params(self) -> None: + p = FlowParams(start="a", end="b") + assert p.flow_type == "data" + + +class TestSMP3Params: + def test_enrich_params(self) -> None: + p = EnrichParams(node_id="func_x") + assert p.force is False + + def test_enrich_batch_params(self) -> None: + p = EnrichBatchParams(scope="package:src/auth") + assert p.force is False + + def test_session_open_params(self) -> None: + p = SessionOpenParams(agent_id="agent_1", task="fix bug", scope=["src/auth.py"], mode="write") + assert p.mode == "write" + + def test_session_close_params(self) -> None: + p = SessionCloseParams(session_id="ses_1", status="completed") + assert p.status == "completed" + + def test_guard_check_params(self) -> None: + p = GuardCheckParams(session_id="ses_1", target="src/auth.py") + assert p.target == "src/auth.py" + + def test_dryrun_params(self) -> None: + p = DryRunParams(session_id="ses_1", file_path="src/auth.py", proposed_content="x=1") + assert p.proposed_content == "x=1" + + def test_checkpoint_params(self) -> None: + p = CheckpointParams(session_id="ses_1", files=["src/auth.py"]) + assert len(p.files) == 1 + + def test_rollback_params(self) -> None: + p = RollbackParams(session_id="ses_1", checkpoint_id="chk_1") + assert p.checkpoint_id == "chk_1" + + def test_search_params(self) -> None: + p = SearchParams(query="auth login", match="all") + assert p.match == "all" + + def test_annotate_params(self) -> None: + p = AnnotateParams(node_id="func_x", description="Handles login", tags=["auth"]) + assert p.force is False + + def test_tag_params(self) -> None: + p = TagParams(scope="package:src/auth", tags=["billing"], action="add") + assert p.action == "add" + + def test_lock_params(self) -> None: + p = LockParams(session_id="ses_1", files=["src/auth.py"]) + assert len(p.files) == 1 + + def test_audit_get_params(self) -> None: + p = AuditGetParams(audit_log_id="aud_1") + assert p.audit_log_id == "aud_1" diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..21d9fe8 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,226 @@ +"""Tests for the tree-sitter parser layer — SMP(3).""" + +from __future__ import annotations + +from smp.core.models import EdgeType, Language, NodeType +from smp.parser.base import detect_language +from smp.parser.python_parser import PythonParser +from smp.parser.registry import ParserRegistry +from smp.parser.typescript_parser import TypeScriptParser + +# ========================================================================= +# Language detection +# ========================================================================= + + +class TestDetectLanguage: + def test_python(self) -> None: + assert detect_language("foo.py") == Language.PYTHON + + def test_typescript(self) -> None: + assert detect_language("foo.ts") == Language.TYPESCRIPT + + def test_tsx(self) -> None: + assert detect_language("foo.tsx") == Language.TYPESCRIPT + + def test_jsx(self) -> None: + assert detect_language("foo.jsx") == Language.TYPESCRIPT + + def test_unknown(self) -> None: + assert detect_language("foo.rs") == Language.UNKNOWN + + def test_no_extension(self) -> None: + assert detect_language("Makefile") == Language.UNKNOWN + + +# ========================================================================= +# Python parser +# ========================================================================= + + +class TestPythonParser: + def _parse(self, src: str): + p = PythonParser() + return p.parse(src, "test.py") + + def test_empty_file(self) -> None: + doc = self._parse("") + assert len(doc.errors) == 0 + assert any(n.type == NodeType.FILE for n in doc.nodes) + + def test_simple_function(self) -> None: + doc = self._parse("def hello():\n pass\n") + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 1 + assert funcs[0].structural.name == "hello" + assert funcs[0].structural.signature == "def hello()" + + def test_typed_function(self) -> None: + doc = self._parse("def add(a: int, b: int) -> int:\n return a + b\n") + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 1 + assert funcs[0].structural.name == "add" + assert "int" in funcs[0].structural.signature + + def test_function_with_docstring(self) -> None: + doc = self._parse('def foo():\n """A docstring."""\n pass\n') + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 1 + assert funcs[0].semantic.docstring == "A docstring." + + def test_class(self) -> None: + doc = self._parse("class Foo:\n pass\n") + classes = [n for n in doc.nodes if n.type == NodeType.CLASS] + assert len(classes) == 1 + assert classes[0].structural.name == "Foo" + assert classes[0].structural.signature == "class Foo" + + def test_class_with_bases(self) -> None: + doc = self._parse("class Child(Parent):\n pass\n") + classes = [n for n in doc.nodes if n.type == NodeType.CLASS] + assert len(classes) == 1 + assert "Parent" in classes[0].structural.signature + inherits = [e for e in doc.edges if e.type == EdgeType.IMPLEMENTS] + assert len(inherits) == 1 + + def test_method_in_class(self) -> None: + doc = self._parse("class Foo:\n def bar(self):\n pass\n") + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + bar_funcs = [f for f in funcs if f.structural.name == "bar"] + assert len(bar_funcs) == 1 + + def test_import(self) -> None: + doc = self._parse("import os\nimport sys\n") + imports = [n for n in doc.nodes if n.structural.signature.startswith("import")] + assert len(imports) == 2 + names = {i.structural.name for i in imports} + assert "os" in names + assert "sys" in names + + def test_from_import(self) -> None: + doc = self._parse("from os.path import join\n") + imports = [n for n in doc.nodes if n.structural.signature.startswith("from")] + assert len(imports) == 1 + assert "os.path" in imports[0].structural.name + + def test_call_edge(self) -> None: + doc = self._parse("def a():\n b()\n") + calls = [e for e in doc.edges if e.type == EdgeType.CALLS] + assert len(calls) == 1 + + def test_decorator(self) -> None: + doc = self._parse("@app.route('/home')\ndef handler():\n pass\n") + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 1 + assert "app.route" in funcs[0].semantic.decorators + + def test_contains_edge_file_to_func(self) -> None: + doc = self._parse("def foo():\n pass\n") + defines = [e for e in doc.edges if e.type == EdgeType.DEFINES] + file_defines = [e for e in defines if "File" in e.source_id] + assert len(file_defines) >= 1 + + def test_contains_edge_class_to_method(self) -> None: + doc = self._parse("class Foo:\n def bar(self):\n pass\n") + defines = [e for e in doc.edges if e.type == EdgeType.DEFINES] + class_defines = [e for e in defines if "Class" in e.source_id] + assert len(class_defines) == 1 + + def test_no_duplicate_nodes(self) -> None: + doc = self._parse("class Foo:\n def bar(self):\n pass\n") + ids = [n.id for n in doc.nodes] + assert len(ids) == len(set(ids)) + + def test_syntax_error_partial(self) -> None: + doc = self._parse("def foo(\n pass\n") + assert any(n.type == NodeType.FILE for n in doc.nodes) + assert len(doc.errors) > 0 + + def test_nested_class(self) -> None: + doc = self._parse("class Outer:\n class Inner:\n def deep(self):\n pass\n") + classes = [n for n in doc.nodes if n.type == NodeType.CLASS] + assert len(classes) == 2 + names = {c.structural.name for c in classes} + assert names == {"Outer", "Inner"} + + def test_multiple_functions(self) -> None: + doc = self._parse("def a():\n pass\n\ndef b():\n pass\n\ndef c():\n pass\n") + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 3 + + def test_annotations_extracted(self) -> None: + doc = self._parse("def add(a: int, b: int) -> int:\n return a + b\n") + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 1 + ann = funcs[0].semantic.annotations + assert ann is not None + assert "a" in ann.params + assert "b" in ann.params + assert ann.returns == "int" + + +# ========================================================================= +# TypeScript parser +# ========================================================================= + + +class TestTypeScriptParser: + def _parse(self, src: str, fname: str = "test.ts"): + p = TypeScriptParser() + return p.parse(src, fname) + + def test_empty_file(self) -> None: + doc = self._parse("") + assert len(doc.errors) == 0 + + def test_function_declaration(self) -> None: + doc = self._parse("function hello(): void {\n}\n") + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 1 + assert funcs[0].structural.name == "hello" + + def test_class(self) -> None: + doc = self._parse("class Foo {\n bar(): void {}\n}\n") + classes = [n for n in doc.nodes if n.type == NodeType.CLASS] + assert len(classes) == 1 + + def test_no_duplicate_nodes(self) -> None: + doc = self._parse("class Foo {\n bar(): void {}\n}\n") + ids = [n.id for n in doc.nodes] + assert len(ids) == len(set(ids)) + + +# ========================================================================= +# Registry +# ========================================================================= + + +class TestParserRegistry: + def test_get_python(self) -> None: + reg = ParserRegistry() + parser = reg.get(Language.PYTHON) + assert parser is not None + + def test_get_typescript(self) -> None: + reg = ParserRegistry() + parser = reg.get(Language.TYPESCRIPT) + assert parser is not None + + def test_get_unknown(self) -> None: + reg = ParserRegistry() + parser = reg.get(Language.UNKNOWN) + assert parser is None + + def test_parse_file_python(self, tmp_path) -> None: + f = tmp_path / "test.py" + f.write_text("def hello():\n pass\n") + reg = ParserRegistry() + doc = reg.parse_file(str(f)) + assert len(doc.errors) == 0 + funcs = [n for n in doc.nodes if n.type == NodeType.FUNCTION] + assert len(funcs) == 1 + + def test_parse_file_missing(self) -> None: + reg = ParserRegistry() + doc = reg.parse_file("/nonexistent/test.py") + assert len(doc.errors) > 0 diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..5bcd0f9 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,195 @@ +"""Protocol tests — JSON-RPC 2.0 endpoint testing — SMP(3).""" + +from contextlib import asynccontextmanager + +import msgspec +import pytest +from fastapi import FastAPI, Request +from fastapi.responses import Response +from starlette.testclient import TestClient + +from smp.core.models import ( + EdgeType, + GraphEdge, + GraphNode, + NodeType, + SemanticProperties, + StructuralProperties, +) +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.engine.query import DefaultQueryEngine +from smp.parser.registry import ParserRegistry +from smp.protocol.router import handle_rpc +from smp.store.graph.neo4j_store import Neo4jGraphStore + + +def _rpc(method: str, params: dict, req_id: int = 1) -> bytes: + return msgspec.json.encode({"jsonrpc": "2.0", "method": method, "params": params, "id": req_id}) + + +def _parse(data: bytes) -> dict: + return msgspec.json.decode(data) + + +def _make_node( + id: str, type: NodeType, name: str, file_path: str, start_line: int = 1, end_line: int = 10, docstring: str = "" +) -> GraphNode: + return GraphNode( + id=id, + type=type, + file_path=file_path, + structural=StructuralProperties( + name=name, + file=file_path, + signature=f"{type.value.lower()} {name}", + start_line=start_line, + end_line=end_line, + lines=end_line - start_line + 1, + ), + semantic=SemanticProperties( + docstring=docstring, + status="enriched" if docstring else "no_metadata", + ), + ) + + +@pytest.fixture(scope="module") +def app_client(): + """Create app + stores all within the same event loop via FastAPI lifespan.""" + graph = Neo4jGraphStore() + enricher = StaticSemanticEnricher() + registry = ParserRegistry() + + @asynccontextmanager + async def lifespan(app: FastAPI): + await graph.connect() + await graph.clear() + nodes = [ + _make_node("f.py::File::f.py::1", NodeType.FILE, "f.py", "f.py", 1, 20), + _make_node( + "f.py::Function::alpha::3", NodeType.FUNCTION, "alpha", "f.py", 3, 8, docstring="Alpha function." + ), + _make_node("f.py::Function::beta::10", NodeType.FUNCTION, "beta", "f.py", 10, 15), + ] + edges = [ + GraphEdge(source_id="f.py::File::f.py::1", target_id="f.py::Function::alpha::3", type=EdgeType.DEFINES), + GraphEdge(source_id="f.py::File::f.py::1", target_id="f.py::Function::beta::10", type=EdgeType.DEFINES), + GraphEdge(source_id="f.py::Function::alpha::3", target_id="f.py::Function::beta::10", type=EdgeType.CALLS), + ] + await graph.upsert_nodes(nodes) + await graph.upsert_edges(edges) + + engine = DefaultQueryEngine(graph, enricher) + builder = DefaultGraphBuilder(graph) + app.state.engine = engine + app.state.builder = builder + app.state.enricher = enricher + app.state.registry = registry + app.state.safety = None + yield + await graph.clear() + await graph.close() + + app = FastAPI(lifespan=lifespan) + + @app.post("/rpc") + async def rpc_endpoint(request: Request) -> Response: + return await handle_rpc( + request, + engine=request.app.state.engine, + enricher=request.app.state.enricher, + builder=request.app.state.builder, + registry=request.app.state.registry, + safety=request.app.state.safety, + ) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + with TestClient(app) as c: + yield c + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_health(app_client): + assert app_client.get("/health").json()["status"] == "ok" + + +def test_navigate(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/navigate", {"query": "f.py::Function::alpha::3"})).content) + assert body["jsonrpc"] == "2.0" + assert body["error"] is None + assert body["result"]["entity"]["name"] == "alpha" + + +def test_navigate_missing(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/navigate", {"query": "nonexistent"})).content) + assert body["error"] is None + assert "error" in body["result"] + + +def test_trace(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/trace", {"start": "f.py::Function::alpha::3"})).content) + assert body["error"] is None + assert "beta" in {n["name"] for n in body["result"]} + + +def test_context(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/context", {"file_path": "f.py"})).content) + assert body["error"] is None + assert len(body["result"]["functions_defined"]) >= 2 + + +def test_impact(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/impact", {"entity": "f.py::Function::beta::10"})).content) + assert body["error"] is None + assert "affected_files" in body["result"] or "severity" in body["result"] + + +def test_locate(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/locate", {"query": "alpha function"})).content) + assert body["error"] is None + assert isinstance(body["result"], list) + + +def test_flow(app_client): + body = _parse( + app_client.post( + "/rpc", content=_rpc("smp/flow", {"start": "f.py::Function::alpha::3", "end": "f.py::Function::beta::10"}) + ).content + ) + assert body["error"] is None + assert len(body["result"]["path"]) >= 1 + + +def test_empty_body(app_client): + body = _parse(app_client.post("/rpc", content=b"").content) + assert body["error"]["code"] == -32700 + + +def test_invalid_json(app_client): + body = _parse(app_client.post("/rpc", content=b"{bad}").content) + assert body["error"]["code"] == -32700 + + +def test_unknown_method(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/nope", {})).content) + assert body["error"]["code"] == -32601 + + +def test_invalid_params(app_client): + body = _parse(app_client.post("/rpc", content=_rpc("smp/navigate", {"wrong": "x"})).content) + assert body["error"]["code"] == -32602 + + +def test_notification(app_client): + payload = msgspec.json.encode( + {"jsonrpc": "2.0", "method": "smp/navigate", "params": {"query": "f.py::Function::alpha::3"}} + ) + assert app_client.post("/rpc", content=payload).status_code == 204 diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..2641e83 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,277 @@ +"""Tests for the query engine — SMP(3).""" + +from __future__ import annotations + +import pytest + +from smp.core.models import ( + EdgeType, + GraphEdge, + GraphNode, + NodeType, + SemanticProperties, + StructuralProperties, +) +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.query import DefaultQueryEngine +from smp.store.graph.neo4j_store import Neo4jGraphStore + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +async def graph(): + store = Neo4jGraphStore() + await store.connect() + await store.clear() + yield store + await store.clear() + await store.close() + + +@pytest.fixture() +async def engine(graph: Neo4jGraphStore): + enricher = StaticSemanticEnricher() + return DefaultQueryEngine(graph, enricher) + + +def _make_node( + id: str, + type: NodeType, + name: str, + file_path: str, + start_line: int = 1, + end_line: int = 10, + docstring: str = "", + signature: str = "", +) -> GraphNode: + return GraphNode( + id=id, + type=type, + file_path=file_path, + structural=StructuralProperties( + name=name, + file=file_path, + signature=signature or f"{type.value.lower()} {name}", + start_line=start_line, + end_line=end_line, + lines=end_line - start_line + 1, + ), + semantic=SemanticProperties( + docstring=docstring, + status="enriched" if docstring else "no_metadata", + ), + ) + + +async def _seed_graph(graph: Neo4jGraphStore) -> None: + """Seed a small graph for testing: + file.py + ├── import os + ├── func_a() calls func_b() + ├── func_b() calls func_c() + ├── class Service + │ └── method() + └── func_c() + """ + nodes = [ + _make_node("file.py::File::file.py::1", NodeType.FILE, "file.py", "file.py", 1, 30), + _make_node("file.py::File::os::2", NodeType.FILE, "os", "file.py", 2, 2, signature="import os"), + _make_node("file.py::Function::func_a::4", NodeType.FUNCTION, "func_a", "file.py", 4, 8), + _make_node( + "file.py::Function::func_b::10", NodeType.FUNCTION, "func_b", "file.py", 10, 14, docstring="Does B things." + ), + _make_node("file.py::Function::func_c::16", NodeType.FUNCTION, "func_c", "file.py", 16, 20), + _make_node("file.py::Class::Service::22", NodeType.CLASS, "Service", "file.py", 22, 28), + _make_node("file.py::Function::method::23", NodeType.FUNCTION, "method", "file.py", 23, 25), + ] + edges = [ + GraphEdge(source_id="file.py::File::file.py::1", target_id="file.py::File::os::2", type=EdgeType.IMPORTS), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Function::func_a::4", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Function::func_b::10", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Function::func_c::16", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::File::file.py::1", target_id="file.py::Class::Service::22", type=EdgeType.DEFINES + ), + GraphEdge( + source_id="file.py::Function::func_a::4", target_id="file.py::Function::func_b::10", type=EdgeType.CALLS + ), + GraphEdge( + source_id="file.py::Function::func_b::10", target_id="file.py::Function::func_c::16", type=EdgeType.CALLS + ), + GraphEdge( + source_id="file.py::Class::Service::22", target_id="file.py::Function::method::23", type=EdgeType.DEFINES + ), + ] + await graph.upsert_nodes(nodes) + await graph.upsert_edges(edges) + + +# --------------------------------------------------------------------------- +# navigate +# --------------------------------------------------------------------------- + + +class TestNavigate: + @pytest.mark.asyncio + async def test_navigate_node(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.navigate("file.py::Function::func_a::4") + assert "entity" in result + assert result["entity"]["name"] == "func_a" + + @pytest.mark.asyncio + async def test_navigate_missing(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.navigate("nonexistent") + assert "error" in result + + @pytest.mark.asyncio + async def test_navigate_file(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.navigate("file.py::File::file.py::1") + assert result["entity"]["type"] == "File" + + +# --------------------------------------------------------------------------- +# trace +# --------------------------------------------------------------------------- + + +class TestTrace: + @pytest.mark.asyncio + async def test_trace_calls(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.trace("file.py::Function::func_a::4", "CALLS", depth=2) + names = {n["name"] for n in result} + assert "func_b" in names + assert "func_c" in names + + @pytest.mark.asyncio + async def test_trace_depth_1(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.trace("file.py::Function::func_a::4", "CALLS", depth=1) + names = {n["name"] for n in result} + assert "func_b" in names + assert "func_c" not in names + + +# --------------------------------------------------------------------------- +# get_context +# --------------------------------------------------------------------------- + + +class TestGetContext: + @pytest.mark.asyncio + async def test_context_file(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + ctx = await engine.get_context("file.py") + assert "self" in ctx + assert len(ctx["functions_defined"]) >= 3 + + @pytest.mark.asyncio + async def test_context_empty_file(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + ctx = await engine.get_context("nonexistent.py") + assert "error" in ctx + + +# --------------------------------------------------------------------------- +# assess_impact +# --------------------------------------------------------------------------- + + +class TestAssessImpact: + @pytest.mark.asyncio + async def test_impact(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.assess_impact("file.py::Function::func_b::10") + assert "affected_files" in result or "severity" in result + + @pytest.mark.asyncio + async def test_impact_missing(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.assess_impact("nonexistent") + assert "error" in result or result.get("severity") == "none" + + +# --------------------------------------------------------------------------- +# locate +# --------------------------------------------------------------------------- + + +class TestLocate: + @pytest.mark.asyncio + async def test_locate_by_name(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.locate("func_b") + assert len(result) > 0 + assert result[0]["entity"] == "func_b" + + @pytest.mark.asyncio + async def test_locate_empty(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.locate("") + assert result == [] or isinstance(result, list) + + +# --------------------------------------------------------------------------- +# search +# --------------------------------------------------------------------------- + + +class TestSearch: + @pytest.mark.asyncio + async def test_search_docstring(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.search("B things") + assert result["total"] >= 1 + + @pytest.mark.asyncio + async def test_search_no_match(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.search("xyz_nonexistent_term") + assert result["total"] == 0 + assert "hint" in result + + +# --------------------------------------------------------------------------- +# find_flow +# --------------------------------------------------------------------------- + + +class TestFindFlow: + @pytest.mark.asyncio + async def test_direct_path(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.find_flow( + "file.py::Function::func_a::4", + "file.py::Function::func_b::10", + ) + assert len(result["path"]) >= 1 + assert result["path"][0]["node"] == "func_a" + assert result["path"][-1]["node"] == "func_b" + + @pytest.mark.asyncio + async def test_same_node(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.find_flow( + "file.py::Function::func_a::4", + "file.py::Function::func_a::4", + ) + assert len(result["path"]) == 1 + assert result["path"][0]["node"] == "func_a" + + @pytest.mark.asyncio + async def test_missing_node(self, engine: DefaultQueryEngine, graph: Neo4jGraphStore) -> None: + await _seed_graph(graph) + result = await engine.find_flow("nonexistent", "also_nonexistent") + assert result["path"] == [] diff --git a/tests/test_store.py b/tests/test_store.py new file mode 100644 index 0000000..d635192 --- /dev/null +++ b/tests/test_store.py @@ -0,0 +1,213 @@ +"""Tests for Neo4j graph store — SMP(3).""" + +from __future__ import annotations + +import pytest + +from smp.core.models import EdgeType, NodeType, StructuralProperties +from smp.store.graph.neo4j_store import Neo4jGraphStore +from tests.conftest import make_edge, make_node + +# =================================================================== +# Neo4j Graph Store Tests +# =================================================================== + + +class TestNeo4jNodeCRUD: + @pytest.mark.asyncio + async def test_upsert_and_get(self, clean_graph: Neo4jGraphStore) -> None: + node = make_node() + await clean_graph.upsert_node(node) + fetched = await clean_graph.get_node("func_login") + assert fetched is not None + assert fetched.id == "func_login" + assert fetched.structural.name == "login" + assert fetched.type == NodeType.FUNCTION + + @pytest.mark.asyncio + async def test_upsert_updates_existing(self, clean_graph: Neo4jGraphStore) -> None: + node = make_node() + await clean_graph.upsert_node(node) + updated = make_node( + structural=StructuralProperties( + name="login", + file="src/auth/login.py", + signature="def login(user: User, otp: str) -> Token:", + start_line=10, + end_line=25, + lines=16, + ) + ) + await clean_graph.upsert_node(updated) + fetched = await clean_graph.get_node("func_login") + assert fetched is not None + assert "otp" in fetched.structural.signature + + @pytest.mark.asyncio + async def test_upsert_batch(self, clean_graph: Neo4jGraphStore) -> None: + nodes = [ + make_node( + id=f"n{i}", + structural=StructuralProperties( + name=f"n{i}", + file="test.py", + signature=f"def n{i}():", + start_line=i, + end_line=i + 1, + lines=1, + ), + ) + for i in range(10) + ] + await clean_graph.upsert_nodes(nodes) + assert await clean_graph.count_nodes() == 10 + + @pytest.mark.asyncio + async def test_get_missing_node(self, clean_graph: Neo4jGraphStore) -> None: + result = await clean_graph.get_node("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_delete_node(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_node(make_node()) + assert await clean_graph.delete_node("func_login") is True + assert await clean_graph.get_node("func_login") is None + + @pytest.mark.asyncio + async def test_delete_missing_returns_false(self, clean_graph: Neo4jGraphStore) -> None: + assert await clean_graph.delete_node("nope") is False + + @pytest.mark.asyncio + async def test_delete_by_file(self, clean_graph: Neo4jGraphStore) -> None: + nodes = [ + make_node(id="a", file_path="f1.py"), + make_node(id="b", file_path="f1.py"), + make_node(id="c", file_path="f2.py"), + ] + await clean_graph.upsert_nodes(nodes) + deleted = await clean_graph.delete_nodes_by_file("f1.py") + assert deleted == 2 + assert await clean_graph.count_nodes() == 1 + + +class TestNeo4jEdgeCRUD: + @pytest.mark.asyncio + async def test_upsert_edge(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_nodes([make_node(id="a"), make_node(id="b")]) + edge = make_edge(source="a", target="b", edge_type=EdgeType.CALLS) + await clean_graph.upsert_edge(edge) + edges = await clean_graph.get_edges("a", direction="outgoing") + assert len(edges) == 1 + assert edges[0].target_id == "b" + + @pytest.mark.asyncio + async def test_upsert_edges_batch(self, clean_graph: Neo4jGraphStore) -> None: + nodes = [make_node(id=f"n{i}") for i in range(5)] + await clean_graph.upsert_nodes(nodes) + edges = [make_edge(source=f"n{i}", target=f"n{i + 1}") for i in range(4)] + await clean_graph.upsert_edges(edges) + total = await clean_graph.count_edges() + assert total == 4 + + @pytest.mark.asyncio + async def test_get_edges_by_type(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_nodes([make_node(id="x"), make_node(id="y"), make_node(id="z")]) + await clean_graph.upsert_edge(make_edge(source="x", target="y", edge_type=EdgeType.CALLS)) + await clean_graph.upsert_edge(make_edge(source="x", target="z", edge_type=EdgeType.IMPORTS)) + calls = await clean_graph.get_edges("x", edge_type=EdgeType.CALLS) + assert len(calls) == 1 + + @pytest.mark.asyncio + async def test_incoming_edges(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_nodes([make_node(id="x"), make_node(id="y")]) + await clean_graph.upsert_edge(make_edge(source="x", target="y")) + incoming = await clean_graph.get_edges("y", direction="incoming") + assert len(incoming) == 1 + assert incoming[0].source_id == "x" + + +class TestNeo4jTraversal: + @pytest.mark.asyncio + async def test_neighbors(self, clean_graph: Neo4jGraphStore) -> None: + nodes = [make_node(id=f"n{i}") for i in range(4)] + await clean_graph.upsert_nodes(nodes) + edges = [ + make_edge(source="n0", target="n1"), + make_edge(source="n0", target="n2"), + make_edge(source="n1", target="n3"), + ] + await clean_graph.upsert_edges(edges) + neighbors = await clean_graph.get_neighbors("n0", depth=1) + ids = {n.id for n in neighbors} + assert ids == {"n1", "n2"} + + @pytest.mark.asyncio + async def test_traverse(self, clean_graph: Neo4jGraphStore) -> None: + nodes = [make_node(id=f"n{i}") for i in range(5)] + await clean_graph.upsert_nodes(nodes) + edges = [make_edge(source=f"n{i}", target=f"n{i + 1}") for i in range(4)] + await clean_graph.upsert_edges(edges) + result = await clean_graph.traverse("n0", EdgeType.CALLS, depth=3) + ids = {n.id for n in result} + assert "n1" in ids + assert "n3" in ids + + +class TestNeo4jSearch: + @pytest.mark.asyncio + async def test_find_by_type(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_nodes( + [ + make_node(id="f1", type=NodeType.FUNCTION), + make_node(id="c1", type=NodeType.CLASS), + ] + ) + funcs = await clean_graph.find_nodes(type=NodeType.FUNCTION) + assert len(funcs) == 1 + assert funcs[0].id == "f1" + + @pytest.mark.asyncio + async def test_find_by_file(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_nodes( + [ + make_node(id="a", file_path="x.py"), + make_node(id="b", file_path="y.py"), + ] + ) + result = await clean_graph.find_nodes(file_path="x.py") + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_find_by_name(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_nodes( + [ + make_node( + id="a", + structural=StructuralProperties( + name="login", file="test.py", signature="", start_line=1, end_line=5, lines=5 + ), + ), + make_node( + id="b", + structural=StructuralProperties( + name="logout", file="test.py", signature="", start_line=1, end_line=5, lines=5 + ), + ), + ] + ) + result = await clean_graph.find_nodes(name="login") + assert len(result) == 1 + + +class TestNeo4jCounts: + @pytest.mark.asyncio + async def test_empty_counts(self, clean_graph: Neo4jGraphStore) -> None: + assert await clean_graph.count_nodes() == 0 + assert await clean_graph.count_edges() == 0 + + @pytest.mark.asyncio + async def test_counts_after_inserts(self, clean_graph: Neo4jGraphStore) -> None: + await clean_graph.upsert_nodes([make_node(id="a"), make_node(id="b")]) + await clean_graph.upsert_edge(make_edge(source="a", target="b")) + assert await clean_graph.count_nodes() == 2 + assert await clean_graph.count_edges() == 1 diff --git a/tests/test_update.py b/tests/test_update.py new file mode 100644 index 0000000..e37d7ee --- /dev/null +++ b/tests/test_update.py @@ -0,0 +1,223 @@ +"""Tests for the incremental update flow (smp/update).""" + +from __future__ import annotations + +from contextlib import asynccontextmanager + +import msgspec +import pytest +from fastapi import FastAPI, Request +from fastapi.responses import Response +from starlette.testclient import TestClient + +from smp.engine.enricher import StaticSemanticEnricher +from smp.engine.graph_builder import DefaultGraphBuilder +from smp.engine.query import DefaultQueryEngine +from smp.parser.registry import ParserRegistry +from smp.protocol.router import handle_rpc +from smp.store.graph.neo4j_store import Neo4jGraphStore + + +def _rpc(method: str, params: dict, req_id: int = 1) -> bytes: + return msgspec.json.encode({"jsonrpc": "2.0", "method": method, "params": params, "id": req_id}) + + +def _parse(data: bytes) -> dict: + return msgspec.json.decode(data) + + +@pytest.fixture(scope="module") +def client(): + graph = Neo4jGraphStore() + enricher = StaticSemanticEnricher() + registry = ParserRegistry() + + @asynccontextmanager + async def lifespan(app: FastAPI): + await graph.connect() + await graph.clear() + app.state.engine = DefaultQueryEngine(graph, enricher) + app.state.builder = DefaultGraphBuilder(graph) + app.state.enricher = enricher + app.state.registry = registry + yield + await graph.clear() + await graph.close() + + app = FastAPI(lifespan=lifespan) + + @app.post("/rpc") + async def rpc(request: Request) -> Response: + return await handle_rpc( + request, + engine=request.app.state.engine, + enricher=request.app.state.enricher, + builder=request.app.state.builder, + registry=request.app.state.registry, + ) + + with TestClient(app) as c: + yield c + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_update_new_file(client: TestClient): + """Updating a new file ingests its nodes and edges.""" + body = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "new_file.py", + "content": "def greet(name: str) -> str:\n return f'Hello {name}'\n", + }, + ), + ).content + ) + assert body["error"] is None + result = body["result"] + assert result["file_path"] == "new_file.py" + assert result.get("nodes", 0) > 0 + + +def test_update_replaces_old_data(client: TestClient): + """Updating an existing file replaces old nodes with new ones.""" + # First update + body1 = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "replace.py", + "content": "def old_func():\n pass\n", + }, + ), + ).content + ) + count1 = body1["result"].get("nodes", 0) + + # Second update with different content + body2 = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "replace.py", + "content": "def new_func_a():\n pass\n\ndef new_func_b():\n pass\n", + }, + ), + ).content + ) + count2 = body2["result"].get("nodes", 0) + assert count2 >= count1 # more or equal nodes in the second version + + +def test_update_context_after_update(client: TestClient): + """After updating, smp/context reflects the new state.""" + # Update + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "ctx_test.py", + "content": "class MyClass:\n def method(self):\n pass\n", + }, + ), + ) + + # Query context + body = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/context", + { + "file_path": "ctx_test.py", + }, + ), + ).content + ) + assert body["error"] is None + result = body["result"] + # Accept any valid response format + assert "functions_defined" in result or "self" in result or "classes" in result + + +def test_update_enriches_nodes(client: TestClient): + """Updated nodes get semantic enrichment.""" + body = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "enrich_test.py", + "content": 'def authenticate(user, password):\n """Validates user credentials."""\n pass\n', + }, + ), + ).content + ) + result = body.get("result", {}) + assert result.get("nodes", 0) >= 0 # Just verify update works + + +def test_update_syntax_error_graceful(client: TestClient): + """Updating with broken syntax doesn't crash — returns partial results.""" + body = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "broken.py", + "content": "def broken(\n pass\n", + }, + ), + ).content + ) + # Should not crash — may have errors but shouldn't have RPC error + assert "result" in body or "error" in body + + +def test_update_empty_content(client: TestClient): + """Updating with empty content returns 0 nodes.""" + body = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "empty.py", + "content": "", + }, + ), + ).content + ) + assert body["result"]["nodes"] == 0 + + +def test_update_typescript(client: TestClient): + """Updating a TypeScript file works.""" + body = _parse( + client.post( + "/rpc", + content=_rpc( + "smp/update", + { + "file_path": "handler.ts", + "content": "export function handle(req: Request): Response {\n return new Response();\n}\n", + "language": "typescript", + }, + ), + ).content + ) + assert body["error"] is None + assert body["result"]["nodes"] > 0 From f69cd8d46c8dfb1f3f488ad0f8c5a4c7102b81f4 Mon Sep 17 00:00:00 2001 From: offx-zinth Date: Sun, 19 Apr 2026 12:07:06 +0530 Subject: [PATCH 2/2] rose v1 --- smp/protocol/server.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/smp/protocol/server.py b/smp/protocol/server.py index 80cdf1a..2e382a8 100644 --- a/smp/protocol/server.py +++ b/smp/protocol/server.py @@ -6,15 +6,10 @@ from __future__ import annotations try: -<<<<<<< HEAD import sys import pysqlite3 -======= - import pysqlite3 - import sys ->>>>>>> 87cfd9650622e51c4c94d43d490450a82a87ad3d sys.modules["sqlite3"] = pysqlite3 except ImportError: pass @@ -26,7 +21,6 @@ from fastapi import FastAPI, Request from fastapi.responses import Response -<<<<<<< HEAD from smp.core.merkle import MerkleIndex, MerkleTree from smp.engine.community import CommunityDetector from smp.engine.embedding import create_embedding_service @@ -39,18 +33,6 @@ from smp.protocol.dispatcher import handle_rpc from smp.store.chroma_store import ChromaVectorStore from smp.store.graph.neo4j_store import Neo4jGraphStore -======= -from smp.engine.enricher import StaticSemanticEnricher -from smp.engine.graph_builder import DefaultGraphBuilder -from smp.engine.seed_walk import SeedWalkEngine -from smp.engine.community import CommunityDetector -from smp.core.merkle import MerkleIndex -from smp.logging import get_logger -from smp.parser.registry import ParserRegistry -from smp.protocol.dispatcher import handle_rpc -from smp.store.graph.neo4j_store import Neo4jGraphStore -from smp.store.chroma_store import ChromaVectorStore ->>>>>>> 87cfd9650622e51c4c94d43d490450a82a87ad3d log = get_logger(__name__) @@ -75,7 +57,6 @@ async def lifespan(app: FastAPI): # type: ignore[no-untyped-def] # noqa: ANN20 vector = ChromaVectorStore() await vector.connect() -<<<<<<< HEAD embedding_service = create_embedding_service() await embedding_service.connect() @@ -86,14 +67,6 @@ async def lifespan(app: FastAPI): # type: ignore[no-untyped-def] # noqa: ANN20 builder = DefaultGraphBuilder(graph) registry = ParserRegistry() merkle_index = MerkleIndex(MerkleTree()) -======= - enricher = StaticSemanticEnricher() - community_detector = CommunityDetector(graph_store=graph, vector_store=vector) - engine = SeedWalkEngine(graph_store=graph, vector_store=vector, enricher=enricher) - builder = DefaultGraphBuilder(graph) - registry = ParserRegistry() - merkle_index = MerkleIndex() ->>>>>>> 87cfd9650622e51c4c94d43d490450a82a87ad3d safety: dict[str, Any] | None = None if safety_enabled: