Skip to content
Merged
123 changes: 106 additions & 17 deletions src/java_functional_lsp/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
import shutil
import subprocess
from collections import deque
from collections.abc import Callable, Mapping
from functools import lru_cache
from pathlib import Path
Expand Down Expand Up @@ -316,6 +317,78 @@ async def read_message(reader: asyncio.StreamReader) -> dict[str, Any] | None:
_BUILD_FILES = ("pom.xml", "build.gradle", "build.gradle.kts")
_WORKSPACE_DID_CHANGE_FOLDERS = "workspace/didChangeWorkspaceFolders"
_MAX_QUEUED_NOTIFICATIONS = 200
_MODULE_READY_TIMEOUT = 30.0


class ModuleState:
"""Module import states — UNKNOWN → ADDED → READY."""

UNKNOWN = "unknown"
ADDED = "added"
READY = "ready"


class ModuleRegistry:
"""Thread-safe (asyncio) registry tracking jdtls module import states.

Uses a plain dict for O(1) hot-path lookups and per-module ``asyncio.Event``
for adaptive waiting — coroutines blocked on ``wait_until_ready()`` wake
instantly when ``mark_ready()`` is called, instead of a fixed sleep.

Safe without locks because asyncio is single-threaded: dict mutations that
don't span an ``await`` are atomic.
"""

def __init__(self) -> None:
self._states: dict[str, str] = {}
self._ready_events: dict[str, asyncio.Event] = {}

def get_state(self, uri: str) -> str:
"""O(1) state lookup. Returns ModuleState constant."""
return self._states.get(uri, ModuleState.UNKNOWN)

def is_ready(self, uri: str) -> bool:
"""O(1) hot-path check — zero overhead when module is ready."""
return self._states.get(uri) == ModuleState.READY

def was_added(self, uri: str) -> bool:
"""True if module was sent to jdtls (ADDED or READY)."""
return uri in self._states

def mark_added(self, uri: str) -> None:
"""Mark module as sent to jdtls. Pre-creates the Event for waiters.

Must be called before any ``await`` to prevent duplicate add_module calls.
"""
self._states[uri] = ModuleState.ADDED
self._ready_events.setdefault(uri, asyncio.Event())

def mark_ready(self, uri: str) -> None:
"""Mark module as confirmed working. Wakes all coroutines waiting on it."""
self._states[uri] = ModuleState.READY
event = self._ready_events.pop(uri, None)
if event is not None:
event.set()

def clear(self) -> None:
"""Reset all state. Used by tests."""
self._states.clear()
self._ready_events.clear()

async def wait_until_ready(self, uri: str, timeout: float = _MODULE_READY_TIMEOUT) -> bool:
"""Suspend until the module is ready or timeout expires.

Returns True if ready, False on timeout. If already READY, returns
immediately without suspending.
"""
event = self._ready_events.setdefault(uri, asyncio.Event())
if event.is_set():
return True
try:
await asyncio.wait_for(event.wait(), timeout=timeout)
return True
except asyncio.TimeoutError:
return False


@lru_cache(maxsize=256)
Expand All @@ -336,6 +409,9 @@ def find_module_root(file_path: str) -> str | None:

Returns the directory path, or ``None`` if no build file is found before
reaching the filesystem root. Results are cached by parent directory.

**Note:** cache entries are never invalidated. Build files added after the
first lookup for a given directory will not be detected until process restart.
"""
return _cached_module_root(str(Path(file_path).parent))

Expand Down Expand Up @@ -373,10 +449,10 @@ def __init__(self, on_diagnostics: Callable[[str, list[Any]], None] | None = Non
self._start_failed = False
self._jdtls_on_path = False
self._lazy_start_fired = False
self._queued_notifications: list[tuple[str, Any]] = []
self._queued_notifications: deque[tuple[str, Any]] = deque(maxlen=_MAX_QUEUED_NOTIFICATIONS)
self._original_root_uri: str | None = None
self._initial_module_uri: str | None = None
self._added_module_uris: set[str] = set()
self.modules = ModuleRegistry()
self._workspace_expanded = False

@property
Expand Down Expand Up @@ -437,9 +513,9 @@ async def start(self, init_params: dict[str, Any], *, module_root_uri: str | Non
ws = caps.setdefault("workspace", {})
ws["workspaceFolders"] = True

# Track the initial module as already loaded.
# Track the initial module as already loaded (mark ADDED before await).
self._initial_module_uri = module_root_uri
self._added_module_uris.add(effective_root_uri)
self.modules.mark_added(effective_root_uri)

# Build a clean environment for jdtls.
loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -508,33 +584,45 @@ async def ensure_started(self, init_params: dict[str, Any], file_uri: str) -> bo
self._start_failed = True
self._queued_notifications.clear()
return started
except Exception:
self._start_failed = True
self._queued_notifications.clear()
raise
finally:
self._starting = False

def queue_notification(self, method: str, params: Any) -> None:
"""Buffer a notification for replay after jdtls starts.

Capped at ``_MAX_QUEUED_NOTIFICATIONS`` to prevent unbounded memory
growth during long jdtls startup. Oldest entries are dropped on overflow.
Uses a ``deque(maxlen=200)`` so oldest entries are dropped in O(1)
when the queue overflows during long jdtls startup.
"""
if len(self._queued_notifications) >= _MAX_QUEUED_NOTIFICATIONS:
self._queued_notifications.pop(0)
self._queued_notifications.append((method, params))

async def flush_queued_notifications(self) -> None:
"""Send all queued notifications to jdtls."""
queue, self._queued_notifications = self._queued_notifications, []
queue = list(self._queued_notifications)
self._queued_notifications.clear()
for method, params in queue:
await self.send_notification(method, params)

async def add_module_if_new(self, file_uri: str) -> None:
"""Add the module containing *file_uri* to jdtls if not already added."""
async def add_module_if_new(self, file_uri: str) -> str | None:
"""Add the module containing *file_uri* to jdtls if not already added.

Returns the module URI if a new module was added (UNKNOWN → ADDED),
or ``None`` if already known or unavailable. The returned URI can be
used with ``modules.wait_until_ready()`` for adaptive waiting.

Calls ``modules.mark_added()`` before any ``await`` to prevent
duplicate add calls from concurrent coroutines.
"""
if not self._available:
return
return None
module_uri = _resolve_module_uri(file_uri)
if module_uri is None or module_uri in self._added_module_uris:
return
self._added_module_uris.add(module_uri)
if module_uri is None or self.modules.was_added(module_uri):
return None
# Mark ADDED before await — atomic in asyncio, prevents duplicate sends.
self.modules.mark_added(module_uri)
from pygls.uris import to_fs_path

logger.info("jdtls: adding module %s", _redact_path(to_fs_path(module_uri)))
Expand All @@ -543,6 +631,7 @@ async def add_module_if_new(self, file_uri: str) -> None:
_WORKSPACE_DID_CHANGE_FOLDERS,
{"event": {"added": [{"uri": module_uri, "name": mod_name}], "removed": []}},
)
return module_uri

async def expand_full_workspace(self) -> None:
"""Expand jdtls workspace to the full monorepo root (background task).
Expand All @@ -556,10 +645,10 @@ async def expand_full_workspace(self) -> None:

root_path = to_fs_path(self._original_root_uri) or self._original_root_uri
root_uri = from_fs_path(root_path) or self._original_root_uri
if root_uri in self._added_module_uris:
if self.modules.was_added(root_uri):
self._workspace_expanded = True
return
self._added_module_uris.add(root_uri)
self.modules.mark_added(root_uri)

# Remove initial module folder to avoid double-indexing.
removed: list[dict[str, str]] = []
Expand Down
70 changes: 53 additions & 17 deletions src/java_functional_lsp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .analyzers.null_checker import NullChecker
from .analyzers.spring_checker import SpringChecker
from .fixes import get_fix, get_fix_registry_keys
from .proxy import JdtlsProxy
from .proxy import JdtlsProxy, _resolve_module_uri

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -337,7 +337,7 @@ def _forward_or_queue(method: str, serialized: Any) -> None:
task = asyncio.create_task(server._proxy.send_notification(method, serialized))
_bg_tasks.add(task)
task.add_done_callback(_bg_tasks.discard)
elif server._proxy._starting:
elif server._proxy._lazy_start_fired and not server._proxy._start_failed:
server._proxy.queue_notification(method, serialized)


Expand Down Expand Up @@ -440,11 +440,55 @@ async def _expand_workspace_background() -> None:
# they only activate after jdtls starts.


async def _ensure_module_and_forward(method: str, params: Any, file_uri: str) -> Any | None:
"""Forward a request to jdtls, ensuring the file's module is loaded.

Uses ``ModuleRegistry`` for adaptive waiting:
- **READY**: forward immediately (zero overhead on hot path)
- **UNKNOWN**: add module, wait until ready (adaptive, not fixed sleep)
- **ADDED**: module sent but not confirmed — wait until ready

When a request succeeds, marks the module as READY so subsequent
requests skip the wait entirely.
"""
proxy = server._proxy
if not proxy.is_available:
return None

module_uri = _resolve_module_uri(file_uri)

# Hot path: module already confirmed working.
if module_uri and proxy.modules.is_ready(module_uri):
return await proxy.send_request(method, _serialize_params(params))

# Cold path: add module if unknown, then wait for ready.
new_module_uri = await proxy.add_module_if_new(file_uri)

serialized = _serialize_params(params)
result = await proxy.send_request(method, serialized)

if result is not None:
# Success — mark module as ready so future requests are instant.
if module_uri:
proxy.modules.mark_ready(module_uri)
return result

# Null result and module is not yet ready — wait then retry once.
# Use a short timeout (5s) so single-caller case doesn't block for 30s.
# If a concurrent request succeeds, Event.set() wakes us early.
wait_uri = new_module_uri or module_uri
if wait_uri and not proxy.modules.is_ready(wait_uri):
await proxy.modules.wait_until_ready(wait_uri, timeout=5.0)
# Always retry once after waiting — even on timeout the module may be ready.
result = await proxy.send_request(method, serialized)
if result is not None and module_uri:
proxy.modules.mark_ready(module_uri)
return result


async def _on_completion(params: lsp.CompletionParams) -> lsp.CompletionList | None:
"""Forward completion request to jdtls."""
if not server._proxy.is_available:
return None
result = await server._proxy.send_request("textDocument/completion", _serialize_params(params))
result = await _ensure_module_and_forward("textDocument/completion", params, params.text_document.uri)
if result is None:
return None
try:
Expand All @@ -455,9 +499,7 @@ async def _on_completion(params: lsp.CompletionParams) -> lsp.CompletionList | N

async def _on_hover(params: lsp.HoverParams) -> lsp.Hover | None:
"""Forward hover request to jdtls."""
if not server._proxy.is_available:
return None
result = await server._proxy.send_request("textDocument/hover", _serialize_params(params))
result = await _ensure_module_and_forward("textDocument/hover", params, params.text_document.uri)
if result is None:
return None
try:
Expand All @@ -468,9 +510,7 @@ async def _on_hover(params: lsp.HoverParams) -> lsp.Hover | None:

async def _on_definition(params: lsp.DefinitionParams) -> list[lsp.Location] | None:
"""Forward go-to-definition request to jdtls."""
if not server._proxy.is_available:
return None
result = await server._proxy.send_request("textDocument/definition", _serialize_params(params))
result = await _ensure_module_and_forward("textDocument/definition", params, params.text_document.uri)
if result is None:
return None
try:
Expand All @@ -483,9 +523,7 @@ async def _on_definition(params: lsp.DefinitionParams) -> list[lsp.Location] | N

async def _on_references(params: lsp.ReferenceParams) -> list[lsp.Location] | None:
"""Forward find-references request to jdtls."""
if not server._proxy.is_available:
return None
result = await server._proxy.send_request("textDocument/references", _serialize_params(params))
result = await _ensure_module_and_forward("textDocument/references", params, params.text_document.uri)
if result is None:
return None
try:
Expand All @@ -496,9 +534,7 @@ async def _on_references(params: lsp.ReferenceParams) -> list[lsp.Location] | No

async def _on_document_symbol(params: lsp.DocumentSymbolParams) -> list[lsp.DocumentSymbol] | None:
"""Forward document symbol request to jdtls."""
if not server._proxy.is_available:
return None
result = await server._proxy.send_request("textDocument/documentSymbol", _serialize_params(params))
result = await _ensure_module_and_forward("textDocument/documentSymbol", params, params.text_document.uri)
if result is None:
return None
try:
Expand Down
Loading
Loading