diff --git a/plain2code_arguments.py b/plain2code_arguments.py index 992ae4e..3e3217f 100644 --- a/plain2code_arguments.py +++ b/plain2code_arguments.py @@ -2,6 +2,8 @@ import os import re +from plain2code_console import console +from plain2code_exceptions import AmbiguousConfigFileError from plain2code_read_config import get_args_from_config CODEPLAIN_API_KEY = os.getenv("CODEPLAIN_API_KEY") @@ -83,9 +85,51 @@ def frid_range_string(s): return s +def resolve_config_file(config_name: str, plain_file_path: str): + """ + Resolve the config file path by searching in two locations: + 1. Directory of the plain file + 2. Current working directory (where render is called from) + + Returns the resolved absolute path, or None if the file is not found in either location. + Raises AmbiguousConfigFileError if the file exists in both locations (and they differ). + """ + plain_file_dir = os.path.dirname(os.path.abspath(plain_file_path)) + cwd = os.getcwd() + + plain_dir_config = os.path.normpath(os.path.join(plain_file_dir, config_name)) + cwd_config = os.path.normpath(os.path.join(cwd, config_name)) + + in_plain_dir = os.path.exists(plain_dir_config) + in_cwd = os.path.exists(cwd_config) + same_location = plain_dir_config == cwd_config + + if in_plain_dir and in_cwd and not same_location: + raise AmbiguousConfigFileError( + f"Config file '{config_name}' was found in two locations:\n" + f" - Plain file directory: {plain_file_dir}\n" + f" - Current working directory: {cwd}\n" + f"Remove the config file from one of these locations to resolve the ambiguity." + ) + + if in_plain_dir: + return plain_dir_config + if in_cwd: + return cwd_config + return None + + def update_args_with_config(args, parser): try: - config_args = get_args_from_config(args.config_name, parser) + resolved_config = resolve_config_file(args.config_name, args.filename) + + if resolved_config is None: + console.info(f"No config file '{args.config_name}' found. Proceeding without one.") + return args + + args.config_name = resolved_config + config_args = get_args_from_config(resolved_config, parser) + # Get all action types from the parser action_types = {action.dest: action for action in parser._actions} @@ -109,6 +153,8 @@ def update_args_with_config(args, parser): else: parser.error(f"Invalid argument: {key}") + except AmbiguousConfigFileError as e: + parser.error(str(e)) except Exception as e: parser.error(f"Error reading config file: {str(e)}") @@ -152,7 +198,7 @@ def create_parser(): "--config-name", type=non_empty_string, default="config.yaml", - help="Path to the config file, defaults to config.yaml", + help="Name of the config file to look for. Looked up in the plain file directory and the current working directory. Defaults to config.yaml.", ) render_range_group = parser.add_mutually_exclusive_group() diff --git a/plain2code_exceptions.py b/plain2code_exceptions.py index 3a307ed..d75086b 100644 --- a/plain2code_exceptions.py +++ b/plain2code_exceptions.py @@ -73,3 +73,9 @@ class NetworkConnectionError(Exception): """Raised when there is a network connectivity issue with the API server.""" pass + + +class AmbiguousConfigFileError(Exception): + """Raised when a config file is found in both the plain file directory and the current working directory.""" + + pass diff --git a/plain2code_read_config.py b/plain2code_read_config.py index 192692b..0956743 100644 --- a/plain2code_read_config.py +++ b/plain2code_read_config.py @@ -1,4 +1,3 @@ -import os from argparse import ArgumentParser, Namespace from typing import Any, Dict @@ -46,11 +45,6 @@ def get_args_from_config(config_file: str, parser: ArgumentParser) -> Namespace: args = Namespace() - if config_file == "config.yaml": - if not os.path.exists(config_file): - console.info(f"Default config file {config_file} not found. No config file is read.") - return args - # Load config config = load_config(config_file) config = validate_config(config, parser) diff --git a/tests/test_resolve_config_file.py b/tests/test_resolve_config_file.py new file mode 100644 index 0000000..93c1363 --- /dev/null +++ b/tests/test_resolve_config_file.py @@ -0,0 +1,89 @@ +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from plain2code_arguments import resolve_config_file +from plain2code_exceptions import AmbiguousConfigFileError + + +@pytest.fixture +def two_dirs(): + """Provide two separate temporary directories: one for the plain file, one as CWD.""" + with tempfile.TemporaryDirectory() as plain_dir: + with tempfile.TemporaryDirectory() as cwd: + yield plain_dir, cwd + + +def _plain_file(plain_dir): + return os.path.join(plain_dir, "module.plain") + + +def test_config_in_plain_file_dir_only(two_dirs): + plain_dir, cwd = two_dirs + config = Path(plain_dir) / "config.yaml" + config.write_text("verbose: true\n") + + with patch("os.getcwd", return_value=cwd): + result = resolve_config_file("config.yaml", _plain_file(plain_dir)) + + assert result == os.path.normpath(str(config)) + + +def test_config_in_cwd_only(two_dirs): + plain_dir, cwd = two_dirs + config = Path(cwd) / "config.yaml" + config.write_text("verbose: true\n") + + with patch("os.getcwd", return_value=cwd): + result = resolve_config_file("config.yaml", _plain_file(plain_dir)) + + assert result == os.path.normpath(str(config)) + + +def test_config_in_both_locations_raises(two_dirs): + plain_dir, cwd = two_dirs + (Path(plain_dir) / "config.yaml").write_text("verbose: true\n") + (Path(cwd) / "config.yaml").write_text("verbose: false\n") + + with patch("os.getcwd", return_value=cwd): + with pytest.raises(AmbiguousConfigFileError) as exc_info: + resolve_config_file("config.yaml", _plain_file(plain_dir)) + + assert plain_dir in str(exc_info.value) + assert cwd in str(exc_info.value) + + +def test_config_not_found_returns_none(two_dirs): + plain_dir, cwd = two_dirs + + with patch("os.getcwd", return_value=cwd): + result = resolve_config_file("config.yaml", _plain_file(plain_dir)) + + assert result is None + + +def test_config_same_dir_no_error(): + """When the plain file and CWD are in the same directory, a single config file is fine.""" + with tempfile.TemporaryDirectory() as d: + config = Path(d) / "config.yaml" + config.write_text("verbose: true\n") + + with patch("os.getcwd", return_value=d): + result = resolve_config_file("config.yaml", os.path.join(d, "module.plain")) + + assert result == os.path.normpath(str(config)) + + +def test_custom_config_name_found_in_plain_file_dir(two_dirs): + """A custom --config-name is also looked up in the two locations.""" + plain_dir, cwd = two_dirs + config = Path(plain_dir) / "myconfig.yaml" + config.write_text("verbose: true\n") + + with patch("os.getcwd", return_value=cwd): + result = resolve_config_file("myconfig.yaml", _plain_file(plain_dir)) + + assert result == os.path.normpath(str(config))