From 3acb048dd1565ff6affd736a563163a66448c6a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 24 Jan 2025 20:09:13 +0000 Subject: [PATCH 01/39] add class to configure a decorator that treats Codex as a backup --- src/cleanlab_codex/__init__.py | 3 +- src/cleanlab_codex/codex_backup.py | 105 ++++++++++++++++++ .../utils/response_validators.py | 80 +++++++++++++ tests/test_codex_backup.py | 65 +++++++++++ 4 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 src/cleanlab_codex/codex_backup.py create mode 100644 src/cleanlab_codex/utils/response_validators.py create mode 100644 tests/test_codex_backup.py diff --git a/src/cleanlab_codex/__init__.py b/src/cleanlab_codex/__init__.py index 472f034..78c2535 100644 --- a/src/cleanlab_codex/__init__.py +++ b/src/cleanlab_codex/__init__.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: MIT from cleanlab_codex.codex import Codex from cleanlab_codex.codex_tool import CodexTool +from cleanlab_codex.codex_backup import CodexBackup -__all__ = ["Codex", "CodexTool"] +__all__ = ["Codex", "CodexTool", "CodexBackup"] diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py new file mode 100644 index 0000000..7e94ab4 --- /dev/null +++ b/src/cleanlab_codex/codex_backup.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional +from functools import wraps + +from cleanlab_codex.codex import Codex +from cleanlab_codex.utils.response_validators import is_bad_response + +def handle_backup_default(backup_response: str, decorated_instance: Any) -> None: + """Default implementation is a no-op.""" + return None + + +class CodexBackup: + """A backup decorator that connects to a Codex project to answer questions that + cannot be adequately answered by the existing agent. + """ + DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." + + def __init__( + self, + codex_client: Codex, + *, + project_id: Optional[str] = None, + fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, + backup_handler: Callable[[str, Any], None] = handle_backup_default, + ): + self._codex_client = codex_client + self._project_id = project_id + self._fallback_answer = fallback_answer + self._backup_handler = backup_handler + + @classmethod + def from_access_key( + cls, + access_key: str, + *, + project_id: Optional[str] = None, + fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, + backup_handler: Callable[[str, Any], None] = handle_backup_default, + ) -> CodexBackup: + """Creates a CodexBackup from an access key. The project ID that the CodexBackup will use is the one that is associated with the access key.""" + return cls( + codex_client=Codex(key=access_key), + project_id=project_id, + fallback_answer=fallback_answer, + backup_handler=backup_handler, + ) + + @classmethod + def from_client( + cls, + codex_client: Codex, + *, + project_id: Optional[str] = None, + fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, + backup_handler: Callable[[str, Any], None] = handle_backup_default, + ) -> CodexBackup: + """Creates a CodexBackup from a Codex client. + If the Codex client is initialized with a project access key, the CodexBackup will use the project ID that is associated with the access key. + If the Codex client is initialized with a user API key, a project ID must be provided. + """ + return cls( + codex_client=codex_client, + project_id=project_id, + fallback_answer=fallback_answer, + backup_handler=backup_handler, + ) + + def to_decorator(self): + """Factory that creates a backup decorator using the provided Codex client""" + def decorator(chat_method): + """ + Decorator for RAG chat methods that adds backup response handling. + + If the original chat method returns an inadequate response, attempts to get + a backup response from Codex. Returns the backup response if available, + otherwise returns the original response. + + Args: + chat_method: Method with signature (self, user_message: str) -> str + where 'self' refers to the instance being decorated, not an instance of CodexBackup. + """ + @wraps(chat_method) + def wrapper(decorated_instance, user_message): + # Call the original chat method + assistant_response = chat_method(decorated_instance, user_message) + + # Return original response if it's adequate + if not is_bad_response(assistant_response): + return assistant_response + + # Query Codex for a backup response + cache_result = self._codex_client.query(user_message)[0] + if not cache_result: + return assistant_response + + # Handle backup response if handler exists + self._backup_handler( + backup_response=cache_result, + decorated_instance=decorated_instance, + ) + return cache_result + return wrapper + return decorator diff --git a/src/cleanlab_codex/utils/response_validators.py b/src/cleanlab_codex/utils/response_validators.py new file mode 100644 index 0000000..a07a5c1 --- /dev/null +++ b/src/cleanlab_codex/utils/response_validators.py @@ -0,0 +1,80 @@ +""" +This module provides validation functions for checking if an LLM response is inadequate/unhelpful. +The default implementation checks for common fallback phrases, but alternative implementations +are provided below as examples that can be adapted for specific needs. +""" + + +def is_bad_response(response: str) -> bool: + """ + Default implementation that checks for common fallback phrases from LLM assistants. + + NOTE: YOU SHOULD MODIFY THIS METHOD YOURSELF. + """ + return basic_validator(response) + +def basic_validator(response: str) -> bool: + """Basic implementation that checks for common fallback phrases from LLM assistants. + + Args: + response: The response from the assistant + + Returns: + bool: True if the response appears to be a fallback/inadequate response + """ + partial_fallback_responses = [ + "Based on the available information", + "I cannot provide a complete answer to this question", + # Add more substrings here to improve the recall of the check + ] + return any( + partial_fallback_response.lower() in response.lower() + for partial_fallback_response in partial_fallback_responses + ) + +# Alternative Implementations +# --------------------------- +# The following implementations are provided as examples and inspiration. +# They should be adapted to your specific needs. + + +# Fuzzy String Matching +""" +from thefuzz import fuzz + +def fuzzy_match_validator(response: str, fallback_answer: str, threshold: int = 70) -> bool: + partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) + return partial_ratio >= threshold +""" + +# TLM Score Thresholding +""" +from cleanlab_studio import Studio + +studio = Studio("") +tlm = studio.TLM() + +def tlm_score_validator(response: str, context: str, query: str, tlm: TLM, threshold: float = 0.5) -> bool: + prompt = f"Context: {context}\n\n Query: {query}\n\n Query: {query}" + resp = tlm.get_trustworthiness_score(prompt, response) + score = resp['trustworthiness_score'] + return score < threshold +""" + +# TLM Binary Classification +""" +from typing import Optional + +from cleanlab_studio import Studio + +studio = Studio("") +tlm = studio.TLM() + +def tlm_binary_validator(response: str, tlm: TLM, query: Optional[str] = None) -> bool: + if query is None: + prompt = f"Here is a response from an AI assistant: {response}\n\n Is it helpful? Answer Yes/No only." + else: + prompt = f"Here is a response from an AI assistant: {response}\n\n Considering the following query: {query}\n\n Is the response helpful? Answer Yes/No only." + output = tlm.prompt(prompt) + return output["response"].lower() == "no" +""" \ No newline at end of file diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py new file mode 100644 index 0000000..1e82dbf --- /dev/null +++ b/tests/test_codex_backup.py @@ -0,0 +1,65 @@ +from unittest.mock import MagicMock + +from cleanlab_codex.codex_backup import CodexBackup + +MOCK_BACKUP_RESPONSE = "This is a test response" +FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question." +TEST_MESSAGE = "Hello, world!" + + +def test_codex_backup(mock_client: MagicMock): # noqa: ARG001 + mock_response = MagicMock() + mock_response.answer = MOCK_BACKUP_RESPONSE + mock_client.projects.entries.query.return_value = mock_response + + codex_backup = CodexBackup.from_access_key("") + + class MockApp: + @codex_backup.to_decorator() + def chat(self, user_message: str) -> str: + # Just echo the user message + return user_message + + app = MockApp() + + # Echo works well + response = app.chat(TEST_MESSAGE) + assert response == TEST_MESSAGE + + # Backup works well for fallback responses + response = app.chat(FALLBACK_MESSAGE) + assert response == MOCK_BACKUP_RESPONSE + +def test_backup_handler(mock_client: MagicMock): + mock_response = MagicMock() + mock_response.answer = MOCK_BACKUP_RESPONSE + mock_client.projects.entries.query.return_value = mock_response + + mock_handler = MagicMock() + mock_handler.return_value = None + codex_backup = CodexBackup.from_access_key("", backup_handler=mock_handler) + + class MockApp: + @codex_backup.to_decorator() + def chat(self, user_message: str) -> str: + # Just echo the user message + return user_message + + app = MockApp() + + response = app.chat(TEST_MESSAGE) + assert response == TEST_MESSAGE + + # Handler should not be called for good responses + assert mock_handler.call_count == 0 + + response = app.chat(FALLBACK_MESSAGE) + assert response == MOCK_BACKUP_RESPONSE + + # Handler should be called for bad responses + assert mock_handler.call_count == 1 + # The MockApp is the second argument to the handler, i.e. it has the necessary context + # to handle the new response + assert mock_handler.call_args.kwargs["decorated_instance"] == app + + \ No newline at end of file From 74755d2d7223b0d80caa63b4696b40d0aa53ed18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 24 Jan 2025 22:17:07 +0000 Subject: [PATCH 02/39] formatting --- src/cleanlab_codex/__init__.py | 2 +- src/cleanlab_codex/codex_backup.py | 22 ++++++++++------ .../utils/response_validators.py | 12 +++++---- tests/test_codex_backup.py | 25 +++++++++---------- 4 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/cleanlab_codex/__init__.py b/src/cleanlab_codex/__init__.py index 78c2535..ffae4d2 100644 --- a/src/cleanlab_codex/__init__.py +++ b/src/cleanlab_codex/__init__.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: MIT from cleanlab_codex.codex import Codex -from cleanlab_codex.codex_tool import CodexTool from cleanlab_codex.codex_backup import CodexBackup +from cleanlab_codex.codex_tool import CodexTool __all__ = ["Codex", "CodexTool", "CodexBackup"] diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 7e94ab4..8eba18e 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import Any, Callable, Optional from functools import wraps +from typing import Any, Callable, Optional from cleanlab_codex.codex import Codex from cleanlab_codex.utils.response_validators import is_bad_response -def handle_backup_default(backup_response: str, decorated_instance: Any) -> None: + +def handle_backup_default(backup_response: str, decorated_instance: Any) -> None: # noqa: ARG001 """Default implementation is a no-op.""" return None @@ -15,8 +16,9 @@ class CodexBackup: """A backup decorator that connects to a Codex project to answer questions that cannot be adequately answered by the existing agent. """ + DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." - + def __init__( self, codex_client: Codex, @@ -69,37 +71,41 @@ def from_client( def to_decorator(self): """Factory that creates a backup decorator using the provided Codex client""" + def decorator(chat_method): """ Decorator for RAG chat methods that adds backup response handling. - + If the original chat method returns an inadequate response, attempts to get a backup response from Codex. Returns the backup response if available, otherwise returns the original response. - + Args: chat_method: Method with signature (self, user_message: str) -> str where 'self' refers to the instance being decorated, not an instance of CodexBackup. """ + @wraps(chat_method) def wrapper(decorated_instance, user_message): # Call the original chat method assistant_response = chat_method(decorated_instance, user_message) - + # Return original response if it's adequate if not is_bad_response(assistant_response): return assistant_response - + # Query Codex for a backup response cache_result = self._codex_client.query(user_message)[0] if not cache_result: return assistant_response - + # Handle backup response if handler exists self._backup_handler( backup_response=cache_result, decorated_instance=decorated_instance, ) return cache_result + return wrapper + return decorator diff --git a/src/cleanlab_codex/utils/response_validators.py b/src/cleanlab_codex/utils/response_validators.py index a07a5c1..deaec9c 100644 --- a/src/cleanlab_codex/utils/response_validators.py +++ b/src/cleanlab_codex/utils/response_validators.py @@ -8,22 +8,23 @@ def is_bad_response(response: str) -> bool: """ Default implementation that checks for common fallback phrases from LLM assistants. - + NOTE: YOU SHOULD MODIFY THIS METHOD YOURSELF. """ return basic_validator(response) - + + def basic_validator(response: str) -> bool: """Basic implementation that checks for common fallback phrases from LLM assistants. Args: response: The response from the assistant - + Returns: bool: True if the response appears to be a fallback/inadequate response """ partial_fallback_responses = [ - "Based on the available information", + "Based on the available information", "I cannot provide a complete answer to this question", # Add more substrings here to improve the recall of the check ] @@ -32,6 +33,7 @@ def basic_validator(response: str) -> bool: for partial_fallback_response in partial_fallback_responses ) + # Alternative Implementations # --------------------------- # The following implementations are provided as examples and inspiration. @@ -77,4 +79,4 @@ def tlm_binary_validator(response: str, tlm: TLM, query: Optional[str] = None) - prompt = f"Here is a response from an AI assistant: {response}\n\n Considering the following query: {query}\n\n Is the response helpful? Answer Yes/No only." output = tlm.prompt(prompt) return output["response"].lower() == "no" -""" \ No newline at end of file +""" diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py index 1e82dbf..59caea0 100644 --- a/tests/test_codex_backup.py +++ b/tests/test_codex_backup.py @@ -7,29 +7,30 @@ TEST_MESSAGE = "Hello, world!" -def test_codex_backup(mock_client: MagicMock): # noqa: ARG001 +def test_codex_backup(mock_client: MagicMock): mock_response = MagicMock() mock_response.answer = MOCK_BACKUP_RESPONSE mock_client.projects.entries.query.return_value = mock_response codex_backup = CodexBackup.from_access_key("") - + class MockApp: @codex_backup.to_decorator() def chat(self, user_message: str) -> str: # Just echo the user message return user_message - + app = MockApp() # Echo works well response = app.chat(TEST_MESSAGE) assert response == TEST_MESSAGE - + # Backup works well for fallback responses response = app.chat(FALLBACK_MESSAGE) assert response == MOCK_BACKUP_RESPONSE - + + def test_backup_handler(mock_client: MagicMock): mock_response = MagicMock() mock_response.answer = MOCK_BACKUP_RESPONSE @@ -38,28 +39,26 @@ def test_backup_handler(mock_client: MagicMock): mock_handler = MagicMock() mock_handler.return_value = None codex_backup = CodexBackup.from_access_key("", backup_handler=mock_handler) - + class MockApp: @codex_backup.to_decorator() def chat(self, user_message: str) -> str: # Just echo the user message return user_message - + app = MockApp() - + response = app.chat(TEST_MESSAGE) assert response == TEST_MESSAGE - + # Handler should not be called for good responses assert mock_handler.call_count == 0 - + response = app.chat(FALLBACK_MESSAGE) assert response == MOCK_BACKUP_RESPONSE - + # Handler should be called for bad responses assert mock_handler.call_count == 1 # The MockApp is the second argument to the handler, i.e. it has the necessary context # to handle the new response assert mock_handler.call_args.kwargs["decorated_instance"] == app - - \ No newline at end of file From f7f81568552da539acbec58832d593a3576b0024 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Thu, 30 Jan 2025 02:18:15 +0000 Subject: [PATCH 03/39] Move is_bad_response helper functions to a validation.py module for checking LLM response quality --- .../utils/response_validators.py | 173 +++++++++--------- src/cleanlab_codex/validation.py | 124 +++++++++++++ 2 files changed, 215 insertions(+), 82 deletions(-) create mode 100644 src/cleanlab_codex/validation.py diff --git a/src/cleanlab_codex/utils/response_validators.py b/src/cleanlab_codex/utils/response_validators.py index deaec9c..592db01 100644 --- a/src/cleanlab_codex/utils/response_validators.py +++ b/src/cleanlab_codex/utils/response_validators.py @@ -1,82 +1,91 @@ -""" -This module provides validation functions for checking if an LLM response is inadequate/unhelpful. -The default implementation checks for common fallback phrases, but alternative implementations -are provided below as examples that can be adapted for specific needs. -""" - - -def is_bad_response(response: str) -> bool: - """ - Default implementation that checks for common fallback phrases from LLM assistants. - - NOTE: YOU SHOULD MODIFY THIS METHOD YOURSELF. - """ - return basic_validator(response) - - -def basic_validator(response: str) -> bool: - """Basic implementation that checks for common fallback phrases from LLM assistants. - - Args: - response: The response from the assistant - - Returns: - bool: True if the response appears to be a fallback/inadequate response - """ - partial_fallback_responses = [ - "Based on the available information", - "I cannot provide a complete answer to this question", - # Add more substrings here to improve the recall of the check - ] - return any( - partial_fallback_response.lower() in response.lower() - for partial_fallback_response in partial_fallback_responses - ) - - -# Alternative Implementations -# --------------------------- -# The following implementations are provided as examples and inspiration. -# They should be adapted to your specific needs. - - -# Fuzzy String Matching -""" -from thefuzz import fuzz - -def fuzzy_match_validator(response: str, fallback_answer: str, threshold: int = 70) -> bool: - partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) - return partial_ratio >= threshold -""" - -# TLM Score Thresholding -""" -from cleanlab_studio import Studio - -studio = Studio("") -tlm = studio.TLM() - -def tlm_score_validator(response: str, context: str, query: str, tlm: TLM, threshold: float = 0.5) -> bool: - prompt = f"Context: {context}\n\n Query: {query}\n\n Query: {query}" - resp = tlm.get_trustworthiness_score(prompt, response) - score = resp['trustworthiness_score'] - return score < threshold -""" - -# TLM Binary Classification -""" -from typing import Optional - -from cleanlab_studio import Studio - -studio = Studio("") -tlm = studio.TLM() - -def tlm_binary_validator(response: str, tlm: TLM, query: Optional[str] = None) -> bool: - if query is None: - prompt = f"Here is a response from an AI assistant: {response}\n\n Is it helpful? Answer Yes/No only." - else: - prompt = f"Here is a response from an AI assistant: {response}\n\n Considering the following query: {query}\n\n Is the response helpful? Answer Yes/No only." - output = tlm.prompt(prompt) - return output["response"].lower() == "no" -""" +# """ +# This module provides validation functions for checking if an LLM response is unhelpful. +# """ + +# from typing import Optional, TYPE_CHECKING + +# if TYPE_CHECKING: +# try: +# from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 +# except ImportError: +# raise ImportError("The 'cleanlab_studio' library is required to run this validator. Please install it with `pip install cleanlab-studio`.") + + +# def is_bad_response(response: str, fallback_answer: str, threshold: int = 70) -> bool: +# """Use partial ratio to match a fallback_answer to the response, indicating how unhelpful the response is. +# If the partial ratio is greater than or equal to the threshold, return True. +# """ +# try: +# from thefuzz import fuzz +# except ImportError: +# raise ImportError("The 'thefuzz' library is required to run this validator. Please install it with `pip install thefuzz`.") + +# partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) +# return partial_ratio >= threshold + +# def is_bad_response_contains_phrase(response: str, fallback_responses: list[str]) -> bool: +# """Check whether the response matches a fallback phrase, indicating the response is not helpful. + +# Args: +# response: The response from the assistant +# fallback_responses: A list of fallback phrases to check against the response. + +# Returns: +# bool: True if the response appears to be a fallback/inadequate response +# """ +# return any( +# phrase.lower() in response.lower() +# for phrase in fallback_responses +# ) + +# def is_bad_response_untrustworthy(response: str, context: str, query: str, tlm: TLM, threshold: float = 0.5) -> bool: +# """ +# Check whether the response is untrustworthy based on the TLM score. + +# Args: +# response: The response from the assistant +# context: The context of the query +# query: The user query +# tlm: The TLM model to use +# threshold: The threshold for the TLM score. If the score is less than this threshold, the response is considered untrustworthy. + +# Returns: +# bool: True if the response is untrustworthy +# """ +# prompt = f"Context: {context}\n\n Query: {query}\n\n Query: {query}" +# resp = tlm.get_trustworthiness_score(prompt, response) +# score = resp['trustworthiness_score'] +# return score < threshold + +# # TLM Binary Classification +# def is_bad_response_unhelpful(response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None) -> bool: +# """ +# Check whether the response is unhelpful based on the TLM score. A query may optionally be provided to help the TLM determine if the response is helpful to answer the given query. + +# Args: +# response: The response from the assistant +# tlm: The TLM model to use +# query: The user query +# trustworthiness_score_threshold: The threshold for the TLM score. If the score is less than this threshold, the response is considered unhelpful. + +# Returns: +# bool: True if the response is unhelpful +# """ +# if query is None: +# prompt = ( +# "Consider the following AI Assistant Response.\n\n" +# f"AI Assistant Response: {response}\n\n" +# "Is the AI Assistant Response helpful? Answer Yes/No only." +# ) +# else: +# prompt = ( +# "Consider the following User Query and AI Assistant Response.\n\n" +# f"User Query: {query}\n\n" +# f"AI Assistant Response: {response}\n\n" +# "Is the AI Assistant Response helpful? Answer Yes/No only." +# ) +# output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) +# response_marked_unhelpful = output["response"].lower() == "no" +# # TODO: Decide if we should keep the trustworthiness score threshold. +# is_trustworthy = trustworthiness_score_threshold is None or (output["trustworthiness_score"] > trustworthiness_score_threshold) +# return response_marked_unhelpful and is_trustworthy diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py new file mode 100644 index 0000000..7fee73a --- /dev/null +++ b/src/cleanlab_codex/validation.py @@ -0,0 +1,124 @@ +""" +This module provides validation functions for checking if an LLM response is unhelpful. +""" +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from cleanlab_studio.studio.trustworthy_language_model import TLM + + +def is_bad_response(response: str, fallback_answer: str, threshold: int = 70) -> bool: + """Check if a response is too similar to a known fallback/unhelpful answer. + + Uses fuzzy string matching to compare the response against a known fallback answer. + Returns True if the response is similar enough to be considered unhelpful. + + Args: + response: The response to check + fallback_answer: A known unhelpful/fallback response to compare against + threshold: Similarity threshold (0-100). Higher values require more similarity. + Default 70 means responses that are 70% or more similar are considered bad. + + Returns: + bool: True if the response is too similar to the fallback answer, False otherwise + """ + try: + from thefuzz import fuzz + except ImportError: + raise ImportError("The 'thefuzz' library is required to run this validator. Please install it with `pip install thefuzz`.") + + partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) + return partial_ratio >= threshold + +def is_bad_response_contains_phrase(response: str, fallback_responses: list[str]) -> bool: + """Check if a response is unhelpful by looking for known fallback phrases. + + Uses simple substring matching to check if the response contains any known fallback phrases + that indicate the response is unhelpful (e.g. "I cannot help with that", "I don't know"). + Returns True if any fallback phrase is found in the response. + + Args: + response: The response to check from the assistant + fallback_responses: List of known fallback phrases that indicate an unhelpful response. + The check is case-insensitive. + + Returns: + bool: True if the response contains any fallback phrase, False otherwise + """ + return any(phrase.lower() in response.lower() for phrase in fallback_responses) + +def is_bad_response_untrustworthy( + response: str, + context: str, + query: str, + tlm: TLM, + threshold: float = 0.6, + # TODO: Optimize prompt template + prompt_template: str = "Using the following Context, provide a helpful answer to the Query.\n\n Context:\n{context}\n\n Query: {query}", +) -> bool: + """Check if a response is untrustworthy based on TLM's evaluation. + + Uses TLM to evaluate whether a response is trustworthy given the context and query. + Returns True if TLM's trustworthiness score falls below the threshold, indicating + the response may be incorrect or unreliable. + + Args: + response: The response to check from the assistant + context: The context information available for answering the query + query: The user's question or request + tlm: The TLM model to use for evaluation + threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. + Default 0.6, meaning responses with scores less than 0.6 are considered untrustworthy. + prompt_template: Template for formatting the evaluation prompt. Must contain {context} + and {query} placeholders. + + Returns: + bool: True if the response is deemed untrustworthy by TLM, False otherwise + """ + prompt = prompt_template.format(context=context, query=query) + resp = tlm.get_trustworthiness_score(prompt, response) + score = resp['trustworthiness_score'] + return score < threshold + +# TLM Binary Classification +def is_bad_response_unhelpful(response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None) -> bool: + """Check if a response is unhelpful by asking TLM to evaluate it. + + Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment. + The evaluation considers both the TLM's binary classification of helpfulness and its + confidence score. Returns True only if TLM classifies the response as unhelpful AND + is sufficiently confident in that assessment (if a threshold is provided). + + Args: + response: The response to check from the assistant + tlm: The TLM model to use for evaluation + query: Optional user query to provide context for evaluating helpfulness. + If provided, TLM will assess if the response helpfully answers this query. + trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0). + If provided, responses are only marked unhelpful if TLM's + confidence score exceeds this threshold. + + Returns: + bool: True if TLM determines the response is unhelpful with sufficient confidence, + False otherwise + """ + if query is None: + prompt = ( + "Consider the following AI Assistant Response.\n\n" + f"AI Assistant Response: {response}\n\n" + "Is the AI Assistant Response helpful? Remember that abstaining from responding is not helpful. Answer Yes/No only." + ) + else: + prompt = ( + "Consider the following User Query and AI Assistant Response.\n\n" + f"User Query: {query}\n\n" + f"AI Assistant Response: {response}\n\n" + "Is the AI Assistant Response helpful? Remember that abstaining from responding is not helpful. Answer Yes/No only." + ) + output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) + response_marked_unhelpful = output["response"].lower() == "no" + # TODO: Decide if we should keep the trustworthiness score threshold. + is_trustworthy = trustworthiness_score_threshold is None or (output["trustworthiness_score"] > trustworthiness_score_threshold) + return response_marked_unhelpful and is_trustworthy From d20d31c830eb6b6a1c1cd11a6d141c7a6d70ff72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Thu, 30 Jan 2025 02:18:27 +0000 Subject: [PATCH 04/39] update import --- src/cleanlab_codex/codex_backup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 8eba18e..13b53af 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Optional from cleanlab_codex.codex import Codex -from cleanlab_codex.utils.response_validators import is_bad_response +from cleanlab_codex.validation import is_bad_response def handle_backup_default(backup_response: str, decorated_instance: Any) -> None: # noqa: ARG001 From a223b04893931632f6b541874e3092d8c5259b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Thu, 30 Jan 2025 02:50:37 +0000 Subject: [PATCH 05/39] formatting and typing (wip) --- src/cleanlab_codex/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 7fee73a..f41a730 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -27,7 +27,7 @@ def is_bad_response(response: str, fallback_answer: str, threshold: int = 70) -> try: from thefuzz import fuzz except ImportError: - raise ImportError("The 'thefuzz' library is required to run this validator. Please install it with `pip install thefuzz`.") + raise ImportError("The 'thefuzz' library is required. Please install it with `pip install thefuzz`.") partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) return partial_ratio >= threshold @@ -79,7 +79,7 @@ def is_bad_response_untrustworthy( """ prompt = prompt_template.format(context=context, query=query) resp = tlm.get_trustworthiness_score(prompt, response) - score = resp['trustworthiness_score'] + score: float = resp['trustworthiness_score'] return score < threshold # TLM Binary Classification From f15424c1ca3175b5c30267f0882cd7625c8301c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Thu, 30 Jan 2025 02:51:42 +0000 Subject: [PATCH 06/39] Remove response_validators.py module --- .../utils/response_validators.py | 91 ------------------- 1 file changed, 91 deletions(-) delete mode 100644 src/cleanlab_codex/utils/response_validators.py diff --git a/src/cleanlab_codex/utils/response_validators.py b/src/cleanlab_codex/utils/response_validators.py deleted file mode 100644 index 592db01..0000000 --- a/src/cleanlab_codex/utils/response_validators.py +++ /dev/null @@ -1,91 +0,0 @@ -# """ -# This module provides validation functions for checking if an LLM response is unhelpful. -# """ - -# from typing import Optional, TYPE_CHECKING - -# if TYPE_CHECKING: -# try: -# from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 -# except ImportError: -# raise ImportError("The 'cleanlab_studio' library is required to run this validator. Please install it with `pip install cleanlab-studio`.") - - -# def is_bad_response(response: str, fallback_answer: str, threshold: int = 70) -> bool: -# """Use partial ratio to match a fallback_answer to the response, indicating how unhelpful the response is. -# If the partial ratio is greater than or equal to the threshold, return True. -# """ -# try: -# from thefuzz import fuzz -# except ImportError: -# raise ImportError("The 'thefuzz' library is required to run this validator. Please install it with `pip install thefuzz`.") - -# partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) -# return partial_ratio >= threshold - -# def is_bad_response_contains_phrase(response: str, fallback_responses: list[str]) -> bool: -# """Check whether the response matches a fallback phrase, indicating the response is not helpful. - -# Args: -# response: The response from the assistant -# fallback_responses: A list of fallback phrases to check against the response. - -# Returns: -# bool: True if the response appears to be a fallback/inadequate response -# """ -# return any( -# phrase.lower() in response.lower() -# for phrase in fallback_responses -# ) - -# def is_bad_response_untrustworthy(response: str, context: str, query: str, tlm: TLM, threshold: float = 0.5) -> bool: -# """ -# Check whether the response is untrustworthy based on the TLM score. - -# Args: -# response: The response from the assistant -# context: The context of the query -# query: The user query -# tlm: The TLM model to use -# threshold: The threshold for the TLM score. If the score is less than this threshold, the response is considered untrustworthy. - -# Returns: -# bool: True if the response is untrustworthy -# """ -# prompt = f"Context: {context}\n\n Query: {query}\n\n Query: {query}" -# resp = tlm.get_trustworthiness_score(prompt, response) -# score = resp['trustworthiness_score'] -# return score < threshold - -# # TLM Binary Classification -# def is_bad_response_unhelpful(response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None) -> bool: -# """ -# Check whether the response is unhelpful based on the TLM score. A query may optionally be provided to help the TLM determine if the response is helpful to answer the given query. - -# Args: -# response: The response from the assistant -# tlm: The TLM model to use -# query: The user query -# trustworthiness_score_threshold: The threshold for the TLM score. If the score is less than this threshold, the response is considered unhelpful. - -# Returns: -# bool: True if the response is unhelpful -# """ -# if query is None: -# prompt = ( -# "Consider the following AI Assistant Response.\n\n" -# f"AI Assistant Response: {response}\n\n" -# "Is the AI Assistant Response helpful? Answer Yes/No only." -# ) -# else: -# prompt = ( -# "Consider the following User Query and AI Assistant Response.\n\n" -# f"User Query: {query}\n\n" -# f"AI Assistant Response: {response}\n\n" -# "Is the AI Assistant Response helpful? Answer Yes/No only." -# ) -# output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) -# response_marked_unhelpful = output["response"].lower() == "no" -# # TODO: Decide if we should keep the trustworthiness score threshold. -# is_trustworthy = trustworthiness_score_threshold is None or (output["trustworthiness_score"] > trustworthiness_score_threshold) -# return response_marked_unhelpful and is_trustworthy From 23eeb580b47897295b87fba910600e6f1a38174b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Thu, 30 Jan 2025 03:02:09 +0000 Subject: [PATCH 07/39] remove is_bad_response_contains_phrase --- src/cleanlab_codex/validation.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index f41a730..f52377a 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -32,23 +32,6 @@ def is_bad_response(response: str, fallback_answer: str, threshold: int = 70) -> partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) return partial_ratio >= threshold -def is_bad_response_contains_phrase(response: str, fallback_responses: list[str]) -> bool: - """Check if a response is unhelpful by looking for known fallback phrases. - - Uses simple substring matching to check if the response contains any known fallback phrases - that indicate the response is unhelpful (e.g. "I cannot help with that", "I don't know"). - Returns True if any fallback phrase is found in the response. - - Args: - response: The response to check from the assistant - fallback_responses: List of known fallback phrases that indicate an unhelpful response. - The check is case-insensitive. - - Returns: - bool: True if the response contains any fallback phrase, False otherwise - """ - return any(phrase.lower() in response.lower() for phrase in fallback_responses) - def is_bad_response_untrustworthy( response: str, context: str, From 9e8690c027f98a75762f2e83999f89bb65d2e2d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 31 Jan 2025 19:20:07 +0000 Subject: [PATCH 08/39] Improve helpfer functions for detecting bad responses --- src/cleanlab_codex/codex_backup.py | 3 +- src/cleanlab_codex/utils/__init__.py | 3 +- src/cleanlab_codex/utils/prompt.py | 20 +++++ src/cleanlab_codex/validation.py | 119 +++++++++++++++++++++------ 4 files changed, 116 insertions(+), 29 deletions(-) create mode 100644 src/cleanlab_codex/utils/prompt.py diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 13b53af..94194ea 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -91,7 +91,8 @@ def wrapper(decorated_instance, user_message): assistant_response = chat_method(decorated_instance, user_message) # Return original response if it's adequate - if not is_bad_response(assistant_response): + # TODO: Update usage of is_bad_response + if not is_bad_response(assistant_response, self._fallback_answer): return assistant_response # Query Codex for a backup response diff --git a/src/cleanlab_codex/utils/__init__.py b/src/cleanlab_codex/utils/__init__.py index da217b3..3fb7e09 100644 --- a/src/cleanlab_codex/utils/__init__.py +++ b/src/cleanlab_codex/utils/__init__.py @@ -2,5 +2,6 @@ from cleanlab_codex.utils.openai import FunctionParameters as OpenAIFunctionParameters from cleanlab_codex.utils.openai import Tool as OpenAITool from cleanlab_codex.utils.openai import format_as_openai_tool +from cleanlab_codex.utils.prompt import default_format_prompt -__all__ = ["OpenAIFunction", "OpenAIFunctionParameters", "OpenAITool", "format_as_openai_tool"] +__all__ = ["OpenAIFunction", "OpenAIFunctionParameters", "OpenAITool", "format_as_openai_tool", "default_format_prompt"] \ No newline at end of file diff --git a/src/cleanlab_codex/utils/prompt.py b/src/cleanlab_codex/utils/prompt.py new file mode 100644 index 0000000..3fb64d4 --- /dev/null +++ b/src/cleanlab_codex/utils/prompt.py @@ -0,0 +1,20 @@ +""" +Utility functions for RAG (Retrieval Augmented Generation) operations. +""" + +def default_format_prompt(query: str, context: str) -> str: + """Default function for formatting RAG prompts. + + Args: + query: The user's question + context: The context/documents to use for answering + + Returns: + str: A formatted prompt combining the query and context + """ + template = ( + "Using only information from the following Context, answer the following Query.\n\n" + "Context:\n{context}\n\n" + "Query: {query}" + ) + return template.format(context=context, query=query) \ No newline at end of file diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index f52377a..208b62c 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -3,14 +3,63 @@ """ from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import Callable, Optional, TYPE_CHECKING + +from cleanlab_codex.utils.prompt import default_format_prompt if TYPE_CHECKING: from cleanlab_studio.studio.trustworthy_language_model import TLM -def is_bad_response(response: str, fallback_answer: str, threshold: int = 70) -> bool: - """Check if a response is too similar to a known fallback/unhelpful answer. +DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." +DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 +DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 + +def is_bad_response( + response: str, + context: str, + tlm: TLM, # TODO: Make this optional + query: Optional[str] = None, + # is_fallback_response args + fallback_answer: str = DEFAULT_FALLBACK_ANSWER, + partial_ratio_threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD, + # is_untrustworthy_response args + trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, + # is_unhelpful_response args + unhelpful_trustworthiness_threshold: Optional[float] = None, +) -> bool: + """Run a series of checks to determine if a response is bad. If any of the checks pass, return True. + + Checks: + - Is the response too similar to a known fallback answer? + - Is the response untrustworthy? + - Is the response unhelpful? + + Args: + response: The response to check. See `is_fallback_response`, `is_untrustworthy_response`, and `is_unhelpful_response`. + context: The context/documents to use for answering. See `is_untrustworthy_response`. + tlm: The TLM model to use for evaluation. See `is_untrustworthy_response` and `is_unhelpful_response`. + query: The user's question (optional). See `is_untrustworthy_response` and `is_unhelpful_response`. + fallback_answer: The fallback answer to compare against. See `is_fallback_response`. + partial_ratio_threshold: The threshold for detecting fallback responses. See `is_fallback_response`. + trustworthiness_threshold: The threshold for detecting untrustworthy responses. See `is_untrustworthy_response`. + unhelpful_trustworthiness_threshold: The threshold for detecting unhelpful responses. See `is_unhelpful_response`. + """ + validation_checks = [ + lambda: is_fallback_response(response, fallback_answer, threshold=partial_ratio_threshold), + lambda: ( + is_untrustworthy_response(response, context, query, tlm, threshold=trustworthiness_threshold) + if query is not None + else False + ), + lambda: is_unhelpful_response(response, tlm, query, trustworthiness_score_threshold=unhelpful_trustworthiness_threshold) + ] + + return any(check() for check in validation_checks) + + +def is_fallback_response(response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int=DEFAULT_PARTIAL_RATIO_THRESHOLD) -> bool: + """Check if a response is too similar to a known fallback answer. Uses fuzzy string matching to compare the response against a known fallback answer. Returns True if the response is similar enough to be considered unhelpful. @@ -32,14 +81,13 @@ def is_bad_response(response: str, fallback_answer: str, threshold: int = 70) -> partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) return partial_ratio >= threshold -def is_bad_response_untrustworthy( +def is_untrustworthy_response( response: str, context: str, query: str, tlm: TLM, - threshold: float = 0.6, - # TODO: Optimize prompt template - prompt_template: str = "Using the following Context, provide a helpful answer to the Query.\n\n Context:\n{context}\n\n Query: {query}", + threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, + format_prompt: Callable[[str, str], str] = default_format_prompt ) -> bool: """Check if a response is untrustworthy based on TLM's evaluation. @@ -54,19 +102,25 @@ def is_bad_response_untrustworthy( tlm: The TLM model to use for evaluation threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. Default 0.6, meaning responses with scores less than 0.6 are considered untrustworthy. - prompt_template: Template for formatting the evaluation prompt. Must contain {context} - and {query} placeholders. + format_prompt: Function that takes (query, context) and returns a formatted prompt string. + Users should provide their RAG app's own prompt formatting function here + to match how their LLM is prompted. Returns: bool: True if the response is deemed untrustworthy by TLM, False otherwise """ - prompt = prompt_template.format(context=context, query=query) + try: + from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 + except ImportError: + raise ImportError("The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`.") + + prompt = format_prompt(query, context) resp = tlm.get_trustworthiness_score(prompt, response) score: float = resp['trustworthiness_score'] return score < threshold -# TLM Binary Classification -def is_bad_response_unhelpful(response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None) -> bool: + +def is_unhelpful_response(response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None) -> bool: """Check if a response is unhelpful by asking TLM to evaluate it. Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment. @@ -87,21 +141,32 @@ def is_bad_response_unhelpful(response: str, tlm: TLM, query: Optional[str] = No bool: True if TLM determines the response is unhelpful with sufficient confidence, False otherwise """ - if query is None: - prompt = ( - "Consider the following AI Assistant Response.\n\n" - f"AI Assistant Response: {response}\n\n" - "Is the AI Assistant Response helpful? Remember that abstaining from responding is not helpful. Answer Yes/No only." - ) - else: - prompt = ( - "Consider the following User Query and AI Assistant Response.\n\n" - f"User Query: {query}\n\n" - f"AI Assistant Response: {response}\n\n" - "Is the AI Assistant Response helpful? Remember that abstaining from responding is not helpful. Answer Yes/No only." - ) + try: + from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 + except ImportError: + raise ImportError("The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`.") + + # The question and expected "unhelpful" response are linked: + # - When asking "is helpful?" -> "no" means unhelpful + # - When asking "is unhelpful?" -> "yes" means unhelpful + question = ( + "Is the AI Assistant Response unhelpful? " + "Unhelpful responses include answers that:\n" + "- Are not useful, incomplete, incorrect, uncertain or unclear.\n" + "- Abstain or refuse to answer the question\n" + "- Leave the original question unresolved\n" + "- Are irrelevant to the question\n" + "Answer Yes/No only." + ) + expected_unhelpful_response = "yes" + + prompt = ( + "Consider the following" + + (f" User Query and AI Assistant Response.\n\nUser Query: {query}\n\n" if query else " AI Assistant Response.\n\n") + + f"AI Assistant Response: {response}\n\n{question}" + ) + output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) - response_marked_unhelpful = output["response"].lower() == "no" - # TODO: Decide if we should keep the trustworthiness score threshold. + response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response is_trustworthy = trustworthiness_score_threshold is None or (output["trustworthiness_score"] > trustworthiness_score_threshold) return response_marked_unhelpful and is_trustworthy From d0ad8df1d05549181a73fc955bcbb9e41976f7af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 31 Jan 2025 19:56:47 +0000 Subject: [PATCH 09/39] formatting and add dependencies --- pyproject.toml | 2 ++ src/cleanlab_codex/utils/__init__.py | 2 +- src/cleanlab_codex/utils/prompt.py | 2 +- src/cleanlab_codex/validation.py | 18 +++++++++--------- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b036139..5db58d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,8 @@ extra-dependencies = [ "pytest", "llama-index-core", "smolagents", + "cleanlab-studio", + "thefuzz", ] [tool.hatch.envs.types.scripts] check = "mypy --install-types --non-interactive {args:src/cleanlab_codex tests}" diff --git a/src/cleanlab_codex/utils/__init__.py b/src/cleanlab_codex/utils/__init__.py index 3fb7e09..442e4ee 100644 --- a/src/cleanlab_codex/utils/__init__.py +++ b/src/cleanlab_codex/utils/__init__.py @@ -4,4 +4,4 @@ from cleanlab_codex.utils.openai import format_as_openai_tool from cleanlab_codex.utils.prompt import default_format_prompt -__all__ = ["OpenAIFunction", "OpenAIFunctionParameters", "OpenAITool", "format_as_openai_tool", "default_format_prompt"] \ No newline at end of file +__all__ = ["OpenAIFunction", "OpenAIFunctionParameters", "OpenAITool", "format_as_openai_tool", "default_format_prompt"] diff --git a/src/cleanlab_codex/utils/prompt.py b/src/cleanlab_codex/utils/prompt.py index 3fb64d4..8ea7e32 100644 --- a/src/cleanlab_codex/utils/prompt.py +++ b/src/cleanlab_codex/utils/prompt.py @@ -17,4 +17,4 @@ def default_format_prompt(query: str, context: str) -> str: "Context:\n{context}\n\n" "Query: {query}" ) - return template.format(context=context, query=query) \ No newline at end of file + return template.format(context=context, query=query) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 208b62c..da843cf 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -3,12 +3,12 @@ """ from __future__ import annotations -from typing import Callable, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Optional from cleanlab_codex.utils.prompt import default_format_prompt if TYPE_CHECKING: - from cleanlab_studio.studio.trustworthy_language_model import TLM + from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." @@ -54,7 +54,7 @@ def is_bad_response( ), lambda: is_unhelpful_response(response, tlm, query, trustworthiness_score_threshold=unhelpful_trustworthiness_threshold) ] - + return any(check() for check in validation_checks) @@ -74,7 +74,7 @@ def is_fallback_response(response: str, fallback_answer: str = DEFAULT_FALLBACK_ bool: True if the response is too similar to the fallback answer, False otherwise """ try: - from thefuzz import fuzz + from thefuzz import fuzz # type: ignore except ImportError: raise ImportError("The 'thefuzz' library is required. Please install it with `pip install thefuzz`.") @@ -113,7 +113,7 @@ def is_untrustworthy_response( from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 except ImportError: raise ImportError("The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`.") - + prompt = format_prompt(query, context) resp = tlm.get_trustworthiness_score(prompt, response) score: float = resp['trustworthiness_score'] @@ -145,7 +145,7 @@ def is_unhelpful_response(response: str, tlm: TLM, query: Optional[str] = None, from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 except ImportError: raise ImportError("The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`.") - + # The question and expected "unhelpful" response are linked: # - When asking "is helpful?" -> "no" means unhelpful # - When asking "is unhelpful?" -> "yes" means unhelpful @@ -159,13 +159,13 @@ def is_unhelpful_response(response: str, tlm: TLM, query: Optional[str] = None, "Answer Yes/No only." ) expected_unhelpful_response = "yes" - + prompt = ( - "Consider the following" + + "Consider the following" + (f" User Query and AI Assistant Response.\n\nUser Query: {query}\n\n" if query else " AI Assistant Response.\n\n") + f"AI Assistant Response: {response}\n\n{question}" ) - + output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response is_trustworthy = trustworthiness_score_threshold is None or (output["trustworthiness_score"] > trustworthiness_score_threshold) From b4bff54a61798af67e6cb4b3ee4b05436c12e1dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 1 Feb 2025 00:58:37 +0000 Subject: [PATCH 10/39] formatting --- src/cleanlab_codex/utils/prompt.py | 5 +-- src/cleanlab_codex/validation.py | 52 ++++++++++++++++++++---------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/src/cleanlab_codex/utils/prompt.py b/src/cleanlab_codex/utils/prompt.py index 8ea7e32..c04fc71 100644 --- a/src/cleanlab_codex/utils/prompt.py +++ b/src/cleanlab_codex/utils/prompt.py @@ -2,13 +2,14 @@ Utility functions for RAG (Retrieval Augmented Generation) operations. """ + def default_format_prompt(query: str, context: str) -> str: """Default function for formatting RAG prompts. - + Args: query: The user's question context: The context/documents to use for answering - + Returns: str: A formatted prompt combining the query and context """ diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index da843cf..76a4c9f 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -1,6 +1,7 @@ """ This module provides validation functions for checking if an LLM response is unhelpful. """ + from __future__ import annotations from typing import TYPE_CHECKING, Callable, Optional @@ -15,10 +16,11 @@ DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 + def is_bad_response( response: str, context: str, - tlm: TLM, # TODO: Make this optional + tlm: TLM, # TODO: Make this optional query: Optional[str] = None, # is_fallback_response args fallback_answer: str = DEFAULT_FALLBACK_ANSWER, @@ -34,7 +36,7 @@ def is_bad_response( - Is the response too similar to a known fallback answer? - Is the response untrustworthy? - Is the response unhelpful? - + Args: response: The response to check. See `is_fallback_response`, `is_untrustworthy_response`, and `is_unhelpful_response`. context: The context/documents to use for answering. See `is_untrustworthy_response`. @@ -52,13 +54,17 @@ def is_bad_response( if query is not None else False ), - lambda: is_unhelpful_response(response, tlm, query, trustworthiness_score_threshold=unhelpful_trustworthiness_threshold) + lambda: is_unhelpful_response( + response, tlm, query, trustworthiness_score_threshold=unhelpful_trustworthiness_threshold + ), ] return any(check() for check in validation_checks) -def is_fallback_response(response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int=DEFAULT_PARTIAL_RATIO_THRESHOLD) -> bool: +def is_fallback_response( + response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD +) -> bool: """Check if a response is too similar to a known fallback answer. Uses fuzzy string matching to compare the response against a known fallback answer. @@ -75,19 +81,21 @@ def is_fallback_response(response: str, fallback_answer: str = DEFAULT_FALLBACK_ """ try: from thefuzz import fuzz # type: ignore - except ImportError: - raise ImportError("The 'thefuzz' library is required. Please install it with `pip install thefuzz`.") + except ImportError as e: + error_msg = "The 'thefuzz' library is required. Please install it with `pip install thefuzz`." + raise ImportError(error_msg) from e partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) return partial_ratio >= threshold + def is_untrustworthy_response( response: str, context: str, query: str, tlm: TLM, threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, - format_prompt: Callable[[str, str], str] = default_format_prompt + format_prompt: Callable[[str, str], str] = default_format_prompt, ) -> bool: """Check if a response is untrustworthy based on TLM's evaluation. @@ -111,16 +119,19 @@ def is_untrustworthy_response( """ try: from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 - except ImportError: - raise ImportError("The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`.") + except ImportError as e: + error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." + raise ImportError(error_msg) from e prompt = format_prompt(query, context) resp = tlm.get_trustworthiness_score(prompt, response) - score: float = resp['trustworthiness_score'] + score: float = resp["trustworthiness_score"] return score < threshold -def is_unhelpful_response(response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None) -> bool: +def is_unhelpful_response( + response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None +) -> bool: """Check if a response is unhelpful by asking TLM to evaluate it. Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment. @@ -143,8 +154,9 @@ def is_unhelpful_response(response: str, tlm: TLM, query: Optional[str] = None, """ try: from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 - except ImportError: - raise ImportError("The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`.") + except ImportError as e: + error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." + raise ImportError(error_msg) from e # The question and expected "unhelpful" response are linked: # - When asking "is helpful?" -> "no" means unhelpful @@ -161,12 +173,18 @@ def is_unhelpful_response(response: str, tlm: TLM, query: Optional[str] = None, expected_unhelpful_response = "yes" prompt = ( - "Consider the following" + - (f" User Query and AI Assistant Response.\n\nUser Query: {query}\n\n" if query else " AI Assistant Response.\n\n") + - f"AI Assistant Response: {response}\n\n{question}" + "Consider the following" + + ( + f" User Query and AI Assistant Response.\n\nUser Query: {query}\n\n" + if query + else " AI Assistant Response.\n\n" + ) + + f"AI Assistant Response: {response}\n\n{question}" ) output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response - is_trustworthy = trustworthiness_score_threshold is None or (output["trustworthiness_score"] > trustworthiness_score_threshold) + is_trustworthy = trustworthiness_score_threshold is None or ( + output["trustworthiness_score"] > trustworthiness_score_threshold + ) return response_marked_unhelpful and is_trustworthy From 038a475a3cba55af054364b658cf389652f0465b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 1 Feb 2025 01:00:34 +0000 Subject: [PATCH 11/39] address type checker complaints --- src/cleanlab_codex/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 76a4c9f..4b9f98e 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -85,8 +85,8 @@ def is_fallback_response( error_msg = "The 'thefuzz' library is required. Please install it with `pip install thefuzz`." raise ImportError(error_msg) from e - partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) - return partial_ratio >= threshold + partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) + return bool(partial_ratio >= threshold) def is_untrustworthy_response( From 5fbb48e16a4d54e1f3628b6afd55ad52bf28f7c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 1 Feb 2025 01:07:49 +0000 Subject: [PATCH 12/39] temporarily skip tests for codex_backup module --- tests/test_codex_backup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py index 59caea0..b0e61ce 100644 --- a/tests/test_codex_backup.py +++ b/tests/test_codex_backup.py @@ -2,6 +2,10 @@ from cleanlab_codex.codex_backup import CodexBackup +# TODO: Remove this skip once we update codex_backup.py +import pytest; pytest.skip(allow_module_level=True) + + MOCK_BACKUP_RESPONSE = "This is a test response" FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question." TEST_MESSAGE = "Hello, world!" From 3892b521e87acb1e6b29b2c8a0571c48ba5789fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 1 Feb 2025 01:10:53 +0000 Subject: [PATCH 13/39] formatting --- tests/test_codex_backup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py index b0e61ce..03cae20 100644 --- a/tests/test_codex_backup.py +++ b/tests/test_codex_backup.py @@ -1,9 +1,10 @@ +import pytest from unittest.mock import MagicMock from cleanlab_codex.codex_backup import CodexBackup # TODO: Remove this skip once we update codex_backup.py -import pytest; pytest.skip(allow_module_level=True) +pytest.skip(allow_module_level=True) MOCK_BACKUP_RESPONSE = "This is a test response" From 22253e977a7e92f6da1064550819ea752d5c3bb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Wed, 5 Feb 2025 17:10:30 +0000 Subject: [PATCH 14/39] address comments --- pyproject.toml | 2 + src/cleanlab_codex/validation.py | 120 +++++++++++++++--------- tests/test_codex_backup.py | 3 +- tests/test_validation.py | 154 +++++++++++++++++++++++++++++++ 4 files changed, 236 insertions(+), 43 deletions(-) create mode 100644 tests/test_validation.py diff --git a/pyproject.toml b/pyproject.toml index 5db58d5..0342e15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ allow-direct-references = true extra-dependencies = [ "llama-index-core", "smolagents; python_version >= '3.10'", + "cleanlab-studio", + "thefuzz", ] [tool.hatch.envs.hatch-test.env-vars] diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 4b9f98e..addaf96 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -19,48 +19,84 @@ def is_bad_response( response: str, - context: str, - tlm: TLM, # TODO: Make this optional + *, + context: Optional[str] = None, query: Optional[str] = None, + tlm: Optional[TLM] = None, # is_fallback_response args fallback_answer: str = DEFAULT_FALLBACK_ANSWER, partial_ratio_threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD, # is_untrustworthy_response args trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, + format_prompt: Callable[[str, str], str] = default_format_prompt, # is_unhelpful_response args unhelpful_trustworthiness_threshold: Optional[float] = None, ) -> bool: """Run a series of checks to determine if a response is bad. If any of the checks pass, return True. - Checks: - - Is the response too similar to a known fallback answer? - - Is the response untrustworthy? - - Is the response unhelpful? + This function runs three possible validation checks: + 1. Fallback check: Detects if response is too similar to known fallback answers. + 2. Untrustworthy check: Evaluates response trustworthiness given context and query. + 3. Unhelpful check: Evaluates if response is helpful for the given query. Args: - response: The response to check. See `is_fallback_response`, `is_untrustworthy_response`, and `is_unhelpful_response`. - context: The context/documents to use for answering. See `is_untrustworthy_response`. - tlm: The TLM model to use for evaluation. See `is_untrustworthy_response` and `is_unhelpful_response`. - query: The user's question (optional). See `is_untrustworthy_response` and `is_unhelpful_response`. - fallback_answer: The fallback answer to compare against. See `is_fallback_response`. - partial_ratio_threshold: The threshold for detecting fallback responses. See `is_fallback_response`. - trustworthiness_threshold: The threshold for detecting untrustworthy responses. See `is_untrustworthy_response`. - unhelpful_trustworthiness_threshold: The threshold for detecting unhelpful responses. See `is_unhelpful_response`. + response: The response to check. + context: Optional context/documents used for answering. Required for untrustworthy check. + query: Optional user question. Required for untrustworthy and unhelpful checks. + tlm: Optional TLM model for evaluation. Required for untrustworthy and unhelpful checks. + + # Fallback check parameters + fallback_answer: Known unhelpful response to compare against. + partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity. + + # Untrustworthy check parameters + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. + format_prompt: Function to format (query, context) into a prompt string. + + # Unhelpful check parameters + unhelpful_trustworthiness_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification. + + Returns: + bool: True if any validation check fails, False if all pass. """ - validation_checks = [ - lambda: is_fallback_response(response, fallback_answer, threshold=partial_ratio_threshold), - lambda: ( - is_untrustworthy_response(response, context, query, tlm, threshold=trustworthiness_threshold) - if query is not None - else False - ), - lambda: is_unhelpful_response( - response, tlm, query, trustworthiness_score_threshold=unhelpful_trustworthiness_threshold - ), - ] - return any(check() for check in validation_checks) + validation_checks = [] + + # All required inputs are available for checking fallback responses + validation_checks.append( + lambda: is_fallback_response(response, fallback_answer, threshold=partial_ratio_threshold) + ) + + can_run_untrustworthy_check = all(x is not None for x in (query, context, tlm)) + if can_run_untrustworthy_check: + assert tlm is not None + assert query is not None + assert context is not None + validation_checks.append( + lambda: is_untrustworthy_response( + response=response, + context=context, + query=query, + tlm=tlm, + threshold=trustworthiness_threshold, + format_prompt=format_prompt, + ) + ) + + can_run_unhelpful_check = query is not None and tlm is not None + if can_run_unhelpful_check: + assert tlm is not None + assert query is not None + validation_checks.append( + lambda: is_unhelpful_response( + response=response, + tlm=tlm, + query=query, + trustworthiness_score_threshold=unhelpful_trustworthiness_threshold, + ) + ) + return any(check() for check in validation_checks) def is_fallback_response( response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD @@ -71,8 +107,8 @@ def is_fallback_response( Returns True if the response is similar enough to be considered unhelpful. Args: - response: The response to check - fallback_answer: A known unhelpful/fallback response to compare against + response: The response to check. + fallback_answer: A known unhelpful/fallback response to compare against. threshold: Similarity threshold (0-100). Higher values require more similarity. Default 70 means responses that are 70% or more similar are considered bad. @@ -97,7 +133,7 @@ def is_untrustworthy_response( threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, format_prompt: Callable[[str, str], str] = default_format_prompt, ) -> bool: - """Check if a response is untrustworthy based on TLM's evaluation. + """Check if a response is untrustworthy. Uses TLM to evaluate whether a response is trustworthy given the context and query. Returns True if TLM's trustworthiness score falls below the threshold, indicating @@ -109,7 +145,7 @@ def is_untrustworthy_response( query: The user's question or request tlm: The TLM model to use for evaluation threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. - Default 0.6, meaning responses with scores less than 0.6 are considered untrustworthy. + Default 0.5, meaning responses with scores less than 0.5 are considered untrustworthy. format_prompt: Function that takes (query, context) and returns a formatted prompt string. Users should provide their RAG app's own prompt formatting function here to match how their LLM is prompted. @@ -118,19 +154,22 @@ def is_untrustworthy_response( bool: True if the response is deemed untrustworthy by TLM, False otherwise """ try: - from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 + from cleanlab_studio import Studio # type: ignore # noqa: F401 except ImportError as e: error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." raise ImportError(error_msg) from e prompt = format_prompt(query, context) - resp = tlm.get_trustworthiness_score(prompt, response) - score: float = resp["trustworthiness_score"] + result = tlm.get_trustworthiness_score(prompt, response) + score: float = result["trustworthiness_score"] return score < threshold def is_unhelpful_response( - response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None + response: str, + query: str, + tlm: TLM, + trustworthiness_score_threshold: Optional[float] = None, ) -> bool: """Check if a response is unhelpful by asking TLM to evaluate it. @@ -153,7 +192,7 @@ def is_unhelpful_response( False otherwise """ try: - from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 + from cleanlab_studio import Studio # type: ignore # noqa: F401 except ImportError as e: error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." raise ImportError(error_msg) from e @@ -173,13 +212,10 @@ def is_unhelpful_response( expected_unhelpful_response = "yes" prompt = ( - "Consider the following" - + ( - f" User Query and AI Assistant Response.\n\nUser Query: {query}\n\n" - if query - else " AI Assistant Response.\n\n" - ) - + f"AI Assistant Response: {response}\n\n{question}" + "Consider the following User Query and AI Assistant Response.\n\n" + f"User Query: {query}\n\n" + f"AI Assistant Response: {response}\n\n" + f"{question}" ) output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py index 03cae20..03fc1c9 100644 --- a/tests/test_codex_backup.py +++ b/tests/test_codex_backup.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import MagicMock +import pytest + from cleanlab_codex.codex_backup import CodexBackup # TODO: Remove this skip once we update codex_backup.py diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..87ca1e9 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,154 @@ +"""Unit tests for validation module functions.""" + +from unittest.mock import Mock, patch + +import pytest + +from cleanlab_codex.validation import ( + is_bad_response, + is_fallback_response, + is_unhelpful_response, + is_untrustworthy_response, +) + +# Mock responses for testing +GOOD_RESPONSE = "This is a helpful and specific response that answers the question completely." +BAD_RESPONSE = "Based on the available information, I cannot provide a complete answer." +QUERY = "What is the capital of France?" +CONTEXT = "Paris is the capital and largest city of France." + +@pytest.fixture +def mock_tlm(): + """Create a mock TLM instance.""" + mock = Mock() + # Configure default return values + mock.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} + mock.prompt.return_value = {"response": "No", "trustworthiness_score": 0.9} + return mock + +@pytest.mark.parametrize( + "response,threshold,fallback_answer,expected", + [ + # Test threshold variations + (GOOD_RESPONSE, 30, None, True), + (GOOD_RESPONSE, 55, None, False), + + # Test default behavior (BAD_RESPONSE should be flagged) + (BAD_RESPONSE, None, None, True), + + # Test default behavior for different response (GOOD_RESPONSE should not be flagged) + (GOOD_RESPONSE, None, None, False), + + # Test custom fallback answer + (GOOD_RESPONSE, 80, "This is an unhelpful response", False), + ] +) +def test_is_fallback_response(response, threshold, fallback_answer, expected): + """Test fallback response detection.""" + kwargs = {} + if threshold is not None: + kwargs["threshold"] = threshold + if fallback_answer is not None: + kwargs["fallback_answer"] = fallback_answer + + assert is_fallback_response(response, **kwargs) is expected + + +def test_is_untrustworthy_response(mock_tlm): + """Test untrustworthy response detection.""" + # Test trustworthy response + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} + assert is_untrustworthy_response( + GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5 + ) is False + + # Test untrustworthy response + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.3} + assert is_untrustworthy_response( + BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5 + ) is True + +@pytest.mark.parametrize( + "response,tlm_response,tlm_score,threshold,expected", + [ + # Test helpful response + (GOOD_RESPONSE, "No", 0.9, 0.5, False), + + # Test unhelpful response + (BAD_RESPONSE, "Yes", 0.9, 0.5, True), + + # Test unhelpful response but low trustworthiness score + (BAD_RESPONSE, "Yes", 0.3, 0.5, False), + + # Test without threshold - Yes prediction + (BAD_RESPONSE, "Yes", 0.3, None, True), + (GOOD_RESPONSE, "Yes", 0.3, None, True), + + # Test without threshold - No prediction + (BAD_RESPONSE, "No", 0.3, None, False), + (GOOD_RESPONSE, "No", 0.3, None, False), + ] +) +def test_is_unhelpful_response(mock_tlm, response, tlm_response, tlm_score, threshold, expected): + """Test unhelpful response detection.""" + mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} + assert is_unhelpful_response( + response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold + ) is expected + +@pytest.mark.parametrize( + "response,trustworthiness_score,prompt_response,prompt_score,expected", + [ + # Good response passes all checks + (GOOD_RESPONSE, 0.8, "No", 0.9, False), + # Bad response fails at least one check + (BAD_RESPONSE, 0.3, "Yes", 0.9, True), + ] +) +def test_is_bad_response( + mock_tlm, + response, + trustworthiness_score, + prompt_response, + prompt_score, + expected +): + """Test the main is_bad_response function.""" + mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} + mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} + + assert is_bad_response( + response, + context=CONTEXT, + query=QUERY, + tlm=mock_tlm, + ) is expected + +@pytest.mark.parametrize( + "response,fuzz_ratio,prompt_response,prompt_score,query,tlm,expected", + [ + # Test with only fallback check (no context/query/tlm) + (BAD_RESPONSE, 90, None, None, None, None, True), + + # Test with fallback and unhelpful checks (no context) + (GOOD_RESPONSE, 30, "No", 0.9, QUERY, "mock_tlm", False), + ] +) +def test_is_bad_response_partial_inputs(mock_tlm, response, fuzz_ratio, prompt_response, prompt_score, query, tlm, expected): + """Test is_bad_response with partial inputs (some checks disabled).""" + mock_fuzz = Mock() + mock_fuzz.partial_ratio.return_value = fuzz_ratio + + with patch.dict('sys.modules', {'thefuzz': Mock(fuzz=mock_fuzz)}): + if prompt_response is not None: + mock_tlm.prompt.return_value = { + "response": prompt_response, + "trustworthiness_score": prompt_score + } + tlm = mock_tlm + + assert is_bad_response( + response, + query=query, + tlm=tlm, + ) is expected From e5a61649978f6a215ef0574408b517945181538b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Thu, 6 Feb 2025 16:42:12 +0000 Subject: [PATCH 15/39] formatting & type hints --- src/cleanlab_codex/codex_backup.py | 217 +++++++++++++++++++---------- src/cleanlab_codex/validation.py | 31 ++--- tests/test_validation.py | 89 ++++++------ 3 files changed, 196 insertions(+), 141 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 94194ea..5e8ac9d 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -1,17 +1,97 @@ from __future__ import annotations +import os from functools import wraps -from typing import Any, Callable, Optional +from typing import Any, Optional, Protocol, Sequence, Union, cast -from cleanlab_codex.codex import Codex +import requests + +from cleanlab_codex.project import Project from cleanlab_codex.validation import is_bad_response -def handle_backup_default(backup_response: str, decorated_instance: Any) -> None: # noqa: ARG001 +def handle_backup_default(codex_response: str, primary_system: Any) -> None: # noqa: ARG001 """Default implementation is a no-op.""" return None +class _TLM(Protocol): + def get_trustworthiness_score( + self, + query: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + **kwargs: Any, + ) -> dict[str, Any]: ... + + def prompt( + self, + prompt: Union[str, Sequence[str]], + /, + **kwargs: Any, + ) -> dict[str, Any]: ... + + +class _TemporaryTLM(_TLM): + def __init__( + self, + api_key: Optional[str] = None, + api_base_url: Optional[str] = None, + ): + self.api_base_url = api_base_url.rstrip("/") if api_base_url else os.getenv("CODEX_API_BASE_URL") + self._headers = { + "X-API-Key": api_key or os.getenv("CODEX_API_KEY"), + "Content-Type": "application/json", + } + + def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: + """Make a request to the TLM API.""" + url = f"{self.api_base_url}/api/tlm/{endpoint}" + response = requests.post( + url, + json=data, + headers=self._headers, + ) + response.raise_for_status() + return cast(dict[str, Any], response.json()) + + def get_trustworthiness_score( + self, query: Union[str, Sequence[str]], response: Union[str, Sequence[str]], **kwargs: Any + ) -> dict[str, Any]: + """Get trustworthiness score for a query-response pair.""" + data = {"prompt": query, "response": response, **kwargs} + return self._make_request("score", data) + + def prompt(self, prompt: Union[str, Sequence[str]], /, **kwargs: Any) -> dict[str, Any]: + """Send a prompt to the TLM API.""" + data = {"prompt": prompt, **kwargs} + return self._make_request("prompt", data) + + +class BackupHandler(Protocol): + """Protocol defining how to handle backup responses from Codex. + + This protocol defines a callable interface for processing Codex responses that are + retrieved when the primary response system (e.g., a RAG system) fails to provide + an adequate answer. Implementations of this protocol can be used to: + + - Update the primary system's context or knowledge base + - Log Codex responses for analysis + - Trigger system improvements or retraining + - Perform any other necessary side effects + + Args: + codex_response (str): The response received from Codex + primary_system (Any): The instance of the primary RAG system that + generated the inadequate response. This allows the handler to + update or modify the primary system if needed. + + Returns: + None: The handler performs side effects but doesn't return a value + """ + + def __call__(self, codex_response: str, primary_system: Any) -> None: ... + + class CodexBackup: """A backup decorator that connects to a Codex project to answer questions that cannot be adequately answered by the existing agent. @@ -21,91 +101,82 @@ class CodexBackup: def __init__( self, - codex_client: Codex, *, - project_id: Optional[str] = None, - fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, - backup_handler: Callable[[str, Any], None] = handle_backup_default, + project: Project, + fallback_answer: str = DEFAULT_FALLBACK_ANSWER, + backup_handler: BackupHandler = handle_backup_default, ): - self._codex_client = codex_client - self._project_id = project_id + self._tlm = _TemporaryTLM() # TODO: Improve integration + self._project = project self._fallback_answer = fallback_answer self._backup_handler = backup_handler - @classmethod - def from_access_key( - cls, - access_key: str, - *, - project_id: Optional[str] = None, - fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, - backup_handler: Callable[[str, Any], None] = handle_backup_default, - ) -> CodexBackup: - """Creates a CodexBackup from an access key. The project ID that the CodexBackup will use is the one that is associated with the access key.""" - return cls( - codex_client=Codex(key=access_key), - project_id=project_id, - fallback_answer=fallback_answer, - backup_handler=backup_handler, - ) - - @classmethod - def from_client( - cls, - codex_client: Codex, - *, - project_id: Optional[str] = None, - fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, - backup_handler: Callable[[str, Any], None] = handle_backup_default, - ) -> CodexBackup: - """Creates a CodexBackup from a Codex client. - If the Codex client is initialized with a project access key, the CodexBackup will use the project ID that is associated with the access key. - If the Codex client is initialized with a user API key, a project ID must be provided. + def run( + self, + primary_system: Any, + response: str, + query: str, + context: Optional[str] = None, + bad_response_kwargs: Optional[dict[str, Any]] = None, + ) -> str: + """Check if a response is adequate and provide a backup from Codex if needed. + + Args: + primary_system: The system that generated the original response + response: The response to evaluate + query: The original query that generated the response + context: Optional context used to generate the response + bad_response_kwargs: Optional kwargs for response evaluation + + Returns: + str: Either the original response if adequate, or a backup response from Codex """ - return cls( - codex_client=codex_client, - project_id=project_id, - fallback_answer=fallback_answer, - backup_handler=backup_handler, + if not is_bad_response( + response, + query=query, + context=context, + tlm=self._tlm, + fallback_answer=self._fallback_answer, + **(bad_response_kwargs or {}), + ): + return response + + cache_result = self._project.query(query, fallback_answer=self._fallback_answer)[0] + if not cache_result: + return response + + self._backup_handler( + codex_response=cache_result, + primary_system=primary_system, ) + return cache_result def to_decorator(self): - """Factory that creates a backup decorator using the provided Codex client""" - - def decorator(chat_method): - """ - Decorator for RAG chat methods that adds backup response handling. + """Create a decorator that uses this backup system. - If the original chat method returns an inadequate response, attempts to get - a backup response from Codex. Returns the backup response if available, - otherwise returns the original response. - - Args: - chat_method: Method with signature (self, user_message: str) -> str - where 'self' refers to the instance being decorated, not an instance of CodexBackup. - """ + Returns a decorator that can be applied to chat methods to automatically + check responses and provide backups when needed. + """ + def decorator(chat_method): @wraps(chat_method) def wrapper(decorated_instance, user_message): # Call the original chat method - assistant_response = chat_method(decorated_instance, user_message) - - # Return original response if it's adequate - # TODO: Update usage of is_bad_response - if not is_bad_response(assistant_response, self._fallback_answer): - return assistant_response - - # Query Codex for a backup response - cache_result = self._codex_client.query(user_message)[0] - if not cache_result: - return assistant_response - - # Handle backup response if handler exists - self._backup_handler( - backup_response=cache_result, - decorated_instance=decorated_instance, + result = chat_method(decorated_instance, user_message) + + # Handle both single response and (response, context) tuple returns + if isinstance(result, tuple): + assistant_response, context = result + else: + assistant_response, context = result, None + + # Use the run method to handle backup logic + return self.run( + primary_system=decorated_instance, + response=assistant_response, + query=user_message, + context=context, ) - return cache_result return wrapper diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index addaf96..57d3c21 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, cast from cleanlab_codex.utils.prompt import default_format_prompt @@ -44,15 +44,15 @@ def is_bad_response( context: Optional context/documents used for answering. Required for untrustworthy check. query: Optional user question. Required for untrustworthy and unhelpful checks. tlm: Optional TLM model for evaluation. Required for untrustworthy and unhelpful checks. - + # Fallback check parameters fallback_answer: Known unhelpful response to compare against. partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity. - + # Untrustworthy check parameters trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. format_prompt: Function to format (query, context) into a prompt string. - + # Unhelpful check parameters unhelpful_trustworthiness_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification. @@ -60,23 +60,19 @@ def is_bad_response( bool: True if any validation check fails, False if all pass. """ - validation_checks = [] + validation_checks: list[Callable[[], bool]] = [] # All required inputs are available for checking fallback responses - validation_checks.append( - lambda: is_fallback_response(response, fallback_answer, threshold=partial_ratio_threshold) - ) + validation_checks.append(lambda: is_fallback_response(response, fallback_answer, threshold=partial_ratio_threshold)) - can_run_untrustworthy_check = all(x is not None for x in (query, context, tlm)) + can_run_untrustworthy_check = query is not None and context is not None and tlm is not None if can_run_untrustworthy_check: - assert tlm is not None - assert query is not None - assert context is not None + # The if condition guarantees these are not None validation_checks.append( lambda: is_untrustworthy_response( response=response, - context=context, - query=query, + context=cast(str, context), + query=cast(str, query), tlm=tlm, threshold=trustworthiness_threshold, format_prompt=format_prompt, @@ -85,19 +81,18 @@ def is_bad_response( can_run_unhelpful_check = query is not None and tlm is not None if can_run_unhelpful_check: - assert tlm is not None - assert query is not None validation_checks.append( lambda: is_unhelpful_response( response=response, + query=cast(str, query), tlm=tlm, - query=query, trustworthiness_score_threshold=unhelpful_trustworthiness_threshold, ) ) return any(check() for check in validation_checks) + def is_fallback_response( response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD ) -> bool: @@ -192,7 +187,7 @@ def is_unhelpful_response( False otherwise """ try: - from cleanlab_studio import Studio # type: ignore # noqa: F401 + from cleanlab_studio import Studio # noqa: F401 except ImportError as e: error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." raise ImportError(error_msg) from e diff --git a/tests/test_validation.py b/tests/test_validation.py index 87ca1e9..bf1624a 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -17,6 +17,7 @@ QUERY = "What is the capital of France?" CONTEXT = "Paris is the capital and largest city of France." + @pytest.fixture def mock_tlm(): """Create a mock TLM instance.""" @@ -26,22 +27,20 @@ def mock_tlm(): mock.prompt.return_value = {"response": "No", "trustworthiness_score": 0.9} return mock + @pytest.mark.parametrize( - "response,threshold,fallback_answer,expected", + ("response", "threshold", "fallback_answer", "expected"), [ # Test threshold variations (GOOD_RESPONSE, 30, None, True), (GOOD_RESPONSE, 55, None, False), - # Test default behavior (BAD_RESPONSE should be flagged) (BAD_RESPONSE, None, None, True), - # Test default behavior for different response (GOOD_RESPONSE should not be flagged) (GOOD_RESPONSE, None, None, False), - # Test custom fallback answer (GOOD_RESPONSE, 80, "This is an unhelpful response", False), - ] + ], ) def test_is_fallback_response(response, threshold, fallback_answer, expected): """Test fallback response detection.""" @@ -58,97 +57,87 @@ def test_is_untrustworthy_response(mock_tlm): """Test untrustworthy response detection.""" # Test trustworthy response mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} - assert is_untrustworthy_response( - GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5 - ) is False + assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5) is False # Test untrustworthy response mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.3} - assert is_untrustworthy_response( - BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5 - ) is True + assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5) is True + @pytest.mark.parametrize( - "response,tlm_response,tlm_score,threshold,expected", + ("response", "tlm_response", "tlm_score", "threshold", "expected"), [ # Test helpful response (GOOD_RESPONSE, "No", 0.9, 0.5, False), - # Test unhelpful response (BAD_RESPONSE, "Yes", 0.9, 0.5, True), - # Test unhelpful response but low trustworthiness score (BAD_RESPONSE, "Yes", 0.3, 0.5, False), - # Test without threshold - Yes prediction (BAD_RESPONSE, "Yes", 0.3, None, True), (GOOD_RESPONSE, "Yes", 0.3, None, True), - # Test without threshold - No prediction (BAD_RESPONSE, "No", 0.3, None, False), (GOOD_RESPONSE, "No", 0.3, None, False), - ] + ], ) def test_is_unhelpful_response(mock_tlm, response, tlm_response, tlm_score, threshold, expected): """Test unhelpful response detection.""" mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} - assert is_unhelpful_response( - response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold - ) is expected + assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected + @pytest.mark.parametrize( - "response,trustworthiness_score,prompt_response,prompt_score,expected", + ("response", "trustworthiness_score", "prompt_response", "prompt_score", "expected"), [ # Good response passes all checks (GOOD_RESPONSE, 0.8, "No", 0.9, False), # Bad response fails at least one check (BAD_RESPONSE, 0.3, "Yes", 0.9, True), - ] + ], ) -def test_is_bad_response( - mock_tlm, - response, - trustworthiness_score, - prompt_response, - prompt_score, - expected -): +def test_is_bad_response(mock_tlm, response, trustworthiness_score, prompt_response, prompt_score, expected): """Test the main is_bad_response function.""" mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} - assert is_bad_response( - response, - context=CONTEXT, - query=QUERY, - tlm=mock_tlm, - ) is expected + assert ( + is_bad_response( + response, + context=CONTEXT, + query=QUERY, + tlm=mock_tlm, + ) + is expected + ) + @pytest.mark.parametrize( - "response,fuzz_ratio,prompt_response,prompt_score,query,tlm,expected", + ("response", "fuzz_ratio", "prompt_response", "prompt_score", "query", "tlm", "expected"), [ # Test with only fallback check (no context/query/tlm) (BAD_RESPONSE, 90, None, None, None, None, True), - # Test with fallback and unhelpful checks (no context) (GOOD_RESPONSE, 30, "No", 0.9, QUERY, "mock_tlm", False), - ] + ], ) -def test_is_bad_response_partial_inputs(mock_tlm, response, fuzz_ratio, prompt_response, prompt_score, query, tlm, expected): +def test_is_bad_response_partial_inputs( + mock_tlm, response, fuzz_ratio, prompt_response, prompt_score, query, tlm, expected +): """Test is_bad_response with partial inputs (some checks disabled).""" mock_fuzz = Mock() mock_fuzz.partial_ratio.return_value = fuzz_ratio - with patch.dict('sys.modules', {'thefuzz': Mock(fuzz=mock_fuzz)}): + with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): if prompt_response is not None: - mock_tlm.prompt.return_value = { - "response": prompt_response, - "trustworthiness_score": prompt_score - } + mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} tlm = mock_tlm - assert is_bad_response( - response, - query=query, - tlm=tlm, - ) is expected + assert ( + is_bad_response( + response, + query=query, + tlm=tlm, + ) + is expected + ) From 807d7fa866eb6ffa0a5b8ce1fb8007c0dc73b58f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Thu, 6 Feb 2025 16:42:48 +0000 Subject: [PATCH 16/39] comment out to_decorator --- src/cleanlab_codex/codex_backup.py | 59 +++++++++++++++--------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 5e8ac9d..75e360f 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -151,33 +151,32 @@ def run( ) return cache_result - def to_decorator(self): - """Create a decorator that uses this backup system. - - Returns a decorator that can be applied to chat methods to automatically - check responses and provide backups when needed. - """ - - def decorator(chat_method): - @wraps(chat_method) - def wrapper(decorated_instance, user_message): - # Call the original chat method - result = chat_method(decorated_instance, user_message) - - # Handle both single response and (response, context) tuple returns - if isinstance(result, tuple): - assistant_response, context = result - else: - assistant_response, context = result, None - - # Use the run method to handle backup logic - return self.run( - primary_system=decorated_instance, - response=assistant_response, - query=user_message, - context=context, - ) - - return wrapper - - return decorator + # def to_decorator(self): + # """Create a decorator that uses this backup system. + + # Returns a decorator that can be applied to chat methods to automatically + # check responses and provide backups when needed. + # """ + + # def decorator(chat_method): + # @wraps(chat_method) + # def wrapper(decorated_instance, user_message): + # # Call the original chat method + # result = chat_method(decorated_instance, user_message) + + # # Handle both single response and (response, context) tuple returns + # if isinstance(result, tuple): + # assistant_response, context = result + # else: + # assistant_response, context = result, None + + # # Use the run method to handle backup logic + # return self.run( + # primary_system=decorated_instance, + # response=assistant_response, + # query=user_message, + # context=context, + # ) + + # return wrapper + # return decorator From d8a6e868e168ccd82f4bdac7168753432253f923 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 7 Feb 2025 23:08:56 +0000 Subject: [PATCH 17/39] enhance CodexBackup --- src/cleanlab_codex/codex_backup.py | 57 ++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 75e360f..bc19f54 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -2,7 +2,7 @@ import os from functools import wraps -from typing import Any, Optional, Protocol, Sequence, Union, cast +from typing import Any, Optional, Protocol, Sequence, Union, cast, Self import requests @@ -38,6 +38,7 @@ def __init__( api_base_url: Optional[str] = None, ): self.api_base_url = api_base_url.rstrip("/") if api_base_url else os.getenv("CODEX_API_BASE_URL") + assert self.api_base_url is not None, "CODEX_API_BASE_URL is not set" self._headers = { "X-API-Key": api_key or os.getenv("CODEX_API_KEY"), "Content-Type": "application/json", @@ -95,6 +96,13 @@ def __call__(self, codex_response: str, primary_system: Any) -> None: ... class CodexBackup: """A backup decorator that connects to a Codex project to answer questions that cannot be adequately answered by the existing agent. + + Args: + project: The Codex project to use for backup responses + fallback_answer: The fallback answer to use if the primary system fails + backup_handler: A callback function that processes Codex's response and updates the primary RAG system. This handler is called whenever Codex provides a backup response after the primary system fails. By default, the backup handler is a no-op. + primary_system: The existing RAG system that needs to be backed up by Codex + is_bad_response_kwargs: Additional keyword arguments to pass to the is_bad_response function, for detecting inadequate responses from the primary system """ DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." @@ -105,19 +113,44 @@ def __init__( project: Project, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, backup_handler: BackupHandler = handle_backup_default, + primary_system: Optional[Any] = None, + is_bad_response_kwargs: Optional[dict[str, Any]] = None, ): - self._tlm = _TemporaryTLM() # TODO: Improve integration self._project = project self._fallback_answer = fallback_answer self._backup_handler = backup_handler + self._tlm = _TemporaryTLM() # TODO: Improve integration + self._primary_system: Optional[Any] = primary_system + self._is_bad_response_kwargs = is_bad_response_kwargs + + @classmethod + def from_project(cls, project: Project, **kwargs: Any) -> Self: + return cls(project=project, **kwargs) + + @property + def primary_system(self) -> Any: + if self._primary_system is None: + raise ValueError("Primary system not set. Please set a primary system using the `add_primary_system` method.") + return self._primary_system + + @primary_system.setter + def primary_system(self, primary_system: Any) -> None: + """Set the primary RAG system that will be used to generate responses.""" + self._primary_system = primary_system + + @property + def is_bad_response_kwargs(self) -> dict[str, Any]: + return self._is_bad_response_kwargs or {} + + @is_bad_response_kwargs.setter + def is_bad_response_kwargs(self, is_bad_response_kwargs: dict[str, Any]) -> None: + self._is_bad_response_kwargs = is_bad_response_kwargs def run( self, - primary_system: Any, response: str, query: str, context: Optional[str] = None, - bad_response_kwargs: Optional[dict[str, Any]] = None, ) -> str: """Check if a response is adequate and provide a backup from Codex if needed. @@ -126,18 +159,19 @@ def run( response: The response to evaluate query: The original query that generated the response context: Optional context used to generate the response - bad_response_kwargs: Optional kwargs for response evaluation - + Returns: str: Either the original response if adequate, or a backup response from Codex """ + + _is_bad_response_kwargs = self.is_bad_response_kwargs if not is_bad_response( response, query=query, context=context, tlm=self._tlm, fallback_answer=self._fallback_answer, - **(bad_response_kwargs or {}), + **_is_bad_response_kwargs, ): return response @@ -145,10 +179,11 @@ def run( if not cache_result: return response - self._backup_handler( - codex_response=cache_result, - primary_system=primary_system, - ) + if self._primary_system is not None: + self._backup_handler( + codex_response=cache_result, + primary_system=self._primary_system, + ) return cache_result # def to_decorator(self): From 2630a2c11e80b3b6a0208371d7bf3322e1548882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 7 Feb 2025 23:09:33 +0000 Subject: [PATCH 18/39] delete commented-out to_decorator method --- src/cleanlab_codex/codex_backup.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index bc19f54..7b8f596 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -185,33 +185,3 @@ def run( primary_system=self._primary_system, ) return cache_result - - # def to_decorator(self): - # """Create a decorator that uses this backup system. - - # Returns a decorator that can be applied to chat methods to automatically - # check responses and provide backups when needed. - # """ - - # def decorator(chat_method): - # @wraps(chat_method) - # def wrapper(decorated_instance, user_message): - # # Call the original chat method - # result = chat_method(decorated_instance, user_message) - - # # Handle both single response and (response, context) tuple returns - # if isinstance(result, tuple): - # assistant_response, context = result - # else: - # assistant_response, context = result, None - - # # Use the run method to handle backup logic - # return self.run( - # primary_system=decorated_instance, - # response=assistant_response, - # query=user_message, - # context=context, - # ) - - # return wrapper - # return decorator From 0ebd4fea7326e11de60d1bf9fc366207d1185f6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 7 Feb 2025 23:11:51 +0000 Subject: [PATCH 19/39] formatting --- src/cleanlab_codex/codex_backup.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 7b8f596..564eb0c 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -1,8 +1,7 @@ from __future__ import annotations import os -from functools import wraps -from typing import Any, Optional, Protocol, Sequence, Union, cast, Self +from typing import Any, Optional, Protocol, Self, Sequence, Union, cast import requests @@ -96,7 +95,7 @@ def __call__(self, codex_response: str, primary_system: Any) -> None: ... class CodexBackup: """A backup decorator that connects to a Codex project to answer questions that cannot be adequately answered by the existing agent. - + Args: project: The Codex project to use for backup responses fallback_answer: The fallback answer to use if the primary system fails @@ -159,7 +158,7 @@ def run( response: The response to evaluate query: The original query that generated the response context: Optional context used to generate the response - + Returns: str: Either the original response if adequate, or a backup response from Codex """ From 4eca7d36d0e4f861df2f71d7e94ae36aa7e0d891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 7 Feb 2025 23:46:01 +0000 Subject: [PATCH 20/39] fix tests for CodexBackup --- tests/test_codex_backup.py | 66 ++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py index 03fc1c9..5becbf8 100644 --- a/tests/test_codex_backup.py +++ b/tests/test_codex_backup.py @@ -1,70 +1,72 @@ from unittest.mock import MagicMock -import pytest - from cleanlab_codex.codex_backup import CodexBackup -# TODO: Remove this skip once we update codex_backup.py -pytest.skip(allow_module_level=True) - - MOCK_BACKUP_RESPONSE = "This is a test response" FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question." TEST_MESSAGE = "Hello, world!" -def test_codex_backup(mock_client: MagicMock): - mock_response = MagicMock() - mock_response.answer = MOCK_BACKUP_RESPONSE - mock_client.projects.entries.query.return_value = mock_response - - codex_backup = CodexBackup.from_access_key("") +def test_codex_backup(): + # Create a mock project directly + mock_project = MagicMock() + mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) + class MockApp: - @codex_backup.to_decorator() def chat(self, user_message: str) -> str: # Just echo the user message return user_message app = MockApp() + codex_backup = CodexBackup.from_project(mock_project) # Echo works well - response = app.chat(TEST_MESSAGE) - assert response == TEST_MESSAGE + query = TEST_MESSAGE + response = app.chat(query) + assert response == query # Backup works well for fallback responses - response = app.chat(FALLBACK_MESSAGE) - assert response == MOCK_BACKUP_RESPONSE - - -def test_backup_handler(mock_client: MagicMock): - mock_response = MagicMock() - mock_response.answer = MOCK_BACKUP_RESPONSE - mock_client.projects.entries.query.return_value = mock_response - + query = FALLBACK_MESSAGE + response = app.chat(query) + assert response == query + response = codex_backup.run(response, query=query) + assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" + +def test_backup_handler(): + mock_project = MagicMock() + mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) + mock_handler = MagicMock() mock_handler.return_value = None - codex_backup = CodexBackup.from_access_key("", backup_handler=mock_handler) class MockApp: - @codex_backup.to_decorator() def chat(self, user_message: str) -> str: # Just echo the user message return user_message app = MockApp() + codex_backup = CodexBackup.from_project(mock_project, primary_system=app, backup_handler=mock_handler) - response = app.chat(TEST_MESSAGE) - assert response == TEST_MESSAGE + query = TEST_MESSAGE + response = app.chat(query) + assert response == query + + response = codex_backup.run(response, query=query) + assert response == query, f"Response was {response}" # Handler should not be called for good responses assert mock_handler.call_count == 0 - - response = app.chat(FALLBACK_MESSAGE) - assert response == MOCK_BACKUP_RESPONSE + + + query = FALLBACK_MESSAGE + response = app.chat(query) + assert response == query + response = codex_backup.run(response, query=query) + assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" # Handler should be called for bad responses assert mock_handler.call_count == 1 # The MockApp is the second argument to the handler, i.e. it has the necessary context # to handle the new response - assert mock_handler.call_args.kwargs["decorated_instance"] == app + assert mock_handler.call_args.kwargs["primary_system"] == app From c59cec577721ec2a571c5dc87534f58acec9d314 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Fri, 7 Feb 2025 23:51:00 +0000 Subject: [PATCH 21/39] formatting and typing --- tests/test_codex_backup.py | 12 ++++++------ tests/test_validation.py | 20 +++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py index 5becbf8..6861897 100644 --- a/tests/test_codex_backup.py +++ b/tests/test_codex_backup.py @@ -7,11 +7,11 @@ TEST_MESSAGE = "Hello, world!" -def test_codex_backup(): +def test_codex_backup() -> None: # Create a mock project directly mock_project = MagicMock() mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) - + class MockApp: def chat(self, user_message: str) -> str: @@ -33,10 +33,10 @@ def chat(self, user_message: str) -> str: response = codex_backup.run(response, query=query) assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" -def test_backup_handler(): +def test_backup_handler() -> None: mock_project = MagicMock() mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) - + mock_handler = MagicMock() mock_handler.return_value = None @@ -57,8 +57,8 @@ def chat(self, user_message: str) -> str: # Handler should not be called for good responses assert mock_handler.call_count == 0 - - + + query = FALLBACK_MESSAGE response = app.chat(query) assert response == query diff --git a/tests/test_validation.py b/tests/test_validation.py index bf1624a..621d361 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,5 +1,7 @@ """Unit tests for validation module functions.""" +from __future__ import annotations + from unittest.mock import Mock, patch import pytest @@ -19,7 +21,7 @@ @pytest.fixture -def mock_tlm(): +def mock_tlm() -> Mock: """Create a mock TLM instance.""" mock = Mock() # Configure default return values @@ -42,18 +44,18 @@ def mock_tlm(): (GOOD_RESPONSE, 80, "This is an unhelpful response", False), ], ) -def test_is_fallback_response(response, threshold, fallback_answer, expected): +def test_is_fallback_response(response: str, threshold: float | None, fallback_answer: str | None, expected: bool) -> None: """Test fallback response detection.""" - kwargs = {} + kwargs: dict[str, float | str] = {} if threshold is not None: kwargs["threshold"] = threshold if fallback_answer is not None: kwargs["fallback_answer"] = fallback_answer - assert is_fallback_response(response, **kwargs) is expected + assert is_fallback_response(response, **kwargs) is expected # type: ignore -def test_is_untrustworthy_response(mock_tlm): +def test_is_untrustworthy_response(mock_tlm: Mock) -> None: """Test untrustworthy response detection.""" # Test trustworthy response mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} @@ -81,7 +83,7 @@ def test_is_untrustworthy_response(mock_tlm): (GOOD_RESPONSE, "No", 0.3, None, False), ], ) -def test_is_unhelpful_response(mock_tlm, response, tlm_response, tlm_score, threshold, expected): +def test_is_unhelpful_response(mock_tlm: Mock, response: str, tlm_response: str, tlm_score: float, threshold: float | None, expected: bool) -> None: """Test unhelpful response detection.""" mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected @@ -96,7 +98,7 @@ def test_is_unhelpful_response(mock_tlm, response, tlm_response, tlm_score, thre (BAD_RESPONSE, 0.3, "Yes", 0.9, True), ], ) -def test_is_bad_response(mock_tlm, response, trustworthiness_score, prompt_response, prompt_score, expected): +def test_is_bad_response(mock_tlm: Mock, response: str, trustworthiness_score: float, prompt_response: str, prompt_score: float, expected: bool) -> None: """Test the main is_bad_response function.""" mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} @@ -122,8 +124,8 @@ def test_is_bad_response(mock_tlm, response, trustworthiness_score, prompt_respo ], ) def test_is_bad_response_partial_inputs( - mock_tlm, response, fuzz_ratio, prompt_response, prompt_score, query, tlm, expected -): + mock_tlm: Mock, response: str, fuzz_ratio: int, prompt_response: str, prompt_score: float, query: str, tlm: Mock, expected: bool +) -> None: """Test is_bad_response with partial inputs (some checks disabled).""" mock_fuzz = Mock() mock_fuzz.partial_ratio.return_value = fuzz_ratio From 602617996e506b57191ee01dcd1dafb1eccf18a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 8 Feb 2025 00:00:46 +0000 Subject: [PATCH 22/39] formatting --- src/cleanlab_codex/codex_backup.py | 17 ++++++++++++----- tests/test_validation.py | 8 ++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 564eb0c..93beb3a 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -1,13 +1,15 @@ from __future__ import annotations import os -from typing import Any, Optional, Protocol, Self, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Protocol, Self, Sequence, Union, cast import requests -from cleanlab_codex.project import Project from cleanlab_codex.validation import is_bad_response +if TYPE_CHECKING: + from cleanlab_codex.project import Project + def handle_backup_default(codex_response: str, primary_system: Any) -> None: # noqa: ARG001 """Default implementation is a no-op.""" @@ -35,14 +37,17 @@ def __init__( self, api_key: Optional[str] = None, api_base_url: Optional[str] = None, + **kwargs: Any, ): self.api_base_url = api_base_url.rstrip("/") if api_base_url else os.getenv("CODEX_API_BASE_URL") - assert self.api_base_url is not None, "CODEX_API_BASE_URL is not set" + if self.api_base_url is None: + error_message = "Please set the CODEX_API_BASE_URL environment variable or pass api_base_url to the _TemporaryTLM constructor." + raise ValueError(error_message) self._headers = { "X-API-Key": api_key or os.getenv("CODEX_API_KEY"), "Content-Type": "application/json", } - + self._timeout = kwargs.get("timeout", 10) # type: ignore def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: """Make a request to the TLM API.""" url = f"{self.api_base_url}/api/tlm/{endpoint}" @@ -50,6 +55,7 @@ def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: url, json=data, headers=self._headers, + timeout=self._timeout, ) response.raise_for_status() return cast(dict[str, Any], response.json()) @@ -129,7 +135,8 @@ def from_project(cls, project: Project, **kwargs: Any) -> Self: @property def primary_system(self) -> Any: if self._primary_system is None: - raise ValueError("Primary system not set. Please set a primary system using the `add_primary_system` method.") + error_message = "Primary system not set. Please set a primary system using the `add_primary_system` method." + raise ValueError(error_message) return self._primary_system @primary_system.setter diff --git a/tests/test_validation.py b/tests/test_validation.py index 621d361..d0e82c8 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -44,7 +44,7 @@ def mock_tlm() -> Mock: (GOOD_RESPONSE, 80, "This is an unhelpful response", False), ], ) -def test_is_fallback_response(response: str, threshold: float | None, fallback_answer: str | None, expected: bool) -> None: +def test_is_fallback_response(response: str, threshold: float | None, fallback_answer: str | None, *, expected: bool) -> None: """Test fallback response detection.""" kwargs: dict[str, float | str] = {} if threshold is not None: @@ -83,7 +83,7 @@ def test_is_untrustworthy_response(mock_tlm: Mock) -> None: (GOOD_RESPONSE, "No", 0.3, None, False), ], ) -def test_is_unhelpful_response(mock_tlm: Mock, response: str, tlm_response: str, tlm_score: float, threshold: float | None, expected: bool) -> None: +def test_is_unhelpful_response(mock_tlm: Mock, response: str, tlm_response: str, tlm_score: float, threshold: float | None, *, expected: bool) -> None: """Test unhelpful response detection.""" mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected @@ -98,7 +98,7 @@ def test_is_unhelpful_response(mock_tlm: Mock, response: str, tlm_response: str, (BAD_RESPONSE, 0.3, "Yes", 0.9, True), ], ) -def test_is_bad_response(mock_tlm: Mock, response: str, trustworthiness_score: float, prompt_response: str, prompt_score: float, expected: bool) -> None: +def test_is_bad_response(mock_tlm: Mock, response: str, trustworthiness_score: float, prompt_response: str, prompt_score: float, *, expected: bool) -> None: """Test the main is_bad_response function.""" mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} @@ -124,7 +124,7 @@ def test_is_bad_response(mock_tlm: Mock, response: str, trustworthiness_score: f ], ) def test_is_bad_response_partial_inputs( - mock_tlm: Mock, response: str, fuzz_ratio: int, prompt_response: str, prompt_score: float, query: str, tlm: Mock, expected: bool + mock_tlm: Mock, response: str, fuzz_ratio: int, prompt_response: str, prompt_score: float, query: str, tlm: Mock, *, expected: bool ) -> None: """Test is_bad_response with partial inputs (some checks disabled).""" mock_fuzz = Mock() From a94ffb59d1e59ba3e802da2cf78e5f1e2d8faff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 8 Feb 2025 00:04:31 +0000 Subject: [PATCH 23/39] formatting --- src/cleanlab_codex/codex_backup.py | 1 + tests/test_codex_backup.py | 3 +-- tests/test_validation.py | 39 ++++++++++++++++++++++++++---- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 93beb3a..140e114 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -48,6 +48,7 @@ def __init__( "Content-Type": "application/json", } self._timeout = kwargs.get("timeout", 10) # type: ignore + def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: """Make a request to the TLM API.""" url = f"{self.api_base_url}/api/tlm/{endpoint}" diff --git a/tests/test_codex_backup.py b/tests/test_codex_backup.py index 6861897..e8e9001 100644 --- a/tests/test_codex_backup.py +++ b/tests/test_codex_backup.py @@ -12,7 +12,6 @@ def test_codex_backup() -> None: mock_project = MagicMock() mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) - class MockApp: def chat(self, user_message: str) -> str: # Just echo the user message @@ -33,6 +32,7 @@ def chat(self, user_message: str) -> str: response = codex_backup.run(response, query=query) assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}" + def test_backup_handler() -> None: mock_project = MagicMock() mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,) @@ -58,7 +58,6 @@ def chat(self, user_message: str) -> str: # Handler should not be called for good responses assert mock_handler.call_count == 0 - query = FALLBACK_MESSAGE response = app.chat(query) assert response == query diff --git a/tests/test_validation.py b/tests/test_validation.py index d0e82c8..de88446 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -44,7 +44,13 @@ def mock_tlm() -> Mock: (GOOD_RESPONSE, 80, "This is an unhelpful response", False), ], ) -def test_is_fallback_response(response: str, threshold: float | None, fallback_answer: str | None, *, expected: bool) -> None: +def test_is_fallback_response( + response: str, + threshold: float | None, + fallback_answer: str | None, + *, + expected: bool, +) -> None: """Test fallback response detection.""" kwargs: dict[str, float | str] = {} if threshold is not None: @@ -83,7 +89,15 @@ def test_is_untrustworthy_response(mock_tlm: Mock) -> None: (GOOD_RESPONSE, "No", 0.3, None, False), ], ) -def test_is_unhelpful_response(mock_tlm: Mock, response: str, tlm_response: str, tlm_score: float, threshold: float | None, *, expected: bool) -> None: +def test_is_unhelpful_response( + mock_tlm: Mock, + response: str, + tlm_response: str, + tlm_score: float, + threshold: float | None, + *, + expected: bool, +) -> None: """Test unhelpful response detection.""" mock_tlm.prompt.return_value = {"response": tlm_response, "trustworthiness_score": tlm_score} assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected @@ -98,7 +112,15 @@ def test_is_unhelpful_response(mock_tlm: Mock, response: str, tlm_response: str, (BAD_RESPONSE, 0.3, "Yes", 0.9, True), ], ) -def test_is_bad_response(mock_tlm: Mock, response: str, trustworthiness_score: float, prompt_response: str, prompt_score: float, *, expected: bool) -> None: +def test_is_bad_response( + mock_tlm: Mock, + response: str, + trustworthiness_score: float, + prompt_response: str, + prompt_score: float, + *, + expected: bool, +) -> None: """Test the main is_bad_response function.""" mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": trustworthiness_score} mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} @@ -124,11 +146,18 @@ def test_is_bad_response(mock_tlm: Mock, response: str, trustworthiness_score: f ], ) def test_is_bad_response_partial_inputs( - mock_tlm: Mock, response: str, fuzz_ratio: int, prompt_response: str, prompt_score: float, query: str, tlm: Mock, *, expected: bool + mock_tlm: Mock, + response: str, + fuzz_ratio: int, + prompt_response: str, + prompt_score: float, + query: str, + tlm: Mock, + *, + expected: bool, ) -> None: """Test is_bad_response with partial inputs (some checks disabled).""" mock_fuzz = Mock() - mock_fuzz.partial_ratio.return_value = fuzz_ratio with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): if prompt_response is not None: From 2510255878225621872dfead9427c06e8af45f16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 8 Feb 2025 00:08:32 +0000 Subject: [PATCH 24/39] fix unused fixture --- tests/test_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_validation.py b/tests/test_validation.py index de88446..d1bcf70 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -158,7 +158,7 @@ def test_is_bad_response_partial_inputs( ) -> None: """Test is_bad_response with partial inputs (some checks disabled).""" mock_fuzz = Mock() - + mock_fuzz.partial_ratio.return_value = fuzz_ratio with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): if prompt_response is not None: mock_tlm.prompt.return_value = {"response": prompt_response, "trustworthiness_score": prompt_score} From b4391136ac7e49268a310099c4015c9596f5f591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 8 Feb 2025 00:12:35 +0000 Subject: [PATCH 25/39] remove Self imported from typing, doesn't work for Python 3.8 --- src/cleanlab_codex/codex_backup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 140e114..5f14e03 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Optional, Protocol, Self, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Protocol, Sequence, Union, cast import requests @@ -130,7 +130,7 @@ def __init__( self._is_bad_response_kwargs = is_bad_response_kwargs @classmethod - def from_project(cls, project: Project, **kwargs: Any) -> Self: + def from_project(cls, project: Project, **kwargs: Any) -> "CodexBackup": return cls(project=project, **kwargs) @property From 38666def71941d670571b2a15965255deb8ca6aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 8 Feb 2025 00:13:22 +0000 Subject: [PATCH 26/39] remove unused type ignore --- src/cleanlab_codex/codex_backup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 5f14e03..194dfab 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -47,7 +47,7 @@ def __init__( "X-API-Key": api_key or os.getenv("CODEX_API_KEY"), "Content-Type": "application/json", } - self._timeout = kwargs.get("timeout", 10) # type: ignore + self._timeout = kwargs.get("timeout", 10) def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: """Make a request to the TLM API.""" From a5d655b36f339856a81daa9da15414324e9c62e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 8 Feb 2025 01:16:51 +0000 Subject: [PATCH 27/39] Add explanation to is_unhelpful_response question --- src/cleanlab_codex/validation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 57d3c21..1781812 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -196,10 +196,11 @@ def is_unhelpful_response( # - When asking "is helpful?" -> "no" means unhelpful # - When asking "is unhelpful?" -> "yes" means unhelpful question = ( - "Is the AI Assistant Response unhelpful? " - "Unhelpful responses include answers that:\n" + "Does the AI Assistant Response seem unhelpful? " + "Things that are not helpful include answers that:\n" "- Are not useful, incomplete, incorrect, uncertain or unclear.\n" "- Abstain or refuse to answer the question\n" + "- Statements which are similar to 'I don't know', 'Sorry', or 'No information available'.\n" "- Leave the original question unresolved\n" "- Are irrelevant to the question\n" "Answer Yes/No only." From e776dfed1566f80fad5f25dddcf1554d1bb395de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Sat, 8 Feb 2025 01:21:56 +0000 Subject: [PATCH 28/39] Remove quotes from type annotation --- src/cleanlab_codex/codex_backup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 194dfab..7e4de5e 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -130,7 +130,7 @@ def __init__( self._is_bad_response_kwargs = is_bad_response_kwargs @classmethod - def from_project(cls, project: Project, **kwargs: Any) -> "CodexBackup": + def from_project(cls, project: Project, **kwargs: Any) -> CodexBackup: return cls(project=project, **kwargs) @property From 7866f0ce7b8673d79a0e2336b66a118cef6a49e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Mon, 10 Feb 2025 17:53:03 +0000 Subject: [PATCH 29/39] remove _TLM protocol MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rely on the Studio client for now, which the user can optionally pass in. Without TLM, CodexBackup only runs the `is_fallback_response`  check. --- src/cleanlab_codex/codex_backup.py | 69 +++--------------------------- 1 file changed, 6 insertions(+), 63 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 7e4de5e..9d97ece 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -1,14 +1,12 @@ from __future__ import annotations -import os -from typing import TYPE_CHECKING, Any, Optional, Protocol, Sequence, Union, cast - -import requests +from typing import TYPE_CHECKING, Any, Optional, Protocol from cleanlab_codex.validation import is_bad_response if TYPE_CHECKING: from cleanlab_codex.project import Project + from cleanlab_studio.studio.trustworthy_language_model import TLM def handle_backup_default(codex_response: str, primary_system: Any) -> None: # noqa: ARG001 @@ -16,63 +14,6 @@ def handle_backup_default(codex_response: str, primary_system: Any) -> None: # return None -class _TLM(Protocol): - def get_trustworthiness_score( - self, - query: Union[str, Sequence[str]], - response: Union[str, Sequence[str]], - **kwargs: Any, - ) -> dict[str, Any]: ... - - def prompt( - self, - prompt: Union[str, Sequence[str]], - /, - **kwargs: Any, - ) -> dict[str, Any]: ... - - -class _TemporaryTLM(_TLM): - def __init__( - self, - api_key: Optional[str] = None, - api_base_url: Optional[str] = None, - **kwargs: Any, - ): - self.api_base_url = api_base_url.rstrip("/") if api_base_url else os.getenv("CODEX_API_BASE_URL") - if self.api_base_url is None: - error_message = "Please set the CODEX_API_BASE_URL environment variable or pass api_base_url to the _TemporaryTLM constructor." - raise ValueError(error_message) - self._headers = { - "X-API-Key": api_key or os.getenv("CODEX_API_KEY"), - "Content-Type": "application/json", - } - self._timeout = kwargs.get("timeout", 10) - - def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]: - """Make a request to the TLM API.""" - url = f"{self.api_base_url}/api/tlm/{endpoint}" - response = requests.post( - url, - json=data, - headers=self._headers, - timeout=self._timeout, - ) - response.raise_for_status() - return cast(dict[str, Any], response.json()) - - def get_trustworthiness_score( - self, query: Union[str, Sequence[str]], response: Union[str, Sequence[str]], **kwargs: Any - ) -> dict[str, Any]: - """Get trustworthiness score for a query-response pair.""" - data = {"prompt": query, "response": response, **kwargs} - return self._make_request("score", data) - - def prompt(self, prompt: Union[str, Sequence[str]], /, **kwargs: Any) -> dict[str, Any]: - """Send a prompt to the TLM API.""" - data = {"prompt": prompt, **kwargs} - return self._make_request("prompt", data) - class BackupHandler(Protocol): """Protocol defining how to handle backup responses from Codex. @@ -105,9 +46,10 @@ class CodexBackup: Args: project: The Codex project to use for backup responses - fallback_answer: The fallback answer to use if the primary system fails + fallback_answer: The fallback answer to use if the primary system fails to provide an adequate response backup_handler: A callback function that processes Codex's response and updates the primary RAG system. This handler is called whenever Codex provides a backup response after the primary system fails. By default, the backup handler is a no-op. primary_system: The existing RAG system that needs to be backed up by Codex + tlm: The client for the Trustworthy Language Model, which evaluates the quality of responses from the primary system is_bad_response_kwargs: Additional keyword arguments to pass to the is_bad_response function, for detecting inadequate responses from the primary system """ @@ -120,13 +62,14 @@ def __init__( fallback_answer: str = DEFAULT_FALLBACK_ANSWER, backup_handler: BackupHandler = handle_backup_default, primary_system: Optional[Any] = None, + tlm: Optional[TLM] = None, is_bad_response_kwargs: Optional[dict[str, Any]] = None, ): self._project = project self._fallback_answer = fallback_answer self._backup_handler = backup_handler - self._tlm = _TemporaryTLM() # TODO: Improve integration self._primary_system: Optional[Any] = primary_system + self._tlm = tlm self._is_bad_response_kwargs = is_bad_response_kwargs @classmethod From 26adbf1a04607da26844c7a49bd7eac284f6de8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Mon, 10 Feb 2025 17:53:27 +0000 Subject: [PATCH 30/39] formatting --- src/cleanlab_codex/codex_backup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 9d97ece..e32182f 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -5,16 +5,16 @@ from cleanlab_codex.validation import is_bad_response if TYPE_CHECKING: - from cleanlab_codex.project import Project from cleanlab_studio.studio.trustworthy_language_model import TLM + from cleanlab_codex.project import Project + def handle_backup_default(codex_response: str, primary_system: Any) -> None: # noqa: ARG001 """Default implementation is a no-op.""" return None - class BackupHandler(Protocol): """Protocol defining how to handle backup responses from Codex. From 36f80e98a53a5d4bca0a39f91df74c72c0a611b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 00:31:18 +0000 Subject: [PATCH 31/39] threshold -> trustworthiness_threshold --- src/cleanlab_codex/validation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 1781812..d03f39f 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -74,7 +74,7 @@ def is_bad_response( context=cast(str, context), query=cast(str, query), tlm=tlm, - threshold=trustworthiness_threshold, + trustworthiness_threshold=trustworthiness_threshold, format_prompt=format_prompt, ) ) @@ -125,7 +125,7 @@ def is_untrustworthy_response( context: str, query: str, tlm: TLM, - threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, + trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, format_prompt: Callable[[str, str], str] = default_format_prompt, ) -> bool: """Check if a response is untrustworthy. @@ -139,7 +139,7 @@ def is_untrustworthy_response( context: The context information available for answering the query query: The user's question or request tlm: The TLM model to use for evaluation - threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. Default 0.5, meaning responses with scores less than 0.5 are considered untrustworthy. format_prompt: Function that takes (query, context) and returns a formatted prompt string. Users should provide their RAG app's own prompt formatting function here @@ -157,7 +157,7 @@ def is_untrustworthy_response( prompt = format_prompt(query, context) result = tlm.get_trustworthiness_score(prompt, response) score: float = result["trustworthiness_score"] - return score < threshold + return score < trustworthiness_threshold def is_unhelpful_response( From febbfd04c8132b7b7f6f358857dddb793eee7dad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 00:43:56 +0000 Subject: [PATCH 32/39] update is_bad_response docstring --- src/cleanlab_codex/validation.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index d03f39f..95f14a0 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -32,12 +32,18 @@ def is_bad_response( # is_unhelpful_response args unhelpful_trustworthiness_threshold: Optional[float] = None, ) -> bool: - """Run a series of checks to determine if a response is bad. If any of the checks pass, return True. + """Run a series of checks to determine if a response is bad. + + If any check detects an issue (i.e. fails), the function returns True, indicating the response is bad. This function runs three possible validation checks: - 1. Fallback check: Detects if response is too similar to known fallback answers. - 2. Untrustworthy check: Evaluates response trustworthiness given context and query. - 3. Unhelpful check: Evaluates if response is helpful for the given query. + 1. **Fallback check**: Detects if response is too similar to a known fallback answer. + 2. **Untrustworthy check**: Assesses response trustworthiness based on the given context and query. + 3. **Unhelpful check**: Predicts if the response adequately answers the query or not, in a useful way. + + Note: + Each validation check runs conditionally based on whether the required arguments are provided. + As soon as any validation check fails, the function returns True. Args: response: The response to check. @@ -45,15 +51,15 @@ def is_bad_response( query: Optional user question. Required for untrustworthy and unhelpful checks. tlm: Optional TLM model for evaluation. Required for untrustworthy and unhelpful checks. - # Fallback check parameters + --- Fallback check parameters --- fallback_answer: Known unhelpful response to compare against. partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity. - # Untrustworthy check parameters + --- Untrustworthy check parameters --- trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. format_prompt: Function to format (query, context) into a prompt string. - # Unhelpful check parameters + --- Unhelpful check parameters --- unhelpful_trustworthiness_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification. Returns: From dc1d003d9d013296e7389bb497d717488c726f77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 00:49:36 +0000 Subject: [PATCH 33/39] update docstrings for is_unhelpful_response --- src/cleanlab_codex/validation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 95f14a0..ab78add 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -180,13 +180,13 @@ def is_unhelpful_response( is sufficiently confident in that assessment (if a threshold is provided). Args: - response: The response to check from the assistant + response: The response to check + query: User query that will be used to evaluate if the response is helpful tlm: The TLM model to use for evaluation - query: Optional user query to provide context for evaluating helpfulness. - If provided, TLM will assess if the response helpfully answers this query. - trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0). - If provided, responses are only marked unhelpful if TLM's - confidence score exceeds this threshold. + trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0) + If provided and the response is marked as unhelpful, + the confidence score must exceed this threshold for + the response to be considered truly unhelpful. Returns: bool: True if TLM determines the response is unhelpful with sufficient confidence, From 739ffc6b4063752d94e9ae87f1b212e8b1607eb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 00:51:25 +0000 Subject: [PATCH 34/39] unhelpful_trustworthiness_threshold -> unhelpfulness_confidence_threshold in is_bad_response --- src/cleanlab_codex/validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index ab78add..23850d9 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -30,7 +30,7 @@ def is_bad_response( trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, format_prompt: Callable[[str, str], str] = default_format_prompt, # is_unhelpful_response args - unhelpful_trustworthiness_threshold: Optional[float] = None, + unhelpfulness_confidence_threshold: Optional[float] = None, ) -> bool: """Run a series of checks to determine if a response is bad. @@ -60,7 +60,7 @@ def is_bad_response( format_prompt: Function to format (query, context) into a prompt string. --- Unhelpful check parameters --- - unhelpful_trustworthiness_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification. + unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification. Returns: bool: True if any validation check fails, False if all pass. @@ -92,7 +92,7 @@ def is_bad_response( response=response, query=cast(str, query), tlm=tlm, - trustworthiness_score_threshold=unhelpful_trustworthiness_threshold, + trustworthiness_score_threshold=unhelpfulness_confidence_threshold, ) ) From 3e4864aa3026e643ef9433449818f9035290cba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 00:55:13 +0000 Subject: [PATCH 35/39] update module docstring for validation.py --- src/cleanlab_codex/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/validation.py index 23850d9..1011294 100644 --- a/src/cleanlab_codex/validation.py +++ b/src/cleanlab_codex/validation.py @@ -1,5 +1,5 @@ """ -This module provides validation functions for checking if an LLM response is unhelpful. +This module provides validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives. """ from __future__ import annotations From 81cc934688f4b944d26fa22d2b3ee3c417eddc2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 01:00:13 +0000 Subject: [PATCH 36/39] rename module validation.py -> response_validation.py --- src/cleanlab_codex/codex_backup.py | 2 +- src/cleanlab_codex/{validation.py => response_validation.py} | 0 tests/test_validation.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/cleanlab_codex/{validation.py => response_validation.py} (100%) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index e32182f..7bf8377 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Optional, Protocol -from cleanlab_codex.validation import is_bad_response +from cleanlab_codex.response_validation import is_bad_response if TYPE_CHECKING: from cleanlab_studio.studio.trustworthy_language_model import TLM diff --git a/src/cleanlab_codex/validation.py b/src/cleanlab_codex/response_validation.py similarity index 100% rename from src/cleanlab_codex/validation.py rename to src/cleanlab_codex/response_validation.py diff --git a/tests/test_validation.py b/tests/test_validation.py index d1bcf70..facda62 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -6,7 +6,7 @@ import pytest -from cleanlab_codex.validation import ( +from cleanlab_codex.response_validation import ( is_bad_response, is_fallback_response, is_unhelpful_response, From 9e91e9b491b588d9f80703a7fe00d39776d6ba0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 03:55:22 +0000 Subject: [PATCH 37/39] move is_bad_response optional parameters to a parameter object (typed dictionary) --- src/cleanlab_codex/codex_backup.py | 14 ++- src/cleanlab_codex/response_validation.py | 138 ++++++++++++++++------ tests/test_validation.py | 8 +- 3 files changed, 112 insertions(+), 48 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index 7bf8377..d0f95d5 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Protocol +from typing import TYPE_CHECKING, Any, Optional, Protocol, cast -from cleanlab_codex.response_validation import is_bad_response +from cleanlab_codex.response_validation import BadResponseDetectionConfig, is_bad_response if TYPE_CHECKING: - from cleanlab_studio.studio.trustworthy_language_model import TLM + from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore from cleanlab_codex.project import Project @@ -119,9 +119,11 @@ def run( response, query=query, context=context, - tlm=self._tlm, - fallback_answer=self._fallback_answer, - **_is_bad_response_kwargs, + config=cast(BadResponseDetectionConfig, { + "tlm": self._tlm, + "fallback_answer": self._fallback_answer, + **_is_bad_response_kwargs, + }), ): return response diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index 1011294..c1c2eb9 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -4,12 +4,36 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, Sequence, TypedDict, Union, cast +from cleanlab_codex.utils.errors import MissingDependencyError from cleanlab_codex.utils.prompt import default_format_prompt if TYPE_CHECKING: - from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore + try: + from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore + except ImportError: + from typing import Any, Dict, Protocol, Sequence + + class _TLMProtocol(Protocol): + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + **kwargs: Any, + ) -> Dict[str, Any]: + ... + + def prompt( + self, + prompt: Union[str, Sequence[str]], + /, + **kwargs: Any, + ) -> Dict[str, Any]: + ... + + TLM = _TLMProtocol + DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." @@ -17,20 +41,54 @@ DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 +class BadResponseDetectionConfig(TypedDict, total=False): + """Configuration for bad response detection functions. + See get_bad_response_config() for default values. + + Attributes: + fallback_answer: Known unhelpful response to compare against + partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses + format_prompt: Function to format (query, context) into a prompt string + unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification + tlm: TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) + """ + # Fallback check config + fallback_answer: str + partial_ratio_threshold: int + + # Untrustworthy check config + trustworthiness_threshold: float + format_prompt: Callable[[str, str], str] + + # Unhelpful check config + unhelpfulness_confidence_threshold: Optional[float] + + # Shared config (for untrustworthiness and unhelpfulness checks) + tlm: Optional[TLM] + +def get_bad_response_config() -> BadResponseDetectionConfig: + """Get the default configuration for bad response detection functions. + + Returns: + BadResponseDetectionConfig: Default configuration for bad response detection functions + """ + return { + "fallback_answer": DEFAULT_FALLBACK_ANSWER, + "partial_ratio_threshold": DEFAULT_PARTIAL_RATIO_THRESHOLD, + "trustworthiness_threshold": DEFAULT_TRUSTWORTHINESS_THRESHOLD, + "format_prompt": default_format_prompt, + "unhelpfulness_confidence_threshold": None, + "tlm": None, + } + + def is_bad_response( response: str, *, context: Optional[str] = None, query: Optional[str] = None, - tlm: Optional[TLM] = None, - # is_fallback_response args - fallback_answer: str = DEFAULT_FALLBACK_ANSWER, - partial_ratio_threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD, - # is_untrustworthy_response args - trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, - format_prompt: Callable[[str, str], str] = default_format_prompt, - # is_unhelpful_response args - unhelpfulness_confidence_threshold: Optional[float] = None, + config: Optional[BadResponseDetectionConfig] = None, ) -> bool: """Run a series of checks to determine if a response is bad. @@ -49,29 +107,25 @@ def is_bad_response( response: The response to check. context: Optional context/documents used for answering. Required for untrustworthy check. query: Optional user question. Required for untrustworthy and unhelpful checks. - tlm: Optional TLM model for evaluation. Required for untrustworthy and unhelpful checks. - - --- Fallback check parameters --- - fallback_answer: Known unhelpful response to compare against. - partial_ratio_threshold: Similarity threshold (0-100). Higher values require more similarity. - - --- Untrustworthy check parameters --- - trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. - format_prompt: Function to format (query, context) into a prompt string. - - --- Unhelpful check parameters --- - unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification. + config: Optional, typed dictionary of configuration parameters. See <_BadReponseConfig> for details. Returns: bool: True if any validation check fails, False if all pass. """ + default_cfg = get_bad_response_config() + cfg: BadResponseDetectionConfig + cfg = {**default_cfg, **(config or {})} validation_checks: list[Callable[[], bool]] = [] # All required inputs are available for checking fallback responses - validation_checks.append(lambda: is_fallback_response(response, fallback_answer, threshold=partial_ratio_threshold)) + validation_checks.append(lambda: is_fallback_response( + response, + cfg["fallback_answer"], + threshold=cfg["partial_ratio_threshold"], + )) - can_run_untrustworthy_check = query is not None and context is not None and tlm is not None + can_run_untrustworthy_check = query is not None and context is not None and cfg["tlm"] is not None if can_run_untrustworthy_check: # The if condition guarantees these are not None validation_checks.append( @@ -79,20 +133,20 @@ def is_bad_response( response=response, context=cast(str, context), query=cast(str, query), - tlm=tlm, - trustworthiness_threshold=trustworthiness_threshold, - format_prompt=format_prompt, + tlm=cfg["tlm"], + trustworthiness_threshold=cfg["trustworthiness_threshold"], + format_prompt=cfg["format_prompt"], ) ) - can_run_unhelpful_check = query is not None and tlm is not None + can_run_unhelpful_check = query is not None and cfg["tlm"] is not None if can_run_unhelpful_check: validation_checks.append( lambda: is_unhelpful_response( response=response, query=cast(str, query), - tlm=tlm, - trustworthiness_score_threshold=unhelpfulness_confidence_threshold, + tlm=cfg["tlm"], + trustworthiness_score_threshold=cast(float, cfg["unhelpfulness_confidence_threshold"]), ) ) @@ -119,8 +173,10 @@ def is_fallback_response( try: from thefuzz import fuzz # type: ignore except ImportError as e: - error_msg = "The 'thefuzz' library is required. Please install it with `pip install thefuzz`." - raise ImportError(error_msg) from e + raise MissingDependencyError( + import_name=e.name or "thefuzz", + package_url="https://github.com/seatgeek/thefuzz", + ) from e partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) return bool(partial_ratio >= threshold) @@ -155,10 +211,13 @@ def is_untrustworthy_response( bool: True if the response is deemed untrustworthy by TLM, False otherwise """ try: - from cleanlab_studio import Studio # type: ignore # noqa: F401 + from cleanlab_studio import Studio # type: ignore[import-untyped] # noqa: F401 except ImportError as e: - error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." - raise ImportError(error_msg) from e + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e prompt = format_prompt(query, context) result = tlm.get_trustworthiness_score(prompt, response) @@ -195,8 +254,11 @@ def is_unhelpful_response( try: from cleanlab_studio import Studio # noqa: F401 except ImportError as e: - error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." - raise ImportError(error_msg) from e + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e # The question and expected "unhelpful" response are linked: # - When asking "is helpful?" -> "no" means unhelpful diff --git a/tests/test_validation.py b/tests/test_validation.py index facda62..d10e661 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -65,11 +65,11 @@ def test_is_untrustworthy_response(mock_tlm: Mock) -> None: """Test untrustworthy response detection.""" # Test trustworthy response mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.8} - assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5) is False + assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is False # Test untrustworthy response mock_tlm.get_trustworthiness_score.return_value = {"trustworthiness_score": 0.3} - assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, threshold=0.5) is True + assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is True @pytest.mark.parametrize( @@ -130,7 +130,7 @@ def test_is_bad_response( response, context=CONTEXT, query=QUERY, - tlm=mock_tlm, + config={"tlm": mock_tlm}, ) is expected ) @@ -168,7 +168,7 @@ def test_is_bad_response_partial_inputs( is_bad_response( response, query=query, - tlm=tlm, + config={"tlm": tlm}, ) is expected ) From c5843c945a8c51ab76da692c7c0b64be9b2df59d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 05:01:03 +0000 Subject: [PATCH 38/39] formatting --- src/cleanlab_codex/codex_backup.py | 13 ++++++++----- src/cleanlab_codex/response_validation.py | 21 +++++++++++---------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index d0f95d5..70cc1f0 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -119,11 +119,14 @@ def run( response, query=query, context=context, - config=cast(BadResponseDetectionConfig, { - "tlm": self._tlm, - "fallback_answer": self._fallback_answer, - **_is_bad_response_kwargs, - }), + config=cast( + BadResponseDetectionConfig, + { + "tlm": self._tlm, + "fallback_answer": self._fallback_answer, + **_is_bad_response_kwargs, + }, + ), ): return response diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index c1c2eb9..f239397 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -21,21 +21,18 @@ def get_trustworthiness_score( prompt: Union[str, Sequence[str]], response: Union[str, Sequence[str]], **kwargs: Any, - ) -> Dict[str, Any]: - ... + ) -> Dict[str, Any]: ... def prompt( self, prompt: Union[str, Sequence[str]], /, **kwargs: Any, - ) -> Dict[str, Any]: - ... + ) -> Dict[str, Any]: ... TLM = _TLMProtocol - DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 @@ -53,6 +50,7 @@ class BadResponseDetectionConfig(TypedDict, total=False): unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification tlm: TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) """ + # Fallback check config fallback_answer: str partial_ratio_threshold: int @@ -67,6 +65,7 @@ class BadResponseDetectionConfig(TypedDict, total=False): # Shared config (for untrustworthiness and unhelpfulness checks) tlm: Optional[TLM] + def get_bad_response_config() -> BadResponseDetectionConfig: """Get the default configuration for bad response detection functions. @@ -119,11 +118,13 @@ def is_bad_response( validation_checks: list[Callable[[], bool]] = [] # All required inputs are available for checking fallback responses - validation_checks.append(lambda: is_fallback_response( - response, - cfg["fallback_answer"], - threshold=cfg["partial_ratio_threshold"], - )) + validation_checks.append( + lambda: is_fallback_response( + response, + cfg["fallback_answer"], + threshold=cfg["partial_ratio_threshold"], + ) + ) can_run_untrustworthy_check = query is not None and context is not None and cfg["tlm"] is not None if can_run_untrustworthy_check: From 49f9a9dd85f8fe06527db396e0affd68cafe1835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 05:16:39 +0000 Subject: [PATCH 39/39] rename test_validation.py -> test_response_validation.py --- tests/{test_validation.py => test_response_validation.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_validation.py => test_response_validation.py} (100%) diff --git a/tests/test_validation.py b/tests/test_response_validation.py similarity index 100% rename from tests/test_validation.py rename to tests/test_response_validation.py