Skip to content
19 changes: 10 additions & 9 deletions ragu/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
perform LLM-driven generation or structured response tasks.
"""

from ragu.common.prompts import PromptTemplate, DEFAULT_PROMPT_TEMPLATES
from ragu.common.prompts import DEFAULT_PROMPT_TEMPLATES
from ragu.common.prompts.prompt_storage import RAGUInstruction


class RaguGenerativeModule:
Expand All @@ -26,51 +27,51 @@ class RaguGenerativeModule:
:class:`PromptTemplate` instances directly.
"""

def __init__(self, prompts: list[str] | dict[str, PromptTemplate]):
def __init__(self, prompts: list[str] | dict[str, RAGUInstruction]):
"""
Initialize the generative module with one or more prompts.

:param prompts: Either a list of prompt names (loaded from
:data:`DEFAULT_PROMPT_TEMPLATES`) or a dictionary
mapping prompt names to :class:`PromptTemplate` objects.
mapping prompt names to :class:`ChatTemplate` objects.
:raises ValueError: If the input format is neither list nor dict.
"""
super().__init__()

if isinstance(prompts, list):
self.prompts: dict[str, PromptTemplate] = {
self.prompts: dict[str, RAGUInstruction] = {
prompt_name: DEFAULT_PROMPT_TEMPLATES.get(prompt_name) for prompt_name in prompts
}
elif isinstance(prompts, dict):
self.prompts = prompts
else:
raise ValueError(
f"Prompts must be a list of prompt names or a dictionary of prompt names and PromptTemplate objects, "
f"Prompts must be a list of prompt names or a dictionary of prompt names and ChatTemplate objects, "
f"got {type(prompts)}"
)

def get_prompts(self) -> dict:
"""
Retrieve all prompt templates registered in the module.

:return: Dictionary mapping prompt names to :class:`PromptTemplate` objects.
:return: Dictionary mapping prompt names to :class:`ChatTemplate` objects.
"""
return self.prompts

def get_prompt(self, prompt_name: str) -> PromptTemplate:
def get_prompt(self, prompt_name: str) -> RAGUInstruction:
"""
Retrieve a specific prompt template by name.

:param prompt_name: The name of the prompt to retrieve.
:return: The corresponding :class:`PromptTemplate` instance.
:return: The corresponding :class:`ChatTemplate` instance.
:raises ValueError: If the prompt name is not found.
"""
if prompt_name in self.prompts:
return self.prompts[prompt_name]
else:
raise ValueError(f"Prompt {prompt_name} not found")

def update_prompt(self, prompt_name: str, prompt: PromptTemplate) -> None:
def update_prompt(self, prompt_name: str, prompt: RAGUInstruction) -> None:
"""
Replace or add a prompt template in the module.

Expand Down
18 changes: 6 additions & 12 deletions ragu/common/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ragu.common.global_parameters import DEFAULT_FILENAMES, Settings
from ragu.common.logger import logger
from ragu.common.prompts import ChatMessages
from ragu.utils.ragu_utils import compute_mdhash_id


Expand All @@ -19,32 +20,25 @@ class PendingRequest:
Represents a request pending generation (not found in cache).
"""
index: int
prompt: str
messages: ChatMessages
cache_key: str


def make_llm_cache_key(
prompt: str,
system_prompt: Optional[str] = None,
content: str,
model_name: Optional[str] = None,
schema: Optional[Type[BaseModel]] = None,
kwargs: Optional[Dict[str, Any]] = None,
) -> str:
"""
Build a deterministic cache key from LLM request parameters.

:param prompt: The user prompt.
:param system_prompt: Optional system prompt.
:param model_name: Model name used for generation.
:param schema: Optional Pydantic schema class.
:param kwargs: Additional API parameters.
:return: A unique cache key string.
"""
key_parts = []

if system_prompt:
key_parts.append(f"[system]: {system_prompt}")
key_parts.append(f"[user]: {prompt}")
key_parts = [content]

if model_name:
key_parts.append(f"[model]: {model_name}")
Expand Down Expand Up @@ -121,7 +115,7 @@ def _load_cache(self) -> None:
return

try:
with self._cache_path.open("r", encoding="utf-16") as f:
with self._cache_path.open("r", encoding="utf-8") as f:
cache = json.load(f)
if isinstance(cache, dict):
self._mem_cache = cache
Expand Down Expand Up @@ -156,7 +150,7 @@ def _write_cache_file(self, path: Path) -> None:
"""
Write cache to file.
"""
with path.open("w", encoding="utf-16") as f:
with path.open("w", encoding="utf-8") as f:
json.dump(self._mem_cache, f, ensure_ascii=False, indent=2)

async def get(
Expand Down
26 changes: 16 additions & 10 deletions ragu/common/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from ragu.common.prompts.prompt_storage import (
PromptTemplate,
# MessageTemplate,
# ChatTemplate,
# ChatMessages,
# SystemMessage,
# AIMessage,
# UserMessage,
DEFAULT_PROMPT_TEMPLATES
)
from ragu.common.prompts.prompt_storage import DEFAULT_PROMPT_TEMPLATES
from ragu.common.prompts.messages import (
SystemMessage,
UserMessage,
AIMessage,
ChatMessages,
)

__all__ = [
"SystemMessage",
"UserMessage",
"AIMessage",
"ChatMessages",

"DEFAULT_PROMPT_TEMPLATES",
]
206 changes: 206 additions & 0 deletions ragu/common/prompts/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from dataclasses import dataclass
from typing import (
Literal,
Dict,
Any,
TypeVar,
List,
Type,
Sequence,
Mapping,
Union,
)

from jinja2 import Environment, StrictUndefined
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
)

Role = Literal["system", "user", "assistant"]


@dataclass(frozen=True, slots=True)
class BaseMessage:
"""
Base chat message abstraction.

Represents a single message in a chat conversation with a fixed role
(system, user, or assistant) and textual content. Provides conversion
to OpenAI SDK message types.
"""
content: str
role: Role
name: str | None = None

def to_openai(self) -> ChatCompletionMessageParam:
"""
Convert this message into a typed OpenAI ChatCompletion message.
"""
if self.role == "system":
return ChatCompletionSystemMessageParam(
role="system",
content=self.content,
name=self.name,
)

if self.role == "user":
return ChatCompletionUserMessageParam(
role="user",
content=self.content,
name=self.name,
)

if self.role == "assistant":
return ChatCompletionAssistantMessageParam(
role="assistant",
content=self.content,
name=self.name,
)

raise ValueError(f"Unsupported role: {self.role}")

def to_str(self) -> str:
"""
Return a human-readable string representation of the message.
"""
return f"[{self.role}]: {self.content}"


@dataclass(frozen=True, slots=True)
class SystemMessage(BaseMessage):
"""
System-level instruction message.
"""
role: Role = "system"


@dataclass(frozen=True, slots=True)
class UserMessage(BaseMessage):
"""
User input message.
"""
role: Role = "user"


@dataclass(frozen=True, slots=True)
class AIMessage(BaseMessage):
"""
Assistant (LLM) response message.
"""
role: Role = "assistant"


T = TypeVar("T", bound="ChatMessages")


@dataclass(frozen=True, slots=True)
class ChatMessages:
"""
Container for a list of chat messages.

Represents a single user-assistant conversation.
"""
messages: List[BaseMessage]

@classmethod
def from_messages(cls: Type[T], messages: Sequence[BaseMessage]) -> T:
"""
Construct a ChatMessages instance from a sequence of messages.
"""
return cls(messages=list(messages))

def to_openai(self) -> List[ChatCompletionMessageParam]:
"""
Convert all messages to OpenAI ChatCompletion message parameters.
"""
return [m.to_openai() for m in self.messages]

def __iter__(self):
return iter(self.messages)

def __len__(self):
return len(self.messages)

def to_str(self) -> str:
"""
Return a readable multi-line string representation of the conversation.
"""
return "\n".join([m.to_str() for m in self.messages])


def render(template_conversation: Union[BaseMessage, ChatMessages], **params: Any) -> List[ChatMessages]:
"""
Render Jinja2 templates inside message contents in batch mode.

Parameters
----------
template_conversation:
A single message or a ChatMessages instance used as a template.
params:
A mix of scalar parameters (shared across all prompts) and
batch parameters (lists/tuples). All batch parameters must
have the same length N.

Returns
-------
List[ChatMessages]
A list of rendered ChatMessages objects of length N.
"""

def _is_batch_value(v: Any) -> bool:
return isinstance(v, (list, tuple))

def _infer_batch_size(params: Mapping[str, Any]) -> int:
sizes = {len(v) for v in params.values() if _is_batch_value(v)}
if not sizes:
return 1
if len(sizes) != 1:
raise ValueError(f"Batch parameters have different sizes: {sorted(sizes)}.")
return next(iter(sizes))

def _build_row_context(params: Mapping[str, Any], i: int) -> Dict[str, Any]:
row: Dict[str, Any] = {}
for k, v in params.items():
row[k] = v[i] if _is_batch_value(v) else v
return row

env = Environment(
undefined=StrictUndefined,
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
)

if isinstance(template_conversation, BaseMessage):
template_cm = ChatMessages.from_messages([template_conversation])
else:
template_cm = template_conversation

n = _infer_batch_size(params)

for k, v in params.items():
if _is_batch_value(v) and len(v) != n:
raise ValueError(
f"Batch parameter '{k}' has length {len(v)}, expected {n}."
)

out: List[ChatMessages] = []
for i in range(n):
ctx = _build_row_context(params, i)

rendered_msgs: List[BaseMessage] = []
for m in template_cm.messages:
tmpl = env.from_string(m.content)
new_content = tmpl.render(**ctx)

msg_type = type(m)
rendered_msgs.append(
msg_type(role=m.role, content=new_content, name=m.name)
)

out.append(ChatMessages.from_messages(rendered_msgs))

return out
Loading