From d0a79bed4311a21740f1e034833d730a4719bf1e Mon Sep 17 00:00:00 2001 From: Haze Date: Fri, 16 Jan 2026 19:11:27 -0600 Subject: [PATCH 1/4] init --- backend/chainlit/data/models.py | 160 ++++ backend/chainlit/data/sql_alchemy.py | 1139 +++++++++++------------- backend/tests/data/test_sql_alchemy.py | 109 +-- 3 files changed, 693 insertions(+), 715 deletions(-) create mode 100644 backend/chainlit/data/models.py diff --git a/backend/chainlit/data/models.py b/backend/chainlit/data/models.py new file mode 100644 index 0000000000..7392bc3820 --- /dev/null +++ b/backend/chainlit/data/models.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from typing import Any, Optional + +from sqlalchemy import Boolean, ForeignKey, Integer, String, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.types import JSON, TypeDecorator + + +class CrossDialectJSON(TypeDecorator): + """JSON type that uses JSONB on PostgreSQL and JSON on other databases.""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(JSONB()) + return dialect.type_descriptor(JSON()) + + +class Base(DeclarativeBase): + """Shared base for all ORM models. Required so Base.metadata.create_all() discovers all tables.""" + + pass + + +class UserModel(Base): + __tablename__ = "users" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + identifier: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + metadata_: Mapped[Optional[dict[str, Any]]] = mapped_column( + "metadata", CrossDialectJSON, nullable=True, default=dict + ) + createdAt: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + + threads: Mapped[list[ThreadModel]] = relationship( + back_populates="user", cascade="all, delete-orphan", passive_deletes=True + ) + + +class ThreadModel(Base): + __tablename__ = "threads" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + createdAt: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + userId: Mapped[Optional[str]] = mapped_column( + String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=True + ) + userIdentifier: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + tags: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + metadata_: Mapped[Optional[dict[str, Any]]] = mapped_column( + "metadata", CrossDialectJSON, nullable=True + ) + + user: Mapped[Optional[UserModel]] = relationship(back_populates="threads") + steps: Mapped[list[StepModel]] = relationship( + back_populates="thread", cascade="all, delete-orphan", passive_deletes=True + ) + elements: Mapped[list[ElementModel]] = relationship( + back_populates="thread", + cascade="all, delete-orphan", + passive_deletes=True, + foreign_keys="ElementModel.threadId", + ) + + +class StepModel(Base): + __tablename__ = "steps" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[str] = mapped_column(String(50), nullable=False) + threadId: Mapped[str] = mapped_column( + String(36), ForeignKey("threads.id", ondelete="CASCADE"), nullable=False + ) + parentId: Mapped[Optional[str]] = mapped_column(String(36), nullable=True) + disableFeedback: Mapped[Optional[bool]] = mapped_column( + Boolean, nullable=True, default=False + ) + streaming: Mapped[Optional[bool]] = mapped_column( + Boolean, nullable=True, default=False + ) + waitForAnswer: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + isError: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + metadata_: Mapped[Optional[dict[str, Any]]] = mapped_column( + "metadata", CrossDialectJSON, nullable=True + ) + tags: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + input: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + output: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + createdAt: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + start: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + end: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + generation: Mapped[Optional[dict[str, Any]]] = mapped_column( + CrossDialectJSON, nullable=True + ) + showInput: Mapped[Optional[str]] = mapped_column(String(20), nullable=True) + language: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + indent: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + + thread: Mapped[ThreadModel] = relationship(back_populates="steps") + feedbacks: Mapped[list[FeedbackModel]] = relationship( + back_populates="step", cascade="all, delete-orphan", passive_deletes=True + ) + elements: Mapped[list[ElementModel]] = relationship( + back_populates="step", + cascade="all, delete-orphan", + passive_deletes=True, + foreign_keys="ElementModel.forId", + ) + + +class ElementModel(Base): + __tablename__ = "elements" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + threadId: Mapped[Optional[str]] = mapped_column( + String(36), ForeignKey("threads.id", ondelete="CASCADE"), nullable=True + ) + type: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + url: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + chainlitKey: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + display: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + objectKey: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + size: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + page: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + language: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) + forId: Mapped[Optional[str]] = mapped_column( + String(36), ForeignKey("steps.id", ondelete="CASCADE"), nullable=True + ) + mime: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + props: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + autoPlay: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + playerConfig: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + thread: Mapped[Optional[ThreadModel]] = relationship( + back_populates="elements", foreign_keys=[threadId] + ) + step: Mapped[Optional[StepModel]] = relationship( + back_populates="elements", foreign_keys=[forId] + ) + + +class FeedbackModel(Base): + __tablename__ = "feedbacks" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + forId: Mapped[str] = mapped_column( + String(36), ForeignKey("steps.id", ondelete="CASCADE"), nullable=False + ) + threadId: Mapped[str] = mapped_column(String(36), nullable=False) + value: Mapped[int] = mapped_column(Integer, nullable=False) + comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + step: Mapped[StepModel] = relationship(back_populates="feedbacks") diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index 8d69c1d9fb..b9ae0ad46e 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -1,18 +1,31 @@ +from __future__ import annotations + import json import ssl import uuid -from dataclasses import asdict +from contextlib import asynccontextmanager from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import aiofiles import aiohttp -from sqlalchemy import text +from sqlalchemy import delete, func, or_, select +from sqlalchemy.dialects.mysql import insert as mysql_insert +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import selectinload, sessionmaker from chainlit.data.base import BaseDataLayer +from chainlit.data.models import ( + Base, + ElementModel, + FeedbackModel, + StepModel, + ThreadModel, + UserModel, +) from chainlit.data.storage_clients.base import BaseStorageClient from chainlit.data.utils import queue_until_user_message from chainlit.element import ElementDict @@ -30,8 +43,7 @@ from chainlit.user import PersistedUser, User if TYPE_CHECKING: - from chainlit.element import Element, ElementDict - from chainlit.step import StepDict + from chainlit.element import Element class SQLAlchemyDataLayer(BaseDataLayer): @@ -43,6 +55,7 @@ def __init__( storage_provider: Optional[BaseStorageClient] = None, user_thread_limit: Optional[int] = 1000, show_logger: Optional[bool] = False, + create_tables: bool = False, ): self._conninfo = conninfo self.user_thread_limit = user_thread_limit @@ -50,7 +63,6 @@ def __init__( if connect_args is None: connect_args = {} if ssl_require: - # Create an SSL context to require an SSL connection ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE @@ -61,6 +73,10 @@ def __init__( self.async_session = sessionmaker( bind=self.engine, expire_on_commit=False, class_=AsyncSession ) # type: ignore + self._dialect_name = self.engine.dialect.name + self._tables_created = False + self._create_tables_on_init = create_tables + if storage_provider: self.storage_provider: Optional[BaseStorageClient] = storage_provider if self.show_logger: @@ -71,236 +87,245 @@ def __init__( "SQLAlchemyDataLayer storage client is not initialized and elements will not be persisted!" ) - async def build_debug_url(self) -> str: - return "" + async def _ensure_tables(self): + if self._create_tables_on_init and not self._tables_created: + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True - ###### SQL Helpers ###### - async def execute_sql( - self, query: str, parameters: dict - ) -> Union[List[Dict[str, Any]], int, None]: - parameterized_query = text(query) + @asynccontextmanager + async def _session(self): + await self._ensure_tables() async with self.async_session() as session: try: - await session.begin() - result = await session.execute(parameterized_query, parameters) + yield session await session.commit() - if result.returns_rows: - json_result = [dict(row._mapping) for row in result.fetchall()] - clean_json_result = self.clean_result(json_result) - assert isinstance(clean_json_result, list) or isinstance( - clean_json_result, int - ) - return clean_json_result - else: - return result.rowcount except SQLAlchemyError as e: await session.rollback() - logger.warning(f"An error occurred: {e}") - return None + logger.warning(f"Database error: {e}") + raise except Exception as e: await session.rollback() - logger.warning(f"An unexpected error occurred: {e}") - return None + logger.warning(f"Unexpected error: {e}") + raise + + def _get_upsert( + self, model: type, values: dict[str, Any], index_elements: list[str] + ) -> Any: + if self._dialect_name == "postgresql": + stmt = pg_insert(model).values(**values) + update_dict = { + k: stmt.excluded[k] for k in values if k not in index_elements + } + return stmt.on_conflict_do_update( + index_elements=index_elements, set_=update_dict + ) + elif self._dialect_name == "sqlite": + stmt = sqlite_insert(model).values(**values) # type: ignore[arg-type] + update_dict = { + k: stmt.excluded[k] for k in values if k not in index_elements + } + return stmt.on_conflict_do_update( + index_elements=index_elements, set_=update_dict + ) + elif self._dialect_name in ("mysql", "mariadb"): + stmt = mysql_insert(model).values(**values) # type: ignore[arg-type] + update_dict = {k: v for k, v in values.items() if k not in index_elements} + return stmt.on_duplicate_key_update(**update_dict) # type: ignore[attr-defined] + else: + stmt = pg_insert(model).values(**values) + update_dict = { + k: stmt.excluded[k] for k in values if k not in index_elements + } + return stmt.on_conflict_do_update( + index_elements=index_elements, set_=update_dict + ) + + async def build_debug_url(self) -> str: + return "" async def get_current_timestamp(self) -> str: return datetime.now().isoformat() + "Z" - def clean_result(self, obj): - """Recursively change UUID -> str and serialize dictionaries""" - if isinstance(obj, dict): - return {k: self.clean_result(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [self.clean_result(item) for item in obj] - elif isinstance(obj, uuid.UUID): - return str(obj) - return obj - - ###### User ###### + def _parse_json(self, value: Any) -> dict: + if value is None: + return {} + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + return {} + return {} + async def get_user(self, identifier: str) -> Optional[PersistedUser]: if self.show_logger: logger.info(f"SQLAlchemy: get_user, identifier={identifier}") - query = "SELECT * FROM users WHERE identifier = :identifier" - parameters = {"identifier": identifier} - result = await self.execute_sql(query=query, parameters=parameters) - if result and isinstance(result, list): - user_data = result[0] - - # SQLite returns JSON as string, we most convert it. (#1137) - metadata = user_data.get("metadata", {}) - if isinstance(metadata, str): - metadata = json.loads(metadata) - - assert isinstance(metadata, dict) - assert isinstance(user_data["id"], str) - assert isinstance(user_data["identifier"], str) - assert isinstance(user_data["createdAt"], str) - + async with self._session() as session: + stmt = select(UserModel).where(UserModel.identifier == identifier) + result = await session.execute(stmt) + user = result.scalar_one_or_none() + if not user: + return None return PersistedUser( - id=user_data["id"], - identifier=user_data["identifier"], - createdAt=user_data["createdAt"], - metadata=metadata, + id=user.id, + identifier=user.identifier, + createdAt=user.createdAt or "", + metadata=self._parse_json(user.metadata_), ) - return None - async def _get_user_identifer_by_id(self, user_id: str) -> str: + async def _get_user_identifier_by_id(self, user_id: str) -> str: if self.show_logger: - logger.info(f"SQLAlchemy: _get_user_identifer_by_id, user_id={user_id}") - query = "SELECT identifier FROM users WHERE id = :user_id" - parameters = {"user_id": user_id} - result = await self.execute_sql(query=query, parameters=parameters) - - assert result - assert isinstance(result, list) - - return result[0]["identifier"] + logger.info(f"SQLAlchemy: _get_user_identifier_by_id, user_id={user_id}") + async with self._session() as session: + stmt = select(UserModel.identifier).where(UserModel.id == user_id) + result = await session.execute(stmt) + identifier = result.scalar_one_or_none() + if identifier is None: + raise ValueError(f"User not found: {user_id}") + return identifier async def _get_user_id_by_thread(self, thread_id: str) -> Optional[str]: if self.show_logger: logger.info(f"SQLAlchemy: _get_user_id_by_thread, thread_id={thread_id}") - query = """SELECT "userId" FROM threads WHERE id = :thread_id""" - parameters = {"thread_id": thread_id} - result = await self.execute_sql(query=query, parameters=parameters) - if result: - assert isinstance(result, list) - return result[0]["userId"] - - return None + async with self._session() as session: + stmt = select(ThreadModel.userId).where(ThreadModel.id == thread_id) + result = await session.execute(stmt) + return result.scalar_one_or_none() async def create_user(self, user: User) -> Optional[PersistedUser]: if self.show_logger: logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}") - existing_user: Optional[PersistedUser] = await self.get_user(user.identifier) - user_dict: Dict[str, Any] = { - "identifier": str(user.identifier), - "metadata": json.dumps(user.metadata) or {}, - } - if not existing_user: # create the user - if self.show_logger: - logger.info("SQLAlchemy: create_user, creating the user") - user_dict["id"] = str(uuid.uuid4()) - user_dict["createdAt"] = await self.get_current_timestamp() - query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)""" - await self.execute_sql(query=query, parameters=user_dict) - else: # update the user - if self.show_logger: - logger.info("SQLAlchemy: update user metadata") - query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier""" - await self.execute_sql( - query=query, parameters=user_dict - ) # We want to update the metadata + existing_user = await self.get_user(user.identifier) + async with self._session() as session: + if not existing_user: + if self.show_logger: + logger.info("SQLAlchemy: create_user, creating the user") + new_user = UserModel( + id=str(uuid.uuid4()), + identifier=user.identifier, + metadata_=user.metadata or {}, + createdAt=await self.get_current_timestamp(), + ) + session.add(new_user) + else: + if self.show_logger: + logger.info("SQLAlchemy: update user metadata") + stmt = select(UserModel).where(UserModel.identifier == user.identifier) + result = await session.execute(stmt) + db_user = result.scalar_one() + db_user.metadata_ = user.metadata or {} return await self.get_user(user.identifier) - ###### Threads ###### async def get_thread_author(self, thread_id: str) -> str: if self.show_logger: logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}") - query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id""" - parameters = {"id": thread_id} - result = await self.execute_sql(query=query, parameters=parameters) - if isinstance(result, list) and result: - author_identifier = result[0].get("userIdentifier") - if author_identifier is not None: - return author_identifier - raise ValueError(f"Author not found for thread_id {thread_id}") + async with self._session() as session: + stmt = select(ThreadModel.userIdentifier).where(ThreadModel.id == thread_id) + result = await session.execute(stmt) + author = result.scalar_one_or_none() + if author is None: + raise ValueError(f"Author not found for thread_id {thread_id}") + return author async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: if self.show_logger: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}") - user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads( - thread_id=thread_id - ) + user_threads = await self.get_all_user_threads(thread_id=thread_id) if user_threads: return user_threads[0] - else: - return None + return None async def update_thread( self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, - metadata: Optional[Dict] = None, - tags: Optional[List[str]] = None, + metadata: Optional[dict] = None, + tags: Optional[list[str]] = None, ): if self.show_logger: logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}") user_identifier = None if user_id: - user_identifier = await self._get_user_identifer_by_id(user_id) - - if metadata is not None: - existing = await self.execute_sql( - query='SELECT "metadata" FROM threads WHERE "id" = :id', - parameters={"id": thread_id}, - ) - base = {} - if isinstance(existing, list) and existing: - raw = existing[0].get("metadata") or {} - if isinstance(raw, str): - try: - base = json.loads(raw) - except json.JSONDecodeError: - base = {} - elif isinstance(raw, dict): - base = raw - incoming = {k: v for k, v in metadata.items() if v is not None} - metadata = {**base, **incoming} - - name_value = name - if name_value is None and metadata: - name_value = metadata.get("name") - created_at_value = ( - await self.get_current_timestamp() if metadata is None else None - ) - - data = { - "id": thread_id, - "createdAt": created_at_value, - "name": name_value, - "userId": user_id, - "userIdentifier": user_identifier, - "tags": tags, - "metadata": json.dumps(metadata) if metadata else None, - } - parameters = { - key: value for key, value in data.items() if value is not None - } # Remove keys with None values - columns = ", ".join(f'"{key}"' for key in parameters.keys()) - values = ", ".join(f":{key}" for key in parameters.keys()) - updates = ", ".join( - f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != "id" - ) - query = f""" - INSERT INTO threads ({columns}) - VALUES ({values}) - ON CONFLICT ("id") DO UPDATE - SET {updates}; - """ - await self.execute_sql(query=query, parameters=parameters) + user_identifier = await self._get_user_identifier_by_id(user_id) + + async with self._session() as session: + stmt = select(ThreadModel).where(ThreadModel.id == thread_id) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + + merged_metadata = {} + if existing and existing.metadata_: + merged_metadata = self._parse_json(existing.metadata_) + + if metadata is not None: + incoming = {k: v for k, v in metadata.items() if v is not None} + merged_metadata = {**merged_metadata, **incoming} + + name_value = name + if name_value is None and metadata: + name_value = metadata.get("name") + + if existing: + if name_value is not None: + existing.name = name_value + if user_id is not None: + existing.userId = user_id + if user_identifier is not None: + existing.userIdentifier = user_identifier + if tags is not None: + existing.tags = json.dumps(tags) + if metadata is not None: + existing.metadata_ = merged_metadata + else: + new_thread = ThreadModel( + id=thread_id, + createdAt=await self.get_current_timestamp(), + name=name_value, + userId=user_id, + userIdentifier=user_identifier, + tags=json.dumps(tags) if tags else None, + metadata_=merged_metadata if metadata else None, + ) + session.add(new_thread) async def delete_thread(self, thread_id: str): if self.show_logger: logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}") - elements_query = """SELECT * FROM elements WHERE "threadId" = :id""" - elements = await self.execute_sql(elements_query, {"id": thread_id}) + async with self._session() as session: + stmt = select(ElementModel).where(ElementModel.threadId == thread_id) + result = await session.execute(stmt) + elements = result.scalars().all() + + if self.storage_provider is not None: + for elem in elements: + if elem.objectKey: + await self.storage_provider.delete_file( + object_key=elem.objectKey + ) + + step_ids_stmt = select(StepModel.id).where(StepModel.threadId == thread_id) + step_ids_result = await session.execute(step_ids_stmt) + step_ids = [row[0] for row in step_ids_result.fetchall()] - if self.storage_provider is not None and isinstance(elements, list): - for elem in filter(lambda x: x["objectKey"], elements): - await self.storage_provider.delete_file(object_key=elem["objectKey"]) + if step_ids: + await session.execute( + delete(FeedbackModel).where(FeedbackModel.forId.in_(step_ids)) + ) - # Delete feedbacks/elements/steps/thread - feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)""" - elements_query = """DELETE FROM elements WHERE "threadId" = :id""" - steps_query = """DELETE FROM steps WHERE "threadId" = :id""" - thread_query = """DELETE FROM threads WHERE "id" = :id""" - parameters = {"id": thread_id} - await self.execute_sql(query=feedbacks_query, parameters=parameters) - await self.execute_sql(query=elements_query, parameters=parameters) - await self.execute_sql(query=steps_query, parameters=parameters) - await self.execute_sql(query=thread_query, parameters=parameters) + await session.execute( + delete(ElementModel).where(ElementModel.threadId == thread_id) + ) + await session.execute( + delete(StepModel).where(StepModel.threadId == thread_id) + ) + await session.execute( + delete(ThreadModel).where(ThreadModel.id == thread_id) + ) async def list_threads( self, pagination: Pagination, filters: ThreadFilter @@ -311,7 +336,7 @@ async def list_threads( ) if not filters.userId: raise ValueError("userId is required") - all_user_threads: List[ThreadDict] = ( + all_user_threads: list[ThreadDict] = ( await self.get_all_user_threads(user_id=filters.userId) or [] ) @@ -330,7 +355,7 @@ async def list_threads( if "output" in step ) if feedback_value is not None: - feedback_match = False # Assume no match until found + feedback_match = False for step in thread["steps"]: feedback = step.get("feedback") if feedback and feedback.get("value") == feedback_value: @@ -342,9 +367,7 @@ async def list_threads( start = 0 if pagination.cursor: for i, thread in enumerate(filtered_threads): - if ( - thread["id"] == pagination.cursor - ): # Find the start index using pagination.cursor + if thread["id"] == pagination.cursor: start = i + 1 break end = start + pagination.first @@ -363,41 +386,57 @@ async def list_threads( data=paginated_threads, ) - ###### Steps ###### @queue_until_user_message() - async def create_step(self, step_dict: "StepDict"): + async def create_step(self, step_dict: StepDict): await self.update_thread(step_dict["threadId"]) if self.show_logger: logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}") - step_dict["showInput"] = ( + show_input = ( str(step_dict.get("showInput", "")).lower() if "showInput" in step_dict else None ) - parameters = { - key: value - for key, value in step_dict.items() - if value is not None and not (isinstance(value, dict) and not value) + + values = { + "id": step_dict["id"], + "name": step_dict.get("name", ""), + "type": step_dict.get("type", ""), + "threadId": step_dict["threadId"], + "parentId": step_dict.get("parentId"), + "streaming": step_dict.get("streaming", False), + "waitForAnswer": step_dict.get("waitForAnswer"), + "isError": step_dict.get("isError"), + "metadata_": step_dict.get("metadata", {}), + "tags": json.dumps(step_dict.get("tags")) + if step_dict.get("tags") + else None, + "input": step_dict.get("input"), + "output": step_dict.get("output"), + "createdAt": step_dict.get("createdAt"), + "start": step_dict.get("start"), + "end": step_dict.get("end"), + "generation": step_dict.get("generation", {}), + "showInput": show_input, + "language": step_dict.get("language"), } - parameters["metadata"] = json.dumps(step_dict.get("metadata", {})) - parameters["generation"] = json.dumps(step_dict.get("generation", {})) - columns = ", ".join(f'"{key}"' for key in parameters.keys()) - values = ", ".join(f":{key}" for key in parameters.keys()) - updates = ", ".join( - f'"{key}" = :{key}' for key in parameters.keys() if key != "id" - ) - query = f""" - INSERT INTO steps ({columns}) - VALUES ({values}) - ON CONFLICT (id) DO UPDATE - SET {updates}; - """ - await self.execute_sql(query=query, parameters=parameters) + values = { + k: v + for k, v in values.items() + if v is not None or k in ("parentId", "waitForAnswer", "isError") + } + if "name" not in values: + values["name"] = "" + if "type" not in values: + values["type"] = "" + + async with self._session() as session: + stmt = self._get_upsert(StepModel, values, ["id"]) + await session.execute(stmt) @queue_until_user_message() - async def update_step(self, step_dict: "StepDict"): + async def update_step(self, step_dict: StepDict): if self.show_logger: logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}") await self.create_step(step_dict) @@ -406,161 +445,125 @@ async def update_step(self, step_dict: "StepDict"): async def delete_step(self, step_id: str): if self.show_logger: logger.info(f"SQLAlchemy: delete_step, step_id={step_id}") - # Delete feedbacks/elements/steps - feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id""" - elements_query = """DELETE FROM elements WHERE "forId" = :id""" - steps_query = """DELETE FROM steps WHERE "id" = :id""" - parameters = {"id": step_id} - await self.execute_sql(query=feedbacks_query, parameters=parameters) - await self.execute_sql(query=elements_query, parameters=parameters) - await self.execute_sql(query=steps_query, parameters=parameters) - - async def get_step(self, step_id: str) -> Optional["StepDict"]: + async with self._session() as session: + await session.execute( + delete(FeedbackModel).where(FeedbackModel.forId == step_id) + ) + await session.execute( + delete(ElementModel).where(ElementModel.forId == step_id) + ) + await session.execute(delete(StepModel).where(StepModel.id == step_id)) + + async def get_step(self, step_id: str) -> Optional[StepDict]: if self.show_logger: logger.info(f"SQLAlchemy: get_step, step_id={step_id}") - steps_feedbacks_query = """ - SELECT - s."id" AS step_id, - s."name" AS step_name, - s."type" AS step_type, - s."threadId" AS step_threadid, - s."parentId" AS step_parentid, - s."streaming" AS step_streaming, - s."waitForAnswer" AS step_waitforanswer, - s."isError" AS step_iserror, - s."metadata" AS step_metadata, - s."tags" AS step_tags, - s."input" AS step_input, - s."output" AS step_output, - s."createdAt" AS step_createdat, - s."start" AS step_start, - s."end" AS step_end, - s."generation" AS step_generation, - s."showInput" AS step_showinput, - s."language" AS step_language, - f."value" AS feedback_value, - f."comment" AS feedback_comment, - f."id" AS feedback_id - FROM steps s LEFT JOIN feedbacks f ON s."id" = f."forId" - WHERE s."id" = :step_id - """ - steps_feedbacks = await self.execute_sql( - query=steps_feedbacks_query, parameters={"step_id": step_id} - ) - - if not isinstance(steps_feedbacks, list) or not steps_feedbacks: - return None + async with self._session() as session: + stmt = ( + select(StepModel) + .options(selectinload(StepModel.feedbacks)) + .where(StepModel.id == step_id) + ) + result = await session.execute(stmt) + step = result.scalar_one_or_none() + if not step: + return None - step_feedback = steps_feedbacks[0] + feedback = None + if step.feedbacks: + fb = step.feedbacks[0] + feedback = FeedbackDict( + forId=step.id, + id=fb.id, + value=fb.value, + comment=fb.comment, + ) - feedback = None - if step_feedback["feedback_value"] is not None: - feedback = FeedbackDict( - forId=step_feedback["step_id"], - id=step_feedback.get("feedback_id"), - value=step_feedback["feedback_value"], - comment=step_feedback.get("feedback_comment"), + return StepDict( + id=step.id, + name=step.name, + type=step.type, + threadId=step.threadId, + parentId=step.parentId, + streaming=step.streaming or False, + waitForAnswer=step.waitForAnswer, + isError=step.isError, + metadata=self._parse_json(step.metadata_), + tags=json.loads(step.tags) if step.tags else None, + input=step.input if step.showInput not in (None, "false") else "", + output=step.output or "", + createdAt=step.createdAt, + start=step.start, + end=step.end, + generation=self._parse_json(step.generation), + showInput=step.showInput, + language=step.language, + feedback=feedback, ) - return StepDict( - id=step_feedback["step_id"], - name=step_feedback["step_name"], - type=step_feedback["step_type"], - threadId=step_feedback.get("step_threadid", ""), - parentId=step_feedback.get("step_parentid"), - streaming=step_feedback.get("step_streaming", False), - waitForAnswer=step_feedback.get("step_waitforanswer"), - isError=step_feedback.get("step_iserror"), - metadata=( - step_feedback["step_metadata"] - if step_feedback.get("step_metadata") is not None - else {} - ), - tags=step_feedback.get("step_tags"), - input=( - step_feedback.get("step_input", "") - if step_feedback.get("step_showinput") not in [None, "false"] - else "" - ), - output=step_feedback.get("step_output", ""), - createdAt=step_feedback.get("step_createdat"), - start=step_feedback.get("step_start"), - end=step_feedback.get("step_end"), - generation=step_feedback.get("step_generation"), - showInput=step_feedback.get("step_showinput"), - language=step_feedback.get("step_language"), - feedback=feedback, - ) - ###### Feedback ###### async def upsert_feedback(self, feedback: Feedback) -> str: if self.show_logger: logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}") - feedback.id = feedback.id or str(uuid.uuid4()) - feedback_dict = asdict(feedback) - parameters = { - key: value for key, value in feedback_dict.items() if value is not None + feedback_id = feedback.id or str(uuid.uuid4()) + + values = { + "id": feedback_id, + "forId": feedback.forId, + "threadId": feedback.threadId or "", + "value": feedback.value, + "comment": feedback.comment, } + values = {k: v for k, v in values.items() if v is not None} - columns = ", ".join(f'"{key}"' for key in parameters.keys()) - values = ", ".join(f":{key}" for key in parameters.keys()) - updates = ", ".join( - f'"{key}" = :{key}' for key in parameters.keys() if key != "id" - ) - query = f""" - INSERT INTO feedbacks ({columns}) - VALUES ({values}) - ON CONFLICT (id) DO UPDATE - SET {updates}; - """ - await self.execute_sql(query=query, parameters=parameters) - return feedback.id + async with self._session() as session: + stmt = self._get_upsert(FeedbackModel, values, ["id"]) + await session.execute(stmt) + return feedback_id async def delete_feedback(self, feedback_id: str) -> bool: if self.show_logger: logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}") - query = """DELETE FROM feedbacks WHERE "id" = :feedback_id""" - parameters = {"feedback_id": feedback_id} - await self.execute_sql(query=query, parameters=parameters) + async with self._session() as session: + await session.execute( + delete(FeedbackModel).where(FeedbackModel.id == feedback_id) + ) return True - ###### Elements ###### async def get_element( self, thread_id: str, element_id: str - ) -> Optional["ElementDict"]: + ) -> Optional[ElementDict]: if self.show_logger: logger.info( f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}" ) - query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id""" - parameters = {"thread_id": thread_id, "element_id": element_id} - element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql( - query=query, parameters=parameters - ) - if isinstance(element, list) and element: - element_dict: Dict[str, Any] = element[0] + async with self._session() as session: + stmt = select(ElementModel).where( + ElementModel.threadId == thread_id, ElementModel.id == element_id + ) + result = await session.execute(stmt) + elem = result.scalar_one_or_none() + if not elem: + return None return ElementDict( - id=element_dict["id"], - threadId=element_dict.get("threadId"), - type=element_dict["type"], - chainlitKey=element_dict.get("chainlitKey"), - url=element_dict.get("url"), - objectKey=element_dict.get("objectKey"), - name=element_dict["name"], - props=json.loads(element_dict.get("props", "{}")), - display=element_dict["display"], - size=element_dict.get("size"), - language=element_dict.get("language"), - page=element_dict.get("page"), - autoPlay=element_dict.get("autoPlay"), - playerConfig=element_dict.get("playerConfig"), - forId=element_dict.get("forId"), - mime=element_dict.get("mime"), + id=elem.id, + threadId=elem.threadId, + type=elem.type, + chainlitKey=elem.chainlitKey, + url=elem.url, + objectKey=elem.objectKey, + name=elem.name, + props=self._parse_json(elem.props), + display=elem.display, + size=elem.size, + language=elem.language, + page=elem.page, + autoPlay=elem.autoPlay, + playerConfig=elem.playerConfig, + forId=elem.forId, + mime=elem.mime, ) - else: - return None @queue_until_user_message() - async def create_element(self, element: "Element"): + async def create_element(self, element: Element): if self.show_logger: logger.info(f"SQLAlchemy: create_element, element_id = {element.id}") @@ -578,8 +581,8 @@ async def create_element(self, element: "Element"): async with aiofiles.open(element.path, "rb") as f: content = await f.read() elif element.url: - async with aiohttp.ClientSession() as session: - async with session.get(element.url) as response: + async with aiohttp.ClientSession() as http_session: + async with http_session.get(element.url) as response: if response.status == 200: content = await response.read() else: @@ -591,7 +594,7 @@ async def create_element(self, element: "Element"): if content is None: raise ValueError("Content is None, cannot upload file") - user_id: str = await self._get_user_id_by_thread(element.thread_id) or "unknown" + user_id = await self._get_user_id_by_thread(element.thread_id) or "unknown" file_object_key = f"{user_id}/{element.id}" + ( f"/{element.name}" if element.name else "" ) @@ -607,329 +610,245 @@ async def create_element(self, element: "Element"): "SQLAlchemy Error: create_element, Failed to persist data in storage_provider" ) - element_dict: ElementDict = element.to_dict() - - element_dict["url"] = uploaded_file.get("url") - element_dict["objectKey"] = uploaded_file.get("object_key") - - element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None} - if "props" in element_dict_cleaned: - element_dict_cleaned["props"] = json.dumps(element_dict_cleaned["props"]) + element_dict = element.to_dict() + + values = { + "id": element.id, + "threadId": element.thread_id, + "type": element.type, + "url": uploaded_file.get("url"), + "chainlitKey": element.chainlit_key, + "name": element.name, + "display": element.display, + "objectKey": uploaded_file.get("object_key"), + "size": element.size, + "page": getattr(element, "page", None), + "language": element.language, + "forId": element.for_id, + "mime": element.mime, + "props": json.dumps(element_dict.get("props", {})), + "autoPlay": getattr(element, "autoPlay", None), + "playerConfig": getattr(element, "playerConfig", None), + } + values = {k: v for k, v in values.items() if v is not None} + if "name" not in values: + values["name"] = "" - columns = ", ".join(f'"{column}"' for column in element_dict_cleaned.keys()) - placeholders = ", ".join(f":{column}" for column in element_dict_cleaned.keys()) - updates = ", ".join( - f'"{column}" = :{column}' - for column in element_dict_cleaned.keys() - if column != "id" - ) - query = f"INSERT INTO elements ({columns}) VALUES ({placeholders}) ON CONFLICT (id) DO UPDATE SET {updates};" - await self.execute_sql(query=query, parameters=element_dict_cleaned) + async with self._session() as session: + stmt = self._get_upsert(ElementModel, values, ["id"]) + await session.execute(stmt) @queue_until_user_message() async def delete_element(self, element_id: str, thread_id: Optional[str] = None): if self.show_logger: logger.info(f"SQLAlchemy: delete_element, element_id={element_id}") - query = """SELECT * FROM elements WHERE "id" = :id""" - elements = await self.execute_sql(query, {"id": element_id}) - - if ( - self.storage_provider is not None - and isinstance(elements, list) - and len(elements) > 0 - and elements[0]["objectKey"] - ): - await self.storage_provider.delete_file(object_key=elements[0]["objectKey"]) + async with self._session() as session: + stmt = select(ElementModel).where(ElementModel.id == element_id) + result = await session.execute(stmt) + elem = result.scalar_one_or_none() - query = """DELETE FROM elements WHERE "id" = :id""" - parameters = {"id": element_id} + if elem and self.storage_provider and elem.objectKey: + await self.storage_provider.delete_file(object_key=elem.objectKey) - await self.execute_sql(query=query, parameters=parameters) + await session.execute( + delete(ElementModel).where(ElementModel.id == element_id) + ) async def get_all_user_threads( self, user_id: Optional[str] = None, thread_id: Optional[str] = None - ) -> Optional[List[ThreadDict]]: - """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided.""" + ) -> Optional[list[ThreadDict]]: if self.show_logger: logger.info("SQLAlchemy: get_all_user_threads") - user_threads_query = """ - SELECT - t."id" AS thread_id, - t."createdAt" AS thread_createdat, - t."name" AS thread_name, - t."userId" AS user_id, - t."userIdentifier" AS user_identifier, - t."tags" AS thread_tags, - t."metadata" AS thread_metadata, - MAX(s."createdAt") AS updatedAt - FROM threads t - LEFT JOIN steps s ON t."id" = s."threadId" - WHERE t."userId" = :user_id OR t."id" = :thread_id - GROUP BY - t."id", - t."createdAt", - t."name", - t."userId", - t."userIdentifier", - t."tags", - t."metadata" - ORDER BY updatedAt DESC NULLS LAST - LIMIT :limit - """ - user_threads = await self.execute_sql( - query=user_threads_query, - parameters={ - "user_id": user_id, - "limit": self.user_thread_limit, - "thread_id": thread_id, - }, - ) - if not isinstance(user_threads, list): - return None - if not user_threads: - return [] - else: - thread_ids = ( - "('" - + "','".join(map(str, [thread["thread_id"] for thread in user_threads])) - + "')" - ) - steps_feedbacks_query = f""" - SELECT - s."id" AS step_id, - s."name" AS step_name, - s."type" AS step_type, - s."threadId" AS step_threadid, - s."parentId" AS step_parentid, - s."streaming" AS step_streaming, - s."waitForAnswer" AS step_waitforanswer, - s."isError" AS step_iserror, - s."metadata" AS step_metadata, - s."tags" AS step_tags, - s."input" AS step_input, - s."output" AS step_output, - s."createdAt" AS step_createdat, - s."start" AS step_start, - s."end" AS step_end, - s."generation" AS step_generation, - s."showInput" AS step_showinput, - s."language" AS step_language, - f."value" AS feedback_value, - f."comment" AS feedback_comment, - f."id" AS feedback_id - FROM steps s LEFT JOIN feedbacks f ON s."id" = f."forId" - WHERE s."threadId" IN {thread_ids} - ORDER BY s."createdAt" ASC - """ - steps_feedbacks = await self.execute_sql( - query=steps_feedbacks_query, parameters={} - ) + async with self._session() as session: + subq = ( + select( + StepModel.threadId, + func.max(StepModel.createdAt).label("max_created"), + ) + .group_by(StepModel.threadId) + .subquery() + ) - elements_query = f""" - SELECT - e."id" AS element_id, - e."threadId" as element_threadid, - e."type" AS element_type, - e."chainlitKey" AS element_chainlitkey, - e."url" AS element_url, - e."objectKey" as element_objectkey, - e."name" AS element_name, - e."display" AS element_display, - e."size" AS element_size, - e."language" AS element_language, - e."page" AS element_page, - e."forId" AS element_forid, - e."mime" AS element_mime, - e."props" AS props - FROM elements e - WHERE e."threadId" IN {thread_ids} - """ - elements = await self.execute_sql(query=elements_query, parameters={}) - - thread_dicts = {} - for thread in user_threads: - thread_id = thread["thread_id"] + conditions = [] + if user_id is not None: + conditions.append(ThreadModel.userId == user_id) if thread_id is not None: - thread_dicts[thread_id] = ThreadDict( - id=thread_id, - createdAt=thread["thread_createdat"], - name=thread["thread_name"], - userId=thread["user_id"], - userIdentifier=thread["user_identifier"], - tags=thread["thread_tags"], - metadata=thread["thread_metadata"], + conditions.append(ThreadModel.id == thread_id) + + if not conditions: + return [] + + stmt = ( + select(ThreadModel, subq.c.max_created) + .outerjoin(subq, ThreadModel.id == subq.c.threadId) + .where(or_(*conditions)) + .order_by(subq.c.max_created.desc().nulls_last()) + .limit(self.user_thread_limit) + ) + + result = await session.execute(stmt) + rows = result.fetchall() + + if not rows: + return [] + + thread_ids = [row[0].id for row in rows] + + steps_stmt = ( + select(StepModel) + .options(selectinload(StepModel.feedbacks)) + .where(StepModel.threadId.in_(thread_ids)) + .order_by(StepModel.createdAt.asc()) + ) + steps_result = await session.execute(steps_stmt) + all_steps = steps_result.scalars().all() + + elements_stmt = select(ElementModel).where( + ElementModel.threadId.in_(thread_ids) + ) + elements_result = await session.execute(elements_stmt) + all_elements = elements_result.scalars().all() + + thread_dicts: dict[str, ThreadDict] = {} + for row in rows: + thread = row[0] + thread_dicts[thread.id] = ThreadDict( + id=thread.id, + createdAt=thread.createdAt, + name=thread.name, + userId=thread.userId, + userIdentifier=thread.userIdentifier, + tags=json.loads(thread.tags) if thread.tags else None, + metadata=self._parse_json(thread.metadata_), steps=[], elements=[], ) - # Process steps_feedbacks to populate the steps in the corresponding ThreadDict - if isinstance(steps_feedbacks, list): - for step_feedback in steps_feedbacks: - thread_id = step_feedback["step_threadid"] - if thread_id is not None: + + for step in all_steps: + if step.threadId in thread_dicts: feedback = None - if step_feedback["feedback_value"] is not None: + if step.feedbacks: + fb = step.feedbacks[0] feedback = FeedbackDict( - forId=step_feedback["step_id"], - id=step_feedback.get("feedback_id"), - value=step_feedback["feedback_value"], - comment=step_feedback.get("feedback_comment"), + forId=step.id, + id=fb.id, + value=fb.value, + comment=fb.comment, ) step_dict = StepDict( - id=step_feedback["step_id"], - name=step_feedback["step_name"], - type=step_feedback["step_type"], - threadId=thread_id, - parentId=step_feedback.get("step_parentid"), - streaming=step_feedback.get("step_streaming", False), - waitForAnswer=step_feedback.get("step_waitforanswer"), - isError=step_feedback.get("step_iserror"), - metadata=( - step_feedback["step_metadata"] - if step_feedback.get("step_metadata") is not None - else {} - ), - tags=step_feedback.get("step_tags"), - input=( - step_feedback.get("step_input", "") - if step_feedback.get("step_showinput") - not in [None, "false"] - else "" - ), - output=step_feedback.get("step_output", ""), - createdAt=step_feedback.get("step_createdat"), - start=step_feedback.get("step_start"), - end=step_feedback.get("step_end"), - generation=step_feedback.get("step_generation"), - showInput=step_feedback.get("step_showinput"), - language=step_feedback.get("step_language"), + id=step.id, + name=step.name, + type=step.type, + threadId=step.threadId, + parentId=step.parentId, + streaming=step.streaming or False, + waitForAnswer=step.waitForAnswer, + isError=step.isError, + metadata=self._parse_json(step.metadata_), + tags=json.loads(step.tags) if step.tags else None, + input=step.input + if step.showInput not in (None, "false") + else "", + output=step.output or "", + createdAt=step.createdAt, + start=step.start, + end=step.end, + generation=self._parse_json(step.generation), + showInput=step.showInput, + language=step.language, feedback=feedback, ) - # Append the step to the steps list of the corresponding ThreadDict - thread_dicts[thread_id]["steps"].append(step_dict) - - if isinstance(elements, list): - for element in elements: - thread_id = element["element_threadid"] - if thread_id is not None: - element_url: str | None = None - object_key_val = element.get("element_objectkey") + thread_dicts[step.threadId]["steps"].append(step_dict) + + for elem in all_elements: + tid = elem.threadId + if tid and tid in thread_dicts: + element_url: Optional[str] = None if ( - self.storage_provider is not None - and isinstance(object_key_val, str) - and object_key_val.strip() + self.storage_provider + and elem.objectKey + and elem.objectKey.strip() ): try: element_url = await self.storage_provider.get_read_url( - object_key=object_key_val, + object_key=elem.objectKey ) except Exception as e: logger.warning( - f"Failed to get read URL for object_key '{object_key_val}': {e}. Falling back to stored URL." + f"Failed to get read URL for object_key '{elem.objectKey}': {e}" ) - element_url = element.get("element_url") + element_url = elem.url else: - element_url = element.get("element_url") + element_url = elem.url + element_dict = ElementDict( - id=element["element_id"], - threadId=thread_id, - type=element["element_type"], - chainlitKey=element.get("element_chainlitkey"), + id=elem.id, + threadId=tid, + type=elem.type, + chainlitKey=elem.chainlitKey, url=element_url, - objectKey=element.get("element_objectkey"), - name=element["element_name"], - display=element["element_display"], - size=element.get("element_size"), - language=element.get("element_language"), - autoPlay=element.get("element_autoPlay"), - playerConfig=element.get("element_playerconfig"), - page=element.get("element_page"), - props=element.get("props", "{}"), - forId=element.get("element_forid"), - mime=element.get("element_mime"), + objectKey=elem.objectKey, + name=elem.name, + display=elem.display, + size=elem.size, + language=elem.language, + autoPlay=elem.autoPlay, + playerConfig=elem.playerConfig, + page=elem.page, + props=self._parse_json(elem.props), + forId=elem.forId, + mime=elem.mime, ) - thread_dicts[thread_id]["elements"].append(element_dict) # type: ignore + thread_dicts[tid]["elements"].append(element_dict) # type: ignore - return list(thread_dicts.values()) + return list(thread_dicts.values()) - async def get_favorite_steps(self, user_id: str) -> List[StepDict]: + async def get_favorite_steps(self, user_id: str) -> list[StepDict]: if self.show_logger: logger.info(f"SQLAlchemy: get_favorite_steps, user_id={user_id}") - query = """ - SELECT - s."id" AS step_id, - s."name" AS step_name, - s."type" AS step_type, - s."threadId" AS step_threadid, - s."parentId" AS step_parentid, - s."streaming" AS step_streaming, - s."waitForAnswer" AS step_waitforanswer, - s."isError" AS step_iserror, - s."metadata" AS step_metadata, - s."tags" AS step_tags, - s."input" AS step_input, - s."output" AS step_output, - s."createdAt" AS step_createdat, - s."start" AS step_start, - s."end" AS step_end, - s."generation" AS step_generation, - s."showInput" AS step_showinput, - s."language" AS step_language - FROM steps s - JOIN threads t ON s."threadId" = t.id - WHERE t."userId" = :user_id - AND s."metadata" LIKE :favorite_pattern - ORDER BY s."createdAt" DESC \ - """ - - result = await self.execute_sql( - query, {"user_id": user_id, "favorite_pattern": '%"favorite": true%'} - ) - - steps = [] - if isinstance(result, list): - for row in result: - metadata_raw = row["step_metadata"] - meta_dict = {} - if isinstance(metadata_raw, str): - try: - meta_dict = json.loads(metadata_raw) - except Exception: - pass - elif isinstance(metadata_raw, dict): - meta_dict = metadata_raw + async with self._session() as session: + stmt = ( + select(StepModel) + .join(ThreadModel, StepModel.threadId == ThreadModel.id) + .where(ThreadModel.userId == user_id) + .order_by(StepModel.createdAt.desc()) + ) + result = await session.execute(stmt) + all_steps = result.scalars().all() + steps = [] + for step in all_steps: + meta_dict = self._parse_json(step.metadata_) if meta_dict.get("favorite"): steps.append( StepDict( - id=row["step_id"], - name=row["step_name"], - type=row["step_type"], - threadId=row["step_threadid"], - parentId=row["step_parentid"], - streaming=row.get("step_streaming", False), - waitForAnswer=row.get("step_waitforanswer"), - isError=row.get("step_iserror"), + id=step.id, + name=step.name, + type=step.type, + threadId=step.threadId, + parentId=step.parentId, + streaming=step.streaming or False, + waitForAnswer=step.waitForAnswer, + isError=step.isError, metadata=meta_dict, - tags=row.get("step_tags"), - input=( - row.get("step_input", "") - if row.get("step_showinput") not in [None, "false"] - else "" - ), - output=row.get("step_output", ""), - createdAt=row.get("step_createdat"), - start=row.get("step_start"), - end=row.get("step_end"), - generation=row.get("step_generation"), - showInput=row.get("step_showinput"), - language=row.get("step_language"), + tags=json.loads(step.tags) if step.tags else None, + input=step.input + if step.showInput not in (None, "false") + else "", + output=step.output or "", + createdAt=step.createdAt, + start=step.start, + end=step.end, + generation=self._parse_json(step.generation), + showInput=step.showInput, + language=step.language, feedback=None, ) ) - return steps + return steps async def close(self) -> None: if self.storage_provider: diff --git a/backend/tests/data/test_sql_alchemy.py b/backend/tests/data/test_sql_alchemy.py index decd5e34c5..e34eb49c81 100644 --- a/backend/tests/data/test_sql_alchemy.py +++ b/backend/tests/data/test_sql_alchemy.py @@ -2,8 +2,6 @@ from pathlib import Path import pytest -from sqlalchemy import text -from sqlalchemy.ext.asyncio import create_async_engine from chainlit import User from chainlit.data.sql_alchemy import SQLAlchemyDataLayer @@ -16,109 +14,10 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path): db_file = tmp_path / "test_db.sqlite" conninfo = f"sqlite+aiosqlite:///{db_file}" - # Create async engine - engine = create_async_engine(conninfo) - - # Execute initialization statements - # Ref: https://docs.chainlit.io/data-persistence/custom#sql-alchemy-data-layer - async with engine.begin() as conn: - await conn.execute( - text( - """ - CREATE TABLE users ( - "id" UUID PRIMARY KEY, - "identifier" TEXT NOT NULL UNIQUE, - "metadata" JSONB NOT NULL, - "createdAt" TEXT - ); - """ - ) - ) - - await conn.execute( - text( - """ - CREATE TABLE IF NOT EXISTS threads ( - "id" UUID PRIMARY KEY, - "createdAt" TEXT, - "name" TEXT, - "userId" UUID, - "userIdentifier" TEXT, - "tags" TEXT[], - "metadata" JSONB, - FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE - ); - """ - ) - ) - - await conn.execute( - text( - """ - CREATE TABLE IF NOT EXISTS steps ( - "id" UUID PRIMARY KEY, - "name" TEXT NOT NULL, - "type" TEXT NOT NULL, - "threadId" UUID NOT NULL, - "parentId" UUID, - "disableFeedback" BOOLEAN NOT NULL, - "streaming" BOOLEAN NOT NULL, - "waitForAnswer" BOOLEAN, - "isError" BOOLEAN, - "metadata" JSONB, - "tags" TEXT[], - "input" TEXT, - "output" TEXT, - "createdAt" TEXT, - "start" TEXT, - "end" TEXT, - "generation" JSONB, - "showInput" TEXT, - "language" TEXT, - "indent" INT - ); - """ - ) - ) - - await conn.execute( - text( - """ - CREATE TABLE IF NOT EXISTS elements ( - "id" UUID PRIMARY KEY, - "threadId" UUID, - "type" TEXT, - "url" TEXT, - "chainlitKey" TEXT, - "name" TEXT NOT NULL, - "display" TEXT, - "objectKey" TEXT, - "size" TEXT, - "page" INT, - "language" TEXT, - "forId" UUID, - "mime" TEXT - ); - """ - ) - ) - - await conn.execute( - text( - """ - CREATE TABLE IF NOT EXISTS feedbacks ( - "id" UUID PRIMARY KEY, - "forId" UUID NOT NULL, - "threadId" UUID NOT NULL, - "value" INT NOT NULL, - "comment" TEXT - ); - """ - ) - ) - - # Create SQLAlchemyDataLayer instance - data_layer = SQLAlchemyDataLayer(conninfo, storage_provider=mock_storage_client) + # Create SQLAlchemyDataLayer instance with automatic table creation + data_layer = SQLAlchemyDataLayer( + conninfo, storage_provider=mock_storage_client, create_tables=True + ) return data_layer From 8961c6434b344b5639a120589c9aadc90dc101da Mon Sep 17 00:00:00 2001 From: Haze Date: Fri, 16 Jan 2026 19:17:30 -0600 Subject: [PATCH 2/4] fix mypy --- backend/chainlit/data/sql_alchemy.py | 37 +++++++++------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index b9ae0ad46e..fd9f2fd175 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -112,34 +112,19 @@ async def _session(self): def _get_upsert( self, model: type, values: dict[str, Any], index_elements: list[str] ) -> Any: - if self._dialect_name == "postgresql": - stmt = pg_insert(model).values(**values) - update_dict = { - k: stmt.excluded[k] for k in values if k not in index_elements - } - return stmt.on_conflict_do_update( - index_elements=index_elements, set_=update_dict - ) - elif self._dialect_name == "sqlite": - stmt = sqlite_insert(model).values(**values) # type: ignore[arg-type] - update_dict = { - k: stmt.excluded[k] for k in values if k not in index_elements - } - return stmt.on_conflict_do_update( - index_elements=index_elements, set_=update_dict - ) - elif self._dialect_name in ("mysql", "mariadb"): - stmt = mysql_insert(model).values(**values) # type: ignore[arg-type] + # MySQL/MariaDB uses different upsert syntax + if self._dialect_name in ("mysql", "mariadb"): + stmt = mysql_insert(model).values(**values) # type: ignore[assignment] update_dict = {k: v for k, v in values.items() if k not in index_elements} return stmt.on_duplicate_key_update(**update_dict) # type: ignore[attr-defined] - else: - stmt = pg_insert(model).values(**values) - update_dict = { - k: stmt.excluded[k] for k in values if k not in index_elements - } - return stmt.on_conflict_do_update( - index_elements=index_elements, set_=update_dict - ) + + # PostgreSQL, SQLite, and others use on_conflict_do_update + insert_fn = sqlite_insert if self._dialect_name == "sqlite" else pg_insert + stmt = insert_fn(model).values(**values) # type: ignore[assignment] + update_dict = {k: stmt.excluded[k] for k in values if k not in index_elements} # type: ignore[attr-defined] + return stmt.on_conflict_do_update( # type: ignore[attr-defined] + index_elements=index_elements, set_=update_dict + ) async def build_debug_url(self) -> str: return "" From 9e2c3b0fd05b8aaf2ee8ab064c5fa2695c70cad9 Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Fri, 16 Jan 2026 19:22:43 -0600 Subject: [PATCH 3/4] Update backend/chainlit/data/sql_alchemy.py Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> --- backend/chainlit/data/sql_alchemy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index fd9f2fd175..f2b6973696 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -612,7 +612,7 @@ async def create_element(self, element: Element): "forId": element.for_id, "mime": element.mime, "props": json.dumps(element_dict.get("props", {})), - "autoPlay": getattr(element, "autoPlay", None), + "autoPlay": element_dict.get("autoPlay"), "playerConfig": getattr(element, "playerConfig", None), } values = {k: v for k, v in values.items() if v is not None} From 469274c210d66b2a46129c4f6e6d5de0a514bed0 Mon Sep 17 00:00:00 2001 From: Josh Hayes <35790761+hayescode@users.noreply.github.com> Date: Fri, 16 Jan 2026 19:22:55 -0600 Subject: [PATCH 4/4] Update backend/chainlit/data/sql_alchemy.py Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> --- backend/chainlit/data/sql_alchemy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index f2b6973696..45cf88debe 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -613,7 +613,7 @@ async def create_element(self, element: Element): "mime": element.mime, "props": json.dumps(element_dict.get("props", {})), "autoPlay": element_dict.get("autoPlay"), - "playerConfig": getattr(element, "playerConfig", None), + "playerConfig": element_dict.get("playerConfig"), } values = {k: v for k, v in values.items() if v is not None} if "name" not in values: