diff --git a/src/java_functional_lsp/proxy.py b/src/java_functional_lsp/proxy.py index d5967a5..a68bc47 100644 --- a/src/java_functional_lsp/proxy.py +++ b/src/java_functional_lsp/proxy.py @@ -12,6 +12,7 @@ import re import shutil import subprocess +from collections import deque from collections.abc import Callable, Mapping from functools import lru_cache from pathlib import Path @@ -316,6 +317,78 @@ async def read_message(reader: asyncio.StreamReader) -> dict[str, Any] | None: _BUILD_FILES = ("pom.xml", "build.gradle", "build.gradle.kts") _WORKSPACE_DID_CHANGE_FOLDERS = "workspace/didChangeWorkspaceFolders" _MAX_QUEUED_NOTIFICATIONS = 200 +_MODULE_READY_TIMEOUT = 30.0 + + +class ModuleState: + """Module import states — UNKNOWN → ADDED → READY.""" + + UNKNOWN = "unknown" + ADDED = "added" + READY = "ready" + + +class ModuleRegistry: + """Thread-safe (asyncio) registry tracking jdtls module import states. + + Uses a plain dict for O(1) hot-path lookups and per-module ``asyncio.Event`` + for adaptive waiting — coroutines blocked on ``wait_until_ready()`` wake + instantly when ``mark_ready()`` is called, instead of a fixed sleep. + + Safe without locks because asyncio is single-threaded: dict mutations that + don't span an ``await`` are atomic. + """ + + def __init__(self) -> None: + self._states: dict[str, str] = {} + self._ready_events: dict[str, asyncio.Event] = {} + + def get_state(self, uri: str) -> str: + """O(1) state lookup. Returns ModuleState constant.""" + return self._states.get(uri, ModuleState.UNKNOWN) + + def is_ready(self, uri: str) -> bool: + """O(1) hot-path check — zero overhead when module is ready.""" + return self._states.get(uri) == ModuleState.READY + + def was_added(self, uri: str) -> bool: + """True if module was sent to jdtls (ADDED or READY).""" + return uri in self._states + + def mark_added(self, uri: str) -> None: + """Mark module as sent to jdtls. Pre-creates the Event for waiters. + + Must be called before any ``await`` to prevent duplicate add_module calls. + """ + self._states[uri] = ModuleState.ADDED + self._ready_events.setdefault(uri, asyncio.Event()) + + def mark_ready(self, uri: str) -> None: + """Mark module as confirmed working. Wakes all coroutines waiting on it.""" + self._states[uri] = ModuleState.READY + event = self._ready_events.pop(uri, None) + if event is not None: + event.set() + + def clear(self) -> None: + """Reset all state. Used by tests.""" + self._states.clear() + self._ready_events.clear() + + async def wait_until_ready(self, uri: str, timeout: float = _MODULE_READY_TIMEOUT) -> bool: + """Suspend until the module is ready or timeout expires. + + Returns True if ready, False on timeout. If already READY, returns + immediately without suspending. + """ + event = self._ready_events.setdefault(uri, asyncio.Event()) + if event.is_set(): + return True + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + return True + except asyncio.TimeoutError: + return False @lru_cache(maxsize=256) @@ -336,6 +409,9 @@ def find_module_root(file_path: str) -> str | None: Returns the directory path, or ``None`` if no build file is found before reaching the filesystem root. Results are cached by parent directory. + + **Note:** cache entries are never invalidated. Build files added after the + first lookup for a given directory will not be detected until process restart. """ return _cached_module_root(str(Path(file_path).parent)) @@ -373,10 +449,10 @@ def __init__(self, on_diagnostics: Callable[[str, list[Any]], None] | None = Non self._start_failed = False self._jdtls_on_path = False self._lazy_start_fired = False - self._queued_notifications: list[tuple[str, Any]] = [] + self._queued_notifications: deque[tuple[str, Any]] = deque(maxlen=_MAX_QUEUED_NOTIFICATIONS) self._original_root_uri: str | None = None self._initial_module_uri: str | None = None - self._added_module_uris: set[str] = set() + self.modules = ModuleRegistry() self._workspace_expanded = False @property @@ -437,9 +513,9 @@ async def start(self, init_params: dict[str, Any], *, module_root_uri: str | Non ws = caps.setdefault("workspace", {}) ws["workspaceFolders"] = True - # Track the initial module as already loaded. + # Track the initial module as already loaded (mark ADDED before await). self._initial_module_uri = module_root_uri - self._added_module_uris.add(effective_root_uri) + self.modules.mark_added(effective_root_uri) # Build a clean environment for jdtls. loop = asyncio.get_running_loop() @@ -508,33 +584,45 @@ async def ensure_started(self, init_params: dict[str, Any], file_uri: str) -> bo self._start_failed = True self._queued_notifications.clear() return started + except Exception: + self._start_failed = True + self._queued_notifications.clear() + raise finally: self._starting = False def queue_notification(self, method: str, params: Any) -> None: """Buffer a notification for replay after jdtls starts. - Capped at ``_MAX_QUEUED_NOTIFICATIONS`` to prevent unbounded memory - growth during long jdtls startup. Oldest entries are dropped on overflow. + Uses a ``deque(maxlen=200)`` so oldest entries are dropped in O(1) + when the queue overflows during long jdtls startup. """ - if len(self._queued_notifications) >= _MAX_QUEUED_NOTIFICATIONS: - self._queued_notifications.pop(0) self._queued_notifications.append((method, params)) async def flush_queued_notifications(self) -> None: """Send all queued notifications to jdtls.""" - queue, self._queued_notifications = self._queued_notifications, [] + queue = list(self._queued_notifications) + self._queued_notifications.clear() for method, params in queue: await self.send_notification(method, params) - async def add_module_if_new(self, file_uri: str) -> None: - """Add the module containing *file_uri* to jdtls if not already added.""" + async def add_module_if_new(self, file_uri: str) -> str | None: + """Add the module containing *file_uri* to jdtls if not already added. + + Returns the module URI if a new module was added (UNKNOWN → ADDED), + or ``None`` if already known or unavailable. The returned URI can be + used with ``modules.wait_until_ready()`` for adaptive waiting. + + Calls ``modules.mark_added()`` before any ``await`` to prevent + duplicate add calls from concurrent coroutines. + """ if not self._available: - return + return None module_uri = _resolve_module_uri(file_uri) - if module_uri is None or module_uri in self._added_module_uris: - return - self._added_module_uris.add(module_uri) + if module_uri is None or self.modules.was_added(module_uri): + return None + # Mark ADDED before await — atomic in asyncio, prevents duplicate sends. + self.modules.mark_added(module_uri) from pygls.uris import to_fs_path logger.info("jdtls: adding module %s", _redact_path(to_fs_path(module_uri))) @@ -543,6 +631,7 @@ async def add_module_if_new(self, file_uri: str) -> None: _WORKSPACE_DID_CHANGE_FOLDERS, {"event": {"added": [{"uri": module_uri, "name": mod_name}], "removed": []}}, ) + return module_uri async def expand_full_workspace(self) -> None: """Expand jdtls workspace to the full monorepo root (background task). @@ -556,10 +645,10 @@ async def expand_full_workspace(self) -> None: root_path = to_fs_path(self._original_root_uri) or self._original_root_uri root_uri = from_fs_path(root_path) or self._original_root_uri - if root_uri in self._added_module_uris: + if self.modules.was_added(root_uri): self._workspace_expanded = True return - self._added_module_uris.add(root_uri) + self.modules.mark_added(root_uri) # Remove initial module folder to avoid double-indexing. removed: list[dict[str, str]] = [] diff --git a/src/java_functional_lsp/server.py b/src/java_functional_lsp/server.py index 030d3c0..5ede8b2 100644 --- a/src/java_functional_lsp/server.py +++ b/src/java_functional_lsp/server.py @@ -26,7 +26,7 @@ from .analyzers.null_checker import NullChecker from .analyzers.spring_checker import SpringChecker from .fixes import get_fix, get_fix_registry_keys -from .proxy import JdtlsProxy +from .proxy import JdtlsProxy, _resolve_module_uri logger = logging.getLogger(__name__) @@ -337,7 +337,7 @@ def _forward_or_queue(method: str, serialized: Any) -> None: task = asyncio.create_task(server._proxy.send_notification(method, serialized)) _bg_tasks.add(task) task.add_done_callback(_bg_tasks.discard) - elif server._proxy._starting: + elif server._proxy._lazy_start_fired and not server._proxy._start_failed: server._proxy.queue_notification(method, serialized) @@ -440,11 +440,55 @@ async def _expand_workspace_background() -> None: # they only activate after jdtls starts. +async def _ensure_module_and_forward(method: str, params: Any, file_uri: str) -> Any | None: + """Forward a request to jdtls, ensuring the file's module is loaded. + + Uses ``ModuleRegistry`` for adaptive waiting: + - **READY**: forward immediately (zero overhead on hot path) + - **UNKNOWN**: add module, wait until ready (adaptive, not fixed sleep) + - **ADDED**: module sent but not confirmed — wait until ready + + When a request succeeds, marks the module as READY so subsequent + requests skip the wait entirely. + """ + proxy = server._proxy + if not proxy.is_available: + return None + + module_uri = _resolve_module_uri(file_uri) + + # Hot path: module already confirmed working. + if module_uri and proxy.modules.is_ready(module_uri): + return await proxy.send_request(method, _serialize_params(params)) + + # Cold path: add module if unknown, then wait for ready. + new_module_uri = await proxy.add_module_if_new(file_uri) + + serialized = _serialize_params(params) + result = await proxy.send_request(method, serialized) + + if result is not None: + # Success — mark module as ready so future requests are instant. + if module_uri: + proxy.modules.mark_ready(module_uri) + return result + + # Null result and module is not yet ready — wait then retry once. + # Use a short timeout (5s) so single-caller case doesn't block for 30s. + # If a concurrent request succeeds, Event.set() wakes us early. + wait_uri = new_module_uri or module_uri + if wait_uri and not proxy.modules.is_ready(wait_uri): + await proxy.modules.wait_until_ready(wait_uri, timeout=5.0) + # Always retry once after waiting — even on timeout the module may be ready. + result = await proxy.send_request(method, serialized) + if result is not None and module_uri: + proxy.modules.mark_ready(module_uri) + return result + + async def _on_completion(params: lsp.CompletionParams) -> lsp.CompletionList | None: """Forward completion request to jdtls.""" - if not server._proxy.is_available: - return None - result = await server._proxy.send_request("textDocument/completion", _serialize_params(params)) + result = await _ensure_module_and_forward("textDocument/completion", params, params.text_document.uri) if result is None: return None try: @@ -455,9 +499,7 @@ async def _on_completion(params: lsp.CompletionParams) -> lsp.CompletionList | N async def _on_hover(params: lsp.HoverParams) -> lsp.Hover | None: """Forward hover request to jdtls.""" - if not server._proxy.is_available: - return None - result = await server._proxy.send_request("textDocument/hover", _serialize_params(params)) + result = await _ensure_module_and_forward("textDocument/hover", params, params.text_document.uri) if result is None: return None try: @@ -468,9 +510,7 @@ async def _on_hover(params: lsp.HoverParams) -> lsp.Hover | None: async def _on_definition(params: lsp.DefinitionParams) -> list[lsp.Location] | None: """Forward go-to-definition request to jdtls.""" - if not server._proxy.is_available: - return None - result = await server._proxy.send_request("textDocument/definition", _serialize_params(params)) + result = await _ensure_module_and_forward("textDocument/definition", params, params.text_document.uri) if result is None: return None try: @@ -483,9 +523,7 @@ async def _on_definition(params: lsp.DefinitionParams) -> list[lsp.Location] | N async def _on_references(params: lsp.ReferenceParams) -> list[lsp.Location] | None: """Forward find-references request to jdtls.""" - if not server._proxy.is_available: - return None - result = await server._proxy.send_request("textDocument/references", _serialize_params(params)) + result = await _ensure_module_and_forward("textDocument/references", params, params.text_document.uri) if result is None: return None try: @@ -496,9 +534,7 @@ async def _on_references(params: lsp.ReferenceParams) -> list[lsp.Location] | No async def _on_document_symbol(params: lsp.DocumentSymbolParams) -> list[lsp.DocumentSymbol] | None: """Forward document symbol request to jdtls.""" - if not server._proxy.is_available: - return None - result = await server._proxy.send_request("textDocument/documentSymbol", _serialize_params(params)) + result = await _ensure_module_and_forward("textDocument/documentSymbol", params, params.text_document.uri) if result is None: return None try: diff --git a/tests/test_proxy.py b/tests/test_proxy.py index af75ab4..7bdc2e1 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -862,6 +862,12 @@ async def fake_send_request(*_args: Any, **_kwargs: Any) -> None: class TestFindModuleRoot: """Tests for find_module_root — build-file detection for module scoping.""" + @pytest.fixture(autouse=True) + def _clear_cache(self) -> None: + from java_functional_lsp.proxy import _cached_module_root + + _cached_module_root.cache_clear() + def test_finds_pom_xml(self, tmp_path: Any) -> None: from java_functional_lsp.proxy import find_module_root @@ -957,7 +963,9 @@ def test_queue_caps_at_max(self) -> None: for i in range(_MAX_QUEUED_NOTIFICATIONS + 50): proxy.queue_notification("textDocument/didChange", {"i": i}) assert len(proxy._queued_notifications) == _MAX_QUEUED_NOTIFICATIONS - # Oldest entries dropped — last entry should be the most recent + # Oldest entries dropped — first surviving entry is i=50 + assert proxy._queued_notifications[0] == ("textDocument/didChange", {"i": 50}) + # Last entry is the most recent assert proxy._queued_notifications[-1] == ("textDocument/didChange", {"i": _MAX_QUEUED_NOTIFICATIONS + 49}) async def test_ensure_started_no_retry_after_failure(self) -> None: @@ -998,10 +1006,13 @@ async def test_add_module_if_new_sends_notification(self) -> None: java_file.parent.mkdir() java_file.touch() uri = java_file.as_uri() - await proxy.add_module_if_new(uri) + result = await proxy.add_module_if_new(uri) + assert result is not None # Returns module URI string proxy.send_notification.assert_called_once() # type: ignore[attr-defined] call_args = proxy.send_notification.call_args # type: ignore[attr-defined] assert call_args[0][0] == "workspace/didChangeWorkspaceFolders" + # Module should be marked ADDED in registry + assert proxy.modules.was_added(result) async def test_add_module_if_new_skips_duplicate(self) -> None: from unittest.mock import AsyncMock @@ -1020,8 +1031,10 @@ async def test_add_module_if_new_skips_duplicate(self) -> None: java_file.parent.mkdir() java_file.touch() uri = java_file.as_uri() - await proxy.add_module_if_new(uri) - await proxy.add_module_if_new(uri) # duplicate + result1 = await proxy.add_module_if_new(uri) + result2 = await proxy.add_module_if_new(uri) # duplicate + assert result1 is not None # New module URI + assert result2 is None # Already known assert proxy.send_notification.call_count == 1 # type: ignore[attr-defined] async def test_expand_full_workspace_sends_notification(self) -> None: @@ -1037,6 +1050,25 @@ async def test_expand_full_workspace_sends_notification(self) -> None: proxy.send_notification.assert_called_once() # type: ignore[attr-defined] assert proxy._workspace_expanded is True + async def test_expand_full_workspace_removes_initial_module(self) -> None: + """When _initial_module_uri differs from root, it should be in the removed list.""" + from unittest.mock import AsyncMock + + from java_functional_lsp.proxy import JdtlsProxy + + proxy = JdtlsProxy() + proxy._available = True + proxy._original_root_uri = "file:///workspace/monorepo" + proxy._initial_module_uri = "file:///workspace/monorepo/module-a" + proxy.send_notification = AsyncMock() # type: ignore[assignment] + await proxy.expand_full_workspace() + proxy.send_notification.assert_called_once() # type: ignore[attr-defined] + call_args = proxy.send_notification.call_args[0] # type: ignore[attr-defined] + event = call_args[1]["event"] + assert len(event["removed"]) == 1 + assert event["removed"][0]["uri"] == "file:///workspace/monorepo/module-a" + assert event["added"][0]["uri"] == "file:///workspace/monorepo" + async def test_expand_full_workspace_noop_when_not_available(self) -> None: from unittest.mock import AsyncMock @@ -1057,7 +1089,7 @@ async def test_expand_full_workspace_noop_when_already_added(self) -> None: proxy = JdtlsProxy() proxy._available = True proxy._original_root_uri = "file:///workspace/monorepo" - proxy._added_module_uris.add("file:///workspace/monorepo") + proxy.modules.mark_added("file:///workspace/monorepo") proxy.send_notification = AsyncMock() # type: ignore[assignment] await proxy.expand_full_workspace() proxy.send_notification.assert_not_called() # type: ignore[attr-defined] diff --git a/tests/test_server.py b/tests/test_server.py index 0b5dc5a..4876612 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -476,6 +476,78 @@ async def test_lazy_start_jdtls_silent_failure(self) -> None: mock_flush.assert_not_called() mock_expand.assert_not_called() + async def test_ensure_module_and_forward_ready_module_fast_path(self) -> None: + """READY module → single send_request, no add_module call.""" + from unittest.mock import AsyncMock, patch + + from java_functional_lsp.server import _ensure_module_and_forward + from java_functional_lsp.server import server as srv + + srv._proxy.modules.mark_added("file:///mod") + srv._proxy.modules.mark_ready("file:///mod") + mock_send = AsyncMock(return_value={"result": "ok"}) + try: + with ( + patch.object(srv._proxy, "send_request", mock_send), + patch.object(srv._proxy, "_available", True), + patch("java_functional_lsp.server._resolve_module_uri", return_value="file:///mod"), + ): + result = await _ensure_module_and_forward("textDocument/hover", {}, "file:///mod/F.java") + finally: + srv._proxy.modules.clear() + assert result == {"result": "ok"} + assert mock_send.call_count == 1 + + async def test_ensure_module_and_forward_new_module_waits_and_retries(self) -> None: + """UNKNOWN module → add, first request null, wait_until_ready, retry succeeds.""" + from unittest.mock import AsyncMock, patch + + from java_functional_lsp.server import _ensure_module_and_forward + from java_functional_lsp.server import server as srv + + mock_add = AsyncMock(return_value="file:///mod") + mock_send = AsyncMock(side_effect=[None, {"result": "ok"}]) + + async def mock_wait(uri: str, timeout: float = 30.0) -> bool: + srv._proxy.modules.mark_ready(uri) + return True + + try: + with ( + patch.object(srv._proxy, "add_module_if_new", mock_add), + patch.object(srv._proxy, "send_request", mock_send), + patch.object(srv._proxy, "_available", True), + patch.object(srv._proxy.modules, "wait_until_ready", mock_wait), + patch("java_functional_lsp.server._resolve_module_uri", return_value="file:///mod"), + ): + result = await _ensure_module_and_forward("textDocument/hover", {}, "file:///mod/F.java") + finally: + srv._proxy.modules.clear() + assert result == {"result": "ok"} + assert mock_send.call_count == 2 + + async def test_ensure_module_and_forward_success_marks_ready(self) -> None: + """First successful request marks module as READY.""" + from unittest.mock import AsyncMock, patch + + from java_functional_lsp.proxy import ModuleState + from java_functional_lsp.server import _ensure_module_and_forward + from java_functional_lsp.server import server as srv + + mock_add = AsyncMock(return_value="file:///mod") + mock_send = AsyncMock(return_value={"result": "ok"}) + try: + with ( + patch.object(srv._proxy, "add_module_if_new", mock_add), + patch.object(srv._proxy, "send_request", mock_send), + patch.object(srv._proxy, "_available", True), + patch("java_functional_lsp.server._resolve_module_uri", return_value="file:///mod"), + ): + await _ensure_module_and_forward("textDocument/hover", {}, "file:///mod/F.java") + assert srv._proxy.modules.get_state("file:///mod") == ModuleState.READY + finally: + srv._proxy.modules.clear() + def test_serialize_params_camelcase(self) -> None: from java_functional_lsp.server import _serialize_params