diff --git a/src/manage/commands.py b/src/manage/commands.py index fbc84d7..74b773b 100644 --- a/src/manage/commands.py +++ b/src/manage/commands.py @@ -39,8 +39,6 @@ WELCOME = f"""!B!Python install manager was successfully updated to {__version__}.!W! -!Y!Start menu shortcuts have been changed in this update.!W! -Run !G!py install --refresh!W! to update any existing shortcuts. """ # The 'py help' or 'pymanager help' output is constructed by these default docs, @@ -254,6 +252,7 @@ def execute(self): "enable_shortcut_kinds": (str, config_split_append), "disable_shortcut_kinds": (str, config_split_append), "default_install_tag": (str, None), + "preserve_site_on_upgrade": (config_bool, None), }, "first_run": { @@ -794,6 +793,7 @@ class InstallCommand(BaseCommand): enable_shortcut_kinds = None disable_shortcut_kinds = None default_install_tag = None + preserve_site_on_upgrade = True def __init__(self, args, root=None): super().__init__(args, root) diff --git a/src/manage/install_command.py b/src/manage/install_command.py index 2edd97f..eae2cc7 100644 --- a/src/manage/install_command.py +++ b/src/manage/install_command.py @@ -458,6 +458,87 @@ def _download_one(cmd, source, install, download_dir, *, must_copy=False): return package +def _preserve_site(cmd, root): + if not root.is_dir(): + return None + if not cmd.preserve_site_on_upgrade: + LOGGER.verbose("Not preserving site directory because of config") + return None + if cmd.force: + LOGGER.verbose("Not preserving site directory because of --force") + return None + if cmd.repair: + LOGGER.verbose("Not preserving site directory because of --repair") + return None + state = [] + i = 0 + dirs = [root] + root = root.with_name(f"_{root.name}") + root.mkdir(parents=True, exist_ok=True) + while dirs: + if dirs[0].match("site-packages"): + while True: + target = root / str(i) + i += 1 + try: + unlink(target) + break + except FileNotFoundError: + break + except OSError: + LOGGER.verbose("Failed to remove %s.", target) + LOGGER.info("Preserving %s during update as %s.", dirs[0], target) + try: + dirs[0].rename(target) + except OSError: + LOGGER.warn("Failed to preserve %s during update.", dirs[0]) + LOGGER.verbose("TRACEBACK", exc_info=True) + else: + state.append((dirs[0], target)) + else: + dirs.extend(d for d in dirs[0].iterdir() if d.is_dir()) + dirs.pop(0) + # Append None, root last so that root gets cleaned up after restore is done + state.append((None, root)) + return state + + +def _restore_site(cmd, state): + if not state: + return + for dest, src in state: + if not dest: + LOGGER.verbose("Removing preserved directory at %s", src) + try: + rmtree( + src, + "Removing temporary files is taking some time. " + + "You can continue to wait or press Ctrl+C to abort. " + + "Python has been installed, but some harmless temporary " + + "files may remain on disk." + ) + except KeyboardInterrupt: + break + continue + LOGGER.info("Restoring %s from %s after update.", dest, src) + try: + for i in src.iterdir(): + if not i.is_dir() and not i.is_file(): + LOGGER.verbose("Not restoring %s because it is not a " + + "normal file or directory.", i) + continue + d = dest / i.name + if d.exists(): + LOGGER.verbose("Not restoring %s because %s exists", i, d) + continue + LOGGER.verbose("Restoring %s to %s", i, d) + d.parent.mkdir(parents=True, exist_ok=True) + i.rename(d) + except OSError: + LOGGER.warn("Failed to restore %s during update.", dest) + LOGGER.verbose("TRACEBACK", exc_info=True) + + def _install_one(cmd, source, install, *, target=None): if cmd.repair: LOGGER.info("Repairing %s.", install['display-name']) @@ -475,6 +556,8 @@ def _install_one(cmd, source, install, *, target=None): dest = target or (cmd.install_dir / install["id"]) + preserved_site = _preserve_site(cmd, dest) + LOGGER.verbose("Extracting %s to %s", package, dest) if not cmd.repair: try: @@ -544,6 +627,8 @@ def _install_one(cmd, source, install, *, target=None): with open(dest / "__install__.json", "w", encoding="utf-8") as f: json.dump(install, f, default=str) + _restore_site(cmd, preserved_site) + LOGGER.verbose("Install complete") @@ -560,7 +645,6 @@ def _merge_existing_index(versions, index_json): else: LOGGER.debug("Merging into existing %s", index_json) current = {i["url"].casefold() for i in versions} - added = [] for install in existing_index["versions"]: if install.get("url", "").casefold() not in current: LOGGER.debug("Merging %s", install.get("url", "")) diff --git a/tests/test_install_command.py b/tests/test_install_command.py index 6c0cae3..2e0f445 100644 --- a/tests/test_install_command.py +++ b/tests/test_install_command.py @@ -230,3 +230,62 @@ def test_merge_existing_index_not_valid(tmp_path): new = [1, 2, 3] IC._merge_existing_index(new, existing) assert new == [1, 2, 3] + + +def test_preserve_site(tmp_path): + root = tmp_path / "root" + preserved = tmp_path / "_root" + site = root / "site-packages" + not_site = root / "site-not-packages" + A = site / "A" + B = site / "B.txt" + C = site / "C.txt" + A.mkdir(parents=True, exist_ok=True) + B.write_bytes(b"") + C.write_bytes(b"original") + + class Cmd: + preserve_site_on_upgrade = False + force = False + repair = False + + state = IC._preserve_site(Cmd, root) + assert not state + assert not preserved.exists() + Cmd.preserve_site_on_upgrade = True + Cmd.force = True + state = IC._preserve_site(Cmd, root) + assert not state + assert not preserved.exists() + Cmd.force = False + Cmd.repair = True + state = IC._preserve_site(Cmd, root) + assert not state + assert not preserved.exists() + + Cmd.repair = False + state = IC._preserve_site(Cmd, root) + assert state == [(site, preserved / "0"), (None, preserved)] + assert preserved.is_dir() + + root.rename(root.parent / "ex_root_1") + IC._restore_site(Cmd, state) + assert root.is_dir() + assert A.is_dir() + assert B.is_file() + assert C.is_file() + assert b"original" == C.read_bytes() + assert not preserved.exists() + + state = IC._preserve_site(Cmd, root) + assert state == [(site, preserved / "0"), (None, preserved)] + + assert not C.exists() + C.parent.mkdir(parents=True, exist_ok=True) + C.write_bytes(b"updated") + IC._restore_site(Cmd, state) + assert A.is_dir() + assert B.is_file() + assert C.is_file() + assert b"updated" == C.read_bytes() + assert not preserved.exists()