diff --git a/app/agent/config.py b/app/agent/config.py index e2dc241..86e10d3 100644 --- a/app/agent/config.py +++ b/app/agent/config.py @@ -5,7 +5,7 @@ class AgentConfig(BaseModel): prompt_source: Literal["file", "langfuse"] = "file" - prompt_dir: str = "agents/prompt" # For file-based prompts + checkpoint_type: Literal["memory", "postgres"] = "memory" custom_params: dict[str, Any] = {} diff --git a/app/agent/factory.py b/app/agent/factory.py index 49c8d33..e81fe2e 100644 --- a/app/agent/factory.py +++ b/app/agent/factory.py @@ -4,13 +4,12 @@ import logging from typing import Any -from dependency_injector.wiring import Provide, inject from langfuse import Langfuse from app.agent.config import AgentConfig from app.agent.interfaces import AgentInstance -from app.agent.langgraph.checkpoint.base import BaseCheckpointer -from app.agent.prompt import create_prompt_provider +from app.agent.langgraph.checkpoint.resolver import CheckpointerResolver +from app.agent.prompt_resolver import PromptProviderResolver from app.bootstrap.config import AppConfig logger = logging.getLogger(__name__) @@ -25,9 +24,17 @@ def __init__(self, agent_class_path: str, config: AgentConfig): class AgentFactory: _registered_agents: dict[str, AgentRegistry] = {} - def __init__(self, global_config: AppConfig, langfuse_client: Langfuse): + def __init__( + self, + global_config: AppConfig, + langfuse_client: Langfuse, + checkpointer_resolver: CheckpointerResolver, + prompt_provider_resolver: PromptProviderResolver, + ): self.global_config = global_config self._langfuse_client = langfuse_client + self._checkpointer_resolver = checkpointer_resolver + self._prompt_provider_resolver = prompt_provider_resolver @classmethod def register_agent( @@ -66,11 +73,9 @@ def _load_agent_class(self, agent_id: str) -> type[Any]: f"Failed to import agent class '{class_path}': {e}" ) from e - @inject async def create_agent( self, agent_id: str, - checkpointer_provider: BaseCheckpointer = Provide["checkpointer_provider"], ) -> AgentInstance: if agent_id not in self._registered_agents: raise ValueError( @@ -80,17 +85,11 @@ async def create_agent( registry_entry = self._registered_agents[agent_id] agent_config = registry_entry.config - await checkpointer_provider.initialize() ## TODO: Refactor. Different agents may require different checkpointers, so we should not call initialize here. And we should just call smth like checkpointer_resolve(agent_config.checkpointer) which will return checkpointer by agent config or name - checkpointer = await checkpointer_provider.get_checkpointer() - - prompt_provider = create_prompt_provider( ## TODO: Refactor. I think prompt providers should be registered in container and here we want to call smth like prompt_resolve(agent_config.prompt_provider) witch will return prompt provider by agent config or name - prompt_source=agent_config.prompt_source, - langfuse_client=self._langfuse_client - if agent_config.prompt_source == "langfuse" - else None, - prompt_dir=agent_config.prompt_dir - if agent_config.prompt_source == "file" - else None, + checkpointer = await self._checkpointer_resolver.get_saver(agent_config.checkpoint_type) + + prompt_provider = self._prompt_provider_resolver.resolve( + agent_config.prompt_source, + agent_name=agent_id, ) agent_class = self._load_agent_class(agent_id) diff --git a/app/agent/langgraph/checkpoint/resolver.py b/app/agent/langgraph/checkpoint/resolver.py new file mode 100644 index 0000000..a31f0d0 --- /dev/null +++ b/app/agent/langgraph/checkpoint/resolver.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Any, Literal + +from .base import BaseCheckpointer + +CheckpointType = Literal["memory", "postgres"] + + +class CheckpointerResolver: + def __init__( + self, + memory_checkpointer: BaseCheckpointer, + postgres_checkpointer: BaseCheckpointer, + ) -> None: + self._registry: dict[str, BaseCheckpointer] = { + "memory": memory_checkpointer, + "postgres": postgres_checkpointer, + } + + def resolve(self, checkpoint_type: CheckpointType) -> BaseCheckpointer: + checkpoint_type_normalized = checkpoint_type.lower() + if checkpoint_type_normalized not in self._registry: + raise ValueError( + f"Unsupported checkpointer type: {checkpoint_type_normalized}" + ) + return self._registry[checkpoint_type_normalized] + + async def get_saver(self, checkpoint_type: CheckpointType) -> Any: + provider = self.resolve(checkpoint_type) + await provider.initialize() + return await provider.get_checkpointer() + + diff --git a/app/agent/langgraph/graph.py b/app/agent/langgraph/graph.py index f84b36c..86e9b89 100644 --- a/app/agent/langgraph/graph.py +++ b/app/agent/langgraph/graph.py @@ -145,7 +145,7 @@ async def call_model( AIMessage, await chain.ainvoke( {"history": state.messages}, - config=config, # TODO: Pass handler here? + config=config, ), ) diff --git a/app/agent/prompt.py b/app/agent/prompt.py index 123f498..e8b4f57 100644 --- a/app/agent/prompt.py +++ b/app/agent/prompt.py @@ -58,7 +58,7 @@ def get_prompt(self, prompt_name: str, label: str, fallback: Prompt) -> Prompt: class JsonFilePromptProvider(PromptProvider): - """Prompt provider that reads prompts from JSON files in *root_dir*. + """Prompt provider that reads prompts from JSON files in *prompt_root_dir*/*agent_name*. File naming convention: ``.json`` Expected structure inside file:: @@ -72,7 +72,7 @@ class JsonFilePromptProvider(PromptProvider): If *label* is not found, *fallback* is returned. """ - def __init__(self, root_dir: str | Path = "agents/prompt") -> None: + def __init__(self, root_dir: str | Path = "data/prompts") -> None: self._root_dir = Path(root_dir) self._root_dir.mkdir(parents=True, exist_ok=True) @@ -108,24 +108,4 @@ def get_prompt(self, prompt_name: str, label: str, fallback: Prompt) -> Prompt: config=config, metadata={"file_path": str(path), "label": label}, ) - - -def create_prompt_provider( - prompt_source: str, - langfuse_client: object | None = None, - prompt_dir: str | Path | None = None, -) -> PromptProvider: - if prompt_source.lower() == "langfuse": - if ( - langfuse_client is None - or _LF is None - or not isinstance(langfuse_client, _LF) - ): - raise ValueError("Langfuse client is required for langfuse prompt type") - return LangfusePromptProvider(langfuse_client) - elif prompt_source.lower() == "file": - return JsonFilePromptProvider(prompt_dir or "agents/prompt") - else: - raise ValueError( - f"Unknown prompt source: {prompt_source}. Supported types: 'langfuse', 'file'" - ) + diff --git a/app/agent/prompt_resolver.py b/app/agent/prompt_resolver.py new file mode 100644 index 0000000..72dea52 --- /dev/null +++ b/app/agent/prompt_resolver.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from langfuse import Langfuse + +from app.bootstrap.config import AppConfig + +from .prompt import JsonFilePromptProvider, LangfusePromptProvider, PromptProvider + + +class PromptProviderResolver: + def __init__( + self, + config: AppConfig, + langfuse_client: Langfuse | None = None, + ) -> None: + self._config = config + self._langfuse_client = langfuse_client + self._registry: dict[ + str, Callable[[dict[str, Any]], PromptProvider] + ] = { + "file": self._build_file, + "langfuse": self._build_langfuse, + } + + def register( + self, source: str, factory: Callable[[dict[str, Any]], PromptProvider] + ) -> None: + self._registry[source.lower()] = factory + + def resolve( + self, source: str, *, agent_name: str | None = None + ) -> PromptProvider: + source_normalized = source.lower() + factory = self._registry.get(source_normalized) + if factory is None: + raise ValueError( + f"Unknown prompt source: {source_normalized}." + ) + return factory({"agent_name": agent_name}) + + def _build_file(self, ctx: dict[str, Any]) -> PromptProvider: + agent_name = ctx.get("agent_name") + if not agent_name: + raise ValueError("agent_name is required for 'file' prompt source") + prompt_dir = Path(self._config.prompt_root_dir) / agent_name + return JsonFilePromptProvider(prompt_dir) + + def _build_langfuse(self, _: dict[str, Any]) -> PromptProvider: + if self._langfuse_client is None: + raise ValueError("Langfuse client is required for 'langfuse' prompt source") + return LangfusePromptProvider(self._langfuse_client) diff --git a/app/bootstrap/config.py b/app/bootstrap/config.py index 9f658ed..37086b4 100644 --- a/app/bootstrap/config.py +++ b/app/bootstrap/config.py @@ -19,6 +19,7 @@ class AppConfig(BaseModel): database_url: str | None = None checkpoint_type: str = "memory" # Options: memory, postgres + prompt_root_dir: str = "data/prompts" def get_config() -> AppConfig: @@ -33,4 +34,5 @@ def get_config() -> AppConfig: "postgresql://postgres:postgres@localhost:5432/agent_template", ), checkpoint_type=os.getenv("CHECKPOINT_TYPE", "memory"), + prompt_root_dir=os.getenv("PROMPT_ROOT_DIR", "data/prompts"), ) diff --git a/app/container.py b/app/container.py index 6ac842e..e158cb9 100644 --- a/app/container.py +++ b/app/container.py @@ -6,28 +6,22 @@ from app.bootstrap.config import get_config -from .agent.langgraph.checkpoint.base import BaseCheckpointer from .agent.langgraph.checkpoint.memory import MemoryCheckpointer from .agent.langgraph.checkpoint.postgres import PostgresCheckpointer -from .bootstrap.config import AppConfig +from .agent.langgraph.checkpoint.resolver import CheckpointerResolver from .infrastructure import DatabaseConnection from .infrastructure.database import PostgreSQLConnection, SQLModelManager from .infrastructure.database.session import SessionManager -def create_checkpointer( - config: AppConfig, +def create_checkpointer_resolver( memory_checkpointer: MemoryCheckpointer, postgres_checkpointer: PostgresCheckpointer, -) -> BaseCheckpointer: - checkpoint_type = config.checkpoint_type.lower() - - if checkpoint_type == "memory": - return memory_checkpointer - elif checkpoint_type == "postgres": - return postgres_checkpointer - else: - raise ValueError(f"Unsupported checkpointer type: {checkpoint_type}") +) -> CheckpointerResolver: + return CheckpointerResolver( + memory_checkpointer=memory_checkpointer, + postgres_checkpointer=postgres_checkpointer, + ) def create_session(sqlmodel_manager: SQLModelManager) -> Session: @@ -77,9 +71,8 @@ class Container(containers.DeclarativeContainer): database_connection=db_connection, ) - checkpointer_provider: providers.Singleton[Any] = providers.Singleton( ## TODO: Remove this - create_checkpointer, - config=config, + checkpointer_resolver: providers.Singleton[Any] = providers.Singleton( + create_checkpointer_resolver, memory_checkpointer=memory_checkpointer, postgres_checkpointer=postgres_checkpointer, ) @@ -89,10 +82,18 @@ class Container(containers.DeclarativeContainer): debug=False, ) + prompt_provider_resolver: providers.Singleton[Any] = providers.Singleton( + "app.agent.prompt_resolver.PromptProviderResolver", + config=config, + langfuse_client=langfuse_client, + ) + agent_factory: providers.Singleton[Any] = providers.Singleton( "app.agent.factory.AgentFactory", global_config=config, langfuse_client=langfuse_client, + checkpointer_resolver=checkpointer_resolver, + prompt_provider_resolver=prompt_provider_resolver, ) agent_service: providers.Singleton[Any] = providers.Singleton(