From a6c4df24f81a862ecc2e843e93c03acfdabdd78a Mon Sep 17 00:00:00 2001 From: Silas Pignotti Date: Mon, 23 Mar 2026 16:00:31 +0100 Subject: [PATCH 1/6] fix: critical issues - JSON parsing, S2 timeout, PDF deduplication - Guard json.loads() in analysis.py with try/except JSONDecodeError - Add s2_timeout config setting (default 10s) with retry=False for S2 client - Prevent PDF double-download by saving during analysis and marking pdf_downloaded - Skip already-downloaded PDFs in export stage --- src/litresearch/config.py | 1 + src/litresearch/stages/analysis.py | 66 ++++++++++++++++++++-------- src/litresearch/stages/discovery.py | 8 +++- src/litresearch/stages/enrichment.py | 8 +++- src/litresearch/stages/export.py | 2 + tests/unit/test_analysis.py | 58 ++++++++++++++++++++++++ tests/unit/test_export.py | 50 +++++++++++++++++++++ 7 files changed, 171 insertions(+), 22 deletions(-) create mode 100644 tests/unit/test_analysis.py diff --git a/src/litresearch/config.py b/src/litresearch/config.py index 941f414..6ef3022 100644 --- a/src/litresearch/config.py +++ b/src/litresearch/config.py @@ -35,6 +35,7 @@ def settings_customise_sources( anthropic_api_key: str | None = None openrouter_api_key: str | None = None s2_api_key: str | None = None + s2_timeout: int = 10 # seconds; SemanticScholar client timeout default_model: str = "openai/gpt-4o-mini" screening_threshold: int = 40 top_n: int = 20 diff --git a/src/litresearch/stages/analysis.py b/src/litresearch/stages/analysis.py index 48da9cd..f77c545 100644 --- a/src/litresearch/stages/analysis.py +++ b/src/litresearch/stages/analysis.py @@ -1,6 +1,7 @@ """Stage 4: screening and extended paper analysis.""" import json +from pathlib import Path from rich.console import Console from rich.progress import track @@ -39,12 +40,16 @@ def _screen_paper( console.print(f"[yellow]Screening failed:[/yellow] {paper.title} ({exc})") return None - payload = json.loads(response) - return ScreeningResult( - paper_id=paper.paper_id, - relevance_score=payload["relevance_score"], - rationale=payload["rationale"], - ) + try: + payload = json.loads(response) + return ScreeningResult( + paper_id=paper.paper_id, + relevance_score=payload["relevance_score"], + rationale=payload["rationale"], + ) + except json.JSONDecodeError: + console.print(f"[yellow]JSON parse failed:[/yellow] {paper.title}") + return None def _analyze_paper( @@ -52,11 +57,17 @@ def _analyze_paper( questions: list[str], settings: Settings, prompt: str, -) -> AnalysisResult | None: + output_dir: str, +) -> tuple[AnalysisResult | None, bool]: pdf_text = "" + pdf_downloaded = False if paper.open_access_pdf_url: pdf_bytes = download_pdf(paper.open_access_pdf_url) if pdf_bytes is not None: + papers_dir = Path(output_dir) / "papers" + papers_dir.mkdir(parents=True, exist_ok=True) + (papers_dir / f"{paper.paper_id}.pdf").write_bytes(pdf_bytes) + pdf_downloaded = True pdf_text = extract_text( pdf_bytes, first_pages=settings.pdf_first_pages, @@ -84,23 +95,31 @@ def _analyze_paper( response = call_llm(settings, prompt, user_content) except LLMError as exc: console.print(f"[yellow]Analysis failed:[/yellow] {paper.title} ({exc})") - return None + return (None, pdf_downloaded) - payload = json.loads(response) - return AnalysisResult( - paper_id=paper.paper_id, - summary=payload["summary"], - key_findings=payload.get("key_findings", []), - methodology=payload["methodology"], - relevance_score=payload["relevance_score"], - relevance_rationale=payload["relevance_rationale"], - ) + try: + payload = json.loads(response) + return ( + AnalysisResult( + paper_id=paper.paper_id, + summary=payload["summary"], + key_findings=payload.get("key_findings", []), + methodology=payload["methodology"], + relevance_score=payload["relevance_score"], + relevance_rationale=payload["relevance_rationale"], + ), + pdf_downloaded, + ) + except json.JSONDecodeError: + console.print(f"[yellow]JSON parse failed:[/yellow] {paper.title}") + return (None, pdf_downloaded) def run(state: PipelineState, settings: Settings) -> PipelineState: """Screen candidate papers and analyze the relevant ones.""" screening_prompt = load_prompt("screening") analysis_prompt = load_prompt("analysis") + papers_by_id = {paper.paper_id: paper for paper in state.candidates} screening_results: list[ScreeningResult] = [] passed_papers: list[Paper] = [] @@ -118,12 +137,23 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: analyses: list[AnalysisResult] = [] for paper in track(passed_papers, description="Analyzing papers"): - analysis_result = _analyze_paper(paper, state.questions, settings, analysis_prompt) + analysis_result, pdf_downloaded = _analyze_paper( + paper, + state.questions, + settings, + analysis_prompt, + state.output_dir, + ) + if pdf_downloaded: + papers_by_id[paper.paper_id] = paper.model_copy(update={"pdf_downloaded": True}) if analysis_result is not None: analyses.append(analysis_result) + updated_candidates = [papers_by_id[paper.paper_id] for paper in state.candidates] + return state.model_copy( update={ + "candidates": updated_candidates, "screening_results": screening_results, "analyses": analyses, "current_stage": "analysis", diff --git a/src/litresearch/stages/discovery.py b/src/litresearch/stages/discovery.py index 25aef21..2659b8d 100644 --- a/src/litresearch/stages/discovery.py +++ b/src/litresearch/stages/discovery.py @@ -27,9 +27,13 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: """Discover candidate papers for the generated search queries.""" if settings.s2_api_key: - scholar = SemanticScholar(api_key=settings.s2_api_key) + scholar = SemanticScholar( + api_key=settings.s2_api_key, + timeout=settings.s2_timeout, + retry=False, + ) else: - scholar = SemanticScholar() + scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False) papers_by_id: dict[str, Paper] = {} for search_query in track(state.search_queries, description="Discovering papers"): diff --git a/src/litresearch/stages/enrichment.py b/src/litresearch/stages/enrichment.py index b9025b7..72765ef 100644 --- a/src/litresearch/stages/enrichment.py +++ b/src/litresearch/stages/enrichment.py @@ -35,9 +35,13 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: return state.model_copy(update={"current_stage": "enrichment"}) if settings.s2_api_key: - scholar = SemanticScholar(api_key=settings.s2_api_key) + scholar = SemanticScholar( + api_key=settings.s2_api_key, + timeout=settings.s2_timeout, + retry=False, + ) else: - scholar = SemanticScholar() + scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False) papers_by_id = {paper.paper_id: paper for paper in state.candidates} for batch in _chunk(list(papers_by_id), BATCH_SIZE): diff --git a/src/litresearch/stages/export.py b/src/litresearch/stages/export.py index b87179a..76bab97 100644 --- a/src/litresearch/stages/export.py +++ b/src/litresearch/stages/export.py @@ -157,6 +157,8 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: for paper in track(top_papers, description="Downloading PDFs"): if not paper.open_access_pdf_url: continue + if paper.pdf_downloaded: + continue pdf_bytes = download_pdf(paper.open_access_pdf_url) if pdf_bytes is None: continue diff --git a/tests/unit/test_analysis.py b/tests/unit/test_analysis.py new file mode 100644 index 0000000..c60473f --- /dev/null +++ b/tests/unit/test_analysis.py @@ -0,0 +1,58 @@ +import json + +from litresearch.config import Settings +from litresearch.models import Paper, PipelineState, ScreeningResult +from litresearch.stages.analysis import run + + +def test_analysis_saves_pdf_and_marks_candidate_downloaded(tmp_path, monkeypatch) -> None: + state = PipelineState( + questions=["q"], + candidates=[ + Paper( + paper_id="p1", + title="One", + abstract="abstract", + open_access_pdf_url="https://example.com/p1.pdf", + ) + ], + ranked_paper_ids=[], + current_stage="enrichment", + output_dir=str(tmp_path), + created_at="2026-03-09T16:00:00Z", + updated_at="2026-03-09T16:00:00Z", + ) + + import litresearch.stages.analysis as analysis_stage + + monkeypatch.setattr(analysis_stage, "load_prompt", lambda _name: "prompt") + monkeypatch.setattr( + analysis_stage, + "_screen_paper", + lambda paper, questions, settings, prompt: ScreeningResult( + paper_id=paper.paper_id, + relevance_score=100, + rationale="fit", + ), + ) + monkeypatch.setattr(analysis_stage, "download_pdf", lambda _url: b"%PDF-1.0") + monkeypatch.setattr(analysis_stage, "extract_text", lambda *_args, **_kwargs: "body") + monkeypatch.setattr( + analysis_stage, + "call_llm", + lambda settings, system_prompt, user_content: json.dumps( + { + "summary": "summary", + "key_findings": ["finding"], + "methodology": "experiment", + "relevance_score": 80, + "relevance_rationale": "fit", + } + ), + ) + + updated_state = run(state, Settings()) + + assert updated_state.candidates[0].pdf_downloaded is True + assert (tmp_path / "papers" / "p1.pdf").read_bytes() == b"%PDF-1.0" + assert len(updated_state.analyses) == 1 diff --git a/tests/unit/test_export.py b/tests/unit/test_export.py index 685b12f..fb4dfe7 100644 --- a/tests/unit/test_export.py +++ b/tests/unit/test_export.py @@ -62,3 +62,53 @@ def test_export_skips_missing_bibtex(tmp_path) -> None: bibtex = (tmp_path / "references.bib").read_text(encoding="utf-8") assert bibtex.strip() == "@article{p1}" + + +def test_export_skips_pdf_download_when_already_downloaded(tmp_path, monkeypatch) -> None: + state = PipelineState( + questions=["q"], + candidates=[ + Paper( + paper_id="p1", + title="One", + open_access_pdf_url="https://example.com/p1.pdf", + pdf_downloaded=True, + ) + ], + analyses=[ + AnalysisResult( + paper_id="p1", + summary="summary", + key_findings=["finding"], + methodology="experiment", + relevance_score=80, + relevance_rationale="fit", + ) + ], + ranked_paper_ids=["p1"], + current_stage="ranking", + output_dir=str(tmp_path), + created_at="2026-03-09T16:00:00Z", + updated_at="2026-03-09T16:00:00Z", + ) + + import litresearch.stages.export as export_stage + + monkeypatch.setattr( + export_stage, + "call_llm", + lambda settings, system_prompt, user_content, expect_json=False: "## Consensus\n\nDone.", + ) + + download_calls = 0 + + def fake_download(_url: str): + nonlocal download_calls + download_calls += 1 + return b"%PDF-1.0" + + monkeypatch.setattr(export_stage, "download_pdf", fake_download) + + run(state, Settings()) + + assert download_calls == 0 From e1c37b1cf7e9e9e0bfc824631f6e91798611af70 Mon Sep 17 00:00:00 2001 From: Silas Pignotti Date: Mon, 23 Mar 2026 16:03:04 +0100 Subject: [PATCH 2/6] fix(cli): immutable settings construction and output collision handling - Refactor _build_settings to use immutable Settings(**overrides) pattern - Add --overwrite flag to run command - Auto-increment output directory name when directory exists and is populated - Add tests for collision detection and overwrite behavior --- src/litresearch/cli.py | 27 ++++++++++++-------- src/litresearch/pipeline.py | 9 +++++++ tests/unit/test_cli.py | 1 + tests/unit/test_pipeline.py | 50 +++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_pipeline.py diff --git a/src/litresearch/cli.py b/src/litresearch/cli.py index 3e3412c..cf659fa 100644 --- a/src/litresearch/cli.py +++ b/src/litresearch/cli.py @@ -22,16 +22,17 @@ def _build_settings( threshold: int | None = None, ) -> Settings: """Load settings and apply CLI overrides.""" - settings = Settings() - if model is not None: - settings.default_model = model - if top_n is not None: - settings.top_n = top_n - if output_dir is not None: - settings.output_dir = output_dir - if threshold is not None: - settings.screening_threshold = threshold - return settings + overrides = { + key: value + for key, value in { + "default_model": model, + "top_n": top_n, + "output_dir": output_dir, + "screening_threshold": threshold, + }.items() + if value is not None + } + return Settings(**overrides) @app.command() @@ -65,6 +66,10 @@ def run( int | None, typer.Option("--threshold", help="Override the screening threshold."), ] = None, + overwrite: Annotated[ + bool, + typer.Option("--overwrite", help="Overwrite existing output directory."), + ] = False, ) -> None: """Run the literature research pipeline.""" settings = _build_settings( @@ -74,7 +79,7 @@ def run( threshold=threshold, ) - state = run_pipeline(questions, settings) + state = run_pipeline(questions, settings, overwrite=overwrite) console.print(f"[green]Run complete.[/green] Output: {state.output_dir}") diff --git a/src/litresearch/pipeline.py b/src/litresearch/pipeline.py index 8e9a8f9..bc34fde 100644 --- a/src/litresearch/pipeline.py +++ b/src/litresearch/pipeline.py @@ -32,6 +32,7 @@ def run_pipeline( questions: list[str], settings: Settings, resume_path: Path | None = None, + overwrite: bool = False, ) -> PipelineState: """Run the configured pipeline from scratch or from a saved state.""" if resume_path is not None: @@ -40,6 +41,14 @@ def run_pipeline( start_index = STAGE_ORDER.index(state.current_stage) + 1 else: output_dir = Path(settings.output_dir) + if output_dir.exists() and any(output_dir.iterdir()) and not overwrite: + base_name = output_dir.name + parent = output_dir.parent + counter = 2 + while output_dir.exists() and any(output_dir.iterdir()): + output_dir = parent / f"{base_name}-{counter}" + counter += 1 + console.print(f"[yellow]Output directory already exists. Using:[/yellow] {output_dir}") state = PipelineState( questions=questions, current_stage="start", diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index e9134b9..ad58db3 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -28,6 +28,7 @@ def test_run_help_shows_expected_options() -> None: assert "final top-N cutoff" in output assert "output directory" in output assert "screening threshold" in output + assert "Overwrite existing output directory" in output def test_resume_help_shows_expected_options() -> None: diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py new file mode 100644 index 0000000..5482e1d --- /dev/null +++ b/tests/unit/test_pipeline.py @@ -0,0 +1,50 @@ +from pathlib import Path + +from litresearch import pipeline +from litresearch.config import Settings + + +def test_run_pipeline_auto_increments_non_empty_output_dir(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr(pipeline, "STAGE_ORDER", []) + + base_output = tmp_path / "output" + base_output.mkdir() + (base_output / "existing.txt").write_text("data", encoding="utf-8") + + state = pipeline.run_pipeline( + questions=["q"], + settings=Settings(output_dir=str(base_output)), + ) + + assert state.output_dir == str(tmp_path / "output-2") + assert (tmp_path / "output-2").exists() + + +def test_run_pipeline_keeps_output_dir_when_overwrite_enabled(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr(pipeline, "STAGE_ORDER", []) + + base_output = tmp_path / "output" + base_output.mkdir() + (base_output / "existing.txt").write_text("data", encoding="utf-8") + + state = pipeline.run_pipeline( + questions=["q"], + settings=Settings(output_dir=str(base_output)), + overwrite=True, + ) + + assert state.output_dir == str(base_output) + + +def test_run_pipeline_keeps_empty_existing_output_dir(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr(pipeline, "STAGE_ORDER", []) + + base_output = tmp_path / "output" + base_output.mkdir() + + state = pipeline.run_pipeline( + questions=["q"], + settings=Settings(output_dir=str(base_output)), + ) + + assert state.output_dir == str(base_output) From 993589f15bdd8e6acf68afa511a750e965ba41f3 Mon Sep 17 00:00:00 2001 From: Silas Pignotti Date: Mon, 23 Mar 2026 16:04:18 +0100 Subject: [PATCH 3/6] fix: handle no-abstract papers and LLMError in query_gen - Write ScreeningResult with score=0 for papers without abstract - Wrap call_llm in try/except LLMError in query_gen with clear error message --- src/litresearch/stages/analysis.py | 7 +++++++ src/litresearch/stages/query_gen.py | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/litresearch/stages/analysis.py b/src/litresearch/stages/analysis.py index f77c545..bb19ca1 100644 --- a/src/litresearch/stages/analysis.py +++ b/src/litresearch/stages/analysis.py @@ -125,6 +125,13 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: passed_papers: list[Paper] = [] for paper in track(state.candidates, description="Screening papers"): if not paper.abstract: + screening_results.append( + ScreeningResult( + paper_id=paper.paper_id, + relevance_score=0, + rationale="no abstract available", + ) + ) continue screening_result = _screen_paper(paper, state.questions, settings, screening_prompt) diff --git a/src/litresearch/stages/query_gen.py b/src/litresearch/stages/query_gen.py index e83f2c3..06b4476 100644 --- a/src/litresearch/stages/query_gen.py +++ b/src/litresearch/stages/query_gen.py @@ -3,7 +3,7 @@ import json from litresearch.config import Settings -from litresearch.llm import call_llm +from litresearch.llm import LLMError, call_llm from litresearch.models import Facet, PipelineState, SearchQuery from litresearch.prompts import load_prompt @@ -14,7 +14,10 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: user_content = "Research questions:\n" + "\n".join( f"- {question}" for question in state.questions ) - response = call_llm(settings, prompt, user_content) + try: + response = call_llm(settings, prompt, user_content) + except LLMError as exc: + raise LLMError(f"Query generation failed: {exc}") from exc payload = json.loads(response) facets = [Facet.model_validate(item) for item in payload.get("facets", [])] From 34e415d75b513eb95ed39d7766e86e7ee0f2d2d9 Mon Sep 17 00:00:00 2001 From: Silas Pignotti Date: Mon, 23 Mar 2026 16:05:24 +0100 Subject: [PATCH 4/6] fix: rename litresearch.toml to example and unescape HTML entities - Rename litresearch.toml to litresearch.toml.example (git mv) - Add html.unescape() for title, abstract, venue in Paper.from_s2() --- litresearch.toml => litresearch.toml.example | 0 src/litresearch/models.py | 7 ++++--- 2 files changed, 4 insertions(+), 3 deletions(-) rename litresearch.toml => litresearch.toml.example (100%) diff --git a/litresearch.toml b/litresearch.toml.example similarity index 100% rename from litresearch.toml rename to litresearch.toml.example diff --git a/src/litresearch/models.py b/src/litresearch/models.py index 1ff8ca1..a88ad0a 100644 --- a/src/litresearch/models.py +++ b/src/litresearch/models.py @@ -1,5 +1,6 @@ """Shared data models for the litresearch pipeline.""" +import html from pathlib import Path from typing import Protocol @@ -70,12 +71,12 @@ def from_s2(cls, s2_paper: S2PaperLike) -> "Paper": return cls( paper_id=s2_paper.paperId, corpus_id=s2_paper.corpusId, - title=s2_paper.title, - abstract=s2_paper.abstract, + title=html.unescape(s2_paper.title), + abstract=html.unescape(s2_paper.abstract) if s2_paper.abstract else None, authors=[author.name for author in authors if author.name], year=s2_paper.year, citation_count=s2_paper.citationCount or 0, - venue=s2_paper.venue, + venue=html.unescape(s2_paper.venue) if s2_paper.venue else None, doi=external_ids.get("DOI"), open_access_pdf_url=open_access_pdf.get("url"), bibtex=citation_styles.get("bibtex"), From ac9f8991dec613a54dcd779810c0521762ebb4e5 Mon Sep 17 00:00:00 2001 From: Silas Pignotti Date: Mon, 23 Mar 2026 16:08:45 +0100 Subject: [PATCH 5/6] test: add stage-level tests for query_gen, screening, discovery - Test query generation with successful LLM response and error handling - Test screening behavior for no-abstract papers and JSON parse failures - Test discovery S2 client configuration and paper deduplication --- tests/unit/test_stages_discovery.py | 112 ++++++++++++++++++++++++++++ tests/unit/test_stages_query_gen.py | 62 +++++++++++++++ tests/unit/test_stages_screening.py | 64 ++++++++++++++++ 3 files changed, 238 insertions(+) create mode 100644 tests/unit/test_stages_discovery.py create mode 100644 tests/unit/test_stages_query_gen.py create mode 100644 tests/unit/test_stages_screening.py diff --git a/tests/unit/test_stages_discovery.py b/tests/unit/test_stages_discovery.py new file mode 100644 index 0000000..5460d29 --- /dev/null +++ b/tests/unit/test_stages_discovery.py @@ -0,0 +1,112 @@ +"""Tests for discovery stage.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from litresearch.config import Settings +from litresearch.models import PipelineState, SearchQuery +from litresearch.stages.discovery import run + + +class TestDiscoveryStage: + """Test paper discovery behavior.""" + + def test_s2_client_configured_with_timeout_and_retry(self, tmp_path) -> None: + """Test that S2 client is created with timeout and retry=False.""" + settings = Settings( + s2_api_key=None, + s2_timeout=10, + max_results_per_query=10, + ) + + query = SearchQuery(query="machine learning", facet="AI") + state = PipelineState( + questions=["What is ML?"], + search_queries=[query], + current_stage="query_gen", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + mock_scholar = MagicMock() + mock_scholar.search_paper.return_value = SimpleNamespace(items=[]) + + with patch( + "litresearch.stages.discovery.SemanticScholar", + return_value=mock_scholar, + ) as mock_init: + run(state, settings) + + mock_init.assert_called_once_with(timeout=10, retry=False) + + def test_s2_client_with_api_key(self, tmp_path) -> None: + """Test S2 client creation with API key.""" + settings = Settings( + s2_api_key="test-key", + s2_timeout=15, + max_results_per_query=10, + ) + + query = SearchQuery(query="AI", facet="Tech") + state = PipelineState( + questions=["What is AI?"], + search_queries=[query], + current_stage="query_gen", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + mock_scholar = MagicMock() + mock_scholar.search_paper.return_value = SimpleNamespace(items=[]) + + with patch( + "litresearch.stages.discovery.SemanticScholar", + return_value=mock_scholar, + ) as mock_init: + run(state, settings) + + mock_init.assert_called_once_with(api_key="test-key", timeout=15, retry=False) + + def test_paper_deduplication_by_id(self, tmp_path) -> None: + """Test that duplicate papers are deduplicated by paper_id.""" + settings = Settings( + s2_api_key=None, + s2_timeout=10, + max_results_per_query=10, + ) + + query1 = SearchQuery(query="query1", facet="Facet1") + query2 = SearchQuery(query="query2", facet="Facet2") + state = PipelineState( + questions=["Question?"], + search_queries=[query1, query2], + current_stage="query_gen", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + mock_paper = SimpleNamespace( + paperId="same-id", + title="Same Paper", + corpusId=123, + abstract="Abstract", + authors=[SimpleNamespace(name="Author")], + year=2024, + citationCount=10, + venue="Venue", + externalIds={"DOI": "10.1234"}, + openAccessPdf={"url": "http://example.com"}, + citationStyles={"bibtex": "@article{...}"}, + ) + + mock_scholar = MagicMock() + mock_scholar.search_paper.return_value = SimpleNamespace(items=[mock_paper]) + + with patch("litresearch.stages.discovery.SemanticScholar", return_value=mock_scholar): + result = run(state, settings) + + assert len(result.candidates) == 1 + assert result.candidates[0].paper_id == "same-id" diff --git a/tests/unit/test_stages_query_gen.py b/tests/unit/test_stages_query_gen.py new file mode 100644 index 0000000..c77f238 --- /dev/null +++ b/tests/unit/test_stages_query_gen.py @@ -0,0 +1,62 @@ +"""Tests for query generation stage.""" + +import json +from unittest.mock import patch + +import pytest + +from litresearch.config import Settings +from litresearch.llm import LLMError +from litresearch.models import PipelineState +from litresearch.stages.query_gen import run + + +class TestQueryGenStage: + """Test query generation stage behavior.""" + + def test_successful_query_generation(self, tmp_path) -> None: + """Test successful facet and query generation.""" + settings = Settings(default_model="test-model") + state = PipelineState( + questions=["What is machine learning?"], + current_stage="start", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + mock_response = json.dumps( + { + "facets": [ + { + "name": "Supervised Learning", + "queries": ["supervised learning algorithms"], + }, + {"name": "Deep Learning", "queries": ["neural networks", "deep learning"]}, + ] + } + ) + + with patch("litresearch.stages.query_gen.call_llm", return_value=mock_response): + result = run(state, settings) + + assert len(result.facets) == 2 + assert len(result.search_queries) == 3 # 1 + 2 queries + assert result.current_stage == "query_gen" + + def test_llm_error_raises_with_message(self, tmp_path) -> None: + """Test that LLMError is re-raised with clear message.""" + settings = Settings(default_model="test-model") + state = PipelineState( + questions=["What is AI?"], + current_stage="start", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + with patch("litresearch.stages.query_gen.call_llm", side_effect=LLMError("API error")): + with pytest.raises(LLMError) as exc_info: + run(state, settings) + + assert "Query generation failed" in str(exc_info.value) diff --git a/tests/unit/test_stages_screening.py b/tests/unit/test_stages_screening.py new file mode 100644 index 0000000..1cde0ff --- /dev/null +++ b/tests/unit/test_stages_screening.py @@ -0,0 +1,64 @@ +"""Tests for screening and analysis stage.""" + +from unittest.mock import patch + +from litresearch.config import Settings +from litresearch.models import Paper, PipelineState +from litresearch.stages.analysis import run + + +class TestScreeningStage: + """Test paper screening behavior.""" + + def test_paper_without_abstract_gets_zero_score(self, tmp_path) -> None: + """Test that papers without abstract get screening result with score 0.""" + settings = Settings( + default_model="test-model", + screening_threshold=50, + pdf_first_pages=4, + pdf_last_pages=2, + ) + + paper_no_abstract = Paper( + paper_id="123", + title="Test Paper", + authors=["Author"], + year=2024, + abstract=None, + ) + + state = PipelineState( + questions=["Test question?"], + candidates=[paper_no_abstract], + current_stage="enrichment", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + with patch("litresearch.stages.analysis.load_prompt", return_value="prompt"): + with patch("litresearch.stages.analysis.call_llm"): + result = run(state, settings) + + assert len(result.screening_results) == 1 + assert result.screening_results[0].relevance_score == 0 + assert "no abstract available" in result.screening_results[0].rationale + + def test_json_parse_failure_skips_paper(self) -> None: + """Test that JSON parse failure returns None and skips paper.""" + from litresearch.stages.analysis import _screen_paper + + paper = Paper( + paper_id="456", + title="Another Paper", + authors=["Author"], + year=2024, + abstract="This is an abstract", + ) + + settings = Settings(default_model="test-model") + + with patch("litresearch.stages.analysis.call_llm", return_value="invalid json"): + result = _screen_paper(paper, ["question"], settings, "prompt") + + assert result is None From 42f744f2087a67bac49786af216e14d7a694fb3c Mon Sep 17 00:00:00 2001 From: Silas Pignotti Date: Mon, 23 Mar 2026 16:09:50 +0100 Subject: [PATCH 6/6] chore: minor polish - comments, summary output, threshold default - Add comment for BATCH_SIZE in enrichment.py - Add run summary block in pipeline.py with timing and counts - Change screening_threshold default from 40 to 60 with documentation --- src/litresearch/config.py | 2 +- src/litresearch/pipeline.py | 11 +++++++++++ src/litresearch/stages/enrichment.py | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/litresearch/config.py b/src/litresearch/config.py index 6ef3022..e3a3f22 100644 --- a/src/litresearch/config.py +++ b/src/litresearch/config.py @@ -37,7 +37,7 @@ def settings_customise_sources( s2_api_key: str | None = None s2_timeout: int = 10 # seconds; SemanticScholar client timeout default_model: str = "openai/gpt-4o-mini" - screening_threshold: int = 40 + screening_threshold: int = 60 # 0-100; papers below this are filtered before analysis top_n: int = 20 max_results_per_query: int = 20 pdf_first_pages: int = 4 diff --git a/src/litresearch/pipeline.py b/src/litresearch/pipeline.py index bc34fde..a454731 100644 --- a/src/litresearch/pipeline.py +++ b/src/litresearch/pipeline.py @@ -35,6 +35,8 @@ def run_pipeline( overwrite: bool = False, ) -> PipelineState: """Run the configured pipeline from scratch or from a saved state.""" + start_time = time.perf_counter() + if resume_path is not None: state = PipelineState.load(resume_path) output_dir = Path(state.output_dir) @@ -78,4 +80,13 @@ def run_pipeline( elapsed = time.perf_counter() - started console.print(f"[green]Completed[/green] {stage_name} in {elapsed:.2f}s") + # Print run summary + console.print("\n[bold]Run Summary[/bold]") + console.print(f" Total time: {time.perf_counter() - start_time:.1f}s") + console.print(f" Candidates: {len(state.candidates)}") + console.print(f" Screened: {len(state.screening_results)}") + console.print(f" Analyzed: {len(state.analyses)}") + console.print(f" Exported: {len(state.ranked_paper_ids)}") + console.print(f" Output: {state.output_dir}") + return state diff --git a/src/litresearch/stages/enrichment.py b/src/litresearch/stages/enrichment.py index 72765ef..5e522da 100644 --- a/src/litresearch/stages/enrichment.py +++ b/src/litresearch/stages/enrichment.py @@ -22,7 +22,7 @@ "citationStyles", ] -BATCH_SIZE = 500 +BATCH_SIZE = 500 # S2 /papers batch endpoint limit def _chunk(items: list[str], size: int) -> list[list[str]]: