diff --git a/ragu/common/base.py b/ragu/common/base.py index 5b39dcd..bab11aa 100644 --- a/ragu/common/base.py +++ b/ragu/common/base.py @@ -14,7 +14,8 @@ perform LLM-driven generation or structured response tasks. """ -from ragu.common.prompts import PromptTemplate, DEFAULT_PROMPT_TEMPLATES +from ragu.common.prompts import DEFAULT_PROMPT_TEMPLATES +from ragu.common.prompts.prompt_storage import RAGUInstruction class RaguGenerativeModule: @@ -26,26 +27,26 @@ class RaguGenerativeModule: :class:`PromptTemplate` instances directly. """ - def __init__(self, prompts: list[str] | dict[str, PromptTemplate]): + def __init__(self, prompts: list[str] | dict[str, RAGUInstruction]): """ Initialize the generative module with one or more prompts. :param prompts: Either a list of prompt names (loaded from :data:`DEFAULT_PROMPT_TEMPLATES`) or a dictionary - mapping prompt names to :class:`PromptTemplate` objects. + mapping prompt names to :class:`ChatTemplate` objects. :raises ValueError: If the input format is neither list nor dict. """ super().__init__() if isinstance(prompts, list): - self.prompts: dict[str, PromptTemplate] = { + self.prompts: dict[str, RAGUInstruction] = { prompt_name: DEFAULT_PROMPT_TEMPLATES.get(prompt_name) for prompt_name in prompts } elif isinstance(prompts, dict): self.prompts = prompts else: raise ValueError( - f"Prompts must be a list of prompt names or a dictionary of prompt names and PromptTemplate objects, " + f"Prompts must be a list of prompt names or a dictionary of prompt names and ChatTemplate objects, " f"got {type(prompts)}" ) @@ -53,16 +54,16 @@ def get_prompts(self) -> dict: """ Retrieve all prompt templates registered in the module. - :return: Dictionary mapping prompt names to :class:`PromptTemplate` objects. + :return: Dictionary mapping prompt names to :class:`ChatTemplate` objects. """ return self.prompts - def get_prompt(self, prompt_name: str) -> PromptTemplate: + def get_prompt(self, prompt_name: str) -> RAGUInstruction: """ Retrieve a specific prompt template by name. :param prompt_name: The name of the prompt to retrieve. - :return: The corresponding :class:`PromptTemplate` instance. + :return: The corresponding :class:`ChatTemplate` instance. :raises ValueError: If the prompt name is not found. """ if prompt_name in self.prompts: @@ -70,7 +71,7 @@ def get_prompt(self, prompt_name: str) -> PromptTemplate: else: raise ValueError(f"Prompt {prompt_name} not found") - def update_prompt(self, prompt_name: str, prompt: PromptTemplate) -> None: + def update_prompt(self, prompt_name: str, prompt: RAGUInstruction) -> None: """ Replace or add a prompt template in the module. diff --git a/ragu/common/cache.py b/ragu/common/cache.py index 8997beb..7a9f97a 100644 --- a/ragu/common/cache.py +++ b/ragu/common/cache.py @@ -10,6 +10,7 @@ from ragu.common.global_parameters import DEFAULT_FILENAMES, Settings from ragu.common.logger import logger +from ragu.common.prompts import ChatMessages from ragu.utils.ragu_utils import compute_mdhash_id @@ -19,13 +20,12 @@ class PendingRequest: Represents a request pending generation (not found in cache). """ index: int - prompt: str + messages: ChatMessages cache_key: str def make_llm_cache_key( - prompt: str, - system_prompt: Optional[str] = None, + content: str, model_name: Optional[str] = None, schema: Optional[Type[BaseModel]] = None, kwargs: Optional[Dict[str, Any]] = None, @@ -33,18 +33,12 @@ def make_llm_cache_key( """ Build a deterministic cache key from LLM request parameters. - :param prompt: The user prompt. - :param system_prompt: Optional system prompt. :param model_name: Model name used for generation. :param schema: Optional Pydantic schema class. :param kwargs: Additional API parameters. :return: A unique cache key string. """ - key_parts = [] - - if system_prompt: - key_parts.append(f"[system]: {system_prompt}") - key_parts.append(f"[user]: {prompt}") + key_parts = [content] if model_name: key_parts.append(f"[model]: {model_name}") @@ -121,7 +115,7 @@ def _load_cache(self) -> None: return try: - with self._cache_path.open("r", encoding="utf-16") as f: + with self._cache_path.open("r", encoding="utf-8") as f: cache = json.load(f) if isinstance(cache, dict): self._mem_cache = cache @@ -156,7 +150,7 @@ def _write_cache_file(self, path: Path) -> None: """ Write cache to file. """ - with path.open("w", encoding="utf-16") as f: + with path.open("w", encoding="utf-8") as f: json.dump(self._mem_cache, f, ensure_ascii=False, indent=2) async def get( diff --git a/ragu/common/prompts/__init__.py b/ragu/common/prompts/__init__.py index 7b7cb46..6453d89 100644 --- a/ragu/common/prompts/__init__.py +++ b/ragu/common/prompts/__init__.py @@ -1,10 +1,16 @@ -from ragu.common.prompts.prompt_storage import ( - PromptTemplate, - # MessageTemplate, - # ChatTemplate, - # ChatMessages, - # SystemMessage, - # AIMessage, - # UserMessage, - DEFAULT_PROMPT_TEMPLATES -) \ No newline at end of file +from ragu.common.prompts.prompt_storage import DEFAULT_PROMPT_TEMPLATES +from ragu.common.prompts.messages import ( + SystemMessage, + UserMessage, + AIMessage, + ChatMessages, +) + +__all__ = [ + "SystemMessage", + "UserMessage", + "AIMessage", + "ChatMessages", + + "DEFAULT_PROMPT_TEMPLATES", +] \ No newline at end of file diff --git a/ragu/common/prompts/messages.py b/ragu/common/prompts/messages.py new file mode 100644 index 0000000..a62a155 --- /dev/null +++ b/ragu/common/prompts/messages.py @@ -0,0 +1,206 @@ +from dataclasses import dataclass +from typing import ( + Literal, + Dict, + Any, + TypeVar, + List, + Type, + Sequence, + Mapping, + Union, +) + +from jinja2 import Environment, StrictUndefined +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, +) + +Role = Literal["system", "user", "assistant"] + + +@dataclass(frozen=True, slots=True) +class BaseMessage: + """ + Base chat message abstraction. + + Represents a single message in a chat conversation with a fixed role + (system, user, or assistant) and textual content. Provides conversion + to OpenAI SDK message types. + """ + content: str + role: Role + name: str | None = None + + def to_openai(self) -> ChatCompletionMessageParam: + """ + Convert this message into a typed OpenAI ChatCompletion message. + """ + if self.role == "system": + return ChatCompletionSystemMessageParam( + role="system", + content=self.content, + name=self.name, + ) + + if self.role == "user": + return ChatCompletionUserMessageParam( + role="user", + content=self.content, + name=self.name, + ) + + if self.role == "assistant": + return ChatCompletionAssistantMessageParam( + role="assistant", + content=self.content, + name=self.name, + ) + + raise ValueError(f"Unsupported role: {self.role}") + + def to_str(self) -> str: + """ + Return a human-readable string representation of the message. + """ + return f"[{self.role}]: {self.content}" + + +@dataclass(frozen=True, slots=True) +class SystemMessage(BaseMessage): + """ + System-level instruction message. + """ + role: Role = "system" + + +@dataclass(frozen=True, slots=True) +class UserMessage(BaseMessage): + """ + User input message. + """ + role: Role = "user" + + +@dataclass(frozen=True, slots=True) +class AIMessage(BaseMessage): + """ + Assistant (LLM) response message. + """ + role: Role = "assistant" + + +T = TypeVar("T", bound="ChatMessages") + + +@dataclass(frozen=True, slots=True) +class ChatMessages: + """ + Container for a list of chat messages. + + Represents a single user-assistant conversation. + """ + messages: List[BaseMessage] + + @classmethod + def from_messages(cls: Type[T], messages: Sequence[BaseMessage]) -> T: + """ + Construct a ChatMessages instance from a sequence of messages. + """ + return cls(messages=list(messages)) + + def to_openai(self) -> List[ChatCompletionMessageParam]: + """ + Convert all messages to OpenAI ChatCompletion message parameters. + """ + return [m.to_openai() for m in self.messages] + + def __iter__(self): + return iter(self.messages) + + def __len__(self): + return len(self.messages) + + def to_str(self) -> str: + """ + Return a readable multi-line string representation of the conversation. + """ + return "\n".join([m.to_str() for m in self.messages]) + + +def render(template_conversation: Union[BaseMessage, ChatMessages], **params: Any) -> List[ChatMessages]: + """ + Render Jinja2 templates inside message contents in batch mode. + + Parameters + ---------- + template_conversation: + A single message or a ChatMessages instance used as a template. + params: + A mix of scalar parameters (shared across all prompts) and + batch parameters (lists/tuples). All batch parameters must + have the same length N. + + Returns + ------- + List[ChatMessages] + A list of rendered ChatMessages objects of length N. + """ + + def _is_batch_value(v: Any) -> bool: + return isinstance(v, (list, tuple)) + + def _infer_batch_size(params: Mapping[str, Any]) -> int: + sizes = {len(v) for v in params.values() if _is_batch_value(v)} + if not sizes: + return 1 + if len(sizes) != 1: + raise ValueError(f"Batch parameters have different sizes: {sorted(sizes)}.") + return next(iter(sizes)) + + def _build_row_context(params: Mapping[str, Any], i: int) -> Dict[str, Any]: + row: Dict[str, Any] = {} + for k, v in params.items(): + row[k] = v[i] if _is_batch_value(v) else v + return row + + env = Environment( + undefined=StrictUndefined, + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, + ) + + if isinstance(template_conversation, BaseMessage): + template_cm = ChatMessages.from_messages([template_conversation]) + else: + template_cm = template_conversation + + n = _infer_batch_size(params) + + for k, v in params.items(): + if _is_batch_value(v) and len(v) != n: + raise ValueError( + f"Batch parameter '{k}' has length {len(v)}, expected {n}." + ) + + out: List[ChatMessages] = [] + for i in range(n): + ctx = _build_row_context(params, i) + + rendered_msgs: List[BaseMessage] = [] + for m in template_cm.messages: + tmpl = env.from_string(m.content) + new_content = tmpl.render(**ctx) + + msg_type = type(m) + rendered_msgs.append( + msg_type(role=m.role, content=new_content, name=m.name) + ) + + out.append(ChatMessages.from_messages(rendered_msgs)) + + return out diff --git a/ragu/common/prompts/prompt_storage.py b/ragu/common/prompts/prompt_storage.py index 2497af9..4b21da5 100644 --- a/ragu/common/prompts/prompt_storage.py +++ b/ragu/common/prompts/prompt_storage.py @@ -1,9 +1,8 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Type, Tuple, List +from typing import Optional, Type -from jinja2 import Template from pydantic import BaseModel from ragu.common.prompts.default_models import ( @@ -16,7 +15,7 @@ RelationDescriptionModel, ClusterSummarizationModel, QueryPlan, - RewriteQuery + RewriteQuery, ) from ragu.common.prompts.default_templates import ( DEFAULT_ARTIFACTS_EXTRACTOR_PROMPT, @@ -26,7 +25,8 @@ DEFAULT_ENTITY_SUMMARIZER_PROMPT, DEFAULT_RESPONSE_ONLY_PROMPT, DEFAULT_GLOBAL_SEARCH_CONTEXT_PROMPT, - DEFAULT_GLOBAL_SEARCH_PROMPT, DEFAULT_CLUSTER_SUMMARIZER_PROMPT, + DEFAULT_GLOBAL_SEARCH_PROMPT, + DEFAULT_CLUSTER_SUMMARIZER_PROMPT, DEFAULT_RAGU_LM_ENTITY_EXTRACTION_PROMPT, DEFAULT_RAGU_LM_ENTITY_NORMALIZATION_PROMPT, DEFAULT_RAGU_LM_ENTITY_DESCRIPTION_PROMPT, @@ -34,150 +34,178 @@ DEFAULT_QUERY_DECOMPOSITION_PROMPT, DEFAULT_QUERY_REWRITE_PROMPT, ) +from ragu.common.prompts.messages import ( + ChatMessages, + UserMessage, + SystemMessage +) -@dataclass -class PromptTemplate: - """ - Represents a Jinja2-based prompt template for instruction generation. - - Each template defines: - - a Jinja2 text pattern (`template`) - - an optional Pydantic schema for structured output validation (`schema`) - - a short description of its purpose (`description`) - - The template can be rendered dynamically with keyword arguments, - supporting both single-instance and batched (list/tuple) generation. - """ - - template: str - schema: Type[BaseModel] = None - description: str = "" - - def __post_init__(self): - """ - Compile the Jinja2 template upon initialization for faster rendering. - """ - self.compiled_template = Template(self.template) - - def get_instruction(self, **batch_kwargs) -> Tuple[List[str], Type[BaseModel]]: - """ - Render one or more prompt instructions using the template. - - Supports both single-value rendering and batch processing - (when lists or tuples are passed as arguments). - - :param batch_kwargs: Key-value pairs passed into the Jinja2 template. - Lists and tuples trigger batch rendering. - :return: A tuple of (list of rendered instructions, associated schema). - """ - batch_lengths = { - key: len(value) for key, value in batch_kwargs.items() - if isinstance(value, (list, tuple)) - } - - # No batched parameters → single instruction - if not batch_lengths: - return [self.compiled_template.render(**batch_kwargs)], self.schema - - # Validate that all batched parameters have equal length - unique_lengths = set(batch_lengths.values()) - if len(unique_lengths) > 1: - raise ValueError("All batch parameters must have the same length") - - batch_size = next(iter(unique_lengths)) - batch_params = [] - for i in range(batch_size): - params = {} - for key, value in batch_kwargs.items(): - if isinstance(value, (list, tuple)): - params[key] = value[i] - else: - params[key] = value - batch_params.append(params) - - instructions = [ - self.compiled_template.render(**params) - for params in batch_params - ] - - return instructions, self.schema - - -DEFAULT_PROMPT_TEMPLATES = { - "artifact_extraction": PromptTemplate( - template=DEFAULT_ARTIFACTS_EXTRACTOR_PROMPT, - schema=ArtifactsModel, - description="Prompt for extracting artifacts (entities and relations) from a text passage." - ), - "artifact_validation": PromptTemplate( - template=DEFAULT_ARTIFACTS_VALIDATOR_PROMPT, - schema=ArtifactsModel, - description="Prompt for validating extracted artifacts against a schema." - ), - "community_report": PromptTemplate( - template=DEFAULT_COMMUNITY_REPORT_PROMPT, - schema=CommunityReportModel, - description="Prompt for generating community summaries from contextual data." - ), - "entity_summarizer": PromptTemplate( - template=DEFAULT_ENTITY_SUMMARIZER_PROMPT, - schema=EntityDescriptionModel, - description="Prompt for summarizing entity descriptions." - ), - "relation_summarizer": PromptTemplate( - template=DEFAULT_RELATIONSHIP_SUMMARIZER_PROMPT, - schema=RelationDescriptionModel, - description="Prompt for summarizing relationship descriptions." - ), - "global_search_context": PromptTemplate( - template=DEFAULT_GLOBAL_SEARCH_CONTEXT_PROMPT, - schema=GlobalSearchContextModel, - description="Prompt for generating contextual information for a global search." - ), - "global_search": PromptTemplate( - template=DEFAULT_GLOBAL_SEARCH_PROMPT, - schema=GlobalSearchResponseModel, - description="Prompt for generating a synthesized global search response." - ), - "local_search": PromptTemplate( - template=DEFAULT_RESPONSE_ONLY_PROMPT, - schema=DefaultResponseModel, - description="Prompt for generating a local context-based search response." - ), - "naive_search": PromptTemplate( - template=DEFAULT_RESPONSE_ONLY_PROMPT, - schema=DefaultResponseModel, - description="Prompt for generating a naive vector RAG search response." - ), - "cluster_summarize": PromptTemplate( - template=DEFAULT_CLUSTER_SUMMARIZER_PROMPT, - schema=ClusterSummarizationModel - ), - "ragu_lm_entity_extraction": PromptTemplate( - template=DEFAULT_RAGU_LM_ENTITY_EXTRACTION_PROMPT, - description="Instruction for RAGU-lm entity extraction stage." - ), - "ragu_lm_entity_normalization": PromptTemplate( - template=DEFAULT_RAGU_LM_ENTITY_NORMALIZATION_PROMPT, - description="Instruction for RAGU-lm entity normalization stage." - ), - "ragu_lm_entity_description": PromptTemplate( - template=DEFAULT_RAGU_LM_ENTITY_DESCRIPTION_PROMPT, - description="Instruction for RAGU-lm entity description stage." - ), - "ragu_lm_relation_description": PromptTemplate( - template=DEFAULT_RAGU_LM_RELATION_DESCRIPTION_PROMPT, - description="Instruction for RAGU-lm relation description stage." - ), - "query_decomposition": PromptTemplate( - template=DEFAULT_QUERY_DECOMPOSITION_PROMPT, - schema=QueryPlan, - description="Prompt for decomposing a complex query into atomic subqueries with dependencies." - ), - "query_rewrite": PromptTemplate( - template=DEFAULT_QUERY_REWRITE_PROMPT, - schema=RewriteQuery, - description="Prompt for rewriting a subquery using answers from its dependencies." +@dataclass(frozen=True, slots=True) +class RAGUInstruction: + messages: ChatMessages + pydantic_model: Optional[Type[BaseModel]] = None + description: Optional[str] = None + + +DEFAULT_PROMPT_TEMPLATES: dict[str, RAGUInstruction] = { + "artifact_extraction": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_ARTIFACTS_EXTRACTOR_PROMPT), + ] + ), + pydantic_model=ArtifactsModel, + description="Prompt for extracting artifacts (entities and relations) from a text passage.", + ), + + "artifact_validation": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_ARTIFACTS_VALIDATOR_PROMPT), + ] + ), + pydantic_model=ArtifactsModel, + description="Prompt for validating extracted artifacts against a schema.", + ), + + "community_report": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_COMMUNITY_REPORT_PROMPT), + ] + ), + pydantic_model=CommunityReportModel, + description="Prompt for generating community summaries from contextual data.", + ), + + "entity_summarizer": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_ENTITY_SUMMARIZER_PROMPT), + ] + ), + pydantic_model=EntityDescriptionModel, + description="Prompt for summarizing entity descriptions.", + ), + + "relation_summarizer": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_RELATIONSHIP_SUMMARIZER_PROMPT), + ] + ), + pydantic_model=RelationDescriptionModel, + description="Prompt for summarizing relationship descriptions.", + ), + + "global_search_context": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_GLOBAL_SEARCH_CONTEXT_PROMPT), + ] + ), + pydantic_model=GlobalSearchContextModel, + description="Prompt for generating contextual information for a global search.", + ), + + "global_search": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_GLOBAL_SEARCH_PROMPT), + ] + ), + pydantic_model=GlobalSearchResponseModel, + description="Prompt for generating a synthesized global search response.", + ), + + "local_search": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_RESPONSE_ONLY_PROMPT), + ] + ), + pydantic_model=DefaultResponseModel, + description="Prompt for generating a local context-based search response.", + ), + + "naive_search": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_RESPONSE_ONLY_PROMPT), + ] + ), + pydantic_model=DefaultResponseModel, + description="Prompt for generating a naive vector RAG search response.", + ), + + "cluster_summarize": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_CLUSTER_SUMMARIZER_PROMPT), + ] + ), + pydantic_model=ClusterSummarizationModel, + description=None, + ), + + "ragu_lm_entity_extraction": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_RAGU_LM_ENTITY_EXTRACTION_PROMPT), + ] + ), + pydantic_model=None, + description="Instruction for RAGU-lm entity extraction stage.", + ), + + "ragu_lm_entity_normalization": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_RAGU_LM_ENTITY_NORMALIZATION_PROMPT), + ] + ), + pydantic_model=None, + description="Instruction for RAGU-lm entity normalization stage.", + ), + + "ragu_lm_entity_description": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_RAGU_LM_ENTITY_DESCRIPTION_PROMPT), + ] + ), + pydantic_model=None, + description="Instruction for RAGU-lm entity description stage.", + ), + + "ragu_lm_relation_description": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_RAGU_LM_RELATION_DESCRIPTION_PROMPT), + ] + ), + pydantic_model=None, + description="Instruction for RAGU-lm relation description stage.", + ), + + "query_decomposition": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_QUERY_DECOMPOSITION_PROMPT), + ] + ), + pydantic_model=QueryPlan, + description="Prompt for decomposing a complex query into atomic subqueries with dependencies.", + ), + + "query_rewrite": RAGUInstruction( + messages=ChatMessages.from_messages( + [ + UserMessage(content=DEFAULT_QUERY_REWRITE_PROMPT), + ] + ), + pydantic_model=RewriteQuery, + description="Prompt for rewriting a subquery using answers from its dependencies.", ), } diff --git a/ragu/graph/artifacts_summarizer.py b/ragu/graph/artifacts_summarizer.py index 9dce945..438b78e 100644 --- a/ragu/graph/artifacts_summarizer.py +++ b/ragu/graph/artifacts_summarizer.py @@ -3,9 +3,9 @@ from typing import List, Any import pandas as pd -from pandas.core.interchange.dataframe_protocol import DataFrame from sklearn.cluster import DBSCAN +from ragu.common.global_parameters import Settings from ragu.common.base import RaguGenerativeModule from ragu.common.logger import logger from ragu.common.prompts.default_models import RelationDescriptionModel, EntityDescriptionModel @@ -13,23 +13,26 @@ from ragu.graph.types import Entity, Relation from ragu.llm.base_llm import BaseLLM +from ragu.common.prompts.prompt_storage import RAGUInstruction +from ragu.common.prompts.messages import ChatMessages, render + class EntitySummarizer(RaguGenerativeModule): def __init__( - self, - client: BaseLLM = None, - use_llm_summarization: bool = True, - use_clustering: bool = False, - embedder: BaseEmbedder = None, - cluster_only_if_more_than: int = 128, - summarize_only_if_more_than: int = 5, - language: str = "russian" + self, + client: BaseLLM = None, + use_llm_summarization: bool = True, + use_clustering: bool = False, + embedder: BaseEmbedder = None, + cluster_only_if_more_than: int = 128, + summarize_only_if_more_than: int = 5, + language: str | None = None, ): _PROMPTS = ["entity_summarizer", "cluster_summarize"] super().__init__(prompts=_PROMPTS) self.client = client - self.language = language + self.language = language if language else Settings.language self.use_llm_summarization = use_llm_summarization self.summarize_only_if_more_than = summarize_only_if_more_than @@ -51,23 +54,22 @@ def __init__( if self.use_clustering and not self.embedder: raise ValueError( - f"Clustering is enabled but no embedder is provided. Please provide an embedder." + "Clustering is enabled but no embedder is provided. Please provide an embedder." ) async def run(self, entities: List[Entity]) -> Any: """ Execute the full artifact summarization pipeline. - The pipeline performs the following steps: - - 1. Group duplicated entities and relations into aggregated dataframes. - 2. Summarize merged entity and relation descriptions if enabled. - 3. Return the updated lists of :class:`Entity` and :class:`Relation` objects. + Steps: + 1. Group duplicated entities by (entity_name, entity_type), + 2. Optionally cluster large description sets and summarize cluster-wise, + 3. Optionally summarize entities with many duplicates via LLM, + 4. Return updated list of Entity objects. - :param entities: List of extracted entities to summarize or merge. - :return: A tuple ``(entities, relations)`` containing updated objects. + :param entities: List of extracted entities. + :return: Summarized/deduplicated entities list. """ - if len(entities) == 0: logger.warning("Empty list of entities. Seems that something goes wrong.") return [] @@ -75,14 +77,16 @@ async def run(self, entities: List[Entity]) -> Any: grouped_entities_df = self.group_entities(entities) num_of_duplicated_entities = len(entities) - len(grouped_entities_df) - logger.info(f"Found {num_of_duplicated_entities} duplicated entities. " - f"Number of unique entities: {len(grouped_entities_df)} ") + logger.info( + f"Found {num_of_duplicated_entities} duplicated entities. " + f"Number of unique entities: {len(grouped_entities_df)} " + ) entities_to_return = await self.summarize_entities(grouped_entities_df) if len(entities_to_return) != len(grouped_entities_df): logger.warning( - f"{len(entities_to_return) - len(grouped_entities_df)} from {len(grouped_entities_df)} entities" + f"{len(entities_to_return) - len(grouped_entities_df)} from {len(grouped_entities_df)} entities " f"were missed during summarization." ) @@ -115,29 +119,31 @@ async def summarize_entities(self, grouped_entities_df: pd.DataFrame) -> List[En if entity_multi_desc.empty: return [Entity(**row) for _, row in entity_single_desc.iterrows()] - entities_to_summarize = [] + entities_to_summarize: List[Entity] = [] if self.use_llm_summarization: entities_to_summarize = [Entity(**row) for _, row in entity_multi_desc.iterrows()] - prompt, schema = self.get_prompt("entity_summarizer").get_instruction( + + instruction: RAGUInstruction = self.get_prompt("entity_summarizer") + rendered_list: List[ChatMessages] = render( + instruction.messages, entity=entities_to_summarize, language=self.language, ) - response: List[EntityDescriptionModel] = await self.client.generate( # type: ignore - prompt=prompt, - schema=schema, - progress_bar_desc="Entity summarization" + + response: List[EntityDescriptionModel] = await self.client.generate( # type: ignore + conversations=rendered_list, + response_model=instruction.pydantic_model, + progress_bar_desc="Entity summarization", ) for i, summary in enumerate(response): if summary: entities_to_summarize[i].description = summary.description - else: entities_to_summarize = [Entity(**row) for _, row in entity_multi_desc.iterrows()] return [Entity(**row) for _, row in entity_single_desc.iterrows()] + entities_to_summarize - @staticmethod def group_entities(entities: List[Entity]) -> pd.DataFrame: """ @@ -162,30 +168,42 @@ def group_entities(entities: List[Entity]) -> pd.DataFrame: return grouped_entities async def _summarize_by_cluster_if_needed(self, descriptions: List[str]) -> str: + """ + Optionally cluster a large set of descriptions and summarize each cluster via LLM. + + If clustering is disabled or there are not enough descriptions, returns the + concatenation of descriptions. + + :param descriptions: List of raw descriptions for one entity. + :return: A single merged (and optionally cluster-summarized) description string. + """ if len(descriptions) > self.cluster_only_if_more_than and self.use_clustering: cluster = DBSCAN(eps=0.5, min_samples=2).fit(await self.embedder.embed(descriptions)) labels = cluster.labels_ - clusters = {} + clusters: dict[int, list[str]] = {} for label, text in zip(labels, descriptions): - if label not in clusters: - clusters[label] = [] - clusters[label].append(text) - - result_description = [] - for cluster in clusters.values(): - prompt, schema = self.get_prompt("cluster_summarize").get_instruction(content=cluster) - result = await self.client.generate( - prompt=prompt, - schema=schema, - progress_bar_desc="Map reduce for clustering" - ) # type: ignore - result_description.extend([r.content for r in result]) + clusters.setdefault(int(label), []).append(text) + + result_description: List[str] = [] + for texts in clusters.values(): + instruction: RAGUInstruction = self.get_prompt("cluster_summarize") + rendered_list: List[ChatMessages] = render( + instruction.messages, + content=texts, + language=self.language, + ) + + result = await self.client.generate( # type: ignore + conversations=rendered_list, + response_model=instruction.pydantic_model, + progress_bar_desc="Map reduce for clustering", + ) + result_description.extend([r.content for r in result if r]) return ". ".join(result_description) - else: - return ". ".join(descriptions) + return ". ".join(descriptions) class RelationSummarizer(RaguGenerativeModule): @@ -207,11 +225,11 @@ class RelationSummarizer(RaguGenerativeModule): """ def __init__( - self, - client: BaseLLM = None, - use_llm_summarization: bool = True, - summarize_only_if_more_than: int = 5, - language: str = "russian" + self, + client: BaseLLM = None, + use_llm_summarization: bool = True, + summarize_only_if_more_than: int = 5, + language: str = "russian", ): _PROMPTS = ["relation_summarizer"] super().__init__(prompts=_PROMPTS) @@ -245,9 +263,11 @@ async def run(self, relations: List[Relation], **kwargs) -> Any: grouped_relations_df = self.group_relations(relations) - num_of_duplicated_entities = len(relations) - len(grouped_relations_df) - logger.info(f"Found {num_of_duplicated_entities} duplicated relations. " - f"Number of unique relations: {len(grouped_relations_df)} ") + num_of_duplicated_relations = len(relations) - len(grouped_relations_df) + logger.info( + f"Found {num_of_duplicated_relations} duplicated relations. " + f"Number of unique relations: {len(grouped_relations_df)} " + ) relations_to_return = await self.summarize_relations(grouped_relations_df) @@ -282,14 +302,21 @@ async def summarize_relations(self, grouped_relations_df: pd.DataFrame) -> List[ if relation_multi_desc.empty: return [Relation(**row) for _, row in relation_single_desc.iterrows()] - relations_to_summarize = [] + relations_to_summarize: List[Relation] = [] if self.use_llm_summarization: relations_to_summarize = [Relation(**row) for _, row in relation_multi_desc.iterrows()] - prompt, schema = self.get_prompt("relation_summarizer").get_instruction( + + instruction: RAGUInstruction = self.get_prompt("relation_summarizer") + rendered_list: List[ChatMessages] = render( + instruction.messages, relation=relations_to_summarize, language=self.language, ) - response: List[RelationDescriptionModel] = await self.client.generate(prompt=prompt, schema=schema) # type: ignore + + response: List[RelationDescriptionModel] = await self.client.generate( # type: ignore + conversations=rendered_list, + response_model=instruction.pydantic_model, + ) for i, summary in enumerate(response): if summary: @@ -302,10 +329,10 @@ async def summarize_relations(self, grouped_relations_df: pd.DataFrame) -> List[ @staticmethod def group_relations(relations: List[Relation]) -> pd.DataFrame: """ - Group relations by ``subject_id`` and ``object_id`` and merge their fields. + Group relations by (subject_id, object_id) and merge their fields. - :param relations: List of :class:`Relation` objects to group. - :return: Aggregated relations as a :class:`pandas.DataFrame`. + :param relations: List of Relation objects. + :return: Aggregated relations as a pandas DataFrame. """ relations_df = pd.DataFrame([asdict(relation) for relation in relations]) grouped_relations = relations_df.groupby(["subject_id", "object_id"]).agg( diff --git a/ragu/graph/community_summarizer.py b/ragu/graph/community_summarizer.py index f4ad8f2..d9e154f 100644 --- a/ragu/graph/community_summarizer.py +++ b/ragu/graph/community_summarizer.py @@ -1,8 +1,12 @@ +from textwrap import dedent from typing import List from jinja2 import Template from ragu.common.base import RaguGenerativeModule +from ragu.common.global_parameters import Settings from ragu.common.prompts.default_models import CommunityReportModel +from ragu.common.prompts.prompt_storage import RAGUInstruction +from ragu.common.prompts.messages import ChatMessages, render from ragu.graph.types import Community, CommunitySummary from ragu.llm.base_llm import BaseLLM @@ -18,29 +22,44 @@ class CommunitySummarizer(RaguGenerativeModule): Attributes ---------- client : BaseLLM - The underlying LLM client used for generating community summaries. + LLM client used for generating community reports. language : str Language of generated summaries. """ - def __init__(self, client: BaseLLM, language: str = "english") -> None: - _PROMPT = ["community_report"] - super().__init__(prompts=_PROMPT) + def __init__(self, client: BaseLLM, language: str | None = None) -> None: + _PROMPTS = ["community_report"] + super().__init__(prompts=_PROMPTS) + self.client = client - self.language = language + self.language = language if language else Settings.language async def summarize(self, communities: List[Community]) -> List[CommunitySummary]: """ Generate structured summaries for a list of graph communities. """ - instructions, schema = self.get_prompt("community_report").get_instruction( - community=communities, + sorted_communities = [] + for community in communities: + sorted_communities.append( + Community(entities=sorted( + community.entities, key=lambda e: e.id), + relations=sorted(community.relations, key=lambda e: e.id), + level=community.level, + cluster_id=community.cluster_id, + ) + ) + instruction: RAGUInstruction = self.get_prompt("community_report") + + rendered_list: List[ChatMessages] = render( + instruction.messages, + community=sorted_communities, language=self.language, ) summaries: List[CommunityReportModel] = await self.client.generate( # type: ignore - prompt=instructions, - schema=schema, + conversations=rendered_list, + response_model=instruction.pydantic_model, + progress_bar_desc="Summarized communities", ) output: List[CommunitySummary] = [ @@ -48,7 +67,7 @@ async def summarize(self, communities: List[Community]) -> List[CommunitySummary id=community.id, summary=self.combine_report_text(summary), ) - for (community, summary) in zip(communities, summaries) + for (community, summary) in zip(sorted_communities, summaries) ] return output @@ -61,16 +80,16 @@ def combine_report_text(report: CommunityReportModel) -> str: if not report: return "" - template = Template( + template = Template(dedent( """ Report title: {{ report.title }} Report summary: {{ report.summary }} - + {% for finding in report.findings %} Finding summary: {{ finding.summary }} Finding explanation: {{ finding.explanation }} {% endfor %} - """.strip() + """) ) return template.render(report=report) diff --git a/ragu/graph/graph_builder_pipeline.py b/ragu/graph/graph_builder_pipeline.py index ee8fc7a..28870a5 100644 --- a/ragu/graph/graph_builder_pipeline.py +++ b/ragu/graph/graph_builder_pipeline.py @@ -1,17 +1,60 @@ -import logging +from dataclasses import dataclass from typing import Any, List, Tuple from ragu.chunker import BaseChunker from ragu.chunker.types import Chunk from ragu.common.global_parameters import Settings from ragu.embedder.base_embedder import BaseEmbedder +from ragu.graph.artifacts_summarizer import EntitySummarizer, RelationSummarizer from ragu.graph.community_summarizer import CommunitySummarizer from ragu.graph.types import CommunitySummary, Community, Entity, Relation -from ragu.graph.artifacts_summarizer import EntitySummarizer, RelationSummarizer from ragu.llm.base_llm import BaseLLM from ragu.triplet.base_artifact_extractor import BaseArtifactExtractor +@dataclass +class KnowledgeGraphBuilderSettings: + """ + Configuration settings for the knowledge graph building pipeline. + + This dataclass controls various aspects of graph construction including + summarization strategies, clustering behavior, and optimization modes. + + Attributes + ---------- + use_llm_summarization : bool, default=True + Enable LLM-based summarization for merging and deduplicating similar + entity and relation descriptions. + use_clustering : bool, default=False + Apply clustering to group similar entities before summarization. + Helps when number of similar entities is very large. + build_only_vector_context : bool, default=False + Skip entity/relation extraction and build a context only for naive (vector) RAG. + make_community_summary : bool, default=True + Generate high-level summaries for detected communities in the graph. + Required for global search operations that rely on community-level context. + remove_isolated_nodes : bool, default=True + Remove entities that have no relations to other entities in the graph. + vectorize_chunks : bool, default=False + Generate and store embeddings for text chunks. + cluster_only_if_more_than : int, default=10000 + Minimum number of entities required before clustering is applied. + max_cluster_size : int, default=128 + Maximum number of entities per cluster during summarization. + random_seed : int, default=42 + Random seed for reproducible clustering and community detection results. + """ + use_llm_summarization: bool = True + use_clustering: bool = False + build_only_vector_context: bool = False + make_community_summary: bool = True + remove_isolated_nodes: bool = True + vectorize_chunks: bool = False + cluster_only_if_more_than: int = 10000 + max_cluster_size: int = 128 + random_seed: int = 42 + + class GraphBuilderModule: """ Abstract interface for modules that extend the graph-building pipeline. @@ -56,27 +99,37 @@ class InMemoryGraphBuilder: 4. (Optional) **Additional modules** for graph enrichment. 5. **Community summarization** (aggregated graph-level summaries). - When `build_only_vector_context=True`, steps 2-5 are skipped, and only chunking - is performed. This is useful for naive vector RAG where only chunk embeddings - are needed without knowledge graph construction. + When `build_parameters.build_only_vector_context=True`, steps 2-5 are skipped, + and only chunking is performed. This is useful for naive vector RAG where only + chunk embeddings are needed without knowledge graph construction. Parameters ---------- client : BaseLLM, optional LLM client used for all text understanding and summarization tasks. - Not required if build_only_vector_context=True. - chunker : BaseChunker - Module responsible for splitting documents into semantically meaningful chunks. + Not required if build_parameters.build_only_vector_context=True. + chunker : BaseChunker, optional + Module responsible for splitting documents into chunks. artifact_extractor : BaseArtifactExtractor, optional - Extracts entities and relations from text chunks (triplet-based). - Not required if build_only_vector_context=True. + Extracts entities and relations from text chunks. + Not required if build_parameters.build_only_vector_context=True. + build_parameters : KnowledgeGraphBuilderSettings, optional + Configuration settings controlling graph building behavior including + summarization, clustering, and optimization modes. + embedder : BaseEmbedder, optional + Embedding model used for vectorizing entities, relations, and optionally chunks. + llm_cache_flush_every : int, default=100 + Number of LLM calls between cache flushes to disk. + Lower values increase I/O. + embedder_cache_flush_every : int, default=100 + Number of embedder calls between cache flushes to disk. + Embedder caches are typically larger, so default flush frequency is lower. additional_pipeline : list[GraphBuilderModule], optional Optional list of post-processing modules applied after main extraction. - build_only_vector_context : bool, default=False - If True, skip entity/relation extraction and only perform chunking. - Use this for naive vector RAG without knowledge graph construction. - language : str, default="english" - Working language for summarization and extraction tasks. + Used for custom normalization, filtering and others logic. + language : str, optional + Working language for all tasks. + Default: inherited from global Settings.language ("english"). """ def __init__( @@ -84,46 +137,41 @@ def __init__( client: BaseLLM = None, chunker: BaseChunker = None, artifact_extractor: BaseArtifactExtractor = None, + build_parameters: KnowledgeGraphBuilderSettings = KnowledgeGraphBuilderSettings(), embedder: BaseEmbedder = None, - use_llm_summarization: bool = True, - use_clustering: bool = False, - cluster_only_if_more_than: int = 128, llm_cache_flush_every: int = 100, - embedder_cache_flush_every: int = 100000, + embedder_cache_flush_every: int = 100, additional_pipeline: List[GraphBuilderModule] = None, - build_only_vector_context: bool = False, - language: str | None = None, + language: str | None = None ): self.client = client self.chunker = chunker self.artifact_extractor = artifact_extractor self.additional_pipeline = additional_pipeline - self.language = language if language else Settings.language self.embedder = embedder - self.use_llm_summarization = use_llm_summarization - self.use_clustering = use_clustering self.llm_cache_flush_every = llm_cache_flush_every self.embedder_cache_flush_every = embedder_cache_flush_every - self.build_only_vector_context = build_only_vector_context + self.language = language if language else Settings.language + self.build_parameters = build_parameters - if build_only_vector_context: + if self.build_parameters.build_only_vector_context: # No need to create those instances => we are able not to think about its parameters self.entity_summarizer, self.relation_summarizer, self.community_summarizer = None, None, None else: self.entity_summarizer = EntitySummarizer( client, - use_llm_summarization=use_llm_summarization, - use_clustering=use_clustering, - cluster_only_if_more_than=cluster_only_if_more_than, + use_llm_summarization=self.build_parameters.use_llm_summarization, + use_clustering=self.build_parameters.use_clustering, + cluster_only_if_more_than=self.build_parameters.cluster_only_if_more_than, embedder=embedder, - language=language, + language=self.language, ) self.relation_summarizer = RelationSummarizer( client, - use_llm_summarization=use_llm_summarization, - language=language + use_llm_summarization=self.build_parameters.use_llm_summarization, + language=self.language ) - self.community_summarizer = CommunitySummarizer(self.client, language=language) + self.community_summarizer = CommunitySummarizer(self.client, language=self.language) async def extract_graph( self, documents: List[str] @@ -148,7 +196,7 @@ async def extract_graph( chunks = self.chunker(documents) # If only building vector context, skip entity/relation extraction - if self.build_only_vector_context: + if self.build_parameters.build_only_vector_context: return [], [], chunks # Step 2: extract entities and relations diff --git a/ragu/graph/knowledge_graph.py b/ragu/graph/knowledge_graph.py index 2d5ba5d..fb0343c 100644 --- a/ragu/graph/knowledge_graph.py +++ b/ragu/graph/knowledge_graph.py @@ -1,75 +1,80 @@ import asyncio from typing import List +from ragu.common.global_parameters import Settings from ragu.common.logger import logger from ragu.graph.graph_builder_pipeline import InMemoryGraphBuilder from ragu.graph.types import Entity, Relation, CommunitySummary from ragu.storage.index import Index -from ragu.common.global_parameters import Settings - -# TODO: implement all methods +# TODO: add all "atomic" operation (CRUD for artifacts) class KnowledgeGraph: + """ + High-level facade for knowledge graph operations. + + Handles graph construction, entity/relation merging logic, and + delegates all CRUD operations to the Index class. + """ + def __init__( self, extraction_pipeline: InMemoryGraphBuilder, index: Index, - make_community_summary: bool = True, - remove_isolated_nodes: bool = True, - vectorize_chunks: bool = False, language: str | None = None, ): - self.pipeline = extraction_pipeline self.index = index - self.make_community_summary = make_community_summary - self.remove_isolated_nodes = remove_isolated_nodes - self.vectorize_chunks = vectorize_chunks + self.pipeline = extraction_pipeline + + self.build_params = extraction_pipeline.build_parameters + + self.make_community_summary = self.build_params.make_community_summary + self.remove_isolated_nodes = self.build_params.remove_isolated_nodes + self.vectorize_chunks = self.build_params.vectorize_chunks self.language = language if language else Settings.language - if self.language != self.pipeline.language: + if self.language != self.language: logger.warning( - "Override language from %s to %s", - self.pipeline.language, self.language + f"Override language from {self.pipeline.language} to {self.language}" ) - for o in self.pipeline.__dict__: - if getattr(o, "language", None): - setattr(o, "language", self.language) + for o in self.pipeline.__dict__.values(): + if hasattr(o, "language"): + o.language = self.language - self._id_to_entity_map = {} - self._id_to_relation_map = {} - - # Initialize storage folder if it doesn't exist Settings.init_storage_folder() - async def build_from_docs(self, docs) -> "KnowledgeGraph": + async def build_from_docs(self, docs: List[str]) -> "KnowledgeGraph": + """ + Build knowledge graph from documents. + """ entities, relations, chunks = await self.pipeline.extract_graph(docs) - # Check if we're in vector-only mode is_vector_only = getattr(self.pipeline, 'build_only_vector_context', False) - # Add entities and relations (skip if vector-only mode) if not is_vector_only: - await self.add_entity(entities) - await self.add_relation(relations) + await self.index.make_index( + entities=entities, + relations=relations, + ) if self.remove_isolated_nodes: await self.index.graph_backend.remove_isolated_nodes() - # Save chunks (always vectorize in vector-only mode) should_vectorize = self.vectorize_chunks or is_vector_only await self.index.insert_chunks(chunks, vectorize=should_vectorize) - # Build community summaries (skip if vector-only mode) if self.make_community_summary and not is_vector_only: communities, summaries = await self.high_level_build() - await self.index._insert_communities(communities) - await self.index._insert_summaries(summaries) + await self.index.insert_communities(communities) + await self.index.insert_summaries(summaries) return self async def high_level_build(self): + """ + Build communities and their summaries. + """ communities = await self.index.graph_backend.cluster() summaries = await self.pipeline.get_community_summary( communities=communities diff --git a/ragu/llm/base_llm.py b/ragu/llm/base_llm.py index 264a069..44d03ee 100644 --- a/ragu/llm/base_llm.py +++ b/ragu/llm/base_llm.py @@ -1,10 +1,16 @@ +import asyncio from abc import ABC, abstractmethod -from typing import ( - Optional, - Union -) +from typing import List, Type, Any +from aiolimiter import AsyncLimiter from pydantic import BaseModel +from tqdm.asyncio import tqdm_asyncio + +from ragu.common.batch_generator import BatchGenerator +from ragu.common.cache import PendingRequest, make_llm_cache_key, TextCache +from ragu.common.logger import logger +from ragu.common.prompts import ChatMessages +from ragu.utils.ragu_utils import AsyncRunner class BaseLLM(ABC): @@ -15,10 +21,26 @@ class BaseLLM(ABC): and maintains statistics about token usage and request outcomes. """ - def __init__(self): + def __init__( + self, + model_name: str, + max_requests_per_minute: int = 60, + max_requests_per_second: int = 1, + concurrency: int = 10, + time_period: int | float = 1, + cache_flush_every: int = 100, + ): """ Initialize the LLM base client with default usage statistics. """ + self.model_name = model_name + self._sem = asyncio.Semaphore(max(1, concurrency)) + self._rpm = AsyncLimiter(max_requests_per_minute, time_period=60) + self._rps = AsyncLimiter(max_requests_per_second, time_period=time_period) + self._cache_flush_every = cache_flush_every + + self.cache = TextCache(flush_every_n_writes=cache_flush_every) + self._save_stats = True self.statistics = { "total_tokens": 0, @@ -47,25 +69,77 @@ def reset_statistics(self): self.statistics[k] = 0 @abstractmethod - async def generate( + async def complete( self, - prompt: str | list[str], - system_prompt: str = None, - pydantic_model: type[BaseModel] = None, - model_name: str = None, - **kwargs - ) -> Optional[Union[BaseModel, str]]: - """ - Abstract method for text or structured output generation. + messages: ChatMessages, + response_model: Type[BaseModel] | None = None, + model_name: str | None = None, + **kwargs: Any, + ) -> str | BaseModel | None: + ... - Implementations must perform the actual LLM call and return - either a plain string response or a parsed Pydantic model. + async def generate( + self, + conversations: List[ChatMessages], + response_model: Type[BaseModel] | None = None, + model_name: str | None = None, + progress_bar_desc: str = "Processing", + **kwargs: Any, + ) -> List[str | BaseModel | None]: + + results: List[str | BaseModel | None] = [None] * len(conversations) + pending: List[PendingRequest] = [] + + for i, conversation in enumerate(conversations): + key = make_llm_cache_key( + content=conversation.to_str(), + model_name=model_name or self.model_name, + schema=response_model, + kwargs=kwargs, + ) + + cached = await self.cache.get(key, schema=response_model) + if cached is not None: + results[i] = cached + else: + pending.append(PendingRequest(i, conversation, key)) + + logger.info( + f"[OpenAIClientService]: Found {len(conversations) - len(pending)}/{len(conversations)} requests in cache.") + + if not pending: + return results + + with tqdm_asyncio(total=len(pending), desc=progress_bar_desc) as pbar: + runner = AsyncRunner(self._sem, self._rps, self._rpm, pbar) + + for batch in BatchGenerator(pending, self._cache_flush_every).get_batches(): + tasks = [ + runner.make_request( + self.complete, + messages=req.messages, + model_name=model_name or self.model_name, + response_model=response_model, + **kwargs + ) + for req in batch + ] + + generated = await asyncio.gather(*tasks) + + for req, value in zip(batch, generated): + if not isinstance(value, Exception) and value is not None: + await self.cache.set( + req.cache_key, + value, + input_instruction=req.messages.to_str(), + model_name=model_name or self.model_name, + ) + results[req.index] = value + else: + results[req.index] = None + + await self.cache.flush_cache() + + return results - :param prompt: User prompt or a list of prompts for batch generation. - :param system_prompt: Optional system-level context or instruction. - :param pydantic_model: Optional Pydantic model for structured output parsing. - :param model_name: Optional override for model selection. - :param kwargs: Additional generation parameters (e.g., temperature, max_tokens). - :return: Generated text or structured model instance, or ``None`` on failure. - """ - pass diff --git a/ragu/llm/openai_client.py b/ragu/llm/openai_client.py index c32aac8..e8062a8 100644 --- a/ragu/llm/openai_client.py +++ b/ragu/llm/openai_client.py @@ -1,13 +1,10 @@ -import asyncio from typing import ( Any, - List, Optional, - Union, + Union, Type, ) import instructor -from aiolimiter import AsyncLimiter from openai import AsyncOpenAI from pydantic import BaseModel from tenacity import ( @@ -15,14 +12,10 @@ wait_exponential, retry, ) -from tqdm.asyncio import tqdm_asyncio -from ragu.common.batch_generator import BatchGenerator -from ragu.common.cache import TextCache, PendingRequest, make_llm_cache_key from ragu.common.logger import logger -from ragu.common.decorator import no_throw +from ragu.common.prompts import ChatMessages from ragu.llm.base_llm import BaseLLM -from ragu.utils.ragu_utils import AsyncRunner class OpenAIClient(BaseLLM): @@ -35,13 +28,13 @@ def __init__( model_name: str, base_url: str, api_token: str, - concurrency: int = 8, - request_timeout: float = 60.0, - instructor_mode: instructor.Mode = instructor.Mode.JSON, max_requests_per_minute: int = 60, max_requests_per_second: int = 1, + concurrency: int = 10, time_period: int | float = 1, cache_flush_every: int = 100, + request_timeout: float = 60.0, + instructor_mode: instructor.Mode = instructor.Mode.JSON, **openai_kwargs: Any, ): """ @@ -58,13 +51,14 @@ def __init__( :param cache_flush_every: Flush cache to disk every N requests (default 100). :param openai_kwargs: Additional keyword arguments passed to AsyncOpenAI. """ - super().__init__() - - self.model_name = model_name - self._sem = asyncio.Semaphore(max(1, concurrency)) - self._rpm = AsyncLimiter(max_requests_per_minute, time_period=60) - self._rps = AsyncLimiter(max_requests_per_second, time_period=time_period) - self._cache_flush_every = cache_flush_every + super().__init__( + model_name=model_name, + max_requests_per_minute=max_requests_per_minute, + max_requests_per_second=max_requests_per_second, + concurrency=concurrency, + time_period=time_period, + cache_flush_every=cache_flush_every, + ) base_client = AsyncOpenAI( base_url=base_url, @@ -75,38 +69,28 @@ def __init__( self._client = instructor.from_openai(client=base_client, mode=instructor_mode) - self.cache = TextCache(flush_every_n_writes=cache_flush_every) - @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) - async def _one_call( + async def complete( self, - prompt: str, - schema: Optional[BaseModel] = None, - system_prompt: Optional[str] = None, - model_name: Optional[str] = None, + messages: ChatMessages, + response_model: Type[BaseModel]=None, + model_name: str = None, **kwargs: Any, ) -> Optional[Union[str, BaseModel]]: """ Perform a single generation request to the LLM with retry logic. - :param prompt: The input text or instruction prompt. - :param schema: Optional Pydantic model defining the structured response format. - :param system_prompt: Optional system-level instruction prepended to the prompt. :param model_name: Override model name for this call (defaults to client model). :param kwargs: Additional API call parameters. :return: Parsed model output or raw string, or ``None`` if failed. """ - messages = [{"role": "user", "content": prompt}] - if system_prompt: - messages.insert(0, {"role": "system", "content": system_prompt}) - try: self.statistics["requests"] += 1 parsed: BaseModel = await self._client.chat.completions.create( model=model_name or self.model_name, - messages=messages, # type: ignore - response_model=schema, + messages=messages.to_openai(), + response_model=response_model, **kwargs, ) self.statistics["success"] += 1 @@ -117,82 +101,6 @@ async def _one_call( self.statistics["fail"] += 1 raise - @no_throw - async def generate( - self, - prompt: str | list[str], - *, - system_prompt: Optional[str] = None, - model_name: Optional[str] = None, - progress_bar_desc: Optional[str] = "Processing", - schema: Optional[type[BaseModel]] = None, - **kwargs: Any, - ) -> List[Optional[Union[str, BaseModel]]]: - - prompts: List[str] = [prompt] if isinstance(prompt, str) else list(prompt) - - results: list[Optional[Union[str, BaseModel]]] = [None] * len(prompts) - pending: list[PendingRequest] = [] - - for i, p in enumerate(prompts): - key = make_llm_cache_key( - prompt=p, - system_prompt=system_prompt, - model_name=model_name or self.model_name, - schema=schema, - kwargs=kwargs, - ) - - cached = await self.cache.get(key, schema=schema) - if cached is not None: - results[i] = cached - else: - pending.append(PendingRequest(i, p, key)) - - logger.info(f"[OpenAIClientService]: Found {len(prompts) - len(pending)}/{len(prompts)} requests in cache.") - - if not pending: - return results - - with tqdm_asyncio(total=len(pending), desc=progress_bar_desc) as pbar: - runner = AsyncRunner(self._sem, self._rps, self._rpm, pbar) - - for batch in BatchGenerator(pending, self._cache_flush_every).get_batches(): - tasks = [ - runner.make_request( - self._one_call, - prompt=req.prompt, - system_prompt=system_prompt, - model_name=model_name, - schema=schema, - **kwargs - ) - for req in batch - ] - - generated = await asyncio.gather(*tasks, return_exceptions=True) - - for req, value in zip(batch, generated): - if not isinstance(value, Exception) and value is not None: - if system_prompt: - input_instruction = f"[system]: {system_prompt}\n[user]: {req.prompt}" - else: - input_instruction = req.prompt - - await self.cache.set( - req.cache_key, - value, - input_instruction=input_instruction, - model_name=model_name or self.model_name, - ) - results[req.index] = value - else: - results[req.index] = None - - await self.cache.flush_cache() - - return results - async def async_close(self) -> None: """ Close the underlying asynchronous OpenAI client and flush cache. @@ -205,4 +113,4 @@ async def async_close(self) -> None: try: await self._client.close() except Exception: - pass + pass \ No newline at end of file diff --git a/ragu/search_engine/base_engine.py b/ragu/search_engine/base_engine.py index cc93b57..8b948fb 100644 --- a/ragu/search_engine/base_engine.py +++ b/ragu/search_engine/base_engine.py @@ -1,30 +1,32 @@ from abc import ABC, abstractmethod -from pydantic import BaseModel - from ragu.common.base import RaguGenerativeModule +from ragu.common.prompts.default_models import GlobalSearchContextModel +from ragu.llm.base_llm import BaseLLM +from ragu.search_engine.types import NaiveSearchResult, LocalSearchResult from ragu.utils.ragu_utils import always_get_an_event_loop class BaseEngine(RaguGenerativeModule, ABC): - def __init__(self, *args, **kwargs): + def __init__(self, client: BaseLLM, *args, **kwargs): super().__init__(*args, **kwargs) + self.client = client @abstractmethod - async def a_search(self, query, *args, **kwargs): + async def a_search(self, query, *args, **kwargs) -> NaiveSearchResult | LocalSearchResult | GlobalSearchContextModel: """ Get relevant information from knowledge graph """ pass @abstractmethod - async def a_query(self, query: str) -> BaseModel: + async def a_query(self, query: str) -> str: """ Get answer on query from knowledge graph """ pass - async def query(self, query: str) -> BaseModel: + async def query(self, query: str) -> str: """ Get answer on query from knowledge graph """ @@ -33,7 +35,7 @@ async def query(self, query: str) -> BaseModel: self.a_query(query) ) - async def search(self, query, *args, **kwargs): + async def search(self, query, *args, **kwargs) -> NaiveSearchResult | LocalSearchResult | GlobalSearchContextModel: """ Get relevant information from knowledge graph """ diff --git a/ragu/search_engine/global_search.py b/ragu/search_engine/global_search.py index 6762519..f4a05f7 100644 --- a/ragu/search_engine/global_search.py +++ b/ragu/search_engine/global_search.py @@ -4,12 +4,16 @@ from pydantic import BaseModel from ragu.common.base import RaguGenerativeModule +from ragu.common.global_parameters import Settings from ragu.graph.knowledge_graph import KnowledgeGraph from ragu.llm.base_llm import BaseLLM from ragu.search_engine.base_engine import BaseEngine from ragu.search_engine.types import GlobalSearchResult from ragu.utils.token_truncation import TokenTruncation +from ragu.common.prompts.prompt_storage import RAGUInstruction +from ragu.common.prompts.messages import ChatMessages, render + class GlobalSearchEngine(BaseEngine, RaguGenerativeModule): """ @@ -27,8 +31,9 @@ def __init__( max_context_length: int = 30_000, tokenizer_backend: str = "tiktoken", tokenizer_model: str = "gpt-4", + language: str | None = None, *args, - **kwargs + **kwargs, ): """ Initialize a new `GlobalSearchEngine`. @@ -40,14 +45,16 @@ def __init__( :param tokenizer_model: Model name for tokenizer calibration (default: ``"gpt-4"``). """ _PROMPTS = ["global_search_context", "global_search"] - super().__init__(prompts=_PROMPTS, *args, **kwargs) + super().__init__(client=client, prompts=_PROMPTS, *args, **kwargs) - self.knowledge_graph = knowledge_graph self.client = client + self.knowledge_graph = knowledge_graph + self.language = language if language else Settings.language + self.truncation = TokenTruncation( tokenizer_model, tokenizer_backend, - max_context_length + max_context_length, ) async def a_search(self, query: str, *args, **kwargs) -> GlobalSearchResult: @@ -66,12 +73,12 @@ async def a_search(self, query: str, *args, **kwargs) -> GlobalSearchResult: self.knowledge_graph.index.community_summary_kv_storage.get_by_id(community_cluster_id) for community_cluster_id in await self.knowledge_graph.index.communities_kv_storage.all_keys() ]) - communities = list(filter(lambda x: x is not None, communities)) + communities = [c for c in communities if c is not None] responses = await self.get_meta_responses(query, communities) - responses: list[dict] = list(filter(lambda x: int(x.get("rating", 0)) > 0, responses)) - responses: list[dict] = sorted(responses, key=lambda x: int(x.get("rating", 0)), reverse=True) + responses = [r for r in responses if int(r.get("rating", 0)) > 0] + responses = sorted(responses, key=lambda x: int(x.get("rating", 0)), reverse=True) return GlobalSearchResult(responses) @@ -87,19 +94,23 @@ async def get_meta_responses(self, query: str, context: List[str]) -> List[dict] :param context: A list of community summary texts to evaluate. :return: A list of structured responses with fields such as ``response`` and ``rating``. """ - prompts, schema = self.get_prompt("global_search_context").get_instruction( + instruction: RAGUInstruction = self.get_prompt("global_search_context") + + rendered_list: List[ChatMessages] = render( + instruction.messages, query=query, - context=context + context=context, + language=self.language, ) meta_responses = await self.client.generate( - prompt=prompts, - schema=schema + conversations=rendered_list, + response_model=instruction.pydantic_model, ) - return [response.model_dump() for response in meta_responses if response] + return [r.model_dump() for r in meta_responses if r] - async def a_query(self, query: str) -> BaseModel: + async def a_query(self, query: str) -> str: """ Execute a full global retrieval-augmented generation query. @@ -112,12 +123,19 @@ async def a_query(self, query: str) -> BaseModel: context = await self.a_search(query) truncated_context: str = self.truncation(str(context)) - prompts, schema = self.get_prompt("global_search").get_instruction( + instruction: RAGUInstruction = self.get_prompt("global_search") + + rendered_list: List[ChatMessages] = render( + instruction.messages, query=query, - context=truncated_context + context=truncated_context, + language=self.language, ) + rendered = rendered_list[0] - return await self.client.generate( - prompt=prompts, - schema=schema + result = await self.client.generate( + conversations=[rendered], + response_model=instruction.pydantic_model, ) + + return result[0].response if hasattr(result, "response") else result diff --git a/ragu/search_engine/local_search.py b/ragu/search_engine/local_search.py index d76d3ff..34305f4 100644 --- a/ragu/search_engine/local_search.py +++ b/ragu/search_engine/local_search.py @@ -1,7 +1,9 @@ # Partially based on https://github.com/gusye1234/nano-graphrag/blob/main/nano_graphrag/ import asyncio +from typing import List +from ragu.common.global_parameters import Settings from ragu.embedder.base_embedder import BaseEmbedder from ragu.graph.knowledge_graph import KnowledgeGraph from ragu.llm.base_llm import BaseLLM @@ -10,18 +12,24 @@ _find_most_related_edges_from_entities, _find_most_related_text_unit_from_entities, _find_documents_id, - _find_most_related_community_from_entities + _find_most_related_community_from_entities, ) from ragu.search_engine.types import LocalSearchResult from ragu.utils.token_truncation import TokenTruncation +from ragu.common.prompts.prompt_storage import RAGUInstruction +from ragu.common.prompts.messages import ChatMessages, render + class LocalSearchEngine(BaseEngine): """ Performs local retrieval-augmented search (RAG) over a knowledge graph. - This engine finds entities, relations, and text units most relevant to a given - query, builds a local context, and passes it to an LLM for response generation. + The engine: + 1) Retrieves relevant entities/relations/chunks/summaries for the query, + 2) Builds and truncates a local context, + 3) Renders a prompt template (Jinja2) into concrete ChatMessages, + 4) Sends OpenAI-typed messages to the LLM for generation/parsing. Reference --------- @@ -36,40 +44,41 @@ def __init__( max_context_length: int = 30_000, tokenizer_backend: str = "tiktoken", tokenizer_model: str = "gpt-4", + language: str | None = None, *args, - **kwargs + **kwargs, ): """ Initialize a `LocalSearchEngine`. - :param client: Language model client for generation. + :param client: LLM client used to generate the final answer. :param knowledge_graph: Knowledge graph used for entity and relation retrieval. - :param embedder: Embedding model for similarity search. - :param max_context_length: Maximum number of tokens allowed in the truncated context. - :param tokenizer_backend: Tokenizer backend to use (e.g. ``tiktoken``). - :param tokenizer_model: Model name used for token counting and truncation. + :param embedder: Embedding model used for similarity search. + :param max_context_length: Max tokens allowed for the final context (after truncation). + :param tokenizer_backend: Tokenizer backend used for token counting/truncation. + :param tokenizer_model: Model name used by the tokenizer backend. + :param language: Default output language (fed into prompt template). """ _PROMPTS_NAMES = ["local_search"] - super().__init__(prompts=_PROMPTS_NAMES, *args, **kwargs) + super().__init__(client=client, prompts=_PROMPTS_NAMES, *args, **kwargs) self.truncation = TokenTruncation( tokenizer_model, tokenizer_backend, - max_context_length + max_context_length, ) self.knowledge_graph = knowledge_graph self.embedder = embedder - self.client = client - self.community_reports = None + self.language = language if language else Settings.language async def a_search(self, query: str, top_k: int = 20, *args, **kwargs) -> LocalSearchResult: """ - Perform a local search on the knowledge graph. + Retrieve local graph context for the given query. - :param query: The input text query to search for. - :param top_k: Number of top entities to include in context (default: 20). - :return: A :class:`SearchResult` object containing entities, relations, and chunks. + :param query: Input query string. + :param top_k: Number of top entities to retrieve from the entity vector DB. + :return: LocalSearchResult containing entities, relations, summaries, chunks, and document ids. """ entities_id = await self.knowledge_graph.index.entity_vector_db.query(query, top_k=top_k) @@ -94,24 +103,39 @@ async def a_search(self, query: str, top_k: int = 20, *args, **kwargs) -> LocalS relations=relations, summaries=summaries, chunks=relevant_chunks, - documents_id=documents_id + documents_id=documents_id, ) async def a_query(self, query: str, top_k: int = 20) -> str: """ - Execute a retrieval-augmented query over the local knowledge graph. + Execute a local RAG query. + + Steps: + 1) Run `a_search` to build a local graph-derived context + 2) Truncate the context to fit `max_context_length` + 3) Take the stored RAGUInstruction, render its ChatMessages via Jinja2 + using (query, context, language) + 4) Send rendered messages to the LLM. If the instruction has `pydantic_model`, + the LLM client may parse into that model. :param query: User query in natural language. - :param top_k: Number of entities to search in the local context (default: 20). - :return: Generated response text from the language model. + :param top_k: Number of entities to retrieve into context. + :return: Final model response (string or extracted field if returned model-like). """ - context: LocalSearchResult = await self.a_search(query, top_k) truncated_context: str = self.truncation(str(context)) + instruction: RAGUInstruction = self.get_prompt("local_search") - prompt, schema = self.get_prompt("local_search").get_instruction( + rendered_conversations: List[ChatMessages] = render( + instruction.messages, query=query, - context=truncated_context + context=truncated_context, + language=self.language, + ) + rendered: ChatMessages = rendered_conversations[0] + result = await self.client.generate( + conversations=[rendered], + response_model=instruction.pydantic_model, ) - return await self.client.generate(prompt=prompt, schema=schema) + return result[0].response if hasattr(result[0], "response") else result[0] diff --git a/ragu/search_engine/naive_search.py b/ragu/search_engine/naive_search.py index 69d65c6..628b128 100644 --- a/ragu/search_engine/naive_search.py +++ b/ragu/search_engine/naive_search.py @@ -1,15 +1,18 @@ from typing import Optional, List from ragu.chunker.types import Chunk +from ragu.common.global_parameters import Settings from ragu.embedder.base_embedder import BaseEmbedder from ragu.graph.knowledge_graph import KnowledgeGraph from ragu.llm.base_llm import BaseLLM from ragu.rerank.base_reranker import BaseReranker from ragu.search_engine.base_engine import BaseEngine from ragu.search_engine.types import NaiveSearchResult -from ragu.storage.index import Index from ragu.utils.token_truncation import TokenTruncation +from ragu.common.prompts.prompt_storage import RAGUInstruction +from ragu.common.prompts.messages import ChatMessages, render + class NaiveSearchEngine(BaseEngine): """ @@ -28,33 +31,36 @@ def __init__( max_context_length: int = 30_000, tokenizer_backend: str = "tiktoken", tokenizer_model: str = "gpt-4", + language: str | None = None, *args, **kwargs ): """ Initialize a `NaiveSearchEngine`. - :param client: Language model client for generation. - :param knowledge_graph: Knowledge graph containing chunk vector database and KV storage. - :param embedder: Embedding model for similarity search. - :param reranker: Optional reranker for improving retrieval quality. - :param max_context_length: Maximum number of tokens allowed in the truncated context. - :param tokenizer_backend: Tokenizer backend to use (e.g. ``tiktoken``). - :param tokenizer_model: Model name used for token counting and truncation. + :param client: LLM client used to generate the final answer. + :param knowledge_graph: Knowledge graph containing chunk vector DB and chunk KV storage. + :param embedder: Embedding model (kept for interface parity; retrieval uses graph index DBs). + :param reranker: Optional reranker used to improve ranking of retrieved chunks. + :param max_context_length: Max tokens allowed for context after truncation. + :param tokenizer_backend: Tokenizer backend used for token truncation. + :param tokenizer_model: Model name used by the tokenizer backend. + :param language: Default output language """ _PROMPTS_NAMES = ["naive_search"] - super().__init__(prompts=_PROMPTS_NAMES, *args, **kwargs) + super().__init__(client=client, prompts=_PROMPTS_NAMES, *args, **kwargs) self.truncation = TokenTruncation( tokenizer_model, tokenizer_backend, - max_context_length + max_context_length, ) self.graph = knowledge_graph self.embedder = embedder self.reranker = reranker self.client = client + self.language = language if language else Settings.language async def a_search( self, @@ -67,11 +73,11 @@ async def a_search( """ Perform a naive vector search over chunks. - :param query: The input text query to search for. - :param top_k: Number of top chunks to retrieve initially (default: 20). - :param rerank_top_k: Number of chunks to return after reranking. - If None, returns all reranked chunks. Only used if reranker is set. - :return: A :class:`NaiveSearchResult` object containing chunks and scores. + :param query: Input query string. + :param top_k: Number of top chunks to retrieve initially. + :param rerank_top_k: Number of chunks to keep after reranking. + If None, keeps all reranked chunks. Used only when reranker is set. + :return: NaiveSearchResult with retrieved chunks, scores, and document ids. """ results = await self.graph.index.chunk_vector_db.query(query, top_k=top_k) @@ -115,13 +121,12 @@ async def a_search( chunks = chunks[:rerank_top_k] scores = scores[:rerank_top_k] - # Collect document IDs - documents_id = list(set(c.doc_id for c in chunks if c.doc_id)) + documents_id = list({c.doc_id for c in chunks if c.doc_id}) return NaiveSearchResult( chunks=chunks, scores=scores, - documents_id=documents_id + documents_id=documents_id, ) async def a_query(self, query: str, top_k: int = 20, rerank_top_k: Optional[int] = None) -> str: @@ -136,9 +141,19 @@ async def a_query(self, query: str, top_k: int = 20, rerank_top_k: Optional[int] context: NaiveSearchResult = await self.a_search(query, top_k, rerank_top_k) truncated_context: str = self.truncation(str(context)) - prompt, schema = self.get_prompt("naive_search").get_instruction( + instruction: RAGUInstruction = self.get_prompt("naive_search") + + rendered_list: List[ChatMessages] = render( + instruction.messages, query=query, - context=truncated_context + context=truncated_context, + language=self.language, + ) + rendered: ChatMessages = rendered_list[0] + + result = await self.client.generate( + conversations=[rendered], + response_model=instruction.pydantic_model, ) - return await self.client.generate(prompt=prompt, schema=schema) + return result.response if hasattr(result, "response") else str(result) diff --git a/ragu/search_engine/query_plan.py b/ragu/search_engine/query_plan.py index ad9b21d..2a3afb0 100644 --- a/ragu/search_engine/query_plan.py +++ b/ragu/search_engine/query_plan.py @@ -4,20 +4,28 @@ from ragu.search_engine.base_engine import BaseEngine from ragu.search_engine.search_functional import _topological_sort +from ragu.common.prompts.prompt_storage import RAGUInstruction +from ragu.common.prompts.messages import ChatMessages, render + class QueryPlanEngine(BaseEngine): """ - A query planning engine that decomposes complex queries into subqueries. - - It analyzes the input query, breaks it down into a dependency graph of simpler - subqueries, executes them in topological order, and combines results to produce - a final answer. - - :param engine: The base search engine used to execute individual subqueries. + Query planning engine that decomposes complex queries into a DAG of subqueries + and executes them in topological order. + + Pipeline: + 1. Decompose query -> list[SubQuery] (DAG) + 2. Topological sort + 3. For each subquery: + - rewrite using dependency answers (if needed) + - execute with underlying engine + - store answer in context + 4. Return answer of the last subquery """ - def __init__(self, engine, *args, **kwargs): + + def __init__(self, engine: BaseEngine, *args, **kwargs): _PROMPTS_NAMES = ["query_decomposition", "query_rewrite"] - super().__init__(prompts=_PROMPTS_NAMES, *args, **kwargs) + super().__init__(client=engine.client, prompts=_PROMPTS_NAMES, *args, **kwargs) self.engine: BaseEngine = engine async def process_query(self, query: str) -> List[SubQuery]: @@ -28,39 +36,59 @@ async def process_query(self, query: str) -> List[SubQuery]: independent subqueries. Each subquery is assigned a unique ID and may declare dependencies on other subqueries that must be resolved first. - :param query: The complex natural-language query to decompose. - :return: List of :class:`SubQuery` objects forming a directed acyclic graph (DAG). + :param query: Complex natural-language query to decompose. + :return: List of SubQuery objects forming a DAG. """ - prompt, schema = self.get_prompt("query_decomposition").get_instruction( - query=query + instruction: RAGUInstruction = self.get_prompt("query_decomposition") + + rendered_list: List[ChatMessages] = render( + instruction.messages, + query=query, ) + rendered = rendered_list[0] response = await self.engine.client.generate( - prompt=prompt, - schema=schema + conversations=[rendered], + response_model=instruction.pydantic_model, ) - print(response[0].subqueries) + return response[0].subqueries async def _rewrite_subquery(self, subquery: SubQuery, context: Dict[str, str]) -> str: """ - Rewrites a subquery using answers of its dependencies. + Rewrite a subquery by injecting answers from its dependency subqueries. + + Only dependency answers listed in `subquery.depends_on` are provided + to the rewrite prompt. + + :param subquery: The subquery to rewrite. + :param context: Mapping of {subquery_id -> answer} accumulated so far. + :return: Rewritten, self-contained query string. """ - context = {k: v for k, v in context.items() if k in subquery.depends_on} - prompt, schema = self.get_prompt("query_rewrite").get_instruction( + dep_context = {k: v for k, v in context.items() if k in subquery.depends_on} + + instruction: RAGUInstruction = self.get_prompt("query_rewrite") + rendered_list: List[ChatMessages] = render( + instruction.messages, original_query=subquery.query, - context=context + context=dep_context, ) + rendered = rendered_list[0] + response = await self.engine.client.generate( - prompt=prompt, - schema=schema + conversations=[rendered], + response_model=instruction.pydantic_model, ) - return response[0].query.strip() + + return response[0].query if hasattr(response[0], "query") else response async def _answer_subquery(self, subquery: SubQuery, context: Dict[str, str]) -> str: """ - Executes a single subquery. - Injects answers of dependencies into the prompt. + Execute a single subquery, rewriting it first if it has dependencies. + + :param subquery: The subquery to execute. + :param context: Mapping of {subquery_id -> answer} for dependency injection. + :return: Answer string for this subquery. """ if subquery.depends_on: query = await self._rewrite_subquery(subquery, context) @@ -68,7 +96,8 @@ async def _answer_subquery(self, subquery: SubQuery, context: Dict[str, str]) -> query = subquery.query result = await self.engine.a_query(query) - return result[0].model_dump().get("response") + + return result async def a_query(self, query: str) -> str: """ @@ -90,7 +119,6 @@ async def a_query(self, query: str) -> str: ordered = _topological_sort(subqueries) context: Dict[str, str] = {} - for subquery in ordered: answer = await self._answer_subquery(subquery, context) context[subquery.id] = answer @@ -107,5 +135,3 @@ async def a_search(self, query, *args, **kwargs): :return: Search results from the underlying engine. """ return await self.engine.a_search(query, *args, **kwargs) - - diff --git a/ragu/search_engine/types.py b/ragu/search_engine/types.py index 3800d66..78db5a5 100644 --- a/ragu/search_engine/types.py +++ b/ragu/search_engine/types.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from textwrap import dedent from jinja2 import Template @@ -11,32 +12,32 @@ class LocalSearchResult: chunks: list=field(default_factory=list) documents_id: list[str]=field(default_factory=list) - _template: Template = Template( -""" -**Entities**\nEntity, entity type, entity description -{%- for e in entities %} -{{ e.entity_name }}, {{ e.entity_type }}, {{ e.description }} -{%- endfor %} - -**Relations**\nSubject, object, relation description, rank -{%- for r in relations %} -{{ r.subject_name }}, {{ r.object_name }}, {{ r.description }}, {{ r.rank }} -{%- endfor %} - -{%- if summaries %} -Summary -{%- for s in summaries %} -{{ s }} -{%- endfor %} -{% endif %} - -{%- if chunks %} -Chunks -{%- for c in chunks %} -{{ c.content }} -{%- endfor %} -{% endif %} -""" + _template: Template = Template(dedent( + """ + **Entities**\nEntity, entity type, entity description + {%- for e in entities %} + {{ e.entity_name }}, {{ e.entity_type }}, {{ e.description }} + {%- endfor %} + + **Relations**\nSubject, object, relation description, rank + {%- for r in relations %} + {{ r.subject_name }}, {{ r.object_name }}, {{ r.description }}, {{ r.rank }} + {%- endfor %} + + {%- if summaries %} + **Summary** + {%- for s in summaries %} + {{ s }} + {%- endfor %} + {% endif %} + + {%- if chunks %} + **Chunks** + {%- for c in chunks %} + {{ c.content }} + {%- endfor %} + {% endif %} + """) ) def __str__(self) -> str: @@ -52,12 +53,12 @@ def __str__(self) -> str: class GlobalSearchResult: insights: list=field(default_factory=list) - _template: Template = Template( + _template: Template = Template(dedent( """ {%- for insight in insights %} {{ loop.index}}. Insight: {{ insight.response }}, rating: {{ insight.rating }} {%- endfor %} - """.strip() + """) ) def __str__(self) -> str: @@ -70,14 +71,14 @@ class NaiveSearchResult: scores: list=field(default_factory=list) documents_id: list[str]=field(default_factory=list) - _template: Template = Template( -""" -**Retrieved Chunks** -{%- for chunk, score in zip(chunks, scores) %} -[{{ loop.index }}] (score: {{ "%.3f"|format(score) }}) -{{ chunk.content }} -{%- endfor %} -""" + _template: Template = Template(dedent( + """ + **Retrieved Chunks** + {%- for chunk, score in zip(chunks, scores) %} + [{{ loop.index }}] (score: {{ "%.3f"|format(score) }}) + {{ chunk.content }} + {%- endfor %} + """) ) def __str__(self) -> str: diff --git a/ragu/storage/base_storage.py b/ragu/storage/base_storage.py index 725c80c..e410f00 100644 --- a/ragu/storage/base_storage.py +++ b/ragu/storage/base_storage.py @@ -108,7 +108,7 @@ async def remove_isolated_nodes(self): ... @abstractmethod - async def cluster(self, **kwargs): + async def cluster(self, max_cluster_size: int=128, **additional_cluster_kwargs): ... @abstractmethod diff --git a/ragu/storage/graph_storage_adapters/networkx_adapter.py b/ragu/storage/graph_storage_adapters/networkx_adapter.py index e014da5..955c959 100644 --- a/ragu/storage/graph_storage_adapters/networkx_adapter.py +++ b/ragu/storage/graph_storage_adapters/networkx_adapter.py @@ -57,7 +57,9 @@ class NetworkXStorage(BaseGraphStorage): def __init__( self, filename: str, - clustering_params=None, + random_seed: Optional[int] = 42, + max_cluster_size: int=128, + clustering_params: Dict | None=None, **kwargs, ): """ @@ -67,7 +69,11 @@ def __init__( :param clustering_params: Optional parameters for community detection. """ if clustering_params is None: - clustering_params = {"max_community_size": 1000} + clustering_params = {} + if "max_cluster_size" not in clustering_params: + clustering_params.update({"max_cluster_size": max_cluster_size}) + if "random_seed" not in clustering_params: + clustering_params.update({"random_seed": random_seed}) self._graph: nx.Graph = nx.read_gml(filename) if os.path.exists(filename) else nx.Graph() self._where_to_save = filename @@ -134,6 +140,48 @@ async def get_node_edges(self, source_node_id: str) -> List[Relation]: unique_relations.append(r) return unique_relations + async def get_all_edges_for_node(self, node_id: str) -> List[Relation]: + """ + Retrieve all edges where the node is either subject or object. + + For undirected graphs, this is equivalent to get_node_edges. + For directed graphs, this would include both incoming and outgoing edges. + + :param node_id: ID of the node whose edges to fetch. + :return: List of all relations involving this node. + """ + if not self._graph.has_node(node_id): + return [] + + relations: List[Relation] = [] + + for u, v, metadata in self._graph.edges(node_id, data=True): + subject_id = str(u) + object_id = str(v) + subject_name = self._graph.nodes.get(u, {}).get("entity_name", subject_id) + object_name = self._graph.nodes.get(v, {}).get("entity_name", object_id) + relation = Relation( + subject_id=subject_id, + object_id=object_id, + subject_name=subject_name, + object_name=object_name, + description=metadata.get("description", ""), + relation_strength=float(metadata.get("relation_strength", 1.0)), + source_chunk_id=list(metadata.get("source_chunk_id", [])), + id=metadata.get("id"), + ) + relations.append(relation) + + seen: set[Tuple[str, str]] = set() + unique_relations: List[Relation] = [] + for r in relations: + key = tuple(sorted((r.subject_id, r.object_id))) + if key in seen: + continue + seen.add(key) + unique_relations.append(r) + return unique_relations + # TODO: add calculating async def get_edge_degree(self, source_node_id: str, target_node_id: str) -> int: """ diff --git a/ragu/storage/index.py b/ragu/storage/index.py index 2c35cbe..58be296 100644 --- a/ragu/storage/index.py +++ b/ragu/storage/index.py @@ -9,6 +9,7 @@ Iterable, List, Optional, + Tuple, Type ) @@ -16,12 +17,14 @@ from ragu.common.global_parameters import DEFAULT_FILENAMES from ragu.common.global_parameters import Settings from ragu.embedder.base_embedder import BaseEmbedder +from ragu.graph.graph_builder_pipeline import KnowledgeGraphBuilderSettings from ragu.graph.types import ( Entity, Relation, Community, CommunitySummary ) +from ragu.storage.transaction import DeleteTransaction from ragu.storage.base_storage import ( BaseKVStorage, BaseVectorStorage, @@ -35,12 +38,17 @@ class Index: """ - Index class that manages storages for a knowledge graph. + Index class that manages all storage operations for a knowledge graph. + + Provides CRUD operations for entities, relations, chunks, communities, + and community summaries with proper cascading deletes and multi-storage + consistency. """ def __init__( self, embedder: BaseEmbedder, + builder_parameters: KnowledgeGraphBuilderSettings = KnowledgeGraphBuilderSettings, graph_backend_storage: Type[BaseGraphStorage] = NetworkXStorage, kv_storage_type: Type[BaseKVStorage] = JsonKVStorage, vdb_storage_type: Type[BaseVectorStorage] = NanoVectorDBStorage, @@ -57,6 +65,8 @@ def __init__( Settings.init_storage_folder() storage_folder: str = Settings.storage_folder + self.builder_parameters = builder_parameters + self.embedder = embedder self.summary_kv_storage_kwargs = self._build_storage_kwargs( storage_folder, @@ -108,7 +118,11 @@ def __init__( self.chunk_vector_db = vdb_storage_type(embedder=embedder, **chunk_vdb_storage_kwargs) # type: ignore # Graph storage - self.graph_backend = graph_backend_storage(**self.graph_storage_kwargs) # type: ignore + self.graph_backend = graph_backend_storage( + max_cluster_size=builder_parameters.max_cluster_size, # type: ignore + random_seed=builder_parameters.random_seed, # type: ignore + **self.graph_storage_kwargs # type: ignore + ) async def make_index( self, @@ -118,54 +132,51 @@ async def make_index( summaries: List[CommunitySummary] = None, ) -> None: """ - Creates an index for the given knowledge graph. Save entities, relations, communities and community summaries. + Creates an index for the given knowledge graph items. """ tasks = [] if entities: tasks.extend( [ - self._insert_entities_to_graph(entities), - self._insert_entities_to_vdb(entities), + self.insert_entities_to_graph(entities), + self.insert_entities_to_vdb(entities), ] ) if relations: tasks.extend( [ - self._insert_relations_to_graph(relations), - self._insert_relations_to_vdb(relations), + self.insert_relations_to_graph(relations), + self.insert_relations_to_vdb(relations), ] ) if communities: - tasks.append(self._insert_communities(communities)) + tasks.append(self.insert_communities(communities)) if summaries: - tasks.append(self._insert_summaries(summaries)) + tasks.append(self.insert_summaries(summaries)) if tasks: await asyncio.gather(*tasks) - async def _insert_entities_to_graph(self, entities: List[Entity]) -> None: + async def insert_entities_to_graph(self, entities: List[Entity]) -> None: if not entities: return backend = self.graph_backend if self.graph_backend is None: logger.warning("Graph storage is not initialized.") return - await self._graph_bulk_upsert(backend, entities, backend.upsert_node, "entities") + await self.graph_bulk_upsert(backend, entities, backend.upsert_node, "entities") - async def _insert_relations_to_graph(self, relations: List[Relation]) -> None: + async def insert_relations_to_graph(self, relations: List[Relation]) -> None: if not relations: return backend = self.graph_backend if backend is None: logger.warning("Graph storage is not initialized.") return - await self._graph_bulk_upsert(backend, relations, backend.upsert_edge, "relations") + await self.graph_bulk_upsert(backend, relations, backend.upsert_edge, "relations") - async def _insert_entities_to_vdb(self, entities: List[Entity]) -> None: - """ - Inserts entities from the knowledge graph into the vector database. - """ + async def insert_entities_to_vdb(self, entities: List[Entity]) -> None: if not entities: return @@ -178,10 +189,7 @@ async def _insert_entities_to_vdb(self, entities: List[Entity]) -> None: } await self._vdb_upsert(self.entity_vector_db, data_for_vdb, "entities") - async def _insert_relations_to_vdb(self, relations: List[Relation]) -> None: - """ - Inserts relations from the knowledge graph into the vector database. - """ + async def insert_relations_to_vdb(self, relations: List[Relation]) -> None: if not relations: return @@ -197,11 +205,7 @@ async def _insert_relations_to_vdb(self, relations: List[Relation]) -> None: async def insert_chunks(self, chunks: List[Chunk], vectorize: bool = False) -> None: """ - Stores raw chunks in a KV storage (id -> chunk fields). - Optionally vectorizes chunks and stores them in the vector database. - - :param chunks: List of Chunk objects to store. - :param vectorize: If True, also insert chunks into the vector database for similarity search. + Stores raw chunks in a KV storage. """ tasks = [] @@ -219,15 +223,12 @@ async def insert_to_kv(): tasks.append(insert_to_kv()) if vectorize: - tasks.append(self._insert_chunks_to_vdb(chunks)) + tasks.append(self.insert_chunks_to_vdb(chunks)) if tasks: await asyncio.gather(*tasks) - async def _insert_chunks_to_vdb(self, chunks: List[Chunk]) -> None: - """ - Inserts chunks into the vector database for similarity search. - """ + async def insert_chunks_to_vdb(self, chunks: List[Chunk]) -> None: if not chunks: return @@ -240,16 +241,7 @@ async def _insert_chunks_to_vdb(self, chunks: List[Chunk]) -> None: } await self._vdb_upsert(self.chunk_vector_db, data_for_vdb, "chunks") - async def _insert_communities(self, communities: List[Community]) -> None: - """ - Store communities as ids only: - community.id -> { - "level": int, - "cluster_id": int, - "entity_ids": [str, ...], - "relation_ids": [str, ...] - } - """ + async def insert_communities(self, communities: List[Community]) -> None: if self.community_kv_storage is None: logger.warning("Community KV storage is not initialized.") return @@ -270,10 +262,7 @@ async def _insert_communities(self, communities: List[Community]) -> None: except Exception as e: logger.error(f"Failed to insert communities into KV storage: {e}") - async def _insert_summaries(self, summaries: List[CommunitySummary]) -> None: - """ - Store summaries as id -> text. - """ + async def insert_summaries(self, summaries: List[CommunitySummary]) -> None: if self.community_summary_kv_storage is None: logger.warning("Community summary KV storage is not initialized.") return @@ -286,7 +275,7 @@ async def _insert_summaries(self, summaries: List[CommunitySummary]) -> None: except Exception as e: logger.error(f"Failed to insert community summaries into KV storage: {e}") - async def _graph_bulk_upsert( + async def graph_bulk_upsert( self, backend: BaseGraphStorage, items: Iterable[Any], diff --git a/ragu/triplet/base_artifact_extractor.py b/ragu/triplet/base_artifact_extractor.py index c5a9b30..d191fac 100644 --- a/ragu/triplet/base_artifact_extractor.py +++ b/ragu/triplet/base_artifact_extractor.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from typing import Tuple, List, Iterable -from ragu.common.prompts import PromptTemplate + from ragu.chunker.types import Chunk +from ragu.common.prompts.prompt_storage import RAGUInstruction from ragu.graph.types import Entity, Relation from ragu.common.base import RaguGenerativeModule @@ -16,7 +17,7 @@ class BaseArtifactExtractor(RaguGenerativeModule, ABC): method to transform raw text chunks into structured graph entities and relations. """ - def __init__(self, prompts: list[str] | dict[str, PromptTemplate]) -> None: + def __init__(self, prompts: list[str] | dict[str, RAGUInstruction]) -> None: """ Initialize a new :class:`BaseArtifactExtractor`. diff --git a/ragu/triplet/llm_artifact_extractor.py b/ragu/triplet/llm_artifact_extractor.py index 2fc740f..56b196a 100644 --- a/ragu/triplet/llm_artifact_extractor.py +++ b/ragu/triplet/llm_artifact_extractor.py @@ -1,9 +1,11 @@ from __future__ import annotations -from typing import List, Tuple, Optional, Iterable +from typing import List, Tuple, Optional from ragu.chunker.types import Chunk from ragu.common.global_parameters import Settings +from ragu.common.prompts.prompt_storage import RAGUInstruction +from ragu.common.prompts.messages import ChatMessages, render from ragu.graph.types import Entity, Relation from ragu.llm.base_llm import BaseLLM from ragu.triplet.base_artifact_extractor import BaseArtifactExtractor @@ -14,9 +16,11 @@ class ArtifactsExtractorLLM(BaseArtifactExtractor): """ Extracts entities and relations from text chunks using LLM. - The class implements an LLM-driven pipeline for artifact extraction: - - Extract entities and relations from raw texts. - - Optionally performs LLM-based validation to refine the extracted artifacts. + Pipeline: + 1. Render the `artifact_extraction` instruction in batch mode over chunk texts. + 2. Call the LLM to produce structured artifacts for each chunk. + 3. Optionally render and run `artifact_validation` to refine extracted artifacts. + 4. Convert model outputs into Entity/Relation objects, preserving source chunk ids. """ def __init__( @@ -25,14 +29,14 @@ def __init__( do_validation: bool = False, language: str | None = None, entity_types: Optional[List[str]] = NEREL_ENTITY_TYPES, - relation_types: Optional[List[str]] = None + relation_types: Optional[List[str]] = None, ): """ Initialize a new :class:`ArtifactsExtractorLLM`. :param client: Language model client for generation and validation. :param do_validation: Whether to perform additional LLM-based validation of artifacts. - :param language: Input text language (used for prompt conditioning). + :param language: Output text language. :param entity_types: List of entity types to guide extraction prompts. :param relation_types: List of relation types to guide extraction prompts. """ @@ -45,94 +49,109 @@ def __init__( self.entity_types = ", ".join(entity_types) if entity_types else None self.relation_types = ", ".join(relation_types) if relation_types else None - async def extract(self, chunks: Iterable[Chunk], *args, **kwargs) -> Tuple[List[Entity], List[Relation]]: + async def extract(self, chunks: List[Chunk], *args, **kwargs) -> Tuple[List[Entity], List[Relation]]: """ - Extract entities and relations from a collection of text chunks. + Extract entities and relations from a collection of chunks. - The method performs two sequential steps: - 1. **Extraction:** Extract entities and relations from each chunk/ - 2. **Validation (optional):** Refine extracted artifacts.. + Steps: + 1) Batch-render the extraction prompt with `context=`, + 2) Generate structured artifacts per chunk, + 3) Optionally validate artifacts against the original context, + 4) Convert artifacts into Entity/Relation objects. - For each chunk, entities and relations are created as :class:`Entity` and :class:`Relation` - objects, preserving source metadata (chunk IDs). - - :param chunks: Iterable of :class:`Chunk` objects containing text content. - :param args: Additional positional arguments (ignored by default). - :param kwargs: Additional keyword arguments (ignored by default). - :return: A tuple ``(entities, relations)`` with all extracted artifacts. + :param chunks: Iterable of Chunk objects. + :return: (entities, relations) extracted from all chunks. """ - entities_result, relations_result = [], [] - context = [chunk.content for chunk in chunks] - prompts, schema = self.get_prompt("artifact_extraction").get_instruction( + entities_result: List[Entity] = [] + relations_result: List[Relation] = [] + + context: List[str] = [chunk.content for chunk in chunks] + + extraction_instruction: RAGUInstruction = self.get_prompt("artifact_extraction") + extraction_conversations: List[ChatMessages] = render( + extraction_instruction.messages, context=context, language=self.language, - entity_types=self.entity_types + entity_types=self.entity_types, + relation_types=self.relation_types, ) result_list = await self.client.generate( - prompt=prompts, - schema=schema, + conversations=extraction_conversations, + response_model=extraction_instruction.pydantic_model, progress_bar_desc="Extracting a knowledge graph from chunks", ) if self.do_validation: - prompts, schema = self.get_prompt("artifact_validation").get_instruction( + validation_instruction: RAGUInstruction = self.get_prompt("artifact_validation") + + validation_conversations: List[ChatMessages] = render( + validation_instruction.messages, artifacts=result_list, context=context, entity_types=self.entity_types, - language=self.language + relation_types=self.relation_types, + language=self.language, ) result_list = await self.client.generate( - prompt=prompts, - schema=schema, + conversations=validation_conversations, + response_model=validation_instruction.pydantic_model, progress_bar_desc="Validation of extracted artifacts", ) - result_list = list(filter(lambda x: x is not None and not isinstance(x, Exception), result_list)) + result_list = [x for x in result_list if x is not None and not isinstance(x, Exception)] for artifacts, chunk in zip(result_list, chunks): - current_chunk_entities = [] + current_chunk_entities: List[Entity] = [] # Parse entities for result in artifacts.model_dump().get("entities", []): - if result is not None: - entity = Entity( - entity_name=result.get("entity_name", ""), - entity_type=result.get("entity_type", ""), - description=result.get("description", ""), - source_chunk_id=[chunk.id], - documents_id=[], - clusters=[], - ) - current_chunk_entities.append(entity) + if result is None: + continue + entity = Entity( + entity_name=result.get("entity_name", ""), + entity_type=result.get("entity_type", ""), + description=result.get("description", ""), + source_chunk_id=[chunk.id], + documents_id=[], + clusters=[], + ) + current_chunk_entities.append(entity) + entities_result.extend(current_chunk_entities) # Parse relations for result in artifacts.model_dump().get("relations", []): - if result is not None: - subject_name = result.get("source_entity", "") - object_name = result.get("target_entity", "") - - if subject_name and object_name: - subject_entity = next( - (e for e in current_chunk_entities if e.entity_name == subject_name), None - ) - object_entity = next( - (e for e in current_chunk_entities if e.entity_name == object_name), None - ) - - if subject_entity and object_entity: - relation = Relation( - subject_name=subject_name, - object_name=object_name, - subject_id=subject_entity.id, - object_id=object_entity.id, - description=result.get("description", ""), - relation_strength=result.get("relation_strength", 1.0), - source_chunk_id=[chunk.id], - ) - relations_result.append(relation) + if result is None: + continue + + subject_name = result.get("source_entity", "") + object_name = result.get("target_entity", "") + + if not (subject_name and object_name): + continue + + subject_entity = next( + (e for e in current_chunk_entities if e.entity_name == subject_name), + None, + ) + object_entity = next( + (e for e in current_chunk_entities if e.entity_name == object_name), + None, + ) + + if subject_entity and object_entity: + relation = Relation( + subject_name=subject_name, + object_name=object_name, + subject_id=subject_entity.id, + object_id=object_entity.id, + description=result.get("description", ""), + relation_strength=result.get("relation_strength", 1.0), + source_chunk_id=[chunk.id], + ) + relations_result.append(relation) return entities_result, relations_result