diff --git a/openviking/session/session.py b/openviking/session/session.py index bdb6500b..13b3f565 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -226,12 +226,14 @@ def commit(self) -> Dict[str, Any]: async def commit_async(self) -> Dict[str, Any]: """Async commit session: two-phase approach. - Phase 1 (Archive): Write archive, clear messages. + Phase 1 (Archive prep, PathLock-protected): Copy messages, clear live + session, increment compression index. Uses a distributed filesystem lock + (PathLock) so this works across workers and processes. Phase 2 (Memory, redo-log protected): Extract memories, write, enqueue. """ import uuid - from openviking.storage.transaction import get_lock_manager + from openviking.storage.transaction import LockContext, get_lock_manager result = { "session_id": self.session_id, @@ -241,19 +243,34 @@ async def commit_async(self) -> Dict[str, Any]: "archived": False, "stats": None, } - if not self._messages: - get_current_telemetry().set("memory.extracted", 0) - return result - # ===== Preparation ===== - self._compression.compression_index += 1 - messages_to_archive = self._messages.copy() + # ===== Phase 1: Snapshot + clear (PathLock-protected) ===== + # Use filesystem-based distributed lock so this works across workers/processes. + session_path = self._viking_fs._uri_to_path(self._session_uri, ctx=self.ctx) + async with LockContext(get_lock_manager(), [session_path], lock_mode="point"): + if not self._messages: + get_current_telemetry().set("memory.extracted", 0) + return result + + self._compression.compression_index += 1 + messages_to_archive = self._messages.copy() + self._messages.clear() + try: + await self._write_to_agfs_async(messages=[]) + except Exception: + # Rollback: restore messages so they aren't lost + self._messages.extend(messages_to_archive) + self._compression.compression_index -= 1 + raise + # Lock released — live session is now clean. + # Any add_message() from here appends to the fresh empty list. + + # ===== Phase 1 continued: Archive write (no lock needed) ===== summary = await self._generate_archive_summary_async(messages_to_archive) archive_abstract = self._extract_abstract_from_summary(summary) archive_overview = summary - # ===== Phase 1: Archive (no lock) ===== archive_uri = ( f"{self._session_uri}/history/archive_{self._compression.compression_index:03d}" ) @@ -263,8 +280,6 @@ async def commit_async(self) -> Dict[str, Any]: abstract=archive_abstract, overview=archive_overview, ) - await self._write_to_agfs_async(messages=[]) - self._messages.clear() self._compression.original_count += len(messages_to_archive) result["archived"] = True diff --git a/tests/session/test_session_commit_race.py b/tests/session/test_session_commit_race.py new file mode 100644 index 00000000..e71a471f --- /dev/null +++ b/tests/session/test_session_commit_race.py @@ -0,0 +1,67 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for session commit race condition fix (#580).""" + +import asyncio + +from openviking import AsyncOpenViking +from openviking.message import TextPart + + +class TestCommitRace: + """Test concurrent commit safety.""" + + async def test_concurrent_commit_no_duplicate(self, client: AsyncOpenViking): + """Two concurrent commits on the same session: only one should archive.""" + session = client.session(session_id="race_test_dedup") + session.add_message("user", [TextPart("Hello")]) + session.add_message("assistant", [TextPart("Hi there")]) + + results = await asyncio.gather( + session.commit_async(), + session.commit_async(), + ) + + archived_count = sum(1 for r in results if r.get("archived") is True) + assert archived_count == 1, f"Expected exactly 1 archived commit, got {archived_count}" + + # Messages should be cleared after commit + assert len(session.messages) == 0 + + # Compression index should have incremented exactly once + assert session._compression.compression_index == 1 + + async def test_message_added_during_commit_not_lost(self, client: AsyncOpenViking): + """Messages added while commit is running should not be lost.""" + session = client.session(session_id="race_test_msg_safety") + session.add_message("user", [TextPart("Original message")]) + + # Use an Event for deterministic synchronization instead of sleeps + phase1_done = asyncio.Event() + original_generate = session._generate_archive_summary_async + + async def slow_generate(messages): + # Signal that Phase 1 is complete (lock released, messages cleared) + phase1_done.set() + # Yield control so add_message can run before archive completes + await asyncio.sleep(0) + return await original_generate(messages) + + session._generate_archive_summary_async = slow_generate + + async def commit_and_add(): + """Start commit, then add a message after Phase 1 completes.""" + commit_task = asyncio.create_task(session.commit_async()) + # Wait until Phase 1 is done (lock released, messages cleared) + await phase1_done.wait() + # Add message while commit is in Phase 2 (after lock released) + session.add_message("user", [TextPart("New message during commit")]) + return await commit_task + + result = await commit_and_add() + + assert result.get("archived") is True + # The new message should still be in the session + assert len(session.messages) == 1 + assert session.messages[0].content == "New message during commit"