Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litresearch.toml.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
default_model = "openrouter/openai/gpt-4o-mini"
screening_selection_mode = "top_percent"
screening_top_percent = 0.3
screening_threshold = 60
top_n = 5
max_results_per_query = 5
Expand Down
7 changes: 6 additions & 1 deletion src/litresearch/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Application settings for litresearch."""

from typing import Literal

from pydantic import computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict, TomlConfigSettingsSource

Expand Down Expand Up @@ -38,7 +40,10 @@ def settings_customise_sources(
s2_timeout: int = 10 # seconds; SemanticScholar client timeout
s2_requests_per_second: float = 1.0 # max S2 request rate across endpoints
default_model: str = "openai/gpt-4o-mini"
screening_threshold: int = 60 # 0-100; papers below this are filtered before analysis
screening_selection_mode: Literal["top_percent", "threshold", "top_k"] = "top_percent"
screening_top_percent: float = 0.3 # 0-1; used when screening_selection_mode=top_percent
screening_top_k: int | None = None # used when screening_selection_mode=top_k
screening_threshold: int = 60 # 0-100; used when screening_selection_mode=threshold
top_n: int = 20
max_results_per_query: int = 20
pdf_first_pages: int = 4
Expand Down
49 changes: 45 additions & 4 deletions src/litresearch/stages/analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Stage 4: screening and extended paper analysis."""

import json
import math
from pathlib import Path

from rich.console import Console
Expand All @@ -15,6 +16,45 @@
console = Console()


def _select_papers_for_analysis(
screened_papers: list[tuple[Paper, ScreeningResult, int]],
settings: Settings,
) -> list[Paper]:
sorted_screened = sorted(
screened_papers,
key=lambda item: (
-item[1].relevance_score,
-item[0].citation_count,
-(item[0].year or 0),
item[2],
),
)

if settings.screening_selection_mode == "threshold":
return [
paper
for paper, screening_result, _ in sorted_screened
if screening_result.relevance_score >= settings.screening_threshold
]

if settings.screening_selection_mode == "top_k":
if settings.screening_top_k is None or settings.screening_top_k <= 0:
raise ValueError("screening_top_k must be > 0 when screening_selection_mode=top_k")
return [paper for paper, _, _ in sorted_screened[: settings.screening_top_k]]

if settings.screening_selection_mode == "top_percent":
if not (0 < settings.screening_top_percent <= 1):
raise ValueError(
"screening_top_percent must be in (0, 1] when screening_selection_mode=top_percent"
)
if not sorted_screened:
return []
selected_count = max(1, math.ceil(len(sorted_screened) * settings.screening_top_percent))
return [paper for paper, _, _ in sorted_screened[:selected_count]]

raise ValueError(f"Unsupported screening_selection_mode: {settings.screening_selection_mode}")


def _screen_paper(
paper: Paper,
questions: list[str],
Expand Down Expand Up @@ -122,8 +162,8 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
papers_by_id = {paper.paper_id: paper for paper in state.candidates}

screening_results: list[ScreeningResult] = []
passed_papers: list[Paper] = []
for paper in track(state.candidates, description="Screening papers"):
screened_papers: list[tuple[Paper, ScreeningResult, int]] = []
for index, paper in enumerate(track(state.candidates, description="Screening papers")):
if not paper.abstract:
screening_results.append(
ScreeningResult(
Expand All @@ -139,8 +179,9 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
continue

screening_results.append(screening_result)
if screening_result.relevance_score >= settings.screening_threshold:
passed_papers.append(paper)
screened_papers.append((paper, screening_result, index))

passed_papers = _select_papers_for_analysis(screened_papers, settings)

analyses: list[AnalysisResult] = []
for paper in track(passed_papers, description="Analyzing papers"):
Expand Down
163 changes: 148 additions & 15 deletions tests/unit/test_stages_screening.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,56 @@
"""Tests for screening and analysis stage."""

from collections.abc import Callable
from unittest.mock import patch

import pytest

from litresearch.config import Settings
from litresearch.models import Paper, PipelineState
from litresearch.models import AnalysisResult, Paper, PipelineState, ScreeningResult
from litresearch.stages.analysis import run


class TestScreeningStage:
"""Test paper screening behavior."""

@staticmethod
def _state_with_papers(tmp_path, papers: list[Paper]) -> PipelineState:
return PipelineState(
questions=["Test question?"],
candidates=papers,
current_stage="enrichment",
output_dir=str(tmp_path),
created_at="2024-01-01",
updated_at="2024-01-01",
)

@staticmethod
def _analysis_stub(analyzed_ids: list[str]) -> Callable[..., tuple[AnalysisResult, bool]]:
def _stub(
paper: Paper,
questions: list[str],
settings: Settings,
prompt: str,
output_dir: str,
) -> tuple[AnalysisResult, bool]:
analyzed_ids.append(paper.paper_id)
return (
AnalysisResult(
paper_id=paper.paper_id,
summary="summary",
key_findings=[],
methodology="method",
relevance_score=paper.citation_count,
relevance_rationale="rationale",
),
False,
)

return _stub

def test_paper_without_abstract_gets_zero_score(self, tmp_path) -> None:
"""Test that papers without abstract get screening result with score 0."""
settings = Settings(
default_model="test-model",
screening_threshold=50,
pdf_first_pages=4,
pdf_last_pages=2,
)
settings = Settings(default_model="test-model", screening_selection_mode="top_percent")

paper_no_abstract = Paper(
paper_id="123",
Expand All @@ -27,14 +60,7 @@ def test_paper_without_abstract_gets_zero_score(self, tmp_path) -> None:
abstract=None,
)

state = PipelineState(
questions=["Test question?"],
candidates=[paper_no_abstract],
current_stage="enrichment",
output_dir=str(tmp_path),
created_at="2024-01-01",
updated_at="2024-01-01",
)
state = self._state_with_papers(tmp_path, [paper_no_abstract])

with patch("litresearch.stages.analysis.load_prompt", return_value="prompt"):
with patch("litresearch.stages.analysis.call_llm"):
Expand All @@ -44,6 +70,113 @@ def test_paper_without_abstract_gets_zero_score(self, tmp_path) -> None:
assert result.screening_results[0].relevance_score == 0
assert "no abstract available" in result.screening_results[0].rationale

def test_top_percent_selection_analyzes_global_top_share(self, tmp_path, monkeypatch) -> None:
"""Test global top-percent selection after screening."""
settings = Settings(screening_selection_mode="top_percent", screening_top_percent=0.4)
papers = [
Paper(paper_id="p1", title="P1", abstract="a", citation_count=1),
Paper(paper_id="p2", title="P2", abstract="a", citation_count=2),
Paper(paper_id="p3", title="P3", abstract="a", citation_count=3),
Paper(paper_id="p4", title="P4", abstract="a", citation_count=4),
Paper(paper_id="p5", title="P5", abstract="a", citation_count=5),
]
scores = {"p1": 90, "p2": 80, "p3": 70, "p4": 60, "p5": 50}
analyzed_ids: list[str] = []

monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt")
monkeypatch.setattr(
"litresearch.stages.analysis._screen_paper",
lambda paper, questions, settings, prompt: ScreeningResult(
paper_id=paper.paper_id,
relevance_score=scores[paper.paper_id],
rationale="fit",
),
)
monkeypatch.setattr(
"litresearch.stages.analysis._analyze_paper",
self._analysis_stub(analyzed_ids),
)

run(self._state_with_papers(tmp_path, papers), settings)

assert analyzed_ids == ["p1", "p2"]

def test_top_k_selection_uses_tiebreakers(self, tmp_path, monkeypatch) -> None:
"""Test top-k selection uses score, citation_count, year, then order."""
settings = Settings(screening_selection_mode="top_k", screening_top_k=1)
papers = [
Paper(paper_id="p1", title="P1", abstract="a", citation_count=10, year=2020),
Paper(paper_id="p2", title="P2", abstract="a", citation_count=20, year=2019),
Paper(paper_id="p3", title="P3", abstract="a", citation_count=1, year=2024),
]
scores = {"p1": 80, "p2": 80, "p3": 70}
analyzed_ids: list[str] = []

monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt")
monkeypatch.setattr(
"litresearch.stages.analysis._screen_paper",
lambda paper, questions, settings, prompt: ScreeningResult(
paper_id=paper.paper_id,
relevance_score=scores[paper.paper_id],
rationale="fit",
),
)
monkeypatch.setattr(
"litresearch.stages.analysis._analyze_paper",
self._analysis_stub(analyzed_ids),
)

run(self._state_with_papers(tmp_path, papers), settings)

assert analyzed_ids == ["p2"]

def test_threshold_selection_mode_still_supported(self, tmp_path, monkeypatch) -> None:
"""Test legacy threshold mode still controls deep analysis."""
settings = Settings(screening_selection_mode="threshold", screening_threshold=70)
papers = [
Paper(paper_id="p1", title="P1", abstract="a"),
Paper(paper_id="p2", title="P2", abstract="a"),
Paper(paper_id="p3", title="P3", abstract="a"),
]
scores = {"p1": 90, "p2": 70, "p3": 69}
analyzed_ids: list[str] = []

monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt")
monkeypatch.setattr(
"litresearch.stages.analysis._screen_paper",
lambda paper, questions, settings, prompt: ScreeningResult(
paper_id=paper.paper_id,
relevance_score=scores[paper.paper_id],
rationale="fit",
),
)
monkeypatch.setattr(
"litresearch.stages.analysis._analyze_paper",
self._analysis_stub(analyzed_ids),
)

run(self._state_with_papers(tmp_path, papers), settings)

assert analyzed_ids == ["p1", "p2"]

def test_invalid_top_percent_raises_value_error(self, tmp_path, monkeypatch) -> None:
"""Test invalid top-percent config fails fast with clear error."""
settings = Settings(screening_selection_mode="top_percent", screening_top_percent=0.0)
papers = [Paper(paper_id="p1", title="P1", abstract="a")]

monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt")
monkeypatch.setattr(
"litresearch.stages.analysis._screen_paper",
lambda paper, questions, settings, prompt: ScreeningResult(
paper_id=paper.paper_id,
relevance_score=90,
rationale="fit",
),
)

with pytest.raises(ValueError, match="screening_top_percent"):
run(self._state_with_papers(tmp_path, papers), settings)

def test_json_parse_failure_skips_paper(self) -> None:
"""Test that JSON parse failure returns None and skips paper."""
from litresearch.stages.analysis import _screen_paper
Expand Down