diff --git a/pyiron_snippets/files.py b/pyiron_snippets/files.py index 09ea2ac..ee413d8 100644 --- a/pyiron_snippets/files.py +++ b/pyiron_snippets/files.py @@ -1,6 +1,8 @@ from __future__ import annotations +import tarfile from pathlib import Path +from typing import cast def delete_files_and_directories_recursively(path): @@ -44,11 +46,12 @@ def __init__( self, directory: str | Path | DirectoryObject, protected: bool = False ): if isinstance(directory, str): - self.path = Path(directory) + path = Path(directory) elif isinstance(directory, Path): - self.path = directory + path = directory elif isinstance(directory, DirectoryObject): - self.path = directory.path + path = directory.path + self.path: Path = path self.create() self._protected = protected @@ -97,3 +100,35 @@ def remove_files(self, *files: str): path = self.get_path(file) if path.is_file(): path.unlink() + + def compress(self, exclude_files: list[str | Path] | None = None): + directory = self.path.resolve() + output_tar_path = directory.with_suffix(".tar.gz") + if output_tar_path.exists(): + return + if exclude_files is None: + exclude_files = [] + else: + exclude_files = [Path(f) for f in exclude_files] + exclude_set = { + f.resolve() if f.is_absolute() else (directory / f).resolve() + for f in cast(list[Path], exclude_files) + } + files_to_delete = [] + with tarfile.open(output_tar_path, "w:gz") as tar: + for file in directory.rglob("*"): + if file.is_file() and file.resolve() not in exclude_set: + arcname = file.relative_to(directory) + tar.add(file, arcname=arcname) + files_to_delete.append(file) + for file in files_to_delete: + file.unlink() + + def decompress(self): + directory = self.path.resolve() + tar_path = directory.with_suffix(".tar.gz") + if not tar_path.exists(): + return + with tarfile.open(tar_path, "r:gz") as tar: + tar.extractall(path=directory, filter="fully_trusted") + tar_path.unlink() diff --git a/tests/unit/test_files.py b/tests/unit/test_files.py index 20ab917..cc75548 100644 --- a/tests/unit/test_files.py +++ b/tests/unit/test_files.py @@ -1,4 +1,5 @@ import pickle +import tarfile import unittest from pathlib import Path @@ -92,6 +93,42 @@ def test_remove(self): msg="Should be able to remove just one file", ) + def test_compress(self): + while Path("test.tar.gz").exists(): + Path("test.tar.gz").unlink() + self.directory.write(file_name="test1.txt", content="something") + self.directory.write(file_name="test2.txt", content="something") + self.directory.compress(exclude_files=["test1.txt"]) + self.assertTrue(Path("test.tar.gz").exists()) + with tarfile.open("test.tar.gz", "r:*") as f: + content = [name for name in f.getnames()] + self.assertNotIn( + "test1.txt", content, msg="Excluded file should not be in archive" + ) + self.assertIn( + "test2.txt", content, msg="Included file should be in archive" + ) + self.assertFalse( + self.directory.file_exists("test2.txt"), + msg="Compressed files should not be in the directory", + ) + self.assertTrue( + self.directory.file_exists("test1.txt"), + msg="Excluded file should still be in the directory", + ) + # Test that compressing again does not raise an error + self.directory.compress() + self.assertTrue(Path("test.tar.gz").exists()) + self.directory.decompress() + self.assertTrue( + self.directory.file_exists("test2.txt"), + msg="Decompressed files should be back in the directory", + ) + self.assertFalse( + Path("test.tar.gz").exists(), + msg="Archive should be deleted after decompression", + ) + if __name__ == "__main__": unittest.main()