From f4b24478b8a75cfd33b5684da49b7aec46d5e51f Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Tue, 17 Mar 2026 07:04:32 +0000 Subject: [PATCH] Normalize SeedPrompt file extension detection --- pyrit/models/seeds/seed_prompt.py | 2 +- tests/unit/models/test_seed.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index b507cf3173..6690a53711 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -62,7 +62,7 @@ def __post_init__(self) -> None: # Note: Does not assign 'error' or 'url' implicitly if os.path.isfile(self.value): _, ext = os.path.splitext(self.value) - ext = ext.lstrip(".") + ext = ext.lstrip(".").lower() if ext in ["mp4", "avi", "mov", "mkv", "ogv", "flv", "wmv", "webm"]: self.data_type = "video_path" elif ext in ["flac", "mp3", "mpeg", "mpga", "m4a", "ogg", "wav"]: diff --git a/tests/unit/models/test_seed.py b/tests/unit/models/test_seed.py index 32414e0f75..d752dfbdff 100644 --- a/tests/unit/models/test_seed.py +++ b/tests/unit/models/test_seed.py @@ -74,6 +74,24 @@ def test_seed_prompt_initialization(seed_prompt_fixture): assert seed_prompt_fixture.parameters == ["param1"] +@pytest.mark.parametrize( + ("suffix", "expected_data_type"), + [ + (".PNG", "image_path"), + (".WAV", "audio_path"), + ], +) +def test_seed_prompt_infers_file_type_from_uppercase_extension(suffix, expected_data_type): + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file: + file_path = temp_file.name + + try: + seed_prompt = SeedPrompt(value=file_path) + assert seed_prompt.data_type == expected_data_type + finally: + os.remove(file_path) + + def test_seed_prompt_render_template_success(seed_prompt_fixture): seed_prompt_fixture.value = "Test prompt with param1={{ param1 }}" result = seed_prompt_fixture.render_template_value(param1="value1")