diff --git a/render_machine/conformance_tests.py b/render_machine/conformance_tests.py index ca9f3bd..961add5 100644 --- a/render_machine/conformance_tests.py +++ b/render_machine/conformance_tests.py @@ -26,7 +26,8 @@ def get_module_conformance_tests_folder(self, module_name: str) -> str: def _get_full_conformance_tests_definition_file_name(self, module_name: str) -> str: return os.path.join( - self.get_module_conformance_tests_folder(module_name), self.conformance_tests_definition_file_name + self.get_module_conformance_tests_folder(module_name), + self.conformance_tests_definition_file_name, ) def get_conformance_tests_json(self, module_name: str) -> dict: @@ -112,7 +113,10 @@ def store_conformance_tests_files( [source_conformance_test_folder_name, new_conformance_test_folder_name] = ( self.get_source_conformance_test_folder_name( - module_name, required_modules, current_testing_module_name, current_conformance_test_folder_name + module_name, + required_modules, + current_testing_module_name, + current_conformance_test_folder_name, ) ) @@ -120,7 +124,12 @@ def store_conformance_tests_files( console.info( f"Creating folder {new_conformance_test_folder_name} for a copy of conformance tests {source_conformance_test_folder_name}" ) - file_utils.copy_folder_content(source_conformance_test_folder_name, new_conformance_test_folder_name) + + if not os.path.exists(new_conformance_test_folder_name): + file_utils.copy_folder_content( + source_conformance_test_folder_name, + new_conformance_test_folder_name, + ) current_conformance_test_folder_name = new_conformance_test_folder_name @@ -147,7 +156,10 @@ def fetch_existing_conformance_test_files( ) -> tuple[list[str], dict[str, str]]: if module_name != current_testing_module_name: [current_conformance_test_folder_name, _] = self.get_source_conformance_test_folder_name( - module_name, required_modules, current_testing_module_name, current_conformance_test_folder_name + module_name, + required_modules, + current_testing_module_name, + current_conformance_test_folder_name, ) existing_conformance_test_files = file_utils.list_all_text_files(current_conformance_test_folder_name) diff --git a/render_machine/render_context.py b/render_machine/render_context.py index ac37c09..31fc151 100644 --- a/render_machine/render_context.py +++ b/render_machine/render_context.py @@ -209,15 +209,21 @@ def _get_first_frid_conformance_test_running_context(self, module: PlainModule | if module is None: conformance_tests_running_context.current_testing_module_name = self.module_name + if not conformance_tests_running_context.conformance_tests_json_has_module_populated( + conformance_tests_running_context.current_testing_module_name + ): + conformance_tests_running_context.set_conformance_tests_json( + conformance_tests_running_context.current_testing_module_name, + {}, + ) else: conformance_tests_running_context.current_testing_module_name = module.name - - conformance_tests_running_context.set_conformance_tests_json( - conformance_tests_running_context.current_testing_module_name, - self.conformance_tests.get_conformance_tests_json( - conformance_tests_running_context.current_testing_module_name - ), - ) + conformance_tests_running_context.set_conformance_tests_json( + conformance_tests_running_context.current_testing_module_name, + self.conformance_tests.get_conformance_tests_json( + conformance_tests_running_context.current_testing_module_name + ), + ) if module is None: conformance_tests_running_context.current_testing_frid = plain_spec.get_first_frid(self.plain_source_tree) diff --git a/render_machine/render_types.py b/render_machine/render_types.py index 1e40431..58ef1c9 100644 --- a/render_machine/render_types.py +++ b/render_machine/render_types.py @@ -54,6 +54,9 @@ def __init__( def get_conformance_tests_json(self, module_name: str) -> dict: return self._conformance_tests_json[module_name] + def conformance_tests_json_has_module_populated(self, module_name: str) -> bool: + return module_name in self._conformance_tests_json and len(self._conformance_tests_json[module_name]) > 0 + def set_conformance_tests_json(self, module_name: str, conformance_tests_json: dict): self._conformance_tests_json[module_name] = conformance_tests_json @@ -76,7 +79,7 @@ def get_current_acceptance_tests(self) -> Optional[list[str]]: plain_spec.ACCEPTANCE_TESTS ] - return None + return [] def get_current_acceptance_test(self) -> Optional[str]: """Get the current acceptance test text (raw, unformatted)."""