Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions pyrit/identifiers/component_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ComponentIdentifier:
KEY_CLASS_NAME: ClassVar[str] = "class_name"
KEY_CLASS_MODULE: ClassVar[str] = "class_module"
KEY_HASH: ClassVar[str] = "hash"
KEY_EVAL_HASH: ClassVar[str] = "eval_hash"
KEY_PYRIT_VERSION: ClassVar[str] = "pyrit_version"
KEY_CHILDREN: ClassVar[str] = "children"
LEGACY_KEY_TYPE: ClassVar[str] = "__type__"
Expand All @@ -127,19 +128,52 @@ class ComponentIdentifier:
#: Named child identifiers for compositional identity (e.g., a scorer's target).
children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = field(default_factory=dict)
#: Content-addressed SHA256 hash computed from class, params, and children.
hash: str = field(init=False, compare=False)
#: When ``None`` (the default), it is computed automatically in ``__post_init__``.
#: Pass an explicit value to preserve a pre-computed hash (e.g. from DB storage
#: where params may have been truncated).
hash: Optional[str] = field(default=None, compare=False)
#: Version tag for storage. Not included in hash.
pyrit_version: str = field(default_factory=lambda: pyrit.__version__, compare=False)
#: Evaluation hash. Computed by EvaluationIdentifier subclasses (e.g. ScorerEvaluationIdentifier)
#: and attached to the identifier so it is always available via ``to_dict()``.
#: Survives DB round-trips even when param values are truncated.
eval_hash: Optional[str] = field(default=None, compare=False)

def __post_init__(self) -> None:
"""Compute the content-addressed hash at creation time."""
hash_dict = _build_hash_dict(
"""Compute the content-addressed hash at creation time if not already provided."""
if self.hash is None:
hash_dict = _build_hash_dict(
class_name=self.class_name,
class_module=self.class_module,
params=self.params,
children=self.children,
)
object.__setattr__(self, "hash", config_hash(hash_dict))

def with_eval_hash(self, eval_hash: str) -> ComponentIdentifier:
"""
Return a new frozen ComponentIdentifier with ``eval_hash`` set.

The original ``hash`` is preserved (important for identifiers
reconstructed from truncated DB data where recomputation would
produce a wrong hash).

Args:
eval_hash: The evaluation hash to attach.

Returns:
A new ComponentIdentifier identical to this one but with
``eval_hash`` set to the given value.
"""
return ComponentIdentifier(
class_name=self.class_name,
class_module=self.class_module,
params=self.params,
children=self.children,
hash=self.hash,
pyrit_version=self.pyrit_version,
eval_hash=eval_hash,
)
object.__setattr__(self, "hash", config_hash(hash_dict))

@property
def short_hash(self) -> str:
Expand Down Expand Up @@ -258,6 +292,9 @@ def to_dict(self, *, max_value_length: Optional[int] = None) -> dict[str, Any]:
self.KEY_PYRIT_VERSION: self.pyrit_version,
}

if self.eval_hash is not None:
result[self.KEY_EVAL_HASH] = self.eval_hash

for key, value in self.params.items():
result[key] = self._truncate_value(value=value, max_length=max_value_length)

Expand Down Expand Up @@ -324,6 +361,7 @@ def from_dict(cls, data: dict[str, Any]) -> ComponentIdentifier:
class_module = data.pop(cls.KEY_CLASS_MODULE, None) or data.pop(cls.LEGACY_KEY_MODULE, None) or "unknown"

stored_hash = data.pop(cls.KEY_HASH, None)
stored_eval_hash = data.pop(cls.KEY_EVAL_HASH, None)
pyrit_version = data.pop(cls.KEY_PYRIT_VERSION, pyrit.__version__)

# Reconstruct children
Expand All @@ -332,22 +370,16 @@ def from_dict(cls, data: dict[str, Any]) -> ComponentIdentifier:
# Everything remaining is a param
params = data

identifier = cls(
return cls(
class_name=class_name,
class_module=class_module,
params=params,
children=children,
hash=stored_hash,
pyrit_version=pyrit_version,
eval_hash=stored_eval_hash,
)

# Preserve stored hash if available — the stored hash was computed from
# untruncated data and is the correct identity. Recomputing from
# potentially truncated DB values would produce a wrong hash.
if stored_hash:
object.__setattr__(identifier, "hash", stored_hash)

return identifier

def get_child(self, key: str) -> Optional[ComponentIdentifier]:
"""
Get a single child by key.
Expand Down
20 changes: 15 additions & 5 deletions pyrit/identifiers/evaluation_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,22 @@ class EvaluationIdentifier(ABC):
CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]]

def __init__(self, identifier: ComponentIdentifier) -> None:
"""Wrap a ComponentIdentifier and eagerly compute its eval hash."""
"""
Wrap a ComponentIdentifier and resolve its eval hash.

If the identifier carries an ``eval_hash`` (preserved from a prior
DB round-trip or set by the scorer), that value is used directly.
Otherwise the eval hash is computed from the identifier's params
and children using the subclass's ``CHILD_EVAL_RULES``.
"""
self._identifier = identifier
self._eval_hash = compute_eval_hash(
identifier,
child_eval_rules=self.CHILD_EVAL_RULES,
)
if identifier.eval_hash is not None:
self._eval_hash = identifier.eval_hash
else:
self._eval_hash = compute_eval_hash(
identifier,
child_eval_rules=self.CHILD_EVAL_RULES,
)

@property
def identifier(self) -> ComponentIdentifier:
Expand Down
35 changes: 31 additions & 4 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
import pyrit
from pyrit.common.utils import to_sha256
from pyrit.identifiers.component_identifier import ComponentIdentifier
from pyrit.identifiers.evaluation_identifier import (
AtomicAttackEvaluationIdentifier,
ScorerEvaluationIdentifier,
)
from pyrit.models import (
AttackOutcome,
AttackResult,
Expand All @@ -51,6 +55,8 @@
SeedType,
)

logger = logging.getLogger(__name__)

# Default pyrit_version for database records created before version tracking was added
LEGACY_PYRIT_VERSION = "<0.10.0"

Expand Down Expand Up @@ -398,7 +404,14 @@ def __init__(self, *, entry: Score):
self.score_metadata = entry.score_metadata
# Normalize to ComponentIdentifier (handles dict with deprecation warning) then convert to dict for JSON storage
normalized_scorer = ComponentIdentifier.normalize(entry.scorer_class_identifier)
self.scorer_class_identifier = normalized_scorer.to_dict(max_value_length=MAX_IDENTIFIER_VALUE_LENGTH)
# Ensure eval_hash is set before truncation so it survives the DB round-trip
if normalized_scorer.eval_hash is None:
normalized_scorer = normalized_scorer.with_eval_hash(
ScorerEvaluationIdentifier(normalized_scorer).eval_hash
)
self.scorer_class_identifier = normalized_scorer.to_dict(
max_value_length=MAX_IDENTIFIER_VALUE_LENGTH,
)
self.prompt_request_response_id = entry.message_piece_id if entry.message_piece_id else None
self.timestamp = entry.timestamp
# Store in both columns for backward compatibility
Expand Down Expand Up @@ -770,8 +783,15 @@ def __init__(self, *, entry: AttackResult):
self.attack_identifier = (
_attack_strategy_id.to_dict(max_value_length=MAX_IDENTIFIER_VALUE_LENGTH) if _attack_strategy_id else {}
)
# Ensure eval_hash is set before truncation so it survives the DB round-trip
if entry.atomic_attack_identifier and entry.atomic_attack_identifier.eval_hash is None:
entry.atomic_attack_identifier = entry.atomic_attack_identifier.with_eval_hash(
AtomicAttackEvaluationIdentifier(entry.atomic_attack_identifier).eval_hash
)
self.atomic_attack_identifier = (
entry.atomic_attack_identifier.to_dict(max_value_length=MAX_IDENTIFIER_VALUE_LENGTH)
entry.atomic_attack_identifier.to_dict(
max_value_length=MAX_IDENTIFIER_VALUE_LENGTH,
)
if entry.atomic_attack_identifier
else None
)
Expand Down Expand Up @@ -974,9 +994,16 @@ def __init__(self, *, entry: ScenarioResult):
self.objective_target_identifier = entry.objective_target_identifier.to_dict(
max_value_length=MAX_IDENTIFIER_VALUE_LENGTH
)
# Convert ComponentIdentifier to dict for JSON storage
# Ensure eval_hash is set before truncation so it survives the DB round-trip.
if entry.objective_scorer_identifier and entry.objective_scorer_identifier.eval_hash is None:
entry.objective_scorer_identifier = entry.objective_scorer_identifier.with_eval_hash(
ScorerEvaluationIdentifier(entry.objective_scorer_identifier).eval_hash
)

self.objective_scorer_identifier = (
entry.objective_scorer_identifier.to_dict(max_value_length=MAX_IDENTIFIER_VALUE_LENGTH)
entry.objective_scorer_identifier.to_dict(
max_value_length=MAX_IDENTIFIER_VALUE_LENGTH,
)
if entry.objective_scorer_identifier
else None
)
Expand Down
11 changes: 9 additions & 2 deletions pyrit/scenario/core/atomic_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyrit.executor.attack import AttackExecutor, AttackStrategy
from pyrit.executor.attack.core.attack_executor import AttackExecutorResult
from pyrit.identifiers import build_atomic_attack_identifier
from pyrit.identifiers.evaluation_identifier import AtomicAttackEvaluationIdentifier
from pyrit.memory import CentralMemory
from pyrit.memory.memory_models import MAX_IDENTIFIER_VALUE_LENGTH
from pyrit.models import AttackResult, SeedAttackGroup
Expand Down Expand Up @@ -251,13 +252,19 @@ def _enrich_atomic_attack_identifiers(self, *, results: AttackExecutorResult[Att
seed_group=self._seed_groups[idx],
)

# Persist the enriched identifier back to the database
# Persist the enriched identifier back to the database.
# Set eval_hash before truncation so it survives the DB round-trip.
if result.atomic_attack_identifier.eval_hash is None:
result.atomic_attack_identifier = result.atomic_attack_identifier.with_eval_hash(
AtomicAttackEvaluationIdentifier(result.atomic_attack_identifier).eval_hash
)

if result.attack_result_id:
memory.update_attack_result_by_id(
attack_result_id=result.attack_result_id,
update_fields={
"atomic_attack_identifier": result.atomic_attack_identifier.to_dict(
max_value_length=MAX_IDENTIFIER_VALUE_LENGTH
max_value_length=MAX_IDENTIFIER_VALUE_LENGTH,
),
},
)
2 changes: 1 addition & 1 deletion pyrit/score/float_scale/float_scale_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_scorer_metrics(self) -> Optional["HarmScorerMetrics"]:
return None

return find_harm_metrics_by_eval_hash(
eval_hash=self.get_eval_hash(),
eval_hash=self.get_identifier().eval_hash,
harm_category=self.evaluation_file_mapping.harm_category,
)

Expand Down
22 changes: 11 additions & 11 deletions pyrit/score/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
pyrit_json_retry,
remove_markdown_json,
)
from pyrit.identifiers import ComponentIdentifier, Identifiable
from pyrit.identifiers import ComponentIdentifier, Identifiable, ScorerEvaluationIdentifier
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import (
ChatMessageRole,
Expand Down Expand Up @@ -70,21 +70,21 @@ def __init__(self, *, validator: ScorerPromptValidator):
"""
self._validator = validator

def get_eval_hash(self) -> str:
def get_identifier(self) -> ComponentIdentifier:
"""
Compute a behavioral equivalence hash for evaluation grouping.
Get the scorer's identifier with eval_hash always attached.

Delegates to ``ScorerEvaluationIdentifier`` which filters target children
(prompt_target, converter_target) to behavioral params only, so the same
scorer configuration on different deployments produces the same eval hash.
Overrides the base ``Identifiable.get_identifier()`` so that
``to_dict()`` always emits the ``eval_hash`` key.

Returns:
str: A hex-encoded SHA256 hash suitable for eval registry keying.
ComponentIdentifier: The identity with ``eval_hash`` set.
"""
# Deferred import to avoid circular dependency (evaluation_identifier → identifiers → …)
from pyrit.identifiers.evaluation_identifier import ScorerEvaluationIdentifier

return ScorerEvaluationIdentifier(self.get_identifier()).eval_hash
identifier = super().get_identifier()
if identifier.eval_hash is None:
identifier = identifier.with_eval_hash(ScorerEvaluationIdentifier(identifier).eval_hash)
self._identifier = identifier
return identifier

@property
def scorer_type(self) -> ScoreType:
Expand Down
4 changes: 2 additions & 2 deletions pyrit/score/scorer_evaluation/scorer_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _should_skip_evaluation(
- (False, None) if should run evaluation
"""
try:
scorer_hash = self.scorer.get_eval_hash()
scorer_hash = self.scorer.get_identifier().eval_hash

# Determine if this is a harm or objective evaluation
metrics_type = MetricsType.OBJECTIVE if isinstance(self.scorer, TrueFalseScorer) else MetricsType.HARM
Expand Down Expand Up @@ -489,7 +489,7 @@ def _write_metrics_to_registry(
replace_evaluation_results(
file_path=result_file_path,
scorer_identifier=self.scorer.get_identifier(),
eval_hash=self.scorer.get_eval_hash(),
eval_hash=self.scorer.get_identifier().eval_hash,
metrics=metrics,
)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion pyrit/score/true_false/true_false_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_scorer_metrics(self) -> Optional["ObjectiveScorerMetrics"]:
if not result_file.exists():
return None

return find_objective_metrics_by_eval_hash(eval_hash=self.get_eval_hash(), file_path=result_file)
return find_objective_metrics_by_eval_hash(eval_hash=self.get_identifier().eval_hash, file_path=result_file)

async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]:
"""
Expand Down
Loading
Loading