diff --git a/database.py b/database.py index aef5771..2d86e10 100644 --- a/database.py +++ b/database.py @@ -4,6 +4,7 @@ from sklearn.metrics.pairwise import cosine_similarity from sqlalchemy import and_, delete, select +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -44,14 +45,10 @@ async def get_or_create_user(self, user_id: str, session: Optional[AsyncSession] async def _get_or_create_user_in_session(self, session: AsyncSession, user_id: str) -> None: """Helper to get or create user within a given session""" - query = select(User).where(User.id == user_id) - result = await session.execute(query) - user = result.scalar_one_or_none() - - if not user: - user = User(id=user_id) - session.add(user) - await session.commit() + # Use ON CONFLICT DO NOTHING to atomically create the user if they don't exist + insert_stmt = insert(User).values(id=user_id).on_conflict_do_nothing(index_elements=['id']) + await session.execute(insert_stmt) + await session.commit() async def add_conversation(self, user_id: str, message: str, response: str, language: str, embedding: List[float] = None,