diff --git a/agent/core/session_uploader.py b/agent/core/session_uploader.py index 404fd224..989e3f0b 100644 --- a/agent/core/session_uploader.py +++ b/agent/core/session_uploader.py @@ -20,6 +20,7 @@ import json import os import sys +from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Any @@ -300,16 +301,37 @@ def _url_field(format: str) -> str: return "personal_upload_url" if format == "claude_code" else "upload_url" -def _read_session_file(session_file: str) -> dict: - """Read a local session file while respecting uploader file locks.""" +@contextmanager +def _file_lock(file_obj, *, exclusive: bool): + """Lock a session file across uploader processes on Unix and Windows.""" + if os.name == "nt": + import msvcrt + + file_obj.seek(0) + mode = msvcrt.LK_LOCK + try: + msvcrt.locking(file_obj.fileno(), mode, 1) + yield + finally: + file_obj.seek(0) + msvcrt.locking(file_obj.fileno(), msvcrt.LK_UNLCK, 1) + return + import fcntl + lock = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH + fcntl.flock(file_obj, lock) + try: + yield + finally: + fcntl.flock(file_obj, fcntl.LOCK_UN) + + +def _read_session_file(session_file: str) -> dict: + """Read a local session file while respecting uploader file locks.""" with open(session_file, "r") as f: - fcntl.flock(f, fcntl.LOCK_SH) - try: + with _file_lock(f, exclusive=False): return json.load(f) - finally: - fcntl.flock(f, fcntl.LOCK_UN) def _update_upload_status( @@ -325,11 +347,8 @@ def _update_upload_status( local session JSON file. Re-read under an exclusive lock so one uploader cannot clobber fields written by the other. """ - import fcntl - with open(session_file, "r+") as f: - fcntl.flock(f, fcntl.LOCK_EX) - try: + with _file_lock(f, exclusive=True): data = json.load(f) data[status_key] = status if dataset_url is not None: @@ -340,8 +359,6 @@ def _update_upload_status( f.truncate() f.flush() os.fsync(f.fileno()) - finally: - fcntl.flock(f, fcntl.LOCK_UN) def dataset_card_readme(repo_id: str) -> str: diff --git a/tests/unit/test_prioritize_backlog.py b/tests/unit/test_prioritize_backlog.py index 9a8fd316..fd6dce2f 100644 --- a/tests/unit/test_prioritize_backlog.py +++ b/tests/unit/test_prioritize_backlog.py @@ -718,4 +718,4 @@ def test_cli_defaults_without_live_network_or_llm(): assert args.github_report_label == mod.DEFAULT_GITHUB_REPORT_LABEL assert args.output_dir is None assert out.name == "20260504T123000Z" - assert "scratch/backlog-prioritization" in str(out) + assert out.parts[-3:-1] == ("scratch", "backlog-prioritization")