diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 0aff8deaf..ea34362c3 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -157,7 +157,12 @@ def _load_initializers_from_scripts( obj = getattr(module, name) # Check if it's a class, is a subclass of PyRITInitializer, # and is not the base class itself - if isinstance(obj, type) and issubclass(obj, PyRITInitializer) and obj is not PyRITInitializer: + if ( + isinstance(obj, type) + and issubclass(obj, PyRITInitializer) + and obj is not PyRITInitializer + and obj.__module__ == module.__name__ + ): try: # Instantiate the initializer class initializer = obj() diff --git a/tests/unit/setup/test_initialization.py b/tests/unit/setup/test_initialization.py index 9c386ba68..6167d25f0 100644 --- a/tests/unit/setup/test_initialization.py +++ b/tests/unit/setup/test_initialization.py @@ -53,6 +53,59 @@ def test_script_not_found_raises_error(self): with pytest.raises(FileNotFoundError): _load_initializers_from_scripts(script_paths=["nonexistent_script.py"]) + def test_ignores_imported_initializer_classes(self): + """Test that imported initializer classes are not instantiated from the script.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + helper_path = temp_path / "helper_init.py" + script_path = temp_path / "script_init.py" + + helper_path.write_text( + """ +from pyrit.setup.initializers import PyRITInitializer + +class ImportedInitializer(PyRITInitializer): + @property + def name(self) -> str: + return "Imported" + + @property + def description(self) -> str: + return "Imported initializer" + + async def initialize_async(self) -> None: + pass +""" + ) + + script_path.write_text( + f""" +import sys + +sys.path.insert(0, {temp_dir!r}) + +from helper_init import ImportedInitializer +from pyrit.setup.initializers import PyRITInitializer + +class LocalInitializer(PyRITInitializer): + @property + def name(self) -> str: + return "Local" + + @property + def description(self) -> str: + return "Local initializer" + + async def initialize_async(self) -> None: + pass +""" + ) + + initializers = _load_initializers_from_scripts(script_paths=[script_path]) + + assert len(initializers) == 1 + assert initializers[0].name == "Local" + class TestInitializePyrit: """Tests for initialize_pyrit_async function - basic orchestration tests."""