diff --git a/.github/workflows/robot-tests.yml b/.github/workflows/robot-tests.yml index 3333266d..b48b5e75 100644 --- a/.github/workflows/robot-tests.yml +++ b/.github/workflows/robot-tests.yml @@ -85,6 +85,18 @@ jobs: echo "✓ Test config.yml created from tests/configs/deepgram-openai.yml" ls -lh config/config.yml + - name: Create plugins.yml from template + run: | + echo "Creating plugins.yml from template..." + if [ -f "config/plugins.yml.template" ]; then + cp config/plugins.yml.template config/plugins.yml + echo "✓ plugins.yml created from template" + ls -lh config/plugins.yml + else + echo "❌ ERROR: config/plugins.yml.template not found" + exit 1 + fi + - name: Run Robot Framework tests working-directory: tests env: diff --git a/.gitignore b/.gitignore index 23141c6b..6fa02d7f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,10 @@ tests/setup/.env.test config/config.yml !config/config.yml.template +# Plugins config (contains secrets) +config/plugins.yml +!config/plugins.yml.template + # Config backups config/*.backup.* config/*.backup* diff --git a/CLAUDE.md b/CLAUDE.md index abe20db6..b981231a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -114,16 +114,8 @@ cp .env.template .env # Configure API keys # Run full integration test suite ./run-test.sh -# Manual test execution (for debugging) -source .env && export DEEPGRAM_API_KEY && export OPENAI_API_KEY -uv run robot --outputdir test-results --loglevel INFO ../../tests/integration/integration_test.robot - # Leave test containers running for debugging (don't auto-cleanup) -CLEANUP_CONTAINERS=false source .env && export DEEPGRAM_API_KEY && export OPENAI_API_KEY -uv run robot --outputdir test-results --loglevel INFO ../../tests/integration/integration_test.robot - -# Manual cleanup when needed -docker compose -f docker-compose-test.yml down -v +CLEANUP_CONTAINERS=false ./run-test.sh ``` #### Test Configuration Flags diff --git a/Docs/getting-started.md b/Docs/getting-started.md index a923c99c..c1e1a4b4 100644 --- a/Docs/getting-started.md +++ b/Docs/getting-started.md @@ -175,11 +175,16 @@ PARAKEET_ASR_URL=http://host.docker.internal:8080 After configuration, verify everything works with the integration test suite: ```bash +# From backends/advanced directory ./run-test.sh -# Alternative: Manual test with detailed logging -source .env && export DEEPGRAM_API_KEY OPENAI_API_KEY && \ - uv run robot --outputdir ../../test-results --loglevel INFO ../../tests/integration/integration_test.robot +# Or run all tests from project root +cd ../.. +./run-test.sh advanced-backend + +# Or run complete Robot Framework test suite +cd tests +./run-robot-tests.sh ``` This end-to-end test validates the complete audio processing pipeline using Robot Framework. diff --git a/backends/advanced/.env.template b/backends/advanced/.env.template index a63ab6f5..9c11af67 100644 --- a/backends/advanced/.env.template +++ b/backends/advanced/.env.template @@ -216,4 +216,41 @@ CORS_ORIGINS=http://localhost:5173,http://localhost:3000,http://127.0.0.1:5173,h LANGFUSE_PUBLIC_KEY="" LANGFUSE_SECRET_KEY="" LANGFUSE_HOST="http://x.x.x.x:3002" -LANGFUSE_ENABLE_TELEMETRY=False \ No newline at end of file +LANGFUSE_ENABLE_TELEMETRY=False + +# ======================================== +# TAILSCALE CONFIGURATION (Optional) +# ======================================== +# Required for accessing remote services on Tailscale network (e.g., Home Assistant plugin) +# +# To enable Tailscale Docker integration: +# 1. Get auth key from: https://login.tailscale.com/admin/settings/keys +# 2. Set TS_AUTHKEY below +# 3. Start Tailscale: docker compose --profile tailscale up -d +# +# The Tailscale container provides proxy access to remote services at: +# http://host.docker.internal:18123 (proxies to Home Assistant on Tailscale) +# +TS_AUTHKEY=your-tailscale-auth-key-here + +# ======================================== +# HOME ASSISTANT PLUGIN (Optional) +# ======================================== +# Required for Home Assistant voice control via wake word (e.g., "Hey Vivi, turn off the lights") +# +# To get a long-lived access token: +# 1. Go to Home Assistant → Profile → Security tab +# 2. Scroll to "Long-lived access tokens" +# 3. Click "Create Token" +# 4. Copy the token and paste it below +# +# Configuration in config/plugins.yml: +# - Enable the homeassistant plugin +# - Set ha_url to your Home Assistant URL +# - Set ha_token to ${HA_TOKEN} (reads from this variable) +# +# SECURITY: This token grants full access to your Home Assistant. +# - Never commit .env or config/plugins.yml to version control +# - Rotate the token if it's ever exposed +# +HA_TOKEN= \ No newline at end of file diff --git a/backends/advanced/Docs/quickstart.md b/backends/advanced/Docs/quickstart.md index 0d681978..9f966242 100644 --- a/backends/advanced/Docs/quickstart.md +++ b/backends/advanced/Docs/quickstart.md @@ -173,11 +173,16 @@ PARAKEET_ASR_URL=http://host.docker.internal:8080 After configuration, verify everything works with the integration test suite: ```bash +# From backends/advanced directory ./run-test.sh -# Alternative: Manual test with detailed logging -source .env && export DEEPGRAM_API_KEY OPENAI_API_KEY && \ - uv run robot --outputdir ../../test-results --loglevel INFO ../../tests/integration/integration_test.robot +# Or run all tests from project root +cd ../.. +./run-test.sh advanced-backend + +# Or run complete Robot Framework test suite +cd tests +./run-robot-tests.sh ``` This end-to-end test validates the complete audio processing pipeline using Robot Framework. diff --git a/backends/advanced/docker-compose-test.yml b/backends/advanced/docker-compose-test.yml index 867edc5f..812d29b9 100644 --- a/backends/advanced/docker-compose-test.yml +++ b/backends/advanced/docker-compose-test.yml @@ -14,7 +14,7 @@ services: - ./data/test_audio_chunks:/app/audio_chunks - ./data/test_debug_dir:/app/debug_dir - ./data/test_data:/app/data - - ${CONFIG_FILE:-../../config/config.yml}:/app/config.yml:ro # Mount config.yml for model registry and memory settings + - ${CONFIG_FILE:-../../config/config.yml}:/app/config.yml # Mount config.yml for model registry and memory settings (writable for admin config updates) environment: # Override with test-specific settings - MONGODB_URI=mongodb://mongo-test:27017/test_db @@ -160,7 +160,7 @@ services: - ./data/test_audio_chunks:/app/audio_chunks - ./data/test_debug_dir:/app/debug_dir - ./data/test_data:/app/data - - ${CONFIG_FILE:-../../config/config.yml}:/app/config.yml:ro # Mount config.yml for model registry and memory settings + - ${CONFIG_FILE:-../../config/config.yml}:/app/config.yml # Mount config.yml for model registry and memory settings (writable for admin config updates) environment: # Same environment as backend - MONGODB_URI=mongodb://mongo-test:27017/test_db @@ -200,6 +200,39 @@ services: condition: service_healthy restart: unless-stopped + deepgram-streaming-worker-test: + build: + context: . + dockerfile: Dockerfile + command: > + uv run python -m advanced_omi_backend.workers.audio_stream_deepgram_streaming_worker + volumes: + - ./src:/app/src + - ./data/test_data:/app/data + - ${CONFIG_FILE:-../../config/config.yml}:/app/config.yml + - ${PLUGINS_CONFIG:-../../config/plugins.yml}:/app/plugins.yml + environment: + - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} + - REDIS_URL=redis://redis-test:6379/0 + - HA_TOKEN=${HA_TOKEN} + - MONGODB_URI=mongodb://mongo-test:27017/test_db + - QDRANT_BASE_URL=qdrant-test + - QDRANT_PORT=6333 + - DEBUG_DIR=/app/debug_dir + - OPENAI_API_KEY=${OPENAI_API_KEY} + - GROQ_API_KEY=${GROQ_API_KEY} + - AUTH_SECRET_KEY=test-jwt-signing-key-for-integration-tests + - ADMIN_PASSWORD=test-admin-password-123 + - ADMIN_EMAIL=test-admin@example.com + - TRANSCRIPTION_PROVIDER=${TRANSCRIPTION_PROVIDER:-deepgram} + - MEMORY_PROVIDER=${MEMORY_PROVIDER:-chronicle} + depends_on: + redis-test: + condition: service_started + mongo-test: + condition: service_healthy + restart: unless-stopped + # Mycelia - AI memory and timeline service (test environment) # mycelia-backend-test: # build: diff --git a/backends/advanced/docker-compose.yml b/backends/advanced/docker-compose.yml index f46a23fa..4e6ba153 100644 --- a/backends/advanced/docker-compose.yml +++ b/backends/advanced/docker-compose.yml @@ -1,4 +1,30 @@ services: + tailscale: + image: tailscale/tailscale:latest + container_name: advanced-tailscale + hostname: chronicle-tailscale + environment: + - TS_AUTHKEY=${TS_AUTHKEY} + - TS_STATE_DIR=/var/lib/tailscale + - TS_USERSPACE=false + - TS_ACCEPT_DNS=true + volumes: + - tailscale-state:/var/lib/tailscale + devices: + - /dev/net/tun:/dev/net/tun + cap_add: + - NET_ADMIN + restart: unless-stopped + profiles: + - tailscale # Optional profile + ports: + - "18123:18123" # HA proxy port + command: > + sh -c "tailscaled & + tailscale up --authkey=$${TS_AUTHKEY} --accept-dns=true && + apk add --no-cache socat 2>/dev/null || true && + socat TCP-LISTEN:18123,fork,reuseaddr TCP:100.99.62.5:8123" + chronicle-backend: build: context: . @@ -12,7 +38,8 @@ services: - ./data/audio_chunks:/app/audio_chunks - ./data/debug_dir:/app/debug_dir - ./data:/app/data - - ../../config/config.yml:/app/config.yml # Removed :ro to allow UI config saving + - ../../config/config.yml:/app/config.yml # Main config file + - ../../config/plugins.yml:/app/plugins.yml # Plugin configuration environment: - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - PARAKEET_ASR_URL=${PARAKEET_ASR_URL} @@ -26,6 +53,7 @@ services: - NEO4J_HOST=${NEO4J_HOST} - NEO4J_USER=${NEO4J_USER} - NEO4J_PASSWORD=${NEO4J_PASSWORD} + - HA_TOKEN=${HA_TOKEN} - CORS_ORIGINS=http://localhost:3010,http://localhost:8000,http://192.168.1.153:3010,http://192.168.1.153:8000,https://localhost:3010,https://localhost:8000,https://100.105.225.45,https://localhost - REDIS_URL=redis://redis:6379/0 depends_on: @@ -35,6 +63,8 @@ services: condition: service_healthy redis: condition: service_healthy + extra_hosts: + - "host.docker.internal:host-gateway" # Access host's Tailscale network healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8000/readiness"] interval: 30s @@ -61,11 +91,13 @@ services: - ./data/audio_chunks:/app/audio_chunks - ./data:/app/data - ../../config/config.yml:/app/config.yml # Removed :ro for consistency + - ../../config/plugins.yml:/app/plugins.yml # Plugin configuration environment: - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} - PARAKEET_ASR_URL=${PARAKEET_ASR_URL} - OPENAI_API_KEY=${OPENAI_API_KEY} - GROQ_API_KEY=${GROQ_API_KEY} + - HA_TOKEN=${HA_TOKEN} - REDIS_URL=redis://redis:6379/0 depends_on: redis: @@ -76,6 +108,35 @@ services: condition: service_started restart: unless-stopped + # Deepgram WebSocket streaming worker + # Real-time transcription worker that processes audio via Deepgram's WebSocket API + # Publishes interim results to Redis Pub/Sub for client display + # Publishes final results to Redis Streams for storage + # Triggers plugins on final results only + deepgram-streaming-worker: + build: + context: . + dockerfile: Dockerfile + command: > + uv run python -m advanced_omi_backend.workers.audio_stream_deepgram_streaming_worker + env_file: + - .env + volumes: + - ./src:/app/src + - ./data:/app/data + - ../../config/config.yml:/app/config.yml + - ../../config/plugins.yml:/app/plugins.yml + environment: + - DEEPGRAM_API_KEY=${DEEPGRAM_API_KEY} + - REDIS_URL=redis://redis:6379/0 + - HA_TOKEN=${HA_TOKEN} + depends_on: + redis: + condition: service_healthy + extra_hosts: + - "host.docker.internal:host-gateway" + restart: unless-stopped + webui: build: context: ./webui @@ -226,3 +287,5 @@ volumes: driver: local neo4j_logs: driver: local + tailscale-state: + driver: local diff --git a/backends/advanced/init.py b/backends/advanced/init.py index fe04fd15..601120ad 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -49,6 +49,9 @@ def __init__(self, args=None): self.console.print("[red][ERROR][/red] Run wizard.py from project root to create config.yml") sys.exit(1) + # Ensure plugins.yml exists (copy from template if missing) + self._ensure_plugins_yml_exists() + def print_header(self, title: str): """Print a colorful header""" self.console.print() @@ -107,6 +110,26 @@ def prompt_choice(self, prompt: str, choices: Dict[str, str], default: str = "1" self.console.print(f"Using default choice: {default}") return default + def _ensure_plugins_yml_exists(self): + """Ensure plugins.yml exists by copying from template if missing.""" + plugins_yml = Path("../../config/plugins.yml") + plugins_template = Path("../../config/plugins.yml.template") + + if not plugins_yml.exists(): + if plugins_template.exists(): + self.console.print("[blue][INFO][/blue] plugins.yml not found, creating from template...") + shutil.copy2(plugins_template, plugins_yml) + self.console.print(f"[green]✅[/green] Created {plugins_yml} from template") + self.console.print("[yellow][NOTE][/yellow] Edit config/plugins.yml to configure plugins") + self.console.print("[yellow][NOTE][/yellow] Set HA_TOKEN in .env for Home Assistant integration") + else: + raise RuntimeError( + f"Template file not found: {plugins_template}\n" + f"The repository structure is incomplete. Please ensure config/plugins.yml.template exists." + ) + else: + self.console.print(f"[blue][INFO][/blue] Found existing {plugins_yml}") + def backup_existing_env(self): """Backup existing .env file""" env_path = Path(".env") @@ -136,6 +159,41 @@ def mask_api_key(self, key: str, show_chars: int = 5) -> str: return f"{key_clean[:show_chars]}{'*' * min(15, len(key_clean) - show_chars * 2)}{key_clean[-show_chars:]}" + def prompt_with_existing_masked(self, prompt_text: str, env_key: str, placeholders: list, + is_password: bool = False, default: str = "") -> str: + """ + Prompt for a value, showing masked existing value from .env if present. + + Args: + prompt_text: The prompt to display + env_key: The .env key to check for existing value + placeholders: List of placeholder values to treat as "not set" + is_password: Whether to mask the value (for passwords/tokens) + default: Default value if no existing value + + Returns: + User input value, existing value if reused, or default + """ + existing_value = self.read_existing_env_value(env_key) + + # Check if existing value is valid (not empty and not a placeholder) + has_valid_existing = existing_value and existing_value not in placeholders + + if has_valid_existing: + # Show masked value with option to reuse + if is_password: + masked = self.mask_api_key(existing_value) + display_prompt = f"{prompt_text} ({masked}) [press Enter to reuse, or enter new]" + else: + display_prompt = f"{prompt_text} ({existing_value}) [press Enter to reuse, or enter new]" + + user_input = self.prompt_value(display_prompt, "") + # If user pressed Enter (empty input), reuse existing value + return user_input if user_input else existing_value + else: + # No existing value, prompt normally + return self.prompt_value(prompt_text, default) + def setup_authentication(self): """Configure authentication settings""" @@ -169,15 +227,14 @@ def setup_transcription(self): self.console.print("[blue][INFO][/blue] Deepgram selected") self.console.print("Get your API key from: https://console.deepgram.com/") - # Check for existing API key - existing_key = self.read_existing_env_value("DEEPGRAM_API_KEY") - if existing_key and existing_key not in ['your_deepgram_api_key_here', 'your-deepgram-key-here']: - masked_key = self.mask_api_key(existing_key) - prompt_text = f"Deepgram API key ({masked_key}) [press Enter to reuse, or enter new]" - api_key_input = self.prompt_value(prompt_text, "") - api_key = api_key_input if api_key_input else existing_key - else: - api_key = self.prompt_value("Deepgram API key (leave empty to skip)", "") + # Use the new masked prompt function + api_key = self.prompt_with_existing_masked( + prompt_text="Deepgram API key (leave empty to skip)", + env_key="DEEPGRAM_API_KEY", + placeholders=['your_deepgram_api_key_here', 'your-deepgram-key-here'], + is_password=True, + default="" + ) if api_key: # Write API key to .env @@ -227,15 +284,14 @@ def setup_llm(self): self.console.print("[blue][INFO][/blue] OpenAI selected") self.console.print("Get your API key from: https://platform.openai.com/api-keys") - # Check for existing API key - existing_key = self.read_existing_env_value("OPENAI_API_KEY") - if existing_key and existing_key not in ['your_openai_api_key_here', 'your-openai-key-here']: - masked_key = self.mask_api_key(existing_key) - prompt_text = f"OpenAI API key ({masked_key}) [press Enter to reuse, or enter new]" - api_key_input = self.prompt_value(prompt_text, "") - api_key = api_key_input if api_key_input else existing_key - else: - api_key = self.prompt_value("OpenAI API key (leave empty to skip)", "") + # Use the new masked prompt function + api_key = self.prompt_with_existing_masked( + prompt_text="OpenAI API key (leave empty to skip)", + env_key="OPENAI_API_KEY", + placeholders=['your_openai_api_key_here', 'your-openai-key-here'], + is_password=True, + default="" + ) if api_key: self.config["OPENAI_API_KEY"] = api_key @@ -347,6 +403,11 @@ def setup_optional_services(self): self.config["PARAKEET_ASR_URL"] = self.args.parakeet_asr_url self.console.print(f"[green][SUCCESS][/green] Parakeet ASR configured via args: {self.args.parakeet_asr_url}") + # Check if Tailscale auth key provided via args + if hasattr(self.args, 'ts_authkey') and self.args.ts_authkey: + self.config["TS_AUTHKEY"] = self.args.ts_authkey + self.console.print(f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)") + def setup_obsidian(self): """Configure Obsidian/Neo4j integration""" # Check if enabled via command line @@ -420,14 +481,14 @@ def setup_https(self): self.console.print("[blue][INFO][/blue] For distributed deployments, use your Tailscale IP (e.g., 100.64.1.2)") self.console.print("[blue][INFO][/blue] For local-only access, use 'localhost'") - # Check for existing SERVER_IP - existing_ip = self.read_existing_env_value("SERVER_IP") - if existing_ip and existing_ip not in ['localhost', 'your-server-ip-here']: - prompt_text = f"Server IP/Domain for SSL certificate ({existing_ip}) [press Enter to reuse, or enter new]" - server_ip_input = self.prompt_value(prompt_text, "") - server_ip = server_ip_input if server_ip_input else existing_ip - else: - server_ip = self.prompt_value("Server IP/Domain for SSL certificate (Tailscale IP or localhost)", "localhost") + # Use the new masked prompt function (not masked for IP, but shows existing) + server_ip = self.prompt_with_existing_masked( + prompt_text="Server IP/Domain for SSL certificate (Tailscale IP or localhost)", + env_key="SERVER_IP", + placeholders=['localhost', 'your-server-ip-here'], + is_password=False, + default="localhost" + ) if enable_https: @@ -702,6 +763,8 @@ def main(): help="Enable Obsidian/Neo4j integration (default: prompt user)") parser.add_argument("--neo4j-password", help="Neo4j password (default: prompt user)") + parser.add_argument("--ts-authkey", + help="Tailscale auth key for Docker integration (default: prompt user)") args = parser.parse_args() diff --git a/backends/advanced/run-test.sh b/backends/advanced/run-test.sh index 01204be6..a18dc895 100755 --- a/backends/advanced/run-test.sh +++ b/backends/advanced/run-test.sh @@ -91,6 +91,29 @@ if [ -n "$_CONFIG_FILE_OVERRIDE" ]; then print_info "Using command-line override: CONFIG_FILE=$CONFIG_FILE" fi +# Load HF_TOKEN from speaker-recognition/.env (proper location for this credential) +SPEAKER_ENV="../../extras/speaker-recognition/.env" +if [ -f "$SPEAKER_ENV" ] && [ -z "$HF_TOKEN" ]; then + print_info "Loading HF_TOKEN from speaker-recognition service..." + set -a + source "$SPEAKER_ENV" + set +a +fi + +# Display HF_TOKEN status with masking +if [ -n "$HF_TOKEN" ]; then + if [ ${#HF_TOKEN} -gt 15 ]; then + MASKED_TOKEN="${HF_TOKEN:0:5}***************${HF_TOKEN: -5}" + else + MASKED_TOKEN="***************" + fi + print_info "HF_TOKEN configured: $MASKED_TOKEN" + export HF_TOKEN +else + print_warning "HF_TOKEN not found - speaker recognition tests may fail" + print_info "Configure via wizard: uv run --with-requirements ../../setup-requirements.txt python ../../wizard.py" +fi + # Set default CONFIG_FILE if not provided # This allows testing with different provider combinations # Usage: CONFIG_FILE=../../tests/configs/parakeet-ollama.yml ./run-test.sh @@ -166,6 +189,18 @@ if [ ! -f "diarization_config.json" ] && [ -f "diarization_config.json.template" print_success "diarization_config.json created" fi +# Ensure plugins.yml exists (required for Docker volume mount) +if [ ! -f "../../config/plugins.yml" ]; then + if [ -f "../../config/plugins.yml.template" ]; then + print_info "Creating config/plugins.yml from template..." + cp ../../config/plugins.yml.template ../../config/plugins.yml + print_success "config/plugins.yml created" + else + print_error "config/plugins.yml.template not found - repository structure incomplete" + exit 1 + fi +fi + # Note: Robot Framework dependencies are managed via tests/test-requirements.txt # The integration tests use Docker containers for service dependencies @@ -176,15 +211,25 @@ print_info "Using environment variables from .env file for test configuration" # Clean test environment print_info "Cleaning test environment..." -sudo rm -rf ./test_audio_chunks/ ./test_data/ ./test_debug_dir/ ./mongo_data_test/ ./qdrant_data_test/ ./test_neo4j/ || true +rm -rf ./test_audio_chunks/ ./test_data/ ./test_debug_dir/ ./mongo_data_test/ ./qdrant_data_test/ ./test_neo4j/ 2>/dev/null || true + +# If cleanup fails due to permissions, try with docker +if [ -d "./data/test_audio_chunks/" ] || [ -d "./data/test_data/" ] || [ -d "./data/test_debug_dir/" ]; then + print_warning "Permission denied, using docker to clean test directories..." + docker run --rm -v "$(pwd)/data:/data" alpine sh -c 'rm -rf /data/test_*' 2>/dev/null || true +fi # Use unique project name to avoid conflicts with development environment export COMPOSE_PROJECT_NAME="advanced-backend-test" # Stop any existing test containers print_info "Stopping existing test containers..." +# Try cleanup with current project name docker compose -f docker-compose-test.yml down -v || true +# Also try cleanup with default project name (in case containers were started without COMPOSE_PROJECT_NAME) +COMPOSE_PROJECT_NAME=advanced docker compose -f docker-compose-test.yml down -v 2>/dev/null || true + # Run integration tests print_info "Running integration tests..." print_info "Using fresh mode (CACHED_MODE=False) for clean testing" @@ -222,6 +267,8 @@ else if [ "${CLEANUP_CONTAINERS:-true}" != "false" ]; then print_info "Cleaning up test containers after failure..." docker compose -f docker-compose-test.yml down -v || true + # Also cleanup with default project name + COMPOSE_PROJECT_NAME=advanced docker compose -f docker-compose-test.yml down -v 2>/dev/null || true docker system prune -f || true else print_warning "Skipping cleanup (CLEANUP_CONTAINERS=false) - containers left running for debugging" @@ -234,6 +281,8 @@ fi if [ "${CLEANUP_CONTAINERS:-true}" != "false" ]; then print_info "Cleaning up test containers..." docker compose -f docker-compose-test.yml down -v || true + # Also cleanup with default project name + COMPOSE_PROJECT_NAME=advanced docker compose -f docker-compose-test.yml down -v 2>/dev/null || true docker system prune -f || true else print_warning "Skipping cleanup (CLEANUP_CONTAINERS=false) - containers left running" diff --git a/backends/advanced/src/advanced_omi_backend/app_factory.py b/backends/advanced/src/advanced_omi_backend/app_factory.py index 7ccda184..8a162cec 100644 --- a/backends/advanced/src/advanced_omi_backend/app_factory.py +++ b/backends/advanced/src/advanced_omi_backend/app_factory.py @@ -111,6 +111,11 @@ async def lifespan(app: FastAPI): from advanced_omi_backend.services.audio_stream import AudioStreamProducer app.state.audio_stream_producer = AudioStreamProducer(app.state.redis_audio_stream) application_logger.info("✅ Redis client for audio streaming producer initialized") + + # Initialize ClientManager Redis for cross-container client→user mapping + from advanced_omi_backend.client_manager import initialize_redis_for_client_manager + initialize_redis_for_client_manager(config.redis_url) + except Exception as e: application_logger.error(f"Failed to initialize Redis client for audio streaming: {e}", exc_info=True) application_logger.warning("Audio streaming producer will not be available") @@ -122,6 +127,36 @@ async def lifespan(app: FastAPI): # SystemTracker is used for monitoring and debugging application_logger.info("Using SystemTracker for monitoring and debugging") + # Initialize plugins using plugin service + try: + from advanced_omi_backend.services.plugin_service import init_plugin_router, set_plugin_router + + plugin_router = init_plugin_router() + + if plugin_router: + # Initialize async resources for each enabled plugin + for plugin_id, plugin in plugin_router.plugins.items(): + if plugin.enabled: + try: + await plugin.initialize() + application_logger.info(f"✅ Plugin '{plugin_id}' initialized") + except Exception as e: + application_logger.error(f"Failed to initialize plugin '{plugin_id}': {e}", exc_info=True) + + application_logger.info(f"Plugins initialized: {len(plugin_router.plugins)} active") + + # Store in app state for API access + app.state.plugin_router = plugin_router + # Register with plugin service for worker access + set_plugin_router(plugin_router) + else: + application_logger.info("No plugins configured") + app.state.plugin_router = None + + except Exception as e: + application_logger.error(f"Failed to initialize plugin system: {e}", exc_info=True) + app.state.plugin_router = None + application_logger.info("Application ready - using application-level processing architecture.") logger.info("App ready") @@ -162,6 +197,14 @@ async def lifespan(app: FastAPI): # Stop metrics collection and save final report application_logger.info("Metrics collection stopped") + # Shutdown plugins + try: + from advanced_omi_backend.services.plugin_service import cleanup_plugin_router + await cleanup_plugin_router() + application_logger.info("Plugins shut down") + except Exception as e: + application_logger.error(f"Error shutting down plugins: {e}") + # Shutdown memory service and speaker service shutdown_memory_service() application_logger.info("Memory and speaker services shut down.") diff --git a/backends/advanced/src/advanced_omi_backend/chat_service.py b/backends/advanced/src/advanced_omi_backend/chat_service.py index de92a4b9..16cba331 100644 --- a/backends/advanced/src/advanced_omi_backend/chat_service.py +++ b/backends/advanced/src/advanced_omi_backend/chat_service.py @@ -22,6 +22,7 @@ from advanced_omi_backend.database import get_database from advanced_omi_backend.llm_client import get_llm_client +from advanced_omi_backend.model_registry import get_models_registry from advanced_omi_backend.services.memory import get_memory_service from advanced_omi_backend.services.memory.base import MemoryEntry from advanced_omi_backend.services.obsidian_service import ( @@ -133,7 +134,7 @@ def from_dict(cls, data: Dict) -> "ChatSession": class ChatService: """Service for managing chat sessions and memory-enhanced conversations.""" - + def __init__(self): self.db = None self.sessions_collection: Optional[AsyncIOMotorCollection] = None @@ -142,6 +143,33 @@ def __init__(self): self.memory_service = None self._initialized = False + def _get_system_prompt(self) -> str: + """ + Get system prompt from config with fallback to default. + + Returns: + str: System prompt for chat interactions + """ + try: + reg = get_models_registry() + if reg and hasattr(reg, 'chat'): + chat_config = reg.chat + prompt = chat_config.get('system_prompt') + if prompt: + logger.info(f"✅ Loaded chat system prompt from config (length: {len(prompt)} chars)") + logger.debug(f"System prompt: {prompt[:100]}...") + return prompt + except Exception as e: + logger.warning(f"Failed to load chat system prompt from config: {e}") + + # Fallback to default + logger.info("⚠️ Using default chat system prompt (config not found)") + return """You are a helpful AI assistant with access to the user's personal memories and conversation history. + +Use the provided memories and conversation context to give personalized, contextual responses. If memories are relevant, reference them naturally in your response. Be conversational and helpful. + +If no relevant memories are available, respond normally based on the conversation context.""" + async def initialize(self): """Initialize the chat service with database connections.""" if self._initialized: @@ -392,12 +420,8 @@ async def generate_response_stream( "timestamp": time.time() } - # Create system prompt - system_prompt = """You are a helpful AI assistant with access to the user's personal memories and conversation history. - -Use the provided memories and conversation context to give personalized, contextual responses. If memories are relevant, reference them naturally in your response. Be conversational and helpful. - -If no relevant memories are available, respond normally based on the conversation context.""" + # Get system prompt from config + system_prompt = self._get_system_prompt() # Prepare full prompt full_prompt = f"{system_prompt}\n\n{context}" diff --git a/backends/advanced/src/advanced_omi_backend/client_manager.py b/backends/advanced/src/advanced_omi_backend/client_manager.py index 5a3131b5..e55b3502 100644 --- a/backends/advanced/src/advanced_omi_backend/client_manager.py +++ b/backends/advanced/src/advanced_omi_backend/client_manager.py @@ -9,6 +9,7 @@ import logging import uuid from typing import TYPE_CHECKING, Dict, Optional +import redis.asyncio as redis if TYPE_CHECKING: from advanced_omi_backend.client import ClientState @@ -21,6 +22,9 @@ _client_to_user_mapping: Dict[str, str] = {} # Active clients only _all_client_user_mappings: Dict[str, str] = {} # All clients including disconnected +# Redis client for cross-container client→user mapping +_redis_client: Optional[redis.Redis] = None + class ClientManager: """ @@ -372,9 +376,33 @@ def unregister_client_user_mapping(client_id: str): logger.warning(f"⚠️ Attempted to unregister non-existent client {client_id}") +async def track_client_user_relationship_async(client_id: str, user_id: str, ttl: int = 86400): + """ + Track that a client belongs to a user (async, writes to Redis for cross-container support). + + Args: + client_id: The client ID + user_id: The user ID that owns this client + ttl: Time-to-live in seconds (default 24 hours) + """ + _all_client_user_mappings[client_id] = user_id # In-memory fallback + + if _redis_client: + try: + await _redis_client.setex(f"client:owner:{client_id}", ttl, user_id) + logger.debug(f"✅ Tracked client {client_id} → user {user_id} in Redis (TTL: {ttl}s)") + except Exception as e: + logger.warning(f"Failed to track client in Redis: {e}") + else: + logger.debug(f"Tracked client {client_id} relationship to user {user_id} (in-memory only)") + + def track_client_user_relationship(client_id: str, user_id: str): """ - Track that a client belongs to a user (persists after disconnection for database queries). + Track that a client belongs to a user (sync version for backward compatibility). + + WARNING: This is synchronous and cannot use Redis. Use track_client_user_relationship_async() + instead in async contexts for cross-container support. Args: client_id: The client ID @@ -444,9 +472,45 @@ def get_user_clients_active(user_id: str) -> list[str]: return user_clients +def initialize_redis_for_client_manager(redis_url: str): + """ + Initialize Redis client for cross-container client→user mapping. + + Args: + redis_url: Redis connection URL + """ + global _redis_client + _redis_client = redis.from_url(redis_url, decode_responses=True) + logger.info(f"✅ ClientManager Redis initialized: {redis_url}") + + +async def get_client_owner_async(client_id: str) -> Optional[str]: + """ + Get the user ID that owns a specific client (async Redis lookup). + + Args: + client_id: The client ID to look up + + Returns: + User ID if found, None otherwise + """ + if _redis_client: + try: + user_id = await _redis_client.get(f"client:owner:{client_id}") + return user_id + except Exception as e: + logger.warning(f"Redis lookup failed for client {client_id}: {e}") + + # Fallback to in-memory mapping + return _all_client_user_mappings.get(client_id) + + def get_client_owner(client_id: str) -> Optional[str]: """ - Get the user ID that owns a specific client. + Get the user ID that owns a specific client (sync version for backward compatibility). + + WARNING: This is synchronous and cannot use Redis. Use get_client_owner_async() instead + in async contexts for cross-container support. Args: client_id: The client ID to look up diff --git a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py index a3836898..d1a22695 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py @@ -9,13 +9,61 @@ import logging import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Literal from fastapi.responses import JSONResponse logger = logging.getLogger(__name__) +async def mark_session_complete( + redis_client, + session_id: str, + reason: Literal[ + "websocket_disconnect", + "user_stopped", + "inactivity_timeout", + "max_duration", + "all_jobs_complete" + ], +) -> None: + """ + Single source of truth for marking sessions as complete. + + This function ensures that both 'status' and 'completion_reason' are ALWAYS + set together atomically, preventing race conditions where workers check status + before completion_reason is set. + + Args: + redis_client: Redis async client + session_id: Session UUID + reason: Why the session is completing (enforced by type system) + + Usage: + # WebSocket disconnect + await mark_session_complete(redis, session_id, "websocket_disconnect") + + # User manually stopped + await mark_session_complete(redis, session_id, "user_stopped") + + # Inactivity timeout + await mark_session_complete(redis, session_id, "inactivity_timeout") + + # Max duration reached + await mark_session_complete(redis, session_id, "max_duration") + + # All jobs finished + await mark_session_complete(redis, session_id, "all_jobs_complete") + """ + session_key = f"audio:session:{session_id}" + await redis_client.hset(session_key, mapping={ + "status": "complete", + "completed_at": str(time.time()), + "completion_reason": reason + }) + logger.info(f"✅ Session {session_id[:12]} marked complete: {reason}") + + async def get_session_info(redis_client, session_id: str) -> Optional[Dict]: """ Get detailed information about a specific session. @@ -192,8 +240,7 @@ async def get_streaming_status(request): # All jobs complete - this is truly a completed session # Update Redis status if it wasn't already marked complete if status not in ["complete", "completed", "finalized"]: - await redis_client.hset(key, "status", "complete") - logger.info(f"✅ Marked session {session_id} as complete (all jobs terminal)") + await mark_session_complete(redis_client, session_id, "all_jobs_complete") # Get additional session data for completed sessions session_key = f"audio:session:{session_id}" diff --git a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py index 17b9cbcf..f5ff3275 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py @@ -7,6 +7,7 @@ import shutil import time from datetime import UTC, datetime +from pathlib import Path import yaml from fastapi import HTTPException @@ -455,3 +456,239 @@ async def set_memory_provider(provider: str): except Exception as e: logger.exception("Error setting memory provider") raise e + + +# Chat Configuration Management Functions + +async def get_chat_config_yaml() -> str: + """Get chat system prompt as plain text.""" + try: + config_path = _find_config_path() + + default_prompt = """You are a helpful AI assistant with access to the user's personal memories and conversation history. + +Use the provided memories and conversation context to give personalized, contextual responses. If memories are relevant, reference them naturally in your response. Be conversational and helpful. + +If no relevant memories are available, respond normally based on the conversation context.""" + + if not os.path.exists(config_path): + return default_prompt + + with open(config_path, 'r') as f: + full_config = yaml.safe_load(f) or {} + + chat_config = full_config.get('chat', {}) + system_prompt = chat_config.get('system_prompt', default_prompt) + + # Return just the prompt text, not the YAML structure + return system_prompt + + except Exception as e: + logger.error(f"Error loading chat config: {e}") + raise + + +async def save_chat_config_yaml(prompt_text: str) -> dict: + """Save chat system prompt from plain text.""" + try: + config_path = _find_config_path() + + # Validate plain text prompt + if not prompt_text or not isinstance(prompt_text, str): + raise ValueError("Prompt must be a non-empty string") + + prompt_text = prompt_text.strip() + if len(prompt_text) < 10: + raise ValueError("Prompt too short (minimum 10 characters)") + if len(prompt_text) > 10000: + raise ValueError("Prompt too long (maximum 10000 characters)") + + # Create chat config dict + chat_config = {'system_prompt': prompt_text} + + # Load full config + if os.path.exists(config_path): + with open(config_path, 'r') as f: + full_config = yaml.safe_load(f) or {} + else: + full_config = {} + + # Backup existing config + if os.path.exists(config_path): + backup_path = str(config_path) + '.backup' + shutil.copy2(config_path, backup_path) + logger.info(f"Created config backup at {backup_path}") + + # Update chat section + full_config['chat'] = chat_config + + # Save + with open(config_path, 'w') as f: + yaml.dump(full_config, f, default_flow_style=False, allow_unicode=True) + + # Reload config in memory (hot-reload) + load_models_config(force_reload=True) + + logger.info("Chat configuration updated successfully") + + return {"success": True, "message": "Chat configuration updated successfully"} + + except Exception as e: + logger.error(f"Error saving chat config: {e}") + raise + + +async def validate_chat_config_yaml(prompt_text: str) -> dict: + """Validate chat system prompt plain text.""" + try: + # Validate plain text prompt + if not isinstance(prompt_text, str): + return {"valid": False, "error": "Prompt must be a string"} + + prompt_text = prompt_text.strip() + if len(prompt_text) < 10: + return {"valid": False, "error": "Prompt too short (minimum 10 characters)"} + if len(prompt_text) > 10000: + return {"valid": False, "error": "Prompt too long (maximum 10000 characters)"} + + return {"valid": True, "message": "Configuration is valid"} + + except Exception as e: + logger.error(f"Error validating chat config: {e}") + return {"valid": False, "error": f"Validation error: {str(e)}"} + + +# Plugin Configuration Management Functions + +async def get_plugins_config_yaml() -> str: + """Get plugins configuration as YAML text.""" + try: + plugins_yml_path = Path("/app/plugins.yml") + + # Default empty plugins config + default_config = """plugins: + # No plugins configured yet + # Example plugin configuration: + # homeassistant: + # enabled: true + # access_level: transcript + # trigger: + # type: wake_word + # wake_word: vivi + # ha_url: http://localhost:8123 + # ha_token: YOUR_TOKEN_HERE +""" + + if not plugins_yml_path.exists(): + return default_config + + with open(plugins_yml_path, 'r') as f: + yaml_content = f.read() + + return yaml_content + + except Exception as e: + logger.error(f"Error loading plugins config: {e}") + raise + + +async def save_plugins_config_yaml(yaml_content: str) -> dict: + """Save plugins configuration from YAML text.""" + try: + plugins_yml_path = Path("/app/plugins.yml") + + # Validate YAML can be parsed + try: + parsed_config = yaml.safe_load(yaml_content) + if not isinstance(parsed_config, dict): + raise ValueError("Configuration must be a YAML dictionary") + + # Validate has 'plugins' key + if 'plugins' not in parsed_config: + raise ValueError("Configuration must contain 'plugins' key") + + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML syntax: {e}") + + # Create config directory if it doesn't exist + plugins_yml_path.parent.mkdir(parents=True, exist_ok=True) + + # Backup existing config + if plugins_yml_path.exists(): + backup_path = str(plugins_yml_path) + '.backup' + shutil.copy2(plugins_yml_path, backup_path) + logger.info(f"Created plugins config backup at {backup_path}") + + # Save new config + with open(plugins_yml_path, 'w') as f: + f.write(yaml_content) + + # Hot-reload plugins (optional - may require restart) + try: + from advanced_omi_backend.services.plugin_service import get_plugin_router + plugin_router = get_plugin_router() + if plugin_router: + logger.info("Plugin configuration updated - restart backend for changes to take effect") + except Exception as reload_err: + logger.warning(f"Could not reload plugins: {reload_err}") + + logger.info("Plugins configuration updated successfully") + + return { + "success": True, + "message": "Plugins configuration updated successfully. Restart backend for changes to take effect." + } + + except Exception as e: + logger.error(f"Error saving plugins config: {e}") + raise + + +async def validate_plugins_config_yaml(yaml_content: str) -> dict: + """Validate plugins configuration YAML.""" + try: + # Parse YAML + try: + parsed_config = yaml.safe_load(yaml_content) + except yaml.YAMLError as e: + return {"valid": False, "error": f"Invalid YAML syntax: {e}"} + + # Check structure + if not isinstance(parsed_config, dict): + return {"valid": False, "error": "Configuration must be a YAML dictionary"} + + if 'plugins' not in parsed_config: + return {"valid": False, "error": "Configuration must contain 'plugins' key"} + + plugins = parsed_config['plugins'] + if not isinstance(plugins, dict): + return {"valid": False, "error": "'plugins' must be a dictionary"} + + # Validate each plugin + valid_access_levels = ['transcript', 'conversation', 'memory'] + valid_trigger_types = ['wake_word', 'always', 'conditional'] + + for plugin_id, plugin_config in plugins.items(): + if not isinstance(plugin_config, dict): + return {"valid": False, "error": f"Plugin '{plugin_id}' config must be a dictionary"} + + # Check required fields + if 'enabled' in plugin_config and not isinstance(plugin_config['enabled'], bool): + return {"valid": False, "error": f"Plugin '{plugin_id}': 'enabled' must be boolean"} + + if 'access_level' in plugin_config and plugin_config['access_level'] not in valid_access_levels: + return {"valid": False, "error": f"Plugin '{plugin_id}': invalid access_level (must be one of {valid_access_levels})"} + + if 'trigger' in plugin_config: + trigger = plugin_config['trigger'] + if not isinstance(trigger, dict): + return {"valid": False, "error": f"Plugin '{plugin_id}': 'trigger' must be a dictionary"} + + if 'type' in trigger and trigger['type'] not in valid_trigger_types: + return {"valid": False, "error": f"Plugin '{plugin_id}': invalid trigger type (must be one of {valid_trigger_types})"} + + return {"valid": True, "message": "Configuration is valid"} + + except Exception as e: + logger.error(f"Error validating plugins config: {e}") + return {"valid": False, "error": f"Validation error: {str(e)}"} diff --git a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py index 50ffc77f..2b98bcbb 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/websocket_controller.py @@ -17,10 +17,12 @@ from fastapi import WebSocket, WebSocketDisconnect, Query from friend_lite.decoder import OmiOpusDecoder +import redis.asyncio as redis from advanced_omi_backend.auth import websocket_auth from advanced_omi_backend.client_manager import generate_client_id, get_client_manager from advanced_omi_backend.constants import OMI_CHANNELS, OMI_SAMPLE_RATE, OMI_SAMPLE_WIDTH +from advanced_omi_backend.controllers.session_controller import mark_session_complete from advanced_omi_backend.utils.audio_utils import process_audio_chunk from advanced_omi_backend.services.audio_stream import AudioStreamProducer from advanced_omi_backend.services.audio_stream.producer import get_audio_stream_producer @@ -39,6 +41,89 @@ pending_connections: set[str] = set() +async def subscribe_to_interim_results(websocket: WebSocket, session_id: str) -> None: + """ + Subscribe to interim transcription results from Redis Pub/Sub and forward to client WebSocket. + + Runs as background task during WebSocket connection. Listens for interim and final + transcription results published by the Deepgram streaming consumer and forwards them + to the connected client for real-time transcript display. + + Args: + websocket: Connected WebSocket client + session_id: Session ID (client_id) to subscribe to + + Note: + This task runs continuously until the WebSocket disconnects or the task is cancelled. + Results are published to Redis Pub/Sub channel: transcription:interim:{session_id} + """ + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + + try: + # Create Redis client for Pub/Sub + redis_client = await redis.from_url(redis_url, decode_responses=True) + + # Create Pub/Sub instance + pubsub = redis_client.pubsub() + + # Subscribe to interim results channel for this session + channel = f"transcription:interim:{session_id}" + await pubsub.subscribe(channel) + + logger.info(f"📢 Subscribed to interim results channel: {channel}") + + # Listen for messages + while True: + try: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) + + if message and message['type'] == 'message': + # Parse result data + try: + result_data = json.loads(message['data']) + + # Forward to client WebSocket + await websocket.send_json({ + "type": "interim_transcript", + "data": result_data + }) + + # Log for debugging + is_final = result_data.get("is_final", False) + text_preview = result_data.get("text", "")[:50] + result_type = "FINAL" if is_final else "interim" + logger.debug(f"✉️ Forwarded {result_type} result to client {session_id}: {text_preview}...") + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse interim result JSON: {e}") + except Exception as send_error: + logger.error(f"Failed to send interim result to client {session_id}: {send_error}") + # WebSocket might be closed, exit loop + break + + except asyncio.TimeoutError: + # No message received, continue waiting + continue + except asyncio.CancelledError: + logger.info(f"Interim results subscriber cancelled for session {session_id}") + break + except Exception as e: + logger.error(f"Error in interim results subscriber for {session_id}: {e}", exc_info=True) + break + + except Exception as e: + logger.error(f"Failed to initialize interim results subscriber for {session_id}: {e}", exc_info=True) + finally: + try: + # Unsubscribe and close connections + await pubsub.unsubscribe(channel) + await pubsub.close() + await redis_client.aclose() + logger.info(f"🔕 Unsubscribed from interim results channel: {channel}") + except Exception as cleanup_error: + logger.error(f"Error cleaning up interim results subscriber: {cleanup_error}") + + async def parse_wyoming_protocol(ws: WebSocket) -> tuple[dict, Optional[bytes]]: """Parse Wyoming protocol: JSON header line followed by optional binary payload. @@ -105,9 +190,9 @@ async def create_client_state(client_id: str, user, device_name: Optional[str] = client_id, CHUNK_DIR, user.user_id, user.email ) - # Also track in persistent mapping (for database queries) - from advanced_omi_backend.client_manager import track_client_user_relationship - track_client_user_relationship(client_id, user.user_id) + # Also track in persistent mapping (for database queries + cross-container Redis) + from advanced_omi_backend.client_manager import track_client_user_relationship_async + await track_client_user_relationship_async(client_id, user.user_id) # Register client in user model (persistent) from advanced_omi_backend.users import register_client_to_user @@ -166,13 +251,8 @@ async def cleanup_client_state(client_id: str): client_id_bytes = await async_redis.hget(key, "client_id") if client_id_bytes and client_id_bytes.decode() == client_id: # Mark session as complete (WebSocket disconnected) - await async_redis.hset(key, mapping={ - "status": "complete", - "completed_at": str(time.time()), - "completion_reason": "websocket_disconnect" - }) session_id = key.decode().replace("audio:session:", "") - logger.info(f"📊 Marked session {session_id[:12]} as complete (WebSocket disconnect)") + await mark_session_complete(async_redis, session_id, "websocket_disconnect") sessions_closed += 1 if cursor == 0: @@ -181,12 +261,12 @@ async def cleanup_client_state(client_id: str): if sessions_closed > 0: logger.info(f"✅ Closed {sessions_closed} active session(s) for client {client_id}") - # Delete Redis Streams for this client + # Set TTL on Redis Streams for this client (allows consumer groups to finish processing) stream_pattern = f"audio:stream:{client_id}" stream_key = await async_redis.exists(stream_pattern) if stream_key: - await async_redis.delete(stream_pattern) - logger.info(f"🧹 Deleted Redis stream: {stream_pattern}") + await async_redis.expire(stream_pattern, 60) # 60 second TTL for consumer group fan-out + logger.info(f"⏰ Set 60s TTL on Redis stream: {stream_pattern}") else: logger.debug(f"No Redis stream found for client {client_id}") @@ -279,8 +359,9 @@ async def _initialize_streaming_session( user_id: str, user_email: str, client_id: str, - audio_format: dict -) -> None: + audio_format: dict, + websocket: Optional[WebSocket] = None +) -> Optional[asyncio.Task]: """ Initialize streaming session with Redis and enqueue processing jobs. @@ -291,10 +372,14 @@ async def _initialize_streaming_session( user_email: User email client_id: Client ID audio_format: Audio format dict from audio-start event + websocket: Optional WebSocket connection to launch interim results subscriber + + Returns: + Interim results subscriber task if websocket provided and session initialized, None otherwise """ if hasattr(client_state, 'stream_session_id'): application_logger.debug(f"Session already initialized for {client_id}") - return + return None # Initialize stream session client_state.stream_session_id = str(uuid.uuid4()) @@ -340,6 +425,16 @@ async def _initialize_streaming_session( client_state.speech_detection_job_id = job_ids['speech_detection'] client_state.audio_persistence_job_id = job_ids['audio_persistence'] + # Launch interim results subscriber if WebSocket provided + subscriber_task = None + if websocket: + subscriber_task = asyncio.create_task( + subscribe_to_interim_results(websocket, client_state.stream_session_id) + ) + application_logger.info(f"📡 Launched interim results subscriber for session {client_state.stream_session_id}") + + return subscriber_task + async def _finalize_streaming_session( client_state, @@ -516,8 +611,9 @@ async def _handle_streaming_mode_audio( audio_format: dict, user_id: str, user_email: str, - client_id: str -) -> None: + client_id: str, + websocket: Optional[WebSocket] = None +) -> Optional[asyncio.Task]: """ Handle audio chunk in streaming mode. @@ -529,16 +625,22 @@ async def _handle_streaming_mode_audio( user_id: User ID user_email: User email client_id: Client ID + websocket: Optional WebSocket connection to launch interim results subscriber + + Returns: + Interim results subscriber task if websocket provided and session initialized, None otherwise """ # Initialize session if needed + subscriber_task = None if not hasattr(client_state, 'stream_session_id'): - await _initialize_streaming_session( + subscriber_task = await _initialize_streaming_session( client_state, audio_stream_producer, user_id, user_email, client_id, - audio_format + audio_format, + websocket=websocket # Pass WebSocket to launch interim results subscriber ) # Publish to Redis Stream @@ -553,6 +655,8 @@ async def _handle_streaming_mode_audio( audio_format.get("width", 2) ) + return subscriber_task + async def _handle_batch_mode_audio( client_state, @@ -589,8 +693,9 @@ async def _handle_audio_chunk( audio_format: dict, user_id: str, user_email: str, - client_id: str -) -> None: + client_id: str, + websocket: Optional[WebSocket] = None +) -> Optional[asyncio.Task]: """ Route audio chunk to appropriate mode handler (streaming or batch). @@ -602,18 +707,24 @@ async def _handle_audio_chunk( user_id: User ID user_email: User email client_id: Client ID + websocket: Optional WebSocket connection to launch interim results subscriber + + Returns: + Interim results subscriber task if websocket provided and streaming mode, None otherwise """ recording_mode = getattr(client_state, 'recording_mode', 'batch') if recording_mode == "streaming": - await _handle_streaming_mode_audio( + return await _handle_streaming_mode_audio( client_state, audio_stream_producer, audio_data, - audio_format, user_id, user_email, client_id + audio_format, user_id, user_email, client_id, + websocket=websocket ) else: await _handle_batch_mode_audio( client_state, audio_data, audio_format, client_id ) + return None async def _handle_audio_session_start( @@ -788,6 +899,7 @@ async def handle_omi_websocket( client_id = None client_state = None + interim_subscriber_task = None try: # Setup connection (accept, auth, create client state) @@ -814,13 +926,14 @@ async def handle_omi_websocket( if header["type"] == "audio-start": # Handle audio session start application_logger.info(f"🎙️ OMI audio session started for {client_id}") - await _initialize_streaming_session( + interim_subscriber_task = await _initialize_streaming_session( client_state, audio_stream_producer, user.user_id, user.email, client_id, - header.get("data", {"rate": OMI_SAMPLE_RATE, "width": OMI_SAMPLE_WIDTH, "channels": OMI_CHANNELS}) + header.get("data", {"rate": OMI_SAMPLE_RATE, "width": OMI_SAMPLE_WIDTH, "channels": OMI_CHANNELS}), + websocket=ws # Pass WebSocket to launch interim results subscriber ) elif header["type"] == "audio-chunk" and payload: @@ -883,6 +996,16 @@ async def handle_omi_websocket( except Exception as e: application_logger.error(f"❌ WebSocket error for client {client_id}: {e}", exc_info=True) finally: + # Cancel interim results subscriber task if running + if interim_subscriber_task and not interim_subscriber_task.done(): + interim_subscriber_task.cancel() + try: + await interim_subscriber_task + except asyncio.CancelledError: + application_logger.info(f"Interim subscriber task cancelled for {client_id}") + except Exception as task_error: + application_logger.error(f"Error cancelling interim subscriber task: {task_error}") + # Clean up pending connection tracking pending_connections.discard(pending_client_id) @@ -909,6 +1032,7 @@ async def handle_pcm_websocket( client_id = None client_state = None + interim_subscriber_task = None try: # Setup connection (accept, auth, create client state) @@ -1011,15 +1135,19 @@ async def handle_pcm_websocket( # Route to appropriate mode handler audio_format = control_header.get("data", {}) - await _handle_audio_chunk( + task = await _handle_audio_chunk( client_state, audio_stream_producer, audio_data, audio_format, user.user_id, user.email, - client_id + client_id, + websocket=ws ) + # Store subscriber task if it was created (first streaming chunk) + if task and not interim_subscriber_task: + interim_subscriber_task = task else: application_logger.warning(f"Expected binary payload for audio-chunk, got: {payload_msg.keys()}") else: @@ -1044,15 +1172,19 @@ async def handle_pcm_websocket( # Route to appropriate mode handler with default format default_format = {"rate": 16000, "width": 2, "channels": 1} - await _handle_audio_chunk( + task = await _handle_audio_chunk( client_state, audio_stream_producer, audio_data, default_format, user.user_id, user.email, - client_id + client_id, + websocket=ws ) + # Store subscriber task if it was created (first streaming chunk) + if task and not interim_subscriber_task: + interim_subscriber_task = task else: application_logger.warning(f"Unexpected message format in streaming mode: {message.keys()}") @@ -1115,6 +1247,16 @@ async def handle_pcm_websocket( f"❌ PCM WebSocket error for client {client_id}: {e}", exc_info=True ) finally: + # Cancel interim results subscriber task if running + if interim_subscriber_task and not interim_subscriber_task.done(): + interim_subscriber_task.cancel() + try: + await interim_subscriber_task + except asyncio.CancelledError: + application_logger.info(f"Interim subscriber task cancelled for {client_id}") + except Exception as task_error: + application_logger.error(f"Error cancelling interim subscriber task: {task_error}") + # Clean up pending connection tracking pending_connections.discard(pending_client_id) diff --git a/backends/advanced/src/advanced_omi_backend/model_registry.py b/backends/advanced/src/advanced_omi_backend/model_registry.py index 53d919ca..18f464ae 100644 --- a/backends/advanced/src/advanced_omi_backend/model_registry.py +++ b/backends/advanced/src/advanced_omi_backend/model_registry.py @@ -160,15 +160,15 @@ def validate_model(self) -> ModelDef: class AppModels(BaseModel): """Application models registry. - + Contains default model selections and all available model definitions. """ - + model_config = ConfigDict( extra='allow', validate_assignment=True, ) - + defaults: Dict[str, str] = Field( default_factory=dict, description="Default model names for each model_type" @@ -185,6 +185,10 @@ class AppModels(BaseModel): default_factory=dict, description="Speaker recognition service configuration" ) + chat: Dict[str, Any] = Field( + default_factory=dict, + description="Chat service configuration including system prompt" + ) def get_by_name(self, name: str) -> Optional[ModelDef]: """Get a model by its unique name. @@ -318,6 +322,7 @@ def load_models_config(force_reload: bool = False) -> Optional[AppModels]: model_list = raw.get("models", []) or [] memory_settings = raw.get("memory", {}) or {} speaker_recognition_cfg = raw.get("speaker_recognition", {}) or {} + chat_settings = raw.get("chat", {}) or {} # Parse and validate models using Pydantic models: Dict[str, ModelDef] = {} @@ -336,7 +341,8 @@ def load_models_config(force_reload: bool = False) -> Optional[AppModels]: defaults=defaults, models=models, memory=memory_settings, - speaker_recognition=speaker_recognition_cfg + speaker_recognition=speaker_recognition_cfg, + chat=chat_settings ) return _REGISTRY diff --git a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py new file mode 100644 index 00000000..3ccea7dc --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py @@ -0,0 +1,18 @@ +""" +Chronicle plugin system for multi-level pipeline extension. + +Plugins can hook into different stages of the processing pipeline: +- transcript: When new transcript segment arrives +- conversation: When conversation processing completes +- memory: After memory extraction finishes + +Trigger types control when plugins execute: +- wake_word: Only when transcript starts with specified wake word +- always: Execute on every invocation at access level +- conditional: Execute based on custom condition (future) +""" + +from .base import BasePlugin, PluginContext, PluginResult +from .router import PluginRouter + +__all__ = ['BasePlugin', 'PluginContext', 'PluginResult', 'PluginRouter'] diff --git a/backends/advanced/src/advanced_omi_backend/plugins/base.py b/backends/advanced/src/advanced_omi_backend/plugins/base.py new file mode 100644 index 00000000..84fc8967 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/base.py @@ -0,0 +1,131 @@ +""" +Base plugin classes for Chronicle multi-level plugin architecture. + +Provides: +- PluginContext: Context passed to plugin execution +- PluginResult: Result from plugin execution +- BasePlugin: Abstract base class for all plugins +""" +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field + + +@dataclass +class PluginContext: + """Context passed to plugin execution""" + user_id: str + access_level: str + data: Dict[str, Any] # Access-level specific data + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PluginResult: + """Result from plugin execution""" + success: bool + data: Optional[Dict[str, Any]] = None + message: Optional[str] = None + should_continue: bool = True # Whether to continue normal processing + + +class BasePlugin(ABC): + """ + Base class for all Chronicle plugins. + + Plugins can hook into different stages of the processing pipeline: + - transcript: When new transcript segment arrives + - conversation: When conversation processing completes + - memory: When memory extraction finishes + + Subclasses should: + 1. Set SUPPORTED_ACCESS_LEVELS to list which levels they support + 2. Implement initialize() for plugin initialization + 3. Implement the appropriate callback methods (on_transcript, on_conversation_complete, on_memory_processed) + 4. Optionally implement cleanup() for resource cleanup + """ + + # Subclasses declare which access levels they support + SUPPORTED_ACCESS_LEVELS: List[str] = [] + + def __init__(self, config: Dict[str, Any]): + """ + Initialize plugin with configuration. + + Args: + config: Plugin configuration from config/plugins.yml + Contains: enabled, access_level, trigger, and plugin-specific config + """ + self.config = config + self.enabled = config.get('enabled', False) + self.access_level = config.get('access_level') + self.trigger = config.get('trigger', {'type': 'always'}) + + @abstractmethod + async def initialize(self): + """ + Initialize plugin resources (connect to services, etc.) + + Called during application startup after plugin registration. + Raise an exception if initialization fails. + """ + pass + + async def cleanup(self): + """ + Clean up plugin resources. + + Called during application shutdown. + Override if your plugin needs cleanup (closing connections, etc.) + """ + pass + + # Access-level specific methods (implement only what you need) + + async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called when new transcript segment arrives. + + Context data contains: + - transcript: str - The transcript text + - segment_id: str - Unique segment identifier + - conversation_id: str - Current conversation ID + + For wake_word triggers, router adds: + - command: str - Command with wake word stripped + - original_transcript: str - Full transcript + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass + + async def on_conversation_complete(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called when conversation processing completes. + + Context data contains: + - conversation: dict - Full conversation data + - transcript: str - Complete transcript + - duration: float - Conversation duration + - conversation_id: str - Conversation identifier + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass + + async def on_memory_processed(self, context: PluginContext) -> Optional[PluginResult]: + """ + Called after memory extraction finishes. + + Context data contains: + - memories: list - Extracted memories + - conversation: dict - Source conversation + - memory_count: int - Number of memories created + - conversation_id: str - Conversation identifier + + Returns: + PluginResult with success status, optional message, and should_continue flag + """ + pass diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/__init__.py new file mode 100644 index 00000000..11b831e9 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/__init__.py @@ -0,0 +1,9 @@ +""" +Home Assistant plugin for Chronicle. + +Allows control of Home Assistant devices via natural language wake word commands. +""" + +from .plugin import HomeAssistantPlugin + +__all__ = ['HomeAssistantPlugin'] diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/command_parser.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/command_parser.py new file mode 100644 index 00000000..cc73626d --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/command_parser.py @@ -0,0 +1,97 @@ +""" +LLM-based command parser for Home Assistant integration. + +This module provides structured command parsing using LLM to extract +intent, target entities/areas, and parameters from natural language. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class ParsedCommand: + """Structured representation of a parsed Home Assistant command.""" + + action: str + """Action to perform (e.g., turn_on, turn_off, set_brightness, toggle)""" + + target_type: str + """Type of target (area, entity, all_in_area)""" + + target: str + """Target identifier (area name or entity name)""" + + entity_type: Optional[str] = None + """Entity domain filter (e.g., light, switch, fan) - None means all types""" + + parameters: Dict[str, Any] = field(default_factory=dict) + """Additional parameters (e.g., brightness_pct=50, color='red')""" + + +# LLM System Prompt for Command Parsing +COMMAND_PARSER_SYSTEM_PROMPT = """You are a smart home command parser for Home Assistant. + +Extract structured information from natural language commands. +Return ONLY valid JSON in this exact format (no markdown, no code blocks, no explanation): + +{ + "action": "turn_off", + "target_type": "area", + "target": "study", + "entity_type": "light", + "parameters": {} +} + +ACTIONS (choose one): +- turn_on: Turn on entities +- turn_off: Turn off entities +- toggle: Toggle entity state +- set_brightness: Set brightness level +- set_color: Set color + +TARGET_TYPE (choose one): +- area: Targeting all entities of a type in an area (e.g., "study lights") +- all_in_area: Targeting ALL entities in an area (e.g., "everything in study") +- entity: Targeting a specific entity by name (e.g., "desk lamp") + +ENTITY_TYPE (optional, use null if not specified): +- light: Light entities +- switch: Switch entities +- fan: Fan entities +- cover: Covers/blinds +- null: All entity types (when target_type is "all_in_area") + +PARAMETERS (optional, empty dict if none): +- brightness_pct: Brightness percentage (0-100) +- color: Color name (e.g., "red", "blue", "warm white") + +EXAMPLES: + +Command: "turn off study lights" +Response: {"action": "turn_off", "target_type": "area", "target": "study", "entity_type": "light", "parameters": {}} + +Command: "turn off everything in study" +Response: {"action": "turn_off", "target_type": "all_in_area", "target": "study", "entity_type": null, "parameters": {}} + +Command: "turn on desk lamp" +Response: {"action": "turn_on", "target_type": "entity", "target": "desk lamp", "entity_type": null, "parameters": {}} + +Command: "set study lights to 50%" +Response: {"action": "set_brightness", "target_type": "area", "target": "study", "entity_type": "light", "parameters": {"brightness_pct": 50}} + +Command: "turn on living room fan" +Response: {"action": "turn_on", "target_type": "area", "target": "living room", "entity_type": "fan", "parameters": {}} + +Command: "turn off all lights" +Response: {"action": "turn_off", "target_type": "entity", "target": "all", "entity_type": "light", "parameters": {}} + +Command: "toggle hallway light" +Response: {"action": "toggle", "target_type": "entity", "target": "hallway light", "entity_type": null, "parameters": {}} + +Remember: +1. Return ONLY the JSON object, no markdown formatting +2. Use lowercase for action, target_type, target, entity_type +3. Use null (not "null" string) for missing entity_type +4. Always include all 5 fields: action, target_type, target, entity_type, parameters +""" diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/entity_cache.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/entity_cache.py new file mode 100644 index 00000000..e8624f1b --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/entity_cache.py @@ -0,0 +1,133 @@ +""" +Entity cache for Home Assistant integration. + +This module provides caching and lookup functionality for Home Assistant areas and entities. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class EntityCache: + """Cache for Home Assistant areas and entities.""" + + areas: List[str] = field(default_factory=list) + """List of area names (e.g., ["study", "living_room"])""" + + area_entities: Dict[str, List[str]] = field(default_factory=dict) + """Map of area names to entity IDs (e.g., {"study": ["light.tubelight_3"]})""" + + entity_details: Dict[str, Dict] = field(default_factory=dict) + """Full entity state data keyed by entity_id""" + + last_refresh: datetime = field(default_factory=datetime.now) + """Timestamp of last cache refresh""" + + def find_entity_by_name(self, name: str) -> Optional[str]: + """ + Find entity ID by fuzzy name matching. + + Matching priority: + 1. Exact friendly_name match (case-insensitive) + 2. Partial friendly_name match (case-insensitive) + 3. Entity ID match (e.g., "tubelight_3" → "light.tubelight_3") + + Args: + name: Entity name to search for + + Returns: + Entity ID if found, None otherwise + """ + name_lower = name.lower().strip() + + # Step 1: Exact friendly_name match + for entity_id, details in self.entity_details.items(): + friendly_name = details.get('attributes', {}).get('friendly_name', '') + if friendly_name.lower() == name_lower: + logger.debug(f"Exact match: {name} → {entity_id} (friendly_name: {friendly_name})") + return entity_id + + # Step 2: Partial friendly_name match + for entity_id, details in self.entity_details.items(): + friendly_name = details.get('attributes', {}).get('friendly_name', '') + if name_lower in friendly_name.lower(): + logger.debug(f"Partial match: {name} → {entity_id} (friendly_name: {friendly_name})") + return entity_id + + # Step 3: Entity ID match (try adding common domains) + common_domains = ['light', 'switch', 'fan', 'cover'] + for domain in common_domains: + candidate_id = f"{domain}.{name_lower.replace(' ', '_')}" + if candidate_id in self.entity_details: + logger.debug(f"Entity ID match: {name} → {candidate_id}") + return candidate_id + + logger.warning(f"No entity found matching: {name}") + return None + + def get_entities_in_area( + self, + area: str, + entity_type: Optional[str] = None + ) -> List[str]: + """ + Get all entities in an area, optionally filtered by domain. + + Args: + area: Area name (case-insensitive) + entity_type: Entity domain filter (e.g., "light", "switch") + + Returns: + List of entity IDs in the area + """ + area_lower = area.lower().strip() + + # Find matching area (case-insensitive) + matching_area = None + for area_name in self.areas: + if area_name.lower() == area_lower: + matching_area = area_name + break + + if not matching_area: + logger.warning(f"Area not found: {area}") + return [] + + # Get entities in area + entities = self.area_entities.get(matching_area, []) + + # Filter by entity type if specified + if entity_type: + entity_type_lower = entity_type.lower() + entities = [ + e for e in entities + if e.split('.')[0] == entity_type_lower + ] + + logger.debug( + f"Found {len(entities)} entities in area '{matching_area}'" + + (f" (type: {entity_type})" if entity_type else "") + ) + + return entities + + def get_cache_age_seconds(self) -> float: + """Get cache age in seconds.""" + return (datetime.now() - self.last_refresh).total_seconds() + + def is_stale(self, max_age_seconds: int = 3600) -> bool: + """ + Check if cache is stale. + + Args: + max_age_seconds: Maximum cache age before considering stale (default: 1 hour) + + Returns: + True if cache is older than max_age_seconds + """ + return self.get_cache_age_seconds() > max_age_seconds diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/mcp_client.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/mcp_client.py new file mode 100644 index 00000000..42ede8dc --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/mcp_client.py @@ -0,0 +1,421 @@ +""" +MCP client for communicating with Home Assistant's MCP Server. + +Home Assistant exposes an MCP server at /api/mcp that provides tools +for controlling smart home devices. +""" + +import json +import logging +from typing import Any, Dict, List, Optional + +import httpx + +logger = logging.getLogger(__name__) + + +class MCPError(Exception): + """MCP protocol error""" + pass + + +class HAMCPClient: + """ + MCP Client for Home Assistant's /api/mcp endpoint. + + Implements the Model Context Protocol for communicating with + Home Assistant's built-in MCP server. + """ + + def __init__(self, base_url: str, token: str, timeout: int = 30): + """ + Initialize the MCP client. + + Args: + base_url: Base URL of Home Assistant (e.g., http://localhost:8123) + token: Long-lived access token for authentication + timeout: Request timeout in seconds + + """ + self.base_url = base_url.rstrip('/') + self.mcp_url = f"{self.base_url}/api/mcp" + self.token = token + self.timeout = timeout + self.client = httpx.AsyncClient(timeout=timeout) + self._request_id = 0 + + async def close(self): + """Close the HTTP client""" + await self.client.aclose() + + def _next_request_id(self) -> int: + """Generate next request ID""" + self._request_id += 1 + return self._request_id + + async def _send_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict[str, Any]: + """ + Send MCP protocol request to Home Assistant. + + Args: + method: MCP method name (e.g., "tools/list", "tools/call") + params: Optional method parameters + + Returns: + Response data from MCP server + + Raises: + MCPError: If request fails or returns an error + """ + payload = { + "jsonrpc": "2.0", + "id": self._next_request_id(), + "method": method + } + + if params: + payload["params"] = params + + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + try: + logger.debug(f"MCP Request: {method} with params: {params}") + response = await self.client.post( + self.mcp_url, + json=payload, + headers=headers + ) + response.raise_for_status() + + data = response.json() + + # Check for JSON-RPC error + if "error" in data: + error = data["error"] + raise MCPError(f"MCP Error {error.get('code')}: {error.get('message')}") + + return data.get("result", {}) + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error calling MCP endpoint: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error calling MCP endpoint: {e}") + raise MCPError(f"Request failed: {e}") + except Exception as e: + logger.error(f"Unexpected error calling MCP endpoint: {e}") + raise MCPError(f"Unexpected error: {e}") + + async def list_tools(self) -> List[Dict[str, Any]]: + """ + Get list of available MCP tools from Home Assistant. + + Returns: + List of tool definitions with schema + + Example tool: + { + "name": "turn_on", + "description": "Turn on a light or switch", + "inputSchema": { + "type": "object", + "properties": { + "entity_id": {"type": "string"} + } + } + } + """ + result = await self._send_mcp_request("tools/list") + tools = result.get("tools", []) + logger.info(f"Retrieved {len(tools)} tools from Home Assistant MCP") + return tools + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute a tool via MCP. + + Args: + tool_name: Name of the tool to call (e.g., "turn_on", "turn_off") + arguments: Tool arguments (e.g., {"entity_id": "light.hall_light"}) + + Returns: + Tool execution result + + Raises: + MCPError: If tool execution fails + + Example: + >>> await client.call_tool("turn_off", {"entity_id": "light.hall_light"}) + {"success": True} + """ + params = { + "name": tool_name, + "arguments": arguments + } + + logger.info(f"Calling MCP tool '{tool_name}' with args: {arguments}") + result = await self._send_mcp_request("tools/call", params) + + # MCP tool results are wrapped in content blocks + content = result.get("content", []) + if content and isinstance(content, list): + # Extract text content from first block + first_block = content[0] + if isinstance(first_block, dict) and first_block.get("type") == "text": + return {"result": first_block.get("text"), "success": True} + + return result + + async def test_connection(self) -> bool: + """ + Test connection to Home Assistant MCP server. + + Returns: + True if connection successful, False otherwise + """ + try: + tools = await self.list_tools() + logger.info(f"MCP connection test successful ({len(tools)} tools available)") + return True + except Exception as e: + logger.error(f"MCP connection test failed: {e}") + return False + + async def _render_template(self, template: str) -> Any: + """ + Render a Home Assistant template using the Template API. + + Args: + template: Jinja2 template string (e.g., "{{ areas() }}") + + Returns: + Rendered template result (parsed as JSON if possible) + + Raises: + MCPError: If template rendering fails + + Example: + >>> await client._render_template("{{ areas() }}") + ["study", "living_room", "bedroom"] + """ + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + payload = {"template": template} + + try: + logger.debug(f"Rendering template: {template}") + response = await self.client.post( + f"{self.base_url}/api/template", + json=payload, + headers=headers + ) + response.raise_for_status() + + result = response.text.strip() + + # Try to parse as JSON (for lists, dicts) + if result.startswith('[') or result.startswith('{'): + try: + return json.loads(result) + except json.JSONDecodeError: + logger.warning(f"Failed to parse template result as JSON: {result}") + return result + + return result + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error rendering template: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error rendering template: {e}") + raise MCPError(f"Request failed: {e}") + + async def fetch_areas(self) -> List[str]: + """ + Fetch all areas from Home Assistant using Template API. + + Returns: + List of area names + + Example: + >>> await client.fetch_areas() + ["study", "living_room", "bedroom"] + """ + template = "{{ areas() | to_json }}" + areas = await self._render_template(template) + + if isinstance(areas, list): + logger.info(f"Fetched {len(areas)} areas from Home Assistant") + return areas + else: + logger.warning(f"Unexpected areas format: {type(areas)}") + return [] + + async def fetch_area_entities(self, area_name: str) -> List[str]: + """ + Fetch all entity IDs in a specific area. + + Args: + area_name: Name of the area + + Returns: + List of entity IDs in the area + + Example: + >>> await client.fetch_area_entities("study") + ["light.tubelight_3", "switch.desk_fan"] + """ + template = f"{{{{ area_entities('{area_name}') | to_json }}}}" + entities = await self._render_template(template) + + if isinstance(entities, list): + logger.info(f"Fetched {len(entities)} entities from area '{area_name}'") + return entities + else: + logger.warning(f"Unexpected entities format for area '{area_name}': {type(entities)}") + return [] + + async def fetch_entity_states(self) -> Dict[str, Dict]: + """ + Fetch all entity states from Home Assistant. + + Returns: + Dict mapping entity_id to state data (includes attributes, area_id) + + Example: + >>> await client.fetch_entity_states() + { + "light.tubelight_3": { + "state": "on", + "attributes": {"friendly_name": "Study Light", ...}, + "area_id": "study" + } + } + """ + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + try: + logger.debug("Fetching all entity states") + response = await self.client.get( + f"{self.base_url}/api/states", + headers=headers + ) + response.raise_for_status() + + states = response.json() + entity_details = {} + + # Enrich with area information + for state in states: + entity_id = state.get('entity_id') + if entity_id: + # Get area_id using Template API + try: + area_template = f"{{{{ area_id('{entity_id}') }}}}" + area_id = await self._render_template(area_template) + state['area_id'] = area_id if area_id else None + except Exception as e: + logger.debug(f"Failed to get area for {entity_id}: {e}") + state['area_id'] = None + + entity_details[entity_id] = state + + logger.info(f"Fetched {len(entity_details)} entity states") + return entity_details + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error fetching states: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error fetching states: {e}") + raise MCPError(f"Request failed: {e}") + + async def call_service( + self, + domain: str, + service: str, + entity_ids: List[str], + **parameters + ) -> Dict[str, Any]: + """ + Call a Home Assistant service directly via REST API. + + Args: + domain: Service domain (e.g., "light", "switch") + service: Service name (e.g., "turn_on", "turn_off") + entity_ids: List of entity IDs to target + **parameters: Additional service parameters (e.g., brightness_pct=50) + + Returns: + Service call response + + Example: + >>> await client.call_service("light", "turn_on", ["light.study"], brightness_pct=50) + [{"entity_id": "light.study", "state": "on"}] + """ + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + + payload = { + "entity_id": entity_ids, + **parameters + } + + service_url = f"{self.base_url}/api/services/{domain}/{service}" + + try: + logger.info(f"Calling service {domain}.{service} for {len(entity_ids)} entities") + logger.debug(f"Service payload: {payload}") + + response = await self.client.post( + service_url, + json=payload, + headers=headers + ) + response.raise_for_status() + + result = response.json() + logger.info(f"Service call successful: {domain}.{service}") + return result + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error calling service: {e.response.status_code}") + raise MCPError(f"HTTP {e.response.status_code}: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error calling service: {e}") + raise MCPError(f"Request failed: {e}") + + async def discover_entities(self) -> Dict[str, Dict]: + """ + Discover available entities from MCP tools. + + Parses the available tools to build an index of entities + that can be controlled. + + Returns: + Dict mapping entity_id to metadata + """ + tools = await self.list_tools() + entities = {} + + for tool in tools: + # Extract entity information from tool schemas + # This will depend on how HA MCP structures its tools + # For now, we'll just log what we find + logger.debug(f"Tool: {tool.get('name')} - {tool.get('description')}") + + # TODO: Parse tool schemas to extract entity_id information + # For now, return empty dict - will be populated based on actual HA MCP response + + return entities diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py new file mode 100644 index 00000000..931dd813 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py @@ -0,0 +1,598 @@ +""" +Home Assistant plugin for Chronicle. + +Enables control of Home Assistant devices through natural language commands +triggered by a wake word. +""" + +import json +import logging +from typing import Any, Dict, List, Optional + +from ..base import BasePlugin, PluginContext, PluginResult +from .entity_cache import EntityCache +from .mcp_client import HAMCPClient, MCPError + +logger = logging.getLogger(__name__) + + +class HomeAssistantPlugin(BasePlugin): + """ + Plugin for controlling Home Assistant devices via wake word commands. + + Example: + User says: "Vivi, turn off the hall lights" + -> Wake word "vivi" detected by router + -> Command "turn off the hall lights" passed to on_transcript() + -> Plugin parses command and calls HA MCP to execute + -> Returns: PluginResult with "I've turned off the hall light" + """ + + SUPPORTED_ACCESS_LEVELS: List[str] = ['transcript'] + + def __init__(self, config: Dict[str, Any]): + """ + Initialize Home Assistant plugin. + + Args: + config: Plugin configuration with keys: + - ha_url: Home Assistant URL + - ha_token: Long-lived access token + - wake_word: Wake word for triggering commands (handled by router) + - enabled: Whether plugin is enabled + - access_level: Should be 'transcript' + - trigger: Should be {'type': 'wake_word', 'wake_word': '...'} + """ + super().__init__(config) + self.mcp_client: Optional[HAMCPClient] = None + self.available_tools: List[Dict] = [] + self.entities: Dict[str, Dict] = {} + + # Entity cache for area-based commands + self.entity_cache: Optional[EntityCache] = None + self.cache_initialized = False + + # Configuration + self.ha_url = config.get('ha_url', 'http://localhost:8123') + self.ha_token = config.get('ha_token', '') + self.wake_word = config.get('wake_word', 'vivi') + self.timeout = config.get('timeout', 30) + + async def initialize(self): + """ + Initialize the Home Assistant plugin. + + Connects to Home Assistant MCP server and discovers available tools. + + Raises: + MCPError: If connection or discovery fails + """ + if not self.enabled: + logger.info("Home Assistant plugin is disabled, skipping initialization") + return + + if not self.ha_token: + raise ValueError("Home Assistant token is required") + + logger.info(f"Initializing Home Assistant plugin (URL: {self.ha_url})") + + # Create MCP client (used for REST API calls, not MCP protocol) + self.mcp_client = HAMCPClient( + base_url=self.ha_url, + token=self.ha_token, + timeout=self.timeout + ) + + # Test basic API connectivity with Template API + try: + logger.info("Testing Home Assistant API connectivity...") + test_result = await self.mcp_client._render_template("{{ 1 + 1 }}") + if str(test_result).strip() != "2": + raise ValueError(f"Unexpected template result: {test_result}") + logger.info("Home Assistant API connection successful") + except Exception as e: + raise MCPError(f"Failed to connect to Home Assistant API: {e}") + + logger.info("Home Assistant plugin initialized successfully") + + async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: + """ + Execute Home Assistant command from wake word transcript. + + Called by the router when a wake word is detected in the transcript. + The router has already stripped the wake word and extracted the command. + + Args: + context: PluginContext containing: + - user_id: User ID who issued the command + - access_level: 'transcript' + - data: Dict with: + - command: str - Command with wake word already stripped + - original_transcript: str - Full transcript with wake word + - transcript: str - Original transcript + - segment_id: str - Unique segment identifier + - conversation_id: str - Current conversation ID + - metadata: Optional additional metadata + + Returns: + PluginResult with: + - success: True if command executed + - message: User-friendly response + - data: Dict with action details + - should_continue: False to stop normal processing + + Example: + Context data: + { + 'command': 'turn off study lights', + 'original_transcript': 'vivi turn off study lights', + 'conversation_id': 'conv_123' + } + + Returns: + PluginResult( + success=True, + message="I've turned off 1 light in study", + data={'action': 'turn_off', 'entity_ids': ['light.tubelight_3']}, + should_continue=False + ) + """ + command = context.data.get('command', '') + + if not command: + return PluginResult( + success=False, + message="No command provided", + should_continue=True + ) + + if not self.mcp_client: + logger.error("MCP client not initialized") + return PluginResult( + success=False, + message="Sorry, Home Assistant is not connected", + should_continue=True + ) + + try: + # Step 1: Parse command using hybrid LLM + fallback parsing + logger.info(f"Processing HA command: '{command}'") + parsed = await self._parse_command_hybrid(command) + + if not parsed: + return PluginResult( + success=False, + message="Sorry, I couldn't understand that command", + should_continue=True + ) + + # Step 2: Resolve entities from parsed command + try: + entity_ids = await self._resolve_entities(parsed) + except ValueError as e: + logger.warning(f"Entity resolution failed: {e}") + return PluginResult( + success=False, + message=str(e), + should_continue=True + ) + + # Step 3: Determine service and domain + # Extract domain from first entity (all should have same domain for area-based) + domain = entity_ids[0].split('.')[0] if entity_ids else 'light' + + # Map action to service name + service_map = { + 'turn_on': 'turn_on', + 'turn_off': 'turn_off', + 'toggle': 'toggle', + 'set_brightness': 'turn_on', # brightness uses turn_on with params + 'set_color': 'turn_on' # color uses turn_on with params + } + service = service_map.get(parsed.action, 'turn_on') + + # Step 4: Call Home Assistant service + logger.info( + f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}" + ) + + result = await self.mcp_client.call_service( + domain=domain, + service=service, + entity_ids=entity_ids, + **parsed.parameters + ) + + # Step 5: Format user-friendly response + entity_type_name = parsed.entity_type or domain + if parsed.target_type == 'area': + message = ( + f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} " + f"{entity_type_name}{'s' if len(entity_ids) != 1 else ''} " + f"in {parsed.target}" + ) + elif parsed.target_type == 'all_in_area': + message = ( + f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} " + f"entities in {parsed.target}" + ) + else: + message = f"I've {parsed.action.replace('_', ' ')} {parsed.target}" + + logger.info(f"HA command executed successfully: {message}") + + return PluginResult( + success=True, + data={ + 'action': parsed.action, + 'entity_ids': entity_ids, + 'target_type': parsed.target_type, + 'target': parsed.target, + 'ha_result': result + }, + message=message, + should_continue=False # Stop normal processing - HA command handled + ) + + except MCPError as e: + logger.error(f"Home Assistant API error: {e}", exc_info=True) + return PluginResult( + success=False, + message=f"Sorry, Home Assistant couldn't execute that: {e}", + should_continue=True + ) + except Exception as e: + logger.error(f"Command execution failed: {e}", exc_info=True) + return PluginResult( + success=False, + message="Sorry, something went wrong while executing that command", + should_continue=True + ) + + async def cleanup(self): + """Clean up resources""" + if self.mcp_client: + await self.mcp_client.close() + logger.info("Closed Home Assistant MCP client") + + async def _ensure_cache_initialized(self): + """Ensure entity cache is initialized. Lazy-load on first use.""" + if not self.cache_initialized: + logger.info("Entity cache not initialized, refreshing...") + await self._refresh_cache() + self.cache_initialized = True + + async def _refresh_cache(self): + """ + Refresh the entity cache from Home Assistant. + + Fetches: + - All areas + - Entities in each area + - Entity state details + """ + if not self.mcp_client: + logger.error("Cannot refresh cache: MCP client not initialized") + return + + try: + logger.info("Refreshing entity cache from Home Assistant...") + + # Fetch all areas + areas = await self.mcp_client.fetch_areas() + logger.debug(f"Fetched {len(areas)} areas: {areas}") + + # Fetch entities for each area + area_entities = {} + for area in areas: + entities = await self.mcp_client.fetch_area_entities(area) + area_entities[area] = entities + logger.debug(f"Area '{area}': {len(entities)} entities") + + # Fetch all entity states + entity_details = await self.mcp_client.fetch_entity_states() + logger.debug(f"Fetched {len(entity_details)} entity states") + + # Create cache + from datetime import datetime + self.entity_cache = EntityCache( + areas=areas, + area_entities=area_entities, + entity_details=entity_details, + last_refresh=datetime.now() + ) + + logger.info( + f"Entity cache refreshed: {len(areas)} areas, " + f"{len(entity_details)} entities" + ) + + except Exception as e: + logger.error(f"Failed to refresh entity cache: {e}", exc_info=True) + raise + + async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand']: + """ + Parse command using LLM with structured system prompt. + + Args: + command: Natural language command (wake word already stripped) + + Returns: + ParsedCommand if parsing succeeds, None otherwise + + Example: + >>> await self._parse_command_with_llm("turn off study lights") + ParsedCommand( + action="turn_off", + target_type="area", + target="study", + entity_type="light", + parameters={} + ) + """ + try: + from advanced_omi_backend.llm_client import get_llm_client + from .command_parser import COMMAND_PARSER_SYSTEM_PROMPT, ParsedCommand + + llm_client = get_llm_client() + + logger.debug(f"Parsing command with LLM: '{command}'") + + # Use OpenAI chat format with system + user messages + response = llm_client.client.chat.completions.create( + model=llm_client.model, + messages=[ + {"role": "system", "content": COMMAND_PARSER_SYSTEM_PROMPT}, + {"role": "user", "content": f'Command: "{command}"\n\nReturn JSON only.'} + ], + temperature=0.1, + max_tokens=150 + ) + + result_text = response.choices[0].message.content.strip() + logger.debug(f"LLM response: {result_text}") + + # Remove markdown code blocks if present + if result_text.startswith('```'): + lines = result_text.split('\n') + result_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else result_text + result_text = result_text.strip() + + # Parse JSON response + result_json = json.loads(result_text) + + # Validate required fields + required_fields = ['action', 'target_type', 'target'] + if not all(field in result_json for field in required_fields): + logger.warning(f"LLM response missing required fields: {result_json}") + return None + + parsed = ParsedCommand( + action=result_json['action'], + target_type=result_json['target_type'], + target=result_json['target'], + entity_type=result_json.get('entity_type'), + parameters=result_json.get('parameters', {}) + ) + + logger.info( + f"LLM parsed command: action={parsed.action}, " + f"target_type={parsed.target_type}, target={parsed.target}, " + f"entity_type={parsed.entity_type}" + ) + + return parsed + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse LLM JSON response: {e}\nResponse: {result_text}") + return None + except Exception as e: + logger.error(f"LLM command parsing failed: {e}", exc_info=True) + return None + + async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]: + """ + Resolve ParsedCommand to actual Home Assistant entity IDs. + + Args: + parsed: ParsedCommand from LLM parsing + + Returns: + List of entity IDs to target + + Raises: + ValueError: If target not found or ambiguous + + Example: + >>> await self._resolve_entities(ParsedCommand( + ... action="turn_off", + ... target_type="area", + ... target="study", + ... entity_type="light" + ... )) + ["light.tubelight_3"] + """ + from .command_parser import ParsedCommand + + # Ensure cache is ready + await self._ensure_cache_initialized() + + if not self.entity_cache: + raise ValueError("Entity cache not initialized") + + if parsed.target_type == 'area': + # Get entities in area, filtered by type + entities = self.entity_cache.get_entities_in_area( + area=parsed.target, + entity_type=parsed.entity_type + ) + + if not entities: + entity_desc = f"{parsed.entity_type}s" if parsed.entity_type else "entities" + raise ValueError( + f"No {entity_desc} found in area '{parsed.target}'. " + f"Available areas: {', '.join(self.entity_cache.areas)}" + ) + + logger.info( + f"Resolved area '{parsed.target}' to {len(entities)} " + f"{parsed.entity_type or 'entity'}(s)" + ) + return entities + + elif parsed.target_type == 'all_in_area': + # Get ALL entities in area (no filter) + entities = self.entity_cache.get_entities_in_area( + area=parsed.target, + entity_type=None + ) + + if not entities: + raise ValueError( + f"No entities found in area '{parsed.target}'. " + f"Available areas: {', '.join(self.entity_cache.areas)}" + ) + + logger.info(f"Resolved 'all in {parsed.target}' to {len(entities)} entities") + return entities + + elif parsed.target_type == 'entity': + # Fuzzy match entity by name + entity_id = self.entity_cache.find_entity_by_name(parsed.target) + + if not entity_id: + raise ValueError( + f"Entity '{parsed.target}' not found. " + f"Try being more specific or check the entity name." + ) + + logger.info(f"Resolved entity '{parsed.target}' to {entity_id}") + return [entity_id] + + else: + raise ValueError(f"Unknown target type: {parsed.target_type}") + + async def _parse_command_fallback(self, command: str) -> Optional[Dict[str, Any]]: + """ + Fallback keyword-based command parser (used when LLM fails). + + Args: + command: Natural language command + + Returns: + Dict with 'tool', 'arguments', and optional metadata + None if parsing fails + + Example: + Input: "turn off the hall lights" + Output: { + "tool": "turn_off", + "arguments": {"entity_id": "light.hall_light"}, + "friendly_name": "Hall Light", + "action": "turn_off" + } + """ + logger.debug("Using fallback keyword-based parsing") + command_lower = command.lower().strip() + + # Determine action + tool = None + if any(word in command_lower for word in ['turn off', 'off', 'disable']): + tool = 'turn_off' + action_desc = 'turned off' + elif any(word in command_lower for word in ['turn on', 'on', 'enable']): + tool = 'turn_on' + action_desc = 'turned on' + elif 'toggle' in command_lower: + tool = 'toggle' + action_desc = 'toggled' + else: + logger.warning(f"Unknown action in command: {command}") + return None + + # Extract entity name from command + entity_query = command_lower + for action_word in ['turn off', 'turn on', 'toggle', 'off', 'on', 'the']: + entity_query = entity_query.replace(action_word, '').strip() + + logger.info(f"Searching for entity: '{entity_query}'") + + # Return placeholder (this will work if entity ID matches pattern) + return { + "tool": tool, + "arguments": { + "entity_id": f"light.{entity_query.replace(' ', '_')}" + }, + "friendly_name": entity_query.title(), + "action_desc": action_desc + } + + async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand']: + """ + Hybrid command parser: Try LLM first, fallback to keywords. + + This provides the best of both worlds: + - LLM parsing for complex area-based and natural commands + - Keyword fallback for reliability when LLM fails or times out + + Args: + command: Natural language command + + Returns: + ParsedCommand if successful, None otherwise + + Example: + >>> await self._parse_command_hybrid("turn off study lights") + ParsedCommand(action="turn_off", target_type="area", target="study", ...) + """ + import asyncio + from .command_parser import ParsedCommand + + # Try LLM parsing with timeout + try: + logger.debug("Attempting LLM-based command parsing...") + parsed = await asyncio.wait_for( + self._parse_command_with_llm(command), + timeout=5.0 + ) + + if parsed: + logger.info("LLM parsing succeeded") + return parsed + else: + logger.warning("LLM parsing returned None, falling back to keywords") + + except asyncio.TimeoutError: + logger.warning("LLM parsing timed out (>5s), falling back to keywords") + except Exception as e: + logger.warning(f"LLM parsing failed: {e}, falling back to keywords") + + # Fallback to keyword-based parsing + try: + logger.debug("Using fallback keyword parsing...") + fallback_result = await self._parse_command_fallback(command) + + if not fallback_result: + return None + + # Convert fallback format to ParsedCommand + # Extract entity_id from arguments + entity_id = fallback_result['arguments'].get('entity_id', '') + entity_name = entity_id.split('.', 1)[1] if '.' in entity_id else entity_id + + # Simple heuristic: assume it's targeting a single entity + parsed = ParsedCommand( + action=fallback_result['tool'], + target_type='entity', + target=entity_name.replace('_', ' '), + entity_type=None, + parameters={} + ) + + logger.info("Fallback parsing succeeded") + return parsed + + except Exception as e: + logger.error(f"Fallback parsing failed: {e}", exc_info=True) + return None diff --git a/backends/advanced/src/advanced_omi_backend/plugins/router.py b/backends/advanced/src/advanced_omi_backend/plugins/router.py new file mode 100644 index 00000000..8074feb3 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/plugins/router.py @@ -0,0 +1,242 @@ +""" +Plugin routing system for multi-level plugin architecture. + +Routes pipeline events to appropriate plugins based on access level and triggers. +""" + +import logging +import re +import string +from typing import Dict, List, Optional + +from .base import BasePlugin, PluginContext, PluginResult + +logger = logging.getLogger(__name__) + + +def normalize_text_for_wake_word(text: str) -> str: + """ + Normalize text for wake word matching. + - Lowercase + - Replace punctuation with spaces + - Collapse multiple spaces to single space + - Strip leading/trailing whitespace + + Example: + "Hey, Vivi!" -> "hey vivi" + "HEY VIVI" -> "hey vivi" + "Hey-Vivi" -> "hey vivi" + """ + # Lowercase + text = text.lower() + # Replace punctuation with spaces (instead of removing, to preserve word boundaries) + text = text.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))) + # Normalize whitespace (collapse multiple spaces to single space) + text = re.sub(r'\s+', ' ', text) + # Strip leading/trailing whitespace + return text.strip() + + +def extract_command_after_wake_word(transcript: str, wake_word: str) -> str: + """ + Intelligently extract command after wake word in original transcript. + + Handles punctuation and spacing variations by creating a flexible regex pattern. + + Example: + transcript: "Hey, Vivi, turn off lights" + wake_word: "hey vivi" + -> extracts: "turn off lights" + + Args: + transcript: Original transcript text with punctuation + wake_word: Configured wake word (will be normalized) + + Returns: + Command text after wake word, or full transcript if wake word boundary not found + """ + # Split wake word into parts (normalized) + wake_word_parts = normalize_text_for_wake_word(wake_word).split() + + if not wake_word_parts: + return transcript.strip() + + # Create regex pattern that allows punctuation/whitespace between parts + # Example: "hey" + "vivi" -> r"hey[\s,.\-!?]*vivi[\s,.\-!?]*" + # The pattern matches the wake word parts with optional punctuation/whitespace between and after + pattern_parts = [re.escape(part) for part in wake_word_parts] + # Allow optional punctuation/whitespace between parts + pattern = r'[\s,.\-!?;:]*'.join(pattern_parts) + # Add trailing punctuation/whitespace consumption after last wake word part + pattern = '^' + pattern + r'[\s,.\-!?;:]*' + + # Try to match wake word at start of transcript (case-insensitive) + match = re.match(pattern, transcript, re.IGNORECASE) + + if match: + # Extract everything after the matched wake word (including trailing punctuation) + command = transcript[match.end():].strip() + return command + else: + # Fallback: couldn't find wake word boundary, return full transcript + logger.warning(f"Could not find wake word boundary for '{wake_word}' in '{transcript}', using full transcript") + return transcript.strip() + + +class PluginRouter: + """Routes pipeline events to appropriate plugins based on access level and triggers""" + + def __init__(self): + self.plugins: Dict[str, BasePlugin] = {} + # Index plugins by access level for fast lookup + self._plugins_by_level: Dict[str, List[str]] = { + 'transcript': [], + 'streaming_transcript': [], + 'conversation': [], + 'memory': [] + } + + def register_plugin(self, plugin_id: str, plugin: BasePlugin): + """Register a plugin with the router""" + self.plugins[plugin_id] = plugin + + # Index by access level + access_level = plugin.access_level + if access_level in self._plugins_by_level: + self._plugins_by_level[access_level].append(plugin_id) + + logger.info(f"Registered plugin '{plugin_id}' for access level '{access_level}'") + + async def trigger_plugins( + self, + access_level: str, + user_id: str, + data: Dict, + metadata: Optional[Dict] = None + ) -> List[PluginResult]: + """ + Trigger all plugins registered for this access level. + + Args: + access_level: 'transcript', 'streaming_transcript', 'conversation', or 'memory' + user_id: User ID for context + data: Access-level specific data + metadata: Optional metadata + + Returns: + List of plugin results + """ + results = [] + + # Hierarchical triggering logic: + # - 'streaming_transcript': trigger both 'streaming_transcript' AND 'transcript' plugins + # - 'transcript': trigger ONLY 'transcript' plugins (not 'streaming_transcript') + # - Other levels: exact match only + if access_level == 'streaming_transcript': + # Streaming mode: trigger both streaming_transcript AND transcript plugins + plugin_ids = ( + self._plugins_by_level.get('streaming_transcript', []) + + self._plugins_by_level.get('transcript', []) + ) + else: + # Batch mode or other modes: exact match only + plugin_ids = self._plugins_by_level.get(access_level, []) + + for plugin_id in plugin_ids: + plugin = self.plugins[plugin_id] + + if not plugin.enabled: + continue + + # Check trigger condition + if not await self._should_trigger(plugin, data): + continue + + # Execute plugin at appropriate access level + try: + context = PluginContext( + user_id=user_id, + access_level=access_level, + data=data, + metadata=metadata or {} + ) + + result = await self._execute_plugin(plugin, access_level, context) + + if result: + results.append(result) + + # If plugin says stop processing, break + if not result.should_continue: + logger.info(f"Plugin '{plugin_id}' stopped further processing") + break + + except Exception as e: + logger.error(f"Error executing plugin '{plugin_id}': {e}", exc_info=True) + + return results + + async def _should_trigger(self, plugin: BasePlugin, data: Dict) -> bool: + """Check if plugin should be triggered based on trigger configuration""" + trigger_type = plugin.trigger.get('type', 'always') + + if trigger_type == 'always': + return True + + elif trigger_type == 'wake_word': + # Normalize transcript for matching (handles punctuation and spacing) + transcript = data.get('transcript', '') + normalized_transcript = normalize_text_for_wake_word(transcript) + + # Support both singular 'wake_word' and plural 'wake_words' (list) + wake_words = plugin.trigger.get('wake_words', []) + if not wake_words: + # Fallback to singular wake_word for backward compatibility + wake_word = plugin.trigger.get('wake_word', '') + if wake_word: + wake_words = [wake_word] + + # Check if transcript starts with any wake word (after normalization) + for wake_word in wake_words: + normalized_wake_word = normalize_text_for_wake_word(wake_word) + if normalized_wake_word and normalized_transcript.startswith(normalized_wake_word): + # Smart extraction: find where wake word actually ends in original text + command = extract_command_after_wake_word(transcript, wake_word) + data['command'] = command + data['original_transcript'] = transcript + logger.debug(f"Wake word '{wake_word}' detected. Original: '{transcript}', Command: '{command}'") + return True + + return False + + elif trigger_type == 'conditional': + # Future: Custom condition checking + return True + + return False + + async def _execute_plugin( + self, + plugin: BasePlugin, + access_level: str, + context: PluginContext + ) -> Optional[PluginResult]: + """Execute plugin method for specified access level""" + # Both 'transcript' and 'streaming_transcript' call on_transcript() + if access_level in ('transcript', 'streaming_transcript'): + return await plugin.on_transcript(context) + elif access_level == 'conversation': + return await plugin.on_conversation_complete(context) + elif access_level == 'memory': + return await plugin.on_memory_processed(context) + + return None + + async def cleanup_all(self): + """Clean up all registered plugins""" + for plugin_id, plugin in self.plugins.items(): + try: + await plugin.cleanup() + logger.info(f"Cleaned up plugin '{plugin_id}'") + except Exception as e: + logger.error(f"Error cleaning up plugin '{plugin_id}': {e}") diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py index ead61ffa..93e94817 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py @@ -7,7 +7,8 @@ import logging from typing import Optional -from fastapi import APIRouter, Body, Depends, Request +from fastapi import APIRouter, Body, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, Response from pydantic import BaseModel from advanced_omi_backend.auth import current_active_user, current_superuser @@ -128,6 +129,100 @@ async def delete_all_user_memories(current_user: User = Depends(current_active_u return await system_controller.delete_all_user_memories(current_user) +# Chat Configuration Management Endpoints + +@router.get("/admin/chat/config", response_class=Response) +async def get_chat_config(current_user: User = Depends(current_superuser)): + """Get chat configuration as YAML. Admin only.""" + try: + yaml_content = await system_controller.get_chat_config_yaml() + return Response(content=yaml_content, media_type="text/plain") + except Exception as e: + logger.error(f"Failed to get chat config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/chat/config") +async def save_chat_config( + request: Request, + current_user: User = Depends(current_superuser) +): + """Save chat configuration from YAML. Admin only.""" + try: + yaml_content = await request.body() + yaml_str = yaml_content.decode('utf-8') + result = await system_controller.save_chat_config_yaml(yaml_str) + return JSONResponse(content=result) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Failed to save chat config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/chat/config/validate") +async def validate_chat_config( + request: Request, + current_user: User = Depends(current_superuser) +): + """Validate chat configuration YAML. Admin only.""" + try: + yaml_content = await request.body() + yaml_str = yaml_content.decode('utf-8') + result = await system_controller.validate_chat_config_yaml(yaml_str) + return JSONResponse(content=result) + except Exception as e: + logger.error(f"Failed to validate chat config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# Plugin Configuration Management Endpoints + +@router.get("/admin/plugins/config", response_class=Response) +async def get_plugins_config(current_user: User = Depends(current_superuser)): + """Get plugins configuration as YAML. Admin only.""" + try: + yaml_content = await system_controller.get_plugins_config_yaml() + return Response(content=yaml_content, media_type="text/plain") + except Exception as e: + logger.error(f"Failed to get plugins config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/plugins/config") +async def save_plugins_config( + request: Request, + current_user: User = Depends(current_superuser) +): + """Save plugins configuration from YAML. Admin only.""" + try: + yaml_content = await request.body() + yaml_str = yaml_content.decode('utf-8') + result = await system_controller.save_plugins_config_yaml(yaml_str) + return JSONResponse(content=result) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Failed to save plugins config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/admin/plugins/config/validate") +async def validate_plugins_config( + request: Request, + current_user: User = Depends(current_superuser) +): + """Validate plugins configuration YAML. Admin only.""" + try: + yaml_content = await request.body() + yaml_str = yaml_content.decode('utf-8') + result = await system_controller.validate_plugins_config_yaml(yaml_str) + return JSONResponse(content=result) + except Exception as e: + logger.error(f"Failed to validate plugins config: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/streaming/status") async def get_streaming_status(request: Request, current_user: User = Depends(current_superuser)): """Get status of active streaming sessions and Redis Streams health. Admin only.""" diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py index 8ae0646b..aeb12e02 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py @@ -11,8 +11,6 @@ import redis.asyncio as redis from redis import exceptions as redis_exceptions -from redis.asyncio.lock import Lock - logger = logging.getLogger(__name__) @@ -28,8 +26,8 @@ def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks: """ Initialize consumer. - Dynamically discovers all audio:stream:* streams and claims them using Redis locks - to ensure exclusive processing (one consumer per stream). + Dynamically discovers all audio:stream:* streams and uses Redis consumer groups + for fan-out processing (multiple worker types can process the same stream). Args: provider_name: Provider name (e.g., "deepgram", "parakeet") @@ -47,9 +45,8 @@ def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks: self.running = False - # Dynamic stream discovery with exclusive locks + # Dynamic stream discovery - consumer groups handle fan-out self.active_streams = {} # {stream_name: True} - self.stream_locks = {} # {stream_name: Lock object} # Buffering: accumulate chunks per session self.session_buffers = {} # {session_id: {"chunks": [], "chunk_ids": [], "sample_rate": int}} @@ -73,59 +70,6 @@ async def discover_streams(self) -> list[str]: return streams - async def try_claim_stream(self, stream_name: str) -> bool: - """ - Try to claim exclusive ownership of a stream using Redis lock. - - Args: - stream_name: Stream to claim - - Returns: - True if lock acquired, False otherwise - """ - lock_key = f"consumer:lock:{stream_name}" - - # Create lock with 30 second timeout (will be renewed) - lock = Lock( - self.redis_client, - lock_key, - timeout=30, - blocking=False # Non-blocking - ) - - acquired = await lock.acquire(blocking=False) - - if acquired: - self.stream_locks[stream_name] = lock - logger.info(f"🔒 Claimed stream: {stream_name}") - return True - else: - logger.debug(f"⏭️ Stream already claimed by another consumer: {stream_name}") - return False - - async def release_stream(self, stream_name: str): - """Release lock on a stream.""" - if stream_name in self.stream_locks: - try: - await self.stream_locks[stream_name].release() - logger.info(f"🔓 Released stream: {stream_name}") - except Exception as e: - logger.warning(f"Failed to release lock for {stream_name}: {e}") - finally: - del self.stream_locks[stream_name] - - async def renew_stream_locks(self): - """Renew locks on all claimed streams.""" - for stream_name, lock in list(self.stream_locks.items()): - try: - await lock.reacquire() - except Exception as e: - logger.warning(f"Failed to renew lock for {stream_name}: {e}") - # Lock expired, remove from our list - del self.stream_locks[stream_name] - if stream_name in self.active_streams: - del self.active_streams[stream_name] - async def setup_consumer_group(self, stream_name: str): """Create consumer group if it doesn't exist.""" # Create consumer group (ignore error if already exists) @@ -257,14 +201,12 @@ async def transcribe_audio(self, audio_data: bytes, sample_rate: int) -> dict: pass async def start_consuming(self): - """Discover and consume from multiple streams with exclusive locking.""" + """Discover and consume from multiple streams using Redis consumer groups.""" self.running = True - logger.info(f"➡️ Starting dynamic stream consumer: {self.consumer_name}") + logger.info(f"➡️ Starting dynamic stream consumer: {self.consumer_name} (group: {self.group_name})") last_discovery = 0 - last_lock_renewal = 0 discovery_interval = 10 # Discover new streams every 10 seconds - lock_renewal_interval = 15 # Renew locks every 15 seconds while self.running: try: @@ -277,20 +219,13 @@ async def start_consuming(self): for stream_name in discovered: if stream_name not in self.active_streams: - # Try to claim this stream - if await self.try_claim_stream(stream_name): - # Setup consumer group for this stream - await self.setup_consumer_group(stream_name) - self.active_streams[stream_name] = True - logger.info(f"✅ Now consuming from {stream_name}") + # Setup consumer group for this stream (no manual lock needed) + await self.setup_consumer_group(stream_name) + self.active_streams[stream_name] = True + logger.info(f"✅ Now consuming from {stream_name} (group: {self.group_name})") last_discovery = current_time - # Periodically renew locks - if current_time - last_lock_renewal > lock_renewal_interval: - await self.renew_stream_locks() - last_lock_renewal = current_time - # Read from all active streams if not self.active_streams: # No streams claimed yet, wait and retry @@ -326,14 +261,6 @@ async def start_consuming(self): if stream_name in error_msg: logger.warning(f"➡️ [{self.consumer_name}] Stream {stream_name} was deleted, removing from active streams") - # Release the lock - lock_key = f"consumer:lock:{stream_name}" - try: - await self.redis_client.delete(lock_key) - logger.info(f"🔓 Released lock for deleted stream: {stream_name}") - except: - pass - # Remove from active streams del self.active_streams[stream_name] logger.info(f"➡️ [{self.consumer_name}] Removed {stream_name}, {len(self.active_streams)} streams remaining") @@ -419,9 +346,6 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st # Clean up session buffer del self.session_buffers[session_id] - # Release the consumer lock for this stream - await self.release_stream(stream_name) - # ACK the END message await self.redis_client.xack(stream_name, self.group_name, message_id) return diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py index cf153472..85ee200a 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py @@ -171,19 +171,19 @@ async def search_memories(self, query_embedding: List[float], user_id: str, limi # For cosine similarity, scores range from -1 to 1, where 1 is most similar search_params = { "collection_name": self.collection_name, - "query_vector": query_embedding, + "query": query_embedding, "query_filter": search_filter, "limit": limit } - + if score_threshold > 0.0: search_params["score_threshold"] = score_threshold memory_logger.debug(f"Using similarity threshold: {score_threshold}") - - results = await self.client.search(**search_params) - + + response = await self.client.query_points(**search_params) + memories = [] - for result in results: + for result in response.points: memory = MemoryEntry( id=str(result.id), content=result.payload.get("content", ""), diff --git a/backends/advanced/src/advanced_omi_backend/services/plugin_service.py b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py new file mode 100644 index 00000000..2c0c9988 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/services/plugin_service.py @@ -0,0 +1,163 @@ +"""Plugin service for accessing the global plugin router. + +This module provides singleton access to the plugin router, allowing +worker jobs to trigger plugins without accessing FastAPI app state directly. +""" + +import logging +import os +import re +from typing import Optional, Any +from pathlib import Path +import yaml + +from advanced_omi_backend.plugins import PluginRouter + +logger = logging.getLogger(__name__) + +# Global plugin router instance +_plugin_router: Optional[PluginRouter] = None + + +def expand_env_vars(value: Any) -> Any: + """ + Recursively expand environment variables in configuration values. + + Supports ${ENV_VAR} syntax. If the environment variable is not set, + the original placeholder is kept. + + Args: + value: Configuration value (can be str, dict, list, or other) + + Returns: + Value with environment variables expanded + + Examples: + >>> os.environ['MY_TOKEN'] = 'secret123' + >>> expand_env_vars('token: ${MY_TOKEN}') + 'token: secret123' + >>> expand_env_vars({'token': '${MY_TOKEN}'}) + {'token': 'secret123'} + """ + if isinstance(value, str): + # Pattern: ${ENV_VAR} or ${ENV_VAR:-default} + def replacer(match): + var_expr = match.group(1) + # Support default values: ${VAR:-default} + if ':-' in var_expr: + var_name, default = var_expr.split(':-', 1) + return os.environ.get(var_name.strip(), default.strip()) + else: + var_name = var_expr.strip() + env_value = os.environ.get(var_name) + if env_value is None: + logger.warning( + f"Environment variable '{var_name}' not found, " + f"keeping placeholder: ${{{var_name}}}" + ) + return match.group(0) # Keep original placeholder + return env_value + + return re.sub(r'\$\{([^}]+)\}', replacer, value) + + elif isinstance(value, dict): + return {k: expand_env_vars(v) for k, v in value.items()} + + elif isinstance(value, list): + return [expand_env_vars(item) for item in value] + + else: + return value + + +def get_plugin_router() -> Optional[PluginRouter]: + """Get the global plugin router instance. + + Returns: + Plugin router instance if initialized, None otherwise + """ + global _plugin_router + return _plugin_router + + +def set_plugin_router(router: PluginRouter) -> None: + """Set the global plugin router instance. + + This should be called during app initialization in app_factory.py. + + Args: + router: Initialized plugin router instance + """ + global _plugin_router + _plugin_router = router + logger.info("Plugin router registered with plugin service") + + +def init_plugin_router() -> Optional[PluginRouter]: + """Initialize the plugin router from configuration. + + This is called during app startup to create and register the plugin router. + + Returns: + Initialized plugin router, or None if no plugins configured + """ + global _plugin_router + + if _plugin_router is not None: + logger.warning("Plugin router already initialized") + return _plugin_router + + try: + _plugin_router = PluginRouter() + + # Load plugin configuration + plugins_yml = Path("/app/plugins.yml") + if plugins_yml.exists(): + with open(plugins_yml, 'r') as f: + plugins_config = yaml.safe_load(f) + # Expand environment variables in configuration + plugins_config = expand_env_vars(plugins_config) + plugins_data = plugins_config.get('plugins', {}) + + # Initialize each enabled plugin + for plugin_id, plugin_config in plugins_data.items(): + if not plugin_config.get('enabled', False): + continue + + try: + if plugin_id == 'homeassistant': + from advanced_omi_backend.plugins.homeassistant import HomeAssistantPlugin + plugin = HomeAssistantPlugin(plugin_config) + # Note: async initialization happens in app_factory lifespan + _plugin_router.register_plugin(plugin_id, plugin) + logger.info(f"✅ Plugin '{plugin_id}' registered") + else: + logger.warning(f"Unknown plugin: {plugin_id}") + + except Exception as e: + logger.error(f"Failed to register plugin '{plugin_id}': {e}", exc_info=True) + + logger.info(f"Plugins registered: {len(_plugin_router.plugins)} total") + else: + logger.info("No plugins.yml found, plugins disabled") + + return _plugin_router + + except Exception as e: + logger.error(f"Failed to initialize plugin router: {e}", exc_info=True) + _plugin_router = None + return None + + +async def cleanup_plugin_router() -> None: + """Clean up the plugin router and all registered plugins.""" + global _plugin_router + + if _plugin_router: + try: + await _plugin_router.cleanup_all() + logger.info("Plugin router cleanup complete") + except Exception as e: + logger.error(f"Error during plugin router cleanup: {e}") + finally: + _plugin_router = None diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py index 2e20171b..f481ac3f 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py @@ -10,6 +10,7 @@ import json import logging from typing import Optional +from urllib.parse import urlencode import httpx import websockets @@ -167,26 +168,65 @@ def __init__(self): def name(self) -> str: return self._name + async def transcribe(self, audio_data: bytes, sample_rate: int, **kwargs) -> dict: + """Not used for streaming providers - use start_stream/process_audio_chunk/end_stream instead.""" + raise NotImplementedError("Streaming providers do not support batch transcription") + async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize: bool = False): - url = self.model.model_url + base_url = self.model.model_url ops = self.model.operations or {} + + # Build WebSocket URL with query parameters (for Deepgram streaming) + query_params = ops.get("query", {}) + query_dict = dict(query_params) if query_params else {} + + # Override sample_rate if provided + if sample_rate and "sample_rate" in query_dict: + query_dict["sample_rate"] = sample_rate + if diarize and "diarize" in query_dict: + query_dict["diarize"] = "true" + + # Normalize boolean values to lowercase strings (Deepgram expects "true"/"false", not "True"/"False") + normalized_query = {} + for k, v in query_dict.items(): + if isinstance(v, bool): + normalized_query[k] = "true" if v else "false" + else: + normalized_query[k] = v + + # Build query string with proper URL encoding (NO token in query) + query_str = urlencode(normalized_query) + url = f"{base_url}?{query_str}" if query_str else base_url + + # Debug: Log the URL + logger.info(f"🔗 Connecting to Deepgram WebSocket: {url}") + + # Connect to WebSocket with Authorization header (Deepgram requires this for server-side connections) + headers = {} + if self.model.api_key: + headers["Authorization"] = f"Token {self.model.api_key}" + + ws = await websockets.connect(url, additional_headers=headers) + + # Send start message if required by provider start_msg = (ops.get("start", {}) or {}).get("message", {}) - # Inject session_id if placeholder present - start_msg = json.loads(json.dumps(start_msg)) # deep copy - start_msg.setdefault("session_id", client_id) - # Apply sample rate and diarization if present - if "config" in start_msg and isinstance(start_msg["config"], dict): - start_msg["config"].setdefault("sample_rate", sample_rate) - if diarize: - start_msg["config"]["diarize"] = True - - ws = await websockets.connect(url, open_timeout=10) - await ws.send(json.dumps(start_msg)) - # Wait for confirmation; non-fatal if not provided - try: - await asyncio.wait_for(ws.recv(), timeout=2.0) - except Exception: - pass + if start_msg: + # Inject session_id if placeholder present + start_msg = json.loads(json.dumps(start_msg)) # deep copy + start_msg.setdefault("session_id", client_id) + # Apply sample rate and diarization if present + if "config" in start_msg and isinstance(start_msg["config"], dict): + start_msg["config"].setdefault("sample_rate", sample_rate) + if diarize: + start_msg["config"]["diarize"] = True + await ws.send(json.dumps(start_msg)) + + # Wait for confirmation; non-fatal if not provided + try: + await asyncio.wait_for(ws.recv(), timeout=2.0) + except Exception: + pass + self._streams[client_id] = {"ws": ws, "sample_rate": sample_rate, "final": None, "interim": []} async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict | None: @@ -194,26 +234,67 @@ async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> dict return None ws = self._streams[client_id]["ws"] ops = self.model.operations or {} + + # Send chunk header if required (for providers like Parakeet) chunk_hdr = (ops.get("chunk_header", {}) or {}).get("message", {}) - hdr = json.loads(json.dumps(chunk_hdr)) - hdr.setdefault("type", "audio_chunk") - hdr.setdefault("session_id", client_id) - hdr.setdefault("rate", self._streams[client_id]["sample_rate"]) - await ws.send(json.dumps(hdr)) + if chunk_hdr: + hdr = json.loads(json.dumps(chunk_hdr)) + hdr.setdefault("type", "audio_chunk") + hdr.setdefault("session_id", client_id) + hdr.setdefault("rate", self._streams[client_id]["sample_rate"]) + await ws.send(json.dumps(hdr)) + + # Send audio chunk (raw bytes for Deepgram, or after header for others) await ws.send(audio_chunk) - # Non-blocking read for interim results + # Non-blocking read for results expect = (ops.get("expect", {}) or {}) + extract = expect.get("extract", {}) interim_type = expect.get("interim_type") + final_type = expect.get("final_type") + try: - while True: - msg = await asyncio.wait_for(ws.recv(), timeout=0.01) - data = json.loads(msg) - if interim_type and data.get("type") == interim_type: - self._streams[client_id]["interim"].append(data) + # Try to read a message (non-blocking) + msg = await asyncio.wait_for(ws.recv(), timeout=0.05) + data = json.loads(msg) + + # Determine if this is interim or final result + is_final = False + if final_type and data.get("type") == final_type: + # Check if Deepgram marks it as final + is_final = data.get("is_final", False) + elif interim_type and data.get("type") == interim_type: + is_final = data.get("is_final", False) + + # Extract result data + text = _dotted_get(data, extract.get("text")) if extract.get("text") else data.get("text", "") + words = _dotted_get(data, extract.get("words")) if extract.get("words") else data.get("words", []) + segments = _dotted_get(data, extract.get("segments")) if extract.get("segments") else data.get("segments", []) + + # Calculate confidence if available + confidence = data.get("confidence", 0.0) + if not confidence and words and isinstance(words, list): + # Calculate average word confidence + confidences = [w.get("confidence", 0.0) for w in words if isinstance(w, dict) and "confidence" in w] + if confidences: + confidence = sum(confidences) / len(confidences) + + # Return result with is_final flag + # Consumer decides what to do with interim vs final + return { + "text": text, + "words": words, + "segments": segments, + "is_final": is_final, + "confidence": confidence + } + except asyncio.TimeoutError: - pass - return None + # No message available yet + return None + except Exception as e: + logger.error(f"Error processing audio chunk result for {client_id}: {e}") + return None async def end_stream(self, client_id: str) -> dict: if client_id not in self._streams: diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/deepgram_stream_consumer.py b/backends/advanced/src/advanced_omi_backend/services/transcription/deepgram_stream_consumer.py new file mode 100644 index 00000000..ff312360 --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/deepgram_stream_consumer.py @@ -0,0 +1,474 @@ +""" +Deepgram WebSocket streaming consumer for real-time transcription. + +Reads from: audio:stream:* streams +Publishes interim to: Redis Pub/Sub channel transcription:interim:{session_id} +Writes final to: transcription:results:{session_id} Redis Stream +Triggers plugins: streaming_transcript level (final results only) +""" + +import asyncio +import json +import logging +import os +import time +from typing import Dict, Optional + +import redis.asyncio as redis +from redis import exceptions as redis_exceptions + +from advanced_omi_backend.plugins.router import PluginRouter +from advanced_omi_backend.services.transcription import get_transcription_provider +from advanced_omi_backend.client_manager import get_client_owner_async + +logger = logging.getLogger(__name__) + + +class DeepgramStreamingConsumer: + """ + Deepgram streaming consumer for real-time WebSocket transcription. + + - Discovers audio:stream:* streams dynamically + - Uses Redis consumer groups for fan-out (allows batch workers to process same stream) + - Starts WebSocket connections to Deepgram per stream + - Sends audio immediately (no buffering) + - Publishes interim results to Redis Pub/Sub for client display + - Publishes final results to Redis Streams for storage + - Triggers plugins only on final results + """ + + def __init__(self, redis_client: redis.Redis, plugin_router: Optional[PluginRouter] = None): + """ + Initialize Deepgram streaming consumer. + + Args: + redis_client: Connected Redis client + plugin_router: Plugin router for triggering plugins on final results + """ + self.redis_client = redis_client + self.plugin_router = plugin_router + + # Get streaming transcription provider from registry + self.provider = get_transcription_provider(mode="streaming") + if not self.provider: + raise RuntimeError( + "Failed to load streaming transcription provider. " + "Ensure config.yml has a default 'stt_stream' model configured." + ) + + # Stream configuration + self.stream_pattern = "audio:stream:*" + self.group_name = "streaming-transcription" + self.consumer_name = f"streaming-worker-{os.getpid()}" + + self.running = False + + # Active stream tracking - consumer groups handle fan-out + self.active_streams: Dict[str, Dict] = {} # {stream_name: {"session_id": ...}} + + # Session tracking for WebSocket connections + self.active_sessions: Dict[str, Dict] = {} # {session_id: {"last_activity": timestamp}} + + async def discover_streams(self) -> list[str]: + """ + Discover all audio streams matching the pattern. + + Returns: + List of stream names + """ + streams = [] + cursor = b"0" + + while cursor: + cursor, keys = await self.redis_client.scan( + cursor, match=self.stream_pattern, count=100 + ) + if keys: + streams.extend([k.decode() if isinstance(k, bytes) else k for k in keys]) + + return streams + + async def setup_consumer_group(self, stream_name: str): + """Create consumer group if it doesn't exist.""" + try: + await self.redis_client.xgroup_create( + stream_name, + self.group_name, + "0", + mkstream=True + ) + logger.debug(f"➡️ Created consumer group {self.group_name} for {stream_name}") + except redis_exceptions.ResponseError as e: + if "BUSYGROUP" not in str(e): + raise + logger.debug(f"➡️ Consumer group {self.group_name} already exists for {stream_name}") + + async def start_session_stream(self, session_id: str, sample_rate: int = 16000): + """ + Start WebSocket connection to Deepgram for a session. + + Args: + session_id: Session ID (client_id from audio stream) + sample_rate: Audio sample rate in Hz + """ + try: + await self.provider.start_stream( + client_id=session_id, + sample_rate=sample_rate, + diarize=False # Deepgram streaming doesn't support diarization + ) + + self.active_sessions[session_id] = { + "last_activity": time.time(), + "sample_rate": sample_rate + } + + logger.info(f"🎙️ Started Deepgram WebSocket stream for session: {session_id}") + + except Exception as e: + logger.error(f"Failed to start Deepgram stream for {session_id}: {e}", exc_info=True) + raise + + async def end_session_stream(self, session_id: str): + """ + End WebSocket connection to Deepgram for a session. + + Args: + session_id: Session ID + """ + try: + # Get final result from Deepgram + final_result = await self.provider.end_stream(client_id=session_id) + + # If there's a final result, publish it + if final_result and final_result.get("text"): + await self.publish_to_client(session_id, final_result, is_final=True) + await self.store_final_result(session_id, final_result) + + # Trigger plugins on final result + if self.plugin_router: + await self.trigger_plugins(session_id, final_result) + + self.active_sessions.pop(session_id, None) + logger.info(f"🛑 Ended Deepgram WebSocket stream for session: {session_id}") + + except Exception as e: + logger.error(f"Error ending stream for {session_id}: {e}", exc_info=True) + + async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_id: str): + """ + Process a single audio chunk through Deepgram WebSocket. + + Args: + session_id: Session ID + audio_chunk: Raw audio bytes + chunk_id: Chunk identifier from Redis stream + """ + try: + # Send audio chunk to Deepgram WebSocket and get result + result = await self.provider.process_audio_chunk( + client_id=session_id, + audio_chunk=audio_chunk + ) + + # Update last activity + if session_id in self.active_sessions: + self.active_sessions[session_id]["last_activity"] = time.time() + + # Deepgram returns None if no response yet, or a dict with results + if result: + is_final = result.get("is_final", False) + + # Always publish to clients (interim + final) for real-time display + await self.publish_to_client(session_id, result, is_final=is_final) + + # If final result, also store and trigger plugins + if is_final: + await self.store_final_result(session_id, result, chunk_id=chunk_id) + + # Trigger plugins on final results only + if self.plugin_router: + await self.trigger_plugins(session_id, result) + + except Exception as e: + logger.error(f"Error processing audio chunk for {session_id}: {e}", exc_info=True) + + async def publish_to_client(self, session_id: str, result: Dict, is_final: bool): + """ + Publish interim or final results to Redis Pub/Sub for client consumption. + + Args: + session_id: Session ID + result: Transcription result from Deepgram + is_final: Whether this is a final result + """ + try: + channel = f"transcription:interim:{session_id}" + + # Prepare message for clients + message = { + "text": result.get("text", ""), + "is_final": is_final, + "words": result.get("words", []), + "confidence": result.get("confidence", 0.0), + "timestamp": time.time() + } + + # Publish to Redis Pub/Sub + await self.redis_client.publish(channel, json.dumps(message)) + + result_type = "FINAL" if is_final else "interim" + logger.debug(f"📢 Published {result_type} result to {channel}: {message['text'][:50]}...") + + except Exception as e: + logger.error(f"Error publishing to client for {session_id}: {e}", exc_info=True) + + async def store_final_result(self, session_id: str, result: Dict, chunk_id: str = None): + """ + Store final transcription result to Redis Stream. + + Args: + session_id: Session ID + result: Final transcription result + chunk_id: Optional chunk identifier + """ + try: + stream_name = f"transcription:results:{session_id}" + + # Prepare result entry + entry = { + "message_id": chunk_id or f"final_{int(time.time() * 1000)}", + "text": result.get("text", ""), + "confidence": result.get("confidence", 0.0), + "provider": "deepgram-stream", + "timestamp": time.time(), + "words": json.dumps(result.get("words", [])), + "segments": json.dumps(result.get("segments", [])), + "is_final": "true" + } + + # Write to Redis Stream + await self.redis_client.xadd(stream_name, entry) + + logger.info(f"💾 Stored final result to {stream_name}: {entry['text'][:50]}...") + + except Exception as e: + logger.error(f"Error storing final result for {session_id}: {e}", exc_info=True) + + async def _get_user_id_from_client_id(self, client_id: str) -> Optional[str]: + """ + Look up user_id from client_id using ClientManager (async Redis lookup). + + Args: + client_id: Client ID to search for + + Returns: + user_id if found, None otherwise + """ + user_id = await get_client_owner_async(client_id) + + if user_id: + logger.debug(f"Found user_id {user_id} for client_id {client_id} via Redis") + else: + logger.warning(f"No user_id found for client_id {client_id} in Redis") + + return user_id + + async def trigger_plugins(self, session_id: str, result: Dict): + """ + Trigger plugins at streaming_transcript access level (final results only). + + Args: + session_id: Session ID (client_id from stream name) + result: Final transcription result + """ + try: + # Find user_id by looking up session with matching client_id + # session_id here is actually the client_id extracted from stream name + user_id = await self._get_user_id_from_client_id(session_id) + + if not user_id: + logger.warning( + f"Could not find user_id for client_id {session_id}. " + "Plugins will not be triggered." + ) + return + + plugin_data = { + 'transcript': result.get("text", ""), + 'session_id': session_id, + 'words': result.get("words", []), + 'segments': result.get("segments", []), + 'confidence': result.get("confidence", 0.0), + 'is_final': True + } + + # Trigger plugins with streaming_transcript access level + logger.info(f"🎯 Triggering plugins for user {user_id}, transcript: {plugin_data['transcript'][:50]}...") + + plugin_results = await self.plugin_router.trigger_plugins( + access_level='streaming_transcript', + user_id=user_id, + data=plugin_data, + metadata={'client_id': session_id} + ) + + if plugin_results: + logger.info(f"✅ Plugins triggered successfully: {len(plugin_results)} results") + else: + logger.info(f"ℹ️ No plugins triggered (no matching conditions)") + + except Exception as e: + logger.error(f"Error triggering plugins for {session_id}: {e}", exc_info=True) + + async def process_stream(self, stream_name: str): + """ + Process a single audio stream. + + Args: + stream_name: Redis stream name (e.g., "audio:stream:user01-phone") + """ + # Extract session_id from stream name (format: audio:stream:{session_id}) + session_id = stream_name.replace("audio:stream:", "") + + # Track this stream + self.active_streams[stream_name] = { + "session_id": session_id, + "started_at": time.time() + } + + # Start WebSocket connection to Deepgram + await self.start_session_stream(session_id) + + last_id = "0" # Start from beginning + stream_ended = False + + try: + while self.running and not stream_ended: + # Read messages from Redis stream using consumer group + try: + messages = await self.redis_client.xreadgroup( + self.group_name, # "streaming-transcription" + self.consumer_name, # "streaming-worker-{pid}" + {stream_name: ">"}, # Read only new messages + count=10, + block=1000 # Block for 1 second + ) + + if not messages: + # No new messages - check if stream is still alive + # Check for stream end marker or timeout + if session_id not in self.active_sessions: + logger.info(f"Session {session_id} no longer active, ending stream processing") + stream_ended = True + continue + + for stream, stream_messages in messages: + logger.debug(f"📥 Read {len(stream_messages)} messages from {stream_name}") + for message_id, fields in stream_messages: + msg_id = message_id.decode() if isinstance(message_id, bytes) else message_id + + # Check for end marker + if fields.get(b'end_marker') or fields.get('end_marker'): + logger.info(f"End marker received for {session_id}") + stream_ended = True + # ACK the end marker + await self.redis_client.xack(stream_name, self.group_name, msg_id) + break + + # Extract audio data (producer sends as 'audio_data', not 'audio_chunk') + audio_chunk = fields.get(b'audio_data') or fields.get('audio_data') + if audio_chunk: + logger.debug(f"🎵 Processing audio chunk {msg_id} ({len(audio_chunk)} bytes)") + # Process audio chunk through Deepgram WebSocket + await self.process_audio_chunk( + session_id=session_id, + audio_chunk=audio_chunk, + chunk_id=msg_id + ) + else: + logger.warning(f"⚠️ Message {msg_id} has no audio_data field") + + # ACK the message after processing + await self.redis_client.xack(stream_name, self.group_name, msg_id) + + if stream_ended: + break + + except redis_exceptions.ResponseError as e: + if "NOGROUP" in str(e): + # Stream has expired or been deleted - exit gracefully + logger.info(f"Stream {stream_name} expired or deleted, ending processing") + stream_ended = True + break + else: + logger.error(f"Redis error reading from stream {stream_name}: {e}", exc_info=True) + await asyncio.sleep(1) + except Exception as e: + logger.error(f"Error reading from stream {stream_name}: {e}", exc_info=True) + await asyncio.sleep(1) + + finally: + # End WebSocket connection + await self.end_session_stream(session_id) + + # Remove from active streams tracking + self.active_streams.pop(stream_name, None) + logger.debug(f"Removed {stream_name} from active streams tracking") + + async def start_consuming(self): + """ + Start consuming audio streams and processing through Deepgram WebSocket. + Uses Redis consumer groups for fan-out (allows batch workers to process same stream). + """ + self.running = True + logger.info(f"🚀 Deepgram streaming consumer started (group: {self.group_name})") + + try: + while self.running: + # Discover available streams + streams = await self.discover_streams() + + if streams: + logger.debug(f"🔍 Discovered {len(streams)} audio streams") + else: + logger.debug("🔍 No audio streams found") + + # Setup consumer groups and spawn processing tasks + for stream_name in streams: + if stream_name in self.active_streams: + continue # Already processing + + # Setup consumer group (no manual lock needed) + await self.setup_consumer_group(stream_name) + + # Track stream and spawn task to process it + session_id = stream_name.replace("audio:stream:", "") + self.active_streams[stream_name] = {"session_id": session_id} + + # Spawn task to process this stream + asyncio.create_task(self.process_stream(stream_name)) + logger.info(f"✅ Now consuming from {stream_name} (group: {self.group_name})") + + # Sleep before next discovery cycle + await asyncio.sleep(5) + + except Exception as e: + logger.error(f"Fatal error in consumer main loop: {e}", exc_info=True) + finally: + await self.stop() + + async def stop(self): + """Stop consuming and clean up resources.""" + logger.info("🛑 Stopping Deepgram streaming consumer...") + self.running = False + + # End all active sessions + session_ids = list(self.active_sessions.keys()) + for session_id in session_ids: + try: + await self.end_session_stream(session_id) + except Exception as e: + logger.error(f"Error ending session {session_id}: {e}") + + logger.info("✅ Deepgram streaming consumer stopped") diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_streaming_worker.py b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_streaming_worker.py new file mode 100644 index 00000000..0a893e6a --- /dev/null +++ b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_deepgram_streaming_worker.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +""" +Deepgram WebSocket streaming audio worker. + +Starts a consumer that reads from audio:stream:* streams and transcribes via Deepgram WebSocket API. +Publishes interim results to Redis Pub/Sub for real-time client display. +Publishes final results to Redis Streams for storage. +Triggers plugins on final results only. +""" + +import asyncio +import logging +import os +import signal +import sys + +import redis.asyncio as redis + +from advanced_omi_backend.services.plugin_service import init_plugin_router +from advanced_omi_backend.services.transcription.deepgram_stream_consumer import DeepgramStreamingConsumer +from advanced_omi_backend.client_manager import initialize_redis_for_client_manager + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" +) + +logger = logging.getLogger(__name__) + + +async def main(): + """Main worker entry point.""" + logger.info("🚀 Starting Deepgram WebSocket streaming worker") + + # Validate DEEPGRAM_API_KEY + api_key = os.getenv("DEEPGRAM_API_KEY") + if not api_key: + logger.error("DEEPGRAM_API_KEY environment variable not set") + logger.error("Cannot start Deepgram streaming worker without API key") + sys.exit(1) + + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") + + # Create Redis client + try: + redis_client = await redis.from_url( + redis_url, + encoding="utf-8", + decode_responses=False + ) + logger.info(f"✅ Connected to Redis: {redis_url}") + + # Initialize ClientManager Redis for cross-container client→user mapping + initialize_redis_for_client_manager(redis_url) + + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}", exc_info=True) + sys.exit(1) + + # Initialize plugin router + try: + plugin_router = init_plugin_router() + if plugin_router: + logger.info(f"✅ Plugin router initialized with {len(plugin_router.plugins)} plugins") + + # Initialize async plugins + for plugin_id, plugin in plugin_router.plugins.items(): + try: + await plugin.initialize() + logger.info(f"✅ Plugin '{plugin_id}' initialized in streaming worker") + except Exception as e: + logger.exception(f"Failed to initialize plugin '{plugin_id}' in streaming worker: {e}") + else: + logger.warning("No plugin router available - plugins will not be triggered") + except Exception as e: + logger.error(f"Failed to initialize plugin router: {e}", exc_info=True) + plugin_router = None + + # Create Deepgram streaming consumer + try: + consumer = DeepgramStreamingConsumer( + redis_client=redis_client, + plugin_router=plugin_router + ) + logger.info("✅ Deepgram streaming consumer created") + except Exception as e: + logger.error(f"Failed to create Deepgram streaming consumer: {e}", exc_info=True) + await redis_client.aclose() + sys.exit(1) + + # Setup signal handlers for graceful shutdown + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, shutting down...") + asyncio.create_task(consumer.stop()) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + logger.info("✅ Deepgram streaming worker ready") + logger.info("📡 Listening for audio streams on audio:stream:* pattern") + logger.info("📢 Publishing interim results to transcription:interim:{session_id}") + logger.info("💾 Publishing final results to transcription:results:{session_id}") + + # This blocks until consumer is stopped + await consumer.start_consuming() + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down...") + except Exception as e: + logger.error(f"Worker error: {e}", exc_info=True) + sys.exit(1) + finally: + await redis_client.aclose() + logger.info("👋 Deepgram streaming worker stopped") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py index d2b8c4fd..7c754d19 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py @@ -10,8 +10,11 @@ from datetime import datetime from typing import Dict, Any from rq.job import Job + from advanced_omi_backend.models.job import async_job from advanced_omi_backend.controllers.queue_controller import redis_conn +from advanced_omi_backend.controllers.session_controller import mark_session_complete +from advanced_omi_backend.services.plugin_service import get_plugin_router from advanced_omi_backend.utils.conversation_utils import ( analyze_speech, @@ -294,9 +297,9 @@ async def open_conversation_job( if status_str in ["finalizing", "complete"]: finalize_received = True - # Check if this was a WebSocket disconnect + # Get completion reason (guaranteed to exist with unified API) completion_reason = await redis_client.hget(session_key, "completion_reason") - completion_reason_str = completion_reason.decode() if completion_reason else None + completion_reason_str = completion_reason.decode() if completion_reason else "unknown" if completion_reason_str == "websocket_disconnect": logger.warning( @@ -306,7 +309,7 @@ async def open_conversation_job( timeout_triggered = False # This is a disconnect, not a timeout else: logger.info( - f"🛑 Session finalizing (reason: {completion_reason_str or 'user_stopped'}), " + f"🛑 Session finalizing (reason: {completion_reason_str}), " f"waiting for audio persistence job to complete..." ) break # Exit immediately when finalize signal received @@ -398,6 +401,42 @@ async def open_conversation_job( ) last_result_count = current_count + # Trigger transcript-level plugins on new transcript segments + try: + plugin_router = get_plugin_router() + if plugin_router: + # Get the latest transcript text for plugin processing + transcript_text = combined.get('text', '') + + if transcript_text: + plugin_data = { + 'transcript': transcript_text, + 'segment_id': f"{session_id}_{current_count}", + 'conversation_id': conversation_id, + 'segments': combined.get('segments', []), + 'word_count': speech_analysis.get('word_count', 0), + } + + plugin_results = await plugin_router.trigger_plugins( + access_level='streaming_transcript', + user_id=user_id, + data=plugin_data, + metadata={'client_id': client_id} + ) + + if plugin_results: + logger.info(f"📌 Triggered {len(plugin_results)} streaming transcript plugins") + for result in plugin_results: + if result.message: + logger.info(f" Plugin: {result.message}") + + # If plugin stopped processing, log it + if not result.should_continue: + logger.info(f" Plugin stopped normal processing") + + except Exception as e: + logger.warning(f"⚠️ Error triggering transcript-level plugins: {e}") + await asyncio.sleep(1) # Check every second for responsiveness logger.info( @@ -496,6 +535,43 @@ async def open_conversation_job( # Wait a moment to ensure jobs are registered in RQ await asyncio.sleep(0.5) + # Trigger conversation-level plugins + try: + plugin_router = get_plugin_router() + if plugin_router: + # Get conversation data for plugin context + conversation_model = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) + + plugin_data = { + 'conversation': { + 'conversation_id': conversation_id, + 'audio_uuid': session_id, + 'client_id': client_id, + 'user_id': user_id, + }, + 'transcript': conversation_model.transcript if conversation_model else "", + 'duration': time.time() - start_time, + 'conversation_id': conversation_id, + } + + plugin_results = await plugin_router.trigger_plugins( + access_level='conversation', + user_id=user_id, + data=plugin_data, + metadata={'end_reason': end_reason} + ) + + if plugin_results: + logger.info(f"📌 Triggered {len(plugin_results)} conversation-level plugins") + for result in plugin_results: + if result.message: + logger.info(f" Plugin result: {result.message}") + + except Exception as e: + logger.warning(f"⚠️ Error triggering conversation-level plugins: {e}") + # Call shared cleanup/restart logic return await handle_end_of_conversation( session_id=session_id, diff --git a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py index 8b64d690..a6939bed 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/memory_jobs.py @@ -16,6 +16,7 @@ ) from advanced_omi_backend.models.job import BaseRQJob, JobPriority, async_job from advanced_omi_backend.services.memory.base import MemoryEntry +from advanced_omi_backend.services.plugin_service import get_plugin_router logger = logging.getLogger(__name__) @@ -240,6 +241,41 @@ async def process_memory_job(conversation_id: str, *, redis_client=None) -> Dict # This allows users to resume talking immediately after conversation closes, # without waiting for memory processing to complete. + # Trigger memory-level plugins + try: + plugin_router = get_plugin_router() + if plugin_router: + plugin_data = { + 'memories': created_memory_ids, + 'conversation': { + 'conversation_id': conversation_id, + 'client_id': client_id, + 'user_id': user_id, + 'user_email': user_email, + }, + 'memory_count': len(created_memory_ids), + 'conversation_id': conversation_id, + } + + plugin_results = await plugin_router.trigger_plugins( + access_level='memory', + user_id=user_id, + data=plugin_data, + metadata={ + 'processing_time': processing_time, + 'memory_provider': str(memory_provider), + } + ) + + if plugin_results: + logger.info(f"📌 Triggered {len(plugin_results)} memory-level plugins") + for result in plugin_results: + if result.message: + logger.info(f" Plugin result: {result.message}") + + except Exception as e: + logger.warning(f"⚠️ Error triggering memory-level plugins: {e}") + return { "success": True, "memories_created": len(created_memory_ids), diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py index c9216d4f..71e64dbd 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py @@ -19,6 +19,7 @@ REDIS_URL, ) from advanced_omi_backend.utils.conversation_utils import analyze_speech, mark_conversation_deleted +from advanced_omi_backend.services.plugin_service import get_plugin_router logger = logging.getLogger(__name__) @@ -167,6 +168,10 @@ async def transcribe_full_audio_job( if not conversation: raise ValueError(f"Conversation {conversation_id} not found") + # Extract user_id and client_id for plugin context + user_id = str(conversation.user_id) if conversation.user_id else None + client_id = conversation.client_id if hasattr(conversation, 'client_id') else None + # Use the provided audio path actual_audio_path = audio_path logger.info(f"📁 Using audio for transcription: {audio_path}") @@ -202,6 +207,59 @@ async def transcribe_full_audio_job( f"📊 Transcription complete: {len(transcript_text)} chars, {len(segments)} segments, {len(words)} words" ) + # Trigger transcript-level plugins BEFORE speech validation + # This ensures wake-word commands execute even if conversation gets deleted + logger.info(f"🔍 DEBUG: About to trigger plugins - transcript_text exists: {bool(transcript_text)}") + if transcript_text: + try: + from advanced_omi_backend.services.plugin_service import init_plugin_router + + # Initialize plugin router if not already initialized (worker context) + plugin_router = get_plugin_router() + if not plugin_router: + logger.info("🔧 Initializing plugin router in worker process...") + plugin_router = init_plugin_router() + + # Initialize async plugins + if plugin_router: + for plugin_id, plugin in plugin_router.plugins.items(): + try: + await plugin.initialize() + logger.info(f"✅ Plugin '{plugin_id}' initialized in worker") + except Exception as e: + logger.exception(f"Failed to initialize plugin '{plugin_id}' in worker: {e}") + + logger.info(f"🔍 DEBUG: Plugin router retrieved: {plugin_router is not None}") + + if plugin_router: + logger.info(f"🔍 DEBUG: Preparing to trigger transcript plugins for conversation {conversation_id}") + plugin_data = { + 'transcript': transcript_text, + 'segment_id': f"{conversation_id}_batch", + 'conversation_id': conversation_id, + 'segments': segments, + 'word_count': len(words), + } + + logger.info(f"🔍 DEBUG: Calling trigger_plugins with user_id={user_id}, client_id={client_id}") + plugin_results = await plugin_router.trigger_plugins( + access_level='transcript', # Batch mode - only 'transcript' plugins, NOT 'streaming_transcript' + user_id=user_id, + data=plugin_data, + metadata={'client_id': client_id} + ) + logger.info(f"🔍 DEBUG: Plugin trigger returned {len(plugin_results) if plugin_results else 0} results") + + if plugin_results: + logger.info(f"✅ Triggered {len(plugin_results)} transcript plugins in batch mode") + for result in plugin_results: + if result.message: + logger.info(f" Plugin: {result.message}") + except Exception as e: + logger.exception(f"⚠️ Error triggering transcript plugins in batch mode: {e}") + + logger.info(f"🔍 DEBUG: Plugin processing complete, moving to speech validation") + # Validate meaningful speech BEFORE any further processing transcript_data = {"text": transcript_text, "words": words} speech_analysis = analyze_speech(transcript_data) diff --git a/backends/advanced/start-workers.sh b/backends/advanced/start-workers.sh index 3fea5a39..8715da4b 100755 --- a/backends/advanced/start-workers.sh +++ b/backends/advanced/start-workers.sh @@ -64,13 +64,17 @@ if registry and registry.defaults: echo "📋 Configured STT provider: ${DEFAULT_STT:-none}" - # Only start Deepgram worker if configured as default STT + # Batch Deepgram worker - uses consumer group "deepgram_workers" + # Runs alongside deepgram-streaming-worker container (consumer group "streaming-transcription") + # Both workers process same streams via Redis consumer groups (fan-out architecture) + # - Batch worker: High-quality transcription with diarization (~6s latency) + # - Streaming worker: Fast wake-word detection with plugins (~1-2s latency) if [[ "$DEFAULT_STT" == "deepgram" ]] && [ -n "$DEEPGRAM_API_KEY" ]; then - echo "🎵 Starting audio stream Deepgram worker (1 worker for sequential processing)..." + echo "🎵 Starting audio stream Deepgram batch worker (consumer group: deepgram_workers)..." uv run python -m advanced_omi_backend.workers.audio_stream_deepgram_worker & AUDIO_STREAM_DEEPGRAM_WORKER_PID=$! else - echo "⏭️ Skipping Deepgram stream worker (not configured as default STT or API key missing)" + echo "⏭️ Skipping Deepgram batch worker (not configured as default STT or API key missing)" AUDIO_STREAM_DEEPGRAM_WORKER_PID="" fi diff --git a/backends/advanced/start.sh b/backends/advanced/start.sh index 40fa4abf..5cc79635 100755 --- a/backends/advanced/start.sh +++ b/backends/advanced/start.sh @@ -10,7 +10,8 @@ echo "🚀 Starting Chronicle Backend..." # Function to handle shutdown shutdown() { echo "🛑 Shutting down services..." - pkill -TERM -P $$ + # Kill the backend process if running + [ -n "$BACKEND_PID" ] && kill -TERM $BACKEND_PID 2>/dev/null || true wait echo "✅ All services stopped" exit 0 diff --git a/backends/advanced/webui/src/App.tsx b/backends/advanced/webui/src/App.tsx index fca59623..42370975 100644 --- a/backends/advanced/webui/src/App.tsx +++ b/backends/advanced/webui/src/App.tsx @@ -13,6 +13,7 @@ import System from './pages/System' import Upload from './pages/Upload' import Queue from './pages/Queue' import LiveRecord from './pages/LiveRecord' +import Plugins from './pages/Plugins' import ProtectedRoute from './components/auth/ProtectedRoute' import { ErrorBoundary, PageErrorBoundary } from './components/ErrorBoundary' @@ -89,6 +90,11 @@ function App() { } /> + + + + } /> diff --git a/backends/advanced/webui/src/components/ChatSettings.tsx b/backends/advanced/webui/src/components/ChatSettings.tsx new file mode 100644 index 00000000..1acad362 --- /dev/null +++ b/backends/advanced/webui/src/components/ChatSettings.tsx @@ -0,0 +1,195 @@ +import { useState, useEffect } from 'react' +import { MessageSquare, RefreshCw, CheckCircle, Save, RotateCcw, AlertCircle } from 'lucide-react' +import { systemApi } from '../services/api' +import { useAuth } from '../contexts/AuthContext' + +interface ChatSettingsProps { + className?: string +} + +export default function ChatSettings({ className }: ChatSettingsProps) { + const [configYaml, setConfigYaml] = useState('') + const [loading, setLoading] = useState(false) + const [validating, setValidating] = useState(false) + const [saving, setSaving] = useState(false) + const [message, setMessage] = useState('') + const [error, setError] = useState('') + const { isAdmin } = useAuth() + + useEffect(() => { + loadChatConfig() + }, []) + + const loadChatConfig = async () => { + setLoading(true) + setError('') + setMessage('') + + try { + const response = await systemApi.getChatConfigRaw() + setConfigYaml(response.data.config_yaml || response.data) + setMessage('Configuration loaded successfully') + setTimeout(() => setMessage(''), 3000) + } catch (err: any) { + const status = err.response?.status + if (status === 401) { + setError('Unauthorized: admin privileges required') + } else { + setError(err.response?.data?.error || 'Failed to load configuration') + } + } finally { + setLoading(false) + } + } + + const validateConfig = async () => { + if (!configYaml.trim()) { + setError('Configuration cannot be empty') + return + } + + setValidating(true) + setError('') + setMessage('') + + try { + const response = await systemApi.validateChatConfig(configYaml) + if (response.data.valid) { + setMessage('✅ Configuration is valid') + } else { + setError(response.data.error || 'Validation failed') + } + setTimeout(() => setMessage(''), 3000) + } catch (err: any) { + setError(err.response?.data?.error || 'Validation failed') + } finally { + setValidating(false) + } + } + + const saveConfig = async () => { + if (!configYaml.trim()) { + setError('Configuration cannot be empty') + return + } + + setSaving(true) + setError('') + setMessage('') + + try { + await systemApi.updateChatConfigRaw(configYaml) + setMessage('✅ Configuration saved successfully') + setTimeout(() => setMessage(''), 5000) + } catch (err: any) { + setError(err.response?.data?.error || 'Failed to save configuration') + } finally { + setSaving(false) + } + } + + const resetConfig = () => { + loadChatConfig() + setMessage('Configuration reset to file version') + setTimeout(() => setMessage(''), 3000) + } + + if (!isAdmin) { + return null + } + + return ( +
+
+ {/* Header */} +
+
+ +

+ Chat System Prompt +

+
+
+ + +
+
+ + {/* Messages */} + {message && ( +
+ +

{message}

+
+ )} + + {error && ( +
+ +

{error}

+
+ )} + + {/* Editor */} +
+