diff --git a/src/palabra_ai/enum.py b/src/palabra_ai/enum.py index 483b28f..dd9163b 100644 --- a/src/palabra_ai/enum.py +++ b/src/palabra_ai/enum.py @@ -7,6 +7,7 @@ class MessageType(StrEnum): VALIDATED_TRANSCRIPTION = "validated_transcription" PARTIAL_TRANSLATED_TRANSCRIPTION = "partial_translated_transcription" PIPELINE_TIMINGS = "pipeline_timings" + TTS_TEXT = "tts_text" _QUEUE_STATUS = "queue_status" # For "es" messages _EMPTY = "empty" # For empty {} messages _UNKNOWN = "unknown" # For unrecognized message formats @@ -19,6 +20,7 @@ class MessageType(StrEnum): MessageType.TRANSLATED_TRANSCRIPTION, MessageType.VALIDATED_TRANSCRIPTION, MessageType.PARTIAL_TRANSLATED_TRANSCRIPTION, + MessageType.TTS_TEXT, ) } ALLOWED_MESSAGE_TYPES = { diff --git a/src/palabra_ai/message.py b/src/palabra_ai/message.py index b2a2c49..5434f84 100644 --- a/src/palabra_ai/message.py +++ b/src/palabra_ai/message.py @@ -116,6 +116,7 @@ class Type(StrEnum): PARTIAL_TRANSLATED_TRANSCRIPTION = "partial_translated_transcription" PIPELINE_TIMINGS = "pipeline_timings" TTS_BUFFER_STATS = "tts_buffer_stats" + TTS_TEXT = "tts_text" ERROR = "error" # For error messages END_TASK = "end_task" # For end_task messages SET_TASK = "set_task" # For set_task messages @@ -131,6 +132,7 @@ class Type(StrEnum): Type.TRANSLATED_TRANSCRIPTION, Type.VALIDATED_TRANSCRIPTION, Type.PARTIAL_TRANSLATED_TRANSCRIPTION, + Type.TTS_TEXT, } IN_PROCESS_TYPES: ClassVar[set[Type]] = TRANSCRIPTION_TYPES @@ -586,7 +588,7 @@ class TranscriptionMessage(Message): id_: str = Field(alias="transcription_id") text: str language: Language - segments: list[TranscriptionSegment] + segments: list[TranscriptionSegment] = Field(default_factory=list) model_config = ConfigDict(populate_by_name=True) @@ -604,13 +606,15 @@ def extract_from_nested(cls, values: dict[str, Any]) -> dict[str, Any]: transcription = values["data"]["transcription"] # Convert language string to Language object lang_code = transcription["language"] - return { + result = { "message_type": values["message_type"], "transcription_id": transcription["transcription_id"], "language": Language.get_or_create(lang_code), - "segments": transcription["segments"], "text": transcription["text"], } + if values["message_type"] != "tts_text": + result["segments"] = transcription["segments"] + return result return values def model_dump(self, **kwargs) -> dict[str, Any]: diff --git a/tests/test_enum.py b/tests/test_enum.py index e7bedbe..e1a3e4d 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -32,6 +32,7 @@ def test_transcription_message_types(): "translated_transcription", "validated_transcription", "partial_translated_transcription", + "tts_text", } assert TRANSCRIPTION_MESSAGE_TYPES == expected @@ -44,6 +45,7 @@ def test_allowed_message_types(): "translated_transcription", "validated_transcription", "partial_translated_transcription", + "tts_text", } assert ALLOWED_MESSAGE_TYPES == expected