diff --git a/plexus/docs/score-yaml-format.md b/plexus/docs/score-yaml-format.md index 1c2dbcb33..f2f0f23c2 100644 --- a/plexus/docs/score-yaml-format.md +++ b/plexus/docs/score-yaml-format.md @@ -310,6 +310,61 @@ data: balance: false # Whether to balance positive/negative examples ``` +## Validation + +Validation rules throw exceptions if score results don't match expected formats. + +```yaml +validation: + value: + valid_classes: ["Yes", "No", "Maybe"] # Must be one of these (case-insensitive by default) + patterns: ["^NQ - (?!Other$).*"] # Must match regex (NQ- anything except "Other") + minimum_length: 1 # Min string length + maximum_length: 50 # Max string length + case_sensitive: false # Case-insensitive comparison (default: false) + explanation: + minimum_length: 10 + patterns: [".*evidence.*", ".*found.*"] # Must contain "evidence" or "found" +``` + +**All constraints must pass (AND logic). Validation runs automatically in `predict()`.** + +### Case Sensitivity + +By default, `valid_classes` comparisons are **case-insensitive** for convenience: + +```yaml +validation: + value: + valid_classes: ["Yes", "No"] # Matches "yes", "YES", "Yes", "no", "NO", "No" +``` + +Enable case-sensitive comparison when exact case matters: + +```yaml +validation: + value: + valid_classes: ["Yes", "No"] + case_sensitive: true # Only matches exact case: "Yes", "No" +``` + +```yaml +# Simple cases +validation: + value: + valid_classes: ["Yes", "No"] + +validation: + value: + patterns: ["^Critical - .*"] + +# Mixed validation - value must be in list AND match pattern +validation: + value: + valid_classes: ["Yes", "No", "NQ - Pricing"] + patterns: ["^(Yes|No)$", "^NQ - (?!Other$).*"] +``` + ## Best Practices 1. Use modern `Classifier` over legacy classifier types @@ -317,4 +372,5 @@ data: 3. Use descriptive node names and field mappings 4. Leverage slicers for complex transcript analysis 5. Include clear system and user prompts -6. Avoid redundant processing by sharing results between nodes \ No newline at end of file +6. Avoid redundant processing by sharing results between nodes +7. Add validation constraints to ensure consistent score outputs \ No newline at end of file diff --git a/plexus/scores/Score.py b/plexus/scores/Score.py index 14472965b..1da273f4a 100644 --- a/plexus/scores/Score.py +++ b/plexus/scores/Score.py @@ -4,9 +4,9 @@ import pandas as pd import numpy as np import inspect -import functools +import re from pydantic import BaseModel, ValidationError, field_validator, ConfigDict -from typing import Optional, Union, List, Any +from typing import Optional, Union, List, Any, Dict from abc import ABC, abstractmethod import matplotlib.pyplot as plt import seaborn as sns @@ -31,6 +31,8 @@ from plexus.scores.core.ScoreMLFlow import ScoreMLFlow from plexus.scores.core.utils import ensure_report_directory_exists + + class Score(ABC, mlflow.pyfunc.PythonModel, # Core Score functionality. ScoreData, @@ -55,8 +57,23 @@ class Score(ABC, mlflow.pyfunc.PythonModel, - Visualization tools for model performance - Cost tracking for API-based models - Metrics computation and logging + - Automatic result validation based on YAML configuration + + ## Validation System + + The Score class automatically validates prediction results against configured + validation rules. When a Score has validation configured in its YAML parameters, + every call to predict() will automatically validate the returned results. + + Validation rules can specify: + - valid_classes: List of allowed result values + - patterns: Regex patterns that results must match + - minimum_length/maximum_length: String length constraints + + If validation fails, a Score.ValidationError is raised with a descriptive message. + + ## Common Usage Patterns - Common usage patterns: 1. Creating a custom classifier: class MyClassifier(Score): def predict(self, context, model_input: Score.Input) -> Score.Result: @@ -67,14 +84,18 @@ def predict(self, context, model_input: Score.Input) -> Score.Result: value="Yes" if is_positive(text) else "No" ) - 2. Using in a Scorecard: + 2. Using in a Scorecard with validation: scores: MyScore: class: MyClassifier + validation: + value: + valid_classes: ["Yes", "No", "Maybe"] + patterns: ["^NQ - (?!Other$).*"] parameters: threshold: 0.8 - 3. Training a model: + 3. Training and evaluating: classifier = MyClassifier() classifier.train_model() classifier.evaluate_model() @@ -85,11 +106,20 @@ def predict(self, context, model_input: Score.Input) -> Score.Result: text="content to classify", metadata={"source": "email"} )) + # Result is automatically validated against configured rules The Score class is designed to be extended for different classification approaches while maintaining a consistent interface for use in Scorecards and Evaluations. """ + class ValidationError(Exception): + """Raised when score result validation fails.""" + def __init__(self, field_name: str, value: Any, message: str): + self.field_name = field_name + self.value = value + self.message = message + super().__init__(f"Validation failed for field '{field_name}': {message}") + class SkippedScoreException(Exception): """Raised when a score is skipped due to dependency conditions not being met.""" def __init__(self, score_name: str, reason: str): @@ -97,6 +127,19 @@ def __init__(self, score_name: str, reason: str): self.reason = reason super().__init__(f"Score '{score_name}' was skipped: {reason}") + class FieldValidation(BaseModel): + """Configuration for validating a specific field.""" + valid_classes: Optional[List[str]] = None + patterns: Optional[List[str]] = None + minimum_length: Optional[int] = None + maximum_length: Optional[int] = None + case_sensitive: bool = False + + class ValidationConfig(BaseModel): + """Configuration for result validation.""" + value: Optional['Score.FieldValidation'] = None + explanation: Optional['Score.FieldValidation'] = None + class Parameters(BaseModel): """ Parameters required for scoring. @@ -105,6 +148,8 @@ class Parameters(BaseModel): ---------- data : dict Dictionary containing data-related parameters. + validation : ValidationConfig + Configuration for validating score results. """ model_config = ConfigDict(protected_namespaces=()) scorecard_name: Optional[str] = None @@ -116,6 +161,7 @@ class Parameters(BaseModel): number_of_classes: Optional[int] = None label_score_name: Optional[str] = None label_field: Optional[str] = None + validation: Optional['Score.ValidationConfig'] = None @field_validator('data') def convert_data_percentage(cls, value): @@ -175,7 +221,7 @@ class Input(BaseModel): ) """ model_config = ConfigDict(protected_namespaces=()) - text: str + text: str metadata: dict = {} results: Optional[List[Any]] = None @@ -231,11 +277,11 @@ class Result(BaseModel): """ model_config = ConfigDict(protected_namespaces=()) parameters: 'Score.Parameters' - value: Union[str, bool] + value: Union[str, bool] explanation: Optional[str] = None - confidence: Optional[float] = None - metadata: dict = {} - error: Optional[str] = None + confidence: Optional[float] = None + metadata: dict = {} + error: Optional[str] = None def __eq__(self, other): if isinstance(other, Score.Result): @@ -260,6 +306,110 @@ def confidence_from_metadata(self) -> Optional[float]: """Backwards compatibility: confidence from metadata""" return self.metadata.get('confidence') if self.metadata else None + def validate(self, validation_config: Optional['Score.ValidationConfig']) -> None: + """ + Validate this result against the provided validation configuration. + + Parameters + ---------- + validation_config : Score.ValidationConfig, optional + The validation configuration to check against + + Raises + ------ + Score.ValidationError + If validation fails for any field + """ + if not validation_config: + return + + # Validate value field + if validation_config.value: + self._validate_field('value', str(self.value), validation_config.value) + + # Validate explanation field + if validation_config.explanation and self.explanation is not None: + self._validate_field('explanation', self.explanation, validation_config.explanation) + + def _validate_field(self, field_name: str, field_value: str, field_config: 'Score.FieldValidation') -> None: + """ + Validate a single field against its configuration. + + Parameters + ---------- + field_name : str + Name of the field being validated + field_value : str + Value of the field being validated + field_config : Score.FieldValidation + Validation configuration for this field + + Raises + ------ + Score.ValidationError + If validation fails + """ + # Check valid_classes + if field_config.valid_classes is not None: + if field_config.case_sensitive: + # Case-sensitive comparison + if field_value not in field_config.valid_classes: + raise Score.ValidationError( + field_name, + field_value, + f"'{field_value}' is not in valid_classes {field_config.valid_classes}" + ) + else: + # Case-insensitive comparison (default) + field_value_lower = field_value.lower() + valid_classes_lower = [cls.lower() for cls in field_config.valid_classes] + if field_value_lower not in valid_classes_lower: + raise Score.ValidationError( + field_name, + field_value, + f"'{field_value}' is not in valid_classes {field_config.valid_classes} (case-insensitive)" + ) + + # Check patterns + if field_config.patterns is not None: + pattern_matched = False + for pattern in field_config.patterns: + try: + if re.match(pattern, field_value): + pattern_matched = True + break + except re.error as e: + raise Score.ValidationError( + field_name, + field_value, + f"Invalid regex pattern '{pattern}': {str(e)}" + ) + + if not pattern_matched: + raise Score.ValidationError( + field_name, + field_value, + f"'{field_value}' does not match any required patterns {field_config.patterns}" + ) + + # Check minimum_length + if field_config.minimum_length is not None: + if len(field_value) < field_config.minimum_length: + raise Score.ValidationError( + field_name, + field_value, + f"length {len(field_value)} is below minimum_length {field_config.minimum_length}" + ) + + # Check maximum_length + if field_config.maximum_length is not None: + if len(field_value) > field_config.maximum_length: + raise Score.ValidationError( + field_name, + field_value, + f"length {len(field_value)} exceeds maximum_length {field_config.maximum_length}" + ) + def __init__(self, **parameters): """ Initialize the Score instance with the given parameters. @@ -453,6 +603,103 @@ def predict(self, context, model_input: Input) -> Union[Result, List[Result]]: Either a single Score.Result or a list of Score.Results """ pass + + def __getattribute__(self, name): + """ + Automatic validation interceptor for predict() method calls. + + This method intercepts all attribute access on Score instances. When the + 'predict' method is accessed, it returns a wrapped version that automatically + applies validation to the results. + + ## How It Works + + 1. Normal attribute access (non-predict methods) passes through unchanged + 2. When predict() is accessed, we wrap it with validation logic + 3. The wrapper calls the original predict() method implementation + 4. If validation is configured, results are validated before returning + 5. If validation fails, Score.ValidationError is raised + + ## Compatibility + + The wrapper handles different predict() method signatures: + - Standard: predict(self, context, model_input) + - Keyword-only: predict(self, *, context, model_input) + + This ensures compatibility with existing Score implementations while + providing automatic validation for all subclasses. + + Parameters + ---------- + name : str + Name of the attribute being accessed + + Returns + ------- + Any + The requested attribute, with predict() methods automatically wrapped + for validation + """ + attr = super().__getattribute__(name) + + # If it's the predict method, wrap it with validation + if name == 'predict' and callable(attr): + original_predict = attr + + def validated_predict(*args, **kwargs): + # Handle different predict method signatures: + # 1. predict(context, model_input) - standard signature + # 2. predict(model_input) - legacy signature + # 3. predict(*, context, model_input) - keyword-only signature + # + # TODO: This complexity can be removed once predict() signatures are standardized. + # See: https://github.com/AnthusAI/Plexus/issues/77 + + try: + # Try calling with original arguments + results = original_predict(*args, **kwargs) + except TypeError as e: + # If that fails, try to adapt the call based on signature inspection + import inspect + sig = inspect.signature(original_predict) + + # Check if this is a legacy single-argument predict method + if len(sig.parameters) == 1: + # Legacy signature: predict(model_input) + if len(args) == 2: + # Called with (context, model_input), use just model_input + results = original_predict(args[1]) + elif len(args) == 1: + # Called with (model_input), pass through + results = original_predict(args[0]) + elif 'model_input' in kwargs: + # Called with keyword args + results = original_predict(kwargs['model_input']) + else: + # Re-raise original error if we can't adapt + raise e + else: + # Standard signature but different calling convention + if len(args) == 2: + # Try keyword-only arguments + results = original_predict(context=args[0], model_input=args[1]) + else: + # Re-raise original error + raise e + + # Apply validation if configured + if hasattr(self, 'parameters') and self.parameters.validation: + if isinstance(results, list): + for result in results: + result.validate(self.parameters.validation) + else: + results.validate(self.parameters.validation) + + return results + + return validated_predict + + return attr def is_relevant(self, text): """ @@ -579,7 +826,7 @@ def get_accumulated_costs(self): """ return { - "total_cost": 0 + "total_cost": 0 } def get_label_score_name(self): @@ -614,4 +861,8 @@ def from_name(cls, scorecard, score): return score_class(**score_parameters) +# Rebuild models to resolve forward references +Score.FieldValidation.model_rebuild() +Score.ValidationConfig.model_rebuild() +Score.Parameters.model_rebuild() Score.Result.model_rebuild() \ No newline at end of file diff --git a/tests/test_score_validation.py b/tests/test_score_validation.py new file mode 100644 index 000000000..2f62aab17 --- /dev/null +++ b/tests/test_score_validation.py @@ -0,0 +1,526 @@ +""" +Comprehensive tests for Score validation functionality. + +Tests cover all validation scenarios including: +- valid_classes validation +- pattern validation +- mixed validation (valid_classes + patterns) +- length validation +- NQ- pattern exclusion scenarios +- predict_with_validation method integration +""" + +import pytest +from plexus.scores.Score import Score + + +class MockScore(Score): + """Mock implementation of Score for testing validation.""" + + def __init__(self, **parameters): + super().__init__(**parameters) + self.mock_result = None + + def set_mock_result(self, result: Score.Result): + """Set the result this mock should return.""" + self.mock_result = result + + def predict(self, context, model_input: Score.Input) -> Score.Result: + """Return the pre-configured mock result.""" + if self.mock_result is None: + return Score.Result( + parameters=self.parameters, + value="Yes", + explanation="Mock result" + ) + return self.mock_result + + +class TestScoreValidation: + """Test cases for Score validation functionality.""" + + def test_no_validation_config_passes(self): + """Test that results pass validation when no validation config is provided.""" + score = MockScore() + result = Score.Result( + parameters=score.parameters, + value="Any Value", + explanation="Any explanation" + ) + + # Should not raise any exception + result.validate(None) + result.validate(Score.ValidationConfig()) + + def test_valid_classes_validation_success(self): + """Test successful validation with valid_classes.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No", "Maybe"]) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Yes", + explanation="Test explanation" + ) + + # Should not raise exception + result.validate(validation_config) + + def test_valid_classes_validation_failure(self): + """Test validation failure with invalid class.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No"]) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Maybe", + explanation="Test explanation" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert exc_info.value.field_name == "value" + assert exc_info.value.value == "Maybe" + assert "'Maybe' is not in valid_classes ['Yes', 'No']" in str(exc_info.value) + + def test_case_insensitive_validation_success(self): + """Test case-insensitive validation (default behavior).""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No", "Maybe"]) + ) + + # Test different cases - should all pass + test_cases = ["yes", "YES", "Yes", "no", "NO", "No", "maybe", "MAYBE", "Maybe"] + + for test_value in test_cases: + result = Score.Result( + parameters=Score.Parameters(), + value=test_value, + explanation="Test explanation" + ) + # Should not raise exception (case-insensitive by default) + result.validate(validation_config) + + def test_case_insensitive_validation_failure(self): + """Test case-insensitive validation failure.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No"]) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="maybe", # Not in valid_classes even case-insensitively + explanation="Test explanation" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert exc_info.value.field_name == "value" + assert exc_info.value.value == "maybe" + assert "(case-insensitive)" in str(exc_info.value) + + def test_case_sensitive_validation_success(self): + """Test case-sensitive validation when explicitly enabled.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation( + valid_classes=["Yes", "No", "Maybe"], + case_sensitive=True + ) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Yes", # Exact case match + explanation="Test explanation" + ) + + # Should not raise exception + result.validate(validation_config) + + def test_case_sensitive_validation_failure(self): + """Test case-sensitive validation failure.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation( + valid_classes=["Yes", "No"], + case_sensitive=True + ) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="yes", # Wrong case, should fail when case_sensitive=True + explanation="Test explanation" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert exc_info.value.field_name == "value" + assert exc_info.value.value == "yes" + # Should not have "(case-insensitive)" in message for case-sensitive validation + assert "(case-insensitive)" not in str(exc_info.value) + + def test_pattern_validation_success(self): + """Test successful validation with regex patterns.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(patterns=["^(Yes|No)$", "^Maybe.*"]) + ) + + # Test first pattern match + result1 = Score.Result( + parameters=Score.Parameters(), + value="Yes", + explanation="Test" + ) + result1.validate(validation_config) + + # Test second pattern match + result2 = Score.Result( + parameters=Score.Parameters(), + value="Maybe sometimes", + explanation="Test" + ) + result2.validate(validation_config) + + def test_pattern_validation_failure(self): + """Test validation failure with patterns that don't match.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(patterns=["^(Yes|No)$"]) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Maybe", + explanation="Test" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert "'Maybe' does not match any required patterns" in str(exc_info.value) + + def test_nq_pattern_exclusion_success(self): + """Test NQ- pattern that excludes 'NQ - Other'.""" + # Pattern matches "NQ - " followed by anything except "Other" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(patterns=["^NQ - (?!Other$).*"]) + ) + + # These should pass + valid_values = [ + "NQ - Pricing", + "NQ - Technical Support", + "NQ - Billing", + "NQ - General Info" + ] + + for value in valid_values: + result = Score.Result( + parameters=Score.Parameters(), + value=value, + explanation="Test" + ) + result.validate(validation_config) # Should not raise + + def test_nq_pattern_exclusion_failure(self): + """Test NQ- pattern correctly excludes 'NQ - Other'.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(patterns=["^NQ - (?!Other$).*"]) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="NQ - Other", + explanation="Test" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert "'NQ - Other' does not match any required patterns" in str(exc_info.value) + + def test_mixed_validation_success(self): + """Test successful validation with both valid_classes and patterns.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation( + valid_classes=["Yes", "No", "NQ - Pricing"], + patterns=["^(Yes|No)$", "^NQ - (?!Other$).*"] + ) + ) + + # Value must be in valid_classes AND match a pattern + result = Score.Result( + parameters=Score.Parameters(), + value="NQ - Pricing", # In valid_classes AND matches NQ pattern + explanation="Test" + ) + + result.validate(validation_config) # Should not raise + + def test_mixed_validation_failure_valid_classes(self): + """Test mixed validation fails when valid_classes check fails.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation( + valid_classes=["Yes", "No"], + patterns=["^NQ - (?!Other$).*"] + ) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="NQ - Pricing", # Matches pattern but NOT in valid_classes + explanation="Test" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert "'NQ - Pricing' is not in valid_classes ['Yes', 'No']" in str(exc_info.value) + + def test_mixed_validation_failure_patterns(self): + """Test mixed validation fails when pattern check fails.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation( + valid_classes=["Yes", "No", "Maybe"], + patterns=["^(Yes|No)$"] # "Maybe" won't match this pattern + ) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Maybe", # In valid_classes but doesn't match pattern + explanation="Test" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert "'Maybe' does not match any required patterns" in str(exc_info.value) + + def test_length_validation_success(self): + """Test successful length validation.""" + validation_config = Score.ValidationConfig( + explanation=Score.FieldValidation( + minimum_length=5, + maximum_length=50 + ) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Yes", + explanation="This explanation is just right" + ) + + result.validate(validation_config) # Should not raise + + def test_minimum_length_validation_failure(self): + """Test validation failure when text is too short.""" + validation_config = Score.ValidationConfig( + explanation=Score.FieldValidation(minimum_length=10) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Yes", + explanation="Short" # Only 5 characters + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert "length 5 is below minimum_length 10" in str(exc_info.value) + + def test_maximum_length_validation_failure(self): + """Test validation failure when text is too long.""" + validation_config = Score.ValidationConfig( + explanation=Score.FieldValidation(maximum_length=10) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Yes", + explanation="This explanation is way too long for the limit" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert "exceeds maximum_length 10" in str(exc_info.value) + + def test_explanation_none_skips_validation(self): + """Test that None explanation skips validation.""" + validation_config = Score.ValidationConfig( + explanation=Score.FieldValidation(minimum_length=10) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Yes", + explanation=None + ) + + # Should not raise exception even though explanation is None + result.validate(validation_config) + + def test_invalid_regex_pattern_error(self): + """Test that invalid regex patterns raise appropriate errors.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(patterns=["[invalid regex"]) + ) + + result = Score.Result( + parameters=Score.Parameters(), + value="Test", + explanation="Test" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result.validate(validation_config) + + assert "Invalid regex pattern" in str(exc_info.value) + + def test_predict_with_validation_success(self): + """Test predict method with successful validation.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No"]) + ) + + score = MockScore(validation=validation_config) + mock_result = Score.Result( + parameters=score.parameters, + value="Yes", + explanation="Valid result" + ) + score.set_mock_result(mock_result) + + input_data = Score.Input(text="test input") + result = score.predict(None, input_data) + + assert result.value == "Yes" + + def test_predict_with_validation_failure(self): + """Test predict method with validation failure.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No"]) + ) + + score = MockScore(validation=validation_config) + mock_result = Score.Result( + parameters=score.parameters, + value="Maybe", # Invalid according to validation config + explanation="Invalid result" + ) + score.set_mock_result(mock_result) + + input_data = Score.Input(text="test input") + + with pytest.raises(Score.ValidationError) as exc_info: + score.predict(None, input_data) + + assert "'Maybe' is not in valid_classes ['Yes', 'No']" in str(exc_info.value) + + def test_predict_with_validation_list_results(self): + """Test predict method with list of results.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No"]) + ) + + # Create a custom MockScore class for this test + class ListResultsMockScore(MockScore): + def predict(self, context, model_input): + return [ + Score.Result(parameters=self.parameters, value="Yes", explanation="First"), + Score.Result(parameters=self.parameters, value="No", explanation="Second") + ] + + score = ListResultsMockScore(validation=validation_config) + + input_data = Score.Input(text="test input") + results = score.predict(None, input_data) + + assert len(results) == 2 + assert results[0].value == "Yes" + assert results[1].value == "No" + + def test_predict_with_validation_list_results_failure(self): + """Test predict method with invalid result in list.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation(valid_classes=["Yes", "No"]) + ) + + # Create a custom MockScore class for this test + class ListResultsMockScore(MockScore): + def predict(self, context, model_input): + return [ + Score.Result(parameters=self.parameters, value="Yes", explanation="Valid"), + Score.Result(parameters=self.parameters, value="Maybe", explanation="Invalid") + ] + + score = ListResultsMockScore(validation=validation_config) + + input_data = Score.Input(text="test input") + + with pytest.raises(Score.ValidationError) as exc_info: + score.predict(None, input_data) + + assert "'Maybe' is not in valid_classes ['Yes', 'No']" in str(exc_info.value) + + def test_predict_with_validation_no_config(self): + """Test predict method works normally without validation config.""" + score = MockScore() # No validation config + mock_result = Score.Result( + parameters=score.parameters, + value="Any Value", + explanation="Any explanation" + ) + score.set_mock_result(mock_result) + + input_data = Score.Input(text="test input") + result = score.predict(None, input_data) + + assert result.value == "Any Value" + + def test_comprehensive_validation_scenario(self): + """Test a comprehensive validation scenario with multiple constraints.""" + validation_config = Score.ValidationConfig( + value=Score.FieldValidation( + valid_classes=["NQ - Pricing", "NQ - Technical", "NQ - Billing", "Yes", "No"], + patterns=["^(Yes|No)$", "^NQ - (?!Other$).*"] + ), + explanation=Score.FieldValidation( + minimum_length=15, + maximum_length=200, + patterns=[".*found.*", ".*evidence.*", ".*clear.*"] + ) + ) + + # This should pass all validations + result = Score.Result( + parameters=Score.Parameters(), + value="NQ - Pricing", # In valid_classes AND matches NQ pattern + explanation="Clear evidence found in the transcript" # Right length AND contains required words + ) + + result.validate(validation_config) # Should not raise + + # This should fail explanation pattern validation + result_bad_explanation = Score.Result( + parameters=Score.Parameters(), + value="NQ - Pricing", + explanation="This explanation does not contain the required terms" + ) + + with pytest.raises(Score.ValidationError) as exc_info: + result_bad_explanation.validate(validation_config) + + assert "does not match any required patterns" in str(exc_info.value) + assert exc_info.value.field_name == "explanation" + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file