From ae72a4afe4d8ddc6d135cecdd27b7266cdb59ce6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 05:31:56 +0000 Subject: [PATCH 01/13] add helper functions to detect bad responses Files copied from https://github.com/cleanlab/cleanlab-codex/pull/11. Specifically: https://github.com/cleanlab/cleanlab-codex/pull/11/commits/49f9a9dd85f8fe06527db396e0affd68cafe1835 --- src/cleanlab_codex/response_validation.py | 291 ++++++++++++++++++++++ src/cleanlab_codex/utils/prompt.py | 21 ++ tests/test_response_validation.py | 174 +++++++++++++ 3 files changed, 486 insertions(+) create mode 100644 src/cleanlab_codex/response_validation.py create mode 100644 src/cleanlab_codex/utils/prompt.py create mode 100644 tests/test_response_validation.py diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py new file mode 100644 index 0000000..f239397 --- /dev/null +++ b/src/cleanlab_codex/response_validation.py @@ -0,0 +1,291 @@ +""" +This module provides validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Optional, Sequence, TypedDict, Union, cast + +from cleanlab_codex.utils.errors import MissingDependencyError +from cleanlab_codex.utils.prompt import default_format_prompt + +if TYPE_CHECKING: + try: + from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore + except ImportError: + from typing import Any, Dict, Protocol, Sequence + + class _TLMProtocol(Protocol): + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + **kwargs: Any, + ) -> Dict[str, Any]: ... + + def prompt( + self, + prompt: Union[str, Sequence[str]], + /, + **kwargs: Any, + ) -> Dict[str, Any]: ... + + TLM = _TLMProtocol + + +DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." +DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 +DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 + + +class BadResponseDetectionConfig(TypedDict, total=False): + """Configuration for bad response detection functions. + See get_bad_response_config() for default values. + + Attributes: + fallback_answer: Known unhelpful response to compare against + partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses + format_prompt: Function to format (query, context) into a prompt string + unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification + tlm: TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) + """ + + # Fallback check config + fallback_answer: str + partial_ratio_threshold: int + + # Untrustworthy check config + trustworthiness_threshold: float + format_prompt: Callable[[str, str], str] + + # Unhelpful check config + unhelpfulness_confidence_threshold: Optional[float] + + # Shared config (for untrustworthiness and unhelpfulness checks) + tlm: Optional[TLM] + + +def get_bad_response_config() -> BadResponseDetectionConfig: + """Get the default configuration for bad response detection functions. + + Returns: + BadResponseDetectionConfig: Default configuration for bad response detection functions + """ + return { + "fallback_answer": DEFAULT_FALLBACK_ANSWER, + "partial_ratio_threshold": DEFAULT_PARTIAL_RATIO_THRESHOLD, + "trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD, + "format_prompt": default_format_prompt, + "unhelpfulness_confidence_threshold": None, + "tlm": None, + } + + +def is_bad_response( + response: str, + *, + context: Optional[str] = None, + query: Optional[str] = None, + config: Optional[BadResponseDetectionConfig] = None, +) -> bool: + """Run a series of checks to determine if a response is bad. + + If any check detects an issue (i.e. fails), the function returns True, indicating the response is bad. + + This function runs three possible validation checks: + 1. **Fallback check**: Detects if response is too similar to a known fallback answer. + 2. **Untrustworthy check**: Assesses response trustworthiness based on the given context and query. + 3. **Unhelpful check**: Predicts if the response adequately answers the query or not, in a useful way. + + Note: + Each validation check runs conditionally based on whether the required arguments are provided. + As soon as any validation check fails, the function returns True. + + Args: + response: The response to check. + context: Optional context/documents used for answering. Required for untrustworthy check. + query: Optional user question. Required for untrustworthy and unhelpful checks. + config: Optional, typed dictionary of configuration parameters. See <_BadReponseConfig> for details. + + Returns: + bool: True if any validation check fails, False if all pass. + """ + default_cfg = get_bad_response_config() + cfg: BadResponseDetectionConfig + cfg = {**default_cfg, **(config or {})} + + validation_checks: list[Callable[[], bool]] = [] + + # All required inputs are available for checking fallback responses + validation_checks.append( + lambda: is_fallback_response( + response, + cfg["fallback_answer"], + threshold=cfg["partial_ratio_threshold"], + ) + ) + + can_run_untrustworthy_check = query is not None and context is not None and cfg["tlm"] is not None + if can_run_untrustworthy_check: + # The if condition guarantees these are not None + validation_checks.append( + lambda: is_untrustworthy_response( + response=response, + context=cast(str, context), + query=cast(str, query), + tlm=cfg["tlm"], + trustworthiness_threshold=cfg["trustworthiness_threshold"], + format_prompt=cfg["format_prompt"], + ) + ) + + can_run_unhelpful_check = query is not None and cfg["tlm"] is not None + if can_run_unhelpful_check: + validation_checks.append( + lambda: is_unhelpful_response( + response=response, + query=cast(str, query), + tlm=cfg["tlm"], + trustworthiness_score_threshold=cast(float, cfg["unhelpfulness_confidence_threshold"]), + ) + ) + + return any(check() for check in validation_checks) + + +def is_fallback_response( + response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD +) -> bool: + """Check if a response is too similar to a known fallback answer. + + Uses fuzzy string matching to compare the response against a known fallback answer. + Returns True if the response is similar enough to be considered unhelpful. + + Args: + response: The response to check. + fallback_answer: A known unhelpful/fallback response to compare against. + threshold: Similarity threshold (0-100). Higher values require more similarity. + Default 70 means responses that are 70% or more similar are considered bad. + + Returns: + bool: True if the response is too similar to the fallback answer, False otherwise + """ + try: + from thefuzz import fuzz # type: ignore + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "thefuzz", + package_url="https://github.com/seatgeek/thefuzz", + ) from e + + partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) + return bool(partial_ratio >= threshold) + + +def is_untrustworthy_response( + response: str, + context: str, + query: str, + tlm: TLM, + trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, + format_prompt: Callable[[str, str], str] = default_format_prompt, +) -> bool: + """Check if a response is untrustworthy. + + Uses TLM to evaluate whether a response is trustworthy given the context and query. + Returns True if TLM's trustworthiness score falls below the threshold, indicating + the response may be incorrect or unreliable. + + Args: + response: The response to check from the assistant + context: The context information available for answering the query + query: The user's question or request + tlm: The TLM model to use for evaluation + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. + Default 0.5, meaning responses with scores less than 0.5 are considered untrustworthy. + format_prompt: Function that takes (query, context) and returns a formatted prompt string. + Users should provide their RAG app's own prompt formatting function here + to match how their LLM is prompted. + + Returns: + bool: True if the response is deemed untrustworthy by TLM, False otherwise + """ + try: + from cleanlab_studio import Studio # type: ignore[import-untyped] # noqa: F401 + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e + + prompt = format_prompt(query, context) + result = tlm.get_trustworthiness_score(prompt, response) + score: float = result["trustworthiness_score"] + return score < trustworthiness_threshold + + +def is_unhelpful_response( + response: str, + query: str, + tlm: TLM, + trustworthiness_score_threshold: Optional[float] = None, +) -> bool: + """Check if a response is unhelpful by asking TLM to evaluate it. + + Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment. + The evaluation considers both the TLM's binary classification of helpfulness and its + confidence score. Returns True only if TLM classifies the response as unhelpful AND + is sufficiently confident in that assessment (if a threshold is provided). + + Args: + response: The response to check + query: User query that will be used to evaluate if the response is helpful + tlm: The TLM model to use for evaluation + trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0) + If provided and the response is marked as unhelpful, + the confidence score must exceed this threshold for + the response to be considered truly unhelpful. + + Returns: + bool: True if TLM determines the response is unhelpful with sufficient confidence, + False otherwise + """ + try: + from cleanlab_studio import Studio # noqa: F401 + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e + + # The question and expected "unhelpful" response are linked: + # - When asking "is helpful?" -> "no" means unhelpful + # - When asking "is unhelpful?" -> "yes" means unhelpful + question = ( + "Does the AI Assistant Response seem unhelpful? " + "Things that are not helpful include answers that:\n" + "- Are not useful, incomplete, incorrect, uncertain or unclear.\n" + "- Abstain or refuse to answer the question\n" + "- Statements which are similar to 'I don't know', 'Sorry', or 'No information available'.\n" + "- Leave the original question unresolved\n" + "- Are irrelevant to the question\n" + "Answer Yes/No only." + ) + expected_unhelpful_response = "yes" + + prompt = ( + "Consider the following User Query and AI Assistant Response.\n\n" + f"User Query: {query}\n\n" + f"AI Assistant Response: {response}\n\n" + f"{question}" + ) + + output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) + response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response + is_trustworthy = trustworthiness_score_threshold is None or ( + output["trustworthiness_score"] > trustworthiness_score_threshold + ) + return response_marked_unhelpful and is_trustworthy diff --git a/src/cleanlab_codex/utils/prompt.py b/src/cleanlab_codex/utils/prompt.py new file mode 100644 index 0000000..c04fc71 --- /dev/null +++ b/src/cleanlab_codex/utils/prompt.py @@ -0,0 +1,21 @@ +""" +Utility functions for RAG (Retrieval Augmented Generation) operations. +""" + + +def default_format_prompt(query: str, context: str) -> str: + """Default function for formatting RAG prompts. + + Args: + query: The user's question + context: The context/documents to use for answering + + Returns: + str: A formatted prompt combining the query and context + """ + template = ( + "Using only information from the following Context, answer the following Query.\n\n" + "Context:\n{context}\n\n" + "Query: {query}" + ) + return template.format(context=context, query=query) diff --git a/tests/test_response_validation.py b/tests/test_response_validation.py new file mode 100644 index 0000000..d10e661 --- /dev/null +++ b/tests/test_response_validation.py @@ -0,0 +1,174 @@ +"""Unit tests for validation module functions.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest + +from cleanlab_codex.response_validation import ( + is_bad_response, + is_fallback_response, + is_unhelpful_response, + is_untrustworthy_response, +) + +# Mock responses for testing +GOOD_RESPONSE = "This is a helpful and specific response that answers the question completely." +BAD_RESPONSE = "Based on the available information, I cannot provide a complete answer." +QUERY = "What is the capital of France?" +CONTEXT = "Paris is the capital and largest city of France." + + +@pytest.fixture +def mock_tlm() -> Mock: + """Create a mock TLM instance.""" + mock = Mock() + # Configure default return values + mock.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} + mock.prompt.return_value = {"response": "No", "trustworthiness_score": 0.9} + return mock + + +@pytest.mark.parametrize( + ("response", "threshold", "fallback_answer", "expected"), + [ + # Test threshold variations + (GOOD_RESPONSE, 30, None, True), + (GOOD_RESPONSE, 55, None, False), + # Test default behavior (BAD_RESPONSE should be flagged) + (BAD_RESPONSE, None, None, True), + # Test default behavior for different response (GOOD_RESPONSE should not be flagged) + (GOOD_RESPONSE, None, None, False), + # Test custom fallback answer + (GOOD_RESPONSE, 80, "This is an unhelpful response", False), + ], +) +def test_is_fallback_response( + response: str, + threshold: float | None, + fallback_answer: str | None, + *, + expected: bool, +) -> None: + """Test fallback response detection.""" + kwargs: dict[str, float | str] = {} + if threshold is not None: + kwargs["threshold"] = threshold + if fallback_answer is not None: + kwargs["fallback_answer"] = fallback_answer + + assert is_fallback_response(response, **kwargs) is expected # type: ignore + + +def test_is_untrustworthy_response(mock_tlm: Mock) -> None: + """Test untrustworthy response detection.""" + # Test trustworthy response + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} + assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is False + + # Test untrustworthy response + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.3} + assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is True + + +@pytest.mark.parametrize( + ("response", "tlm_response", "tlm_score", "threshold", "expected"), + [ + # Test helpful response + (GOOD_RESPONSE, "No", 0.9, 0.5, False), + # Test unhelpful response + (BAD_RESPONSE, "Yes", 0.9, 0.5, True), + # Test unhelpful response but low trustworthiness score + (BAD_RESPONSE, "Yes", 0.3, 0.5, False), + # Test without threshold - Yes prediction + (BAD_RESPONSE, "Yes", 0.3, None, True), + (GOOD_RESPONSE, "Yes", 0.3, None, True), + # Test without threshold - No prediction + (BAD_RESPONSE, "No", 0.3, None, False), + (GOOD_RESPONSE, "No", 0.3, None, False), + ], +) +def test_is_unhelpful_response( + mock_tlm: Mock, + response: str, + tlm_response: str, + tlm_score: float, + threshold: float | None, + *, + expected: bool, +) -> None: + """Test unhelpful response detection.""" + mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} + assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected + + +@pytest.mark.parametrize( + ("response", "trustworthiness_score", "prompt_response", "prompt_score", "expected"), + [ + # Good response passes all checks + (GOOD_RESPONSE, 0.8, "No", 0.9, False), + # Bad response fails at least one check + (BAD_RESPONSE, 0.3, "Yes", 0.9, True), + ], +) +def test_is_bad_response( + mock_tlm: Mock, + response: str, + trustworthiness_score: float, + prompt_response: str, + prompt_score: float, + *, + expected: bool, +) -> None: + """Test the main is_bad_response function.""" + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} + mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} + + assert ( + is_bad_response( + response, + context=CONTEXT, + query=QUERY, + config={"tlm": mock_tlm}, + ) + is expected + ) + + +@pytest.mark.parametrize( + ("response", "fuzz_ratio", "prompt_response", "prompt_score", "query", "tlm", "expected"), + [ + # Test with only fallback check (no context/query/tlm) + (BAD_RESPONSE, 90, None, None, None, None, True), + # Test with fallback and unhelpful checks (no context) + (GOOD_RESPONSE, 30, "No", 0.9, QUERY, "mock_tlm", False), + ], +) +def test_is_bad_response_partial_inputs( + mock_tlm: Mock, + response: str, + fuzz_ratio: int, + prompt_response: str, + prompt_score: float, + query: str, + tlm: Mock, + *, + expected: bool, +) -> None: + """Test is_bad_response with partial inputs (some checks disabled).""" + mock_fuzz = Mock() + mock_fuzz.partial_ratio.return_value = fuzz_ratio + with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): + if prompt_response is not None: + mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} + tlm = mock_tlm + + assert ( + is_bad_response( + response, + query=query, + config={"tlm": tlm}, + ) + is expected + ) From 7d17a4fcac2c569e47ed86c14202da5357a381f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 05:36:18 +0000 Subject: [PATCH 02/13] update pyproject.toml dependencies --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index c7fa840..259170e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,8 @@ extra-dependencies = [ "pytest", "llama-index-core", "smolagents", + "cleanlab-studio", + "thefuzz", "langchain-core", ] [tool.hatch.envs.types.scripts] From 0a7c72ae75825eb40141c233d622cb60d58d7998 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 05:38:07 +0000 Subject: [PATCH 03/13] tests also like those dependenies --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 259170e..a2b53a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ allow-direct-references = true extra-dependencies = [ "llama-index-core", "smolagents; python_version >= '3.10'", + "cleanlab-studio", + "thefuzz", "langchain-core", ] From af6d4cbd681a6a247b89916acb4ecb13c9f24a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 15:19:35 +0000 Subject: [PATCH 04/13] partial ratio threshold -> fallback similarity threshold --- src/cleanlab_codex/response_validation.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index f239397..1d6aecc 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -33,9 +33,9 @@ def prompt( TLM = _TLMProtocol -DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." -DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 -DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 +DEFAULT_FALLBACK_ANSWER: str = "Based on the available information, I cannot provide a complete answer to this question." +DEFAULT_FALLBACK_SIMILARITY_THRESHOLD: int = 70 +DEFAULT_TRUSTWORTHINESS_THRESHOLD: float = 0.5 class BadResponseDetectionConfig(TypedDict, total=False): @@ -44,7 +44,9 @@ class BadResponseDetectionConfig(TypedDict, total=False): Attributes: fallback_answer: Known unhelpful response to compare against - partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity + fallback_similarity_threshold: Fuzzy string matching similarity threshold (0-100). + Higher values mean responses must be more similar to fallback_answer + to be considered bad. trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses format_prompt: Function to format (query, context) into a prompt string unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification @@ -53,7 +55,7 @@ class BadResponseDetectionConfig(TypedDict, total=False): # Fallback check config fallback_answer: str - partial_ratio_threshold: int + fallback_similarity_threshold: int # Fuzzy matching similarity threshold (0-100) # Untrustworthy check config trustworthiness_threshold: float @@ -74,7 +76,7 @@ def get_bad_response_config() -> BadResponseDetectionConfig: """ return { "fallback_answer": DEFAULT_FALLBACK_ANSWER, - "partial_ratio_threshold": DEFAULT_PARTIAL_RATIO_THRESHOLD, + "fallback_similarity_threshold": DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, "trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD, "format_prompt": default_format_prompt, "unhelpfulness_confidence_threshold": None, @@ -122,7 +124,7 @@ def is_bad_response( lambda: is_fallback_response( response, cfg["fallback_answer"], - threshold=cfg["partial_ratio_threshold"], + threshold=cfg["fallback_similarity_threshold"], ) ) @@ -155,7 +157,9 @@ def is_bad_response( def is_fallback_response( - response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD + response: str, + fallback_answer: str = DEFAULT_FALLBACK_ANSWER, + threshold: int = DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, ) -> bool: """Check if a response is too similar to a known fallback answer. From 14f78af5496131f78e62dbf1a84595e5b3f32233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 15:23:28 +0000 Subject: [PATCH 05/13] Define type aliases used for format_prompt --- src/cleanlab_codex/response_validation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index 1d6aecc..1c30c43 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -37,6 +37,10 @@ def prompt( DEFAULT_FALLBACK_SIMILARITY_THRESHOLD: int = 70 DEFAULT_TRUSTWORTHINESS_THRESHOLD: float = 0.5 +Query = str +Context = str +Prompt = str + class BadResponseDetectionConfig(TypedDict, total=False): """Configuration for bad response detection functions. @@ -59,7 +63,7 @@ class BadResponseDetectionConfig(TypedDict, total=False): # Untrustworthy check config trustworthiness_threshold: float - format_prompt: Callable[[str, str], str] + format_prompt: Callable[[Query, Context], Prompt] # Unhelpful check config unhelpfulness_confidence_threshold: Optional[float] From 5d33c32461cace50d728c463629c21315f13cdfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 16:23:25 +0000 Subject: [PATCH 06/13] make get_default_config and apply_defaults class methods --- src/cleanlab_codex/response_validation.py | 50 ++++++++++++++--------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index 1c30c43..69e6a6e 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -70,23 +70,35 @@ class BadResponseDetectionConfig(TypedDict, total=False): # Shared config (for untrustworthiness and unhelpfulness checks) tlm: Optional[TLM] - - -def get_bad_response_config() -> BadResponseDetectionConfig: - """Get the default configuration for bad response detection functions. - - Returns: - BadResponseDetectionConfig: Default configuration for bad response detection functions - """ - return { - "fallback_answer": DEFAULT_FALLBACK_ANSWER, - "fallback_similarity_threshold": DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, - "trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD, - "format_prompt": default_format_prompt, - "unhelpfulness_confidence_threshold": None, - "tlm": None, - } - + + @classmethod + def get_default_config(cls) -> BadResponseDetectionConfig: + """Get the default configuration for bad response detection functions. + + Returns: + BadResponseDetectionConfig: Default configuration for bad response detection functions + """ + return { + "fallback_answer": DEFAULT_FALLBACK_ANSWER, + "fallback_similarity_threshold": DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, + "trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD, + "format_prompt": default_format_prompt, + "unhelpfulness_confidence_threshold": None, + "tlm": None, + } + + @classmethod + def apply_defaults(cls, config: Optional[BadResponseDetectionConfig] = None) -> BadResponseDetectionConfig: + """Apply default values to a configuration dictionary with missing entries. + + Args: + config: Configuration dictionary to apply defaults to + + Returns: + BadResponseDetectionConfig: Configuration dictionary with defaults applied + """ + default_cfg = cls.get_default_config() + return {**default_cfg, **(config or {})} def is_bad_response( response: str, @@ -117,9 +129,7 @@ def is_bad_response( Returns: bool: True if any validation check fails, False if all pass. """ - default_cfg = get_bad_response_config() - cfg: BadResponseDetectionConfig - cfg = {**default_cfg, **(config or {})} + cfg = BadResponseDetectionConfig.apply_defaults(config) validation_checks: list[Callable[[], bool]] = [] From 41853a4ea9b1f79764e04d3331ea46bed9c3111d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 16:36:30 +0000 Subject: [PATCH 07/13] update comment in is_unhelpful_response --- src/cleanlab_codex/response_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index 69e6a6e..fb26edf 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -279,7 +279,7 @@ def is_unhelpful_response( package_url="https://github.com/cleanlab/cleanlab-studio", ) from e - # The question and expected "unhelpful" response are linked: + # If editing `question`, make sure `expected_unhelpful_response` is still correct: # - When asking "is helpful?" -> "no" means unhelpful # - When asking "is unhelpful?" -> "yes" means unhelpful question = ( From e9bd0c589328c10eea8bec27b27edaa8fd3cdf29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 17:24:41 +0000 Subject: [PATCH 08/13] formatting --- src/cleanlab_codex/response_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index fb26edf..da56d2c 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -70,7 +70,7 @@ class BadResponseDetectionConfig(TypedDict, total=False): # Shared config (for untrustworthiness and unhelpfulness checks) tlm: Optional[TLM] - + @classmethod def get_default_config(cls) -> BadResponseDetectionConfig: """Get the default configuration for bad response detection functions. From 3da219404403f3af038492560a898c0e29759639 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 17:26:37 +0000 Subject: [PATCH 09/13] more formatting --- src/cleanlab_codex/response_validation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index da56d2c..d5a3f52 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -33,7 +33,9 @@ def prompt( TLM = _TLMProtocol -DEFAULT_FALLBACK_ANSWER: str = "Based on the available information, I cannot provide a complete answer to this question." +DEFAULT_FALLBACK_ANSWER: str = ( + "Based on the available information, I cannot provide a complete answer to this question." +) DEFAULT_FALLBACK_SIMILARITY_THRESHOLD: int = 70 DEFAULT_TRUSTWORTHINESS_THRESHOLD: float = 0.5 @@ -100,6 +102,7 @@ def apply_defaults(cls, config: Optional[BadResponseDetectionConfig] = None) -> default_cfg = cls.get_default_config() return {**default_cfg, **(config or {})} + def is_bad_response( response: str, *, From a7c80586cff47a689be63a6ba74fc1fdd379d2b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 18:23:02 +0000 Subject: [PATCH 10/13] update module docstring --- src/cleanlab_codex/utils/prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cleanlab_codex/utils/prompt.py b/src/cleanlab_codex/utils/prompt.py index c04fc71..2717ef5 100644 --- a/src/cleanlab_codex/utils/prompt.py +++ b/src/cleanlab_codex/utils/prompt.py @@ -1,5 +1,5 @@ """ -Utility functions for RAG (Retrieval Augmented Generation) operations. +Helper functions for processing prompts in RAG applications. """ From 90560134b117281bca66c62d64058f4136374de6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 20:08:55 +0000 Subject: [PATCH 11/13] switch BadResponseDetectionConfig to a pydantic model --- src/cleanlab_codex/response_validation.py | 134 +++++++++------------- tests/test_response_validation.py | 63 +++++++--- 2 files changed, 105 insertions(+), 92 deletions(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index d5a3f52..648f5b0 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -4,34 +4,48 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Sequence, TypedDict, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + Protocol, + Sequence, + Union, + cast, + runtime_checkable, +) + +from pydantic import BaseModel, ConfigDict, Field from cleanlab_codex.utils.errors import MissingDependencyError from cleanlab_codex.utils.prompt import default_format_prompt + +@runtime_checkable +class TLMProtocol(Protocol): + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + **kwargs: Any, + ) -> Dict[str, Any]: ... + + def prompt( + self, + prompt: Union[str, Sequence[str]], + /, + **kwargs: Any, + ) -> Dict[str, Any]: ... + if TYPE_CHECKING: try: from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore except ImportError: - from typing import Any, Dict, Protocol, Sequence - - class _TLMProtocol(Protocol): - def get_trustworthiness_score( - self, - prompt: Union[str, Sequence[str]], - response: Union[str, Sequence[str]], - **kwargs: Any, - ) -> Dict[str, Any]: ... - - def prompt( - self, - prompt: Union[str, Sequence[str]], - /, - **kwargs: Any, - ) -> Dict[str, Any]: ... - - TLM = _TLMProtocol - + TLM = TLMProtocol +else: + TLM = TLMProtocol DEFAULT_FALLBACK_ANSWER: str = ( "Based on the available information, I cannot provide a complete answer to this question." @@ -44,71 +58,33 @@ def prompt( Prompt = str -class BadResponseDetectionConfig(TypedDict, total=False): - """Configuration for bad response detection functions. - See get_bad_response_config() for default values. +class BadResponseDetectionConfig(BaseModel): + """Configuration for bad response detection functions.""" - Attributes: - fallback_answer: Known unhelpful response to compare against - fallback_similarity_threshold: Fuzzy string matching similarity threshold (0-100). - Higher values mean responses must be more similar to fallback_answer - to be considered bad. - trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses - format_prompt: Function to format (query, context) into a prompt string - unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification - tlm: TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) - """ + model_config = ConfigDict(arbitrary_types_allowed=True) # Fallback check config - fallback_answer: str - fallback_similarity_threshold: int # Fuzzy matching similarity threshold (0-100) + fallback_answer: str = Field(default=DEFAULT_FALLBACK_ANSWER, description="Known unhelpful response to compare against") + fallback_similarity_threshold: int = Field(default=DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, description="Fuzzy matching similarity threshold (0-100). Higher values mean responses must be more similar to fallback_answer to be considered bad.") # Untrustworthy check config - trustworthiness_threshold: float - format_prompt: Callable[[Query, Context], Prompt] + trustworthiness_threshold: float = Field(default=DEFAULT_TRUSTWORTHINESS_THRESHOLD, description="Score threshold (0.0-1.0). Lower values allow less trustworthy responses.") + format_prompt: Callable[[Query, Context], Prompt] = Field(default=default_format_prompt, description="Function to format (query, context) into a prompt string.") # Unhelpful check config - unhelpfulness_confidence_threshold: Optional[float] + unhelpfulness_confidence_threshold: Optional[float] = Field(default=None, description="Optional confidence threshold (0.0-1.0) for unhelpful classification.") # Shared config (for untrustworthiness and unhelpfulness checks) - tlm: Optional[TLM] - - @classmethod - def get_default_config(cls) -> BadResponseDetectionConfig: - """Get the default configuration for bad response detection functions. - - Returns: - BadResponseDetectionConfig: Default configuration for bad response detection functions - """ - return { - "fallback_answer": DEFAULT_FALLBACK_ANSWER, - "fallback_similarity_threshold": DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, - "trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD, - "format_prompt": default_format_prompt, - "unhelpfulness_confidence_threshold": None, - "tlm": None, - } - - @classmethod - def apply_defaults(cls, config: Optional[BadResponseDetectionConfig] = None) -> BadResponseDetectionConfig: - """Apply default values to a configuration dictionary with missing entries. - - Args: - config: Configuration dictionary to apply defaults to - - Returns: - BadResponseDetectionConfig: Configuration dictionary with defaults applied - """ - default_cfg = cls.get_default_config() - return {**default_cfg, **(config or {})} + tlm: Optional[TLM] = Field(default=None, description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).") +DEFAULT_CONFIG = BadResponseDetectionConfig() def is_bad_response( response: str, *, context: Optional[str] = None, query: Optional[str] = None, - config: Optional[BadResponseDetectionConfig] = None, + config: Union[BadResponseDetectionConfig, Dict[str, Any]] = DEFAULT_CONFIG, ) -> bool: """Run a series of checks to determine if a response is bad. @@ -132,7 +108,7 @@ def is_bad_response( Returns: bool: True if any validation check fails, False if all pass. """ - cfg = BadResponseDetectionConfig.apply_defaults(config) + config = BadResponseDetectionConfig.model_validate(config) validation_checks: list[Callable[[], bool]] = [] @@ -140,12 +116,12 @@ def is_bad_response( validation_checks.append( lambda: is_fallback_response( response, - cfg["fallback_answer"], - threshold=cfg["fallback_similarity_threshold"], + config.fallback_answer, + threshold=config.fallback_similarity_threshold, ) ) - can_run_untrustworthy_check = query is not None and context is not None and cfg["tlm"] is not None + can_run_untrustworthy_check = query is not None and context is not None and config.tlm is not None if can_run_untrustworthy_check: # The if condition guarantees these are not None validation_checks.append( @@ -153,20 +129,20 @@ def is_bad_response( response=response, context=cast(str, context), query=cast(str, query), - tlm=cfg["tlm"], - trustworthiness_threshold=cfg["trustworthiness_threshold"], - format_prompt=cfg["format_prompt"], + tlm=config.tlm, + trustworthiness_threshold=config.trustworthiness_threshold, + format_prompt=config.format_prompt, ) ) - can_run_unhelpful_check = query is not None and cfg["tlm"] is not None + can_run_unhelpful_check = query is not None and config.tlm is not None if can_run_unhelpful_check: validation_checks.append( lambda: is_unhelpful_response( response=response, query=cast(str, query), - tlm=cfg["tlm"], - trustworthiness_score_threshold=cast(float, cfg["unhelpfulness_confidence_threshold"]), + tlm=config.tlm, + trustworthiness_score_threshold=cast(float, config.unhelpfulness_confidence_threshold), ) ) diff --git a/tests/test_response_validation.py b/tests/test_response_validation.py index d10e661..fa9e397 100644 --- a/tests/test_response_validation.py +++ b/tests/test_response_validation.py @@ -2,6 +2,7 @@ from __future__ import annotations +from typing import Any, Dict, Sequence, Union from unittest.mock import Mock, patch import pytest @@ -20,14 +21,47 @@ CONTEXT = "Paris is the capital and largest city of France." +class MockTLM(Mock): + + _trustworthiness_score: float = 0.8 + _response: str = "No" + + @property + def trustworthiness_score(self) -> float: + return self._trustworthiness_score + + @trustworthiness_score.setter + def trustworthiness_score(self, value: float) -> None: + self._trustworthiness_score = value + + @property + def response(self) -> str: + return self._response + + @response.setter + def response(self, value: str) -> None: + self._response = value + + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], # noqa: ARG002 + response: Union[str, Sequence[str]], # noqa: ARG002 + **kwargs: Any, # noqa: ARG002 + ) -> Dict[str, Any]: + return {"trustworthiness_score": self._trustworthiness_score} + + def prompt( + self, + prompt: Union[str, Sequence[str]], # noqa: ARG002 + /, + **kwargs: Any, # noqa: ARG002 + ) -> Dict[str, Any]: + return {"response": self._response, "trustworthiness_score": self._trustworthiness_score} + + @pytest.fixture -def mock_tlm() -> Mock: - """Create a mock TLM instance.""" - mock = Mock() - # Configure default return values - mock.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} - mock.prompt.return_value = {"response": "No", "trustworthiness_score": 0.9} - return mock +def mock_tlm() -> MockTLM: + return MockTLM() @pytest.mark.parametrize( @@ -64,11 +98,11 @@ def test_is_fallback_response( def test_is_untrustworthy_response(mock_tlm: Mock) -> None: """Test untrustworthy response detection.""" # Test trustworthy response - mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} + mock_tlm.trustworthiness_score = 0.8 assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is False # Test untrustworthy response - mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.3} + mock_tlm.trustworthiness_score = 0.3 assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is True @@ -99,7 +133,8 @@ def test_is_unhelpful_response( expected: bool, ) -> None: """Test unhelpful response detection.""" - mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} + mock_tlm.response = tlm_response + mock_tlm.trustworthiness_score = tlm_score assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected @@ -122,8 +157,9 @@ def test_is_bad_response( expected: bool, ) -> None: """Test the main is_bad_response function.""" - mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} - mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} + mock_tlm.trustworthiness_score = trustworthiness_score + mock_tlm.response = prompt_response + mock_tlm.trustworthiness_score = prompt_score assert ( is_bad_response( @@ -161,7 +197,8 @@ def test_is_bad_response_partial_inputs( mock_fuzz.partial_ratio.return_value = fuzz_ratio with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): if prompt_response is not None: - mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} + mock_tlm.response = prompt_response + mock_tlm.trustworthiness_score = prompt_score tlm = mock_tlm assert ( From d4b6264c12edb4b872d87009bd81d6f1f1a72c8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 20:15:41 +0000 Subject: [PATCH 12/13] formatting --- src/cleanlab_codex/response_validation.py | 32 ++++++++++++++++++----- tests/test_response_validation.py | 1 - 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index 648f5b0..05bc039 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -39,6 +39,7 @@ def prompt( **kwargs: Any, ) -> Dict[str, Any]: ... + if TYPE_CHECKING: try: from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore @@ -64,21 +65,40 @@ class BadResponseDetectionConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) # Fallback check config - fallback_answer: str = Field(default=DEFAULT_FALLBACK_ANSWER, description="Known unhelpful response to compare against") - fallback_similarity_threshold: int = Field(default=DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, description="Fuzzy matching similarity threshold (0-100). Higher values mean responses must be more similar to fallback_answer to be considered bad.") + fallback_answer: str = Field( + default=DEFAULT_FALLBACK_ANSWER, description="Known unhelpful response to compare against" + ) + fallback_similarity_threshold: int = Field( + default=DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, + description="Fuzzy matching similarity threshold (0-100). Higher values mean responses must be more similar to fallback_answer to be considered bad.", + ) # Untrustworthy check config - trustworthiness_threshold: float = Field(default=DEFAULT_TRUSTWORTHINESS_THRESHOLD, description="Score threshold (0.0-1.0). Lower values allow less trustworthy responses.") - format_prompt: Callable[[Query, Context], Prompt] = Field(default=default_format_prompt, description="Function to format (query, context) into a prompt string.") + trustworthiness_threshold: float = Field( + default=DEFAULT_TRUSTWORTHINESS_THRESHOLD, + description="Score threshold (0.0-1.0). Lower values allow less trustworthy responses.", + ) + format_prompt: Callable[[Query, Context], Prompt] = Field( + default=default_format_prompt, + description="Function to format (query, context) into a prompt string.", + ) # Unhelpful check config - unhelpfulness_confidence_threshold: Optional[float] = Field(default=None, description="Optional confidence threshold (0.0-1.0) for unhelpful classification.") + unhelpfulness_confidence_threshold: Optional[float] = Field( + default=None, + description="Optional confidence threshold (0.0-1.0) for unhelpful classification.", + ) # Shared config (for untrustworthiness and unhelpfulness checks) - tlm: Optional[TLM] = Field(default=None, description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).") + tlm: Optional[TLM] = Field( + default=None, + description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).", + ) + DEFAULT_CONFIG = BadResponseDetectionConfig() + def is_bad_response( response: str, *, diff --git a/tests/test_response_validation.py b/tests/test_response_validation.py index fa9e397..cbc1d29 100644 --- a/tests/test_response_validation.py +++ b/tests/test_response_validation.py @@ -22,7 +22,6 @@ class MockTLM(Mock): - _trustworthiness_score: float = 0.8 _response: str = "No" From bd89b195b191d22120a4d1855463181558bba3c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 21:41:30 +0000 Subject: [PATCH 13/13] update TLM typing --- src/cleanlab_codex/response_validation.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index 05bc039..dcc15d5 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -5,7 +5,6 @@ from __future__ import annotations from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -24,7 +23,7 @@ @runtime_checkable -class TLMProtocol(Protocol): +class TLM(Protocol): def get_trustworthiness_score( self, prompt: Union[str, Sequence[str]], @@ -40,14 +39,6 @@ def prompt( ) -> Dict[str, Any]: ... -if TYPE_CHECKING: - try: - from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore - except ImportError: - TLM = TLMProtocol -else: - TLM = TLMProtocol - DEFAULT_FALLBACK_ANSWER: str = ( "Based on the available information, I cannot provide a complete answer to this question." ) @@ -149,7 +140,7 @@ def is_bad_response( response=response, context=cast(str, context), query=cast(str, query), - tlm=config.tlm, + tlm=cast(TLM, config.tlm), trustworthiness_threshold=config.trustworthiness_threshold, format_prompt=config.format_prompt, ) @@ -161,7 +152,7 @@ def is_bad_response( lambda: is_unhelpful_response( response=response, query=cast(str, query), - tlm=config.tlm, + tlm=cast(TLM, config.tlm), trustworthiness_score_threshold=cast(float, config.unhelpfulness_confidence_threshold), ) )