diff --git a/src/aedifix/config.py b/src/aedifix/config.py index 2833fba..248c237 100644 --- a/src/aedifix/config.py +++ b/src/aedifix/config.py @@ -29,6 +29,7 @@ class ConfigFile(Configurable): "_cmake_configure_file", "_config_file_template", "_default_subst", + "_project_variables_file", ) def __init__( @@ -45,6 +46,13 @@ def __init__( """ super().__init__(manager=manager) self._config_file_template = config_file_template.resolve() + + config_file = self.template_file + if config_file.suffix == ".in": + # remove the trailing .in + config_file = config_file.with_suffix("") + + self._project_variables_file = self.project_arch_dir / config_file.name self._default_subst = {"PYTHON_EXECUTABLE": sys.executable} @property @@ -72,7 +80,7 @@ def project_variables_file(self) -> Path: The file is not guaranteed to exist, or be up to date. Usually it is created/refreshed during finalization of this object. """ - return self.project_arch_dir / "gmakevariables" + return self._project_variables_file def _read_entire_cmake_cache(self, cmake_cache: Path) -> dict[str, str]: r"""Read a CMakeCache.txt and convert all of the cache values to @@ -164,7 +172,7 @@ def finalize(self) -> None: If the user config file contains an unknown AEDIFIX substitution. """ project_file = self.project_variables_file - template_file = self._config_file_template + template_file = self.template_file self.log(f"Using project file: {project_file}") self.log(f"Using template file: {template_file}") diff --git a/tests/test_config.py b/tests/test_config.py index 7f2d67a..9b68b6b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -32,9 +32,29 @@ def test_create(self, manager: DummyManager, tmp_path: Path) -> None: assert config.template_file.exists() assert config.template_file.is_file() assert config.template_file == template - + assert ( + config.project_variables_file.name == template.with_suffix("").name + ) assert config._default_subst == {"PYTHON_EXECUTABLE": sys.executable} + @pytest.mark.parametrize( + ("base", "expected"), + ( + ("foo.bar.baz.in", "foo.bar.baz"), + ("foo.bar", "foo.bar"), + ("foo", "foo"), + ), + ) + def test_project_variables_file( + self, manager: DummyManager, tmp_path: Path, base: str, expected: str + ) -> None: + template = tmp_path / base + template.touch() + + config = ConfigFile(manager=manager, config_file_template=template) + + assert config.project_variables_file.name == expected + if __name__ == "__main__": sys.exit(pytest.main())