Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from db_utils import get_db_connector
from db_utils.base_connector import BaseConnector
from infra.db.connect_db import ConnectDB
from viz.display_chart import DisplayChart
from engine.query_executor import execute_query as execute_query_common
from llm_utils.llm_response_parser import LLMResponseParser
Expand All @@ -30,6 +29,8 @@
"show_sql": "Show SQL",
"show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI",
"show_referenced_tables": "Show List of Referenced Tables",
"show_question_gate_result": "Show Question Gate Result",
"show_document_suitability": "Show Document Suitability",
"show_table": "Show Table",
"show_chart": "Show Chart",
}
Expand Down Expand Up @@ -103,8 +104,55 @@ def should_show(_key: str) -> bool:
show_sql_section = has_query and should_show("show_sql")
show_result_desc = has_query and should_show("show_result_description")
show_reinterpreted = has_query and should_show("show_question_reinterpreted_by_ai")
show_gate_result = should_show("show_question_gate_result")
show_doc_suitability = should_show("show_document_suitability")
show_table_section = has_query and should_show("show_table")
show_chart_section = has_query and should_show("show_chart")
if show_gate_result and ("question_gate_result" in res):
st.markdown("---")
st.markdown("**Question Gate 결과:**")
details = res.get("question_gate_result")
if details:
try:
import json as _json

st.code(
_json.dumps(details, ensure_ascii=False, indent=2), language="json"
)
except Exception:
st.write(details)

if show_doc_suitability and ("document_suitability" in res):
st.markdown("---")
st.markdown("**문서 적합성 평가:**")
ds = res.get("document_suitability")
if not isinstance(ds, dict):
st.write(ds)
else:

def _as_float(value):
try:
return float(value)
except Exception:
return -1.0

rows = [
{
"table": table_name,
"score": _as_float(info.get("score", -1)),
"matched_columns": ", ".join(info.get("matched_columns", [])),
"missing_entities": ", ".join(info.get("missing_entities", [])),
"reason": info.get("reason", ""),
}
for table_name, info in ds.items()
if isinstance(info, dict)
]

rows.sort(key=lambda r: r["score"], reverse=True)
if rows:
st.dataframe(rows, use_container_width=True)
else:
st.info("문서 적합성 평가 결과가 비어 있습니다.")

if should_show("show_token_usage"):
st.markdown("---")
Expand Down
92 changes: 88 additions & 4 deletions llm_utils/chains.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
"""
LLM 체인 생성 모듈.

이 모듈은 Lang2SQL에서 사용하는 다양한 LangChain 기반 체인을 정의합니다.
- Query Maker
- Query Enrichment
- Profile Extraction
- Question Gate (SQL 적합성 분류)
"""

import os
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from pydantic import BaseModel, Field
from llm_utils.output_parser.question_suitability import QuestionSuitability
from llm_utils.output_parser.document_suitability import (
DocumentSuitabilityList,
)

from llm_utils.llm import get_llm

Expand All @@ -15,6 +28,12 @@


class QuestionProfile(BaseModel):
"""
자연어 질문의 특징을 구조화해 표현하는 프로파일 모델.

이 프로파일은 이후 컨텍스트 보강 및 SQL 생성 시 힌트로 사용됩니다.
"""

is_timeseries: bool = Field(description="시계열 분석 필요 여부")
is_aggregation: bool = Field(description="집계 함수 필요 여부")
has_filter: bool = Field(description="조건 필터 필요 여부")
Expand All @@ -26,6 +45,15 @@ class QuestionProfile(BaseModel):

# QueryMakerChain
def create_query_maker_chain(llm):
"""
SQL 쿼리 생성을 위한 체인을 생성합니다.

Args:
llm: LangChain 호환 LLM 인스턴스

Returns:
Runnable: 입력 프롬프트를 받아 SQL을 생성하는 체인
"""
prompt = get_prompt_template("query_maker_prompt")
query_maker_prompt = ChatPromptTemplate.from_messages(
[
Expand All @@ -36,6 +64,15 @@ def create_query_maker_chain(llm):


def create_query_enrichment_chain(llm):
"""
사용자 질문을 메타데이터로 보강하기 위한 체인을 생성합니다.

Args:
llm: LangChain 호환 LLM 인스턴스

Returns:
Runnable: 보강된 질문 텍스트를 반환하는 체인
"""
prompt = get_prompt_template("query_enrichment_prompt")

enrichment_prompt = ChatPromptTemplate.from_messages(
Expand All @@ -49,6 +86,15 @@ def create_query_enrichment_chain(llm):


def create_profile_extraction_chain(llm):
"""
질문으로부터 `QuestionProfile`을 추출하는 체인을 생성합니다.

Args:
llm: LangChain 호환 LLM 인스턴스

Returns:
Runnable: `QuestionProfile` 구조화 출력을 반환하는 체인
"""
prompt = get_prompt_template("profile_extraction_prompt")

profile_prompt = ChatPromptTemplate.from_messages(
Expand All @@ -61,9 +107,47 @@ def create_profile_extraction_chain(llm):
return chain


def create_question_gate_chain(llm):
"""
질문 적합성(Question Gate) 체인을 생성합니다.

ChatPromptTemplate(SystemMessage) + LLM 구조화 출력으로
`QuestionSuitability`를 반환합니다.

Args:
llm: LangChain 호환 LLM 인스턴스

Returns:
Runnable: invoke({"question": str}) -> QuestionSuitability
"""

prompt = get_prompt_template("question_gate_prompt")
gate_prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate.from_template(prompt)]
)
return gate_prompt | llm.with_structured_output(QuestionSuitability)


def create_document_suitability_chain(llm):
"""
문서 적합성 평가 체인을 생성합니다.

질문(question)과 검색 결과(tables)를 입력으로 받아
테이블별 적합도 점수를 포함한 JSON 딕셔너리를 반환합니다.

Returns:
Runnable: invoke({"question": str, "tables": dict}) -> {"results": DocumentSuitability[]}
"""

prompt = get_prompt_template("document_suitability_prompt")
doc_prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate.from_template(prompt)]
)
return doc_prompt | llm.with_structured_output(DocumentSuitabilityList)


query_maker_chain = create_query_maker_chain(llm)
profile_extraction_chain = create_profile_extraction_chain(llm)
query_enrichment_chain = create_query_enrichment_chain(llm)

if __name__ == "__main__":
pass
question_gate_chain = create_question_gate_chain(llm)
document_suitability_chain = create_document_suitability_chain(llm)
99 changes: 95 additions & 4 deletions llm_utils/graph_utils/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import os
import json

from typing_extensions import TypedDict, Annotated
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages


from llm_utils.chains import (
query_maker_chain,
profile_extraction_chain,
query_enrichment_chain,
question_gate_chain,
document_suitability_chain,
)

from llm_utils.tools import get_info_from_db
from llm_utils.retrieval import search_tables
from llm_utils.graph_utils.profile_utils import profile_to_text

# 노드 식별자 정의
QUESTION_GATE = "question_gate"
EVALUATE_DOCUMENT_SUITABILITY = "evaluate_document_suitability"
GET_TABLE_INFO = "get_table_info"
TOOL = "tool"
TABLE_FILTER = "table_filter"
Expand All @@ -30,12 +30,39 @@ class QueryMakerState(TypedDict):
messages: Annotated[list, add_messages]
user_database_env: str
searched_tables: dict[str, dict[str, str]]
document_suitability: dict
best_practice_query: str
question_profile: dict
generated_query: str
retriever_name: str
top_n: int
device: str
question_gate_result: dict


# 노드 함수: QUESTION_GATE 노드
def question_gate_node(state: QueryMakerState):
"""
사용자의 질문이 SQL로 답변 가능한지 판별하고, 구조화된 결과를 반환하는 게이트 노드입니다.

- question_gate_chain 으로 적합성을 판정하여
`question_gate_result`를 설정합니다.

Args:
state (QueryMakerState): 그래프 상태

Returns:
QueryMakerState: 게이트 판정 결과가 반영된 상태
"""

question_text = state["messages"][0].content
suitability = question_gate_chain.invoke({"question": question_text})
state["question_gate_result"] = {
"reason": getattr(suitability, "reason", ""),
"missing_entities": getattr(suitability, "missing_entities", []),
"requires_data_science": getattr(suitability, "requires_data_science", False),
}
return state


# 노드 함수: PROFILE_EXTRACTION 노드
Expand Down Expand Up @@ -132,6 +159,70 @@ def get_table_info_node(state: QueryMakerState):
return state


# 노드 함수: DOCUMENT_SUITABILITY 노드
def document_suitability_node(state: QueryMakerState):
"""
GET_TABLE_INFO에서 수집된 테이블 후보들에 대해 문서 적합성 점수를 계산하는 노드입니다.

질문(`messages[0].content`)과 `searched_tables`(테이블→칼럼 설명 맵)를 입력으로
프롬프트 체인(`document_suitability_chain`)을 호출하고, 결과 딕셔너리를
`document_suitability` 상태 키에 저장합니다.

Returns:
QueryMakerState: 문서 적합성 평가 결과가 포함된 상태
"""

# 관련 테이블이 없으면 즉시 반환
if not state.get("searched_tables"):
state["document_suitability"] = {}
return state

res = document_suitability_chain.invoke(
{
"question": state["messages"][0].content,
"tables": state["searched_tables"],
}
)

items = (
res.get("results", [])
if isinstance(res, dict)
else getattr(res, "results", None)
or (res.model_dump().get("results", []) if hasattr(res, "model_dump") else [])
)

normalized = {}
for x in items:
d = (
x.model_dump()
if hasattr(x, "model_dump")
else (
x
if isinstance(x, dict)
else {
"table_name": getattr(x, "table_name", ""),
"score": getattr(x, "score", 0),
"reason": getattr(x, "reason", ""),
"matched_columns": getattr(x, "matched_columns", []),
"missing_entities": getattr(x, "missing_entities", []),
}
)
)
t = d.get("table_name")
if not t:
continue
normalized[t] = {
"score": float(d.get("score", 0)),
"reason": d.get("reason", ""),
"matched_columns": d.get("matched_columns", []),
"missing_entities": d.get("missing_entities", []),
}

state["document_suitability"] = normalized

return state


# 노드 함수: QUERY_MAKER 노드
def query_maker_node(state: QueryMakerState):
# 사용자 원 질문 + (있다면) 컨텍스트 보강 결과를 하나의 문자열로 결합
Expand Down
25 changes: 23 additions & 2 deletions llm_utils/graph_utils/basic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from langgraph.graph import StateGraph, END
from llm_utils.graph_utils.base import (
QueryMakerState,
QUESTION_GATE,
GET_TABLE_INFO,
EVALUATE_DOCUMENT_SUITABILITY,
QUERY_MAKER,
question_gate_node,
get_table_info_node,
document_suitability_node,
query_maker_node,
)

Expand All @@ -16,14 +20,31 @@

# StateGraph 생성 및 구성
builder = StateGraph(QueryMakerState)
builder.set_entry_point(GET_TABLE_INFO)
builder.set_entry_point(QUESTION_GATE)

# 노드 추가
builder.add_node(QUESTION_GATE, question_gate_node)
builder.add_node(GET_TABLE_INFO, get_table_info_node)
builder.add_node(EVALUATE_DOCUMENT_SUITABILITY, document_suitability_node)
builder.add_node(QUERY_MAKER, query_maker_node)


def _route_after_gate(state: QueryMakerState):
return GET_TABLE_INFO


builder.add_conditional_edges(
QUESTION_GATE,
_route_after_gate,
{
GET_TABLE_INFO: GET_TABLE_INFO,
END: END,
},
)

# 기본 엣지 설정
builder.add_edge(GET_TABLE_INFO, QUERY_MAKER)
builder.add_edge(GET_TABLE_INFO, EVALUATE_DOCUMENT_SUITABILITY)
builder.add_edge(EVALUATE_DOCUMENT_SUITABILITY, QUERY_MAKER)

# QUERY_MAKER 노드 후 종료
builder.add_edge(QUERY_MAKER, END)
Loading