Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 96 additions & 21 deletions runner/app/live/streamer/protocol/trickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import logging
import queue
import json
from typing import AsyncGenerator, Optional

from PIL import Image
import time
from typing import AsyncGenerator, Optional, Any

from trickle import media, TricklePublisher, TrickleSubscriber, InputFrame, OutputFrame, AudioFrame, AudioOutput, DEFAULT_WIDTH, DEFAULT_HEIGHT

Expand All @@ -17,18 +16,21 @@ def __init__(self, subscribe_url: str, publish_url: str, control_url: Optional[s
self.publish_url = publish_url
self.control_url = control_url
self.events_url = events_url
self.subscribe_queue = queue.Queue[InputFrame]()
self.publish_queue = queue.Queue[OutputFrame]()
self.subscribe_queue = queue.Queue[InputFrame | None]()
self.publish_queue = queue.Queue[OutputFrame | None]()
self.event_queue = asyncio.Queue[dict | None]()
self.control_subscriber = None
self.events_publisher = None
self.subscribe_task = None
self.publish_task = None
self.event_batcher_task = None
self.width = width
self.height = height

async def start(self):
self.subscribe_queue = queue.Queue[InputFrame]()
self.publish_queue = queue.Queue[OutputFrame]()
self.subscribe_queue = queue.Queue[InputFrame | None]()
self.publish_queue = queue.Queue[OutputFrame | None]()
self.event_queue = asyncio.Queue[dict | None]()
metadata_cache = LastValueCache[dict]() # to pass video metadata from decoder to encoder
self.subscribe_task = asyncio.create_task(
media.run_subscribe(self.subscribe_url, self.subscribe_queue.put, metadata_cache.put, self.emit_monitoring_event, self.width, self.height)
Expand All @@ -40,6 +42,7 @@ async def start(self):
self.control_subscriber = TrickleSubscriber(self.control_url)
if self.events_url and self.events_url.strip() != "":
self.events_publisher = TricklePublisher(self.events_url, "application/json")
self.event_batcher_task = asyncio.create_task(self._event_batcher_loop())

async def stop(self):
if not self.subscribe_task or not self.publish_task:
Expand All @@ -48,6 +51,7 @@ async def stop(self):
# send sentinel None values to stop the trickle tasks gracefully
self.subscribe_queue.put(None)
self.publish_queue.put(None)
await self.event_queue.put(None)

if self.control_subscriber:
await self.control_subscriber.close()
Expand All @@ -57,15 +61,59 @@ async def stop(self):
await self.events_publisher.close()
self.events_publisher = None

tasks = [self.subscribe_task, self.publish_task]
tasks = [self.subscribe_task, self.publish_task, self.event_batcher_task]
try:
await asyncio.wait(tasks, timeout=10.0)
except asyncio.TimeoutError:
for task in tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

self.subscribe_task = None
self.publish_task = None
self.event_batcher_task = None

async def _event_batcher_loop(self):
"""Background task that batches events and sends them as JSONL segments."""
BATCH_TIMEOUT = 5.0 # 5 seconds timeout

if not self.events_publisher:
raise ValueError("Events publisher not initialized")

while True:
try:
event_data = await self.event_queue.get()
if not event_data:
return

async with await self.events_publisher.next() as current_segment:
segment_start_time = time.time()

event_json = json.dumps(event_data)
await current_segment.write(event_json.encode() + b'\n')

while True:
timeout = BATCH_TIMEOUT - (time.time() - segment_start_time)
if timeout <= 0.1:
break

try:
next_event = await asyncio.wait_for(self.event_queue.get(), timeout=timeout)
if not next_event:
return
except asyncio.TimeoutError:
break

event_json = json.dumps(next_event)
await current_segment.write(event_json.encode() + b'\n')

logging.debug("Sent batched events segment")

except Exception:
logging.error(f"Error in event batcher loop", exc_info=True)

async def ingress_loop(self, done: asyncio.Event) -> AsyncGenerator[InputFrame, None]:
subscribe_queue = self.subscribe_queue
Expand Down Expand Up @@ -99,12 +147,9 @@ def enqueue_bytes(frame: OutputFrame):
async def emit_monitoring_event(self, event: dict, queue_event_type: str = "ai_stream_events"):
if not self.events_publisher:
return
try:
event_json = json.dumps({"event": event, "queue_event_type": queue_event_type})
async with await self.events_publisher.next() as event:
await event.write(event_json.encode())
except Exception as e:
logging.error(f"Error reporting status: {e}")

event_data = {"event": event, "queue_event_type": queue_event_type}
await self.event_queue.put(event_data)

async def control_loop(self, done: asyncio.Event) -> AsyncGenerator[dict, None]:
if not self.control_subscriber:
Expand All @@ -114,20 +159,50 @@ async def control_loop(self, done: asyncio.Event) -> AsyncGenerator[dict, None]:
logging.info("Starting Control subscriber at %s", self.control_url)
keepalive_message = {"keep": "alive"}

def read_event_line(line: str) -> Any:
"""Read a single event line. Returns None for JSON decode errors and keepalive messages."""
try:
data = json.loads(line)
if data == keepalive_message:
# Ignore periodic keepalive messages
return None

logging.info("Received control message with params: %s", data)
return data
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON line: {line}, error: {e}")
return None

while not done.is_set():
try:
segment = await self.control_subscriber.next()
if not segment or segment.eos():
return

params = await segment.read()
data = json.loads(params)
if data == keepalive_message:
# Ignore periodic keepalive messages
continue
buffer = ""

logging.info("Received control message with params: %s", data)
yield data
while True:
chunk = await segment.read()
if not chunk:
break

buffer += chunk.decode('utf-8')

while '\n' in buffer:
line, buffer = buffer.split('\n', 1)
line = line.strip()
if not line:
continue

event = read_event_line(line)
if event:
yield event

buffer = buffer.strip()
if buffer:
event = read_event_line(buffer)
if event:
yield event

except Exception:
logging.error(f"Error in control loop", exc_info=True)
Expand Down