diff --git a/.gitignore b/.gitignore index 8be03fc..21bb0b4 100644 --- a/.gitignore +++ b/.gitignore @@ -373,3 +373,4 @@ components.d.ts **/.vscode/* .wxt var +.vscfavoriterc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b75e472..f01b87b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,6 +33,7 @@ Before submitting changes, please run checks and tests as described in [Automate - [Docker](https://www.docker.com/) - [devbox](https://www.jetpack.io/devbox) - [git-lfs](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) +- If you use a Visual Studio Code based IDE (like Cursor, Antigravity, etc), then consider installing the [EditorConfig extension](https://marketplace.visualstudio.com/items?itemName=EditorConfig.EditorConfig) to apply all the rules defined in `.editorconfig` and keep the code format consistent. ### Setup diff --git a/devbox.json b/devbox.json index 92a1cb3..57a5e9d 100644 --- a/devbox.json +++ b/devbox.json @@ -2,6 +2,8 @@ "$schema": "https://raw.githubusercontent.com/jetify-com/devbox/0.14.0/.schema/devbox.schema.json", "packages": { "python": "3.13", + "git": "latest", + "uv": "latest", "poetry": { "version": "latest", // this fails because is trying to find a pyproject.toml at the root of the project @@ -56,6 +58,9 @@ "cd src/backend", "poetry run python -m tero.secrets_cleanup" ], + "vllm": [ + "./scripts/vllm.sh" + ], "playwright": [ "docker compose up playwright" ], diff --git a/devbox.lock b/devbox.lock index 3fc848d..4ae8f00 100644 --- a/devbox.lock +++ b/devbox.lock @@ -1,6 +1,78 @@ { "lockfile_version": "1", "packages": { + "git@latest": { + "last_modified": "2026-01-23T17:20:52Z", + "resolved": "github:NixOS/nixpkgs/a1bab9e494f5f4939442a57a58d0449a109593fe#git", + "source": "devbox-search", + "version": "2.52.0", + "systems": { + "aarch64-darwin": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/vnhprisb777byfjpp5mdd0mxwkpvhbc0-git-2.52.0", + "default": true + }, + { + "name": "doc", + "path": "/nix/store/p0dx3053175fpr3kjf0fqgs9x6gm3dri-git-2.52.0-doc" + } + ], + "store_path": "/nix/store/vnhprisb777byfjpp5mdd0mxwkpvhbc0-git-2.52.0" + }, + "aarch64-linux": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/sd4bjblxfljbm13mpl50x8g336gw2dri-git-2.52.0", + "default": true + }, + { + "name": "debug", + "path": "/nix/store/cx033x4nmr12d9dcn8hxjfrajc3983f0-git-2.52.0-debug" + }, + { + "name": "doc", + "path": "/nix/store/ylsmgdrnp78p8hn7n9lxpada78dka41v-git-2.52.0-doc" + } + ], + "store_path": "/nix/store/sd4bjblxfljbm13mpl50x8g336gw2dri-git-2.52.0" + }, + "x86_64-darwin": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/sr42bfqa0pc2ysba678mrm49g78jdynp-git-2.52.0", + "default": true + }, + { + "name": "doc", + "path": "/nix/store/fky0ci2bhwgyh9klg5682pmqswqg7wk4-git-2.52.0-doc" + } + ], + "store_path": "/nix/store/sr42bfqa0pc2ysba678mrm49g78jdynp-git-2.52.0" + }, + "x86_64-linux": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/ipwndaag56mm8g8gn8j98z0jvn8x4mk1-git-2.52.0", + "default": true + }, + { + "name": "debug", + "path": "/nix/store/zd8di23fmzjwrhb5ij1bjnlfqkx9j7d6-git-2.52.0-debug" + }, + { + "name": "doc", + "path": "/nix/store/d8kc4kqzcrbcj92msxrpdsdjglh2q5gp-git-2.52.0-doc" + } + ], + "store_path": "/nix/store/ipwndaag56mm8g8gn8j98z0jvn8x4mk1-git-2.52.0" + } + } + }, "github:NixOS/nixpkgs/nixpkgs-unstable": { "last_modified": "2025-10-23T16:27:14Z", "resolved": "github:NixOS/nixpkgs/d5faa84122bc0a1fd5d378492efce4e289f8eac1?lastModified=1761236834&narHash=sha256-%2Bpthv6hrL5VLW2UqPdISGuLiUZ6SnAXdd2DdUE%2BfV2Q%3D" @@ -238,6 +310,54 @@ "store_path": "/nix/store/2mab9iiwhcqwk75qwvp3zv0bvbiaq6cs-python3-3.13.3" } } + }, + "uv@latest": { + "last_modified": "2026-01-30T02:32:49Z", + "resolved": "github:NixOS/nixpkgs/6308c3b21396534d8aaeac46179c14c439a89b8a#uv", + "source": "devbox-search", + "version": "0.9.27", + "systems": { + "aarch64-darwin": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/bj9jidx7h317hmnz220qpka6850sjvi1-uv-0.9.27", + "default": true + } + ], + "store_path": "/nix/store/bj9jidx7h317hmnz220qpka6850sjvi1-uv-0.9.27" + }, + "aarch64-linux": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/2703hnsgll85c1rpzzqycnmp0jvy579z-uv-0.9.27", + "default": true + } + ], + "store_path": "/nix/store/2703hnsgll85c1rpzzqycnmp0jvy579z-uv-0.9.27" + }, + "x86_64-darwin": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/il24573a65xa0bfy2d6kli6nvnx80qhv-uv-0.9.27", + "default": true + } + ], + "store_path": "/nix/store/il24573a65xa0bfy2d6kli6nvnx80qhv-uv-0.9.27" + }, + "x86_64-linux": { + "outputs": [ + { + "name": "out", + "path": "/nix/store/k9b6fbx5pzj04fws4na77i4pb5l64li8-uv-0.9.27", + "default": true + } + ], + "store_path": "/nix/store/k9b6fbx5pzj04fws4na77i4pb5l64li8-uv-0.9.27" + } + } } } } diff --git a/scripts/vllm.sh b/scripts/vllm.sh new file mode 100644 index 0000000..8e89433 --- /dev/null +++ b/scripts/vllm.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -e + +# using v0.14.1 since when using v0.15.1 we get "Failed to import from vllm._C ... _C.abi.so" and then "AttributeError: '_OpNamespace' '_C_utils' object has no attribute 'init_cpu_threads_env'" when starting vllm serve. +[ -d var/vllm/.git ] || git clone -b v0.14.1 https://github.com/vllm-project/vllm.git var/vllm +cd var/vllm + +# C++: 0 < M <= 8 is parsed as (0 < M) <= 8; fix to (0 < M) && (M <= 8) so Clang on macOS accepts it +sed -i.bak 's/static_assert(0 < M <= 8);/static_assert(0 < M \&\& M <= 8);/g' csrc/cpu/cpu_attn_vec.hpp 2>/dev/null || true +sed -i.bak 's/static_assert(0 < M <= 16);/static_assert(0 < M \&\& M <= 16);/g' csrc/cpu/cpu_attn_vec16.hpp 2>/dev/null || true + +if [ ! -x .venv/bin/vllm ]; then + ([ -d .venv ] || uv venv --seed .venv) + . .venv/bin/activate + uv pip install -r requirements/cpu.txt --index-strategy unsafe-best-match + uv pip install -e ".[audio]" +else + . .venv/bin/activate +fi + +cleanup() { + [ -n "$CHAT_PID" ] && kill "$CHAT_PID" 2>/dev/null || true + [ -n "$EMBED_PID" ] && kill "$EMBED_PID" 2>/dev/null || true + [ -n "$TRANSCRIBE_PID" ] && kill "$TRANSCRIBE_PID" 2>/dev/null || true + wait 2>/dev/null || true +} +trap cleanup EXIT INT TERM + +vllm serve Qwen/Qwen2.5-1.5B-Instruct --dtype auto --api-key test-token --port 8001 --enable-auto-tool-choice --tool-call-parser hermes & +CHAT_PID=$! +sleep 2 +vllm serve nomic-ai/nomic-embed-text-v1 --api-key test-token --trust-remote-code --port 8002 & +EMBED_PID=$! + +wait + diff --git a/src/backend/alembic/versions/20260205-613cc99427e2-qwen_2_5_1_5b_model.py b/src/backend/alembic/versions/20260205-613cc99427e2-qwen_2_5_1_5b_model.py new file mode 100644 index 0000000..bc9b8ab --- /dev/null +++ b/src/backend/alembic/versions/20260205-613cc99427e2-qwen_2_5_1_5b_model.py @@ -0,0 +1,41 @@ +"""qwen_2.5_1.5b_model + +Revision ID: 613cc99427e2 +Revises: c3ae0aefa4d1 +Create Date: 2026-02-05 16:23:32.147821 + +""" +from typing import Sequence, Union +from alembic import op +from alembic_postgresql_enum import TableReference + +# revision identifiers, used by Alembic. +revision: str = '613cc99427e2' +down_revision: Union[str, None] = 'c3ae0aefa4d1' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.sync_enum_values( # type: ignore + enum_schema='public', + enum_name='llmmodelvendor', + new_values=['ANTHROPIC', 'GOOGLE', 'OPENAI', 'QWEN'], + affected_columns=[TableReference(table_schema='public', table_name='llm_model', column_name='model_vendor')], + enum_values_to_rename=[], + ) + op.execute(""" + INSERT INTO llm_model (id, name, description, token_limit, output_token_limit, prompt_1k_token_usd, completion_1k_token_usd, model_type, model_vendor) VALUES + ('qwen-2.5-1.5b', 'Qwen 2.5 1.5B', 'This is a free, open‑source model for simple tasks and basic coding; less capable than GPT‑4o Mini and GPT‑4.1 Nano.', 128000, 8000, 0.0, 0.0, 'CHAT', 'QWEN') + """) + + +def downgrade() -> None: + op.execute("DELETE FROM llm_model WHERE id = 'qwen-2.5-1.5b'") + op.sync_enum_values( # type: ignore + enum_schema='public', + enum_name='llmmodelvendor', + new_values=['ANTHROPIC', 'GOOGLE', 'OPENAI'], + affected_columns=[TableReference(table_schema='public', table_name='llm_model', column_name='model_vendor')], + enum_values_to_rename=[], + ) diff --git a/src/backend/alembic/versions/20260210-fd24763a078c-recursion_limit.py b/src/backend/alembic/versions/20260210-fd24763a078c-recursion_limit.py new file mode 100644 index 0000000..be8b55c --- /dev/null +++ b/src/backend/alembic/versions/20260210-fd24763a078c-recursion_limit.py @@ -0,0 +1,43 @@ +"""recursion-limit + +Revision ID: fd24763a078c +Revises: 613cc99427e2 +Create Date: 2026-02-10 12:30:05.846201 + +""" + +import sqlalchemy as sa +import sqlmodel +from typing import Sequence, Union +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = "fd24763a078c" +down_revision: Union[str, None] = "613cc99427e2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "agent", + sa.Column( + "recursion_limit", + sa.Integer(), + nullable=False, + server_default=sa.literal_column("40"), + ), + ) + op.create_check_constraint( + "agent_recursion_limit_range", + "agent", + "recursion_limit >= 20 AND recursion_limit <= 100", + ) + + +def downgrade() -> None: + op.execute( + "ALTER TABLE agent DROP CONSTRAINT IF EXISTS agent_recursion_limit_range" + ) + op.drop_column("agent", "recursion_limit") diff --git a/src/backend/alembic/versions/20260211-8e9d8ca4dd0c-thread_message_origin_system.py b/src/backend/alembic/versions/20260211-8e9d8ca4dd0c-thread_message_origin_system.py new file mode 100644 index 0000000..6c2557e --- /dev/null +++ b/src/backend/alembic/versions/20260211-8e9d8ca4dd0c-thread_message_origin_system.py @@ -0,0 +1,38 @@ +"""thread_message_origin_system + +Revision ID: 8e9d8ca4dd0c +Revises: fd24763a078c +Create Date: 2026-02-11 19:05:25.867844 + +""" +import sqlalchemy as sa +import sqlmodel +from typing import Sequence, Union +from alembic import op +from alembic_postgresql_enum import TableReference + +# revision identifiers, used by Alembic. +revision: str = '8e9d8ca4dd0c' +down_revision: Union[str, None] = 'fd24763a078c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.sync_enum_values( # type: ignore + enum_schema="public", + enum_name="threadmessageorigin", + new_values=["USER", "AGENT", "SYSTEM"], + affected_columns=[TableReference(table_schema="public", table_name="thread_message", column_name="origin")], + enum_values_to_rename=[], + ) + + +def downgrade() -> None: + op.sync_enum_values( # type: ignore + enum_schema="public", + enum_name="threadmessageorigin", + new_values=["USER", "AGENT"], + affected_columns=[TableReference(table_schema="public", table_name="thread_message", column_name="origin")], + enum_values_to_rename=[], + ) diff --git a/src/backend/tero/agents/api.py b/src/backend/tero/agents/api.py index 9aa2c30..29cc3cf 100644 --- a/src/backend/tero/agents/api.py +++ b/src/backend/tero/agents/api.py @@ -14,9 +14,8 @@ from ..core.env import env from ..core.repos import get_db from ..files.api import build_file_download_response +from ..files.core import QuotaExceededError, add_encoding_to_content_type from ..files.domain import File, FileStatus, FileUpdate, FileMetadata, FileMetadataWithContent -from ..files.file_quota import QuotaExceededError -from ..files.parser import add_encoding_to_content_type from ..files.repos import FileRepository from ..teams.domain import GLOBAL_TEAM_ID, Role from ..tools.core import AgentTool diff --git a/src/backend/tero/agents/domain.py b/src/backend/tero/agents/domain.py index 5b1a1b7..bcfb913 100644 --- a/src/backend/tero/agents/domain.py +++ b/src/backend/tero/agents/domain.py @@ -22,10 +22,10 @@ class BaseAgent(CamelCaseModel, abc.ABC): id: int = Field(primary_key=True, default=None) - name: Optional[str] = Field(max_length=NAME_MAX_LENGTH, default=None) + name: Optional[str] = Field(max_length=NAME_MAX_LENGTH, default=None) description: Optional[str] = Field(max_length=100, default=None) last_update: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - + def set_default_name(self): self.name = f"Agent #{self.id}" @@ -38,6 +38,7 @@ class AgentUpdate(CamelCaseModel): system_prompt: Optional[str] = None temperature: Optional[LlmTemperature] = None reasoning_effort: Optional[ReasoningEffort] = None + recursion_limit: Optional[int] = None publish_prompts: Optional[bool] = None team_id: Optional[int] = None @@ -55,10 +56,11 @@ class Agent(BaseAgent, table=True): system_prompt: str = Field(sa_column=Column(Text)) temperature: LlmTemperature = LlmTemperature.NEUTRAL reasoning_effort: ReasoningEffort = ReasoningEffort.LOW + recursion_limit: int = Field(default=20, ge=20, le=100) team_id: Optional[int] = Field(default=None, foreign_key="team.id") team: Optional[Team] = Relationship() evaluator_id: Optional[int] = Field(default=None, foreign_key="evaluator.id") - + def update_with(self, update: AgentUpdate): update_dict = update.model_dump(exclude_none=True) update_dict["icon"] = base64.b64decode(update_dict["icon"]) if update_dict.get("icon") else None @@ -79,7 +81,7 @@ def is_visible_by(self, user: User) -> bool: def is_editable_by(self, user: User) -> bool: return self.user_id == user.id or any( - tr.role in [Role.TEAM_OWNER, Role.TEAM_EDITOR] and cast(Team, tr.team).id == self.team_id + tr.role in [Role.TEAM_OWNER, Role.TEAM_EDITOR] and cast(Team, tr.team).id == self.team_id for tr in user.team_roles ) @@ -130,9 +132,10 @@ class PublicAgent(BaseAgent): system_prompt: str temperature: LlmTemperature reasoning_effort: ReasoningEffort + recursion_limit: int user: Optional[UserListItem] = None team: Optional[Team] = None - + @staticmethod def from_agent(a: Agent, can_edit: bool) -> 'PublicAgent': agent_dict = a.model_dump() diff --git a/src/backend/tero/agents/evaluators/api.py b/src/backend/tero/agents/evaluators/api.py index 47684a4..100612f 100644 --- a/src/backend/tero/agents/evaluators/api.py +++ b/src/backend/tero/agents/evaluators/api.py @@ -12,9 +12,9 @@ from ..test_cases.api import TEST_CASE_PATH, find_test_case_by_id from ..test_cases.repos import TestCaseRepository from ..test_cases.runner import ( + EVALUATOR_DEFAULT_INSTRUCTIONS, EVALUATOR_DEFAULT_REASONING_EFFORT, EVALUATOR_DEFAULT_TEMPERATURE, - EVALUATOR_HUMAN_MESSAGE, ) from .domain import Evaluator, PublicEvaluator from .repos import EvaluatorRepository @@ -33,10 +33,10 @@ async def find_agent_evaluator(agent_id: int, user: Annotated[User, Depends(get_ return PublicEvaluator.from_evaluator(evaluator) else: return PublicEvaluator( - model_id=cast(str, env.internal_evaluator_model), - temperature=EVALUATOR_DEFAULT_TEMPERATURE, - reasoning_effort=EVALUATOR_DEFAULT_REASONING_EFFORT, - prompt=EVALUATOR_HUMAN_MESSAGE + model_id=cast(str, env.internal_evaluator_model), + temperature=EVALUATOR_DEFAULT_TEMPERATURE, + reasoning_effort=EVALUATOR_DEFAULT_REASONING_EFFORT, + prompt=EVALUATOR_DEFAULT_INSTRUCTIONS ) diff --git a/src/backend/tero/agents/test_cases/runner.py b/src/backend/tero/agents/test_cases/runner.py index 60617ea..557abb4 100644 --- a/src/backend/tero/agents/test_cases/runner.py +++ b/src/backend/tero/agents/test_cases/runner.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) -EVALUATOR_HUMAN_MESSAGE = """ +EVALUATOR_DEFAULT_INSTRUCTIONS = """ Compare the actual output with the reference output based on these criteria: 1. Semantic equivalence - Does the actual output convey the same meaning as the reference output? 2. Completeness - Does the actual output contain all key information from the reference output? @@ -40,10 +40,10 @@ 5. Conciseness - Does the actual output avoid including extra information not present in the reference output? If the reference output is concise the response should also be concise for example if the reference output is "Agent response" the actual output should also be "Agent response" or similar. Be lenient with minor differences in wording, formatting, or style. Focus on whether the core meaning and key information match. Be strict about factual errors, missing critical information, or extraneous details that go beyond the expected output. - +""".strip() +EVALUATOR_APPENDED_BLOCK = """ Respond with 'Y' if the actual output sufficiently matches the reference output, or 'N' if there are significant discrepancies. Then provide a brief explanation. - Input: {{inputs}} @@ -52,7 +52,7 @@ Actual Output: {{outputs}} -""" +""".strip() EVALUATOR_DEFAULT_TEMPERATURE = LlmTemperature.NEUTRAL EVALUATOR_DEFAULT_REASONING_EFFORT = ReasoningEffort.MEDIUM @@ -100,7 +100,8 @@ async def _broadcast_event(self, db: AsyncSession, suite_run_id: int, event_type )) def _build_test_case_evaluator_prompt(self, evaluator: Optional[Evaluator]) -> ChatPromptTemplate: - human_message = evaluator.prompt if evaluator else EVALUATOR_HUMAN_MESSAGE + instructions = evaluator.prompt if evaluator else EVALUATOR_DEFAULT_INSTRUCTIONS + human_message = instructions.rstrip() + "\n\n" + EVALUATOR_APPENDED_BLOCK return ChatPromptTemplate( [("system", "You are an expert evaluator assessing whether the actual output from an AI agent matches the expected output for a given test case."), ("human", human_message)], diff --git a/src/backend/tero/agents/tool_file.py b/src/backend/tero/agents/tool_file.py index 0300345..1b4c682 100644 --- a/src/backend/tero/agents/tool_file.py +++ b/src/backend/tero/agents/tool_file.py @@ -5,9 +5,8 @@ from fastapi.background import BackgroundTasks from ..core import repos as repos_module +from ..files.core import add_encoding_to_content_type, QuotaExceededError from ..files.domain import File, FileStatus, FileMetadata -from ..files.file_quota import QuotaExceededError -from ..files.parser import add_encoding_to_content_type from ..files.repos import FileRepository from ..tools.core import AgentTool from ..tools.repos import ToolRepository diff --git a/src/backend/tero/ai_models/ai_factory.py b/src/backend/tero/ai_models/ai_factory.py index f49a094..7931611 100644 --- a/src/backend/tero/ai_models/ai_factory.py +++ b/src/backend/tero/ai_models/ai_factory.py @@ -18,7 +18,7 @@ providers.append(AWSProvider()) if env.google_api_key: providers.append(GoogleProvider()) -if env.vllm_base_url and env.vllm_api_key: +if env.vllm_urls and env.vllm_api_keys: providers.append(VllmAiProvider()) diff --git a/src/backend/tero/ai_models/azure_provider.py b/src/backend/tero/ai_models/azure_provider.py index baafc51..ea3833f 100644 --- a/src/backend/tero/ai_models/azure_provider.py +++ b/src/backend/tero/ai_models/azure_provider.py @@ -1,5 +1,5 @@ import io -from typing import Optional, cast +from typing import Callable, Iterable, Optional, cast from langchain_core.language_models.chat_models import BaseChatModel from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings @@ -9,7 +9,7 @@ from ..core.env import env from .domain import AiModelProvider -from .openai_provider import get_encoding_model +from .openai_provider import get_encoding_model, count_tokens class AzureProvider(AiModelProvider): @@ -43,17 +43,30 @@ async def transcribe_audio(self, file: io.BytesIO, model: str) -> str: ) return response.text - def build_embedding(self, model: str) -> AzureOpenAIEmbeddings: + def build_embedding(self, model: str, usage_tracker: Callable[[int], None]) -> AzureOpenAIEmbeddings: deployment = env.azure_model_deployments[model] - return AzureOpenAIEmbeddings( + return UsageTrackingAzureOpenAIEmbeddings( + usage_tracker=usage_tracker, azure_endpoint=env.azure_endpoints[deployment.endpoint_index], azure_deployment=deployment.deployment_name, api_version=env.azure_api_version, api_key=env.azure_api_keys[deployment.endpoint_index]) + def count_tokens(self, txt: str, model: str) -> int: + return count_tokens(txt, model) + class ReasoningTokenCountingAzureChatOpenAI(AzureChatOpenAI): # we override this method which is the one used by get_num_tokens_from_messages to count the tokens def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]: return get_encoding_model(self.model_name, lambda: AzureChatOpenAI._get_encoding_model(self)) + + +class UsageTrackingAzureOpenAIEmbeddings(AzureOpenAIEmbeddings): + usage_tracker: Callable[[int], None] + + def _tokenize(self, texts: list[str], chunk_size: int) -> tuple[Iterable[int], list[list[int] | str], list[int], list[int]]: + ret = super()._tokenize(texts, chunk_size) + self.usage_tracker(sum(ret[3])) + return ret diff --git a/src/backend/tero/ai_models/domain.py b/src/backend/tero/ai_models/domain.py index 4ded5ab..98da44f 100644 --- a/src/backend/tero/ai_models/domain.py +++ b/src/backend/tero/ai_models/domain.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum import io -from typing import Any, Optional +from typing import Any, Callable, Optional from langchain_core.callbacks import StdOutCallbackHandler from langchain_core.embeddings import Embeddings @@ -9,6 +9,7 @@ from langchain_core.tracers import ConsoleCallbackHandler from pydantic import computed_field from sqlmodel import Field +from tokenizers import Tokenizer from ..core.env import env from ..core.domain import CamelCaseModel @@ -23,6 +24,8 @@ class LlmModelVendor(Enum): ANTHROPIC = 'ANTHROPIC' GOOGLE = 'GOOGLE' OPENAI = 'OPENAI' + QWEN = 'QWEN' + class LlmModel(CamelCaseModel, table=True): @@ -36,7 +39,7 @@ class LlmModel(CamelCaseModel, table=True): output_token_limit: int prompt_1k_token_usd: float completion_1k_token_usd: float - + @computed_field @property def is_basic(self) -> bool: @@ -63,7 +66,7 @@ class AiModelProvider(ABC): def build_chat_model(self, model: str, temperature: Optional[float]=None, reasoning_effort: Optional[str] = None) -> BaseChatModel: ret = self._build_chat_model(model, temperature, reasoning_effort, False) return self._prepare_chat_model(ret) - + def _prepare_chat_model(self, model: Any) -> BaseChatModel: model.verbose = True model.callbacks=[StdOutCallbackHandler(), ConsoleCallbackHandler()] if not env.azure_app_insights_connection else [] @@ -80,9 +83,12 @@ def _build_chat_model(self, model: str, temperature: Optional[float], reasoning_ @abstractmethod def supports_model(self, model: str) -> bool: pass - + async def transcribe_audio(self, file: io.BytesIO, model: str) -> str: raise NotImplementedError("Transcription is not yet supported by this provider") - def build_embedding(self, model: str) -> Embeddings: + def build_embedding(self, model: str, usage_tracker: Callable[[int], None]) -> Embeddings: raise NotImplementedError("Embedding is not yet supported by this provider") + + def count_tokens(self, txt: str, model: str) -> int: + raise NotImplementedError("Counting tokens is not yet supported by this provider") diff --git a/src/backend/tero/ai_models/openai_provider.py b/src/backend/tero/ai_models/openai_provider.py index 675e40d..e6ab979 100644 --- a/src/backend/tero/ai_models/openai_provider.py +++ b/src/backend/tero/ai_models/openai_provider.py @@ -1,5 +1,5 @@ import io -from typing import Callable, Optional, cast +from typing import Callable, Iterable, Optional, cast from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -33,11 +33,20 @@ async def transcribe_audio(self, file: io.BytesIO, model: str) -> str: ) return response.text - def build_embedding(self, model: str) -> Embeddings: - return OpenAIEmbeddings( + def build_embedding(self, model: str, usage_tracker: Callable[[int], None]) -> Embeddings: + return UsageTrackingOpenAIEmbeddings( + usage_tracker=usage_tracker, api_key=env.openai_api_key, + embedding_ctx_length=env.embedding_context_limit, model=env.openai_model_id_mapping[model]) + def count_tokens(self, txt: str, model: str) -> int: + return count_tokens(txt, model) + + +def count_tokens(txt: str, model: str) -> int: + return len(tiktoken.encoding_for_model(model).encode(txt)) + class ReasoningTokenCountingChatOpenAI(ChatOpenAI): @@ -51,3 +60,12 @@ def get_encoding_model(model_name: Optional[str], default: Callable[[], tuple[st # we return gpt-4o for o- series since it is supported by existing implementation of get_num_tokens_from_messages return "gpt-4o", tiktoken.get_encoding("o200k_base") return default() + + +class UsageTrackingOpenAIEmbeddings(OpenAIEmbeddings): + usage_tracker: Callable[[int], None] + + def _tokenize(self, texts: list[str], chunk_size: int) -> tuple[Iterable[int], list[list[int] | str], list[int], list[int]]: + ret = super()._tokenize(texts, chunk_size) + self.usage_tracker(sum(ret[3])) + return ret diff --git a/src/backend/tero/ai_models/vllm_provider.py b/src/backend/tero/ai_models/vllm_provider.py index fb988c3..ece798d 100644 --- a/src/backend/tero/ai_models/vllm_provider.py +++ b/src/backend/tero/ai_models/vllm_provider.py @@ -1,20 +1,20 @@ -import io +from functools import cache import logging import json -from typing import Any, Callable, Optional, Sequence, cast +from typing import Any, Callable, Optional, Sequence from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage from langchain_core.tools import BaseTool -from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_openai import ChatOpenAI from langchain_core.utils.function_calling import convert_to_openai_tool -from openai import AsyncOpenAI -from pydantic import SecretStr from tokenizers import Tokenizer from ..core.env import env from .domain import AiModelProvider +from .openai_provider import UsageTrackingOpenAIEmbeddings + logger = logging.getLogger(__name__) @@ -22,89 +22,52 @@ class VllmAiProvider(AiModelProvider): def _build_chat_model(self, model: str, temperature: Optional[float], reasoning_effort: Optional[str], streaming: bool) -> BaseChatModel: - vllm_model_id = env.vllm_model_id_mapping[model] + index, model_id = self._find_vllm_model(model) return VLLMChatModel( - base_url=env.vllm_base_url, - api_key=env.vllm_api_key, - model=vllm_model_id, + base_url=env.vllm_urls[index], + api_key=env.vllm_api_keys[index], + model=model_id, temperature=temperature, streaming=streaming) + def _find_vllm_model(self, model: str) -> tuple[int, str]: + return next((index, item[1]) for index, item in enumerate(env.vllm_model_id_mapping.items()) if item[0] == model) + def supports_model(self, model: str) -> bool: return model in env.vllm_model_id_mapping - async def transcribe_audio(self, file: io.BytesIO, model: str) -> str: - client = AsyncOpenAI(api_key=cast(SecretStr, env.vllm_api_key).get_secret_value()) - response = await client.audio.transcriptions.create( - file=file, - model=env.vllm_model_id_mapping[model] - ) - return response.text - - def build_embedding(self, model: str) -> Embeddings: - return OpenAIEmbeddings( - api_key=env.vllm_api_key, - model=env.vllm_model_id_mapping[model]) - -# this cache is used to avoid downloading the tokenizer for the same model multiple times -_hf_tokenizer_cache: dict[str, "HuggingFaceTokenizerEncoding"] = {} - -class HuggingFaceTokenizerEncoding: - def __init__(self, tokenizer, model_name: str): - self._tokenizer = tokenizer - self._model_name = model_name - - def encode(self, text: str, *, allowed_special: set = set(), disallowed_special: set = set()) -> list[int]: - return self._tokenizer.encode(text).ids - - def decode(self, tokens: list[int], errors: str = "replace") -> str: - return self._tokenizer.decode(tokens) - - def encode_ordinary(self, text: str) -> list[int]: - return self._tokenizer.encode(text).ids - - @classmethod - def from_pretrained(cls, model_name: str) -> Optional["HuggingFaceTokenizerEncoding"]: - if model_name in _hf_tokenizer_cache: - return _hf_tokenizer_cache[model_name] - - try: - tokenizer = Tokenizer.from_pretrained(model_name) - wrapper = cls(tokenizer, model_name) - _hf_tokenizer_cache[model_name] = wrapper - return wrapper - except Exception as e: - logger.debug(f"Failed to load HuggingFace tokenizer for {model_name}: {e}") - return None + def build_embedding(self, model: str, usage_tracker: Callable[[int], None]) -> Embeddings: + index, model_id = self._find_vllm_model(model) + print(f"Building embedding model {model_id} with context limit {env.embedding_context_limit}") + return UsageTrackingOpenAIEmbeddings( + usage_tracker=usage_tracker, + base_url=env.vllm_urls[index], + api_key=env.vllm_api_keys[index], + model=model_id, + embedding_ctx_length=env.embedding_context_limit, + tiktoken_enabled=False) -class VLLMChatModel(ChatOpenAI): - _hf_tokenizer: Optional[HuggingFaceTokenizerEncoding] = None + def count_tokens(self, txt: str, model: str) -> int: + _, model_id = self._find_vllm_model(model) + return len(get_tokenizer(model_id).encode(txt).ids) - def _get_tokenizer(self) -> HuggingFaceTokenizerEncoding: - if self._hf_tokenizer is None: - self._hf_tokenizer = HuggingFaceTokenizerEncoding.from_pretrained(self.model_name) - - if self._hf_tokenizer is None: - raise ValueError( - f"Failed to load HuggingFace tokenizer for model '{self.model_name}'. " - f"Verify the model exists on HuggingFace Hub and has a tokenizer available." - ) - - return self._hf_tokenizer - def _get_encoding_model(self) -> tuple[str, HuggingFaceTokenizerEncoding]: # type: ignore[override] - return self.model_name, self._get_tokenizer() +@cache +def get_tokenizer(model_name: str) -> Tokenizer: + return Tokenizer.from_pretrained(model_name) - def get_num_tokens(self, text: str) -> int: - tokenizer = self._get_tokenizer() - return len(tokenizer.encode(text)) + +class VLLMChatModel(ChatOpenAI): + + def get_token_ids(self, text: str) -> list[int]: + tokenizer = get_tokenizer(self.model_name) + return tokenizer.encode(text).ids def get_num_tokens_from_messages( self, messages: Sequence[BaseMessage], tools: Sequence[dict[str, Any] | type | Callable | BaseTool] | None = None, ) -> int: - tokenizer = self._get_tokenizer() total = 0 for msg in messages: # Message overhead (role, separators, etc.) @@ -112,18 +75,18 @@ def get_num_tokens_from_messages( content = msg.content if isinstance(content, str): - total += len(tokenizer.encode(content)) + total += len(self.get_token_ids(content)) elif isinstance(content, list): for item in content: if isinstance(item, dict) and 'text' in item: - total += len(tokenizer.encode(item['text'])) + total += len(self.get_token_ids(item['text'])) elif isinstance(item, str): - total += len(tokenizer.encode(item)) + total += len(self.get_token_ids(item)) if tools: openai_tools = [convert_to_openai_tool(tool) for tool in tools] tools_json = json.dumps(openai_tools) - total += len(tokenizer.encode(tools_json)) + total += len(self.get_token_ids(tools_json)) total += 2 return total diff --git a/src/backend/tero/core/env.py b/src/backend/tero/core/env.py index c3ff6bb..b2f0cbb 100644 --- a/src/backend/tero/core/env.py +++ b/src/backend/tero/core/env.py @@ -30,13 +30,13 @@ class Settings(BaseSettings): disable_publish_global : Optional[bool] = False contact_email : str azure_app_insights_connection : Optional[str] = None - azure_endpoints : list[str] - azure_api_keys : list[SecretStr] - azure_api_version : str - azure_model_deployments : dict[str, AzureModelDeployment] + azure_endpoints : list[str] = [] + azure_api_keys : list[SecretStr] = [] + azure_api_version : Optional[str] = None + azure_model_deployments : dict[str, AzureModelDeployment] = {} azure_doc_intelligence_endpoint : Optional[str] = None azure_doc_intelligence_key : Optional[SecretStr] = None - azure_doc_intelligence_cost_per_1k_pages_usd : float + azure_doc_intelligence_cost_per_1k_pages_usd : Optional[float] = None temperatures: dict[str, float] monthly_usd_limit_default : int internal_generator_model : str @@ -46,6 +46,7 @@ class Settings(BaseSettings): agent_basic_models : List[str] default_agent_name : str embedding_model : str + embedding_context_limit : int = 8191 embedding_cost_per_1k_tokens : float transcription_model : str aws_access_key_id : Optional[SecretStr] = None @@ -56,6 +57,9 @@ class Settings(BaseSettings): google_model_id_mapping : dict[str, str] openai_api_key : Optional[SecretStr] = None openai_model_id_mapping : dict[str, str] + vllm_urls : List[str] = [] + vllm_api_keys : List[SecretStr] = [] + vllm_model_id_mapping : dict[str, str] = {} docs_tool_chunk_size : int docs_tool_chunk_overlap : int docs_tool_retrieve_top : int @@ -71,9 +75,6 @@ class Settings(BaseSettings): web_tool_google_cost_per_1k_searches_usd : float browser_tool_playwright_mcp_url : str browser_tool_playwright_output_dir : str - vllm_base_url : Optional[str] = None - vllm_api_key : Optional[SecretStr] = None - vllm_model_id_mapping : dict[str, str] = {} def is_local_env(self) -> bool: found = re.search('@([^/]+)(?:\\d+)?/', self.db_url) @@ -102,7 +103,7 @@ def decode_model_id_mapping(cls, v: str) -> dict[str, str]: def decode_temperatures(cls, v: str) -> dict[str, float]: return {k: float(v) for k, v in (pair.split(':', 1) for pair in v.split(','))} if v else {} - @field_validator('allowed_users', 'azure_endpoints', 'azure_api_keys', 'agent_basic_models', mode='before') + @field_validator('allowed_users', 'azure_endpoints', 'azure_api_keys', 'agent_basic_models', 'vllm_urls', 'vllm_api_keys', mode='before') @classmethod def decode_list(cls, v: str) -> list[str]: return v.split(',') if v else [] diff --git a/src/backend/tero/files/core.py b/src/backend/tero/files/core.py new file mode 100644 index 0000000..e59d320 --- /dev/null +++ b/src/backend/tero/files/core.py @@ -0,0 +1,59 @@ +import abc +from typing import Optional + +import chardet +from langchain_core.messages import HumanMessage + +from ..ai_models import ai_factory +from ..agents.domain import Agent +from ..usage.domain import Usage +from .domain import File + + +class CurrentQuota: + def __init__(self, current_usage: float, user_quota: float): + self.current_usage = current_usage + self.user_quota = user_quota + + +class FileQuota: + + def __init__(self, pdf_parsing_usage: Usage, agent: Optional[Agent], current_quota: CurrentQuota): + self.pdf_parsing_usage = pdf_parsing_usage + self.current_quota = current_quota + self.model = ai_factory.build_streaming_chat_model(agent.model_id, agent.model_temperature, agent.model_reasoning_effort) if agent else None + self.available_tokens = agent.model.token_limit - agent.model.output_token_limit if agent else None + + def has_reached_token_limit(self, text: str) -> bool: + if not self.model or not self.available_tokens: + return False + + current_tokens = self.model.get_num_tokens_from_messages(messages=[HumanMessage(content=text)]) + return current_tokens >= self.available_tokens + + def has_reached_quota_limit(self) -> bool: + return self.current_quota.current_usage + self.pdf_parsing_usage.usd_cost > self.current_quota.user_quota + + +class BaseFileProcessor(abc.ABC): + + @abc.abstractmethod + def supports(self, file: File) -> bool: + pass + + @abc.abstractmethod + def extract_text(self, file: File, file_quota: FileQuota) -> str: + pass + + +class QuotaExceededError(Exception): + pass + + +def add_encoding_to_content_type(content_type: Optional[str], content: bytes) -> str: + # add the encoding to the content type so later on it can be used (for exammple in tools file processing) and is avaible to frontend for proper file visualization + if content_type and content_type.startswith('text/') and not 'charset=' in content_type: + detected = chardet.detect(content) + encoding = detected['encoding'] if detected and detected['encoding'] else 'utf-8' + content_type = f"{content_type}; charset={encoding.lower()}" + return content_type or "application/octet-stream" diff --git a/src/backend/tero/files/file_quota.py b/src/backend/tero/files/file_quota.py deleted file mode 100644 index 15fbaed..0000000 --- a/src/backend/tero/files/file_quota.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -from typing import TYPE_CHECKING, Optional - -from langchain_core.messages import HumanMessage - -from ..ai_models import ai_factory -if TYPE_CHECKING: - from ..threads.engine import AgentEngine -from ..usage.domain import Usage - - -logger = logging.getLogger(__name__) - - -class QuotaExceededError(Exception): - pass - - -class CurrentQuota: - def __init__(self, current_usage: float, user_quota: float): - self.current_usage = current_usage - self.user_quota = user_quota - - -class FileQuota: - def __init__(self, pdf_parsing_usage: Usage, engine: Optional["AgentEngine"], current_quota: CurrentQuota): - self.pdf_parsing_usage = pdf_parsing_usage - self.current_quota = current_quota - self.model = ai_factory.build_streaming_chat_model(engine._agent.model_id, engine._agent.model_temperature, engine._agent.model_reasoning_effort) if engine else None - self.available_tokens = engine._agent.model.token_limit - engine._agent.model.output_token_limit if engine else None - - def has_reached_token_limit(self, current_content: str) -> bool: - if not self.model or not self.available_tokens: - return False - - current_tokens = self.model.get_num_tokens_from_messages(messages=[HumanMessage(content=current_content)]) - return current_tokens >= self.available_tokens - - def has_reached_quota_limit(self) -> bool: - return self.current_quota.current_usage + self.pdf_parsing_usage.usd_cost > self.current_quota.user_quota diff --git a/src/backend/tero/files/parser.py b/src/backend/tero/files/parser.py index 65d7ee9..73cdc17 100644 --- a/src/backend/tero/files/parser.py +++ b/src/backend/tero/files/parser.py @@ -1,12 +1,13 @@ import asyncio import logging -from typing import Optional - -import chardet from ..files.domain import File, FileProcessor -from ..files.file_processor import BaseFileProcessor, PlainTextFileProcessor, XlsxFileProcessor, XlsFileProcessor, BasicPdfFileProcessor, EnhancedPdfFileProcessor, ImageFileProcessor -from ..files.file_quota import FileQuota +from .core import BaseFileProcessor, FileQuota +from .domain import File, FileProcessor +from .processors.plaintext import PlainTextFileProcessor +from .processors.spreadsheet import XlsxFileProcessor, XlsFileProcessor +from .processors.image import ImageFileProcessor +from .processors.pdf import build_basic_pdf_processor, build_enhanced_pdf_processor logger = logging.getLogger(__name__) @@ -17,29 +18,20 @@ def __init__(self, file_name: str): super().__init__(f"Unsupported file type: {file_name}") -def add_encoding_to_content_type(content_type: Optional[str], content: bytes) -> str: - # add the encoding to the content type so later on it can be used (for exammple in tools file processing) and is avaible to frontend for proper file visualization - if content_type and content_type.startswith('text/') and not 'charset=' in content_type: - detected = chardet.detect(content) - encoding = detected['encoding'] if detected and detected['encoding'] else 'utf-8' - content_type = f"{content_type}; charset={encoding.lower()}" - return content_type or "application/octet-stream" +async def extract_file_text(file: File, file_quota: FileQuota) -> str: + processor = _find_file_processor(file) + return await asyncio.to_thread(processor.extract_text, file, file_quota) -def find_file_processor(file: File) -> BaseFileProcessor: +def _find_file_processor(file: File) -> BaseFileProcessor: processors = [ PlainTextFileProcessor(), XlsxFileProcessor(), XlsFileProcessor(), ImageFileProcessor(), - BasicPdfFileProcessor() if file.file_processor == FileProcessor.BASIC else EnhancedPdfFileProcessor() + build_basic_pdf_processor() if file.file_processor == FileProcessor.BASIC else build_enhanced_pdf_processor() ] found = next((processor for processor in processors if processor.supports(file)), None) if found is None: raise UnsupportedFileError(file.name) return found - - -async def extract_file_text(file: File, file_quota: FileQuota) -> str: - processor = find_file_processor(file) - return await asyncio.to_thread(processor.extract_text, file, file_quota) diff --git a/src/backend/tero/files/pdf_processor.py b/src/backend/tero/files/pdf_processor.py deleted file mode 100644 index fe05f20..0000000 --- a/src/backend/tero/files/pdf_processor.py +++ /dev/null @@ -1,307 +0,0 @@ -import abc -from dataclasses import dataclass -import io -import logging -from typing import Optional, TypeVar, Generic, cast - -from azure.ai.documentintelligence import DocumentIntelligenceClient -from azure.ai.documentintelligence.models import AnalyzeDocumentRequest, AnalyzeResult -from azure.core.credentials import AzureKeyCredential -from pydantic import SecretStr -from pypdf import PdfReader, PdfWriter -import pypdfium2 as pdfium -from tabulate import tabulate - -from ..core.env import env -from ..files.domain import File -from ..files.file_quota import FileQuota, QuotaExceededError - - -logger = logging.getLogger(__name__) -T = TypeVar('T', bound='BoundedElement') -PAGES_CHUNK_SIZE = 50 - - -@dataclass -class BoundingBox: - x: float - y: float - width: float - height: float - - @classmethod - def from_polygon(cls, polygon: list) -> 'BoundingBox | None': - if not polygon or len(polygon) != 8: - return None - - x_coords = [polygon[i] for i in range(0, 8, 2)] - y_coords = [polygon[i] for i in range(1, 8, 2)] - - x = min(x_coords) - y = min(y_coords) - width = max(x_coords) - x - height = max(y_coords) - y - return cls(x=x, y=y, width=width, height=height) - - def contains(self, other: 'BoundingBox') -> bool: - return (other.y >= self.y and other.y + other.height <= self.y + self.height) - - -@dataclass -class BoundedElement(Generic[T]): - content: str - y: float - height: float - bbox: Optional[BoundingBox] = None - - @classmethod - def create(cls: type[T], content: str, y: float, height: float, bbox: Optional[BoundingBox] = None) -> T: - return cls(content=content, y=y, height=height, bbox=bbox) - - -@dataclass -class BoundedParagraph(BoundedElement['BoundedParagraph']): - @classmethod - def from_paragraph(cls, paragraph: dict) -> Optional['BoundedParagraph']: - content = paragraph.get("content", "").strip() - if not content: - return None - - bounding_regions = paragraph.get("boundingRegions", []) - if not bounding_regions: - return cls.create(content=content, y=0.0, height=0.0, bbox=None) - - polygon = bounding_regions[0].get("polygon", []) - bbox = BoundingBox.from_polygon(polygon) if polygon else None - - if bbox: - return cls.create(content=content, y=bbox.y, height=bbox.height, bbox=bbox) - else: - return cls.create(content=content, y=0.0, height=0.0, bbox=None) - - -@dataclass -class BoundedTable(BoundedElement['BoundedTable']): - @classmethod - def from_cells(cls, table: dict) -> Optional['BoundedTable']: - cells = table.get("cells", []) - if not cells: - return None - - bbox = cls._get_table_bounding_box(table) - if not bbox: - return None - - grid = cls._create_grid_from_cells(cells) - content = cls._format_grid_as_markdown(grid) - - if not content.strip(): - return None - - return cls.create(content=content, y=bbox.y, height=bbox.height, bbox=bbox) - - @staticmethod - def _get_table_bounding_box(table: dict) -> BoundingBox | None: - table_regions = table.get("boundingRegions", []) - if not table_regions: - return None - - table_polygon = table_regions[0].get("polygon", []) - if not table_polygon: - return None - - return BoundingBox.from_polygon(table_polygon) - - @staticmethod - def _create_grid_from_cells(cells: list) -> list: - if not cells: - return [] - - max_row = max(cell.get("rowIndex", 0) for cell in cells) - max_col = max(cell.get("columnIndex", 0) for cell in cells) - grid = [["" for _ in range(max_col + 1)] for _ in range(max_row + 1)] - - for cell in cells: - row = cell.get("rowIndex", 0) - col = cell.get("columnIndex", 0) - content = BoundedTable._normalize_cell_text(cell.get("content", "")) - grid[row][col] = content - return grid - - @staticmethod - def _normalize_cell_text(text: str) -> str: - return text.replace(":unselected:", "").replace(":selected:", "").replace("\n", " ").strip() - - @staticmethod - def _format_grid_as_markdown(grid: list) -> str: - if not grid: - return "" - - header, *data = grid - table = tabulate(data, headers=header, tablefmt="pipe") - return f"\n{table}\n" - - -def process_pdf_basic(upload_file: File, file_quota: FileQuota) -> str: - processor = BasicPDFProcessor() - return processor.extract_content(upload_file, file_quota) - - -def process_pdf_enhanced(upload_file: File, file_quota: FileQuota) -> str: - processor = EnhancedPDFProcessor(endpoint=cast(str, env.azure_doc_intelligence_endpoint), key=cast(SecretStr, env.azure_doc_intelligence_key).get_secret_value()) - return processor.extract_content(upload_file, file_quota) - - -class BasePDFProcessor(abc.ABC): - - @abc.abstractmethod - def extract_content(self, upload_file: File, file_quota: FileQuota) -> str: - pass - - def _get_total_pages(self, content: bytes): - pdf = PdfReader(io.BytesIO(content)) - return len(pdf.pages) - - def _format_pages_content(self, all_pages_content: dict) -> str: - return "\n\n".join(f"## Page {page_num}\n{all_pages_content[page_num]}" for page_num in sorted(all_pages_content.keys())) - - def _write_pdf_chunk(self, content: bytes, start_page: int, end_page: int) -> bytes: - try: - pdf_reader = PdfReader(io.BytesIO(content)) - pdf_writer = PdfWriter() - - for page_num in range(start_page, end_page + 1): - pdf_writer.add_page(pdf_reader.pages[page_num - 1]) - - output_buffer = io.BytesIO() - pdf_writer.write(output_buffer) - return output_buffer.getvalue() - - except Exception as e: - logger.warning(f"Failed to write PDF chunk {start_page}-{end_page}: {e}. Using original content.") - return content - - -class BasicPDFProcessor(BasePDFProcessor): - - def extract_content(self, upload_file: File, file_quota: FileQuota) -> str: - content = upload_file.content - total_pages = self._get_total_pages(content) - all_pages_content = {} - - for start_page in range(1, total_pages + 1, PAGES_CHUNK_SIZE): - end_page = min(start_page + PAGES_CHUNK_SIZE - 1, total_pages) - - if file_quota.has_reached_quota_limit(): - raise QuotaExceededError(f"Quota exceeded when analyzing pdf {upload_file.id} {upload_file.name}") - - current_content = self._format_pages_content(all_pages_content) - if file_quota.has_reached_token_limit(current_content): - logger.warning(f"Token limit reached when analyzing pdf {upload_file.id} {upload_file.name}. Stopping analysis at page {start_page-1}") - break - - pdf_chunk = self._write_pdf_chunk(content, start_page, end_page) - pages_content = self._process_with_pypdfium2(pdf_chunk, start_page) - all_pages_content.update(pages_content) - - return self._format_pages_content(all_pages_content) - - def _process_with_pypdfium2(self, pdf_chunk: bytes, start_page_offset: int = 0) -> dict: - pages_content = {} - with pdfium.PdfDocument(pdf_chunk) as pdf: - for relative_page_number, page in enumerate(pdf, start=1): - actual_page_number = relative_page_number + start_page_offset - 1 - textpage = page.get_textpage() - content = self._clean_pypdfium2_content(textpage.get_text_bounded()) - pages_content[actual_page_number] = content - return pages_content - - def _clean_pypdfium2_content(self, content: str) -> str: - return content.replace("\r", "").strip() - - -class EnhancedPDFProcessor(BasePDFProcessor): - - def __init__(self, endpoint: str, key: str): - self.client = DocumentIntelligenceClient(endpoint=endpoint, credential=AzureKeyCredential(key)) - - def extract_content(self, upload_file: File, file_quota: FileQuota) -> str: - content = upload_file.content - total_pages = self._get_total_pages(content) - all_pages_content = {} - analyzed_pages = 0 - - for start_page in range(1, total_pages + 1, PAGES_CHUNK_SIZE): - end_page = min(start_page + PAGES_CHUNK_SIZE - 1, total_pages) - - if file_quota.has_reached_quota_limit(): - raise QuotaExceededError(f"Quota exceeded when analyzing pdf {upload_file.id} {upload_file.name}") - - current_content = self._format_pages_content(all_pages_content) - if file_quota.has_reached_token_limit(current_content): - logger.warning(f"Token limit reached when analyzing pdf {upload_file.id} {upload_file.name}. Stopping analysis at page {start_page-1}") - break - - pdf_chunk = self._write_pdf_chunk(content, start_page, end_page) - chunk_pages = end_page - start_page + 1 - - layout = self._analyze_layout(pdf_chunk) - pages_content = self._extract_pages_content(layout, start_page) - self._update_with_pdf_parsing_usage(file_quota, chunk_pages) - - all_pages_content.update(pages_content) - analyzed_pages += chunk_pages - - return self._format_pages_content(all_pages_content) - - def _analyze_layout(self, content: bytes) -> AnalyzeResult: - request = AnalyzeDocumentRequest(bytes_source=content) - # https://tech-depth-and-breadth.medium.com/azure-ai-document-intelligence-for-rag-use-cases-4e242b0ba7de - poller = self.client.begin_analyze_document("prebuilt-layout", request) - return poller.result() - - def _extract_pages_content(self, result: AnalyzeResult, start_page_offset: int = 0) -> dict: - pages_content = {} - for page in result.get("pages", []): - relative_page_number = page.get("pageNumber", 1) - actual_page_number = relative_page_number + start_page_offset - 1 - tables = self._get_page_elements(result, "tables", relative_page_number) - paragraphs = self._get_page_elements(result, "paragraphs", relative_page_number) - elements = self._create_page_elements(paragraphs, tables) - pages_content[actual_page_number] = self._combine_elements_content(elements) - return pages_content - - def _update_with_pdf_parsing_usage(self, file_quota: FileQuota, analyzed_pages: int): - # https://azure.microsoft.com/en-us/pricing/details/ai-document-intelligence/ - file_quota.pdf_parsing_usage.increment(new_quantity=analyzed_pages, cost_per_1k_units=env.azure_doc_intelligence_cost_per_1k_pages_usd) - - def _get_page_elements(self, result: AnalyzeResult, element_type: str, page_number: int) -> list: - return [element for element in result.get(element_type, []) if element.get("boundingRegions", [{}])[0].get("pageNumber", -1) == page_number] - - def _create_page_elements(self, paragraphs: list, tables: list) -> list[BoundedElement]: - elements: list[BoundedElement] = [] - - paragraph_elements = [] - for paragraph in paragraphs: - paragraph_element = BoundedParagraph.from_paragraph(paragraph) - if paragraph_element: - paragraph_elements.append(paragraph_element) - - table_elements = [] - for table in tables: - table_element = BoundedTable.from_cells(table) - if table_element: - table_elements.append(table_element) - - for paragraph_element in paragraph_elements: - if paragraph_element.bbox and not any(table_element.bbox and table_element.bbox.contains(paragraph_element.bbox) for table_element in table_elements): - elements.append(paragraph_element) - elif not paragraph_element.bbox: - elements.append(paragraph_element) - - elements.extend(table_elements) - return elements - - def _combine_elements_content(self, elements: list[BoundedElement]) -> str: - elements.sort(key=lambda x: x.y) - return "\n".join(element.content for element in elements) diff --git a/src/backend/tero/files/processors/image.py b/src/backend/tero/files/processors/image.py new file mode 100644 index 0000000..2b06a2e --- /dev/null +++ b/src/backend/tero/files/processors/image.py @@ -0,0 +1,27 @@ +import io +import logging + +from PIL import Image + +from ..core import BaseFileProcessor, FileQuota +from ..domain import File + + +logger = logging.getLogger(__name__) + + +class ImageFileProcessor(BaseFileProcessor): + + def supports(self, file: File) -> bool: + return any(file.name.lower().endswith(ext) for ext in {'.jpg', '.jpeg', '.png'}) + + def extract_text(self, file: File, file_quota: FileQuota) -> str: + try: + image_bytes = io.BytesIO(file.content) + image = Image.open(image_bytes) + image.verify() + except Exception as e: + logger.error(f"Invalid image file {file.name}: {e}") + raise ValueError(f"Invalid image file: {file.name}") + + return f"Image file: {file.name}" diff --git a/src/backend/tero/files/processors/pdf/__init__.py b/src/backend/tero/files/processors/pdf/__init__.py new file mode 100644 index 0000000..8633d63 --- /dev/null +++ b/src/backend/tero/files/processors/pdf/__init__.py @@ -0,0 +1,18 @@ +from ...core import BaseFileProcessor +from .pypdfium import PyPdfiumPdfProcessor +from .azure_document_intelligence import AzureDocumentIntelligencePdfProcessor + + +def build_basic_pdf_processor() -> BaseFileProcessor: + return PyPdfiumPdfProcessor() + + +def is_enhanced_pdf_processor_available() -> bool: + return AzureDocumentIntelligencePdfProcessor.is_configured() + + +def build_enhanced_pdf_processor() -> BaseFileProcessor: + if AzureDocumentIntelligencePdfProcessor.is_configured(): + return AzureDocumentIntelligencePdfProcessor() + else: + raise RuntimeError("No enhanced PDF processor available") diff --git a/src/backend/tero/files/processors/pdf/azure_document_intelligence.py b/src/backend/tero/files/processors/pdf/azure_document_intelligence.py new file mode 100644 index 0000000..e3ef0dd --- /dev/null +++ b/src/backend/tero/files/processors/pdf/azure_document_intelligence.py @@ -0,0 +1,190 @@ +from dataclasses import dataclass +import logging +from typing import cast, Generic, Optional, TypeVar, Callable + +from pydantic import SecretStr +from tabulate import tabulate + +from azure.ai.documentintelligence import DocumentIntelligenceClient +from azure.ai.documentintelligence.models import AnalyzeDocumentRequest, AnalyzeResult +from azure.core.credentials import AzureKeyCredential + +from ....core.env import env +from .core import BasePdfProcessor + + +logger = logging.getLogger(__name__) +T = TypeVar('T', bound='BoundedElement') + + +@dataclass +class BoundingBox: + x: float + y: float + width: float + height: float + + @classmethod + def from_polygon(cls, polygon: list) -> 'BoundingBox | None': + if not polygon or len(polygon) != 8: + return None + + x_coords = [polygon[i] for i in range(0, 8, 2)] + y_coords = [polygon[i] for i in range(1, 8, 2)] + + x = min(x_coords) + y = min(y_coords) + width = max(x_coords) - x + height = max(y_coords) - y + return cls(x=x, y=y, width=width, height=height) + + def contains(self, other: 'BoundingBox') -> bool: + return (other.y >= self.y and other.y + other.height <= self.y + self.height) + + +@dataclass +class BoundedElement(Generic[T]): + content: str + y: float + height: float + bbox: Optional[BoundingBox] = None + + @classmethod + def create(cls: type[T], content: str, y: float, height: float, bbox: Optional[BoundingBox] = None) -> T: + return cls(content=content, y=y, height=height, bbox=bbox) + + +@dataclass +class BoundedParagraph(BoundedElement['BoundedParagraph']): + + @classmethod + def from_paragraph(cls, paragraph: dict) -> Optional['BoundedParagraph']: + content = paragraph.get("content", "").strip() + if not content: + return None + + bounding_regions = paragraph.get("boundingRegions", []) + if not bounding_regions: + return cls.create(content=content, y=0.0, height=0.0, bbox=None) + + polygon = bounding_regions[0].get("polygon", []) + bbox = BoundingBox.from_polygon(polygon) if polygon else None + + if bbox: + return cls.create(content=content, y=bbox.y, height=bbox.height, bbox=bbox) + else: + return cls.create(content=content, y=0.0, height=0.0, bbox=None) + + +@dataclass +class BoundedTable(BoundedElement['BoundedTable']): + @classmethod + def from_cells(cls, table: dict) -> Optional['BoundedTable']: + cells = table.get("cells", []) + if not cells: + return None + + bbox = cls._get_table_bounding_box(table) + if not bbox: + return None + + grid = cls._create_grid_from_cells(cells) + content = cls._format_grid_as_markdown(grid) + + if not content.strip(): + return None + + return cls.create(content=content, y=bbox.y, height=bbox.height, bbox=bbox) + + @staticmethod + def _get_table_bounding_box(table: dict) -> BoundingBox | None: + table_regions = table.get("boundingRegions", []) + if not table_regions: + return None + + table_polygon = table_regions[0].get("polygon", []) + if not table_polygon: + return None + + return BoundingBox.from_polygon(table_polygon) + + @staticmethod + def _create_grid_from_cells(cells: list) -> list: + if not cells: + return [] + + max_row = max(cell.get("rowIndex", 0) for cell in cells) + max_col = max(cell.get("columnIndex", 0) for cell in cells) + grid = [["" for _ in range(max_col + 1)] for _ in range(max_row + 1)] + + for cell in cells: + row = cell.get("rowIndex", 0) + col = cell.get("columnIndex", 0) + content = BoundedTable._normalize_cell_text(cell.get("content", "")) + grid[row][col] = content + return grid + + @staticmethod + def _normalize_cell_text(text: str) -> str: + return text.replace(":unselected:", "").replace(":selected:", "").replace("\n", "").strip() + + @staticmethod + def _format_grid_as_markdown(grid: list) -> str: + if not grid: + return "" + + header, *data = grid + table = tabulate(data, headers=header, tablefmt="pipe") + return f"\n{table}\n" + + +class AzureDocumentIntelligencePdfProcessor(BasePdfProcessor): + + @staticmethod + def is_configured() -> bool: + return bool(env.azure_doc_intelligence_endpoint and env.azure_doc_intelligence_key and env.azure_doc_intelligence_cost_per_1k_pages_usd) + + def __init__(self): + super().__init__(cast(float, env.azure_doc_intelligence_cost_per_1k_pages_usd)) + self._client = DocumentIntelligenceClient( + endpoint=cast(str, env.azure_doc_intelligence_endpoint), + credential=AzureKeyCredential(cast(SecretStr, env.azure_doc_intelligence_key).get_secret_value())) + + def _extract_pages_content(self, pdf_chunk: bytes, page_offset: int) -> dict[int, str]: + ret = {} + request = AnalyzeDocumentRequest(bytes_source=pdf_chunk) + # https://tech-depth-and-breadth.medium.com/azure-ai-document-intelligence-for-rag-use-cases-4e242b0ba7de + poller = self._client.begin_analyze_document("prebuilt-layout", request) + result = poller.result() + for page in result.get("pages", []): + page_number = page.get("pageNumber", 1) + elements = self._create_page_elements(result, page_number) + ret[page_number + page_offset - 1] = self._combine_elements_content(elements) + return ret + + def _create_page_elements(self, result: AnalyzeResult, page_number: int) -> list[BoundedElement]: + paragraph_elements = self._create_page_elements_by_type("paragraphs", BoundedParagraph.from_paragraph, result, page_number) + table_elements = self._create_page_elements_by_type("tables", BoundedTable.from_cells, result, page_number) + ret = [] + for paragraph_element in paragraph_elements: + if paragraph_element.bbox and not any(table_element.bbox and table_element.bbox.contains(paragraph_element.bbox) for table_element in table_elements): + ret.append(paragraph_element) + elif not paragraph_element.bbox: + ret.append(paragraph_element) + ret.extend(table_elements) + return ret + + def _create_page_elements_by_type(self, element_type: str, factory: Callable[[dict], Optional[BoundedElement]], result: AnalyzeResult, page_number: int) -> list[BoundedElement]: + ret = [] + for element in self._get_page_elements(result, element_type, page_number): + elem = factory(element) + if elem: + ret.append(cast(BoundedElement, elem)) + return ret + + def _get_page_elements(self, result: AnalyzeResult, element_type: str, page_number: int) -> list: + return [element for element in result.get(element_type, []) if element.get("boundingRegions", [{}])[0].get("pageNumber", -1) == page_number] + + def _combine_elements_content(self, elements: list[BoundedElement]) -> str: + elements.sort(key=lambda x: x.y) + return "\n".join(element.content for element in elements) diff --git a/src/backend/tero/files/processors/pdf/core.py b/src/backend/tero/files/processors/pdf/core.py new file mode 100644 index 0000000..ac0523b --- /dev/null +++ b/src/backend/tero/files/processors/pdf/core.py @@ -0,0 +1,74 @@ +import abc +import io +import logging + +from pypdf import PdfReader, PdfWriter + +from ...core import BaseFileProcessor, FileQuota, QuotaExceededError +from ...domain import File + + +logger = logging.getLogger(__name__) +_PAGES_CHUNK_SIZE = 50 + + +class BasePdfProcessor(BaseFileProcessor, abc.ABC): + + def __init__(self, cost_per_1k_pages_usd: float): + self._cost_per_1k_pages_usd = cost_per_1k_pages_usd + + def supports(self, file: File) -> bool: + return file.name.lower().endswith('.pdf') + + def extract_text(self, file: File, file_quota: FileQuota) -> str: + content = file.content + total_pages = self._get_total_pages(content) + all_pages_content = {} + + for start_page in range(1, total_pages + 1, _PAGES_CHUNK_SIZE): + end_page = min(start_page + _PAGES_CHUNK_SIZE - 1, total_pages) + + if file_quota.has_reached_quota_limit(): + raise QuotaExceededError(f"Quota exceeded when analyzing pdf {file.id} {file.name}") + + current_content = self._format_pages_content(all_pages_content) + if file_quota.has_reached_token_limit(current_content): + logger.warning(f"Token limit reached when analyzing pdf {file.id} {file.name}. Stopping analysis at page {start_page-1}") + break + + pdf_chunk = self._write_pdf_chunk(content, start_page, end_page) + chunk_pages = end_page - start_page + 1 + + pages_content = self._extract_pages_content(pdf_chunk, start_page) + file_quota.pdf_parsing_usage.increment(new_quantity=chunk_pages, cost_per_1k_units=self._cost_per_1k_pages_usd) + + all_pages_content.update(pages_content) + + return self._format_pages_content(all_pages_content) + + def _get_total_pages(self, content: bytes): + pdf = PdfReader(io.BytesIO(content)) + return len(pdf.pages) + + def _write_pdf_chunk(self, content: bytes, start_page: int, end_page: int) -> bytes: + try: + pdf_reader = PdfReader(io.BytesIO(content)) + pdf_writer = PdfWriter() + + for page_num in range(start_page, end_page + 1): + pdf_writer.add_page(pdf_reader.pages[page_num - 1]) + + output_buffer = io.BytesIO() + pdf_writer.write(output_buffer) + return output_buffer.getvalue() + + except Exception as e: + logger.warning(f"Failed to write PDF chunk {start_page}-{end_page}: {e}. Using original content.") + return content + + @abc.abstractmethod + def _extract_pages_content(self, pdf_chunk: bytes, page_offset: int) -> dict[int, str]: + pass + + def _format_pages_content(self, all_pages_content: dict) -> str: + return "\n\n".join(f"## Page {page_num}\n{all_pages_content[page_num]}" for page_num in sorted(all_pages_content.keys())) diff --git a/src/backend/tero/files/processors/pdf/pypdfium.py b/src/backend/tero/files/processors/pdf/pypdfium.py new file mode 100644 index 0000000..b99fea2 --- /dev/null +++ b/src/backend/tero/files/processors/pdf/pypdfium.py @@ -0,0 +1,24 @@ +import logging + +import pypdfium2 + +from .core import BasePdfProcessor + + +logger = logging.getLogger(__name__) + + +class PyPdfiumPdfProcessor(BasePdfProcessor): + + def __init__(self): + super().__init__(0.0) + + def _extract_pages_content(self, pdf_chunk: bytes, page_offset: int) -> dict[int, str]: + pages_content = {} + with pypdfium2.PdfDocument(pdf_chunk) as pdf: + for relative_page_number, page in enumerate(pdf, start=1): + actual_page_number = relative_page_number + page_offset - 1 + textpage = page.get_textpage() + content = textpage.get_text_bounded() + pages_content[actual_page_number] = content.replace("\r", "").strip() + return pages_content diff --git a/src/backend/tero/files/processors/plaintext.py b/src/backend/tero/files/processors/plaintext.py new file mode 100644 index 0000000..0970550 --- /dev/null +++ b/src/backend/tero/files/processors/plaintext.py @@ -0,0 +1,31 @@ +import logging +from typing import Optional + +from ..core import BaseFileProcessor, FileQuota +from ..domain import File + +logger = logging.getLogger(__name__) + +class PlainTextFileProcessor(BaseFileProcessor): + + def supports(self, file: File) -> bool: + return any(file.name.lower().endswith(ext) for ext in {'.txt', '.md', '.csv', '.har', '.json', '.svg'}) + + def extract_text(self, file: File, file_quota: FileQuota) -> str: + encoding = self._get_encoding(file.content_type) + try: + return file.content.decode(encoding) + except (UnicodeDecodeError, LookupError): + logger.warning(f"Failed to decode {file.name} with {encoding}. Trying fallback encodings.", exc_info=True) + for fallback_encoding in [ e for e in ['utf-8', 'latin-1', 'cp1252'] if e != encoding]: + try: + return file.content.decode(fallback_encoding) + except (UnicodeDecodeError, LookupError): + continue + logger.warning(f"All encodings failed for {file.name}, using {encoding} with error replacement") + return file.content.decode(encoding, errors='replace') + + def _get_encoding(self, content_type: Optional[str]) -> str: + charset_param = '; charset=' + encoding = content_type.split(charset_param, 1)[1] if content_type and charset_param in content_type else 'utf-8' + return encoding diff --git a/src/backend/tero/files/file_processor.py b/src/backend/tero/files/processors/spreadsheet.py similarity index 51% rename from src/backend/tero/files/file_processor.py rename to src/backend/tero/files/processors/spreadsheet.py index af0fe0e..11a82ef 100644 --- a/src/backend/tero/files/file_processor.py +++ b/src/backend/tero/files/processors/spreadsheet.py @@ -1,58 +1,17 @@ from abc import ABC, abstractmethod -import io from io import BytesIO -import logging -from typing import Optional, Any +from typing import Any import openpyxl import openpyxl.worksheet.worksheet -from PIL import Image import xlrd -from ..files.domain import File -from ..files.file_quota import FileQuota -from ..files.pdf_processor import process_pdf_basic, process_pdf_enhanced +from ..core import BaseFileProcessor, FileQuota +from ..domain import File -logger = logging.getLogger(__name__) - - -def get_encoding(content_type: Optional[str]) -> str: - charset_param = '; charset=' - encoding = content_type.split(charset_param, 1)[1] if content_type and charset_param in content_type else 'utf-8' - return encoding - -class BaseFileProcessor(ABC): - @abstractmethod - def supports(self, file: File) -> bool: - # Checks if this processor supports the given file - pass - - @abstractmethod - def extract_text(self, file: File, file_quota: FileQuota) -> str: - pass - -class PlainTextFileProcessor(BaseFileProcessor): - - def supports(self, file: File) -> bool: - return any(file.name.lower().endswith(ext) for ext in {'.txt', '.md', '.csv', '.har', '.json', '.svg'}) - - def extract_text(self, file: File, file_quota: FileQuota) -> str: - encoding = get_encoding(file.content_type) - try: - return file.content.decode(encoding) - except (UnicodeDecodeError, LookupError): - logger.warning(f"Failed to decode {file.name} with {encoding}. Trying fallback encodings.", exc_info=True) - for fallback_encoding in [ e for e in ['utf-8', 'latin-1', 'cp1252'] if e != encoding]: - try: - return file.content.decode(fallback_encoding) - except (UnicodeDecodeError, LookupError): - continue - logger.warning(f"All encodings failed for {file.name}, using {encoding} with error replacement") - return file.content.decode(encoding, errors='replace') - class Sheet(ABC): - + @property @abstractmethod def title(self) -> str: @@ -72,6 +31,7 @@ def column_count(self) -> int: def cell(self, row_idx: int, col_idx: int) -> Any: pass + class SpreadsheetFileProcessor(BaseFileProcessor, ABC): file_extension: str @@ -98,6 +58,7 @@ def _format_cell(self, row_idx: int, col_idx: int, sheet: Sheet) -> str: ret = sheet.cell(row_idx, col_idx) return str(ret) if ret is not None else "" + class XlsxSheet(Sheet): def __init__(self, sheet:openpyxl.worksheet.worksheet.Worksheet): @@ -106,18 +67,19 @@ def __init__(self, sheet:openpyxl.worksheet.worksheet.Worksheet): @property def title(self) -> str: return self._sheet.title - + @property def row_count(self) -> int: return self._sheet.max_row - + @property def column_count(self) -> int: return self._sheet.max_column - + def cell(self, row_idx: int, col_idx: int) -> Any: return self._sheet.cell(row_idx + 1, col_idx + 1).value + class XlsxFileProcessor(SpreadsheetFileProcessor): file_extension = '.xlsx' @@ -125,6 +87,7 @@ def _load_sheets(self, content: bytes) -> list[Sheet]: wb = openpyxl.load_workbook(BytesIO(content)) return [XlsxSheet(sheet) for sheet in wb.worksheets] + class XlsSheet(Sheet): def __init__(self, sheet: xlrd.sheet.Sheet): @@ -148,41 +111,10 @@ def column_count(self) -> int: def cell(self, row_idx: int, col_idx: int) -> Any: return self._sheet.cell(row_idx, col_idx).value + class XlsFileProcessor(SpreadsheetFileProcessor): file_extension = '.xls' def _load_sheets(self, content: bytes) -> list[Sheet]: wb = xlrd.open_workbook(file_contents=content) return [XlsSheet(sheet) for sheet in wb.sheets()] - -class BasicPdfFileProcessor(BaseFileProcessor): - - def supports(self, file: File) -> bool: - return file.name.lower().endswith('.pdf') - - def extract_text(self, file: File, file_quota: FileQuota) -> str: - return process_pdf_basic(file, file_quota) - -class EnhancedPdfFileProcessor(BaseFileProcessor): - - def supports(self, file: File) -> bool: - return file.name.lower().endswith('.pdf') - - def extract_text(self, file: File, file_quota: FileQuota) -> str: - return process_pdf_enhanced(file, file_quota) - -class ImageFileProcessor(BaseFileProcessor): - - def supports(self, file: File) -> bool: - return any(file.name.lower().endswith(ext) for ext in {'.jpg', '.jpeg', '.png'}) - - def extract_text(self, file: File, file_quota: FileQuota) -> str: - try: - image_bytes = io.BytesIO(file.content) - image = Image.open(image_bytes) - image.verify() - except Exception as e: - logger.error(f"Invalid image file {file.name}: {e}") - raise ValueError(f"Invalid image file: {file.name}") - - return f"Image file: {file.name}" diff --git a/src/backend/tero/threads/api.py b/src/backend/tero/threads/api.py index c9f5948..7aa2fdd 100644 --- a/src/backend/tero/threads/api.py +++ b/src/backend/tero/threads/api.py @@ -10,6 +10,7 @@ from fastapi.responses import StreamingResponse from sqlmodel.ext.asyncio.session import AsyncSession from sse_starlette.event import ServerSentEvent +from langgraph.errors import GraphRecursionError from ..agents.repos import AgentRepository from ..ai_models import ai_factory @@ -19,9 +20,10 @@ from ..core.env import env from ..core.repos import get_db from ..files.api import build_file_download_response +from ..files.core import FileQuota, CurrentQuota, QuotaExceededError, add_encoding_to_content_type from ..files.domain import File, FileStatus, FileMetadata, FileProcessor, FileMetadataWithContent -from ..files.file_quota import FileQuota, CurrentQuota, QuotaExceededError -from ..files.parser import add_encoding_to_content_type, extract_file_text +from ..files.parser import extract_file_text +from ..files.processors.pdf import is_enhanced_pdf_processor_available from ..files.repos import FileRepository from ..tools.auth import ToolAuthRequestException, build_tool_auth_request_http_exception from ..usage.domain import Usage, UsageType, MessageUsage @@ -165,7 +167,7 @@ async def add_message(thread_id: int, request: Request, user: Annotated[User, De user_message = await repo.add(initial_thread_message) await _attach_existing_files_to_message(existing_files, user_message, db) - await _handle_file_contents(files, user_message, user, thread, engine, db) + await _handle_file_contents(files, user_message, user, thread, db) user_message = await repo.refresh_with_files(user_message) return StreamingResponse( @@ -198,23 +200,22 @@ async def _attach_existing_files_to_message(files: List[ThreadMessageFile], user await repo.add(ThreadMessageFile(thread_message_id=user_message.id, file_id=f.file_id)) -async def _handle_file_contents(files: List[UploadFile], user_message: ThreadMessage, user: User, thread: Thread, engine: AgentEngine, db: AsyncSession): +async def _handle_file_contents(files: List[UploadFile], user_message: ThreadMessage, user: User, thread: Thread, db: AsyncSession): file_repo = FileRepository(db) if files: for f in files: content = await f.read() content_type = add_encoding_to_content_type(f.content_type, content) - file_processor = FileProcessor.ENHANCED if env.azure_doc_intelligence_endpoint and env.azure_doc_intelligence_key else FileProcessor.BASIC + file_processor = FileProcessor.ENHANCED if is_enhanced_pdf_processor_available() else FileProcessor.BASIC file = File(name=f.filename or "uploaded-file", content_type=content_type, content=content, user_id=user.id, file_processor=file_processor) pdf_parsing_usage = Usage(message_id=user_message.id, user_id=user.id, agent_id=thread.agent_id, model_id=None, type=UsageType.PDF_PARSING) current_usage = await UsageRepository(db).find_current_month_user_usage_usd(user.id) - file_quota = FileQuota(pdf_parsing_usage, engine, CurrentQuota(current_usage, user.monthly_usd_limit)) + file_quota = FileQuota(pdf_parsing_usage, thread.agent, CurrentQuota(current_usage, user.monthly_usd_limit)) try: file.processed_content = await extract_file_text(file, file_quota) file.status = FileStatus.PROCESSED saved_file = await file_repo.add(file) await ThreadMessageFileRepository(db).add(ThreadMessageFile(thread_message_id=user_message.id, file_id=saved_file.id)) - finally: await UsageRepository(db).add(pdf_parsing_usage) @@ -222,15 +223,18 @@ async def _handle_file_contents(files: List[UploadFile], user_message: ThreadMes async def _agent_response(message: ThreadMessage, thread: Thread, user_id: int, db: AsyncSession, is_in_agent_edition: bool) \ -> AsyncIterator[bytes]: message_usage = None + repo = ThreadMessageRepository(db) yield ServerSentEvent(event="userMessage", data=json.dumps({ "id": message.id, "files": [FileMetadata.from_file(f.file).model_dump(mode="json", by_alias=True) for f in message.files if f.file] })).encode() + complete_answer = "" + files: List[FileMetadata] = [] + status_updates: List[AgentActionEvent] = [] try: stop_event = asyncio.Event() active_streaming_connections[thread.id] = stop_event - repo = ThreadMessageRepository(db) message_usage = MessageUsage(user_id=user_id, agent_id=thread.agent_id, model_id=thread.agent.model_id, message_id=message.id) thread_messages = await repo.find_previous_messages(message) @@ -239,9 +243,6 @@ async def _agent_response(message: ThreadMessage, thread: Thread, user_id: int, await ThreadRepository(db).update(thread) answer_stream = AgentEngine(thread.agent, user_id, db).answer([*thread_messages, message], message_usage, stop_event) - complete_answer = "" - files: List[FileMetadata] = [] - status_updates: List[AgentActionEvent] = [] async for event in answer_stream: if isinstance(event, AgentActionEvent): @@ -275,7 +276,7 @@ async def _agent_response(message: ThreadMessage, thread: Thread, user_id: int, parent_id=message.id, minutes_saved=minutes_saved, stopped=stop_event.is_set(), - status_updates=[event.model_dump(mode="json", by_alias=True) for event in status_updates] if status_updates else None + status_updates=_dump_status_updates(status_updates) )) for f in files: await ThreadMessageFileRepository(db).add(ThreadMessageFile(thread_message_id=answer.id, file_id=f.id)) @@ -286,14 +287,34 @@ async def _agent_response(message: ThreadMessage, thread: Thread, user_id: int, "minutesSaved": answer.minutes_saved, "stopped": answer.stopped })).encode() - except Exception: + + except* GraphRecursionError: + await repo.add(ThreadMessage( + thread_id=thread.id, + text="ERROR_RECURSION_LIMIT_EXCEEDED", + origin=ThreadMessageOrigin.SYSTEM, + parent_id=message.id, + status_updates=_dump_status_updates(status_updates) + )) + yield ServerSentEvent(event="error", data="recursionLimitExceeded").encode() + except* Exception: logger.exception(f"Problem answering message in thread {thread.id}") + await repo.add(ThreadMessage( + thread_id=thread.id, + text="ERROR_GENERIC", + origin=ThreadMessageOrigin.SYSTEM, + parent_id=message.id, + status_updates=_dump_status_updates(status_updates) + )) yield ServerSentEvent(event="error").encode() finally: await UsageRepository(db).add(message_usage) del active_streaming_connections[thread.id] +def _dump_status_updates(status_updates: List[AgentActionEvent]) -> Optional[List[dict]]: + return [event.model_dump(mode="json", by_alias=True) for event in status_updates] if status_updates else None + @router.post(THREAD_PATH + "/stop", status_code=status.HTTP_200_OK) async def stop_message(thread_id: int, user: Annotated[User, Depends(get_current_user)], db: Annotated[AsyncSession, Depends(get_db)]): diff --git a/src/backend/tero/threads/domain.py b/src/backend/tero/threads/domain.py index 25add8d..7e3d3fd 100644 --- a/src/backend/tero/threads/domain.py +++ b/src/backend/tero/threads/domain.py @@ -61,15 +61,16 @@ class ThreadListItem(BaseThread, CamelCaseModel): @staticmethod def from_thread(thread: Thread, last_message: Optional[datetime] = None, creation: Optional[datetime] = None) -> 'ThreadListItem': return ThreadListItem.model_validate( - {**thread.model_dump(), - "agent": AgentListItem.from_agent(thread.agent, thread.agent.is_editable_by(thread.user)), - "creation": thread.creation if creation is None else creation, + {**thread.model_dump(), + "agent": AgentListItem.from_agent(thread.agent, thread.agent.is_editable_by(thread.user)), + "creation": thread.creation if creation is None else creation, "last_message": last_message}) class ThreadMessageOrigin(Enum): USER = 'USER' AGENT = 'AGENT' + SYSTEM = 'SYSTEM' class ThreadMessageUpdate(CamelCaseModel): @@ -111,7 +112,7 @@ class ThreadMessageFile(CamelCaseModel, table=True): __tablename__ : Any = "thread_message_file" thread_message_id: int = Field(primary_key=True, foreign_key="thread_message.id", index=True, ondelete="CASCADE") file_id: int = Field(primary_key=True, foreign_key="file.id", index=True, ondelete="CASCADE") - + thread_message: "ThreadMessage" = Relationship(back_populates="files") file: "File" = Relationship() @@ -135,7 +136,7 @@ class ThreadMessagePublic(CamelCaseModel, table=False): @staticmethod def from_message(message: ThreadMessage) -> 'ThreadMessagePublic': return ThreadMessagePublic.model_validate( - {**message.model_dump(), + {**message.model_dump(), "files": [FileMetadata.from_file(m.file) for m in message.files if m.file] if message.files else None}) diff --git a/src/backend/tero/threads/engine.py b/src/backend/tero/threads/engine.py index c29f0ed..8691bc3 100644 --- a/src/backend/tero/threads/engine.py +++ b/src/backend/tero/threads/engine.py @@ -62,15 +62,15 @@ async def load_tools(self, stack: AsyncExitStack, thread_id: Optional[int] = Non tool = await stack.enter_async_context(agent_tool.load()) ret.append(tool) return ret - + async def answer(self, messages: List[ThreadMessage], message_usage: MessageUsage, stop_event: asyncio.Event) -> AsyncIterator[AgentEvent]: llm = ai_factory.build_streaming_chat_model(self._agent.model.id, self._agent.model_temperature, self._agent.model_reasoning_effort) async with AsyncExitStack() as stack: agent_tools = await self.load_tools(stack, thread_id=messages[0].thread_id) tools = [ lt for t in agent_tools for lt in await t.build_langchain_tools() ] tools.append(clock) - # Enable error handling so ToolException from MCP tools (execution errors) - # are shown to the LLM instead of crashing the agent + # Enable error handling so ToolException from MCP tools (execution errors) + # are shown to the LLM instead of crashing the agent agent = create_react_agent( llm, ToolNode(tools, handle_tool_errors=True), pre_model_hook=self._build_message_trimmer(llm, tools) ) @@ -80,8 +80,7 @@ async def answer(self, messages: List[ThreadMessage], message_usage: MessageUsag stream = agent.astream( input, { - # multiply by 2 and add 1 because recursion counts every event (find tool & call tool) - "recursion_limit": 20 * 2 + 1 + "recursion_limit": self._agent.recursion_limit }, stream_mode=["updates", "messages", "custom"], ) @@ -110,7 +109,7 @@ async def answer(self, messages: List[ThreadMessage], message_usage: MessageUsag message_usage.increment_tool_usage(agent_tool_metadata.tool_usage) if agent_tool_metadata.file: yield AgentFileEvent(file=agent_tool_metadata.file) - + # If the response was stopped, approximate the token usage if stop_event.is_set(): approximate_input_tokens = llm.get_num_tokens_from_messages(input["messages"]) + self._count_tools_tokens(tools, llm) @@ -121,7 +120,7 @@ async def answer(self, messages: List[ThreadMessage], message_usage: MessageUsag "output_tokens": approximate_output_tokens, "total_tokens": approximate_input_tokens + approximate_output_tokens }, self._agent.model) - + def _get_content(self, msg: str | list[str | dict]) -> str: if isinstance(msg, str): return msg @@ -136,7 +135,7 @@ def _get_content(self, msg: str | list[str | dict]) -> str: texts.append(text) return "".join(texts) raise ValueError(f"Invalid message type: {type(msg)}") - + async def _process_updates(self, content: Any) -> AsyncIterator[AgentActionEvent]: if isinstance(content, dict): ((key, value), *_) = content.items() diff --git a/src/backend/tero/threads/time_saved_estimation.py b/src/backend/tero/threads/time_saved_estimation.py index 1876259..585610c 100644 --- a/src/backend/tero/threads/time_saved_estimation.py +++ b/src/backend/tero/threads/time_saved_estimation.py @@ -32,11 +32,11 @@ async def estimate_minutes_saved(user_message: str, agent_response: str, thread: HumanMessage(content=user_message), AIMessage(content=agent_response), ] + [ - HumanMessage(content=message.text) if message.origin == ThreadMessageOrigin.USER else AIMessage(content=message.text) - for message in thread_messages + HumanMessage(content=message.text) if message.origin == ThreadMessageOrigin.USER else AIMessage(content=message.text) + for message in thread_messages if message.origin != ThreadMessageOrigin.SYSTEM ] + [ - HumanMessage(content=message.text) if message.origin == ThreadMessageOrigin.USER else AIMessage(content=message.text) - for message in feedback_messages + HumanMessage(content=message.text) if message.origin == ThreadMessageOrigin.USER else AIMessage(content=message.text) + for message in feedback_messages if message.origin != ThreadMessageOrigin.SYSTEM ] token_counter = llm.get_num_tokens_from_messages @@ -58,11 +58,11 @@ async def estimate_minutes_saved(user_message: str, agent_response: str, thread: ) system_prompt = SYSTEM_PROMPT.format( - agent_name=thread.agent.name, + agent_name=thread.agent.name, agent_description=thread.agent.description, - user_message=trimmed_messages.pop(0).content if trimmed_messages else "[NO USER MESSAGE]", - agent_response=trimmed_messages.pop(0).content if trimmed_messages else "[NO AGENT RESPONSE]", - previous_message=trimmed_messages.pop(0).content if trimmed_messages and thread_messages else "[NO PREVIOUS USER MESSAGE IN CONVERSATION]", + user_message=trimmed_messages.pop(0).content if trimmed_messages else "[NO USER MESSAGE]", + agent_response=trimmed_messages.pop(0).content if trimmed_messages else "[NO AGENT RESPONSE]", + previous_message=trimmed_messages.pop(0).content if trimmed_messages and thread_messages else "[NO PREVIOUS USER MESSAGE IN CONVERSATION]", previous_agent_response=trimmed_messages.pop(0).content if trimmed_messages and thread_messages else "[NO PREVIOUS AGENT RESPONSE IN CONVERSATION]", reference_examples=_add_reference_examples(feedback_messages, trimmed_messages) ) @@ -88,11 +88,11 @@ def _add_reference_examples(feedback_thread_messages: List[ThreadMessage], feedb feedback_thread_messages = feedback_thread_messages[2:] examples.append(f""" -Reference example {len(examples) + 1}: +Reference example {len(examples) + 1}: User message: \"\"\"{user_message.content}\"\"\" Agent response: \"\"\"{agent_response.content}\"\"\" Minutes saved: {minutes_saved}""") - + return "\n".join(examples) if examples else "[NO REFERENCE EXAMPLES]" @@ -105,7 +105,7 @@ def _add_reference_examples(feedback_thread_messages: List[ThreadMessage], feedb --- INSTRUCTIONS: -- Reply with the exact number of minutes saved. +- Reply with the exact number of minutes saved. - Reply 0 only if the response is purely a greeting, confirmation, vague promise of help, or only a follow-up question with no actual work done. --- @@ -173,4 +173,4 @@ def _add_reference_examples(feedback_thread_messages: List[ThreadMessage], feedb INTERACTION TO EVALUATE (this is the only part to score): User message: \"\"\"{user_message}\"\"\" Agent response: \"\"\"{agent_response}\"\"\" -""" \ No newline at end of file +""" diff --git a/src/backend/tero/tools/auth.py b/src/backend/tero/tools/auth.py index 7ddb8d4..c748764 100644 --- a/src/backend/tero/tools/auth.py +++ b/src/backend/tero/tools/auth.py @@ -77,6 +77,7 @@ class ToolOAuthClientInfo(SQLModel, table=True): class ToolAuthRequest(CamelCaseModel): request_type: str + tool_id: str agent_id: int @@ -93,7 +94,6 @@ class ToolOAuthRequest(ToolAuthRequest): class ToolAuthTokenRequest(ToolAuthRequest): request_type: str = "auth_token" - tool_id: str def build_tool_auth_request_http_exception(request: ToolAuthRequest) -> HTTPException: @@ -303,7 +303,7 @@ async def _redirect_handler(self, auth_url: str): code_verifier=self.code_verifier, token_endpoint=self.context.oauth_metadata.token_endpoint.unicode_string() if self.context.oauth_metadata else None) await self._oauth_repo.save_state(tool_state) - raise ToolAuthRequestException(ToolOAuthRequest(oauth_url=auth_url, oauth_state=self.state, agent_id=self._agent_id)) + raise ToolAuthRequestException(ToolOAuthRequest(tool_id=self._tool_id, oauth_url=auth_url, oauth_state=self.state, agent_id=self._agent_id)) # this is just to satisfy the callback_handler. It should never be called due to the redirect_handler async def _callback_handler(self) -> tuple[str, str | None]: diff --git a/src/backend/tero/tools/docs/tool.py b/src/backend/tero/tools/docs/tool.py index 4cd1818..be89e9f 100644 --- a/src/backend/tero/tools/docs/tool.py +++ b/src/backend/tero/tools/docs/tool.py @@ -1,11 +1,12 @@ import aiofiles from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from enum import Enum +from functools import cache import logging +from tokenizers import Tokenizer from typing import List, Any, Optional, cast, Sequence from uuid import UUID -from enum import Enum -import tiktoken from langchain_classic.indexes import SQLRecordManager, aindex from langchain_core.callbacks import AsyncCallbackHandler @@ -32,12 +33,12 @@ from ...ai_models import ai_factory from ...ai_models.domain import LlmModel from ...ai_models.repos import AiModelRepository -from ...ai_models.vllm_provider import HuggingFaceTokenizerEncoding from ...core.assets import solve_asset_path from ...core.env import env from ...files.domain import File, FileProcessor -from ...files.file_quota import FileQuota, CurrentQuota +from ...files.core import FileQuota, CurrentQuota from ...files.parser import extract_file_text +from ...files.processors.pdf import is_enhanced_pdf_processor_available from ...files.repos import FileRepository from ...threads.domain import AgentActionEvent, AgentAction from ...usage.domain import Usage, MessageUsage, UsageType @@ -48,21 +49,14 @@ from .repos import DocToolFileRepository, DocToolConfigRepository - logger = logging.getLogger(__name__) DOCS_TOOL_ID = "docs" ADVANCED_FILE_PROCESSING = "advancedFileProcessing" -def embedding_tokens_from_text(text: str) -> int: - embeddings_encoding = tiktoken.encoding_for_model(env.embedding_model) - return len(embeddings_encoding.encode(text)) - - class DocumentUrlSolvingRetriever(VectorStoreRetriever): agent_id: int tool_id: str - embedding_usage: Usage async def _aget_relevant_documents( self, @@ -71,7 +65,6 @@ async def _aget_relevant_documents( run_manager: AsyncCallbackManagerForRetrieverRun, **kwargs: Any, ) -> list[Document]: - self.embedding_usage.increment(embedding_tokens_from_text(query), env.embedding_cost_per_1k_tokens) ret = await super()._aget_relevant_documents( query, run_manager=run_manager, **kwargs ) @@ -97,7 +90,7 @@ class DocsTool(AgentToolWithFiles): @model_validator(mode="after") def remove_advanced_processing_if_not_configured(self): - if not env.azure_doc_intelligence_endpoint or not env.azure_doc_intelligence_key: + if not is_enhanced_pdf_processor_available(): self.config_schema["properties"].pop(ADVANCED_FILE_PROCESSING) self.config_schema["required"].remove(ADVANCED_FILE_PROCESSING) return self @@ -136,7 +129,8 @@ async def teardown(self): def _build_vectorstore(self): ai_provider = ai_factory.get_provider(env.embedding_model) - return PGVector(embeddings=ai_provider.build_embedding(env.embedding_model), connection=self._get_async_engine(), + usage_tracker = lambda tokens: self.embedding_usage.increment(tokens, env.embedding_cost_per_1k_tokens) + return PGVector(embeddings=ai_provider.build_embedding(env.embedding_model, usage_tracker), connection=self._get_async_engine(), collection_name=self._build_collection_name(self.agent.id), use_jsonb=True) async def add_file(self, file: File, user: User): @@ -156,25 +150,33 @@ async def _handle_file(self, file: File, user: User): file.processed_content = file_doc.page_content await FileRepository(self.db).update(file) await self._update_tool_description_with_file(file, model, message_usage) - docs = MarkdownTextSplitter.from_tiktoken_encoder(encoding_name=tiktoken.encoding_for_model(env.embedding_model).name, - chunk_size=env.docs_tool_chunk_size, chunk_overlap=env.docs_tool_chunk_overlap).split_documents([file_doc]) - embeddings_tokens = sum(embedding_tokens_from_text(doc.page_content) for doc in docs) - self.embedding_usage.increment(embeddings_tokens, env.embedding_cost_per_1k_tokens) - await aindex(docs, self._build_record_manager(), self._build_vectorstore(), cleanup="incremental", - source_id_key="id", key_encoder="sha256") - + await aindex( + self._split_file_content(file_doc), + self._build_record_manager(), + self._build_vectorstore(), + cleanup="incremental", + source_id_key="id", + key_encoder="sha256") finally: usage_repo = UsageRepository(self.db) await usage_repo.add(pdf_parsing_usage) await usage_repo.add(message_usage) await usage_repo.add(self.embedding_usage) - + async def _update_tool_description_with_file(self, file: File, model: LlmModel, message_usage: MessageUsage): description = await self._generate_file_description(file, model, message_usage) await DocToolFileRepository(self.db).add( DocToolFile(file_id=file.id, description=description, agent_id=self.agent.id)) await self._update_tool_description(model, message_usage) + def _split_file_content(self, file_doc: Document) -> list[Document]: + ai_provider = ai_factory.get_provider(env.embedding_model) + text_splitter = MarkdownTextSplitter( + length_function=lambda text: ai_provider.count_tokens(text, env.embedding_model), + chunk_size=env.docs_tool_chunk_size, + chunk_overlap=env.docs_tool_chunk_overlap) + return text_splitter.split_documents([file_doc]) + async def _find_description_model(self) -> LlmModel: ret = await AiModelRepository(self.db).find_by_id(env.internal_generator_model) if not ret: @@ -188,35 +190,22 @@ async def _build_document(file: File, file_quota: FileQuota): return Document(page_content=content, metadata=metadata) async def _generate_file_description(self, file: File, model: LlmModel, message_usage: MessageUsage) -> str: + llm = ai_factory.build_chat_model(model.id, env.internal_generator_temperature) async with aiofiles.open(solve_asset_path('file-description-prompt.md', __file__)) as f: system_prompt = await f.read() - text_splitter = self._create_text_splitter(model.id) + text_splitter = CharacterTextSplitter( + length_function=lambda text: llm.get_num_tokens(text), + chunk_size=env.docs_tool_description_chunk_size, + chunk_overlap=env.docs_tool_description_chunk_overlap) chunks = text_splitter.split_text(cast(str, file.processed_content)) ret = "none" for chunk in chunks: prompt = system_prompt + f"Previous Description: {ret}\n\n" + f"## File contents\n\n{chunk}" - ret = await self._generate_description(prompt, 200, model, message_usage) + ret = await self._generate_description(prompt, 200, llm, model, message_usage) return ret @staticmethod - def _create_text_splitter(model_id: str) -> CharacterTextSplitter: - hf_tokenizer = HuggingFaceTokenizerEncoding.from_pretrained(model_id) - if hf_tokenizer: - return CharacterTextSplitter( - chunk_size=env.docs_tool_description_chunk_size, - chunk_overlap=env.docs_tool_description_chunk_overlap, - length_function=lambda text: len(hf_tokenizer.encode(text)), - ) - - return CharacterTextSplitter.from_tiktoken_encoder( - model_name=model_id, - chunk_size=env.docs_tool_description_chunk_size, - chunk_overlap=env.docs_tool_description_chunk_overlap, - ) - - @staticmethod - async def _generate_description(prompt: str, max_length: int, model: LlmModel, message_usage: MessageUsage) -> str: - llm = ai_factory.build_chat_model(model.id, env.internal_generator_temperature) + async def _generate_description(prompt: str, max_length: int, llm: BaseChatModel, model: LlmModel, message_usage: MessageUsage) -> str: response = await llm.ainvoke([HumanMessage(prompt)]) response = cast(AIMessage, response) message_usage.increment_with_metadata(response.usage_metadata, model) @@ -241,12 +230,13 @@ async def _generate_tool_description(self, files: List[DocToolFile], model: LlmM prompt = await f.read() for f in files: prompt += f"\n- {f.description}" - return await self._generate_description(prompt, 200, model, message_usage) + llm = ai_factory.build_chat_model(model.id, env.internal_generator_temperature) + return await self._generate_description(prompt, 200, llm, model, message_usage) async def update_file(self, file: File, user: User): # clear processed content before update to avoid partial quota-exceeded state await self._remove_file_processed_content(file) - await self._handle_file(file, user) + await self._handle_file(file, user) async def remove_file(self, file: File): # langchain does not provide an abstraction to just remove one document from the index, so we built this logic @@ -318,8 +308,7 @@ def _build_retriever(self) -> VectorStoreRetriever: vectorstore=self._build_vectorstore(), search_kwargs={"k": env.docs_tool_retrieve_top}, agent_id=self.agent.id, - tool_id=self.id, - embedding_usage=self.embedding_usage, + tool_id=self.id ) @staticmethod diff --git a/src/backend/tests/assets/init_db.sql b/src/backend/tests/assets/init_db.sql index a1f2757..1b8eb9c 100644 --- a/src/backend/tests/assets/init_db.sql +++ b/src/backend/tests/assets/init_db.sql @@ -31,13 +31,13 @@ insert into team_role (team_id, user_id, role, status) values (4, 2, 'TEAM_MEMBER', 'PENDING'), (4, 5, 'TEAM_MEMBER', 'REJECTED'); -insert into agent (name, description, user_id, last_update, team_id, model_id, system_prompt, temperature, reasoning_effort) values -('Agent 1', 'This is the first agent', 1, '2025-02-21 12:00', Null, 'gpt-4o-mini', 'You are a helpful AI agent.', 'PRECISE', 'LOW'), -('Agent 2', 'This is the second agent', 1, '2025-02-21 12:01', 1, 'o4-mini', 'You are a helpful AI agent.', 'CREATIVE', 'LOW'), -('Agent 3', 'This is the third agent', 2, '2025-02-21 12:02', 4, 'gpt-4o', 'You are a helpful AI agent.', 'PRECISE', 'LOW'), -('Agent 4', 'This is the fourth agent', 2, '2025-02-21 12:03', Null, 'gpt-4o', 'You are a helpful AI agent.', 'CREATIVE', 'LOW'), -('Agent 5', 'This is the fifth agent', 2, '2025-02-21 12:04', 2, 'gpt-4o', 'You are a helpful AI agent.', 'PRECISE', 'LOW'), -('GPT-5 Nano', 'This is the default agent', Null, '2025-02-21 12:00', 1, 'gpt-5-nano', 'You are a helpful AI agent.', 'NEUTRAL', 'LOW'); +insert into agent (name, description, user_id, last_update, team_id, model_id, system_prompt, temperature, reasoning_effort, recursion_limit) values +('Agent 1', 'This is the first agent', 1, '2025-02-21 12:00', Null, 'gpt-4o-mini', 'You are a helpful AI agent.', 'PRECISE', 'LOW', 20), +('Agent 2', 'This is the second agent', 1, '2025-02-21 12:01', 1, 'o4-mini', 'You are a helpful AI agent.', 'CREATIVE', 'LOW', 20), +('Agent 3', 'This is the third agent', 2, '2025-02-21 12:02', 4, 'gpt-4o', 'You are a helpful AI agent.', 'PRECISE', 'LOW', 20), +('Agent 4', 'This is the fourth agent', 2, '2025-02-21 12:03', Null, 'gpt-4o', 'You are a helpful AI agent.', 'CREATIVE', 'LOW', 20), +('Agent 5', 'This is the fifth agent', 2, '2025-02-21 12:04', 2, 'gpt-4o', 'You are a helpful AI agent.', 'PRECISE', 'LOW', 20), +('GPT-5 Nano', 'This is the default agent', Null, '2025-02-21 12:00', 1, 'gpt-5-nano', 'You are a helpful AI agent.', 'NEUTRAL', 'LOW', 20); insert into user_agent (user_id, agent_id, creation) values (1, 1, '2025-02-21 12:00'), @@ -181,6 +181,6 @@ $$ LANGUAGE plpgsql; CREATE TRIGGER after_update_test_suite_run_status AFTER UPDATE OF status ON test_suite_run -FOR EACH ROW +FOR EACH ROW WHEN (OLD.status IS DISTINCT FROM NEW.status) EXECUTE FUNCTION notify_test_suite_run_status(); diff --git a/src/backend/tests/test_agent_distribution.py b/src/backend/tests/test_agent_distribution.py index 14434d6..5efe53f 100644 --- a/src/backend/tests/test_agent_distribution.py +++ b/src/backend/tests/test_agent_distribution.py @@ -30,8 +30,8 @@ async def test_import_exported_minimal_agent(users: List[UserListItem], client: resp = await _find_agent(target_agent_id, client) assert_response(resp, PublicAgent( id=target_agent_id, name=f"Agent #{source_agent_id}", description="", last_update=CURRENT_TIME, user_id=USER_ID, - model_id=cast(str, env.agent_default_model), system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=LlmTemperature.NEUTRAL, reasoning_effort=ReasoningEffort.LOW, - icon=None, can_edit=True, user=users[0])) + model_id=cast(str, env.agent_default_model), system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=LlmTemperature.NEUTRAL, reasoning_effort=ReasoningEffort.LOW, + recursion_limit=20, icon=None, can_edit=True, user=users[0])) async def _create_agent(client: AsyncClient) -> int: @@ -76,7 +76,7 @@ async def fixture_last_file_id(session: AsyncSession) -> int: @freeze_time(CURRENT_TIME) async def test_import_exported_agent_with_all_tools_and_configs( users: List[UserListItem], last_prompt_id: int, last_file_id: int, last_thread_id: int, last_message_id: int, client: AsyncClient): - agent_update = AgentUpdate(name="Test Agent 1", description="Test description", system_prompt="Test system prompt", + agent_update = AgentUpdate(name="Test Agent 1", description="Test description", system_prompt="Test system prompt", model_id="o4-mini", icon=TEST_ICON, reasoning_effort= ReasoningEffort.MEDIUM) prompts = [ AgentPromptCreate(name="Starter", content="Starter Text", shared=True, starter=True), @@ -95,12 +95,12 @@ async def test_import_exported_agent_with_all_tools_and_configs( zip_file_content = await _export_agent(source_agent_id, client) target_agent_id = await _create_agent(client) await _import_agent(target_agent_id, zip_file_content, client) - await _assert_imported_agent(target_agent_id, agent_update, prompts, advanced_file_processing, file_name, file_content, test_messages, + await _assert_imported_agent(target_agent_id, agent_update, prompts, advanced_file_processing, file_name, file_content, test_messages, last_prompt_id + len(prompts), last_file_id + 1, last_thread_id + 1, last_message_id + len(test_messages), test_case_name, users, client) async def _create_agent_with_all_tools_and_configs( - agent_update: AgentUpdate, prompts: list[AgentPromptCreate], advanced_file_processing: bool, file_name: str, file_content: bytes, + agent_update: AgentUpdate, prompts: list[AgentPromptCreate], advanced_file_processing: bool, file_name: str, file_content: bytes, test_messages: list[NewTestCaseMessage], client: AsyncClient) -> tuple[int, str]: agent_id = await _create_agent(client) await _update_agent(agent_id, agent_update, client) @@ -138,27 +138,27 @@ async def _add_test(agent_id: int, client: AsyncClient) -> int: async def _add_test_message(agent_id: int, test_case_id: int, message: NewTestCaseMessage, client: AsyncClient): - resp = await client.post(TEST_CASE_MESSAGES_PATH.format(agent_id=agent_id, test_case_id=test_case_id), + resp = await client.post(TEST_CASE_MESSAGES_PATH.format(agent_id=agent_id, test_case_id=test_case_id), json=message.model_dump(mode="json", by_alias=True)) resp.raise_for_status() -async def _assert_imported_agent(agent_id: int, agent_update: AgentUpdate, prompts: list[AgentPromptCreate], - advanced_file_processing: bool, file_name: str, file_content: bytes, test_messages: list[NewTestCaseMessage], +async def _assert_imported_agent(agent_id: int, agent_update: AgentUpdate, prompts: list[AgentPromptCreate], + advanced_file_processing: bool, file_name: str, file_content: bytes, test_messages: list[NewTestCaseMessage], last_prompt_id: int, last_file_id: int, last_thread_id: int, last_message_id: int, test_case_name: str, users: List[UserListItem], client: AsyncClient): resp = await _find_agent(agent_id, client) resp.raise_for_status() assert_response(resp, PublicAgent( - id=agent_id, name=agent_update.name, description=agent_update.description, last_update=CURRENT_TIME, - user_id=USER_ID, model_id=cast(str, agent_update.model_id), system_prompt=cast(str, agent_update.system_prompt), - temperature=LlmTemperature.NEUTRAL, reasoning_effort=cast(ReasoningEffort, agent_update.reasoning_effort), - icon=agent_update.icon, can_edit=True, user=users[0])) + id=agent_id, name=agent_update.name, description=agent_update.description, last_update=CURRENT_TIME, + user_id=USER_ID, model_id=cast(str, agent_update.model_id), system_prompt=cast(str, agent_update.system_prompt), + temperature=LlmTemperature.NEUTRAL, reasoning_effort=cast(ReasoningEffort, agent_update.reasoning_effort), + recursion_limit=20, icon=agent_update.icon, can_edit=True, user=users[0])) agent_prompts = await _find_agent_prompts(agent_id, client) expected_prompts = [ AgentPromptPublic( - id=last_prompt_id + 1 + i, name=p.name, content=p.content, shared=p.shared, - last_update=CURRENT_TIME, user_id=USER_ID, can_edit=True, starter=p.starter) + id=last_prompt_id + 1 + i, name=p.name, content=p.content, shared=p.shared, + last_update=CURRENT_TIME, user_id=USER_ID, can_edit=True, starter=p.starter) for i, p in enumerate(prompts)] expected_prompts.sort(key=lambda x: x.name or "") assert_response(agent_prompts, expected_prompts) @@ -172,7 +172,7 @@ async def _assert_imported_agent(agent_id: int, agent_update: AgentUpdate, promp resp = await find_agent_tool_config_files(agent_id, DOCS_TOOL_ID, client) file_id = last_file_id + 1 assert_response(resp, [FileMetadata( - id=file_id, name=file_name, status=FileStatus.PROCESSED, file_processor=FileProcessor.BASIC, + id=file_id, name=file_name, status=FileStatus.PROCESSED, file_processor=FileProcessor.BASIC, content_type="text/plain; charset=ascii", user_id=USER_ID, timestamp=CURRENT_TIME)]) resp = await _find_agent_tool_config_file_content(agent_id, DOCS_TOOL_ID, file_id, client) resp.raise_for_status() @@ -181,11 +181,11 @@ async def _assert_imported_agent(agent_id: int, agent_update: AgentUpdate, promp resp = await _find_agent_tests(agent_id, client) test_thread_id = last_thread_id + 1 assert_response(resp, [PublicTestCase( - agent_id=agent_id, thread=Thread(id=test_thread_id, name=test_case_name, user_id=USER_ID, agent_id=agent_id, creation=CURRENT_TIME, is_test_case=True), + agent_id=agent_id, thread=Thread(id=test_thread_id, name=test_case_name, user_id=USER_ID, agent_id=agent_id, creation=CURRENT_TIME, is_test_case=True), last_update=CURRENT_TIME)]) resp = await _find_test_case_messages(agent_id, test_thread_id, client) assert_response(resp, [ThreadMessagePublic( - id=last_message_id + 1 + i, thread_id=test_thread_id, origin=message.origin, text=message.text, files=[], timestamp=CURRENT_TIME) + id=last_message_id + 1 + i, thread_id=test_thread_id, origin=message.origin, text=message.text, files=[], timestamp=CURRENT_TIME) for i, message in enumerate(test_messages)]) diff --git a/src/backend/tests/test_agents.py b/src/backend/tests/test_agents.py index 7dbc60b..bd01a97 100644 --- a/src/backend/tests/test_agents.py +++ b/src/backend/tests/test_agents.py @@ -145,8 +145,8 @@ async def test_update_agent(client: AsyncClient, users: List[UserListItem], team assert_response( resp, PublicAgent(id=AGENT_ID, name=update.name, description=update.description, last_update=CURRENT_TIME, team=teams[0], user_id=USER_ID, - model_id=cast(str, update.model_id), system_prompt=cast(str, update.system_prompt), temperature=cast(LlmTemperature, update.temperature), - reasoning_effort=cast(ReasoningEffort, update.reasoning_effort), icon=update.icon, can_edit=True, user=users[0])) + model_id=cast(str, update.model_id), system_prompt=cast(str, update.system_prompt), temperature=cast(LlmTemperature, update.temperature), + reasoning_effort=cast(ReasoningEffort, update.reasoning_effort), recursion_limit=20, icon=update.icon, can_edit=True, user=users[0])) async def _update_agent(agent_id: int, update: AgentUpdate, client: AsyncClient) -> Response: @@ -354,30 +354,30 @@ async def test_delete_non_existent_file(client: AsyncClient): async def _reprocess_agent_tool_file(client: AsyncClient, file_processor: FileProcessor): await _configure_docs_tool(client, FileProcessor.BASIC if file_processor == FileProcessor.ENHANCED else FileProcessor.ENHANCED) - + file_content = await find_asset_bytes("Emma's routine.pdf") - + filename = "Emma's routine.pdf" file_id = await upload_agent_tool_config_file(AGENT_ID, DOCS_TOOL_ID, client, filename, file_content) await _await_docs_tool_file_processed(file_id, client) - + await _configure_docs_tool(client, file_processor) - + resp = await _update_agent_tool_config_file(AGENT_ID, DOCS_TOOL_ID, file_id, client, filename, file_content) resp.raise_for_status() - + await _await_docs_tool_file_processed(file_id, client) - + resp = await client.get(AGENT_TOOL_FILE_PATH.format(agent_id=AGENT_ID, tool_id=DOCS_TOOL_ID, file_id=file_id)) resp.raise_for_status() toolFile = resp.json() - + assert toolFile["fileProcessor"] == file_processor.value content = toolFile["processedContent"] - + expected_content = await find_asset_text(f"pdf_{file_processor.value.lower()}_content.txt") assert content.strip() == expected_content.strip() - + return file_id @@ -437,6 +437,7 @@ async def test_clone_agent(users: dict[int, UserListItem], last_agent_id: int,cl system_prompt="You are a helpful AI agent.", temperature=LlmTemperature.CREATIVE, reasoning_effort=ReasoningEffort.LOW, + recursion_limit=20, user=users[0] )) @@ -445,7 +446,7 @@ async def _clone_agent(agent_id: int, client: AsyncClient) -> int: resp = await client.post(f"{AGENT_PATH.format(agent_id=agent_id)}/clone") resp.raise_for_status() return resp.json()["id"] - + @pytest.fixture(name="last_prompt_id") @@ -458,9 +459,9 @@ async def test_clone_agent_prompts(last_prompt_id: int,client: AsyncClient): cloned_agent_id = await _clone_agent(AGENT_ID, client) resp = await _find_agent_prompts(cloned_agent_id, client) assert_response(resp, [ - AgentPromptPublic(id=last_prompt_id + 1, name="Test prompt private 1", content="Test prompt content", shared=False, + AgentPromptPublic(id=last_prompt_id + 1, name="Test prompt private 1", content="Test prompt content", shared=False, last_update=CURRENT_TIME, user_id=USER_ID, can_edit=True, starter=False), - AgentPromptPublic(id=last_prompt_id + 2, name="Test prompt shared", content="Test shared prompt content", shared=True, + AgentPromptPublic(id=last_prompt_id + 2, name="Test prompt shared", content="Test shared prompt content", shared=True, last_update=CURRENT_TIME, user_id=USER_ID, can_edit=True, starter=False)]) @@ -471,4 +472,4 @@ async def _find_agent_prompts(agent_id: int, client: AsyncClient) -> Response: async def test_find_default_agent(client: AsyncClient, teams: List[Team]): resp = await client.get(AGENTS_PATH + "/default") assert_response(resp, PublicAgent(id=6, name="GPT-5 Nano", description="This is the default agent", last_update=PAST_TIME, team=teams[0], user_id=None, - model_id="gpt-5-nano", system_prompt="You are a helpful AI agent.", temperature=LlmTemperature.NEUTRAL, reasoning_effort=ReasoningEffort.LOW, icon=None, can_edit=True, user=None)) + model_id="gpt-5-nano", system_prompt="You are a helpful AI agent.", temperature=LlmTemperature.NEUTRAL, reasoning_effort=ReasoningEffort.LOW, recursion_limit=20, icon=None, can_edit=True, user=None)) diff --git a/src/browser-extension/components/CopilotChat.vue b/src/browser-extension/components/CopilotChat.vue index 2225acb..a9846df 100644 --- a/src/browser-extension/components/CopilotChat.vue +++ b/src/browser-extension/components/CopilotChat.vue @@ -7,13 +7,14 @@ import { Agent } from '~/utils/agent' import ChatInput from '../../common/src/components/chat/ChatInput.vue' import AgentChatMenu from '../../common/src/components/common/AgentChatMenu.vue' import { AgentPrompt } from '../../common/src/utils/domain' +import { useToolAuthModal } from '@tero/common/utils/useToolAuthModal.js' const props = defineProps<{ messages: ChatMessage[], minimized?: boolean, agent: Agent, audioTranscriber: (blob: Blob) => Promise, - errorHandler: (error: unknown) => void + errorHandler: (error: any) => void }>() const emit = defineEmits<{ (e: 'close'): void, @@ -30,9 +31,20 @@ const inputText = ref(''); const chatInputRef = ref>() const lastMessage = computed((): ChatMessage => props.messages[props.messages.length - 1]) +const allPrompts = ref([]) +const starterPrompts = computed(() => allPrompts.value.filter(p => p.starter)) + +const findPrompts = async () => { + if (!allPrompts.value.length) { + allPrompts.value = await props.agent.getPrompts() + } + return allPrompts.value +} + +const { showToolAuthModal, toolAuthType, toolId, submitAuth, closeAuth } = useToolAuthModal() onMounted(async () => { - await chatInputRef.value?.focus(); + await chatInputRef.value?.focus() }); watch(props.messages, async () => { @@ -72,6 +84,7 @@ const adjustMessagesScroll = async () => { }]" /> +
{ ref="chatInputRef" :chat="{ supportsStopResponse: () => agent.supportsStopResponse(), - findPrompts: async () => await agent.getPrompts(), + findPrompts: async () => await findPrompts(), savePrompt: async (p: AgentPrompt) => await agent.savePrompt(p), deletePrompt: async (id: number) => await agent.deletePrompt(id), supportsFileUpload: () => false, supportsTranscriptions: () => agent.supportsTranscriptions(), transcribe: async (blob: Blob) => audioTranscriber(blob), handleError: errorHandler }" - :is-answering="!lastMessage || !lastMessage.isComplete" + :is-answering="!!(lastMessage && !lastMessage.isComplete)" :enable-prompts="true" :shareable-prompts="false" @send="emit('userMessage', inputText)" @@ -97,7 +110,13 @@ const adjustMessagesScroll = async () => {
diff --git a/src/browser-extension/components/ToolAuthModal.vue b/src/browser-extension/components/ToolAuthModal.vue new file mode 100644 index 0000000..8abca1d --- /dev/null +++ b/src/browser-extension/components/ToolAuthModal.vue @@ -0,0 +1,50 @@ + + + + + +{ + "en": { + "authenticationRequired": "Authentication required" + }, + "es": { + "authenticationRequired": "Autenticación requerida" + } +} + diff --git a/src/browser-extension/entrypoints/iframe/App.vue b/src/browser-extension/entrypoints/iframe/App.vue index 8e6ea83..2b186d7 100644 --- a/src/browser-extension/entrypoints/iframe/App.vue +++ b/src/browser-extension/entrypoints/iframe/App.vue @@ -198,7 +198,6 @@ const onAgentActivation = async (msg: AgentActivation) => { toast.error({ component: ToastMessage, props: { message: text } }, { icon: IconAlertCircleFilled }) } else { agent.value = Agent.fromJsonObject(msg.agent) - messages.value.push(ChatMessage.agentMessage(agent.value.manifest.welcomeMessage)) } } @@ -331,13 +330,13 @@ const sidebarClasses = computed(() => [ "en": { "activationError": "Could not activate {agentName}. You can try again and if the issue persists then contact [{agentName} support](mailto:{contactEmail}?subject=Activation%20issue)", "interactionSummaryError": "I could not process some information from the current site. This might impact the information and answers I provide. If the issue persists please contact [support](mailto:{contactEmail}?subject=Interaction%20issue)", - "agentAnswerError": "I am currently unable to complete your request. You can try again and if the issue persists contact [support](mailto:{contactEmail}?subject=Question%20issue)", + "agentAnswerError": "I can't help with that message. Edit it or send a new one. If the problem continues, [contact the support team](mailto:{contactEmail}?subject=Question%20issue)", "flowStepMissingElement": "I could not find the element '{selector}'. This might be due to recent changes in the page which I am not aware of. Please try again and if the issue persists contact [support](mailto:{contactEmail}?subject=Navigation%20element).", }, "es": { "activationError": "No se pudo activar {agentName}. Puedes intentar de nuevo y si el problema persiste contactar al [soporte de {agentName}](mailto:{contactEmail}?subject=Activation%20issue)", "interactionSummaryError": "No pude procesar informacion generada por la página actual. Esto puede impactar en la información y respuestas que te puedo dar. Si el problema persiste por favor contacta a [soporte](mailto:{contactEmail})?subject=Interaction%20issue", - "agentAnswerError": "Ahora no puedo completar tu pedido. Puedes intentar de nuevo y si el problema persiste contactar a [soporte](mailto:{contactEmail}?subject=Question%20issue)", + "agentAnswerError": "No puedo ayudarte con ese mensaje. Probá editarlo o enviar uno nuevo. Si el problema continúa, podés [contactar al equipo de soporte](mailto:{contactEmail}?subject=Question%20issue)", "flowStepMissingElement": "No pude encontrar el elemento '{selector}'. Esto puede ser debido a cambios recientes en la página de los cuales no tengo conocimiento. Por favor intenta de nuevo y si el problema persiste contacta a [soporte](mailto:{contactEmail}?subject=Navigation%20element).", } } diff --git a/src/browser-extension/utils/agent.ts b/src/browser-extension/utils/agent.ts index a31b815..8dbe52e 100644 --- a/src/browser-extension/utils/agent.ts +++ b/src/browser-extension/utils/agent.ts @@ -5,7 +5,7 @@ import { fetchJson, fetchStreamJson, fetchResponse, streamFromResponse, ServerSe import { AgentFlow } from "./flow" import { addAgent, findAgentById, ExistingAgentError } from "./agent-repository" import { AgentPrompt } from "../../common/src/utils/domain" -import { handleOAuthRequestsIn } from "./tool-oauth" +import { handleToolAuthRequestsIn } from "./tool-oauth" export abstract class AgentSource { abstract findAgents(authService?: AuthService): Promise; @@ -354,7 +354,7 @@ export class TeroAgent extends Agent { formData.append("parentMessageId", parentMessageId.toString()); } - yield* await handleOAuthRequestsIn( + yield* await handleToolAuthRequestsIn( async () => { const response = await fetchResponse(url, await Agent.buildHttpRequest("POST", formData, authService)); const stream = await streamFromResponse(response, url); diff --git a/src/browser-extension/utils/tool-oauth.ts b/src/browser-extension/utils/tool-oauth.ts index 8b1740b..abfe85b 100644 --- a/src/browser-extension/utils/tool-oauth.ts +++ b/src/browser-extension/utils/tool-oauth.ts @@ -3,161 +3,135 @@ import { HttpServiceError } from "./http"; import { fetchJson } from "./http"; import { Agent } from "./agent"; import { AuthService } from "./auth"; - -export const handleOAuthRequestsIn = async ( - fn: () => Promise, - agent: Agent, - authService?: AuthService -): Promise => { - while (true) { - try { - return await fn(); - } catch (e) { - const oauthRequest = parseOAuthRequest(e); - if (oauthRequest) { - await oauthExtensionFlow(oauthRequest, agent, authService); - } else { - throw e; - } - } - } +import { + createAuthFlowHandler, + OAuthRequest, + AuthenticationError, + toolAuthManager, + type HttpErrorLike +} from "@tero/common/utils/toolAuth.js"; + +const isHttpServiceError = (e: any): e is HttpErrorLike => { + return e instanceof HttpServiceError; }; -const parseOAuthRequest = (e: unknown): { url: string; state: string; agentId: number } | undefined => { - if (!(e instanceof HttpServiceError && e.status === 401 && e.body)) { - return undefined; - } +const createOAuthExtensionFlow = (agent: Agent, authService?: AuthService) => async ( + oauthRequest: OAuthRequest +): Promise => { try { - const body = JSON.parse(e.body); - if (!body.detail) { - return undefined; + const redirectUrl = await openOAuthPopup(oauthRequest); + await processOAuthCallback(redirectUrl, oauthRequest, agent, authService); + } catch (error: any) { + if (error instanceof HttpServiceError && error.status === 400 && error.body) { + let detail: string | undefined; + try { + detail = JSON.parse(error.body).detail; + } catch(e) { } + if (detail === "Authentication cancelled") { + throw new AuthenticationError("authenticationCancelled"); + } } - return new OAuthRequest(body.detail?.oauthUrl, body.detail?.oauthState, body.detail?.agentId); - } catch (_) { - return undefined; + throw error; + } finally { + toolAuthManager.hideOAuthRequired(); } }; -class OAuthRequest { - url: string; - state: string; - agentId: number; +const openOAuthPopup = (oauthRequest: OAuthRequest): Promise => { + return new Promise((resolve, reject) => { + let windowId: number | undefined; - constructor(url: string, state: string, agentId: number) { - this.url = url; - this.state = state; - this.agentId = agentId; - } -} + const handleCancel = () => { + if (windowId) { + browser.windows.remove(windowId).catch(() => {}); + } + reject(new AuthenticationError("authenticationCancelled")); + }; + + toolAuthManager.showOAuthRequired(oauthRequest.toolId, handleCancel); + + browser.windows.create({ + url: oauthRequest.url, + type: 'popup', + width: 600, + height: 600, + }).then((window) => { + if (!window?.tabs?.[0]) { + reject(new Error('Failed to create popup')); + return; + } -export class AuthenticationError extends Error { - errorCode: string; + windowId = window.id!; + const tabId = window.tabs[0].id!; + + const cleanup = () => { + browser.tabs.onUpdated.removeListener(tabUpdateListener); + browser.windows.onRemoved.removeListener(windowRemovedListener); + browser.windows.remove(windowId!).catch(() => {}); + }; + + const isCallbackUrl = (url: string): boolean => { + try { + const parsedUrl = new URL(url); + return parsedUrl.pathname.match(/\/tools\/[^\/]+\/oauth-callback/) !== null && + (parsedUrl.searchParams.has('code') || parsedUrl.searchParams.has('error')); + } catch (_) { + return false; + } + }; - constructor(errorCode: string) { - super("Authentication error: " + errorCode); - this.errorCode = errorCode; - } -} + const tabUpdateListener = (updatedTabId: number, changeInfo: { url?: string }) => { + if (updatedTabId !== tabId) return; + if (changeInfo.url && isCallbackUrl(changeInfo.url)) { + cleanup(); + resolve(changeInfo.url); + } + }; -const oauthExtensionFlow = async ( - oauthRequest: OAuthRequest, - agent: Agent, - authService?: AuthService -): Promise => { - try { - const redirectUrl = await new Promise((resolve, reject) => { - browser.windows.create({ - url: oauthRequest.url, - type: 'popup', - width: 600, - height: 600, - }).then((window) => { - if (!window?.tabs?.[0]) { - reject(new Error('Failed to create popup')); - return; + const windowRemovedListener = (removedWindowId: number) => { + if (removedWindowId === windowId) { + cleanup(); + reject(new AuthenticationError("authenticationCancelled")); } + }; - const windowId = window.id!; - const tabId = window.tabs[0].id!; - let resolved = false; - - const cleanup = () => { - if (resolved) return; - resolved = true; - browser.tabs.onUpdated.removeListener(tabUpdateListener); - browser.windows.onRemoved.removeListener(windowRemovedListener); - browser.windows.remove(windowId).catch(() => {}); - }; - - const isCallbackUrl = (url: string): boolean => { - try { - const parsedUrl = new URL(url); - return parsedUrl.pathname.match(/\/tools\/[^\/]+\/oauth-callback/) !== null && - (parsedUrl.searchParams.has('code') || parsedUrl.searchParams.has('error')); - } catch (_) { - return false; - } - }; - - const tabUpdateListener = (updatedTabId: number, changeInfo: { url?: string }) => { - if (resolved || updatedTabId !== tabId) return; - if (changeInfo.url && isCallbackUrl(changeInfo.url)) { - cleanup(); - resolve(changeInfo.url); - } - }; - - const windowRemovedListener = (removedWindowId: number) => { - if (removedWindowId === windowId && !resolved) { - cleanup(); - reject(new AuthenticationError("authenticationCancelled")); - } - }; - - browser.tabs.onUpdated.addListener(tabUpdateListener); - browser.windows.onRemoved.addListener(windowRemovedListener); - }).catch(reject); - }); - - const url = new URL(redirectUrl); - const code = url.searchParams.get("code"); - const state = url.searchParams.get("state"); - const error = url.searchParams.get("error"); - const toolId = url.pathname.match(/\/tools\/([^\/]+)\/oauth-callback/)?.[1]; - - if (state !== oauthRequest.state) { - throw new AuthenticationError("authenticationStateMismatch"); - } + browser.tabs.onUpdated.addListener(tabUpdateListener); + browser.windows.onRemoved.addListener(windowRemovedListener); + }).catch(reject); + }); +}; - if (error) { - if (toolId) { - await deleteToolAuth(agent, toolId, state!, authService); - } - throw new AuthenticationError(error == "access_denied" ? "authenticationAccessDenied" : "authenticationUnknownError"); - } +const processOAuthCallback = async ( + redirectUrl: string, + oauthRequest: OAuthRequest, + agent: Agent, + authService?: AuthService +): Promise => { + const url = new URL(redirectUrl); + const code = url.searchParams.get("code"); + const state = url.searchParams.get("state"); + const error = url.searchParams.get("error"); + const toolId = url.pathname.match(/\/tools\/([^\/]+)\/oauth-callback/)?.[1]; + + if (state !== oauthRequest.state) { + throw new AuthenticationError("authenticationStateMismatch"); + } - if (!code || !toolId) { - throw new AuthenticationError("authenticationCancelled"); + if (error) { + if (toolId) { + await deleteToolAuth(agent, toolId, state!, authService); } + throw new AuthenticationError(error == "access_denied" ? "authenticationAccessDenied" : "authenticationUnknownError"); + } - await completeToolAuth(agent, toolId, oauthRequest.agentId, state!, code, authService); - } catch (error: any) { - if (error?.message?.includes("The user canceled the sign-in flow")) { - throw new AuthenticationError("authenticationCancelled"); - } - if (error instanceof HttpServiceError && error.status === 400) { - try { - const body = JSON.parse(error.body || "{}"); - if (body.detail && body.detail === "Authentication cancelled") { - throw new AuthenticationError("authenticationCancelled"); - } - } catch (_) {} - } - throw error; + if (!code || !toolId) { + throw new AuthenticationError("authenticationCancelled"); } + + await completeToolOAuth(agent, toolId, oauthRequest.agentId, state!, code, authService); }; -const completeToolAuth = async ( +const completeToolOAuth = async ( agent: Agent, toolId: string, agentId: number, @@ -171,6 +145,19 @@ const completeToolAuth = async ( ); }; +const completeToolAuthToken = async ( + agent: Agent, + toolId: string, + agentId: number, + authToken: string, + authService?: AuthService +): Promise => { + await fetchJson( + `${agent.url}/api/tools/${toolId}/agents/${agentId}/auth`, + await Agent.buildHttpRequest("PUT", { auth_token: authToken }, authService) + ); +}; + const deleteToolAuth = async ( agent: Agent, toolId: string, @@ -182,3 +169,16 @@ const deleteToolAuth = async ( await Agent.buildHttpRequest("DELETE", undefined, authService) ); }; + +export const handleToolAuthRequestsIn = async ( + fn: () => Promise, + agent: Agent, + authService?: AuthService +): Promise => { + const handler = createAuthFlowHandler({ + isHttpError: isHttpServiceError, + handleOAuthFlow: createOAuthExtensionFlow(agent, authService), + completeAuthToken: (toolId, agentId, token) => completeToolAuthToken(agent, toolId, agentId, token, authService), + }); + return handler(fn); +}; diff --git a/src/common/src/assets/images/csv-icon.svg b/src/common/src/assets/images/csv-icon.svg index aa8a9a1..3c6df49 100644 --- a/src/common/src/assets/images/csv-icon.svg +++ b/src/common/src/assets/images/csv-icon.svg @@ -7,7 +7,7 @@ - + diff --git a/src/common/src/assets/images/jira-icon.svg b/src/common/src/assets/images/jira-icon.svg new file mode 100644 index 0000000..63917a8 --- /dev/null +++ b/src/common/src/assets/images/jira-icon.svg @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/common/src/assets/images/mcp-icon.svg b/src/common/src/assets/images/mcp-icon.svg new file mode 100644 index 0000000..6bc5a93 --- /dev/null +++ b/src/common/src/assets/images/mcp-icon.svg @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/common/src/components/chat/ChatInput.vue b/src/common/src/components/chat/ChatInput.vue index e363f89..48a9b1a 100644 --- a/src/common/src/components/chat/ChatInput.vue +++ b/src/common/src/components/chat/ChatInput.vue @@ -64,7 +64,6 @@ const MAX_FILES = 5 const fileInputRef = ref | null>(null) const attachedFiles = ref(props.initialFiles || []) const attachedFilesError = ref(undefined) -const allowedExtensions = ['pdf', 'txt', 'md', 'csv', 'xlsx', 'xls', 'png', 'jpg', 'jpeg', 'har', 'json', 'svg'] const transcriptionRef = ref | null>(null) const isRecordingAudio = ref(false); @@ -98,10 +97,6 @@ const loadAgentPrompts = async () => { } } -watch(props.chat.findPrompts, async () => { - await loadAgentPrompts() -}) - watch(inputText, async() => { if (inputText.value.startsWith('/') && props.enablePrompts !== false) { if (!isShowingPrompts.value) { @@ -151,6 +146,24 @@ const onKeydown = async (e: KeyboardEvent) => { } } +const onPaste = (event: ClipboardEvent) => { + if (!props.chat.supportsFileUpload()) return + + const items = event.clipboardData?.items + if (!items?.length) return + + const files = Array.from(items) + .filter((item) => item.kind === 'file') + .map((item) => item.getAsFile()) + .filter((file): file is File => file !== null) + + if (!files.length) return + + event.preventDefault() + resetAttachedFilesError() + ;(fileInputRef.value as { addFiles?: (files: File[]) => void } | null)?.addFiles?.(files) +} + const sendMessage = async () => { if (props.isAnswering || (inputText.value.trim() === '' && attachedFiles.value.length === 0)) { return @@ -307,7 +320,7 @@ const handleAudioReady = async (blob: Blob) => { isCancellingTranscription.value = false return } - + try { const text = await props.chat.transcribe(blob) isWaitingTranscript.value = false @@ -330,7 +343,7 @@ const handleAudioReady = async (blob: Blob) => { trimmed + inputText.value.slice(end) - + await nextTick() await focusInputAt(start + trimmed.length) } catch (error) { @@ -364,7 +377,8 @@ const openFileBrowser = () => { defineExpose({ focus, createPromptFromMessage, - selectPrompt + selectPrompt, + reloadPrompts: loadAgentPrompts }) @@ -376,7 +390,7 @@ defineExpose({ class="flex flex-col gap-1 relative rounded-xl bg-surface" :class="!borderless ? 'border focus-within:border-abstracta shadow-sm p-2' : ''">
-
-
-