From 2761d3f3e506b9d754e90316b022221f7620d544 Mon Sep 17 00:00:00 2001 From: Silas Pignotti Date: Mon, 23 Mar 2026 17:45:41 +0100 Subject: [PATCH] feat(screening): add global top-percent selection for deep analysis Introduce configurable screening selection modes (top_percent, top_k, threshold) with top_percent defaulting to 30%. Refactor analysis stage to rank screened papers globally before deep analysis and add tests for mode behavior, tie-breaking, and config validation. --- litresearch.toml.example | 2 + src/litresearch/config.py | 7 +- src/litresearch/stages/analysis.py | 49 ++++++++- tests/unit/test_stages_screening.py | 163 +++++++++++++++++++++++++--- 4 files changed, 201 insertions(+), 20 deletions(-) diff --git a/litresearch.toml.example b/litresearch.toml.example index d29add2..4760656 100644 --- a/litresearch.toml.example +++ b/litresearch.toml.example @@ -1,4 +1,6 @@ default_model = "openrouter/openai/gpt-4o-mini" +screening_selection_mode = "top_percent" +screening_top_percent = 0.3 screening_threshold = 60 top_n = 5 max_results_per_query = 5 diff --git a/src/litresearch/config.py b/src/litresearch/config.py index 19c851c..7c1daea 100644 --- a/src/litresearch/config.py +++ b/src/litresearch/config.py @@ -1,5 +1,7 @@ """Application settings for litresearch.""" +from typing import Literal + from pydantic import computed_field from pydantic_settings import BaseSettings, SettingsConfigDict, TomlConfigSettingsSource @@ -38,7 +40,10 @@ def settings_customise_sources( s2_timeout: int = 10 # seconds; SemanticScholar client timeout s2_requests_per_second: float = 1.0 # max S2 request rate across endpoints default_model: str = "openai/gpt-4o-mini" - screening_threshold: int = 60 # 0-100; papers below this are filtered before analysis + screening_selection_mode: Literal["top_percent", "threshold", "top_k"] = "top_percent" + screening_top_percent: float = 0.3 # 0-1; used when screening_selection_mode=top_percent + screening_top_k: int | None = None # used when screening_selection_mode=top_k + screening_threshold: int = 60 # 0-100; used when screening_selection_mode=threshold top_n: int = 20 max_results_per_query: int = 20 pdf_first_pages: int = 4 diff --git a/src/litresearch/stages/analysis.py b/src/litresearch/stages/analysis.py index bb19ca1..3abaa3e 100644 --- a/src/litresearch/stages/analysis.py +++ b/src/litresearch/stages/analysis.py @@ -1,6 +1,7 @@ """Stage 4: screening and extended paper analysis.""" import json +import math from pathlib import Path from rich.console import Console @@ -15,6 +16,45 @@ console = Console() +def _select_papers_for_analysis( + screened_papers: list[tuple[Paper, ScreeningResult, int]], + settings: Settings, +) -> list[Paper]: + sorted_screened = sorted( + screened_papers, + key=lambda item: ( + -item[1].relevance_score, + -item[0].citation_count, + -(item[0].year or 0), + item[2], + ), + ) + + if settings.screening_selection_mode == "threshold": + return [ + paper + for paper, screening_result, _ in sorted_screened + if screening_result.relevance_score >= settings.screening_threshold + ] + + if settings.screening_selection_mode == "top_k": + if settings.screening_top_k is None or settings.screening_top_k <= 0: + raise ValueError("screening_top_k must be > 0 when screening_selection_mode=top_k") + return [paper for paper, _, _ in sorted_screened[: settings.screening_top_k]] + + if settings.screening_selection_mode == "top_percent": + if not (0 < settings.screening_top_percent <= 1): + raise ValueError( + "screening_top_percent must be in (0, 1] when screening_selection_mode=top_percent" + ) + if not sorted_screened: + return [] + selected_count = max(1, math.ceil(len(sorted_screened) * settings.screening_top_percent)) + return [paper for paper, _, _ in sorted_screened[:selected_count]] + + raise ValueError(f"Unsupported screening_selection_mode: {settings.screening_selection_mode}") + + def _screen_paper( paper: Paper, questions: list[str], @@ -122,8 +162,8 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: papers_by_id = {paper.paper_id: paper for paper in state.candidates} screening_results: list[ScreeningResult] = [] - passed_papers: list[Paper] = [] - for paper in track(state.candidates, description="Screening papers"): + screened_papers: list[tuple[Paper, ScreeningResult, int]] = [] + for index, paper in enumerate(track(state.candidates, description="Screening papers")): if not paper.abstract: screening_results.append( ScreeningResult( @@ -139,8 +179,9 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: continue screening_results.append(screening_result) - if screening_result.relevance_score >= settings.screening_threshold: - passed_papers.append(paper) + screened_papers.append((paper, screening_result, index)) + + passed_papers = _select_papers_for_analysis(screened_papers, settings) analyses: list[AnalysisResult] = [] for paper in track(passed_papers, description="Analyzing papers"): diff --git a/tests/unit/test_stages_screening.py b/tests/unit/test_stages_screening.py index 1cde0ff..ffbddf4 100644 --- a/tests/unit/test_stages_screening.py +++ b/tests/unit/test_stages_screening.py @@ -1,23 +1,56 @@ """Tests for screening and analysis stage.""" +from collections.abc import Callable from unittest.mock import patch +import pytest + from litresearch.config import Settings -from litresearch.models import Paper, PipelineState +from litresearch.models import AnalysisResult, Paper, PipelineState, ScreeningResult from litresearch.stages.analysis import run class TestScreeningStage: """Test paper screening behavior.""" + @staticmethod + def _state_with_papers(tmp_path, papers: list[Paper]) -> PipelineState: + return PipelineState( + questions=["Test question?"], + candidates=papers, + current_stage="enrichment", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + @staticmethod + def _analysis_stub(analyzed_ids: list[str]) -> Callable[..., tuple[AnalysisResult, bool]]: + def _stub( + paper: Paper, + questions: list[str], + settings: Settings, + prompt: str, + output_dir: str, + ) -> tuple[AnalysisResult, bool]: + analyzed_ids.append(paper.paper_id) + return ( + AnalysisResult( + paper_id=paper.paper_id, + summary="summary", + key_findings=[], + methodology="method", + relevance_score=paper.citation_count, + relevance_rationale="rationale", + ), + False, + ) + + return _stub + def test_paper_without_abstract_gets_zero_score(self, tmp_path) -> None: """Test that papers without abstract get screening result with score 0.""" - settings = Settings( - default_model="test-model", - screening_threshold=50, - pdf_first_pages=4, - pdf_last_pages=2, - ) + settings = Settings(default_model="test-model", screening_selection_mode="top_percent") paper_no_abstract = Paper( paper_id="123", @@ -27,14 +60,7 @@ def test_paper_without_abstract_gets_zero_score(self, tmp_path) -> None: abstract=None, ) - state = PipelineState( - questions=["Test question?"], - candidates=[paper_no_abstract], - current_stage="enrichment", - output_dir=str(tmp_path), - created_at="2024-01-01", - updated_at="2024-01-01", - ) + state = self._state_with_papers(tmp_path, [paper_no_abstract]) with patch("litresearch.stages.analysis.load_prompt", return_value="prompt"): with patch("litresearch.stages.analysis.call_llm"): @@ -44,6 +70,113 @@ def test_paper_without_abstract_gets_zero_score(self, tmp_path) -> None: assert result.screening_results[0].relevance_score == 0 assert "no abstract available" in result.screening_results[0].rationale + def test_top_percent_selection_analyzes_global_top_share(self, tmp_path, monkeypatch) -> None: + """Test global top-percent selection after screening.""" + settings = Settings(screening_selection_mode="top_percent", screening_top_percent=0.4) + papers = [ + Paper(paper_id="p1", title="P1", abstract="a", citation_count=1), + Paper(paper_id="p2", title="P2", abstract="a", citation_count=2), + Paper(paper_id="p3", title="P3", abstract="a", citation_count=3), + Paper(paper_id="p4", title="P4", abstract="a", citation_count=4), + Paper(paper_id="p5", title="P5", abstract="a", citation_count=5), + ] + scores = {"p1": 90, "p2": 80, "p3": 70, "p4": 60, "p5": 50} + analyzed_ids: list[str] = [] + + monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") + monkeypatch.setattr( + "litresearch.stages.analysis._screen_paper", + lambda paper, questions, settings, prompt: ScreeningResult( + paper_id=paper.paper_id, + relevance_score=scores[paper.paper_id], + rationale="fit", + ), + ) + monkeypatch.setattr( + "litresearch.stages.analysis._analyze_paper", + self._analysis_stub(analyzed_ids), + ) + + run(self._state_with_papers(tmp_path, papers), settings) + + assert analyzed_ids == ["p1", "p2"] + + def test_top_k_selection_uses_tiebreakers(self, tmp_path, monkeypatch) -> None: + """Test top-k selection uses score, citation_count, year, then order.""" + settings = Settings(screening_selection_mode="top_k", screening_top_k=1) + papers = [ + Paper(paper_id="p1", title="P1", abstract="a", citation_count=10, year=2020), + Paper(paper_id="p2", title="P2", abstract="a", citation_count=20, year=2019), + Paper(paper_id="p3", title="P3", abstract="a", citation_count=1, year=2024), + ] + scores = {"p1": 80, "p2": 80, "p3": 70} + analyzed_ids: list[str] = [] + + monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") + monkeypatch.setattr( + "litresearch.stages.analysis._screen_paper", + lambda paper, questions, settings, prompt: ScreeningResult( + paper_id=paper.paper_id, + relevance_score=scores[paper.paper_id], + rationale="fit", + ), + ) + monkeypatch.setattr( + "litresearch.stages.analysis._analyze_paper", + self._analysis_stub(analyzed_ids), + ) + + run(self._state_with_papers(tmp_path, papers), settings) + + assert analyzed_ids == ["p2"] + + def test_threshold_selection_mode_still_supported(self, tmp_path, monkeypatch) -> None: + """Test legacy threshold mode still controls deep analysis.""" + settings = Settings(screening_selection_mode="threshold", screening_threshold=70) + papers = [ + Paper(paper_id="p1", title="P1", abstract="a"), + Paper(paper_id="p2", title="P2", abstract="a"), + Paper(paper_id="p3", title="P3", abstract="a"), + ] + scores = {"p1": 90, "p2": 70, "p3": 69} + analyzed_ids: list[str] = [] + + monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") + monkeypatch.setattr( + "litresearch.stages.analysis._screen_paper", + lambda paper, questions, settings, prompt: ScreeningResult( + paper_id=paper.paper_id, + relevance_score=scores[paper.paper_id], + rationale="fit", + ), + ) + monkeypatch.setattr( + "litresearch.stages.analysis._analyze_paper", + self._analysis_stub(analyzed_ids), + ) + + run(self._state_with_papers(tmp_path, papers), settings) + + assert analyzed_ids == ["p1", "p2"] + + def test_invalid_top_percent_raises_value_error(self, tmp_path, monkeypatch) -> None: + """Test invalid top-percent config fails fast with clear error.""" + settings = Settings(screening_selection_mode="top_percent", screening_top_percent=0.0) + papers = [Paper(paper_id="p1", title="P1", abstract="a")] + + monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") + monkeypatch.setattr( + "litresearch.stages.analysis._screen_paper", + lambda paper, questions, settings, prompt: ScreeningResult( + paper_id=paper.paper_id, + relevance_score=90, + rationale="fit", + ), + ) + + with pytest.raises(ValueError, match="screening_top_percent"): + run(self._state_with_papers(tmp_path, papers), settings) + def test_json_parse_failure_skips_paper(self) -> None: """Test that JSON parse failure returns None and skips paper.""" from litresearch.stages.analysis import _screen_paper