Skip to content

Commit 89f6a1b

Browse files
authored
fix(loader): close qdrant client after run (#146)
1 parent 8cbae1f commit 89f6a1b

File tree

2 files changed

+101
-81
lines changed

2 files changed

+101
-81
lines changed

mcp_plex/loader/__init__.py

Lines changed: 79 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -364,86 +364,89 @@ async def run(
364364
https=qdrant_https,
365365
prefer_grpc=qdrant_prefer_grpc,
366366
)
367-
collection_name = "media-items"
368-
await _ensure_collection(
369-
client,
370-
collection_name,
371-
dense_size=dense_size,
372-
dense_distance=dense_distance,
373-
)
374-
375-
items: list[AggregatedItem]
376-
if sample_dir is not None:
377-
logger.info("Loading sample data from %s", sample_dir)
378-
sample_items = samples._load_from_sample(sample_dir)
379-
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
380-
client=client,
381-
collection_name=collection_name,
382-
dense_model_name=dense_model_name,
383-
sparse_model_name=sparse_model_name,
384-
tmdb_api_key=None,
385-
sample_items=sample_items,
386-
plex_server=None,
387-
plex_chunk_size=plex_chunk_size,
388-
enrichment_batch_size=enrichment_batch_size,
389-
enrichment_workers=enrichment_workers,
390-
upsert_buffer_size=upsert_buffer_size,
391-
max_concurrent_upserts=max_concurrent_upserts,
392-
imdb_config=IMDbRuntimeConfig(
393-
cache=imdb_config.cache,
394-
max_retries=imdb_config.max_retries,
395-
backoff=imdb_config.backoff,
396-
retry_queue=IMDbRetryQueue(),
397-
requests_per_window=imdb_config.requests_per_window,
398-
window_seconds=imdb_config.window_seconds,
399-
),
400-
qdrant_config=qdrant_config,
367+
try:
368+
collection_name = "media-items"
369+
await _ensure_collection(
370+
client,
371+
collection_name,
372+
dense_size=dense_size,
373+
dense_distance=dense_distance,
401374
)
402-
logger.info("Starting staged loader (sample mode)")
403-
await orchestrator.run()
404-
else:
405-
if PlexServer is None:
406-
raise RuntimeError("plexapi is required for live loading")
407-
if not plex_url or not plex_token:
408-
raise RuntimeError("PLEX_URL and PLEX_TOKEN must be provided")
409-
if not tmdb_api_key:
410-
raise RuntimeError("TMDB_API_KEY must be provided")
411-
logger.info("Loading data from Plex server %s", plex_url)
412-
server = PlexServer(plex_url, plex_token)
413-
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
414-
client=client,
415-
collection_name=collection_name,
416-
dense_model_name=dense_model_name,
417-
sparse_model_name=sparse_model_name,
418-
tmdb_api_key=tmdb_api_key,
419-
sample_items=None,
420-
plex_server=server,
421-
plex_chunk_size=plex_chunk_size,
422-
enrichment_batch_size=enrichment_batch_size,
423-
enrichment_workers=enrichment_workers,
424-
upsert_buffer_size=upsert_buffer_size,
425-
max_concurrent_upserts=max_concurrent_upserts,
426-
imdb_config=imdb_config,
427-
qdrant_config=qdrant_config,
428-
)
429-
logger.info("Starting staged loader (Plex mode)")
430-
await orchestrator.run()
431-
logger.info("Loaded %d items", len(items))
432-
if not items:
433-
logger.info("No points to upsert")
434375

435-
await _process_qdrant_retry_queue(
436-
client,
437-
collection_name,
438-
qdrant_retry_queue,
439-
config=qdrant_config,
440-
)
376+
items: list[AggregatedItem]
377+
if sample_dir is not None:
378+
logger.info("Loading sample data from %s", sample_dir)
379+
sample_items = samples._load_from_sample(sample_dir)
380+
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
381+
client=client,
382+
collection_name=collection_name,
383+
dense_model_name=dense_model_name,
384+
sparse_model_name=sparse_model_name,
385+
tmdb_api_key=None,
386+
sample_items=sample_items,
387+
plex_server=None,
388+
plex_chunk_size=plex_chunk_size,
389+
enrichment_batch_size=enrichment_batch_size,
390+
enrichment_workers=enrichment_workers,
391+
upsert_buffer_size=upsert_buffer_size,
392+
max_concurrent_upserts=max_concurrent_upserts,
393+
imdb_config=IMDbRuntimeConfig(
394+
cache=imdb_config.cache,
395+
max_retries=imdb_config.max_retries,
396+
backoff=imdb_config.backoff,
397+
retry_queue=IMDbRetryQueue(),
398+
requests_per_window=imdb_config.requests_per_window,
399+
window_seconds=imdb_config.window_seconds,
400+
),
401+
qdrant_config=qdrant_config,
402+
)
403+
logger.info("Starting staged loader (sample mode)")
404+
await orchestrator.run()
405+
else:
406+
if PlexServer is None:
407+
raise RuntimeError("plexapi is required for live loading")
408+
if not plex_url or not plex_token:
409+
raise RuntimeError("PLEX_URL and PLEX_TOKEN must be provided")
410+
if not tmdb_api_key:
411+
raise RuntimeError("TMDB_API_KEY must be provided")
412+
logger.info("Loading data from Plex server %s", plex_url)
413+
server = PlexServer(plex_url, plex_token)
414+
orchestrator, items, qdrant_retry_queue = _build_loader_orchestrator(
415+
client=client,
416+
collection_name=collection_name,
417+
dense_model_name=dense_model_name,
418+
sparse_model_name=sparse_model_name,
419+
tmdb_api_key=tmdb_api_key,
420+
sample_items=None,
421+
plex_server=server,
422+
plex_chunk_size=plex_chunk_size,
423+
enrichment_batch_size=enrichment_batch_size,
424+
enrichment_workers=enrichment_workers,
425+
upsert_buffer_size=upsert_buffer_size,
426+
max_concurrent_upserts=max_concurrent_upserts,
427+
imdb_config=imdb_config,
428+
qdrant_config=qdrant_config,
429+
)
430+
logger.info("Starting staged loader (Plex mode)")
431+
await orchestrator.run()
432+
logger.info("Loaded %d items", len(items))
433+
if not items:
434+
logger.info("No points to upsert")
435+
436+
await _process_qdrant_retry_queue(
437+
client,
438+
collection_name,
439+
qdrant_retry_queue,
440+
config=qdrant_config,
441+
)
441442

442-
if imdb_queue_path:
443-
_persist_imdb_retry_queue(imdb_queue_path, imdb_config.retry_queue)
443+
if imdb_queue_path:
444+
_persist_imdb_retry_queue(imdb_queue_path, imdb_config.retry_queue)
444445

445-
json.dump([item.model_dump(mode="json") for item in items], fp=sys.stdout, indent=2)
446-
sys.stdout.write("\n")
446+
json.dump([item.model_dump(mode="json") for item in items], fp=sys.stdout, indent=2)
447+
sys.stdout.write("\n")
448+
finally:
449+
await client.close()
447450

448451

449452
async def load_media(

tests/test_loader_integration.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class CaptureClient(AsyncQdrantClient):
1717
captured_points: list[models.PointStruct] = []
1818
upsert_calls: int = 0
1919
created_indexes: list[tuple[str, Any]] = []
20+
close_calls: int = 0
2021

2122
def __init__(self, *args, **kwargs):
2223
super().__init__(*args, **kwargs)
@@ -44,6 +45,10 @@ async def create_payload_index(
4445
wait=wait,
4546
)
4647

48+
async def close(self) -> None:
49+
CaptureClient.close_calls += 1
50+
await super().close()
51+
4752

4853
async def _run_loader(sample_dir: Path, **kwargs) -> None:
4954
await loader.run(
@@ -62,6 +67,7 @@ def test_run_writes_points(monkeypatch):
6267
CaptureClient.captured_points = []
6368
CaptureClient.upsert_calls = 0
6469
CaptureClient.created_indexes = []
70+
CaptureClient.close_calls = 0
6571
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
6672
asyncio.run(_run_loader(sample_dir))
6773
client = CaptureClient.instance
@@ -70,9 +76,6 @@ def test_run_writes_points(monkeypatch):
7076
assert index_map.get("show_title") == models.PayloadSchemaType.KEYWORD
7177
assert index_map.get("season_number") == models.PayloadSchemaType.INTEGER
7278
assert index_map.get("episode_number") == models.PayloadSchemaType.INTEGER
73-
points, _ = asyncio.run(client.scroll("media-items", limit=10, with_payload=True))
74-
assert len(points) == 2
75-
assert all("title" in p.payload and "type" in p.payload for p in points)
7679
captured = CaptureClient.captured_points
7780
assert len(captured) == 2
7881
assert all(isinstance(p.vector["dense"], models.Document) for p in captured)
@@ -85,7 +88,7 @@ def test_run_writes_points(monkeypatch):
8588
texts = [p.vector["dense"].text for p in captured]
8689
assert any("Directed by" in t for t in texts)
8790
assert any("Starring" in t for t in texts)
88-
movie_point = next(p for p in points if p.payload["type"] == "movie")
91+
movie_point = next(p for p in captured if p.payload["type"] == "movie")
8992
assert "directors" in movie_point.payload and "Guy Ritchie" in movie_point.payload["directors"]
9093
assert "writers" in movie_point.payload and movie_point.payload["writers"]
9194
assert "genres" in movie_point.payload and movie_point.payload["genres"]
@@ -94,7 +97,7 @@ def test_run_writes_points(monkeypatch):
9497
assert movie_point.payload.get("plot")
9598
assert movie_point.payload.get("tagline")
9699
assert movie_point.payload.get("reviews")
97-
episode_point = next(p for p in points if p.payload["type"] == "episode")
100+
episode_point = next(p for p in captured if p.payload["type"] == "episode")
98101
assert episode_point.payload.get("show_title") == "Alien: Earth"
99102
assert episode_point.payload.get("season_title") == "Season 1"
100103
assert episode_point.payload.get("season_number") == 1
@@ -110,6 +113,7 @@ def test_run_processes_imdb_queue(monkeypatch, tmp_path):
110113
monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
111114
CaptureClient.captured_points = []
112115
CaptureClient.upsert_calls = 0
116+
CaptureClient.close_calls = 0
113117
queue_file = tmp_path / "queue.json"
114118
queue_file.write_text(json.dumps(["tt0111161"]))
115119
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
@@ -138,12 +142,25 @@ def test_run_upserts_in_batches(monkeypatch):
138142
monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
139143
CaptureClient.captured_points = []
140144
CaptureClient.upsert_calls = 0
145+
CaptureClient.close_calls = 0
141146
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
142147
asyncio.run(_run_loader(sample_dir, qdrant_batch_size=1))
143148
assert CaptureClient.upsert_calls == 2
144149
assert len(CaptureClient.captured_points) == 2
145150

146151

152+
def test_run_closes_client_once(monkeypatch):
153+
monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
154+
CaptureClient.close_calls = 0
155+
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
156+
157+
asyncio.run(_run_loader(sample_dir))
158+
assert CaptureClient.close_calls == 1
159+
160+
asyncio.run(_run_loader(sample_dir))
161+
assert CaptureClient.close_calls == 2
162+
163+
147164
def test_run_raises_for_unknown_dense_model():
148165
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
149166

0 commit comments

Comments
 (0)