diff --git a/alembic/versions/20260504_0013_add_regulation_chunk_search_tsvector.py b/alembic/versions/20260504_0013_add_regulation_chunk_search_tsvector.py new file mode 100644 index 0000000..4492208 --- /dev/null +++ b/alembic/versions/20260504_0013_add_regulation_chunk_search_tsvector.py @@ -0,0 +1,49 @@ +"""add regulation chunk search tsvector + +Revision ID: 20260504_0013 +Revises: 20260429_0012 +Create Date: 2026-05-04 00:00:00.000000 +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +revision = "20260504_0013" +down_revision = "20260429_0012" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "regulation_chunk", + sa.Column("search_tsvector", postgresql.TSVECTOR(), nullable=True), + ) + + op.execute( + """ + UPDATE regulation_chunk AS rc + SET search_tsvector = to_tsvector( + 'simple', + COALESCE(rc.chunk_text, '') || ' ' || + COALESCE(rd.content, '') || ' ' || + COALESCE(rc.keywords::text, '') + ) + FROM regulation_document AS rd + WHERE rd.regulation_document_id = rc.regulation_document_id + """ + ) + + op.create_index( + "idx_regulation_chunk_search_tsvector", + "regulation_chunk", + ["search_tsvector"], + postgresql_using="gin", + ) + + +def downgrade() -> None: + op.drop_index("idx_regulation_chunk_search_tsvector", table_name="regulation_chunk") + op.drop_column("regulation_chunk", "search_tsvector") diff --git a/app/db/models/regulation_chunk.py b/app/db/models/regulation_chunk.py index 1cc77ba..7b92d04 100644 --- a/app/db/models/regulation_chunk.py +++ b/app/db/models/regulation_chunk.py @@ -5,6 +5,7 @@ from pgvector.sqlalchemy import Vector from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, Index, Integer, String, Text from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy import func from sqlalchemy import text from sqlalchemy.orm import Mapped, mapped_column @@ -20,6 +21,7 @@ class RegulationChunk(Base): Index("idx_regulation_chunk_document_version", "document_version"), Index("idx_regulation_chunk_chunk_id", "chunk_id"), Index("idx_regulation_chunk_is_active", "is_active"), + Index("idx_regulation_chunk_search_tsvector", "search_tsvector", postgresql_using="gin"), ) regulation_chunk_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) @@ -33,6 +35,7 @@ class RegulationChunk(Base): chunk_index: Mapped[int] = mapped_column(Integer, nullable=False) chunk_text: Mapped[Optional[str]] = mapped_column(Text, nullable=True) keywords: Mapped[Optional[list[str]]] = mapped_column(JSONB, nullable=True) + search_tsvector: Mapped[Optional[str]] = mapped_column(TSVECTOR, nullable=True) embedding: Mapped[Optional[List[float]]] = mapped_column(Vector(1536), nullable=True) chunk_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) embedding_model: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) diff --git a/app/repositories/regulation_chunk_repository.py b/app/repositories/regulation_chunk_repository.py index 22bb337..76b3795 100644 --- a/app/repositories/regulation_chunk_repository.py +++ b/app/repositories/regulation_chunk_repository.py @@ -53,11 +53,45 @@ def create_regulation_chunks_for_document( created_chunks.append(regulation_chunk) db.flush() + refresh_search_vectors_for_chunks( + db, + [ + regulation_chunk.regulation_chunk_id + for regulation_chunk in created_chunks + if regulation_chunk.regulation_chunk_id is not None + ], + ) for regulation_chunk in created_chunks: db.refresh(regulation_chunk) return created_chunks +def refresh_search_vectors_for_chunks(db: Session, regulation_chunk_ids: list[int]) -> int: + """저장된 청크 검색 텍스트를 tsvector 컬럼에 반영합니다.""" + + if not regulation_chunk_ids: + return 0 + + result = db.execute( + text( + """ + UPDATE regulation_chunk AS rc + SET search_tsvector = to_tsvector( + 'simple', + COALESCE(rc.chunk_text, '') || ' ' || + COALESCE(rd.content, '') || ' ' || + COALESCE(rc.keywords::text, '') + ) + FROM regulation_document AS rd + WHERE rd.regulation_document_id = rc.regulation_document_id + AND rc.regulation_chunk_id = ANY(:regulation_chunk_ids) + """ + ), + {"regulation_chunk_ids": regulation_chunk_ids}, + ) + return result.rowcount or 0 + + def deactivate_chunks_for_document(db: Session, regulation_document_id: int) -> int: statement = ( update(RegulationChunk) @@ -392,23 +426,13 @@ def _search_hybrid_chunks( rd.dormitory, 1 - (rc.embedding <=> CAST(:embedding AS vector)) AS vector_similarity, ts_rank_cd( - to_tsvector( - 'simple', - COALESCE(rc.chunk_text, '') || ' ' || - COALESCE(rd.content, '') || ' ' || - COALESCE(rc.keywords::text, '') - ), + rc.search_tsvector, websearch_to_tsquery('simple', :query_text) ) AS keyword_score, NULL::bigint AS vector_rank, ROW_NUMBER() OVER ( ORDER BY ts_rank_cd( - to_tsvector( - 'simple', - COALESCE(rc.chunk_text, '') || ' ' || - COALESCE(rd.content, '') || ' ' || - COALESCE(rc.keywords::text, '') - ), + rc.search_tsvector, websearch_to_tsquery('simple', :query_text) ) DESC ) AS keyword_rank @@ -419,12 +443,7 @@ def _search_hybrid_chunks( AND rc.is_active = TRUE AND rd.is_active = TRUE AND rc.embedding IS NOT NULL - AND to_tsvector( - 'simple', - COALESCE(rc.chunk_text, '') || ' ' || - COALESCE(rd.content, '') || ' ' || - COALESCE(rc.keywords::text, '') - ) @@ websearch_to_tsquery('simple', :query_text) + AND rc.search_tsvector @@ websearch_to_tsquery('simple', :query_text) ORDER BY keyword_score DESC LIMIT :candidate_k ), diff --git a/app/services/chat_service.py b/app/services/chat_service.py index 37d7ef0..a70c308 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -25,6 +25,9 @@ from app.repositories.regulation_chunk_repository import search_hybrid_chunks from app.repositories.regulation_chunk_repository import search_hybrid_chunks_all_dormitories from app.repositories.regulation_chunk_repository import search_hybrid_chunks_for_dormitories +from app.repositories.regulation_chunk_repository import search_similar_chunks +from app.repositories.regulation_chunk_repository import search_similar_chunks_for_dormitories +from app.repositories.regulation_chunk_repository import search_similar_chunks_all_dormitories from app.schemas.chat import ChatRequest from app.schemas.chat import ChatResponse from app.services.embeddings import create_query_embedding @@ -32,7 +35,6 @@ from app.services.generator import generate_answer from app.services.validator import validate_question from app.services.query_rewriter import expand_query_for_retrieval -from app.repositories.regulation_chunk_repository import search_similar_chunks_all_dormitories from app.services.room_floor_resolver import resolve_room_floor_question ERROR_TYPE_TIMEOUT = "TIMEOUT" @@ -274,10 +276,13 @@ def _answer_single_dormitory_chat( # 전체 생활관 fallback까지 했는데도 답변을 못 만들면 # LLM query expansion으로 검색용 질의를 확장한 뒤 재검색 if _is_no_answer(answer_result.answer): - expanded_query = expand_query_for_retrieval( - question=question, - dormitory=dormitory, - ) + if rewritten_query != question: + expanded_query = rewritten_query + else: + expanded_query = expand_query_for_retrieval( + question=question, + dormitory=dormitory, + ) if expanded_query != question: expanded_query_embedding = create_query_embedding(expanded_query) @@ -442,10 +447,13 @@ def _answer_unspecified_dormitory_chat( # 비로그인/생활관 미지정 상태에서 원문 검색으로 답변을 못 만들면 # query expansion으로 검색용 질의를 확장한 뒤 전체 생활관 대상으로 재검색 if _is_no_answer(answer_result.answer): - expanded_query = expand_query_for_retrieval( - question=question, - dormitory=None, - ) + if rewritten_query != question: + expanded_query = rewritten_query + else: + expanded_query = expand_query_for_retrieval( + question=question, + dormitory=None, + ) if expanded_query != question: diff --git a/scripts/benchmark_hybrid_search.py b/scripts/benchmark_hybrid_search.py new file mode 100644 index 0000000..168b851 --- /dev/null +++ b/scripts/benchmark_hybrid_search.py @@ -0,0 +1,572 @@ +"""Benchmark hybrid regulation chunk search before/after DB indexing changes. + +This script intentionally avoids calling external embedding or LLM APIs. It uses +a stable synthetic query embedding with the same dimension as stored chunk +embeddings, then runs the repository search functions directly so DB search cost +can be compared with less application-level noise. +""" + +from __future__ import annotations + +import argparse +import json +import os +import statistics +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +from sqlalchemy import create_engine +from sqlalchemy import text +from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker + +from app.core.config import get_settings +from app.repositories import regulation_chunk_repository as repo + + +DEFAULT_CASES = [ + { + "name": "single_curfew_keyword", + "function": "single", + "query_text": "통금", + "dormitory": "제1학생생활관", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "single_curfew_when", + "function": "single", + "query_text": "통금시간 언제까지야??", + "dormitory": "제2학생생활관", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "single_ramen_cook", + "function": "single", + "query_text": "방에서 라면 끓여먹어도 돼?", + "dormitory": "제1학생생활관", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "single_ramen_eat", + "function": "single", + "query_text": "방에서 라면 먹어도 돼?", + "dormitory": "제2학생생활관", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "single_convenience_store", + "function": "single", + "query_text": "2긱에 편의점 있어?", + "dormitory": "제2학생생활관", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "single_atm", + "function": "single", + "query_text": "2긱에 atm기 있어?", + "dormitory": "제2학생생활관", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "single_no_match", + "function": "single", + "query_text": "화성 탐사선 주차 규정", + "dormitory": "제1학생생활관", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "all_curfew_keyword", + "function": "all", + "query_text": "통금", + "top_k": 5, + "candidate_k": 30, + }, + { + "name": "all_capacity", + "function": "all", + "query_text": "기숙사 수용인원 몇명이야?", + "top_k": 5, + "candidate_k": 30, + }, + { + "name": "all_convenience_store", + "function": "all", + "query_text": "편의점 어디있어?", + "top_k": 5, + "candidate_k": 30, + }, + { + "name": "all_cost", + "function": "all", + "query_text": "기숙사 비용 얼마야?", + "top_k": 5, + "candidate_k": 30, + }, + { + "name": "grouped_cooking_keyword", + "function": "grouped", + "query_text": "취사", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "grouped_ramen_cook", + "function": "grouped", + "query_text": "라면 끓여 먹어도 돼?", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "grouped_ramen_eat_typo", + "function": "grouped", + "query_text": "방에서 라면 먹어도돼??", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "grouped_printer", + "function": "grouped", + "query_text": "프린트할 곳 있어?", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "grouped_cash_machine", + "function": "grouped", + "query_text": "현금자동입출기 있어?", + "top_k": 3, + "candidate_k": 20, + }, + { + "name": "grouped_no_match", + "function": "grouped", + "query_text": "화성 탐사선 주차 규정", + "top_k": 3, + "candidate_k": 20, + }, +] + + +@dataclass(frozen=True) +class CapturedQuery: + sql: str + params: dict[str, Any] + + +class _EmptyMappingResult: + def mappings(self) -> "_EmptyMappingResult": + return self + + def all(self) -> list[Any]: + return [] + + +class _CaptureSession: + def __init__(self) -> None: + self.captured: CapturedQuery | None = None + + def execute(self, statement: Any, params: dict[str, Any] | None = None) -> _EmptyMappingResult: + self.captured = CapturedQuery(sql=str(statement), params=params or {}) + return _EmptyMappingResult() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark hybrid search repository functions.") + parser.add_argument("--iterations", type=int, default=50) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--keyword-weight", type=float, default=0.3) + parser.add_argument( + "--embedding-mode", + choices=["anchor", "zero"], + default="anchor", + help="Use an existing chunk embedding as a stable anchor, or a zero vector.", + ) + parser.add_argument( + "--sample-keyword-cases", + type=int, + default=3, + help="Append this many DB-derived keyword cases with known active chunks.", + ) + parser.add_argument("--explain", action="store_true", help="Also collect EXPLAIN ANALYZE BUFFERS JSON.") + parser.add_argument( + "--include-raw-plan", + action="store_true", + help="Include full EXPLAIN FORMAT JSON plans. By default only plan summaries are emitted.", + ) + parser.add_argument("--output", type=Path, help="Write benchmark result JSON to this path.") + parser.add_argument( + "--database-url", + default=os.getenv("DATABASE_URL"), + help="Override settings.database_url, useful for local before/after benchmarking.", + ) + args = parser.parse_args() + + settings = get_settings() + database_url = args.database_url or settings.database_url + session_factory = sessionmaker( + bind=create_engine(database_url, future=True), + autocommit=False, + autoflush=False, + class_=Session, + ) + db = session_factory() + + try: + embedding_dimension = _get_embedding_dimension(db, settings.openai_embedding_dimension) + query_embedding = _get_query_embedding(db, args.embedding_mode, embedding_dimension) + cases = DEFAULT_CASES + _get_sample_keyword_cases(db, args.sample_keyword_cases) + + result = { + "metadata": { + "created_at": datetime.now().isoformat(timespec="seconds"), + "app_env": settings.app_env, + "database_url": _redact_database_url(database_url), + "iterations": args.iterations, + "warmup": args.warmup, + "keyword_weight": args.keyword_weight, + "embedding_mode": args.embedding_mode, + "embedding_dimension": embedding_dimension, + "case_count": len(cases), + "active_chunk_count": _scalar_int( + db, + """ + SELECT COUNT(*) + FROM regulation_chunk rc + JOIN regulation_document rd + ON rd.regulation_document_id = rc.regulation_document_id + WHERE rc.is_active = TRUE + AND rd.is_active = TRUE + AND rc.embedding IS NOT NULL + """, + ), + }, + "cases": [], + } + + for case in cases: + benchmark = _benchmark_case( + db=db, + case=case, + query_embedding=query_embedding, + iterations=args.iterations, + warmup=args.warmup, + keyword_weight=args.keyword_weight, + grouped_dormitories=settings.chat_grouped_dormitories, + ) + + if args.explain: + benchmark["explain"] = _explain_case( + db=db, + case=case, + query_embedding=query_embedding, + keyword_weight=args.keyword_weight, + grouped_dormitories=settings.chat_grouped_dormitories, + include_raw_plan=args.include_raw_plan, + ) + + result["cases"].append(benchmark) + + payload = json.dumps(result, ensure_ascii=False, indent=2) + if args.output: + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(payload + "\n", encoding="utf-8") + + print(payload) + finally: + db.close() + + +def _benchmark_case( + *, + db: Any, + case: dict[str, Any], + query_embedding: list[float], + iterations: int, + warmup: int, + keyword_weight: float, + grouped_dormitories: list[str], +) -> dict[str, Any]: + for _ in range(warmup): + _run_case( + db=db, + case=case, + query_embedding=query_embedding, + keyword_weight=keyword_weight, + grouped_dormitories=grouped_dormitories, + ) + + elapsed_ms = [] + last_rows = [] + for _ in range(iterations): + started_at = time.perf_counter() + last_rows = _run_case( + db=db, + case=case, + query_embedding=query_embedding, + keyword_weight=keyword_weight, + grouped_dormitories=grouped_dormitories, + ) + elapsed_ms.append((time.perf_counter() - started_at) * 1000) + + return { + "name": case["name"], + "function": case["function"], + "query_text": case["query_text"], + "top_k": case["top_k"], + "candidate_k": case["candidate_k"], + "timing_ms": _summarize_timings(elapsed_ms), + "top_results": [ + { + "regulation_chunk_id": row["regulation_chunk_id"], + "document_id": row["document_id"], + "retrieval_group": row["retrieval_group"], + "hybrid_score": row.get("hybrid_score"), + "vector_score": row.get("vector_score"), + "keyword_score": row.get("keyword_score"), + } + for row in last_rows + ], + } + + +def _run_case( + *, + db: Any, + case: dict[str, Any], + query_embedding: list[float], + keyword_weight: float, + grouped_dormitories: list[str], +) -> list[dict[str, Any]]: + common = { + "query_text": case["query_text"], + "query_embedding": query_embedding, + "top_k": case["top_k"], + "candidate_k": case["candidate_k"], + "keyword_weight": keyword_weight, + } + + if case["function"] == "single": + return repo.search_hybrid_chunks(db=db, dormitory=case["dormitory"], **common) + if case["function"] == "all": + return repo.search_hybrid_chunks_all_dormitories(db=db, **common) + if case["function"] == "grouped": + return repo.search_hybrid_chunks_for_dormitories( + db=db, + dormitories=grouped_dormitories, + **common, + ) + + raise ValueError(f"Unknown benchmark function: {case['function']}") + + +def _explain_case( + *, + db: Any, + case: dict[str, Any], + query_embedding: list[float], + keyword_weight: float, + grouped_dormitories: list[str], + include_raw_plan: bool, +) -> dict[str, Any]: + capture_db = _CaptureSession() + _run_case( + db=capture_db, + case=case, + query_embedding=query_embedding, + keyword_weight=keyword_weight, + grouped_dormitories=grouped_dormitories, + ) + + if capture_db.captured is None: + raise RuntimeError(f"Failed to capture SQL for benchmark case {case['name']}") + + explain_sql = text("EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) " + capture_db.captured.sql) + explain_result = db.execute(explain_sql, capture_db.captured.params).scalar_one() + plan = explain_result[0] + + explain = { + "planning_time_ms": plan.get("Planning Time"), + "execution_time_ms": plan.get("Execution Time"), + "top_plan_node": plan.get("Plan", {}).get("Node Type"), + "plan_summary": _summarize_plan(plan.get("Plan", {})), + } + if include_raw_plan: + explain["raw_plan"] = plan + + return explain + + +def _summarize_timings(values: list[float]) -> dict[str, float]: + sorted_values = sorted(values) + return { + "min": round(sorted_values[0], 3), + "avg": round(statistics.fmean(sorted_values), 3), + "median": round(statistics.median(sorted_values), 3), + "p95": round(_percentile(sorted_values, 95), 3), + "max": round(sorted_values[-1], 3), + } + + +def _percentile(sorted_values: list[float], percentile: float) -> float: + if len(sorted_values) == 1: + return sorted_values[0] + + rank = (len(sorted_values) - 1) * (percentile / 100) + lower = int(rank) + upper = min(lower + 1, len(sorted_values) - 1) + weight = rank - lower + return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight + + +def _summarize_plan(plan: dict[str, Any]) -> dict[str, Any]: + node_counts: dict[str, int] = {} + index_names: set[str] = set() + + def visit(node: dict[str, Any]) -> None: + node_type = node.get("Node Type") + if node_type: + node_counts[node_type] = node_counts.get(node_type, 0) + 1 + index_name = node.get("Index Name") + if index_name: + index_names.add(index_name) + for child in node.get("Plans", []): + visit(child) + + visit(plan) + return { + "node_counts": dict(sorted(node_counts.items())), + "index_names": sorted(index_names), + "shared_hit_blocks": _sum_plan_value(plan, "Shared Hit Blocks"), + "shared_read_blocks": _sum_plan_value(plan, "Shared Read Blocks"), + "temp_read_blocks": _sum_plan_value(plan, "Temp Read Blocks"), + "temp_written_blocks": _sum_plan_value(plan, "Temp Written Blocks"), + } + + +def _sum_plan_value(plan: dict[str, Any], key: str) -> int: + total = int(plan.get(key, 0) or 0) + for child in plan.get("Plans", []): + total += _sum_plan_value(child, key) + return total + + +def _get_embedding_dimension(db: Any, fallback: int) -> int: + dimension = db.execute( + text( + """ + SELECT vector_dims(embedding) + FROM regulation_chunk + WHERE embedding IS NOT NULL + LIMIT 1 + """ + ) + ).scalar_one_or_none() + return int(dimension or fallback) + + +def _get_query_embedding(db: Any, embedding_mode: str, embedding_dimension: int) -> list[float]: + if embedding_mode == "zero": + return [0.0] * embedding_dimension + + embedding = db.execute( + text( + """ + SELECT embedding::text + FROM regulation_chunk + WHERE is_active = TRUE + AND embedding IS NOT NULL + ORDER BY regulation_chunk_id + LIMIT 1 + """ + ) + ).scalar_one_or_none() + if embedding is None: + return [0.0] * embedding_dimension + + return [float(value) for value in embedding.strip("[]").split(",")] + + +def _get_sample_keyword_cases(db: Any, limit: int) -> list[dict[str, Any]]: + if limit <= 0: + return [] + + rows = db.execute( + text( + """ + SELECT + rc.regulation_chunk_id, + rd.dormitory, + keyword + FROM regulation_chunk rc + JOIN regulation_document rd + ON rd.regulation_document_id = rc.regulation_document_id + CROSS JOIN LATERAL jsonb_array_elements_text(rc.keywords) AS keyword + WHERE rc.is_active = TRUE + AND rd.is_active = TRUE + AND rc.embedding IS NOT NULL + AND rc.keywords IS NOT NULL + AND jsonb_array_length(rc.keywords) > 0 + ORDER BY rc.regulation_chunk_id + LIMIT :limit + """ + ), + {"limit": limit}, + ).mappings().all() + + cases = [] + for index, row in enumerate(rows, start=1): + keyword = str(row.keyword).strip() + if not keyword: + continue + + if row.dormitory: + cases.append( + { + "name": f"sample_keyword_single_{index}", + "function": "single", + "query_text": keyword, + "dormitory": row.dormitory, + "top_k": 3, + "candidate_k": 20, + "sample_regulation_chunk_id": row.regulation_chunk_id, + } + ) + else: + cases.append( + { + "name": f"sample_keyword_all_{index}", + "function": "all", + "query_text": keyword, + "top_k": 5, + "candidate_k": 30, + "sample_regulation_chunk_id": row.regulation_chunk_id, + } + ) + + return cases + + +def _scalar_int(db: Any, sql: str) -> int: + return int(db.execute(text(sql)).scalar_one()) + + +def _redact_database_url(database_url: str) -> str: + if "@" not in database_url or "://" not in database_url: + return database_url + + scheme, rest = database_url.split("://", 1) + return f"{scheme}://***@{rest.split('@', 1)[1]}" + + +if __name__ == "__main__": + main() diff --git a/tests/test_db.py b/tests/test_db.py index 7019bec..9baebc6 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -76,6 +76,7 @@ def test_regulation_chunk_columns_match_expected_schema() -> None: "chunk_index", "chunk_text", "keywords", + "search_tsvector", "embedding", "chunk_hash", "embedding_model", @@ -93,6 +94,7 @@ def test_regulation_chunk_indexes_match_expected_schema() -> None: "idx_regulation_chunk_document_version", "idx_regulation_chunk_chunk_id", "idx_regulation_chunk_is_active", + "idx_regulation_chunk_search_tsvector", } diff --git a/tests/test_regulation_chunk_repository.py b/tests/test_regulation_chunk_repository.py index 5ba7ccd..cb780c7 100644 --- a/tests/test_regulation_chunk_repository.py +++ b/tests/test_regulation_chunk_repository.py @@ -16,16 +16,25 @@ def __init__(self) -> None: self.added: list[object] = [] self.flush_called = False self.refresh_called_values: list[object] = [] + self.executed_statements: list[object] = [] + self.executed_params: list[dict] = [] def add(self, value) -> None: self.added.append(value) def flush(self) -> None: self.flush_called = True + for index, value in enumerate(self.added, start=1): + value.regulation_chunk_id = index def refresh(self, value) -> None: self.refresh_called_values.append(value) + def execute(self, statement, params): + self.executed_statements.append(statement) + self.executed_params.append(params) + return SimpleNamespace(rowcount=len(params["regulation_chunk_ids"])) + def test_create_regulation_chunks_for_document_maps_document_fields_to_model() -> None: db = FakeSession() @@ -54,6 +63,8 @@ def test_create_regulation_chunks_for_document_maps_document_fields_to_model() - assert regulation_chunk.keywords == ["외박", "외출"] assert regulation_chunk.embedding_model == "text-embedding-3-small" assert db.flush_called is True + assert db.executed_params == [{"regulation_chunk_ids": [1]}] + assert "search_tsvector = to_tsvector" in str(db.executed_statements[0]) assert db.refresh_called_values == [regulation_chunk] @@ -98,6 +109,7 @@ def execute(self, statement): def test_search_hybrid_chunks_maps_hybrid_score_to_similarity() -> None: executed_params: list[dict] = [] + executed_statements: list[object] = [] class MappingResult: def mappings(self): @@ -125,7 +137,8 @@ def all(self): ] class HybridSession: - def execute(self, _statement, params): + def execute(self, statement, params): + executed_statements.append(statement) executed_params.append(params) return MappingResult() @@ -139,6 +152,9 @@ def execute(self, _statement, params): assert executed_params[0]["query_text"] == "외박 신청" assert executed_params[0]["dormitory"] == "제1학생생활관" + executed_sql = str(executed_statements[0]) + assert "rc.search_tsvector" in executed_sql + assert "to_tsvector(" not in executed_sql assert result[0]["similarity"] == 0.874 assert result[0]["vector_similarity"] == 0.82 assert result[0]["vector_score"] == 0.82 @@ -170,3 +186,26 @@ def execute(self, _statement, params): assert result == [] assert executed_params[0]["dormitories"] == ["제1학생생활관", "제2학생생활관"] + + +def test_refresh_search_vectors_for_chunks_updates_tsvector_from_document_content() -> None: + executed_statements: list[object] = [] + executed_params: list[dict] = [] + + class RefreshSession: + def execute(self, statement, params): + executed_statements.append(statement) + executed_params.append(params) + return SimpleNamespace(rowcount=2) + + result = regulation_chunk_repository.refresh_search_vectors_for_chunks( + RefreshSession(), + [10, 11], + ) + + executed_sql = str(executed_statements[0]) + assert result == 2 + assert "UPDATE regulation_chunk AS rc" in executed_sql + assert "search_tsvector = to_tsvector" in executed_sql + assert "COALESCE(rd.content, '')" in executed_sql + assert executed_params == [{"regulation_chunk_ids": [10, 11]}]