diff --git a/pyview/__init__.py b/pyview/__init__.py index 22ce821..fd61f88 100644 --- a/pyview/__init__.py +++ b/pyview/__init__.py @@ -1,5 +1,6 @@ import pyview.flash # noqa: F401 — registers flash context processor from pyview.components import ComponentMeta, ComponentsManager, ComponentSocket, LiveComponent +from pyview.connection_tracker import ConnectionTracker from pyview.depends import Depends, Session from pyview.js import JsCommand, JsCommands, js from pyview.live_socket import ( @@ -36,4 +37,6 @@ # Dependency injection "Depends", "Session", + # Connection tracking + "ConnectionTracker", ] diff --git a/pyview/connection_tracker.py b/pyview/connection_tracker.py new file mode 100644 index 0000000..3e86588 --- /dev/null +++ b/pyview/connection_tracker.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from pyview.live_socket import ConnectedLiveViewSocket + from pyview.live_view import LiveView + + +@runtime_checkable +class ConnectionTracker(Protocol): + """Protocol for tracking LiveView connection lifecycle. + + Implement this protocol to receive callbacks at key lifecycle points. + All methods should be fast and non-blocking — they're called inline + in the WebSocket handler hot path. + """ + + def on_connect( + self, + topic: str, + socket: ConnectedLiveViewSocket, + view_class: type[LiveView], + route: str, + session: dict[str, Any], + ) -> None: + """Called when a LiveView mounts and completes its first render.""" + ... + + def on_disconnect(self, topic: str) -> None: + """Called when a LiveView WebSocket connection is closed.""" + ... + + def on_event( + self, + topic: str, + event_name: str, + duration_seconds: float, + ) -> None: + """Called after an event is processed and rendered.""" + ... diff --git a/pyview/pyview.py b/pyview/pyview.py index d17ae4f..a890b16 100644 --- a/pyview/pyview.py +++ b/pyview/pyview.py @@ -12,6 +12,7 @@ from pyview.auth import AuthProviderFactory from pyview.binding import call_handle_params, call_mount, create_view from pyview.components.lifecycle import run_nested_component_lifecycle +from pyview.connection_tracker import ConnectionTracker from pyview.csrf import generate_csrf_token from pyview.instrumentation import InstrumentationProvider, NoOpInstrumentation from pyview.live_socket import UnconnectedSocket @@ -33,7 +34,13 @@ class PyView(Starlette): rootTemplate: RootTemplate instrumentation: InstrumentationProvider - def __init__(self, *args, instrumentation: Optional[InstrumentationProvider] = None, **kwargs): + def __init__( + self, + *args, + instrumentation: Optional[InstrumentationProvider] = None, + connection_tracker: Optional[ConnectionTracker] = None, + **kwargs, + ): # Extract user's lifespan if provided, then always use our composed lifespan user_lifespan = kwargs.pop("lifespan", None) kwargs["lifespan"] = self._create_lifespan(user_lifespan) @@ -41,8 +48,11 @@ def __init__(self, *args, instrumentation: Optional[InstrumentationProvider] = N super().__init__(*args, **kwargs) self.rootTemplate = defaultRootTemplate() self.instrumentation = instrumentation or NoOpInstrumentation() + self.connection_tracker = connection_tracker self.view_lookup = LiveViewLookup() - self.live_handler = LiveSocketHandler(self.view_lookup, self.instrumentation) + self.live_handler = LiveSocketHandler( + self.view_lookup, self.instrumentation, self.connection_tracker + ) self.routes.append(WebSocketRoute("/live/websocket", self.live_handler.handle)) self.add_middleware(GZipMiddleware) @@ -71,6 +81,11 @@ async def lifespan(app): return lifespan + @property + def registered_routes(self) -> list[tuple[str, type[LiveView]]]: + """Return list of (path_format, view_class) for all registered LiveViews.""" + return [(fmt, cls) for fmt, _, _, cls in self.view_lookup.routes] + def add_live_view(self, path: str, view: type[LiveView]): async def lv(request: Request): return await liveview_container(self.rootTemplate, self.view_lookup, request) diff --git a/pyview/ws_handler.py b/pyview/ws_handler.py index aca27ed..b83caee 100644 --- a/pyview/ws_handler.py +++ b/pyview/ws_handler.py @@ -1,5 +1,6 @@ import json import logging +import time from contextlib import suppress from typing import Optional from urllib.parse import parse_qs, urlparse @@ -9,6 +10,7 @@ from pyview.auth import AuthProviderFactory from pyview.binding import call_handle_event, call_handle_params, call_mount, create_view +from pyview.connection_tracker import ConnectionTracker from pyview.csrf import validate_csrf_token from pyview.instrumentation import InstrumentationProvider from pyview.live_routes import LiveViewLookup @@ -51,9 +53,15 @@ def __init__(self, instrumentation: InstrumentationProvider): class LiveSocketHandler: - def __init__(self, routes: LiveViewLookup, instrumentation: InstrumentationProvider): + def __init__( + self, + routes: LiveViewLookup, + instrumentation: InstrumentationProvider, + connection_tracker: Optional[ConnectionTracker] = None, + ): self.routes = routes self.instrumentation = instrumentation + self.connection_tracker = connection_tracker self.metrics = LiveSocketMetrics(instrumentation) self.manager = ConnectionManager() self.sessions = 0 @@ -142,6 +150,13 @@ async def handle(self, websocket: WebSocket): ] await self.manager.send_personal_message(json.dumps(resp), websocket) + + if self.connection_tracker: + with suppress(Exception): + self.connection_tracker.on_connect( + topic, socket, lv_class, url.path, session + ) + await self.handle_connected(topic, socket) except WebSocketDisconnect: @@ -156,6 +171,9 @@ async def handle(self, websocket: WebSocket): if socket: with suppress(Exception): await socket.close() + if self.connection_tracker and topic: + with suppress(Exception): + self.connection_tracker.on_disconnect(topic) self.sessions -= 1 self.metrics.active_connections.add(-1) @@ -193,6 +211,8 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket event_name = payload["event"] view_name = socket.liveview.__class__.__name__ + t0 = time.perf_counter() + # Handle built-in lv:clear-flash event if event_name == "lv:clear-flash": raw_key = value.get("key") if isinstance(value, dict) else None @@ -242,6 +262,13 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket resp_json = json.dumps(resp) self.metrics.message_size.record(len(resp_json)) await self.manager.send_personal_message(resp_json, socket.websocket) + + if self.connection_tracker: + with suppress(Exception): + self.connection_tracker.on_event( + topic, event_name, time.perf_counter() - t0 + ) + continue if event == "live_patch": @@ -394,6 +421,12 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket await self.manager.send_personal_message(json.dumps(resp), socket.websocket) + if self.connection_tracker is not None: + with suppress(Exception): + self.connection_tracker.on_connect( + topic, socket, lv_class, url.path, session + ) + if event == "chunk": socket.upload_manager.add_chunk(joinRef, payload) # type: ignore @@ -443,6 +476,9 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket if event == "phx_leave": # Handle LiveView navigation - clean up current LiveView await socket.close() + if self.connection_tracker and topic: + with suppress(Exception): + self.connection_tracker.on_disconnect(topic) resp = [ joinRef, diff --git a/tests/test_connection_tracker.py b/tests/test_connection_tracker.py new file mode 100644 index 0000000..a84f59a --- /dev/null +++ b/tests/test_connection_tracker.py @@ -0,0 +1,129 @@ +"""Tests for the ConnectionTracker protocol and its integration points.""" + +from pyview.connection_tracker import ConnectionTracker +from pyview.instrumentation import NoOpInstrumentation +from pyview.live_routes import LiveViewLookup +from pyview.live_view import LiveView +from pyview.pyview import PyView +from pyview.ws_handler import LiveSocketHandler + + +class DummyView(LiveView): + pass + + +class FakeTracker: + """A concrete ConnectionTracker implementation for testing.""" + + def __init__(self): + self.connects = [] + self.disconnects = [] + self.events = [] + + def on_connect(self, topic, socket, view_class, route, session): + self.connects.append( + { + "topic": topic, + "socket": socket, + "view_class": view_class, + "route": route, + "session": session, + } + ) + + def on_disconnect(self, topic): + self.disconnects.append(topic) + + def on_event(self, topic, event_name, duration_seconds): + self.events.append( + { + "topic": topic, + "event_name": event_name, + "duration_seconds": duration_seconds, + } + ) + + +def test_fake_tracker_satisfies_protocol(): + """FakeTracker should satisfy the ConnectionTracker protocol.""" + tracker = FakeTracker() + assert isinstance(tracker, ConnectionTracker) + + +def test_handler_accepts_none_tracker(): + """LiveSocketHandler should work fine with no tracker (default).""" + routes = LiveViewLookup() + handler = LiveSocketHandler(routes, NoOpInstrumentation()) + assert handler.connection_tracker is None + + +def test_handler_accepts_tracker(): + """LiveSocketHandler should store the tracker when provided.""" + routes = LiveViewLookup() + tracker = FakeTracker() + handler = LiveSocketHandler(routes, NoOpInstrumentation(), connection_tracker=tracker) + assert handler.connection_tracker is tracker + + +def test_protocol_is_runtime_checkable(): + """ConnectionTracker should be runtime-checkable via isinstance.""" + + class NotATracker: + pass + + class MinimalTracker: + def on_connect(self, topic, socket, view_class, route, session): + pass + + def on_disconnect(self, topic): + pass + + def on_event(self, topic, event_name, duration_seconds): + pass + + assert isinstance(MinimalTracker(), ConnectionTracker) + assert not isinstance(NotATracker(), ConnectionTracker) + + +def test_registered_routes_empty(): + """PyView.registered_routes should return empty list with no routes.""" + app = PyView() + assert app.registered_routes == [] + + +def test_registered_routes_returns_routes(): + """PyView.registered_routes should return (path, view_class) tuples.""" + app = PyView() + app.add_live_view("/test", DummyView) + routes = app.registered_routes + assert len(routes) == 1 + assert routes[0] == ("/test", DummyView) + + +def test_registered_routes_multiple(): + """PyView.registered_routes should return all registered routes.""" + + class AnotherView(LiveView): + pass + + app = PyView() + app.add_live_view("/a", DummyView) + app.add_live_view("/b", AnotherView) + routes = app.registered_routes + assert len(routes) == 2 + assert routes[0] == ("/a", DummyView) + assert routes[1] == ("/b", AnotherView) + + +def test_pyview_passes_tracker_to_handler(): + """PyView should pass connection_tracker through to LiveSocketHandler.""" + tracker = FakeTracker() + app = PyView(connection_tracker=tracker) + assert app.live_handler.connection_tracker is tracker + + +def test_pyview_no_tracker_by_default(): + """PyView should have no tracker by default.""" + app = PyView() + assert app.connection_tracker is None + assert app.live_handler.connection_tracker is None