diff --git a/app/core/config.py b/app/core/config.py index c3a8ac5..93278f8 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -18,12 +18,14 @@ class Settings(BaseSettings): openai_timeout_seconds: float = 10.0 chat_answer_model: str = "gpt-4o-mini" notice_summary_model: str = "gpt-4o-mini" + + chat_retrieval_version_single: str = "hybrid-dormitory-search-v1" + chat_retrieval_version_grouped: str = "hybrid-dormitory-search-unspecified-v1" + chat_retrieval_method_single: str = "hybrid_dormitory_top_k" + chat_retrieval_method_grouped: str = "hybrid_unspecified_dormitory_top_k" + chat_prompt_version_single: str = "chat-answer-citation-v2" chat_prompt_version_grouped: str = "chat-answer-grouped-citation-v2" - chat_retrieval_version_single: str = "dormitory-search-v1" - chat_retrieval_version_grouped: str = "dormitory-search-unspecified-v1" - chat_retrieval_method_single: str = "vector_dormitory_top_k" - chat_retrieval_method_grouped: str = "vector_unspecified_dormitory_top_k" chat_no_answer_message: str = "관련 정보를 찾을 수 없습니다." chat_invalid_question_message: str = "기숙사 관련 질문을 입력해주세요." chat_single_dormitory_top_k: int = 3 @@ -33,8 +35,8 @@ class Settings(BaseSettings): chat_session_timeout_minutes: int = 30 chat_fallback_top_k: int = 5 - chat_retrieval_method_fallback: str = "vector_all_dormitories_fallback" - chat_retrieval_version_fallback: str = "dormitory-search-fallback-v1" + chat_retrieval_method_fallback: str = "hybrid_all_dormitories_fallback" + chat_retrieval_version_fallback: str = "hybrid-dormitory-search-fallback-v1" chat_fallback_similarity_threshold: float = 0.35 @@ -56,4 +58,3 @@ def get_settings() -> Settings: Settings.model_config = SettingsConfigDict(extra="ignore") - diff --git a/app/repositories/regulation_chunk_repository.py b/app/repositories/regulation_chunk_repository.py index a0aa7be..22bb337 100644 --- a/app/repositories/regulation_chunk_repository.py +++ b/app/repositories/regulation_chunk_repository.py @@ -261,4 +261,270 @@ def search_similar_chunks_all_dormitories( "similarity": float(row.similarity), } for row in result - ] \ No newline at end of file + ] + + +def search_hybrid_chunks( + db: Session, + query_text: str, + query_embedding: list[float], + dormitory: str, + top_k: int = 3, + candidate_k: int = 20, + keyword_weight: float = 0.3, +): + """단일 생활관과 공통 문서를 대상으로 하이브리드 검색합니다.""" + + return _search_hybrid_chunks( + db=db, + query_text=query_text, + query_embedding=query_embedding, + top_k=top_k, + candidate_k=candidate_k, + keyword_weight=keyword_weight, + filter_sql="(rd.dormitory = :dormitory OR rd.dormitory IS NULL)", + params={"dormitory": dormitory}, + ) + + +def search_hybrid_chunks_for_dormitories( + db: Session, + query_text: str, + query_embedding: list[float], + dormitories: list[str], + top_k: int = 3, + candidate_k: int = 20, + keyword_weight: float = 0.3, +): + """여러 생활관과 공통 문서를 대상으로 하이브리드 검색합니다.""" + + return _search_hybrid_chunks( + db=db, + query_text=query_text, + query_embedding=query_embedding, + top_k=top_k, + candidate_k=candidate_k, + keyword_weight=keyword_weight, + filter_sql="(rd.dormitory = ANY(:dormitories) OR rd.dormitory IS NULL)", + params={"dormitories": dormitories}, + ) + + +def search_hybrid_chunks_all_dormitories( + db: Session, + query_text: str, + query_embedding: list[float], + top_k: int = 5, + candidate_k: int = 30, + keyword_weight: float = 0.3, +): + """생활관 필터 없이 전체 활성 regulation_chunk를 대상으로 하이브리드 검색합니다.""" + + return _search_hybrid_chunks( + db=db, + query_text=query_text, + query_embedding=query_embedding, + top_k=top_k, + candidate_k=candidate_k, + keyword_weight=keyword_weight, + filter_sql="TRUE", + params={}, + ) + + +def _search_hybrid_chunks( + db: Session, + *, + query_text: str, + query_embedding: list[float], + top_k: int, + candidate_k: int, + keyword_weight: float, + filter_sql: str, + params: dict, +): + """ + 벡터 유사도와 키워드 점수를 가중합해 검색합니다. + + 반환값의 similarity는 최종 결합 점수이며, 벡터 단독 점수는 vector_similarity에 남깁니다. + """ + + embedding_str = "[" + ",".join(map(str, query_embedding)) + "]" + + sql = text( + f""" + WITH vector_search AS ( + SELECT + rc.regulation_chunk_id, + rd.document_id, + rd.document_version, + rc.chunk_id, + COALESCE(rc.chunk_text, rd.content, '') AS content, + rd.source, + rd.source_url, + rd.dormitory, + 1 - (rc.embedding <=> CAST(:embedding AS vector)) AS vector_similarity, + NULL::float AS keyword_score, + ROW_NUMBER() OVER ( + ORDER BY rc.embedding <=> CAST(:embedding AS vector) + ) AS vector_rank, + NULL::bigint AS keyword_rank + FROM regulation_chunk rc + JOIN regulation_document rd + ON rd.regulation_document_id = rc.regulation_document_id + WHERE {filter_sql} + AND rc.is_active = TRUE + AND rd.is_active = TRUE + AND rc.embedding IS NOT NULL + ORDER BY rc.embedding <=> CAST(:embedding AS vector) + LIMIT :candidate_k + ), + + keyword_search AS ( + SELECT + rc.regulation_chunk_id, + rd.document_id, + rd.document_version, + rc.chunk_id, + COALESCE(rc.chunk_text, rd.content, '') AS content, + rd.source, + rd.source_url, + rd.dormitory, + 1 - (rc.embedding <=> CAST(:embedding AS vector)) AS vector_similarity, + ts_rank_cd( + to_tsvector( + 'simple', + COALESCE(rc.chunk_text, '') || ' ' || + COALESCE(rd.content, '') || ' ' || + COALESCE(rc.keywords::text, '') + ), + websearch_to_tsquery('simple', :query_text) + ) AS keyword_score, + NULL::bigint AS vector_rank, + ROW_NUMBER() OVER ( + ORDER BY ts_rank_cd( + to_tsvector( + 'simple', + COALESCE(rc.chunk_text, '') || ' ' || + COALESCE(rd.content, '') || ' ' || + COALESCE(rc.keywords::text, '') + ), + websearch_to_tsquery('simple', :query_text) + ) DESC + ) AS keyword_rank + FROM regulation_chunk rc + JOIN regulation_document rd + ON rd.regulation_document_id = rc.regulation_document_id + WHERE {filter_sql} + AND rc.is_active = TRUE + AND rd.is_active = TRUE + AND rc.embedding IS NOT NULL + AND to_tsvector( + 'simple', + COALESCE(rc.chunk_text, '') || ' ' || + COALESCE(rd.content, '') || ' ' || + COALESCE(rc.keywords::text, '') + ) @@ websearch_to_tsquery('simple', :query_text) + ORDER BY keyword_score DESC + LIMIT :candidate_k + ), + + combined AS ( + SELECT * FROM vector_search + UNION ALL + SELECT * FROM keyword_search + ), + + dedup AS ( + SELECT + regulation_chunk_id, + MAX(document_id) AS document_id, + MAX(document_version) AS document_version, + MAX(chunk_id) AS chunk_id, + MAX(content) AS content, + MAX(source) AS source, + MAX(source_url) AS source_url, + MAX(dormitory) AS dormitory, + MAX(vector_similarity) AS vector_similarity, + MAX(keyword_score) AS keyword_score, + MIN(vector_rank) AS vector_rank, + MIN(keyword_rank) AS keyword_rank + FROM combined + GROUP BY regulation_chunk_id + ), + + scored AS ( + SELECT + *, + LEAST(1, GREATEST(0, COALESCE(vector_similarity, 0))) AS vector_score, + COALESCE(keyword_score / NULLIF(MAX(keyword_score) OVER (), 0), 0) AS normalized_keyword_score, + ( + LEAST(1, GREATEST(0, COALESCE(vector_similarity, 0))) + + ( + :keyword_weight * + COALESCE(keyword_score / NULLIF(MAX(keyword_score) OVER (), 0), 0) * + (1 - LEAST(1, GREATEST(0, COALESCE(vector_similarity, 0)))) + ) + ) AS hybrid_score + FROM dedup + ) + + SELECT + regulation_chunk_id, + document_id, + document_version, + chunk_id, + content, + source, + source_url, + dormitory, + vector_similarity, + vector_score, + keyword_score, + normalized_keyword_score, + vector_rank, + keyword_rank, + hybrid_score + FROM scored + ORDER BY hybrid_score DESC + LIMIT :top_k + """ + ) + + result = db.execute( + sql, + { + "embedding": embedding_str, + "query_text": query_text.strip(), + "top_k": top_k, + "candidate_k": candidate_k, + "keyword_weight": keyword_weight, + **params, + }, + ).mappings().all() + + return [ + { + "regulation_chunk_id": row.regulation_chunk_id, + "document_id": row.document_id, + "document_version": row.document_version, + "chunk_id": row.chunk_id, + "content": row.content, + "source": row.source, + "source_url": row.source_url, + "retrieval_group": row.dormitory, + "similarity": float(row.hybrid_score), + "vector_similarity": float(row.vector_similarity) if row.vector_similarity is not None else None, + "vector_score": float(row.vector_score) if row.vector_score is not None else None, + "keyword_score": float(row.keyword_score) if row.keyword_score is not None else None, + "normalized_keyword_score": ( + float(row.normalized_keyword_score) + if row.normalized_keyword_score is not None + else None + ), + "vector_rank": int(row.vector_rank) if row.vector_rank is not None else None, + "keyword_rank": int(row.keyword_rank) if row.keyword_rank is not None else None, + "hybrid_score": float(row.hybrid_score), + } + for row in result + ] diff --git a/app/services/chat_service.py b/app/services/chat_service.py index ac2fabb..37d7ef0 100644 --- a/app/services/chat_service.py +++ b/app/services/chat_service.py @@ -22,8 +22,9 @@ from app.repositories.chat_log_repository import update_chat_log_result from app.repositories.chat_retrieval_result_repository import create_chat_retrieval_results from app.repositories.chat_retrieval_result_repository import mark_chat_retrieval_results_used_in_answer -from app.repositories.regulation_chunk_repository import search_similar_chunks -from app.repositories.regulation_chunk_repository import search_similar_chunks_for_dormitories +from app.repositories.regulation_chunk_repository import search_hybrid_chunks +from app.repositories.regulation_chunk_repository import search_hybrid_chunks_all_dormitories +from app.repositories.regulation_chunk_repository import search_hybrid_chunks_for_dormitories from app.schemas.chat import ChatRequest from app.schemas.chat import ChatResponse from app.services.embeddings import create_query_embedding @@ -31,8 +32,6 @@ from app.services.generator import generate_answer from app.services.validator import validate_question from app.services.query_rewriter import expand_query_for_retrieval - - from app.repositories.regulation_chunk_repository import search_similar_chunks_all_dormitories from app.services.room_floor_resolver import resolve_room_floor_question @@ -193,17 +192,21 @@ def _answer_single_dormitory_chat( query_embedding = create_query_embedding(retrieval_query) - chunks = search_similar_chunks( + chunks = search_hybrid_chunks( db=db, + query_text=question, query_embedding=query_embedding, dormitory=dormitory, top_k=settings.chat_single_dormitory_top_k, + candidate_k=20, + keyword_weight=0.3, ) # 1차 검색 결과가 없거나 유사도가 낮으면 전체 생활관 fallback 검색 if _should_fallback_retrieval(chunks): - chunks = search_similar_chunks_all_dormitories( + chunks = search_hybrid_chunks_all_dormitories( db=db, + query_text=question, query_embedding=query_embedding, top_k=settings.chat_fallback_top_k, ) @@ -248,8 +251,9 @@ def _answer_single_dormitory_chat( _is_no_answer(answer_result.answer) and retrieval_method != settings.chat_retrieval_method_fallback ): - fallback_chunks = search_similar_chunks_all_dormitories( + fallback_chunks = search_hybrid_chunks_all_dormitories( db=db, + query_text=question, query_embedding=query_embedding, top_k=settings.chat_fallback_top_k, ) @@ -348,7 +352,6 @@ def _answer_single_dormitory_chat( raise db.commit() - db.close() final_answer_status = ChatAnswerStatus.SUCCESS final_source_url = answer_result.source_url or "" @@ -407,8 +410,9 @@ def _answer_unspecified_dormitory_chat( raise try: - chunks = search_similar_chunks_for_dormitories( + chunks = search_hybrid_chunks_for_dormitories( db=db, + query_text=question, query_embedding=query_embedding, dormitories=settings.chat_grouped_dormitories, top_k=settings.chat_grouped_dormitory_top_k, @@ -499,7 +503,6 @@ def _answer_unspecified_dormitory_chat( raise db.commit() - db.close() final_answer_status = ChatAnswerStatus.SUCCESS final_source_url = answer_result.source_url or "" @@ -686,11 +689,16 @@ def _should_fallback_retrieval(chunks: list[dict]) -> bool: if not chunks: return True - top_similarity = chunks[0].get("similarity") - if top_similarity is None: + top_vector_score = chunks[0].get("vector_score") + if top_vector_score is None: + top_vector_score = chunks[0].get("vector_similarity") + if top_vector_score is None: + top_vector_score = chunks[0].get("similarity") + if top_vector_score is None: return True - return float(top_similarity) < settings.chat_fallback_similarity_threshold + + return float(top_vector_score) < settings.chat_fallback_similarity_threshold def _is_no_answer(answer: str) -> bool: settings = get_settings() @@ -842,4 +850,4 @@ def _should_pre_expand_query(question: str) -> bool: if any(trigger in compact_question for trigger in cooking_triggers): return True - return False \ No newline at end of file + return False diff --git a/tests/test_chat_service.py b/tests/test_chat_service.py index ca77a21..fcfbf6e 100644 --- a/tests/test_chat_service.py +++ b/tests/test_chat_service.py @@ -87,7 +87,7 @@ def test_answer_chat_question_returns_success_for_single_dormitory(monkeypatch: monkeypatch.setattr(chat_service, "create_query_embedding", lambda *_args, **_kwargs: [0.1, 0.2, 0.3]) monkeypatch.setattr( chat_service, - "search_similar_chunks", + "search_hybrid_chunks", lambda *_args, **_kwargs: [ { "regulation_chunk_id": 1001, @@ -154,7 +154,7 @@ def test_answer_chat_question_returns_success_for_single_dormitory(monkeypatch: assert retrieval_calls["mark_used_chat_log_id"] == 501 assert retrieval_calls["cited_regulation_chunk_ids"] == [1001] assert db.commit_count == 2 - assert db.close_count == 1 + assert db.close_count == 0 assert finalize_db.commit_count == 1 assert db.flush_count == 0 assert finalize_db.flush_count == 1 @@ -181,7 +181,7 @@ def test_answer_chat_question_uses_top_scored_chunks_when_dormitory_is_missing( monkeypatch.setattr(chat_service, "validate_question", lambda *_args, **_kwargs: (True, "택배는 어디서 받나요?")) monkeypatch.setattr(chat_service, "create_query_embedding", lambda *_args, **_kwargs: [0.1, 0.2, 0.3]) - def fake_search_similar_chunks_for_dormitories(*_args, **kwargs): + def fake_search_hybrid_chunks_for_dormitories(*_args, **kwargs): multi_search_calls.append(kwargs) return [ { @@ -218,13 +218,8 @@ def fake_generate_answer(_question, chunks): monkeypatch.setattr( chat_service, - "search_similar_chunks", - lambda *_args, **_kwargs: pytest.fail("unspecified dormitory should use a single multi-dormitory query"), - ) - monkeypatch.setattr( - chat_service, - "search_similar_chunks_for_dormitories", - fake_search_similar_chunks_for_dormitories, + "search_hybrid_chunks_for_dormitories", + fake_search_hybrid_chunks_for_dormitories, ) monkeypatch.setattr(chat_service, "generate_answer", fake_generate_answer) monkeypatch.setattr( @@ -400,7 +395,7 @@ def test_answer_chat_question_marks_error_when_generation_fails(monkeypatch: pyt monkeypatch.setattr(chat_service, "create_query_embedding", lambda *_args, **_kwargs: [0.1, 0.2, 0.3]) monkeypatch.setattr( chat_service, - "search_similar_chunks", + "search_hybrid_chunks", lambda *_args, **_kwargs: [ { "regulation_chunk_id": 1001, @@ -448,14 +443,14 @@ def raise_generation_error(*_args, **_kwargs): assert chat_log.answer_status == ChatAnswerStatus.ERROR assert chat_log.rewritten_query == "외박 신청은 어디서 하나요?" assert chat_log.answer == "" - assert retrieval_calls["chat_log_id"] == 501 + assert retrieval_calls == {} assert error_log_calls["chat_log_id"] == 501 assert error_log_calls["session_id"] == "session-123" assert error_log_calls["error_type"] == chat_service.ERROR_TYPE_LLM_API assert error_log_calls["occurred_step"] == chat_service.STEP_ANSWER_GENERATION assert error_log_calls["error_message"] == "llm failed" assert error_log_calls["error_detail"] == "RuntimeError: llm failed" - assert db.commit_count == 2 + assert db.commit_count == 1 assert db.close_count == 0 assert finalize_db.commit_count == 1 assert db.flush_count == 0 @@ -495,3 +490,24 @@ def test_build_chat_error_metadata_defaults_timeout_error() -> None: assert result.error_type == chat_service.ERROR_TYPE_TIMEOUT assert result.occurred_step is None assert result.error_message == "request timed out" + + +def test_should_fallback_retrieval_uses_vector_score_before_hybrid_similarity() -> None: + chunks = [ + { + "similarity": 0.9, + "vector_score": 0.2, + } + ] + + assert chat_service._should_fallback_retrieval(chunks) is True + + +def test_should_fallback_retrieval_falls_back_to_similarity_for_legacy_results() -> None: + chunks = [ + { + "similarity": 0.9, + } + ] + + assert chat_service._should_fallback_retrieval(chunks) is False diff --git a/tests/test_regulation_chunk_repository.py b/tests/test_regulation_chunk_repository.py index fe6acbd..5ba7ccd 100644 --- a/tests/test_regulation_chunk_repository.py +++ b/tests/test_regulation_chunk_repository.py @@ -94,3 +94,79 @@ def execute(self, statement): assert result == 3 assert len(executed_statements) == 1 + + +def test_search_hybrid_chunks_maps_hybrid_score_to_similarity() -> None: + executed_params: list[dict] = [] + + class MappingResult: + def mappings(self): + return self + + def all(self): + return [ + SimpleNamespace( + regulation_chunk_id=1001, + document_id="dorm-rule", + document_version="v1", + chunk_id="chunk-1", + content="외박 신청은 포털에서 가능합니다.", + source="생활관 규정집", + source_url="https://example.com/rules/1", + dormitory="제1학생생활관", + vector_similarity=0.82, + vector_score=0.82, + keyword_score=0.4, + normalized_keyword_score=1.0, + vector_rank=2, + keyword_rank=1, + hybrid_score=0.874, + ) + ] + + class HybridSession: + def execute(self, _statement, params): + executed_params.append(params) + return MappingResult() + + result = regulation_chunk_repository.search_hybrid_chunks( + db=HybridSession(), + query_text=" 외박 신청 ", + query_embedding=[0.1, 0.2, 0.3], + dormitory="제1학생생활관", + top_k=3, + ) + + assert executed_params[0]["query_text"] == "외박 신청" + assert executed_params[0]["dormitory"] == "제1학생생활관" + assert result[0]["similarity"] == 0.874 + assert result[0]["vector_similarity"] == 0.82 + assert result[0]["vector_score"] == 0.82 + assert result[0]["keyword_score"] == 0.4 + assert result[0]["normalized_keyword_score"] == 1.0 + + +def test_search_hybrid_chunks_for_dormitories_passes_dormitory_list() -> None: + executed_params: list[dict] = [] + + class EmptyMappingResult: + def mappings(self): + return self + + def all(self): + return [] + + class HybridSession: + def execute(self, _statement, params): + executed_params.append(params) + return EmptyMappingResult() + + result = regulation_chunk_repository.search_hybrid_chunks_for_dormitories( + db=HybridSession(), + query_text="택배", + query_embedding=[0.1, 0.2, 0.3], + dormitories=["제1학생생활관", "제2학생생활관"], + ) + + assert result == [] + assert executed_params[0]["dormitories"] == ["제1학생생활관", "제2학생생활관"]