Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class AgentConfig(BaseModel):
prompt_source: Literal["file", "langfuse"] = "file"
prompt_dir: str = "agents/prompt" # For file-based prompts
checkpoint_type: Literal["memory", "postgres"] = "memory"

custom_params: dict[str, Any] = {}

Expand Down
33 changes: 16 additions & 17 deletions app/agent/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import logging
from typing import Any

from dependency_injector.wiring import Provide, inject
from langfuse import Langfuse

from app.agent.config import AgentConfig
from app.agent.interfaces import AgentInstance
from app.agent.langgraph.checkpoint.base import BaseCheckpointer
from app.agent.prompt import create_prompt_provider
from app.agent.langgraph.checkpoint.resolver import CheckpointerResolver
from app.agent.prompt_resolver import PromptProviderResolver
from app.bootstrap.config import AppConfig

logger = logging.getLogger(__name__)
Expand All @@ -25,9 +24,17 @@ def __init__(self, agent_class_path: str, config: AgentConfig):
class AgentFactory:
_registered_agents: dict[str, AgentRegistry] = {}

def __init__(self, global_config: AppConfig, langfuse_client: Langfuse):
def __init__(
self,
global_config: AppConfig,
langfuse_client: Langfuse,
checkpointer_resolver: CheckpointerResolver,
prompt_provider_resolver: PromptProviderResolver,
):
self.global_config = global_config
self._langfuse_client = langfuse_client
self._checkpointer_resolver = checkpointer_resolver
self._prompt_provider_resolver = prompt_provider_resolver

@classmethod
def register_agent(
Expand Down Expand Up @@ -66,11 +73,9 @@ def _load_agent_class(self, agent_id: str) -> type[Any]:
f"Failed to import agent class '{class_path}': {e}"
) from e

@inject
async def create_agent(
self,
agent_id: str,
checkpointer_provider: BaseCheckpointer = Provide["checkpointer_provider"],
) -> AgentInstance:
if agent_id not in self._registered_agents:
raise ValueError(
Expand All @@ -80,17 +85,11 @@ async def create_agent(
registry_entry = self._registered_agents[agent_id]
agent_config = registry_entry.config

await checkpointer_provider.initialize() ## TODO: Refactor. Different agents may require different checkpointers, so we should not call initialize here. And we should just call smth like checkpointer_resolve(agent_config.checkpointer) which will return checkpointer by agent config or name
checkpointer = await checkpointer_provider.get_checkpointer()

prompt_provider = create_prompt_provider( ## TODO: Refactor. I think prompt providers should be registered in container and here we want to call smth like prompt_resolve(agent_config.prompt_provider) witch will return prompt provider by agent config or name
prompt_source=agent_config.prompt_source,
langfuse_client=self._langfuse_client
if agent_config.prompt_source == "langfuse"
else None,
prompt_dir=agent_config.prompt_dir
if agent_config.prompt_source == "file"
else None,
checkpointer = await self._checkpointer_resolver.get_saver(agent_config.checkpoint_type)

prompt_provider = self._prompt_provider_resolver.resolve(
agent_config.prompt_source,
agent_name=agent_id,
)

agent_class = self._load_agent_class(agent_id)
Expand Down
34 changes: 34 additions & 0 deletions app/agent/langgraph/checkpoint/resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from typing import Any, Literal

from .base import BaseCheckpointer

CheckpointType = Literal["memory", "postgres"]


class CheckpointerResolver:
def __init__(
self,
memory_checkpointer: BaseCheckpointer,
postgres_checkpointer: BaseCheckpointer,
) -> None:
self._registry: dict[str, BaseCheckpointer] = {
"memory": memory_checkpointer,
"postgres": postgres_checkpointer,
}

def resolve(self, checkpoint_type: CheckpointType) -> BaseCheckpointer:
checkpoint_type_normalized = checkpoint_type.lower()
if checkpoint_type_normalized not in self._registry:
raise ValueError(
f"Unsupported checkpointer type: {checkpoint_type_normalized}"
)
return self._registry[checkpoint_type_normalized]

async def get_saver(self, checkpoint_type: CheckpointType) -> Any:
provider = self.resolve(checkpoint_type)
await provider.initialize()
return await provider.get_checkpointer()


2 changes: 1 addition & 1 deletion app/agent/langgraph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def call_model(
AIMessage,
await chain.ainvoke(
{"history": state.messages},
config=config, # TODO: Pass handler here?
config=config,
),
)

Expand Down
26 changes: 3 additions & 23 deletions app/agent/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_prompt(self, prompt_name: str, label: str, fallback: Prompt) -> Prompt:


class JsonFilePromptProvider(PromptProvider):
"""Prompt provider that reads prompts from JSON files in *root_dir*.
"""Prompt provider that reads prompts from JSON files in *prompt_root_dir*/*agent_name*.

File naming convention: ``<prompt_name>.json``
Expected structure inside file::
Expand All @@ -72,7 +72,7 @@ class JsonFilePromptProvider(PromptProvider):
If *label* is not found, *fallback* is returned.
"""

def __init__(self, root_dir: str | Path = "agents/prompt") -> None:
def __init__(self, root_dir: str | Path = "data/prompts") -> None:
self._root_dir = Path(root_dir)
self._root_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -108,24 +108,4 @@ def get_prompt(self, prompt_name: str, label: str, fallback: Prompt) -> Prompt:
config=config,
metadata={"file_path": str(path), "label": label},
)


def create_prompt_provider(
prompt_source: str,
langfuse_client: object | None = None,
prompt_dir: str | Path | None = None,
) -> PromptProvider:
if prompt_source.lower() == "langfuse":
if (
langfuse_client is None
or _LF is None
or not isinstance(langfuse_client, _LF)
):
raise ValueError("Langfuse client is required for langfuse prompt type")
return LangfusePromptProvider(langfuse_client)
elif prompt_source.lower() == "file":
return JsonFilePromptProvider(prompt_dir or "agents/prompt")
else:
raise ValueError(
f"Unknown prompt source: {prompt_source}. Supported types: 'langfuse', 'file'"
)

55 changes: 55 additions & 0 deletions app/agent/prompt_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from collections.abc import Callable
from pathlib import Path
from typing import Any

from langfuse import Langfuse

from app.bootstrap.config import AppConfig

from .prompt import JsonFilePromptProvider, LangfusePromptProvider, PromptProvider


class PromptProviderResolver:
def __init__(
self,
config: AppConfig,
langfuse_client: Langfuse | None = None,
) -> None:
self._config = config
self._langfuse_client = langfuse_client
self._registry: dict[
str, Callable[[dict[str, Any]], PromptProvider]
] = {
"file": self._build_file,
"langfuse": self._build_langfuse,
}

def register(
self, source: str, factory: Callable[[dict[str, Any]], PromptProvider]
) -> None:
self._registry[source.lower()] = factory

def resolve(
self, source: str, *, agent_name: str | None = None
) -> PromptProvider:
source_normalized = source.lower()
factory = self._registry.get(source_normalized)
if factory is None:
raise ValueError(
f"Unknown prompt source: {source_normalized}."
)
return factory({"agent_name": agent_name})

def _build_file(self, ctx: dict[str, Any]) -> PromptProvider:
agent_name = ctx.get("agent_name")
if not agent_name:
raise ValueError("agent_name is required for 'file' prompt source")
prompt_dir = Path(self._config.prompt_root_dir) / agent_name
return JsonFilePromptProvider(prompt_dir)

def _build_langfuse(self, _: dict[str, Any]) -> PromptProvider:
if self._langfuse_client is None:
raise ValueError("Langfuse client is required for 'langfuse' prompt source")
return LangfusePromptProvider(self._langfuse_client)
2 changes: 2 additions & 0 deletions app/bootstrap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class AppConfig(BaseModel):

database_url: str | None = None
checkpoint_type: str = "memory" # Options: memory, postgres
prompt_root_dir: str = "data/prompts"


def get_config() -> AppConfig:
Expand All @@ -33,4 +34,5 @@ def get_config() -> AppConfig:
"postgresql://postgres:postgres@localhost:5432/agent_template",
),
checkpoint_type=os.getenv("CHECKPOINT_TYPE", "memory"),
prompt_root_dir=os.getenv("PROMPT_ROOT_DIR", "data/prompts"),
)
33 changes: 17 additions & 16 deletions app/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,22 @@

from app.bootstrap.config import get_config

from .agent.langgraph.checkpoint.base import BaseCheckpointer
from .agent.langgraph.checkpoint.memory import MemoryCheckpointer
from .agent.langgraph.checkpoint.postgres import PostgresCheckpointer
from .bootstrap.config import AppConfig
from .agent.langgraph.checkpoint.resolver import CheckpointerResolver
from .infrastructure import DatabaseConnection
from .infrastructure.database import PostgreSQLConnection, SQLModelManager
from .infrastructure.database.session import SessionManager


def create_checkpointer(
config: AppConfig,
def create_checkpointer_resolver(
memory_checkpointer: MemoryCheckpointer,
postgres_checkpointer: PostgresCheckpointer,
) -> BaseCheckpointer:
checkpoint_type = config.checkpoint_type.lower()

if checkpoint_type == "memory":
return memory_checkpointer
elif checkpoint_type == "postgres":
return postgres_checkpointer
else:
raise ValueError(f"Unsupported checkpointer type: {checkpoint_type}")
) -> CheckpointerResolver:
return CheckpointerResolver(
memory_checkpointer=memory_checkpointer,
postgres_checkpointer=postgres_checkpointer,
)


def create_session(sqlmodel_manager: SQLModelManager) -> Session:
Expand Down Expand Up @@ -77,9 +71,8 @@ class Container(containers.DeclarativeContainer):
database_connection=db_connection,
)

checkpointer_provider: providers.Singleton[Any] = providers.Singleton( ## TODO: Remove this
create_checkpointer,
config=config,
checkpointer_resolver: providers.Singleton[Any] = providers.Singleton(
create_checkpointer_resolver,
memory_checkpointer=memory_checkpointer,
postgres_checkpointer=postgres_checkpointer,
)
Expand All @@ -89,10 +82,18 @@ class Container(containers.DeclarativeContainer):
debug=False,
)

prompt_provider_resolver: providers.Singleton[Any] = providers.Singleton(
"app.agent.prompt_resolver.PromptProviderResolver",
config=config,
langfuse_client=langfuse_client,
)

agent_factory: providers.Singleton[Any] = providers.Singleton(
"app.agent.factory.AgentFactory",
global_config=config,
langfuse_client=langfuse_client,
checkpointer_resolver=checkpointer_resolver,
prompt_provider_resolver=prompt_provider_resolver,
)

agent_service: providers.Singleton[Any] = providers.Singleton(
Expand Down