Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litresearch.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ default_model = "openrouter/openai/gpt-4o-mini"
screening_threshold = 60
top_n = 5
max_results_per_query = 5
s2_timeout = 10
s2_requests_per_second = 1.0
pdf_first_pages = 4
pdf_last_pages = 2
output_dir = "output-smoke"
1 change: 1 addition & 0 deletions src/litresearch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def settings_customise_sources(
openrouter_api_key: str | None = None
s2_api_key: str | None = None
s2_timeout: int = 10 # seconds; SemanticScholar client timeout
s2_requests_per_second: float = 1.0 # max S2 request rate across endpoints
default_model: str = "openai/gpt-4o-mini"
screening_threshold: int = 60 # 0-100; papers below this are filtered before analysis
top_n: int = 20
Expand Down
15 changes: 15 additions & 0 deletions src/litresearch/stages/discovery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Stage 2: paper discovery."""

import time
from typing import Any, cast

from rich.console import Console
Expand Down Expand Up @@ -34,9 +35,20 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
)
else:
scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False)

min_interval = 0.0
if settings.s2_requests_per_second > 0:
min_interval = 1.0 / settings.s2_requests_per_second
last_request_at: float | None = None

papers_by_id: dict[str, Paper] = {}

for search_query in track(state.search_queries, description="Discovering papers"):
if last_request_at is not None and min_interval > 0:
elapsed = time.monotonic() - last_request_at
if elapsed < min_interval:
time.sleep(min_interval - elapsed)

try:
results = scholar.search_paper(
search_query.query,
Expand All @@ -45,8 +57,11 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
)
except Exception as exc: # noqa: BLE001
console.print(f"[yellow]Search failed:[/yellow] {search_query.query} ({exc})")
last_request_at = time.monotonic()
continue

last_request_at = time.monotonic()

paginated_results = cast(Any, results)
for result in paginated_results.items:
paper = Paper.from_s2(result)
Expand Down
14 changes: 14 additions & 0 deletions src/litresearch/stages/enrichment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Stage 3: metadata enrichment."""

import time
from typing import Any, cast

from rich.console import Console
Expand Down Expand Up @@ -43,14 +44,27 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
else:
scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False)

min_interval = 0.0
if settings.s2_requests_per_second > 0:
min_interval = 1.0 / settings.s2_requests_per_second
last_request_at: float | None = None

papers_by_id = {paper.paper_id: paper for paper in state.candidates}
for batch in _chunk(list(papers_by_id), BATCH_SIZE):
if last_request_at is not None and min_interval > 0:
elapsed = time.monotonic() - last_request_at
if elapsed < min_interval:
time.sleep(min_interval - elapsed)

try:
results = scholar.get_papers(batch, fields=ENRICHMENT_FIELDS)
except Exception as exc: # noqa: BLE001
console.print(f"[yellow]Enrichment failed:[/yellow] {exc}")
last_request_at = time.monotonic()
continue

last_request_at = time.monotonic()

for result in cast(list[Any], results):
enriched = Paper.from_s2(result)
papers_by_id[enriched.paper_id] = papers_by_id[enriched.paper_id].model_copy(
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_stages_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest

from litresearch.config import Settings
from litresearch.models import PipelineState, SearchQuery
from litresearch.stages.discovery import run
Expand Down Expand Up @@ -110,3 +112,39 @@ def test_paper_deduplication_by_id(self, tmp_path) -> None:

assert len(result.candidates) == 1
assert result.candidates[0].paper_id == "same-id"

def test_rate_limit_waits_between_requests(self, tmp_path) -> None:
"""Test discovery throttles requests to configured RPS."""
settings = Settings(
s2_api_key=None,
s2_timeout=10,
s2_requests_per_second=1.0,
max_results_per_query=10,
)

query1 = SearchQuery(query="query1", facet="Facet1")
query2 = SearchQuery(query="query2", facet="Facet2")
state = PipelineState(
questions=["Question?"],
search_queries=[query1, query2],
current_stage="query_gen",
output_dir=str(tmp_path),
created_at="2024-01-01",
updated_at="2024-01-01",
)

mock_scholar = MagicMock()
mock_scholar.search_paper.return_value = SimpleNamespace(items=[])

with (
patch("litresearch.stages.discovery.SemanticScholar", return_value=mock_scholar),
patch(
"litresearch.stages.discovery.time.monotonic",
side_effect=[0.0, 0.2, 0.3, 1.3],
),
patch("litresearch.stages.discovery.time.sleep") as mock_sleep,
):
run(state, settings)

mock_sleep.assert_called_once()
assert mock_sleep.call_args.args[0] == pytest.approx(0.8)