diff --git a/litresearch.toml.example b/litresearch.toml.example index c514228..d29add2 100644 --- a/litresearch.toml.example +++ b/litresearch.toml.example @@ -2,6 +2,8 @@ default_model = "openrouter/openai/gpt-4o-mini" screening_threshold = 60 top_n = 5 max_results_per_query = 5 +s2_timeout = 10 +s2_requests_per_second = 1.0 pdf_first_pages = 4 pdf_last_pages = 2 output_dir = "output-smoke" diff --git a/src/litresearch/config.py b/src/litresearch/config.py index e3a3f22..19c851c 100644 --- a/src/litresearch/config.py +++ b/src/litresearch/config.py @@ -36,6 +36,7 @@ def settings_customise_sources( openrouter_api_key: str | None = None s2_api_key: str | None = None s2_timeout: int = 10 # seconds; SemanticScholar client timeout + s2_requests_per_second: float = 1.0 # max S2 request rate across endpoints default_model: str = "openai/gpt-4o-mini" screening_threshold: int = 60 # 0-100; papers below this are filtered before analysis top_n: int = 20 diff --git a/src/litresearch/stages/discovery.py b/src/litresearch/stages/discovery.py index 2659b8d..102ce69 100644 --- a/src/litresearch/stages/discovery.py +++ b/src/litresearch/stages/discovery.py @@ -1,5 +1,6 @@ """Stage 2: paper discovery.""" +import time from typing import Any, cast from rich.console import Console @@ -34,9 +35,20 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: ) else: scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False) + + min_interval = 0.0 + if settings.s2_requests_per_second > 0: + min_interval = 1.0 / settings.s2_requests_per_second + last_request_at: float | None = None + papers_by_id: dict[str, Paper] = {} for search_query in track(state.search_queries, description="Discovering papers"): + if last_request_at is not None and min_interval > 0: + elapsed = time.monotonic() - last_request_at + if elapsed < min_interval: + time.sleep(min_interval - elapsed) + try: results = scholar.search_paper( search_query.query, @@ -45,8 +57,11 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: ) except Exception as exc: # noqa: BLE001 console.print(f"[yellow]Search failed:[/yellow] {search_query.query} ({exc})") + last_request_at = time.monotonic() continue + last_request_at = time.monotonic() + paginated_results = cast(Any, results) for result in paginated_results.items: paper = Paper.from_s2(result) diff --git a/src/litresearch/stages/enrichment.py b/src/litresearch/stages/enrichment.py index 5e522da..14365ef 100644 --- a/src/litresearch/stages/enrichment.py +++ b/src/litresearch/stages/enrichment.py @@ -1,5 +1,6 @@ """Stage 3: metadata enrichment.""" +import time from typing import Any, cast from rich.console import Console @@ -43,14 +44,27 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: else: scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False) + min_interval = 0.0 + if settings.s2_requests_per_second > 0: + min_interval = 1.0 / settings.s2_requests_per_second + last_request_at: float | None = None + papers_by_id = {paper.paper_id: paper for paper in state.candidates} for batch in _chunk(list(papers_by_id), BATCH_SIZE): + if last_request_at is not None and min_interval > 0: + elapsed = time.monotonic() - last_request_at + if elapsed < min_interval: + time.sleep(min_interval - elapsed) + try: results = scholar.get_papers(batch, fields=ENRICHMENT_FIELDS) except Exception as exc: # noqa: BLE001 console.print(f"[yellow]Enrichment failed:[/yellow] {exc}") + last_request_at = time.monotonic() continue + last_request_at = time.monotonic() + for result in cast(list[Any], results): enriched = Paper.from_s2(result) papers_by_id[enriched.paper_id] = papers_by_id[enriched.paper_id].model_copy( diff --git a/tests/unit/test_stages_discovery.py b/tests/unit/test_stages_discovery.py index 5460d29..c0ef54d 100644 --- a/tests/unit/test_stages_discovery.py +++ b/tests/unit/test_stages_discovery.py @@ -3,6 +3,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from litresearch.config import Settings from litresearch.models import PipelineState, SearchQuery from litresearch.stages.discovery import run @@ -110,3 +112,39 @@ def test_paper_deduplication_by_id(self, tmp_path) -> None: assert len(result.candidates) == 1 assert result.candidates[0].paper_id == "same-id" + + def test_rate_limit_waits_between_requests(self, tmp_path) -> None: + """Test discovery throttles requests to configured RPS.""" + settings = Settings( + s2_api_key=None, + s2_timeout=10, + s2_requests_per_second=1.0, + max_results_per_query=10, + ) + + query1 = SearchQuery(query="query1", facet="Facet1") + query2 = SearchQuery(query="query2", facet="Facet2") + state = PipelineState( + questions=["Question?"], + search_queries=[query1, query2], + current_stage="query_gen", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + mock_scholar = MagicMock() + mock_scholar.search_paper.return_value = SimpleNamespace(items=[]) + + with ( + patch("litresearch.stages.discovery.SemanticScholar", return_value=mock_scholar), + patch( + "litresearch.stages.discovery.time.monotonic", + side_effect=[0.0, 0.2, 0.3, 1.3], + ), + patch("litresearch.stages.discovery.time.sleep") as mock_sleep, + ): + run(state, settings) + + mock_sleep.assert_called_once() + assert mock_sleep.call_args.args[0] == pytest.approx(0.8)