diff --git a/acquire/outputs/tar.py b/acquire/outputs/tar.py index 69b72906..73237d25 100644 --- a/acquire/outputs/tar.py +++ b/acquire/outputs/tar.py @@ -1,6 +1,8 @@ from __future__ import annotations +import copy import io +import shutil import tarfile from typing import TYPE_CHECKING, BinaryIO @@ -100,7 +102,52 @@ def write( if stat: info.mtime = stat.st_mtime - self.tar.addfile(info, fh) + # Inline version of Python stdlib's tarfile.addfile & tarfile.copyfileobj, + # to allow for padding and more control over the tar file writing. + self.tar._check("awx") + + if fh is None and info.isreg() and info.size != 0: + raise ValueError("fileobj not provided for non zero-size regular file") + + info = copy.copy(info) + + buf = info.tobuf(self.tar.format, self.tar.encoding, self.tar.errors) + self.tar.fileobj.write(buf) + self.tar.offset += len(buf) + bufsize = self.tar.copybufsize + if fh is not None: + bufsize = bufsize or 16 * 1024 + + if info.size == 0: + return + if info.size is None: + shutil.copyfileobj(fh, self.tar.fileobj, bufsize) + return + + blocks, remainder = divmod(info.size, bufsize) + for _ in range(blocks): + # Prevents "long reads" because it reads at max bufsize bytes at a time + buf = fh.read(bufsize) + if len(buf) < bufsize: + # PATCH; instead of raising an exception, pad the data to the desired length + buf += tarfile.NUL * (bufsize - len(buf)) + self.tar.fileobj.write(buf) + + if remainder != 0: + # Prevents "long reads" because it reads at max bufsize bytes at a time + buf = fh.read(remainder) + if len(buf) < remainder: + # PATCH; instead of raising an exception, pad the data to the desired length + buf += tarfile.NUL * (remainder - len(buf)) + self.tar.fileobj.write(buf) + + blocks, remainder = divmod(info.size, tarfile.BLOCKSIZE) + if remainder > 0: + self.tar.fileobj.write(tarfile.NUL * (tarfile.BLOCKSIZE - remainder)) + blocks += 1 + self.tar.offset += blocks * tarfile.BLOCKSIZE + + self.tar.members.append(info) def close(self) -> None: """Closes the tar file.""" diff --git a/tests/test_outputs_tar.py b/tests/test_outputs_tar.py index 81059bb2..49d00f86 100644 --- a/tests/test_outputs_tar.py +++ b/tests/test_outputs_tar.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import tarfile from pathlib import Path from typing import TYPE_CHECKING @@ -63,3 +64,77 @@ def test_tar_output_encrypt(mock_fs: VirtualFilesystem, public_key: bytes, tmp_p with tarfile.open(name=decrypted_path, mode="r") as tar_file: assert entry.open().read() == tar_file.extractfile(entry_name).read() + + +def test_tar_output_race_condition_with_shrinking_file(tmp_path: Path, public_key: bytes) -> None: + class ShrinkingFile(io.BytesIO): + """ + A file-like object that returns 5 bytes less than required. + Simulates a file on disk that has shrunk in between the time of + determining the size and actually reading the data. + """ + + def __init__(self, data: bytes): + super().__init__(data) + + def read(self, size: int) -> bytes: + return super().read(size - 5) + + content = b"some text" + + content_padded = content[:-5] + tarfile.NUL * 5 + file = ShrinkingFile(content) + + tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key) + tar_output.write("file.log", file) + tar_output.close() + file.close() + + encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem")) + decrypted_path = tmp_path / "decrypted.tar" + + # Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly + Path(decrypted_path).write_bytes(encrypted_stream.read()) + + with tarfile.open(name=decrypted_path, mode="r") as tar_file: + member = tar_file.getmember("file.log") + extracted = tar_file.extractfile(member).read() + # The content should be padded with zeros to match the original size, despite the fact that the file shrunk + assert extracted == content_padded + + +def test_tar_output_race_condition_with_growing_file(tmp_path: Path, public_key: bytes) -> None: + class GrowingFile(io.BytesIO): + """ + A file-like object that returns 3 extra bytes. + Simulates a file on disk that has grown in between the time of + determining the size and actually reading the data. + """ + + def __init__(self, data: bytes): + super().__init__(data) + + def read(self, size: int) -> bytes: + return super().read(size) + b"FOX" + + content = b"some text" + + file = GrowingFile(content) + + tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key) + tar_output.write("file.log", file) + tar_output.close() + file.close() + + encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem")) + decrypted_path = tmp_path / "decrypted.tar" + + # Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly + Path(decrypted_path).write_bytes(encrypted_stream.read()) + + with tarfile.open(name=decrypted_path, mode="r") as tar_file: + member = tar_file.getmember("file.log") + extracted = tar_file.extractfile(member).read() + # The content should match the original content, without the extra bytes + # because the file was read with the original size + assert extracted == content