diff --git a/pyrit/executor/attack/core/attack_executor.py b/pyrit/executor/attack/core/attack_executor.py index 06850fb33..9355f62bd 100644 --- a/pyrit/executor/attack/core/attack_executor.py +++ b/pyrit/executor/attack/core/attack_executor.py @@ -183,7 +183,7 @@ async def execute_attack_from_seed_groups_async( if not seed_groups: raise ValueError("At least one seed_group must be provided") - if field_overrides and len(field_overrides) != len(seed_groups): + if field_overrides is not None and len(field_overrides) != len(seed_groups): raise ValueError( f"field_overrides length ({len(field_overrides)}) must match seed_groups length ({len(seed_groups)})" ) @@ -197,7 +197,7 @@ async def execute_attack_from_seed_groups_async( async def build_params(i: int, sg: SeedAttackGroup) -> AttackParameters: async with semaphore: combined_overrides = dict(broadcast_fields) - if field_overrides: + if field_overrides is not None: combined_overrides.update(field_overrides[i]) return await params_type.from_seed_group_async( seed_group=sg, @@ -248,7 +248,7 @@ async def execute_attack_async( if not objectives: raise ValueError("At least one objective must be provided") - if field_overrides and len(field_overrides) != len(objectives): + if field_overrides is not None and len(field_overrides) != len(objectives): raise ValueError( f"field_overrides length ({len(field_overrides)}) must match objectives length ({len(objectives)})" ) @@ -262,7 +262,7 @@ async def execute_attack_async( fields = dict(broadcast_fields) # Apply per-objective overrides - if field_overrides: + if field_overrides is not None: fields.update(field_overrides[i]) # Add objective diff --git a/tests/unit/executor/attack/core/test_attack_executor.py b/tests/unit/executor/attack/core/test_attack_executor.py index 8bc6b70e5..3fea1106d 100644 --- a/tests/unit/executor/attack/core/test_attack_executor.py +++ b/tests/unit/executor/attack/core/test_attack_executor.py @@ -173,6 +173,19 @@ async def test_validates_field_overrides_length(self): field_overrides=[{}], # Wrong length ) + @pytest.mark.asyncio + async def test_validates_explicit_empty_field_overrides(self): + """Test that explicit empty field_overrides still validate length.""" + attack = create_mock_attack() + executor = AttackExecutor() + + with pytest.raises(ValueError, match="field_overrides length .* must match"): + await executor.execute_attack_async( + attack=attack, + objectives=["Obj1", "Obj2"], + field_overrides=[], + ) + @pytest.mark.asyncio async def test_concurrency_control(self): """Test that concurrency is properly limited.""" @@ -322,6 +335,21 @@ async def capture_from_seed_group_async(*, seed_group, **kwargs): # Restore the original to prevent test pollution in parallel test runs attack.params_type.from_seed_group_async = original_from_seed_group_async + @pytest.mark.asyncio + async def test_validates_explicit_empty_field_overrides_for_seed_groups(self): + """Test that explicit empty field_overrides still validate seed group length.""" + attack = create_mock_attack() + executor = AttackExecutor() + sg1 = create_seed_group("Objective 1") + sg2 = create_seed_group("Objective 2") + + with pytest.raises(ValueError, match="field_overrides length .* must match"): + await executor.execute_attack_from_seed_groups_async( + attack=attack, + seed_groups=[sg1, sg2], + field_overrides=[], + ) + @pytest.mark.usefixtures("patch_central_database") class TestPartialFailureHandling: