Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
27 changes: 16 additions & 11 deletions src/litresearch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ def _build_settings(
threshold: int | None = None,
) -> Settings:
"""Load settings and apply CLI overrides."""
settings = Settings()
if model is not None:
settings.default_model = model
if top_n is not None:
settings.top_n = top_n
if output_dir is not None:
settings.output_dir = output_dir
if threshold is not None:
settings.screening_threshold = threshold
return settings
overrides = {
key: value
for key, value in {
"default_model": model,
"top_n": top_n,
"output_dir": output_dir,
"screening_threshold": threshold,
}.items()
if value is not None
}
return Settings(**overrides)


@app.command()
Expand Down Expand Up @@ -65,6 +66,10 @@ def run(
int | None,
typer.Option("--threshold", help="Override the screening threshold."),
] = None,
overwrite: Annotated[
bool,
typer.Option("--overwrite", help="Overwrite existing output directory."),
] = False,
) -> None:
"""Run the literature research pipeline."""
settings = _build_settings(
Expand All @@ -74,7 +79,7 @@ def run(
threshold=threshold,
)

state = run_pipeline(questions, settings)
state = run_pipeline(questions, settings, overwrite=overwrite)
console.print(f"[green]Run complete.[/green] Output: {state.output_dir}")


Expand Down
3 changes: 2 additions & 1 deletion src/litresearch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def settings_customise_sources(
anthropic_api_key: str | None = None
openrouter_api_key: str | None = None
s2_api_key: str | None = None
s2_timeout: int = 10 # seconds; SemanticScholar client timeout
default_model: str = "openai/gpt-4o-mini"
screening_threshold: int = 40
screening_threshold: int = 60 # 0-100; papers below this are filtered before analysis
top_n: int = 20
max_results_per_query: int = 20
pdf_first_pages: int = 4
Expand Down
7 changes: 4 additions & 3 deletions src/litresearch/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Shared data models for the litresearch pipeline."""

import html
from pathlib import Path
from typing import Protocol

Expand Down Expand Up @@ -70,12 +71,12 @@ def from_s2(cls, s2_paper: S2PaperLike) -> "Paper":
return cls(
paper_id=s2_paper.paperId,
corpus_id=s2_paper.corpusId,
title=s2_paper.title,
abstract=s2_paper.abstract,
title=html.unescape(s2_paper.title),
abstract=html.unescape(s2_paper.abstract) if s2_paper.abstract else None,
authors=[author.name for author in authors if author.name],
year=s2_paper.year,
citation_count=s2_paper.citationCount or 0,
venue=s2_paper.venue,
venue=html.unescape(s2_paper.venue) if s2_paper.venue else None,
doi=external_ids.get("DOI"),
open_access_pdf_url=open_access_pdf.get("url"),
bibtex=citation_styles.get("bibtex"),
Expand Down
20 changes: 20 additions & 0 deletions src/litresearch/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,25 @@ def run_pipeline(
questions: list[str],
settings: Settings,
resume_path: Path | None = None,
overwrite: bool = False,
) -> PipelineState:
"""Run the configured pipeline from scratch or from a saved state."""
start_time = time.perf_counter()

if resume_path is not None:
state = PipelineState.load(resume_path)
output_dir = Path(state.output_dir)
start_index = STAGE_ORDER.index(state.current_stage) + 1
else:
output_dir = Path(settings.output_dir)
if output_dir.exists() and any(output_dir.iterdir()) and not overwrite:
base_name = output_dir.name
parent = output_dir.parent
counter = 2
while output_dir.exists() and any(output_dir.iterdir()):
output_dir = parent / f"{base_name}-{counter}"
counter += 1
console.print(f"[yellow]Output directory already exists. Using:[/yellow] {output_dir}")
state = PipelineState(
questions=questions,
current_stage="start",
Expand Down Expand Up @@ -69,4 +80,13 @@ def run_pipeline(
elapsed = time.perf_counter() - started
console.print(f"[green]Completed[/green] {stage_name} in {elapsed:.2f}s")

# Print run summary
console.print("\n[bold]Run Summary[/bold]")
console.print(f" Total time: {time.perf_counter() - start_time:.1f}s")
console.print(f" Candidates: {len(state.candidates)}")
console.print(f" Screened: {len(state.screening_results)}")
console.print(f" Analyzed: {len(state.analyses)}")
console.print(f" Exported: {len(state.ranked_paper_ids)}")
console.print(f" Output: {state.output_dir}")

return state
73 changes: 55 additions & 18 deletions src/litresearch/stages/analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Stage 4: screening and extended paper analysis."""

import json
from pathlib import Path

from rich.console import Console
from rich.progress import track
Expand Down Expand Up @@ -39,24 +40,34 @@ def _screen_paper(
console.print(f"[yellow]Screening failed:[/yellow] {paper.title} ({exc})")
return None

payload = json.loads(response)
return ScreeningResult(
paper_id=paper.paper_id,
relevance_score=payload["relevance_score"],
rationale=payload["rationale"],
)
try:
payload = json.loads(response)
return ScreeningResult(
paper_id=paper.paper_id,
relevance_score=payload["relevance_score"],
rationale=payload["rationale"],
)
except json.JSONDecodeError:
console.print(f"[yellow]JSON parse failed:[/yellow] {paper.title}")
return None


def _analyze_paper(
paper: Paper,
questions: list[str],
settings: Settings,
prompt: str,
) -> AnalysisResult | None:
output_dir: str,
) -> tuple[AnalysisResult | None, bool]:
pdf_text = ""
pdf_downloaded = False
if paper.open_access_pdf_url:
pdf_bytes = download_pdf(paper.open_access_pdf_url)
if pdf_bytes is not None:
papers_dir = Path(output_dir) / "papers"
papers_dir.mkdir(parents=True, exist_ok=True)
(papers_dir / f"{paper.paper_id}.pdf").write_bytes(pdf_bytes)
pdf_downloaded = True
pdf_text = extract_text(
pdf_bytes,
first_pages=settings.pdf_first_pages,
Expand Down Expand Up @@ -84,28 +95,43 @@ def _analyze_paper(
response = call_llm(settings, prompt, user_content)
except LLMError as exc:
console.print(f"[yellow]Analysis failed:[/yellow] {paper.title} ({exc})")
return None
return (None, pdf_downloaded)

payload = json.loads(response)
return AnalysisResult(
paper_id=paper.paper_id,
summary=payload["summary"],
key_findings=payload.get("key_findings", []),
methodology=payload["methodology"],
relevance_score=payload["relevance_score"],
relevance_rationale=payload["relevance_rationale"],
)
try:
payload = json.loads(response)
return (
AnalysisResult(
paper_id=paper.paper_id,
summary=payload["summary"],
key_findings=payload.get("key_findings", []),
methodology=payload["methodology"],
relevance_score=payload["relevance_score"],
relevance_rationale=payload["relevance_rationale"],
),
pdf_downloaded,
)
except json.JSONDecodeError:
console.print(f"[yellow]JSON parse failed:[/yellow] {paper.title}")
return (None, pdf_downloaded)


def run(state: PipelineState, settings: Settings) -> PipelineState:
"""Screen candidate papers and analyze the relevant ones."""
screening_prompt = load_prompt("screening")
analysis_prompt = load_prompt("analysis")
papers_by_id = {paper.paper_id: paper for paper in state.candidates}

screening_results: list[ScreeningResult] = []
passed_papers: list[Paper] = []
for paper in track(state.candidates, description="Screening papers"):
if not paper.abstract:
screening_results.append(
ScreeningResult(
paper_id=paper.paper_id,
relevance_score=0,
rationale="no abstract available",
)
)
continue

screening_result = _screen_paper(paper, state.questions, settings, screening_prompt)
Expand All @@ -118,12 +144,23 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:

analyses: list[AnalysisResult] = []
for paper in track(passed_papers, description="Analyzing papers"):
analysis_result = _analyze_paper(paper, state.questions, settings, analysis_prompt)
analysis_result, pdf_downloaded = _analyze_paper(
paper,
state.questions,
settings,
analysis_prompt,
state.output_dir,
)
if pdf_downloaded:
papers_by_id[paper.paper_id] = paper.model_copy(update={"pdf_downloaded": True})
if analysis_result is not None:
analyses.append(analysis_result)

updated_candidates = [papers_by_id[paper.paper_id] for paper in state.candidates]

return state.model_copy(
update={
"candidates": updated_candidates,
"screening_results": screening_results,
"analyses": analyses,
"current_stage": "analysis",
Expand Down
8 changes: 6 additions & 2 deletions src/litresearch/stages/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
def run(state: PipelineState, settings: Settings) -> PipelineState:
"""Discover candidate papers for the generated search queries."""
if settings.s2_api_key:
scholar = SemanticScholar(api_key=settings.s2_api_key)
scholar = SemanticScholar(
api_key=settings.s2_api_key,
timeout=settings.s2_timeout,
retry=False,
)
else:
scholar = SemanticScholar()
scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False)
papers_by_id: dict[str, Paper] = {}

for search_query in track(state.search_queries, description="Discovering papers"):
Expand Down
10 changes: 7 additions & 3 deletions src/litresearch/stages/enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"citationStyles",
]

BATCH_SIZE = 500
BATCH_SIZE = 500 # S2 /papers batch endpoint limit


def _chunk(items: list[str], size: int) -> list[list[str]]:
Expand All @@ -35,9 +35,13 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
return state.model_copy(update={"current_stage": "enrichment"})

if settings.s2_api_key:
scholar = SemanticScholar(api_key=settings.s2_api_key)
scholar = SemanticScholar(
api_key=settings.s2_api_key,
timeout=settings.s2_timeout,
retry=False,
)
else:
scholar = SemanticScholar()
scholar = SemanticScholar(timeout=settings.s2_timeout, retry=False)

papers_by_id = {paper.paper_id: paper for paper in state.candidates}
for batch in _chunk(list(papers_by_id), BATCH_SIZE):
Expand Down
2 changes: 2 additions & 0 deletions src/litresearch/stages/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
for paper in track(top_papers, description="Downloading PDFs"):
if not paper.open_access_pdf_url:
continue
if paper.pdf_downloaded:
continue
pdf_bytes = download_pdf(paper.open_access_pdf_url)
if pdf_bytes is None:
continue
Expand Down
7 changes: 5 additions & 2 deletions src/litresearch/stages/query_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json

from litresearch.config import Settings
from litresearch.llm import call_llm
from litresearch.llm import LLMError, call_llm
from litresearch.models import Facet, PipelineState, SearchQuery
from litresearch.prompts import load_prompt

Expand All @@ -14,7 +14,10 @@ def run(state: PipelineState, settings: Settings) -> PipelineState:
user_content = "Research questions:\n" + "\n".join(
f"- {question}" for question in state.questions
)
response = call_llm(settings, prompt, user_content)
try:
response = call_llm(settings, prompt, user_content)
except LLMError as exc:
raise LLMError(f"Query generation failed: {exc}") from exc
payload = json.loads(response)

facets = [Facet.model_validate(item) for item in payload.get("facets", [])]
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/test_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import json

from litresearch.config import Settings
from litresearch.models import Paper, PipelineState, ScreeningResult
from litresearch.stages.analysis import run


def test_analysis_saves_pdf_and_marks_candidate_downloaded(tmp_path, monkeypatch) -> None:
state = PipelineState(
questions=["q"],
candidates=[
Paper(
paper_id="p1",
title="One",
abstract="abstract",
open_access_pdf_url="https://example.com/p1.pdf",
)
],
ranked_paper_ids=[],
current_stage="enrichment",
output_dir=str(tmp_path),
created_at="2026-03-09T16:00:00Z",
updated_at="2026-03-09T16:00:00Z",
)

import litresearch.stages.analysis as analysis_stage

monkeypatch.setattr(analysis_stage, "load_prompt", lambda _name: "prompt")
monkeypatch.setattr(
analysis_stage,
"_screen_paper",
lambda paper, questions, settings, prompt: ScreeningResult(
paper_id=paper.paper_id,
relevance_score=100,
rationale="fit",
),
)
monkeypatch.setattr(analysis_stage, "download_pdf", lambda _url: b"%PDF-1.0")
monkeypatch.setattr(analysis_stage, "extract_text", lambda *_args, **_kwargs: "body")
monkeypatch.setattr(
analysis_stage,
"call_llm",
lambda settings, system_prompt, user_content: json.dumps(
{
"summary": "summary",
"key_findings": ["finding"],
"methodology": "experiment",
"relevance_score": 80,
"relevance_rationale": "fit",
}
),
)

updated_state = run(state, Settings())

assert updated_state.candidates[0].pdf_downloaded is True
assert (tmp_path / "papers" / "p1.pdf").read_bytes() == b"%PDF-1.0"
assert len(updated_state.analyses) == 1
1 change: 1 addition & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_run_help_shows_expected_options() -> None:
assert "final top-N cutoff" in output
assert "output directory" in output
assert "screening threshold" in output
assert "Overwrite existing output directory" in output


def test_resume_help_shows_expected_options() -> None:
Expand Down
Loading