Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions openviking/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,14 @@ def commit(self) -> Dict[str, Any]:
async def commit_async(self) -> Dict[str, Any]:
"""Async commit session: two-phase approach.

Phase 1 (Archive): Write archive, clear messages.
Phase 1 (Archive prep, PathLock-protected): Copy messages, clear live
session, increment compression index. Uses a distributed filesystem lock
(PathLock) so this works across workers and processes.
Phase 2 (Memory, redo-log protected): Extract memories, write, enqueue.
"""
import uuid

from openviking.storage.transaction import get_lock_manager
from openviking.storage.transaction import LockContext, get_lock_manager

result = {
"session_id": self.session_id,
Expand All @@ -241,19 +243,34 @@ async def commit_async(self) -> Dict[str, Any]:
"archived": False,
"stats": None,
}
if not self._messages:
get_current_telemetry().set("memory.extracted", 0)
return result

# ===== Preparation =====
self._compression.compression_index += 1
messages_to_archive = self._messages.copy()
# ===== Phase 1: Snapshot + clear (PathLock-protected) =====
# Use filesystem-based distributed lock so this works across workers/processes.
session_path = self._viking_fs._uri_to_path(self._session_uri, ctx=self.ctx)
async with LockContext(get_lock_manager(), [session_path], lock_mode="point"):
if not self._messages:
get_current_telemetry().set("memory.extracted", 0)
return result

self._compression.compression_index += 1
messages_to_archive = self._messages.copy()
self._messages.clear()

try:
await self._write_to_agfs_async(messages=[])
except Exception:
# Rollback: restore messages so they aren't lost
self._messages.extend(messages_to_archive)
self._compression.compression_index -= 1
raise
# Lock released — live session is now clean.
# Any add_message() from here appends to the fresh empty list.

# ===== Phase 1 continued: Archive write (no lock needed) =====
summary = await self._generate_archive_summary_async(messages_to_archive)
archive_abstract = self._extract_abstract_from_summary(summary)
archive_overview = summary

# ===== Phase 1: Archive (no lock) =====
archive_uri = (
f"{self._session_uri}/history/archive_{self._compression.compression_index:03d}"
)
Expand All @@ -263,8 +280,6 @@ async def commit_async(self) -> Dict[str, Any]:
abstract=archive_abstract,
overview=archive_overview,
)
await self._write_to_agfs_async(messages=[])
self._messages.clear()

self._compression.original_count += len(messages_to_archive)
result["archived"] = True
Expand Down
67 changes: 67 additions & 0 deletions tests/session/test_session_commit_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: Apache-2.0

"""Tests for session commit race condition fix (#580)."""

import asyncio

from openviking import AsyncOpenViking
from openviking.message import TextPart


class TestCommitRace:
"""Test concurrent commit safety."""

async def test_concurrent_commit_no_duplicate(self, client: AsyncOpenViking):
"""Two concurrent commits on the same session: only one should archive."""
session = client.session(session_id="race_test_dedup")
session.add_message("user", [TextPart("Hello")])
session.add_message("assistant", [TextPart("Hi there")])

results = await asyncio.gather(
session.commit_async(),
session.commit_async(),
)

archived_count = sum(1 for r in results if r.get("archived") is True)
assert archived_count == 1, f"Expected exactly 1 archived commit, got {archived_count}"

# Messages should be cleared after commit
assert len(session.messages) == 0

# Compression index should have incremented exactly once
assert session._compression.compression_index == 1

async def test_message_added_during_commit_not_lost(self, client: AsyncOpenViking):
"""Messages added while commit is running should not be lost."""
session = client.session(session_id="race_test_msg_safety")
session.add_message("user", [TextPart("Original message")])

# Use an Event for deterministic synchronization instead of sleeps
phase1_done = asyncio.Event()
original_generate = session._generate_archive_summary_async

async def slow_generate(messages):
# Signal that Phase 1 is complete (lock released, messages cleared)
phase1_done.set()
# Yield control so add_message can run before archive completes
await asyncio.sleep(0)
return await original_generate(messages)

session._generate_archive_summary_async = slow_generate

async def commit_and_add():
"""Start commit, then add a message after Phase 1 completes."""
commit_task = asyncio.create_task(session.commit_async())
# Wait until Phase 1 is done (lock released, messages cleared)
await phase1_done.wait()
# Add message while commit is in Phase 2 (after lock released)
session.add_message("user", [TextPart("New message during commit")])
return await commit_task

result = await commit_and_add()

assert result.get("archived") is True
# The new message should still be in the session
assert len(session.messages) == 1
assert session.messages[0].content == "New message during commit"
Loading