Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyrit/score/audio_transcript_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import tempfile
import uuid
from abc import ABC
from typing import Optional

import av
Expand Down Expand Up @@ -88,7 +87,7 @@ def _audio_to_wav(input_path: str, *, sample_rate: int, channels: int) -> str:
return output_path


class AudioTranscriptHelper(ABC): # noqa: B024
class AudioTranscriptHelper: # noqa: B024
"""
Abstract base class for audio scorers that process audio by transcribing and scoring the text.

Expand Down
22 changes: 10 additions & 12 deletions pyrit/score/float_scale/video_float_scale_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
)
from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer
from pyrit.score.scorer_prompt_validator import ScorerPromptValidator
from pyrit.score.video_scorer import _BaseVideoScorer
from pyrit.score.video_scorer import VideoHelper

if TYPE_CHECKING:
from pyrit.score.score_aggregator_result import ScoreAggregatorResult


class VideoFloatScaleScorer(
FloatScaleScorer,
_BaseVideoScorer,
):
"""
A scorer that processes videos by extracting frames and scoring them using a float scale image scorer.
Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(
num_sampled_frames: Optional[int] = None,
validator: Optional[ScorerPromptValidator] = None,
score_aggregator: FloatScaleAggregatorFunc = FloatScaleScorerByCategory.MAX,
image_objective_template: Optional[str] = _BaseVideoScorer._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE,
image_objective_template: Optional[str] = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE,
audio_objective_template: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -82,8 +81,7 @@ def __init__(
"""
FloatScaleScorer.__init__(self, validator=validator or self._DEFAULT_VALIDATOR)

_BaseVideoScorer.__init__(
self,
self._video_helper = VideoHelper(
image_capable_scorer=image_capable_scorer,
num_sampled_frames=num_sampled_frames,
image_objective_template=image_objective_template,
Expand All @@ -92,7 +90,7 @@ def __init__(
self._score_aggregator = score_aggregator

if audio_scorer is not None:
self._validate_audio_scorer(audio_scorer)
VideoHelper._validate_audio_scorer(audio_scorer)
self.audio_scorer = audio_scorer

def _build_identifier(self) -> ComponentIdentifier:
Expand All @@ -102,17 +100,17 @@ def _build_identifier(self) -> ComponentIdentifier:
Returns:
ComponentIdentifier: The identifier for this scorer.
"""
sub_scorer_ids = [self.image_scorer.get_identifier()]
sub_scorer_ids = [self._video_helper.image_scorer.get_identifier()]
if self.audio_scorer:
sub_scorer_ids.append(self.audio_scorer.get_identifier())

return self._create_identifier(
params={
"score_aggregator": self._score_aggregator.__name__,
"num_sampled_frames": self.num_sampled_frames,
"num_sampled_frames": self._video_helper.num_sampled_frames,
"has_audio_scorer": self.audio_scorer is not None,
"image_objective_template": self.image_objective_template,
"audio_objective_template": self.audio_objective_template,
"image_objective_template": self._video_helper.image_objective_template,
"audio_objective_template": self._video_helper.audio_objective_template,
},
children={
"sub_scorers": sub_scorer_ids,
Expand All @@ -131,14 +129,14 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op
List of aggregated scores for the video. Returns one score if using FloatScaleScoreAggregator,
or multiple scores (one per category) if using FloatScaleScorerByCategory.
"""
frame_scores = await self._score_frames_async(message_piece=message_piece, objective=objective)
frame_scores = await self._video_helper._score_frames_async(message_piece=message_piece, objective=objective)

all_scores = list(frame_scores)
audio_scored = False

# Score audio if audio_scorer is provided
if self.audio_scorer:
audio_scores = await self._score_video_audio_async(
audio_scores = await self._video_helper._score_video_audio_async(
message_piece=message_piece, audio_scorer=self.audio_scorer, objective=objective
)
if audio_scores:
Expand Down
27 changes: 13 additions & 14 deletions pyrit/score/true_false/video_true_false_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from pyrit.score.scorer_prompt_validator import ScorerPromptValidator
from pyrit.score.true_false.true_false_score_aggregator import TrueFalseScoreAggregator
from pyrit.score.true_false.true_false_scorer import TrueFalseScorer
from pyrit.score.video_scorer import _BaseVideoScorer
from pyrit.score.video_scorer import VideoHelper


class VideoTrueFalseScorer(TrueFalseScorer, _BaseVideoScorer):
class VideoTrueFalseScorer(TrueFalseScorer):
"""
A scorer that processes videos by extracting frames and scoring them using a true/false image scorer.

Expand All @@ -34,7 +34,7 @@ def __init__(
audio_scorer: Optional[TrueFalseScorer] = None,
num_sampled_frames: Optional[int] = None,
validator: Optional[ScorerPromptValidator] = None,
image_objective_template: Optional[str] = _BaseVideoScorer._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE,
image_objective_template: Optional[str] = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE,
audio_objective_template: Optional[str] = None,
) -> None:
"""
Expand All @@ -59,18 +59,17 @@ def __init__(
Raises:
ValueError: If audio_scorer is provided and does not support audio_path data type.
"""
_BaseVideoScorer.__init__(
self,
super().__init__(validator=validator or self._DEFAULT_VALIDATOR)

self._video_helper = VideoHelper(
image_capable_scorer=image_capable_scorer,
num_sampled_frames=num_sampled_frames,
image_objective_template=image_objective_template,
audio_objective_template=audio_objective_template,
)

TrueFalseScorer.__init__(self, validator=validator or self._DEFAULT_VALIDATOR)

if audio_scorer is not None:
self._validate_audio_scorer(audio_scorer)
VideoHelper._validate_audio_scorer(audio_scorer)
self.audio_scorer = audio_scorer

def _build_identifier(self) -> ComponentIdentifier:
Expand All @@ -80,16 +79,16 @@ def _build_identifier(self) -> ComponentIdentifier:
Returns:
ComponentIdentifier: The identifier for this scorer.
"""
sub_scorer_ids = [self.image_scorer.get_identifier()]
sub_scorer_ids = [self._video_helper.image_scorer.get_identifier()]
if self.audio_scorer:
sub_scorer_ids.append(self.audio_scorer.get_identifier())

return self._create_identifier(
params={
"num_sampled_frames": self.num_sampled_frames,
"num_sampled_frames": self._video_helper.num_sampled_frames,
"has_audio_scorer": self.audio_scorer is not None,
"image_objective_template": self.image_objective_template,
"audio_objective_template": self.audio_objective_template,
"image_objective_template": self._video_helper.image_objective_template,
"audio_objective_template": self._video_helper.audio_objective_template,
},
children={
"sub_scorers": sub_scorer_ids,
Expand All @@ -114,7 +113,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op
piece_id = message_piece.id if message_piece.id is not None else message_piece.original_prompt_id

# Get scores for all frames and aggregate with OR (True if ANY frame matches)
frame_scores = await self._score_frames_async(message_piece=message_piece, objective=objective)
frame_scores = await self._video_helper._score_frames_async(message_piece=message_piece, objective=objective)
frame_result = TrueFalseScoreAggregator.OR(frame_scores)

# Create a Score from the frame aggregation result
Expand All @@ -132,7 +131,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op

# Score audio if audio_scorer is provided
if self.audio_scorer:
audio_scores = await self._score_video_audio_async(
audio_scores = await self._video_helper._score_video_audio_async(
message_piece=message_piece, audio_scorer=self.audio_scorer, objective=objective
)
if audio_scores:
Expand Down
9 changes: 4 additions & 5 deletions pyrit/score/video_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import random
import tempfile
import uuid
from abc import ABC
from typing import Optional

from pyrit.memory import CentralMemory
Expand All @@ -17,13 +16,13 @@
logger = logging.getLogger(__name__)


class _BaseVideoScorer(ABC): # noqa: B024
class VideoHelper:
"""
Abstract base class for video scorers that process videos by extracting frames and scoring them.
Helper class for video scorers that process videos by extracting frames and scoring them.

This class provides common functionality for extracting frames from videos and delegating
scoring to an image-capable scorer. Concrete implementations handle aggregation logic
specific to their scoring type (true/false or float scale).
scoring to an image-capable scorer. Used via composition by VideoTrueFalseScorer and
VideoFloatScaleScorer.
"""

_DEFAULT_VIDEO_FRAMES_SAMPLING_NUM = 5
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/score/test_video_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ async def test_extract_frames_true_false(video_converter_sample_video):
image_scorer = MockTrueFalseScorer()
scorer = VideoTrueFalseScorer(image_capable_scorer=image_scorer, num_sampled_frames=3)
video_path = video_converter_sample_video.converted_value
frame_paths = scorer._extract_frames(video_path=video_path)
frame_paths = scorer._video_helper._extract_frames(video_path=video_path)

assert len(frame_paths) == scorer.num_sampled_frames, (
f"Expected {scorer.num_sampled_frames} frames, got {len(frame_paths)}"
assert len(frame_paths) == scorer._video_helper.num_sampled_frames, (
f"Expected {scorer._video_helper.num_sampled_frames} frames, got {len(frame_paths)}"
)

# Verify frames are valid images and cleanup
Expand All @@ -159,10 +159,10 @@ async def test_extract_frames_float_scale(video_converter_sample_video):
image_scorer = MockFloatScaleScorer()
scorer = VideoFloatScaleScorer(image_capable_scorer=image_scorer, num_sampled_frames=3)
video_path = video_converter_sample_video.converted_value
frame_paths = scorer._extract_frames(video_path=video_path)
frame_paths = scorer._video_helper._extract_frames(video_path=video_path)

assert len(frame_paths) == scorer.num_sampled_frames, (
f"Expected {scorer.num_sampled_frames} frames, got {len(frame_paths)}"
assert len(frame_paths) == scorer._video_helper.num_sampled_frames, (
f"Expected {scorer._video_helper.num_sampled_frames} frames, got {len(frame_paths)}"
)

# Verify frames are valid images and cleanup
Expand Down Expand Up @@ -228,7 +228,7 @@ async def test_score_video_no_frames(video_converter_sample_video):
scorer = VideoTrueFalseScorer(image_capable_scorer=image_scorer, num_sampled_frames=3)

# Mock _extract_frames to return empty list
scorer._extract_frames = MagicMock(return_value=[])
scorer._video_helper._extract_frames = MagicMock(return_value=[])

with pytest.raises(ValueError, match="No frames extracted from video for scoring."):
await scorer._score_piece_async(video_converter_sample_video)
Expand Down Expand Up @@ -292,10 +292,10 @@ def test_video_scorer_default_num_frames():
image_scorer = MockTrueFalseScorer()
scorer = VideoTrueFalseScorer(image_capable_scorer=image_scorer)

assert scorer.num_sampled_frames == 5 # Default value
assert scorer._video_helper.num_sampled_frames == 5 # Default value


class MockAudioTrueFalseScorer(TrueFalseScorer, AudioTranscriptHelper):
class MockAudioTrueFalseScorer(TrueFalseScorer):
"""Mock AudioTrueFalseScorer for testing video+audio integration"""

def __init__(self, return_value: bool = True):
Expand Down
Loading