Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:

- name: Test with pytest
run: |
pytest tests/ -v --cov=sugar --cov-report=xml --cov-report=term-missing --tb=short --ignore=tests/plugin/
pytest tests/ -v --cov=sugar --cov-report=xml --cov-report=term-missing --tb=short --ignore=tests/plugin/ -m 'not benchmark'

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,10 @@ testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
addopts = "-v --cov=sugar --cov-branch --cov-report=term-missing --cov-report=xml"
markers = [
"unit: Unit tests",
"integration: Integration tests",
"slow: Slow running tests"
"unit: Unit tests (no I/O, no database)",
"integration: Integration tests (real database, aiosqlite)",
"slow: Slow running tests (throughput, load)",
"benchmark: Throughput benchmarks (skip in CI, run with: pytest -m benchmark)"
]

# MCP Registry identification
Expand Down
9 changes: 8 additions & 1 deletion sugar/agent/subagent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,17 @@ async def cancel_all(self) -> None:

Note: This will attempt graceful shutdown but may not
interrupt tasks that are already executing.

Takes a snapshot of the dict before iterating to avoid
RuntimeError if spawn()'s finally block mutates the dict
concurrently (e.g., during shutdown while tasks are in-flight).
"""
logger.warning(f"Cancelling {len(self._active_subagents)} active sub-agents")

for task_id, subagent in self._active_subagents.items():
# Snapshot to avoid RuntimeError: dictionary changed size during iteration
active_snapshot = dict(self._active_subagents)

for task_id, subagent in active_snapshot.items():
try:
await subagent.end_session()
logger.debug(f"Cancelled sub-agent task: {task_id}")
Expand Down
15 changes: 12 additions & 3 deletions sugar/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def signal_handler(signum, frame):
shutdown_event.set()
logger.info("🔔 Shutdown event triggered")
else:
logger.warning("⚠️ Shutdown event not available")
# Fallback: if shutdown_event isn't ready yet (shouldn't happen
# now that we create it before registering handlers), exit cleanly.
logger.warning("⚠️ Shutdown event not available, forcing exit")
sys.exit(128 + signum)


@click.group(invoke_without_command=True)
Expand Down Expand Up @@ -2035,6 +2038,11 @@ def run(ctx, dry_run, once, validate):
asyncio.run(validate_config(sugar_loop))
return

# Create shutdown event BEFORE registering signal handlers to avoid
# a window where a signal arrives but the event doesn't exist yet.
global shutdown_event
shutdown_event = asyncio.Event()

# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
Expand Down Expand Up @@ -2117,9 +2125,10 @@ async def run_once(sugar_loop):


async def run_continuous(sugar_loop):
"""Run Sugar continuously"""
"""Run Sugar continuously (shutdown_event created before signal handlers)"""
global shutdown_event
shutdown_event = asyncio.Event()
if shutdown_event is None:
shutdown_event = asyncio.Event()

# Create PID file for stop command
import os
Expand Down
162 changes: 90 additions & 72 deletions sugar/memory/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import sqlite3
import threading
import uuid
from datetime import datetime, timezone
from pathlib import Path
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
self.embedder = embedder or create_embedder()
self._has_vec = self._check_sqlite_vec()
self._conn: Optional[sqlite3.Connection] = None
self._lock = threading.Lock()

self._init_db()

Expand All @@ -73,9 +75,14 @@ def _check_sqlite_vec(self) -> bool:
return False

def _get_connection(self) -> sqlite3.Connection:
"""Get or create database connection."""
"""Get or create database connection.

Uses check_same_thread=False to allow safe cross-thread use when
called from asyncio's run_in_executor. Thread safety is ensured
by self._lock around all public methods that access the connection.
"""
if self._conn is None:
self._conn = sqlite3.connect(str(self.db_path))
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self._conn.row_factory = sqlite3.Row

if self._has_vec:
Expand Down Expand Up @@ -209,96 +216,106 @@ def store(self, entry: MemoryEntry) -> str:
if entry.created_at is None:
entry.created_at = datetime.now(timezone.utc)

conn = self._get_connection()
cursor = conn.cursor()
with self._lock:
conn = self._get_connection()
cursor = conn.cursor()

# Store main entry
cursor.execute(
"""
INSERT OR REPLACE INTO memory_entries
(id, memory_type, source_id, content, summary, metadata,
importance, created_at, last_accessed_at, access_count, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.id,
# Store main entry
cursor.execute(
"""
INSERT OR REPLACE INTO memory_entries
(id, memory_type, source_id, content, summary, metadata,
importance, created_at, last_accessed_at, access_count, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.memory_type.value
if isinstance(entry.memory_type, MemoryType)
else entry.memory_type
entry.id,
(
entry.memory_type.value
if isinstance(entry.memory_type, MemoryType)
else entry.memory_type
),
entry.source_id,
entry.content,
entry.summary,
json.dumps(entry.metadata) if entry.metadata else None,
entry.importance,
entry.created_at.isoformat() if entry.created_at else None,
(
entry.last_accessed_at.isoformat()
if entry.last_accessed_at
else None
),
entry.access_count,
entry.expires_at.isoformat() if entry.expires_at else None,
),
entry.source_id,
entry.content,
entry.summary,
json.dumps(entry.metadata) if entry.metadata else None,
entry.importance,
entry.created_at.isoformat() if entry.created_at else None,
entry.last_accessed_at.isoformat() if entry.last_accessed_at else None,
entry.access_count,
entry.expires_at.isoformat() if entry.expires_at else None,
),
)
)

# Generate and store embedding if we have semantic search
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
try:
embedding = self.embedder.embed(entry.content)
if embedding:
cursor.execute(
"""
INSERT OR REPLACE INTO memory_vectors (id, embedding)
VALUES (?, ?)
""",
(entry.id, _serialize_embedding(embedding)),
)
except Exception as e:
logger.warning(f"Failed to store embedding: {e}")
# Generate and store embedding if we have semantic search
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
try:
embedding = self.embedder.embed(entry.content)
if embedding:
cursor.execute(
"""
INSERT OR REPLACE INTO memory_vectors (id, embedding)
VALUES (?, ?)
""",
(entry.id, _serialize_embedding(embedding)),
)
except Exception as e:
logger.warning(f"Failed to store embedding: {e}")

conn.commit()
return entry.id
conn.commit()
return entry.id

def get(self, entry_id: str) -> Optional[MemoryEntry]:
"""Get a memory entry by ID."""
conn = self._get_connection()
cursor = conn.cursor()
with self._lock:
conn = self._get_connection()
cursor = conn.cursor()

cursor.execute(
"""
SELECT * FROM memory_entries WHERE id = ?
""",
(entry_id,),
)
cursor.execute(
"""
SELECT * FROM memory_entries WHERE id = ?
""",
(entry_id,),
)

row = cursor.fetchone()
if row:
return self._row_to_entry(row)
return None
row = cursor.fetchone()
if row:
return self._row_to_entry(row)
return None

def delete(self, entry_id: str) -> bool:
"""Delete a memory entry."""
conn = self._get_connection()
cursor = conn.cursor()
with self._lock:
conn = self._get_connection()
cursor = conn.cursor()

cursor.execute("DELETE FROM memory_entries WHERE id = ?", (entry_id,))
cursor.execute("DELETE FROM memory_entries WHERE id = ?", (entry_id,))

if self._has_vec:
try:
cursor.execute("DELETE FROM memory_vectors WHERE id = ?", (entry_id,))
except Exception:
pass
if self._has_vec:
try:
cursor.execute(
"DELETE FROM memory_vectors WHERE id = ?", (entry_id,)
)
except Exception:
pass

conn.commit()
return cursor.rowcount > 0
conn.commit()
return cursor.rowcount > 0

def search(self, query: MemoryQuery) -> List[MemorySearchResult]:
"""
Search memories.

Uses vector similarity if available, falls back to FTS5.
"""
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
return self._search_semantic(query)
return self._search_keyword(query)
with self._lock:
if self._has_vec and not isinstance(self.embedder, FallbackEmbedder):
return self._search_semantic(query)
return self._search_keyword(query)

def _search_semantic(self, query: MemoryQuery) -> List[MemorySearchResult]:
"""Search using vector similarity."""
Expand Down Expand Up @@ -638,6 +655,7 @@ def prune_expired(self) -> int:

def close(self):
"""Close database connection."""
if self._conn:
self._conn.close()
self._conn = None
with self._lock:
if self._conn:
self._conn.close()
self._conn = None
52 changes: 29 additions & 23 deletions sugar/storage/issue_response_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Issue Response Manager - Track GitHub issue responses
"""

import asyncio
import json
import logging
import uuid
Expand All @@ -19,41 +20,46 @@ class IssueResponseManager:
def __init__(self, db_path: str = ".sugar/sugar.db"):
self.db_path = db_path
self._initialized = False
self._init_lock = asyncio.Lock()

async def initialize(self) -> None:
"""Create table if not exists"""
if self._initialized:
return

async with aiosqlite.connect(self.db_path) as db:
await db.execute(
async with self._init_lock:
if self._initialized:
return

async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"""
CREATE TABLE IF NOT EXISTS issue_responses (
id TEXT PRIMARY KEY,
repo TEXT NOT NULL,
issue_number INTEGER NOT NULL,
response_type TEXT NOT NULL,
work_item_id TEXT,
confidence REAL,
posted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
response_content TEXT,
labels_applied TEXT,
was_auto_posted BOOLEAN DEFAULT 0,
UNIQUE(repo, issue_number, response_type)
)
"""
CREATE TABLE IF NOT EXISTS issue_responses (
id TEXT PRIMARY KEY,
repo TEXT NOT NULL,
issue_number INTEGER NOT NULL,
response_type TEXT NOT NULL,
work_item_id TEXT,
confidence REAL,
posted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
response_content TEXT,
labels_applied TEXT,
was_auto_posted BOOLEAN DEFAULT 0,
UNIQUE(repo, issue_number, response_type)
)
"""
)

await db.execute(
await db.execute(
"""
CREATE INDEX IF NOT EXISTS idx_issue_responses_repo_number
ON issue_responses (repo, issue_number)
"""
CREATE INDEX IF NOT EXISTS idx_issue_responses_repo_number
ON issue_responses (repo, issue_number)
"""
)
)

await db.commit()
await db.commit()

self._initialized = True
self._initialized = True
logger.debug(f"Issue response manager initialized: {self.db_path}")

async def has_responded(
Expand Down
Loading
Loading