Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions taf/tests/test_updater/test_interrupt_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import signal
import pytest
from pathlib import Path

from taf.exceptions import UpdateFailedError
from taf.updater.updater import update_repository
from taf.updater.updater import UpdateConfig
from taf.updater.types.update import OperationType
from taf.git import GitRepository


class DummyRepo:
"""Fake empty git repo used for path presence."""
def __init__(self, path):
self.path = path
self.name = "dummy"
self.is_git_repository = True
self.is_bare_repository = False

def get_remote_url(self):
return "https://example.com/repo.git"


@pytest.fixture
def fake_repo(tmp_path, monkeypatch):
"""Create a tmp directory and mock GitRepository so no actual git commands run."""
repo_path = tmp_path / "auth"
repo_path.mkdir(parents=True, exist_ok=True)

monkeypatch.setattr("taf.updater.updater.GitRepository", lambda path: DummyRepo(path))

return repo_path


def test_update_repository_interrupts_gracefully(fake_repo, monkeypatch):
"""
Simulate Ctrl+C (SIGINT) during update_repository and ensure UpdateFailedError is raised.
"""

config = UpdateConfig(
path=fake_repo,
operation=OperationType.UPDATE,
remote_url="https://example.com/repo.git",
strict=True,
run_scripts=False,
)

def fake_update(*args, **kwargs):
# Immediately simulate Ctrl+C by sending SIGINT to the current process
os.kill(os.getpid(), signal.SIGINT)

monkeypatch.setattr("taf.updater.updater._update_or_clone_repository", fake_update)

with pytest.raises(UpdateFailedError, match="interrupted"):
update_repository(config)
51 changes: 34 additions & 17 deletions taf/updater/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
loads data from a most recent commit.
"""
import copy
import signal
from logging import ERROR

from typing import Dict, Tuple, Any
Expand Down Expand Up @@ -327,29 +328,45 @@ def update_repository(config: UpdateConfig):
Returns:
None
"""
settings.strict = config.strict
settings.run_scripts = config.run_scripts

# if path is not specified, name should be read from info.json
# which is available after the remote repository is cloned and validated
# --- Graceful Ctrl+C / SIGTERM handling ---
interrupted = {"flag": False}

auth_repo = GitRepository(path=config.path)
if not config.path.is_dir() or not auth_repo.is_git_repository:
raise UpdateFailedError(
f"{config.path} is not a Git repository. Run 'taf repo clone' instead"
)
def _handle_interrupt(signum, frame):
interrupted["flag"] = True
raise UpdateFailedError("Update interrupted by user (Ctrl+C)")

taf_logger.info(f"Updating repository {auth_repo.name}")
old_sigint = signal.signal(signal.SIGINT, _handle_interrupt)
old_sigterm = signal.signal(signal.SIGTERM, _handle_interrupt)

try:
# --- BEGIN ORIGINAL UNMODIFIED LOGIC ---
settings.strict = config.strict
settings.run_scripts = config.run_scripts

auth_repo = GitRepository(path=config.path)
if not config.path.is_dir() or not auth_repo.is_git_repository:
raise UpdateFailedError(
f"{config.path} is not a Git repository. Run 'taf repo clone' instead"
)

taf_logger.info(f"Updating repository {auth_repo.name}")

if config.remote_url is None:
config.remote_url = auth_repo.get_remote_url()
if config.remote_url is None:
raise UpdateFailedError("URL cannot be determined. Please specify it")
config.remote_url = auth_repo.get_remote_url()
if config.remote_url is None:
raise UpdateFailedError("URL cannot be determined. Please specify it")

if auth_repo.is_bare_repository:
# Handle updates for bare repositories
config.bare = True
return _update_or_clone_repository(config)
if auth_repo.is_bare_repository:
config.bare = True

return _update_or_clone_repository(config)
# --- END ORIGINAL LOGIC ---

finally:
# Restore original handlers so we do not break global signal behavior
signal.signal(signal.SIGINT, old_sigint)
signal.signal(signal.SIGTERM, old_sigterm)


def _update_or_clone_repository(config: UpdateConfig):
Expand Down