diff --git a/taf/tests/test_updater/test_interrupt_handling.py b/taf/tests/test_updater/test_interrupt_handling.py new file mode 100644 index 00000000..ee2ffd26 --- /dev/null +++ b/taf/tests/test_updater/test_interrupt_handling.py @@ -0,0 +1,56 @@ +import os +import signal +import pytest +from pathlib import Path + +from taf.exceptions import UpdateFailedError +from taf.updater.updater import update_repository +from taf.updater.updater import UpdateConfig +from taf.updater.types.update import OperationType +from taf.git import GitRepository + + +class DummyRepo: + """Fake empty git repo used for path presence.""" + def __init__(self, path): + self.path = path + self.name = "dummy" + self.is_git_repository = True + self.is_bare_repository = False + + def get_remote_url(self): + return "https://example.com/repo.git" + + +@pytest.fixture +def fake_repo(tmp_path, monkeypatch): + """Create a tmp directory and mock GitRepository so no actual git commands run.""" + repo_path = tmp_path / "auth" + repo_path.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr("taf.updater.updater.GitRepository", lambda path: DummyRepo(path)) + + return repo_path + + +def test_update_repository_interrupts_gracefully(fake_repo, monkeypatch): + """ + Simulate Ctrl+C (SIGINT) during update_repository and ensure UpdateFailedError is raised. + """ + + config = UpdateConfig( + path=fake_repo, + operation=OperationType.UPDATE, + remote_url="https://example.com/repo.git", + strict=True, + run_scripts=False, + ) + + def fake_update(*args, **kwargs): + # Immediately simulate Ctrl+C by sending SIGINT to the current process + os.kill(os.getpid(), signal.SIGINT) + + monkeypatch.setattr("taf.updater.updater._update_or_clone_repository", fake_update) + + with pytest.raises(UpdateFailedError, match="interrupted"): + update_repository(config) diff --git a/taf/updater/updater.py b/taf/updater/updater.py index 6315b6d8..f06f487c 100644 --- a/taf/updater/updater.py +++ b/taf/updater/updater.py @@ -26,6 +26,7 @@ loads data from a most recent commit. """ import copy +import signal from logging import ERROR from typing import Dict, Tuple, Any @@ -327,29 +328,45 @@ def update_repository(config: UpdateConfig): Returns: None """ - settings.strict = config.strict - settings.run_scripts = config.run_scripts - # if path is not specified, name should be read from info.json - # which is available after the remote repository is cloned and validated + # --- Graceful Ctrl+C / SIGTERM handling --- + interrupted = {"flag": False} - auth_repo = GitRepository(path=config.path) - if not config.path.is_dir() or not auth_repo.is_git_repository: - raise UpdateFailedError( - f"{config.path} is not a Git repository. Run 'taf repo clone' instead" - ) + def _handle_interrupt(signum, frame): + interrupted["flag"] = True + raise UpdateFailedError("Update interrupted by user (Ctrl+C)") - taf_logger.info(f"Updating repository {auth_repo.name}") + old_sigint = signal.signal(signal.SIGINT, _handle_interrupt) + old_sigterm = signal.signal(signal.SIGTERM, _handle_interrupt) + + try: + # --- BEGIN ORIGINAL UNMODIFIED LOGIC --- + settings.strict = config.strict + settings.run_scripts = config.run_scripts + + auth_repo = GitRepository(path=config.path) + if not config.path.is_dir() or not auth_repo.is_git_repository: + raise UpdateFailedError( + f"{config.path} is not a Git repository. Run 'taf repo clone' instead" + ) + + taf_logger.info(f"Updating repository {auth_repo.name}") - if config.remote_url is None: - config.remote_url = auth_repo.get_remote_url() if config.remote_url is None: - raise UpdateFailedError("URL cannot be determined. Please specify it") + config.remote_url = auth_repo.get_remote_url() + if config.remote_url is None: + raise UpdateFailedError("URL cannot be determined. Please specify it") - if auth_repo.is_bare_repository: - # Handle updates for bare repositories - config.bare = True - return _update_or_clone_repository(config) + if auth_repo.is_bare_repository: + config.bare = True + + return _update_or_clone_repository(config) + # --- END ORIGINAL LOGIC --- + + finally: + # Restore original handlers so we do not break global signal behavior + signal.signal(signal.SIGINT, old_sigint) + signal.signal(signal.SIGTERM, old_sigterm) def _update_or_clone_repository(config: UpdateConfig):