Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 98 additions & 29 deletions context_scribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
from pathlib import Path
from datetime import datetime
from typing import Optional
from typing import List, Optional
import click
from rich.console import Console
from rich.live import Live
Expand All @@ -20,6 +20,7 @@
from context_scribe.evaluator import get_evaluator, EVALUATOR_REGISTRY
from context_scribe.bridge.mcp_client import MemoryBankClient


logger = logging.getLogger("context_scribe")
console: Console = Console()

Expand Down Expand Up @@ -163,6 +164,31 @@ def bootstrap_claude_config() -> None:
f.write(f"\n{MASTER_RETRIEVAL_RULE}\n")


TOOL_REGISTRY = {
"gemini-cli": (GeminiCliProvider, bootstrap_global_config),
"copilot": (CopilotProvider, bootstrap_copilot_config),
"claude": (ClaudeProvider, bootstrap_claude_config),
}


def _create_providers(tools: List[str]):
"""Create and bootstrap providers for the given tool names.

Raises ValueError for unknown tool names.
"""
providers = []
for tool in tools:
entry = TOOL_REGISTRY.get(tool)
if entry is None:
raise ValueError(
f"Unknown tool '{tool}'. Available: {', '.join(sorted(TOOL_REGISTRY))}"
)
provider_cls, bootstrap_fn = entry
bootstrap_fn()
providers.append((tool, provider_cls()))
return providers


def _detect_evaluator(preferred_tool: Optional[str] = None) -> str:
"""Auto-detect which evaluator CLI is available, prioritizing the preferred tool."""
# Map tool names to their corresponding CLI commands
Expand Down Expand Up @@ -209,22 +235,20 @@ def _status(msg: str, db, live, debug: bool):
live.update(db.generate_layout())


async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto") -> bool:
if tool == "gemini-cli":
bootstrap_global_config()
provider = GeminiCliProvider()
elif tool == "copilot":
bootstrap_copilot_config()
provider = CopilotProvider()
elif tool == "claude":
bootstrap_claude_config()
provider = ClaudeProvider()
async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto", tools: Optional[List[str]] = None) -> bool:
# Build provider list: --tools takes precedence over --tool
if tools is not None:
if not tools:
raise ValueError("--tools was provided but resolved to an empty list.")
tool_names = tools
else:
provider = None
if not provider: return False
tool_names = [tool]
providers = _create_providers(tool_names)
if not providers:
return False

if evaluator_name == "auto":
evaluator_name = _detect_evaluator(tool)
evaluator_name = _detect_evaluator(tool_names[0])
evaluator = get_evaluator(evaluator_name)
mcp_client = MemoryBankClient(bank_path=bank_path)

Expand All @@ -234,29 +258,54 @@ async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_n
console.print("[bold red]Fatal Error: Could not connect to the Memory Bank MCP server.[/bold red]")
raise SystemExit(1)

db = Dashboard(tool, bank_path)
display_name = ",".join(tool_names)
db = Dashboard(display_name, bank_path)
queue: asyncio.Queue = asyncio.Queue(maxsize=1000)

async def _watch_provider(tool_name: str, provider):
"""Run a provider's watch() in a thread and feed interactions into the shared queue."""
loop = asyncio.get_event_loop()
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In async contexts, asyncio.get_event_loop() is deprecated on recent Python versions and can emit DeprecationWarnings; use asyncio.get_running_loop() here instead (and reuse that loop variable inside the coroutine).

Suggested change
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()

Copilot uses AI. Check for mistakes.
watch_iter = provider.watch()
try:
while True:
interaction = await loop.run_in_executor(None, next, watch_iter)
if interaction is not None:
await queue.put((tool_name, interaction))
except (StopIteration, asyncio.CancelledError, KeyboardInterrupt):
pass
except Exception as e:
logger.error("Watcher for %s failed: %s", tool_name, e)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_watch_provider swallows CancelledError/StopIteration and exits without ensuring the underlying provider.watch() generator is closed. For BaseProvider implementations, the watchdog Observer is stopped/joined only in the generator's finally, so not closing the generator can leak observer threads on shutdown or if a watcher exits early. Ensure the generator is closed (or provide an explicit provider shutdown hook) when the watcher task ends.

Suggested change
logger.error("Watcher for %s failed: %s", tool_name, e)
logger.error("Watcher for %s failed: %s", tool_name, e)
finally:
close = getattr(watch_iter, "close", None)
if callable(close):
try:
close()
except Exception as e:
logger.warning("Failed to close watcher for %s: %s", tool_name, e)

Copilot uses AI. Check for mistakes.

async def _loop(live=None):
watcher_tasks = []
try:
loop = asyncio.get_event_loop()
watch_iter = provider.watch()
# Start a watcher task for each provider
watcher_tasks = [
asyncio.create_task(_watch_provider(name, prov))
for name, prov in providers
]
_status("🔍 Watching log stream...", db, live, debug)

while True:
if live: live.update(db.generate_layout())
interaction = await loop.run_in_executor(None, next, watch_iter)
if interaction is None:
if live:
live.update(db.generate_layout())

# Wait for next interaction from any provider
try:
tool_name, interaction = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue

_status(f"🤔 Analyzing user message ({interaction.project_name})", db, live, debug)
_status(f"🤔 [{tool_name}] Analyzing user message ({interaction.project_name})", db, live, debug)
if debug:
logging.getLogger("context_scribe").info(" content: %s", interaction.content[:120])
logger.info(" content: %s", interaction.content[:120])

_status(f"📖 Accessing Memory Bank ({interaction.project_name})...", db, live, debug)
_status(f"📖 [{tool_name}] Accessing Memory Bank ({interaction.project_name})...", db, live, debug)
existing_global = await mcp_client.read_rules("global", "global_rules.md")
existing_project = await mcp_client.read_rules(interaction.project_name, "rules.md")

_status(f"🧠 Thinking: Extracting rules for {interaction.project_name}...", db, live, debug)
_status(f"🧠 [{tool_name}] Extracting rules for {interaction.project_name}...", db, live, debug)
loop = asyncio.get_event_loop()
rule_output = await loop.run_in_executor(None, evaluator.evaluate_interaction, interaction, existing_global, existing_project)

if rule_output:
Expand All @@ -276,11 +325,11 @@ async def _loop(live=None):
seen.add(stripped)
deduped_content = "\n".join(unique_lines).strip()

_status(f"📝 Committing: {dest_path}", db, live, debug)
_status(f"📝 [{tool_name}] Committing: {dest_path}", db, live, debug)
await mcp_client.save_rule(deduped_content, dest_proj, dest_file)

db.add_history(dest_path, rule_output.description)
_status(f"✅ SUCCESS: Updated {dest_path}", db, live, debug)
_status(f"✅ [{tool_name}] Updated {dest_path}", db, live, debug)
if not debug:
console.print(f"[bold green]▶ UPDATED:[/bold green] [cyan]{dest_path}[/cyan] ({rule_output.description})")
else:
Expand All @@ -291,6 +340,8 @@ async def _loop(live=None):
except (KeyboardInterrupt, asyncio.CancelledError):
_status("🛑 Stopping...", db, live, debug)
finally:
for task in watcher_tasks:
task.cancel()
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Watcher tasks are cancelled in finally, but they are never awaited. This can leave pending tasks and produce "Task was destroyed but it is pending" / unhandled-exception warnings. After cancelling, await them (e.g., await asyncio.gather(*watcher_tasks, return_exceptions=True)) before closing the MCP client.

Suggested change
task.cancel()
task.cancel()
if watcher_tasks:
await asyncio.gather(*watcher_tasks, return_exceptions=True)

Copilot uses AI. Check for mistakes.
await mcp_client.close()

if debug:
Expand All @@ -301,16 +352,34 @@ async def _loop(live=None):
return True

@click.command()
@click.option('--tool', default='gemini-cli', type=click.Choice(['gemini-cli', 'copilot', 'claude']), help='The AI tool to monitor')
@click.option('--tool', default='gemini-cli', type=click.Choice(['gemini-cli', 'copilot', 'claude']), help='Single AI tool to monitor (use --tools for multiple)')
@click.option('--tools', 'tools_csv', default=None, help='Comma-separated tools to monitor concurrently (e.g. gemini-cli,claude,copilot)')
@click.option('--bank-path', default='~/.memory-bank', help='Path to your Memory Bank root')
@click.option('--evaluator', 'evaluator_name', default='auto', type=click.Choice(['auto'] + sorted(EVALUATOR_REGISTRY)), help='Evaluator LLM to use (default: auto-detect)')
@click.option('--debug', is_flag=True, default=False, help='Stream plain debug logs instead of dashboard UI')
def cli(tool, bank_path, evaluator_name, debug):
def cli(tool, tools_csv, bank_path, evaluator_name, debug):
"""Context-Scribe: Persistent Secretary Daemon"""
if debug:
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')

# Parse --tools if provided
tools = None
if tools_csv is not None:
tools = list(dict.fromkeys( # deduplicate preserving order
t.strip() for t in tools_csv.split(",") if t.strip()
))
if not tools:
raise click.ClickException("--tools requires at least one tool name.")
valid_tools = set(TOOL_REGISTRY)
invalid = [t for t in tools if t not in valid_tools]
if invalid:
raise click.ClickException(
f"Unknown tool(s): {', '.join(invalid)}. "
f"Available: {', '.join(sorted(valid_tools))}"
)

try:
asyncio.run(run_daemon(tool, bank_path, debug=debug, evaluator_name=evaluator_name))
asyncio.run(run_daemon(tool, bank_path, debug=debug, evaluator_name=evaluator_name, tools=tools))
except KeyboardInterrupt:
pass

Expand Down
14 changes: 7 additions & 7 deletions tests/test_daemons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from context_scribe.main import run_daemon

@pytest.mark.asyncio
@pytest.mark.parametrize("tool, provider_class, bootstrap_func, evaluator_name", [
("gemini-cli", "GeminiCliProvider", "bootstrap_global_config", "gemini"),
("copilot", "CopilotProvider", "bootstrap_copilot_config", "copilot"),
("claude", "ClaudeProvider", "bootstrap_claude_config", "claude"),
@pytest.mark.parametrize("tool, bootstrap_func, evaluator_name", [
("gemini-cli", "bootstrap_global_config", "gemini"),
("copilot", "bootstrap_copilot_config", "copilot"),
("claude", "bootstrap_claude_config", "claude"),
])
async def test_run_daemon_tools(tool, provider_class, bootstrap_func, evaluator_name, daemon_mocks):
async def test_run_daemon_tools(tool, bootstrap_func, evaluator_name, daemon_mocks):
"""Test the daemon run loop for all supported tools."""
with patch(f"context_scribe.main.{provider_class}", return_value=daemon_mocks.provider):

with patch("context_scribe.main._create_providers", return_value=[(tool, daemon_mocks.provider)]):
with patch("context_scribe.main.get_evaluator", return_value=daemon_mocks.evaluator):
with patch("context_scribe.main.MemoryBankClient", return_value=daemon_mocks.mcp):
with patch(f"context_scribe.main.{bootstrap_func}"):
Expand Down
141 changes: 141 additions & 0 deletions tests/test_multi_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Tests for concurrent multi-tool daemon support."""
import asyncio
import sys
from unittest.mock import patch, MagicMock
import pytest


@pytest.fixture(autouse=True)
def mock_heavy_deps():
"""Mock heavy imports so we can import main without mcp/rich."""
mocks = {}
for mod in ["mcp", "mcp.client", "mcp.client.stdio",
"rich", "rich.console", "rich.live", "rich.panel",
"rich.text", "rich.layout", "rich.table", "rich.spinner"]:
if mod not in sys.modules or not hasattr(sys.modules.get(mod), '__file__'):
mocks[mod] = MagicMock()
with patch.dict(sys.modules, mocks):
# Clear cached imports so they re-resolve with mocks
for key in list(sys.modules.keys()):
if key.startswith("context_scribe.main") or key.startswith("context_scribe.bridge"):
del sys.modules[key]
yield


def test_tool_registry_populated():
from context_scribe.main import TOOL_REGISTRY
assert "gemini-cli" in TOOL_REGISTRY
assert "copilot" in TOOL_REGISTRY
assert "claude" in TOOL_REGISTRY


def test_create_providers_single():
from context_scribe.main import _create_providers
with patch("context_scribe.main.bootstrap_global_config"):
with patch("context_scribe.main.GeminiCliProvider") as mock_cls:
mock_cls.return_value = MagicMock()
providers = _create_providers(["gemini-cli"])
assert len(providers) == 1
assert providers[0][0] == "gemini-cli"


def test_create_providers_multiple():
from context_scribe.main import _create_providers
with patch("context_scribe.main.bootstrap_global_config"):
with patch("context_scribe.main.bootstrap_claude_config"):
with patch("context_scribe.main.GeminiCliProvider", return_value=MagicMock()):
with patch("context_scribe.main.ClaudeProvider", return_value=MagicMock()):
providers = _create_providers(["gemini-cli", "claude"])
assert len(providers) == 2
names = [p[0] for p in providers]
assert "gemini-cli" in names
assert "claude" in names


def test_create_providers_unknown_raises():
from context_scribe.main import _create_providers
with pytest.raises(ValueError, match="Unknown tool"):
_create_providers(["nonexistent"])


def test_create_providers_calls_bootstrap():
from context_scribe.main import _create_providers, TOOL_REGISTRY
mock_boot = MagicMock()
original = TOOL_REGISTRY["gemini-cli"]
TOOL_REGISTRY["gemini-cli"] = (original[0], mock_boot)
try:
with patch("context_scribe.main.GeminiCliProvider", return_value=MagicMock()):
_create_providers(["gemini-cli"])
mock_boot.assert_called_once()
finally:
TOOL_REGISTRY["gemini-cli"] = original


def test_create_providers_all_three():
from context_scribe.main import _create_providers
with patch("context_scribe.main.bootstrap_global_config"):
with patch("context_scribe.main.bootstrap_copilot_config"):
with patch("context_scribe.main.bootstrap_claude_config"):
with patch("context_scribe.main.GeminiCliProvider", return_value=MagicMock()):
with patch("context_scribe.main.CopilotProvider", return_value=MagicMock()):
with patch("context_scribe.main.ClaudeProvider", return_value=MagicMock()):
providers = _create_providers(["gemini-cli", "copilot", "claude"])
assert len(providers) == 3


# --- CLI tests for --tools flag ---

def test_cli_tools_deduplication():
"""--tools with duplicate entries should deduplicate preserving order."""
from click.testing import CliRunner
from context_scribe.main import cli

runner = CliRunner()
captured_tools = {}

original_run_daemon = None

async def fake_run_daemon(tool, bank_path, debug=False, evaluator_name="auto", tools=None):
captured_tools["tools"] = tools
return True

with patch("context_scribe.main.run_daemon", side_effect=fake_run_daemon) as mock_rd:
loop = asyncio.new_event_loop()
with patch("asyncio.run", side_effect=lambda coro: loop.run_until_complete(coro)):
result = runner.invoke(cli, ["--tools", "gemini-cli,gemini-cli,claude"])
Comment on lines +102 to +105
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test creates a new event loop but never closes it. On some Python/pytest configurations this can raise ResourceWarnings or fail the suite. Close the loop in a finally block (or avoid creating a custom loop by using CliRunner(mix_stderr=False) with asyncio.run unpatched).

Suggested change
with patch("context_scribe.main.run_daemon", side_effect=fake_run_daemon) as mock_rd:
loop = asyncio.new_event_loop()
with patch("asyncio.run", side_effect=lambda coro: loop.run_until_complete(coro)):
result = runner.invoke(cli, ["--tools", "gemini-cli,gemini-cli,claude"])
loop = asyncio.new_event_loop()
try:
with patch("context_scribe.main.run_daemon", side_effect=fake_run_daemon) as mock_rd:
with patch("asyncio.run", side_effect=lambda coro: loop.run_until_complete(coro)):
result = runner.invoke(cli, ["--tools", "gemini-cli,gemini-cli,claude"])
finally:
loop.close()

Copilot uses AI. Check for mistakes.

assert result.exit_code == 0
assert captured_tools["tools"] == ["gemini-cli", "claude"]


def test_cli_tools_invalid_tool():
"""--tools with an unknown tool name should fail with a clear error."""
from click.testing import CliRunner
from context_scribe.main import cli

runner = CliRunner()
result = runner.invoke(cli, ["--tools", "gemini-cli,nonexistent"])
assert result.exit_code != 0
assert "Unknown tool(s): nonexistent" in result.output


def test_cli_tools_empty_string():
"""--tools with an empty string should fail."""
from click.testing import CliRunner
from context_scribe.main import cli

runner = CliRunner()
result = runner.invoke(cli, ["--tools", ""])
assert result.exit_code != 0
assert "--tools requires at least one tool name" in result.output


def test_cli_tools_whitespace_only():
"""--tools with only whitespace/commas should fail."""
from click.testing import CliRunner
from context_scribe.main import cli

runner = CliRunner()
result = runner.invoke(cli, ["--tools", " , , "])
assert result.exit_code != 0
assert "--tools requires at least one tool name" in result.output
Loading