From f3c06e77a9238345f1b09bb07a50bd6502d353bf Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 19 Nov 2025 10:36:24 -0300 Subject: [PATCH] feat(chat): select model of run with request params --- src/askui/chat/api/runs/models.py | 3 ++- src/askui/chat/api/runs/runner/runner.py | 11 ++++++++--- src/askui/chat/api/runs/service.py | 1 + 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/askui/chat/api/runs/models.py b/src/askui/chat/api/runs/models.py index 5699a4ef..5d311b1d 100644 --- a/src/askui/chat/api/runs/models.py +++ b/src/askui/chat/api/runs/models.py @@ -40,6 +40,7 @@ class RunCreate(BaseModel): stream: bool = False assistant_id: AssistantId + model: str | None = None class RunStart(BaseModel): @@ -146,7 +147,7 @@ def create( thread_id=thread_id, created_at=now(), expires_at=now() + timedelta(minutes=10), - **params.model_dump(exclude={"stream"}), + **params.model_dump(exclude={"model", "stream"}), ) @computed_field # type: ignore[prop-decorator] diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 8bacaf3d..066c273a 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import Any from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaTextBlockParam from anyio.abc import ObjectStream @@ -34,7 +33,6 @@ ) from askui.chat.api.settings import Settings from askui.custom_agent import CustomAgent -from askui.models.models import ModelName from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCbParam from askui.models.shared.settings import ActSettings, MessageSettings @@ -67,6 +65,7 @@ def __init__( mcp_client_manager_manager: McpClientManagerManager, run_service: RunnerRunService, settings: Settings, + model: str | None = None, ) -> None: self._run_id = run_id self._workspace_id = workspace_id @@ -76,6 +75,7 @@ def __init__( self._mcp_client_manager_manager = mcp_client_manager_manager self._run_service = run_service self._settings = settings + self._model: str | None = model def _retrieve_run(self) -> Run: return self._run_service.retrieve( @@ -164,7 +164,7 @@ def _run_agent_inner() -> None: ) betas = tools.retrieve_tool_beta_flags() system = self._build_system() - model = self._settings.model + model = self._get_model() messages = syncify(self._chat_history_manager.retrieve_message_params)( workspace_id=self._workspace_id, thread_id=self._thread_id, @@ -269,3 +269,8 @@ async def run( def _should_abort(self, run: Run) -> bool: return run.status in ("cancelled", "cancelling", "expired") + + def _get_model(self) -> str: + if self._model is not None: + return self._model + return self._settings.model diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 2535f850..7bbffdbc 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -92,6 +92,7 @@ async def create( mcp_client_manager_manager=self._mcp_client_manager_manager, run_service=self, settings=self._settings, + model=params.model, ) async def event_generator() -> AsyncGenerator[Event, None]: