diff --git a/tests/functions/test_file_utils.py b/tests/functions/test_file_utils.py index d3449f76..4e7b1c7e 100644 --- a/tests/functions/test_file_utils.py +++ b/tests/functions/test_file_utils.py @@ -1,22 +1,20 @@ # tests/functions/test_file_utils.py +import pytest +from pathlib import Path +import sys import logging import json -import sys from pathlib import Path from pprint import pprint -import pytest - -from loguru import logger - -# Ensure the repository root is on sys.path so ``ryan_library`` can be imported during tests -REPO_ROOT = Path(__file__).resolve().parents[2] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) +# Ensure the project root is on sys.path for direct test execution +PROJECT_ROOT = Path(__file__).resolve().parents[2] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) # Import the function to be tested -from ryan_library.functions.file_utils import find_files_parallel, is_non_zero_file +from ryan_library.functions.file_utils import ensure_output_directory, find_files_parallel, is_non_zero_file # Configure logging for tests logging.basicConfig( @@ -49,18 +47,18 @@ def setup_test_environment(): return TEST_DATA_DIR -def resolve_paths(relative_paths): +def resolve_paths(relative_paths: list[Path | str]) -> list[Path]: """ Helper function to resolve relative paths to absolute Path objects. """ return [TEST_DATA_DIR / Path(p) for p in relative_paths] -def test_find_files_inclusion_only(setup_test_environment, load_expected_files): +def test_find_files_inclusion_only(setup_test_environment: Path, load_expected_files) -> None: """ Test the find_files_parallel function with only inclusion patterns. """ - root_dir = setup_test_environment + root_dir: str | Path = setup_test_environment # Define inclusion patterns with wildcard include_patterns = "*.hpc.dt.csv" @@ -69,7 +67,7 @@ def test_find_files_inclusion_only(setup_test_environment, load_expected_files): exclude_patterns = None # Call the function - matched_files = find_files_parallel( + matched_files: list[Path] = find_files_parallel( root_dirs=[root_dir], patterns=include_patterns, excludes=exclude_patterns, @@ -329,6 +327,33 @@ def test_find_files_with_report_level(setup_test_environment, load_expected_file # assert any("Searching in folder" in message.message for message in caplog.records), "Expected log messages not found." +def test_ensure_output_directory_creates_missing_nested(tmp_path): + """Ensure missing nested directories are created.""" + + nested_dir = tmp_path / "level_one" / "level_two" + assert not nested_dir.exists() + + ensure_output_directory(nested_dir) + + assert nested_dir.exists() + assert nested_dir.is_dir() + + +def test_ensure_output_directory_idempotent_existing(tmp_path): + """Ensure calling on an existing directory does not recreate it.""" + + existing_dir = tmp_path / "already_there" + existing_dir.mkdir(parents=True) + sentinel_file = existing_dir / "sentinel.txt" + sentinel_file.write_text("sentinel") + + ensure_output_directory(existing_dir) + + assert existing_dir.exists() + assert existing_dir.is_dir() + assert sentinel_file.exists() + + @pytest.mark.parametrize( ("scenario", "expected_result", "expected_log_fragment"), [