Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyview/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pyview.flash # noqa: F401 — registers flash context processor
from pyview.components import ComponentMeta, ComponentsManager, ComponentSocket, LiveComponent
from pyview.connection_tracker import ConnectionTracker
from pyview.depends import Depends, Session
from pyview.js import JsCommand, JsCommands, js
from pyview.live_socket import (
Expand Down Expand Up @@ -36,4 +37,6 @@
# Dependency injection
"Depends",
"Session",
# Connection tracking
"ConnectionTracker",
]
41 changes: 41 additions & 0 deletions pyview/connection_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

if TYPE_CHECKING:
from pyview.live_socket import ConnectedLiveViewSocket
from pyview.live_view import LiveView


@runtime_checkable
class ConnectionTracker(Protocol):
"""Protocol for tracking LiveView connection lifecycle.

Implement this protocol to receive callbacks at key lifecycle points.
All methods should be fast and non-blocking — they're called inline
in the WebSocket handler hot path.
"""

def on_connect(
self,
topic: str,
socket: ConnectedLiveViewSocket,
view_class: type[LiveView],
route: str,
session: dict[str, Any],
) -> None:
"""Called when a LiveView mounts and completes its first render."""
...

def on_disconnect(self, topic: str) -> None:
"""Called when a LiveView WebSocket connection is closed."""
...

def on_event(
self,
topic: str,
event_name: str,
duration_seconds: float,
) -> None:
"""Called after an event is processed and rendered."""
...
19 changes: 17 additions & 2 deletions pyview/pyview.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pyview.auth import AuthProviderFactory
from pyview.binding import call_handle_params, call_mount, create_view
from pyview.components.lifecycle import run_nested_component_lifecycle
from pyview.connection_tracker import ConnectionTracker
from pyview.csrf import generate_csrf_token
from pyview.instrumentation import InstrumentationProvider, NoOpInstrumentation
from pyview.live_socket import UnconnectedSocket
Expand All @@ -33,16 +34,25 @@ class PyView(Starlette):
rootTemplate: RootTemplate
instrumentation: InstrumentationProvider

def __init__(self, *args, instrumentation: Optional[InstrumentationProvider] = None, **kwargs):
def __init__(
self,
*args,
instrumentation: Optional[InstrumentationProvider] = None,
connection_tracker: Optional[ConnectionTracker] = None,
**kwargs,
):
# Extract user's lifespan if provided, then always use our composed lifespan
user_lifespan = kwargs.pop("lifespan", None)
kwargs["lifespan"] = self._create_lifespan(user_lifespan)

super().__init__(*args, **kwargs)
self.rootTemplate = defaultRootTemplate()
self.instrumentation = instrumentation or NoOpInstrumentation()
self.connection_tracker = connection_tracker
self.view_lookup = LiveViewLookup()
self.live_handler = LiveSocketHandler(self.view_lookup, self.instrumentation)
self.live_handler = LiveSocketHandler(
self.view_lookup, self.instrumentation, self.connection_tracker
)

self.routes.append(WebSocketRoute("/live/websocket", self.live_handler.handle))
self.add_middleware(GZipMiddleware)
Expand Down Expand Up @@ -71,6 +81,11 @@ async def lifespan(app):

return lifespan

@property
def registered_routes(self) -> list[tuple[str, type[LiveView]]]:
"""Return list of (path_format, view_class) for all registered LiveViews."""
return [(fmt, cls) for fmt, _, _, cls in self.view_lookup.routes]

def add_live_view(self, path: str, view: type[LiveView]):
async def lv(request: Request):
return await liveview_container(self.rootTemplate, self.view_lookup, request)
Expand Down
38 changes: 37 additions & 1 deletion pyview/ws_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import time
from contextlib import suppress
from typing import Optional
from urllib.parse import parse_qs, urlparse
Expand All @@ -9,6 +10,7 @@

from pyview.auth import AuthProviderFactory
from pyview.binding import call_handle_event, call_handle_params, call_mount, create_view
from pyview.connection_tracker import ConnectionTracker
from pyview.csrf import validate_csrf_token
from pyview.instrumentation import InstrumentationProvider
from pyview.live_routes import LiveViewLookup
Expand Down Expand Up @@ -51,9 +53,15 @@ def __init__(self, instrumentation: InstrumentationProvider):


class LiveSocketHandler:
def __init__(self, routes: LiveViewLookup, instrumentation: InstrumentationProvider):
def __init__(
self,
routes: LiveViewLookup,
instrumentation: InstrumentationProvider,
connection_tracker: Optional[ConnectionTracker] = None,
):
self.routes = routes
self.instrumentation = instrumentation
self.connection_tracker = connection_tracker
self.metrics = LiveSocketMetrics(instrumentation)
self.manager = ConnectionManager()
self.sessions = 0
Expand Down Expand Up @@ -142,6 +150,13 @@ async def handle(self, websocket: WebSocket):
]

await self.manager.send_personal_message(json.dumps(resp), websocket)

if self.connection_tracker:
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use if self.connection_tracker is not None: rather than a truthiness check. As written, a tracker instance that defines __bool__/__len__ and evaluates to False would silently disable tracking.

Suggested change
if self.connection_tracker:
if self.connection_tracker is not None:

Copilot uses AI. Check for mistakes.
with suppress(Exception):
self.connection_tracker.on_connect(
topic, socket, lv_class, url.path, session
)

Comment on lines +154 to +159
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on_connect is only invoked for the initial phx_join in handle(). When the client navigates within the same WebSocket (the phx_join handled inside _handle_connected_loop), a new LiveView is mounted and first-rendered but the tracker never receives on_connect, so connection accounting will be incomplete. Consider invoking connection_tracker.on_connect(...) for navigation joins as well (and likewise for any mount path that results in a new ConnectedLiveViewSocket).

Copilot uses AI. Check for mistakes.
await self.handle_connected(topic, socket)

except WebSocketDisconnect:
Expand All @@ -156,6 +171,9 @@ async def handle(self, websocket: WebSocket):
if socket:
with suppress(Exception):
await socket.close()
if self.connection_tracker and topic:
with suppress(Exception):
self.connection_tracker.on_disconnect(topic)
Comment on lines +174 to +176
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use active topic when reporting disconnects

The finally block always emits on_disconnect(topic) using the topic captured from the initial phx_join, but _handle_connected_loop can switch LiveViews via later phx_join messages. In a navigation flow (phx_leave old topic -> phx_join new topic -> socket closes), this reports the old topic again and never reports disconnect for the active topic, which can leave trackers with incorrect connection state/counts.

Useful? React with 👍 / 👎.

self.sessions -= 1
self.metrics.active_connections.add(-1)

Expand Down Expand Up @@ -193,6 +211,8 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket
event_name = payload["event"]
view_name = socket.liveview.__class__.__name__

t0 = time.perf_counter()

# Handle built-in lv:clear-flash event
if event_name == "lv:clear-flash":
raw_key = value.get("key") if isinstance(value, dict) else None
Expand Down Expand Up @@ -242,6 +262,13 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket
resp_json = json.dumps(resp)
self.metrics.message_size.record(len(resp_json))
await self.manager.send_personal_message(resp_json, socket.websocket)

if self.connection_tracker:
with suppress(Exception):
self.connection_tracker.on_event(
topic, event_name, time.perf_counter() - t0
)

continue

if event == "live_patch":
Expand Down Expand Up @@ -394,6 +421,12 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket

await self.manager.send_personal_message(json.dumps(resp), socket.websocket)

if self.connection_tracker is not None:
with suppress(Exception):
self.connection_tracker.on_connect(
topic, socket, lv_class, url.path, session
)

if event == "chunk":
socket.upload_manager.add_chunk(joinRef, payload) # type: ignore

Expand Down Expand Up @@ -443,6 +476,9 @@ async def _handle_connected_loop(self, myJoinId, socket: ConnectedLiveViewSocket
if event == "phx_leave":
# Handle LiveView navigation - clean up current LiveView
await socket.close()
if self.connection_tracker and topic:
with suppress(Exception):
self.connection_tracker.on_disconnect(topic)
Comment on lines +479 to +481
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on_disconnect(topic) is called here on phx_leave, but handle() also unconditionally calls on_disconnect(topic) in its finally block when the WebSocket closes. If the client sends phx_leave and then disconnects without re-joining a new topic, the tracker will receive a duplicate disconnect for the same topic. Consider clearing/resetting the tracked topic (or tracking an "already disconnected" flag) after phx_leave so the finally block doesn’t emit a second disconnect for the same view.

Suggested change
if self.connection_tracker and topic:
with suppress(Exception):
self.connection_tracker.on_disconnect(topic)

Copilot uses AI. Check for mistakes.

resp = [
joinRef,
Expand Down
129 changes: 129 additions & 0 deletions tests/test_connection_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Tests for the ConnectionTracker protocol and its integration points."""

from pyview.connection_tracker import ConnectionTracker
from pyview.instrumentation import NoOpInstrumentation
from pyview.live_routes import LiveViewLookup
from pyview.live_view import LiveView
from pyview.pyview import PyView
from pyview.ws_handler import LiveSocketHandler


class DummyView(LiveView):
pass


class FakeTracker:
"""A concrete ConnectionTracker implementation for testing."""

def __init__(self):
self.connects = []
self.disconnects = []
self.events = []

def on_connect(self, topic, socket, view_class, route, session):
self.connects.append(
{
"topic": topic,
"socket": socket,
"view_class": view_class,
"route": route,
"session": session,
}
)

def on_disconnect(self, topic):
self.disconnects.append(topic)

def on_event(self, topic, event_name, duration_seconds):
self.events.append(
{
"topic": topic,
"event_name": event_name,
"duration_seconds": duration_seconds,
}
)


def test_fake_tracker_satisfies_protocol():
"""FakeTracker should satisfy the ConnectionTracker protocol."""
tracker = FakeTracker()
assert isinstance(tracker, ConnectionTracker)


def test_handler_accepts_none_tracker():
"""LiveSocketHandler should work fine with no tracker (default)."""
routes = LiveViewLookup()
handler = LiveSocketHandler(routes, NoOpInstrumentation())
assert handler.connection_tracker is None


def test_handler_accepts_tracker():
"""LiveSocketHandler should store the tracker when provided."""
routes = LiveViewLookup()
tracker = FakeTracker()
handler = LiveSocketHandler(routes, NoOpInstrumentation(), connection_tracker=tracker)
assert handler.connection_tracker is tracker


def test_protocol_is_runtime_checkable():
"""ConnectionTracker should be runtime-checkable via isinstance."""

class NotATracker:
pass

class MinimalTracker:
def on_connect(self, topic, socket, view_class, route, session):
pass

def on_disconnect(self, topic):
pass

def on_event(self, topic, event_name, duration_seconds):
pass

assert isinstance(MinimalTracker(), ConnectionTracker)
assert not isinstance(NotATracker(), ConnectionTracker)


def test_registered_routes_empty():
"""PyView.registered_routes should return empty list with no routes."""
app = PyView()
assert app.registered_routes == []


def test_registered_routes_returns_routes():
"""PyView.registered_routes should return (path, view_class) tuples."""
app = PyView()
app.add_live_view("/test", DummyView)
routes = app.registered_routes
assert len(routes) == 1
assert routes[0] == ("/test", DummyView)


def test_registered_routes_multiple():
"""PyView.registered_routes should return all registered routes."""

class AnotherView(LiveView):
pass

app = PyView()
app.add_live_view("/a", DummyView)
app.add_live_view("/b", AnotherView)
routes = app.registered_routes
assert len(routes) == 2
assert routes[0] == ("/a", DummyView)
assert routes[1] == ("/b", AnotherView)


def test_pyview_passes_tracker_to_handler():
"""PyView should pass connection_tracker through to LiveSocketHandler."""
tracker = FakeTracker()
app = PyView(connection_tracker=tracker)
assert app.live_handler.connection_tracker is tracker

Comment on lines +118 to +123
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests validate protocol conformance and that the tracker is plumbed through, but they don’t assert the new runtime behavior (that on_connect, on_event, and on_disconnect are actually invoked by LiveSocketHandler, and that exceptions from the tracker are suppressed). Adding an async WebSocket handler test (similar to existing test_ws_handler_cleanup.py) that drives a phx_join + event + disconnect (and a phx_leave path) would prevent regressions in the hook semantics.

Copilot uses AI. Check for mistakes.

def test_pyview_no_tracker_by_default():
"""PyView should have no tracker by default."""
app = PyView()
assert app.connection_tracker is None
assert app.live_handler.connection_tracker is None
Loading