Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions app/infra/db/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,16 @@ async def upsert_job_result(
instagram_meta: dict[str, Any] | None,
extraction_result: dict[str, Any] | None = None,
) -> JobResultRecord:
_ = extraction_result
sql = f"""
INSERT INTO {self._results_table}
(job_id, caption, instagram_meta)
(job_id, caption, instagram_meta, extraction_result)
VALUES
($1, $2, $3::jsonb)
($1, $2, $3::jsonb, $4::jsonb)
ON CONFLICT (job_id)
DO UPDATE SET
caption = EXCLUDED.caption,
instagram_meta = EXCLUDED.instagram_meta,
extraction_result = EXCLUDED.extraction_result,
updated_at = NOW()
RETURNING *
"""
Expand All @@ -125,6 +125,7 @@ async def upsert_job_result(
job_id,
caption,
json.dumps(instagram_meta or {}),
json.dumps(extraction_result) if extraction_result is not None else None,
)
if row is None:
raise RuntimeError("Failed to upsert job result")
Expand All @@ -146,7 +147,7 @@ def _to_job_result_record(self, row: asyncpg.Record) -> JobResultRecord:
job_id=row["job_id"],
caption=row["caption"],
instagram_meta=self._json_to_dict(row["instagram_meta"]),
extraction_result=None,
extraction_result=self._json_to_dict(row["extraction_result"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
)
Expand Down
48 changes: 48 additions & 0 deletions migrations/000_processing_schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
DROP SCHEMA IF EXISTS processing CASCADE;

CREATE SCHEMA IF NOT EXISTS processing;

CREATE TABLE IF NOT EXISTS processing.jobs (
job_id UUID PRIMARY KEY,
room_id UUID NOT NULL,
source_url TEXT NOT NULL,
status VARCHAR(16) NOT NULL DEFAULT 'QUEUED',
error_message TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT chk_processing_jobs_status
CHECK (status IN ('QUEUED', 'PROCESSING', 'SUCCEEDED', 'FAILED'))
);

CREATE INDEX IF NOT EXISTS idx_processing_jobs_status_created_at
ON processing.jobs (status, created_at DESC);

CREATE TABLE IF NOT EXISTS processing.job_results (
job_id UUID PRIMARY KEY REFERENCES processing.jobs(job_id) ON DELETE CASCADE,
caption TEXT,
instagram_meta JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT chk_processing_job_results_instagram_meta_object
CHECK (jsonb_typeof(instagram_meta) = 'object')
);

CREATE OR REPLACE FUNCTION processing.touch_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

DROP TRIGGER IF EXISTS trg_processing_jobs_updated_at ON processing.jobs;
CREATE TRIGGER trg_processing_jobs_updated_at
BEFORE UPDATE ON processing.jobs
FOR EACH ROW
EXECUTE FUNCTION processing.touch_updated_at();

DROP TRIGGER IF EXISTS trg_processing_job_results_updated_at ON processing.job_results;
CREATE TRIGGER trg_processing_job_results_updated_at
BEFORE UPDATE ON processing.job_results
FOR EACH ROW
EXECUTE FUNCTION processing.touch_updated_at();
2 changes: 2 additions & 0 deletions migrations/001_add_extraction_result_to_job_results.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE processing.job_results
ADD COLUMN IF NOT EXISTS extraction_result JSONB;
99 changes: 99 additions & 0 deletions scripts/run_hf_extraction_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

import argparse
import asyncio
import json
import sys
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))

from app.core.config import Settings # noqa: E402
from app.domain.job import as_extraction_result_dict # noqa: E402
from app.infra.llm import HFExtractionClient # noqa: E402

DEFAULT_INPUT_PATH = ROOT / "artifacts" / "hf_extraction_sample_inputs.json"
DEFAULT_OUTPUT_PATH = ROOT / "artifacts" / "hf_extraction_samples.json"


def _load_samples(path: Path) -> list[dict[str, Any]]:
with path.open("r", encoding="utf-8") as file:
raw = json.load(file)

if not isinstance(raw, list):
raise ValueError("Input JSON must be a list of captions or sample objects.")

samples: list[dict[str, Any]] = []
for index, item in enumerate(raw, start=1):
if isinstance(item, str):
samples.append({"id": index, "caption": item})
continue
if isinstance(item, dict) and isinstance(item.get("caption"), str):
samples.append(
{
"id": item.get("id", index),
"caption": item["caption"],
"source_url": item.get("source_url"),
"media_type": item.get("media_type", "reel"),
}
)
continue
raise ValueError(f"Sample #{index} must be a string or an object with a caption field.")
return samples


async def _run_samples(input_path: Path, output_path: Path) -> None:
settings = Settings()
extractor = HFExtractionClient(settings)
samples = _load_samples(input_path)
results: list[dict[str, Any]] = []

for sample in samples:
sample_id = sample["id"]
caption = sample["caption"]
print(f"[{sample_id}] extracting...", flush=True)

try:
prediction = await extractor.extract(
text=caption,
source_url=sample.get("source_url") or f"https://www.instagram.com/reel/sample-{sample_id}/",
media_type=sample.get("media_type") or "reel",
)
results.append(
{
"id": sample_id,
"caption": caption,
"prediction": as_extraction_result_dict(prediction) if prediction else None,
"error": None,
}
)
except Exception as exc: # noqa: BLE001
results.append(
{
"id": sample_id,
"caption": caption,
"prediction": None,
"error": f"{type(exc).__name__}: {exc}",
}
)

output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as file:
json.dump(results, file, ensure_ascii=False, indent=2)
file.write("\n")
print(f"Wrote {len(results)} results to {output_path}", flush=True)


def main() -> None:
parser = argparse.ArgumentParser(description="Run HF extraction against local caption samples.")
parser.add_argument("--input", type=Path, default=DEFAULT_INPUT_PATH)
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT_PATH)
args = parser.parse_args()

asyncio.run(_run_samples(args.input, args.output))


if __name__ == "__main__":
main()
64 changes: 64 additions & 0 deletions tests/test_hf_extraction_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,70 @@ async def handler(request: httpx.Request) -> httpx.Response:
assert seen_requests[0]["max_tokens"] == 512


def test_hf_extraction_client_accepts_long_realistic_caption() -> None:
long_caption = """실제 광화문 직장인 지인이 여기가 최고라고 소개해줘서 알게 된 집

실내 분위기 너무 좋았던 브런치 맛집 커먼맨션 입니다

샌드위치에 샐러드 파스타 이렇게 3종류로 크게 나눌 수 있는데 샌드위치 먹고 있으면 샌드위치 전문점인 거 같고

샐러드 먹으면 샐러드 파스타면 파스타

모든 메뉴가 전문점 수준으로 너무 맛있어서 정말 대만족했던 집 입니다

광화문 직장인 상권이다보니 점심시간에 가면 웨이팅이 심해서 못 먹고 올 수 있으니까

방문 예정이시면 꼭 예약을 미리 하고 가시는 걸 추천 드릴게요

실제 근처 직장인이시라면 점심 혹은 미팅 잡기도 정말 좋은 곳 일

거 같아요

• 커먼맨션

서울 종로구 신문로2가 1-102
10:00 - 21:00
20:00 라스트오더"""
seen_requests: list[dict[str, object]] = []

async def handler(request: httpx.Request) -> httpx.Response:
seen_requests.append(json.loads(request.content.decode("utf-8")))
return httpx.Response(
200,
json={
"generated_text": json.dumps(
{
"store_name": "커먼맨션",
"address": "서울 종로구 신문로2가 1-102",
"store_name_evidence": "커먼맨션",
"address_evidence": "서울 종로구 신문로2가 1-102",
"certainty": "high",
},
ensure_ascii=False,
)
},
)

extractor = HFExtractionClient(
_settings(),
transport=httpx.MockTransport(handler),
)

result = _run(
extractor.extract(
text=long_caption,
source_url="https://www.instagram.com/reel/example/",
media_type="reel",
)
)

assert result is not None
assert result.store_name == "커먼맨션"
assert result.address == "서울 종로구 신문로2가 1-102"
assert result.certainty is ExtractionCertainty.HIGH
assert seen_requests[0]["messages"][1]["content"] == long_caption


def test_hf_extraction_client_raises_on_http_error() -> None:
async def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(500, json={"error": "temporary failure"})
Expand Down
118 changes: 118 additions & 0 deletions tests/test_job_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

import asyncio
import json
from datetime import datetime, timezone
from uuid import uuid4

import pytest

from app.infra.db.repository import JobRepository

if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())


def _can_create_event_loop() -> bool:
try:
loop = asyncio.new_event_loop()
loop.close()
return True
except OSError:
return False


EVENT_LOOP_AVAILABLE = _can_create_event_loop()


def _run(coro):
try:
return asyncio.run(coro)
except OSError as exc:
pytest.skip(f"Event loop creation is blocked in this environment: {exc}")


class FakePool:
def __init__(self, row: dict | None = None) -> None:
self.row = row
self.sql: str | None = None
self.args: tuple | None = None

async def fetchrow(self, sql: str, *args):
self.sql = sql
self.args = args
if self.row is not None:
return self.row

now = datetime.now(timezone.utc)
return {
"job_id": args[0],
"caption": args[1],
"instagram_meta": args[2],
"extraction_result": args[3],
"created_at": now,
"updated_at": now,
}


@pytest.mark.skipif(not EVENT_LOOP_AVAILABLE, reason="Event loop creation is blocked in this environment")
def test_upsert_job_result_persists_extraction_result() -> None:
pool = FakePool()
repository = JobRepository(pool, "processing")
job_id = uuid4()
extraction_result = {
"store_name": "Common Mansion",
"address": "1-102 Sinmunro 2-ga, Jongno-gu, Seoul",
"store_name_evidence": "Common Mansion",
"address_evidence": "1-102 Sinmunro 2-ga, Jongno-gu, Seoul",
"certainty": "high",
}

record = _run(
repository.upsert_job_result(
job_id=job_id,
caption="Common Mansion review",
instagram_meta={"media_type": "reel"},
extraction_result=extraction_result,
)
)

assert pool.sql is not None
assert "extraction_result" in pool.sql
assert "extraction_result = EXCLUDED.extraction_result" in pool.sql
assert pool.args == (
job_id,
"Common Mansion review",
json.dumps({"media_type": "reel"}),
json.dumps(extraction_result),
)
assert record.extraction_result == extraction_result


@pytest.mark.skipif(not EVENT_LOOP_AVAILABLE, reason="Event loop creation is blocked in this environment")
def test_get_job_result_maps_extraction_result() -> None:
job_id = uuid4()
now = datetime.now(timezone.utc)
extraction_result = {
"store_name": None,
"address": None,
"store_name_evidence": None,
"address_evidence": None,
"certainty": "low",
}
pool = FakePool(
{
"job_id": job_id,
"caption": "caption",
"instagram_meta": json.dumps({"caption": "caption"}),
"extraction_result": json.dumps(extraction_result),
"created_at": now,
"updated_at": now,
}
)
repository = JobRepository(pool, "processing")

record = _run(repository.get_job_result(job_id))

assert record is not None
assert record.extraction_result == extraction_result
Loading
Loading