diff --git a/.gitignore b/.gitignore index 7961a03..c5c75a2 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,6 @@ docs/ agent.py call.py .scannerwork -.pytest_cache \ No newline at end of file +.pytest_cache +.claude +testing/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 05b22fb..5795f9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ keywords = [ dependencies = [ "livekit", "livekit-api", - "livekit-agents", + "livekit-agents>=1.5.1", "boto3", "python-dotenv", "livekit-plugins-noise-cancellation", diff --git a/siphon/agent/core/entrypoint.py b/siphon/agent/core/entrypoint.py index 6befa3a..ddff860 100644 --- a/siphon/agent/core/entrypoint.py +++ b/siphon/agent/core/entrypoint.py @@ -1,7 +1,9 @@ from livekit.agents import JobContext, room_io from livekit import rtc from livekit.agents.voice import AgentSession +from livekit.agents.voice.agent_session import TurnHandlingOptions from livekit.plugins import silero, noise_cancellation +from livekit.plugins.turn_detector.multilingual import MultilingualModel from .voice_agent import AgentSetup import json import asyncio @@ -210,6 +212,7 @@ def _build_agent_session( min_endpointing_delay: float, max_endpointing_delay: float, min_interruption_duration: float, + preemptive_generation: bool = True, ) -> AgentSession: """Construct an AgentSession with the configured models and settings.""" vad_instance = silero.VAD.load( @@ -218,16 +221,27 @@ def _build_agent_session( prefix_padding_duration=prefix_padding_duration, ) + turn_options: TurnHandlingOptions = { + "turn_detection": MultilingualModel(), + "endpointing": { + "min_delay": min_endpointing_delay, + "max_delay": max_endpointing_delay, + }, + "interruption": { + "enabled": allow_interruptions, + "mode": "adaptive", + "discard_audio_if_uninterruptible": False, + "min_duration": min_interruption_duration, + }, + } + return AgentSession( llm=session_llm, tts=session_tts, stt=session_stt, vad=vad_instance, - turn_detection="stt", - allow_interruptions=allow_interruptions, - min_endpointing_delay=min_endpointing_delay, - max_endpointing_delay=max_endpointing_delay, - min_interruption_duration=min_interruption_duration, + turn_handling=turn_options, + preemptive_generation=preemptive_generation, max_tool_steps=1000, ) @@ -329,16 +343,19 @@ async def entrypoint( greeting_instructions = kwargs.get("greeting_instructions", "Greet and introduce yourself briefly") system_instructions = kwargs.get("system_instructions", "You are a helpful voice assistant") allow_interruptions = kwargs.get("allow_interruptions", True) - min_silence_duration = kwargs.get("min_silence_duration", 0.25) - activation_threshold = kwargs.get("activation_threshold", 0.25) - prefix_padding_duration = kwargs.get("prefix_padding_duration", 1.0) - min_endpointing_delay = kwargs.get("min_endpointing_delay", 0.25) - max_endpointing_delay = kwargs.get("max_endpointing_delay", 1.5) - min_interruption_duration = kwargs.get("min_interruption_duration", 0.05) + min_silence_duration = kwargs.get("min_silence_duration", 0.5) + activation_threshold = kwargs.get("activation_threshold", 0.5) + prefix_padding_duration = kwargs.get("prefix_padding_duration", 0.3) + min_endpointing_delay = kwargs.get("min_endpointing_delay", 0.2) + max_endpointing_delay = kwargs.get("max_endpointing_delay", 0.6) + min_interruption_duration = kwargs.get("min_interruption_duration", 0.3) + preemptive_generation = kwargs.get("preemptive_generation", True) tools = kwargs.get("tools", None) google_calendar = kwargs.get("google_calendar", False) date_time = kwargs.get("date_time", True) remember_call = kwargs.get("remember_call", False) + noise_cancellation_sip = kwargs.get("noise_cancellation_sip", False) + debug = kwargs.get("debug", False) agent_setup: Optional[AgentSetup] = None try: @@ -408,6 +425,7 @@ async def entrypoint( min_endpointing_delay=min_endpointing_delay, max_endpointing_delay=max_endpointing_delay, min_interruption_duration=min_interruption_duration, + preemptive_generation=preemptive_generation, ) await session.start( @@ -415,12 +433,68 @@ async def entrypoint( agent=agent_setup, room_options=room_io.RoomOptions( audio_input=room_io.AudioInputOptions( - noise_cancellation=lambda params: noise_cancellation.BVCTelephony() - if params.participant.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP - else noise_cancellation.BVC(), + noise_cancellation=( + lambda params: noise_cancellation.BVCTelephony() + if params.participant.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP + else noise_cancellation.BVC() + ) if noise_cancellation_sip else ( + lambda params: noise_cancellation.BVC() + if params.participant.kind != rtc.ParticipantKind.PARTICIPANT_KIND_SIP + else None + ), ), ), ) + + # ------------------------------------------------------------------ + # Diagnostic logging (only when debug=True) + # ------------------------------------------------------------------ + + if debug: + interruption_opts = session.options.interruption + logger.info( + "Session config: interrupt_enabled=%s, interrupt_mode=%s, " + "discard_audio=%s, min_duration=%.2f, preemptive=%s", + interruption_opts.get("enabled"), + interruption_opts.get("mode"), + interruption_opts.get("discard_audio_if_uninterruptible"), + interruption_opts.get("min_duration"), + session.options.preemptive_generation, + ) + + @session.on("user_input_transcribed") + def _on_user_input_transcribed(ev): + """Log ALL transcripts with speech handle state for interrupt debugging.""" + speech = session.current_speech + if speech: + speech_info = ( + f"allow_int={speech.allow_interruptions}, " + f"interrupted={speech.interrupted}, " + f"done={speech.done()}" + ) + else: + speech_info = "no active speech" + + logger.info( + "TRANSCRIPT: %r (final=%s, agent_state=%s, speech=[%s])", + ev.transcript[:120] if ev.transcript else "", + ev.is_final, + session.agent_state, + speech_info, + ) + + @session.on("agent_state_changed") + def _on_agent_state_changed(ev): + """Log agent state transitions and clear text buffer.""" + logger.info("AGENT_STATE: %s -> %s", ev.old_state, ev.new_state) + if ev.new_state != "speaking": + agent_setup.clear_agent_text_buffer() + + @session.on("agent_false_interruption") + def _on_false_interruption(ev): + """Log when the SDK detects a false interruption (echo).""" + logger.info("FALSE_INTERRUPTION: resumed=%s", ev.resumed) + logger.info("Agent session started successfully.") call_result = await monitor_call(ctx, agent_setup) diff --git a/siphon/agent/core/voice_agent.py b/siphon/agent/core/voice_agent.py index 0f58562..d1d674d 100644 --- a/siphon/agent/core/voice_agent.py +++ b/siphon/agent/core/voice_agent.py @@ -1,7 +1,8 @@ import asyncio import time from datetime import datetime -from livekit.agents.voice import Agent +from typing import AsyncIterable, Optional +from livekit.agents.voice import Agent, ModelSettings from livekit import rtc from livekit.agents import ChatContext from siphon.config import get_logger, HangupCall, CallTranscription @@ -11,9 +12,11 @@ logger = get_logger("calling-agent") -from typing import Optional from siphon.memory import MemoryService +# Maximum characters kept in the rolling agent-text buffer (for echo detection). +_AGENT_TEXT_BUFFER_MAX = 1000 + def _get_current_datetime_stamp() -> str: """Generate a current date/time stamp for injection into the system prompt. @@ -134,6 +137,10 @@ def __init__(self, # Initialize transcription mixin for conversation tracking CallTranscription.__init__(self) + # Rolling buffer of recent agent output text (for echo detection). + # Filled by transcription_node() in real-time as LLM text streams to TTS. + self._agent_text_buffer: str = "" + async def _setup_recording_task(self): if self.call_recording: try: @@ -166,6 +173,41 @@ async def _send_greeting_task(self): except Exception as e: logger.error(f"Greeting error: {e}", exc_info=True) + # ------------------------------------------------------------------ + # Echo detection helpers + # ------------------------------------------------------------------ + + async def transcription_node( + self, text: AsyncIterable, model_settings + ) -> AsyncIterable: + """Override to capture real-time LLM output for echo comparison. + + The transcription_node receives every text chunk the LLM produces + (the same text that flows to TTS). We accumulate it into a rolling + buffer so the echo filter in entrypoint can compare incoming STT + transcripts against what the agent is *currently* saying. + """ + async for delta in text: + self._agent_text_buffer += delta + # Keep only the most-recent characters (tail) + if len(self._agent_text_buffer) > _AGENT_TEXT_BUFFER_MAX: + self._agent_text_buffer = self._agent_text_buffer[ + -_AGENT_TEXT_BUFFER_MAX: + ] + yield delta + + def get_recent_agent_text(self, max_chars: int = 500) -> str: + """Return the last *max_chars* characters the agent has generated.""" + if len(self._agent_text_buffer) <= max_chars: + return self._agent_text_buffer + return self._agent_text_buffer[-max_chars:] + + def clear_agent_text_buffer(self) -> None: + """Clear the buffer (called when agent stops speaking).""" + self._agent_text_buffer = "" + + # ------------------------------------------------------------------ + def update_phone_number(self, phone_number: Optional[str]) -> None: """Update memory phone number when SIP participant data becomes available.""" if not phone_number: diff --git a/siphon/plugins/sarvam.py b/siphon/plugins/sarvam.py index 651b605..a475a6c 100644 --- a/siphon/plugins/sarvam.py +++ b/siphon/plugins/sarvam.py @@ -1,11 +1,13 @@ from typing import Optional import os +from livekit.agents import tts as _tts from livekit.plugins import sarvam from . import ClientWrapperMixin + class STT(ClientWrapperMixin): - """Clova-backed STT wrapper around the LiveKit Clova plugin.""" + """Sarvam-backed STT wrapper around the LiveKit Sarvam plugin.""" def __init__( self, language: Optional[str] = "unknown", @@ -31,7 +33,6 @@ def _build_client(self): api_key=self.api_key, ) - # JSON-serializable view (no Python objects) def to_config(self) -> dict: return { @@ -48,8 +49,23 @@ def from_config(cls, cfg: dict) -> "STT": model=cfg.get("model", "saarika:v2.5"), ) + class TTS(ClientWrapperMixin): - """Sarvam-backed TTS wrapper around the Sarvam plugin.""" + """Sarvam-backed TTS wrapper around the Sarvam plugin. + + IMPORTANT: The Sarvam WebSocket streaming path is broken upstream + (livekit-agents v1.5.1). Sarvam's WS API returns raw PCM without + WAV headers, but the LiveKit plugin declares mime_type="audio/wav", + causing the WAV decoder to fail on every synthesis after the first. + + We disable streaming (forcing REST API path) which returns proper + WAV with RIFF headers. This is a known upstream bug tracked in: + - https://github.com/livekit/agents/pull/5209 (merged, unreleased) + - https://github.com/livekit/agents/issues/5267 (still open) + + When a fixed version of livekit-plugins-sarvam is released, this + workaround can be removed. + """ def __init__( self, target_language_code: Optional[str] = "en-IN", @@ -75,7 +91,7 @@ def __init__( self._client = self._build_client() def _build_client(self): - return sarvam.TTS( + client = sarvam.TTS( target_language_code=self.target_language_code, model=self.model, speaker=self.speaker, @@ -83,6 +99,11 @@ def _build_client(self): enable_preprocessing=self.enable_preprocessing, api_key=self.api_key, ) + # Disable WebSocket streaming — Sarvam WS returns raw PCM (no RIFF + # headers) but the plugin declares audio/wav. REST API returns + # proper WAV and works correctly. + client._capabilities = _tts.TTSCapabilities(streaming=False) + return client # JSON-serializable view (no Python objects) def to_config(self) -> dict: @@ -99,9 +120,9 @@ def to_config(self) -> dict: @classmethod def from_config(cls, cfg: dict) -> "TTS": return cls( - target_language_code=cfg.get("target_language_code", "hi-IN"), - model=cfg.get("model", "bulbul:v2"), - speaker=cfg.get("speaker", "anushka"), + target_language_code=cfg.get("target_language_code", "en-IN"), + model=cfg.get("model", "bulbul:v3"), + speaker=cfg.get("speaker", "shubh"), speech_sample_rate=cfg.get("speech_sample_rate", 22050), - enable_preprocessing=cfg.get("enable_preprocessing", False), - ) \ No newline at end of file + enable_preprocessing=cfg.get("enable_preprocessing", True), + )