diff --git a/.gitignore b/.gitignore index 37fe97d..1c5ec48 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ build/ .pytest_cache/ .mypy_cache/ .ruff_cache/ +artifacts/ diff --git a/app/api/v1/endpoints/jobs.py b/app/api/v1/endpoints/jobs.py index f0f4e75..4581db1 100644 --- a/app/api/v1/endpoints/jobs.py +++ b/app/api/v1/endpoints/jobs.py @@ -130,6 +130,7 @@ async def get_job_result( status=job.status, caption=result.caption if result else None, instagram_meta=result.instagram_meta if result else None, + extraction_result=result.extraction_result if result else None, error_message=job.error_message, updated_at=job.updated_at, ) diff --git a/app/core/config.py b/app/core/config.py index 885b013..aebefbc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -91,6 +91,12 @@ class Settings(BaseSettings): kakao_timeout_seconds: int = 5 kakao_max_places_per_candidate: int = 5 + hf_extraction_endpoint_url: str = "" + hf_extraction_api_token: str = "" + hf_extraction_model_name: str = "Qwen/Qwen2.5-3B-Instruct" + hf_extraction_timeout_seconds: int = 20 + hf_extraction_max_new_tokens: int = 512 + @field_validator("processing_schema") @classmethod def validate_schema_name(cls, value: str) -> str: diff --git a/app/domain/job/__init__.py b/app/domain/job/__init__.py index fd26b81..3c31519 100644 --- a/app/domain/job/__init__.py +++ b/app/domain/job/__init__.py @@ -1,23 +1,29 @@ from app.domain.job.model import ( CrawlArtifact, + ExtractionCertainty, + ExtractionResult, ExtractedCandidate, JobRecord, JobResultRecord, JobStatus, PlaceCandidate, as_candidate_dict, + as_extraction_result_dict, as_place_dict, ) from app.domain.job.service import CreateJobCommand, InvalidJobRequest, JobService __all__ = [ "CrawlArtifact", + "ExtractionCertainty", + "ExtractionResult", "ExtractedCandidate", "JobRecord", "JobResultRecord", "JobStatus", "PlaceCandidate", "as_candidate_dict", + "as_extraction_result_dict", "as_place_dict", "CreateJobCommand", "InvalidJobRequest", diff --git a/app/domain/job/model.py b/app/domain/job/model.py index c9ef30a..ff9f040 100644 --- a/app/domain/job/model.py +++ b/app/domain/job/model.py @@ -14,6 +14,21 @@ class JobStatus(str, Enum): FAILED = "FAILED" +class ExtractionCertainty(str, Enum): + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +@dataclass(slots=True) +class ExtractionResult: + store_name: str | None + address: str | None + store_name_evidence: str | None + address_evidence: str | None + certainty: ExtractionCertainty + + @dataclass(slots=True) class JobRecord: job_id: UUID @@ -30,6 +45,7 @@ class JobResultRecord: job_id: UUID caption: str | None instagram_meta: dict[str, Any] | None + extraction_result: dict[str, Any] | None created_at: datetime updated_at: datetime @@ -86,3 +102,13 @@ def as_candidate_dict(candidate: ExtractedCandidate) -> dict[str, Any]: "source_sentence": candidate.source_sentence, "raw_candidate": candidate.raw_candidate, } + + +def as_extraction_result_dict(result: ExtractionResult) -> dict[str, Any]: + return { + "store_name": result.store_name, + "address": result.address, + "store_name_evidence": result.store_name_evidence, + "address_evidence": result.address_evidence, + "certainty": result.certainty.value, + } diff --git a/app/infra/db/repository.py b/app/infra/db/repository.py index 322f280..ff70f30 100644 --- a/app/infra/db/repository.py +++ b/app/infra/db/repository.py @@ -105,7 +105,9 @@ async def upsert_job_result( job_id: UUID, caption: str | None, instagram_meta: dict[str, Any] | None, + extraction_result: dict[str, Any] | None = None, ) -> JobResultRecord: + _ = extraction_result sql = f""" INSERT INTO {self._results_table} (job_id, caption, instagram_meta) @@ -144,6 +146,7 @@ def _to_job_result_record(self, row: asyncpg.Record) -> JobResultRecord: job_id=row["job_id"], caption=row["caption"], instagram_meta=self._json_to_dict(row["instagram_meta"]), + extraction_result=None, created_at=row["created_at"], updated_at=row["updated_at"], ) diff --git a/app/infra/llm/__init__.py b/app/infra/llm/__init__.py new file mode 100644 index 0000000..c15270e --- /dev/null +++ b/app/infra/llm/__init__.py @@ -0,0 +1,13 @@ +from app.infra.llm.client import ( + HFExtractionClient, + HFExtractionError, + extract_json_object, + extract_text_from_hf_payload, +) + +__all__ = [ + "HFExtractionClient", + "HFExtractionError", + "extract_json_object", + "extract_text_from_hf_payload", +] diff --git a/app/infra/llm/client.py b/app/infra/llm/client.py new file mode 100644 index 0000000..85e3c78 --- /dev/null +++ b/app/infra/llm/client.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import json +import re +from typing import Any + +import httpx +from pydantic import ValidationError + +from app.core.config import Settings +from app.domain.job import ExtractionResult +from app.schemas.extraction import ExtractionLLMResponse + +EXTRACTION_SYSTEM_PROMPT = ( + "You extract store information from Korean restaurant social media captions. " + "Return only one JSON object with these exact keys: store_name, address, " + "store_name_evidence, address_evidence, certainty. Use null when a value is " + "unknown. Evidence values must be substrings copied from the input caption. " + "certainty must be one of high, medium, or low. Do not include explanations, " + "Markdown, or any text outside the JSON object." +) + + +class HFExtractionError(Exception): + pass + + +class HFExtractionClient: + def __init__( + self, + settings: Settings, + *, + transport: httpx.AsyncBaseTransport | None = None, + ) -> None: + self._settings = settings + self._transport = transport + + async def extract( + self, + *, + text: str, + source_url: str, + media_type: str | None, + ) -> ExtractionResult | None: + if not text.strip(): + return None + if not self._settings.hf_extraction_endpoint_url: + raise HFExtractionError("HF extraction endpoint URL is empty") + if not self._settings.hf_extraction_api_token: + raise HFExtractionError("HF extraction API token is empty") + + payload = self._build_payload( + text=text, + source_url=source_url, + media_type=media_type, + ) + headers = { + "Authorization": f"Bearer {self._settings.hf_extraction_api_token}", + "Content-Type": "application/json", + } + timeout = httpx.Timeout(self._settings.hf_extraction_timeout_seconds) + + try: + async with httpx.AsyncClient( + timeout=timeout, + transport=self._transport, + ) as client: + response = await client.post( + self._settings.hf_extraction_endpoint_url, + headers=headers, + json=payload, + ) + except (httpx.TimeoutException, httpx.NetworkError) as exc: + raise HFExtractionError(str(exc)) from exc + + if response.status_code >= 400: + raise HFExtractionError(f"HF request failed ({response.status_code})") + + try: + response_payload = response.json() + except json.JSONDecodeError as exc: + raise HFExtractionError("HF response is not valid JSON") from exc + + generated_text = extract_text_from_hf_payload(response_payload) + generated_json = extract_json_object(generated_text) + + try: + return ExtractionLLMResponse.model_validate(generated_json).to_domain() + except ValidationError as exc: + raise HFExtractionError("HF response failed schema validation") from exc + + def _build_payload( + self, + *, + text: str, + source_url: str, + media_type: str | None, + ) -> dict[str, Any]: + _ = source_url, media_type + return { + "model": self._settings.hf_extraction_model_name, + "messages": [ + {"role": "system", "content": EXTRACTION_SYSTEM_PROMPT}, + {"role": "user", "content": text}, + ], + "temperature": 0.0, + "max_tokens": self._settings.hf_extraction_max_new_tokens, + } + + +def extract_text_from_hf_payload(payload: Any) -> str: + if isinstance(payload, str): + return payload + + if isinstance(payload, list): + if not payload: + raise HFExtractionError("HF response list is empty") + return extract_text_from_hf_payload(payload[0]) + + if not isinstance(payload, dict): + raise HFExtractionError("HF response has unsupported shape") + + generated_text = payload.get("generated_text") + if isinstance(generated_text, str): + return generated_text + + output = payload.get("output") or payload.get("outputs") + if isinstance(output, str): + return output + + choices = payload.get("choices") + if isinstance(choices, list) and choices: + choice = choices[0] + if isinstance(choice, dict): + message = choice.get("message") + if isinstance(message, dict) and isinstance(message.get("content"), str): + return message["content"] + if isinstance(choice.get("text"), str): + return choice["text"] + + raise HFExtractionError("HF response does not contain generated text") + + +def extract_json_object(text: str) -> dict[str, Any]: + raw = (text or "").strip() + if not raw: + raise HFExtractionError("Generated text is empty") + + fenced = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL | re.IGNORECASE) + if fenced: + raw = fenced.group(1).strip() + + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + start = raw.find("{") + end = raw.rfind("}") + if start < 0 or end <= start: + raise HFExtractionError("Generated text does not contain a JSON object") from None + try: + parsed = json.loads(raw[start : end + 1]) + except json.JSONDecodeError as exc: + raise HFExtractionError("Generated text contains invalid JSON") from exc + + if not isinstance(parsed, dict): + raise HFExtractionError("Generated JSON is not an object") + return parsed diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index 4c08afe..697db16 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -1,9 +1,11 @@ +from app.schemas.extraction import ExtractionLLMResponse from app.schemas.jobs import ( ApiErrorResponse, CreateJobRequest, CreateJobResponse, JobResultResponse, JobStatusResponse, + ExtractionResultResponse, ) __all__ = [ @@ -12,4 +14,6 @@ "CreateJobResponse", "JobResultResponse", "JobStatusResponse", + "ExtractionResultResponse", + "ExtractionLLMResponse", ] diff --git a/app/schemas/extraction.py b/app/schemas/extraction.py new file mode 100644 index 0000000..5b5c763 --- /dev/null +++ b/app/schemas/extraction.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, field_validator + +from app.domain.job.model import ExtractionCertainty, ExtractionResult + + +class ExtractionLLMResponse(BaseModel): + model_config = ConfigDict(extra="ignore") + + store_name: str | None = None + address: str | None = None + store_name_evidence: str | None = None + address_evidence: str | None = None + certainty: Literal["high", "medium", "low"] | None = None + + @field_validator( + "store_name", + "address", + "store_name_evidence", + "address_evidence", + mode="before", + ) + @classmethod + def normalize_optional_string(cls, value: object) -> object: + if value is None: + return None + if isinstance(value, str): + stripped = value.strip() + return stripped or None + return value + + @field_validator("certainty", mode="before") + @classmethod + def normalize_certainty(cls, value: object) -> object: + if value is None: + return None + if isinstance(value, str): + stripped = value.strip().lower() + return stripped or None + return value + + def to_domain(self) -> ExtractionResult: + return ExtractionResult( + store_name=self.store_name, + address=self.address, + store_name_evidence=self.store_name_evidence, + address_evidence=self.address_evidence, + certainty=ExtractionCertainty(self.certainty or "low"), + ) diff --git a/app/schemas/jobs.py b/app/schemas/jobs.py index 61b2e68..ff7420e 100644 --- a/app/schemas/jobs.py +++ b/app/schemas/jobs.py @@ -9,6 +9,14 @@ from app.domain.job.model import JobStatus +class ExtractionResultResponse(BaseModel): + store_name: str | None + address: str | None + store_name_evidence: str | None + address_evidence: str | None + certainty: Literal["high", "medium", "low"] + + class CreateJobRequest(BaseModel): url: HttpUrl = Field(..., examples=["https://www.instagram.com/reel/abcde/"]) room_id: UUID @@ -40,6 +48,7 @@ class JobResultResponse(BaseModel): status: JobStatus caption: str | None instagram_meta: dict[str, object] | None + extraction_result: ExtractionResultResponse | None = None error_message: str | None updated_at: datetime diff --git a/app/worker/processor.py b/app/worker/processor.py index b3f7f2b..0720a89 100644 --- a/app/worker/processor.py +++ b/app/worker/processor.py @@ -9,7 +9,12 @@ from app.core.config import Settings from app.domain.crawl import crawl_and_parse -from app.domain.job import JobRecord +from app.domain.job import ( + CrawlArtifact, + ExtractionResult, + JobRecord, + as_extraction_result_dict, +) logger = logging.getLogger("processing.worker.processor") @@ -32,15 +37,27 @@ async def mark_succeeded(self, job_id: UUID): ... async def mark_failed(self, job_id: UUID, error_message: str): ... +class ExtractionPort(Protocol): + async def extract( + self, + *, + text: str, + source_url: str, + media_type: str | None, + ) -> ExtractionResult | None: ... + + class JobProcessor: def __init__( self, *, repository: JobRepositoryPort, settings: Settings, + extraction_client: ExtractionPort | None = None, ) -> None: self._repository = repository self._settings = settings + self._extraction_client = extraction_client async def process_job(self, job_id: UUID) -> JobProcessOutcome: started = time.monotonic() @@ -57,7 +74,7 @@ async def process_job(self, job_id: UUID) -> JobProcessOutcome: try: crawl_artifact = await crawl_and_parse(job.source_url, self._settings) - # TODO(ner): Add embedding-based extraction in next migration step. + extraction_result = await self._extract_result(job.source_url, crawl_artifact) # TODO(kakao): Add Kakao Local enrichment and final place ranking in next migration step. logger.info( "job crawl completed job_id=%s caption_len=%s", @@ -69,6 +86,9 @@ async def process_job(self, job_id: UUID) -> JobProcessOutcome: job_id=job.job_id, caption=crawl_artifact.caption, instagram_meta=crawl_artifact.instagram_meta, + extraction_result=( + as_extraction_result_dict(extraction_result) if extraction_result else None + ), ) await self._repository.mark_succeeded(job.job_id) elapsed_ms = int((time.monotonic() - started) * 1000) @@ -92,3 +112,21 @@ async def process_job(self, job_id: UUID) -> JobProcessOutcome: timed_out=timed_out, elapsed_ms=elapsed_ms, ) + + async def _extract_result( + self, + source_url: str, + crawl_artifact: CrawlArtifact, + ) -> ExtractionResult | None: + if not self._extraction_client or not crawl_artifact.caption: + return None + + try: + return await self._extraction_client.extract( + text=crawl_artifact.caption, + source_url=source_url, + media_type=crawl_artifact.media_type, + ) + except Exception: + logger.exception("extraction failed source_url=%s", source_url) + return None diff --git a/app/worker/runner.py b/app/worker/runner.py index aa3b505..624626a 100644 --- a/app/worker/runner.py +++ b/app/worker/runner.py @@ -8,9 +8,10 @@ from app.core.config import get_settings from app.infra.db import JobRepository, create_db_pool +from app.infra.llm import HFExtractionClient from app.infra.queue import RedisJobQueue from app.services.crawler.playwright_service import prewarm_crawler_runtime, shutdown_crawler_runtime -from app.worker.processor import JobProcessor +from app.worker.processor import ExtractionPort, JobProcessor logger = logging.getLogger("processing.worker") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") @@ -84,6 +85,13 @@ def _p95(values: list[int]) -> int: return int(sorted_values[idx]) +def build_extraction_client(settings) -> ExtractionPort | None: + if not settings.hf_extraction_endpoint_url or not settings.hf_extraction_api_token: + logger.info("worker extraction client disabled (HF endpoint URL or token is empty)") + return None + return HFExtractionClient(settings) + + async def run_worker() -> None: settings = get_settings() pool = await create_db_pool(settings) @@ -94,6 +102,7 @@ async def run_worker() -> None: processor = JobProcessor( repository=repository, settings=settings, + extraction_client=build_extraction_client(settings), ) if settings.worker_prewarm_browser: diff --git a/tests/test_extraction_schema.py b/tests/test_extraction_schema.py new file mode 100644 index 0000000..03207e0 --- /dev/null +++ b/tests/test_extraction_schema.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.domain.job import ExtractionCertainty +from app.schemas.extraction import ExtractionLLMResponse + + +def test_llm_response_normalizes_missing_fields_and_certainty() -> None: + response = ExtractionLLMResponse.model_validate( + { + "store_name": " 커먼맨션 ", + "certainty": "HIGH", + } + ) + + assert response.store_name == "커먼맨션" + assert response.address is None + assert response.store_name_evidence is None + assert response.address_evidence is None + assert response.certainty == "high" + + domain = response.to_domain() + + assert domain.store_name == "커먼맨션" + assert domain.certainty is ExtractionCertainty.HIGH + + +def test_llm_response_defaults_missing_certainty_to_low_in_domain() -> None: + response = ExtractionLLMResponse.model_validate( + { + "address": " 서울 종로구 신문로2가 1-102 ", + "unexpected": "ignored", + } + ) + + domain = response.to_domain() + + assert response.address == "서울 종로구 신문로2가 1-102" + assert response.certainty is None + assert domain.certainty is ExtractionCertainty.LOW + + +def test_llm_response_rejects_unknown_certainty() -> None: + with pytest.raises(ValidationError): + ExtractionLLMResponse.model_validate({"certainty": "certain"}) diff --git a/tests/test_hf_extraction_client.py b/tests/test_hf_extraction_client.py new file mode 100644 index 0000000..4d15f47 --- /dev/null +++ b/tests/test_hf_extraction_client.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import asyncio +import json + +import httpx +import pytest + +from app.core.config import Settings +from app.domain.job import ExtractionCertainty +from app.infra.llm import ( + HFExtractionClient, + HFExtractionError, + extract_json_object, + extract_text_from_hf_payload, +) + + +def _run(coro): + try: + return asyncio.run(coro) + except OSError as exc: + pytest.skip(f"Event loop creation is blocked in this environment: {exc}") + + +def _settings() -> Settings: + return Settings( + hf_extraction_endpoint_url="https://example.test/hf", + hf_extraction_api_token="test-token", + ) + + +def _response_payload() -> dict[str, object]: + return { + "store_name": "Common Mansion", + "address": "1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + "store_name_evidence": "Common Mansion", + "address_evidence": "1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + "certainty": "HIGH", + } + + +def test_extract_json_object_accepts_fenced_json() -> None: + parsed = extract_json_object(f"```json\n{json.dumps(_response_payload())}\n```") + + assert parsed["store_name"] == "Common Mansion" + + +def test_extract_json_object_accepts_text_wrapped_json() -> None: + parsed = extract_json_object(f"Here is the result:\n{json.dumps(_response_payload())}\nDone.") + + assert parsed["certainty"] == "HIGH" + + +def test_extract_text_from_hf_payload_accepts_common_shapes() -> None: + assert extract_text_from_hf_payload({"generated_text": "a"}) == "a" + assert extract_text_from_hf_payload([{"generated_text": "b"}]) == "b" + assert extract_text_from_hf_payload({"choices": [{"message": {"content": "c"}}]}) == "c" + + +def test_hf_extraction_client_returns_domain_result() -> None: + seen_requests: list[dict[str, object]] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + seen_requests.append(json.loads(request.content.decode("utf-8"))) + return httpx.Response( + 200, + json={"generated_text": json.dumps(_response_payload())}, + ) + + extractor = HFExtractionClient( + _settings(), + transport=httpx.MockTransport(handler), + ) + + result = _run( + extractor.extract( + text="Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + source_url="https://www.instagram.com/reel/example/", + media_type="reel", + ) + ) + + assert result is not None + assert result.store_name == "Common Mansion" + assert result.certainty is ExtractionCertainty.HIGH + assert seen_requests[0]["messages"][1]["content"] == ( + "Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul" + ) + assert seen_requests[0]["temperature"] == 0.0 + assert seen_requests[0]["max_tokens"] == 512 + + +def test_hf_extraction_client_raises_on_http_error() -> None: + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(500, json={"error": "temporary failure"}) + + extractor = HFExtractionClient( + _settings(), + transport=httpx.MockTransport(handler), + ) + + with pytest.raises(HFExtractionError): + _run( + extractor.extract( + text="Common Mansion", + source_url="https://example.com/post", + media_type=None, + ) + ) + + +def test_hf_extraction_client_raises_on_invalid_schema() -> None: + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={"generated_text": json.dumps({"certainty": "unknown"})}, + ) + + extractor = HFExtractionClient( + _settings(), + transport=httpx.MockTransport(handler), + ) + + with pytest.raises(HFExtractionError): + _run( + extractor.extract( + text="Common Mansion", + source_url="https://example.com/post", + media_type=None, + ) + ) + + +def test_hf_extraction_client_raises_when_endpoint_is_missing() -> None: + extractor = HFExtractionClient( + Settings( + hf_extraction_endpoint_url="", + hf_extraction_api_token="test-token", + ) + ) + + with pytest.raises(HFExtractionError): + _run( + extractor.extract( + text="Common Mansion", + source_url="https://example.com/post", + media_type=None, + ) + ) diff --git a/tests/test_job_result_schema.py b/tests/test_job_result_schema.py new file mode 100644 index 0000000..e22c3a6 --- /dev/null +++ b/tests/test_job_result_schema.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from uuid import uuid4 + +from app.domain.job import JobStatus +from app.schemas.jobs import JobResultResponse + + +def test_job_result_response_accepts_extraction_result() -> None: + response = JobResultResponse( + job_id=uuid4(), + source_url="https://www.instagram.com/reel/example/", + source="instagram", + status=JobStatus.SUCCEEDED, + caption="• 커먼맨션\n서울 종로구 신문로2가 1-102", + instagram_meta=None, + extraction_result={ + "store_name": "커먼맨션", + "address": "서울 종로구 신문로2가 1-102", + "store_name_evidence": "• 커먼맨션", + "address_evidence": "서울 종로구 신문로2가 1-102", + "certainty": "high", + }, + error_message=None, + updated_at=datetime.now(timezone.utc), + ) + + dumped = response.model_dump() + + assert dumped["extraction_result"]["store_name"] == "커먼맨션" + assert dumped["extraction_result"]["certainty"] == "high" + + +def test_job_result_response_allows_missing_extraction_result() -> None: + response = JobResultResponse( + job_id=uuid4(), + source_url="https://example.com/post", + source="web", + status=JobStatus.SUCCEEDED, + caption="caption only", + instagram_meta=None, + error_message=None, + updated_at=datetime.now(timezone.utc), + ) + + assert response.extraction_result is None diff --git a/tests/test_worker_processor.py b/tests/test_worker_processor.py index 6232562..980cd82 100644 --- a/tests/test_worker_processor.py +++ b/tests/test_worker_processor.py @@ -2,12 +2,18 @@ import asyncio from datetime import datetime, timezone -from uuid import uuid4 +from uuid import UUID, uuid4 import pytest from app.core.config import Settings -from app.domain.job import CrawlArtifact, JobRecord, JobStatus +from app.domain.job import ( + CrawlArtifact, + ExtractionCertainty, + ExtractionResult, + JobRecord, + JobStatus, +) from app.worker.processor import JobProcessor if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): @@ -58,6 +64,39 @@ async def mark_failed(self, job_id: UUID, error_message: str): return self._job +class FakeExtractionClient: + def __init__(self, result: ExtractionResult | None) -> None: + self.result = result + self.calls: list[dict[str, object]] = [] + + async def extract( + self, + *, + text: str, + source_url: str, + media_type: str | None, + ) -> ExtractionResult | None: + self.calls.append( + { + "text": text, + "source_url": source_url, + "media_type": media_type, + } + ) + return self.result + + +class FailingExtractionClient: + async def extract( + self, + *, + text: str, + source_url: str, + media_type: str | None, + ) -> ExtractionResult | None: + raise RuntimeError("endpoint unavailable") + + def _new_job() -> JobRecord: now = datetime.now(timezone.utc) return JobRecord( @@ -99,6 +138,94 @@ async def fake_crawl(url: str, _settings: Settings) -> CrawlArtifact: assert repo.succeeded is True assert repo.saved_result is not None assert repo.saved_result["caption"] == "#yeonnamcafe review" + assert repo.saved_result["extraction_result"] is None + assert repo.failed is None + + +@pytest.mark.skipif(not EVENT_LOOP_AVAILABLE, reason="Event loop creation is blocked in this environment") +def test_processor_passes_caption_to_extraction_client(monkeypatch) -> None: + job = _new_job() + repo = FakeRepository(job) + settings = Settings() + extractor = FakeExtractionClient( + ExtractionResult( + store_name="Common Mansion", + address="1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + store_name_evidence="Common Mansion", + address_evidence="1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + certainty=ExtractionCertainty.HIGH, + ) + ) + + async def fake_crawl(url: str, _settings: Settings) -> CrawlArtifact: + return CrawlArtifact( + url=url, + html=None, + text="Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + media_type="reel", + caption="Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + instagram_meta=None, + ) + + monkeypatch.setattr("app.worker.processor.crawl_and_parse", fake_crawl) + + processor = JobProcessor( + repository=repo, + settings=settings, + extraction_client=extractor, + ) + + _run(processor.process_job(job.job_id)) + + assert extractor.calls == [ + { + "text": "Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + "source_url": job.source_url, + "media_type": "reel", + } + ] + assert repo.succeeded is True + assert repo.saved_result is not None + assert repo.saved_result["extraction_result"] == { + "store_name": "Common Mansion", + "address": "1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + "store_name_evidence": "Common Mansion", + "address_evidence": "1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + "certainty": "high", + } + assert repo.failed is None + + +@pytest.mark.skipif(not EVENT_LOOP_AVAILABLE, reason="Event loop creation is blocked in this environment") +def test_processor_succeeds_when_extraction_client_fails(monkeypatch) -> None: + job = _new_job() + repo = FakeRepository(job) + settings = Settings() + + async def fake_crawl(url: str, _settings: Settings) -> CrawlArtifact: + return CrawlArtifact( + url=url, + html=None, + text="Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + media_type="reel", + caption="Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul", + instagram_meta=None, + ) + + monkeypatch.setattr("app.worker.processor.crawl_and_parse", fake_crawl) + + processor = JobProcessor( + repository=repo, + settings=settings, + extraction_client=FailingExtractionClient(), + ) + + _run(processor.process_job(job.job_id)) + + assert repo.succeeded is True + assert repo.saved_result is not None + assert repo.saved_result["caption"] == "Common Mansion 1-102 Sinmunro 2-ga, Jongno-gu, Seoul" + assert repo.saved_result["extraction_result"] is None assert repo.failed is None diff --git a/tests/test_worker_runner.py b/tests/test_worker_runner.py new file mode 100644 index 0000000..0c38e45 --- /dev/null +++ b/tests/test_worker_runner.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from app.core.config import Settings +from app.infra.llm import HFExtractionClient +from app.worker.runner import build_extraction_client + + +def test_build_extraction_client_returns_none_without_endpoint() -> None: + settings = Settings( + hf_extraction_endpoint_url="", + hf_extraction_api_token="test-token", + ) + + assert build_extraction_client(settings) is None + + +def test_build_extraction_client_returns_none_without_token() -> None: + settings = Settings( + hf_extraction_endpoint_url="https://router.huggingface.co/v1/chat/completions", + hf_extraction_api_token="", + ) + + assert build_extraction_client(settings) is None + + +def test_build_extraction_client_returns_hf_client_when_configured() -> None: + settings = Settings( + hf_extraction_endpoint_url="https://router.huggingface.co/v1/chat/completions", + hf_extraction_api_token="test-token", + ) + + assert isinstance(build_extraction_client(settings), HFExtractionClient)