From e4ab579684318ac9c68759a6df3dcc32b36e6872 Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Mon, 6 Apr 2026 12:36:20 -0400 Subject: [PATCH 1/6] Add --metrics CLI flag to filter which metrics run Allows running a subset of configured metrics without editing YAML configs. Example: --metrics custom:answer_correctness to skip RAGAS metrics. --- src/lightspeed_evaluation/core/system/validator.py | 12 ++++++++++++ src/lightspeed_evaluation/runner/evaluation.py | 7 +++++++ 2 files changed, 19 insertions(+) diff --git a/src/lightspeed_evaluation/core/system/validator.py b/src/lightspeed_evaluation/core/system/validator.py index fba4b882..98c691ee 100644 --- a/src/lightspeed_evaluation/core/system/validator.py +++ b/src/lightspeed_evaluation/core/system/validator.py @@ -171,6 +171,7 @@ def load_evaluation_data( data_path: str, tags: Optional[list[str]] = None, conv_ids: Optional[list[str]] = None, + metrics: Optional[list[str]] = None, ) -> list[EvaluationData]: """Load, filter, and validate evaluation data from YAML file. @@ -184,6 +185,7 @@ def load_evaluation_data( data_path: Path to the evaluation data YAML file tags: Optional list of tags to filter by conv_ids: Optional list of conversation group IDs to filter by + metrics: Optional list of metrics to run (filters each turn's turn_metrics) Returns: Filtered and validated list of Evaluation Data @@ -230,6 +232,16 @@ def load_evaluation_data( # Filter by scope before validation evaluation_data = self._filter_by_scope(evaluation_data, tags, conv_ids) + # Filter turn_metrics if --metrics was specified + if metrics: + metrics_set = set(metrics) + for eval_data in evaluation_data: + for turn in eval_data.turns: + if turn.turn_metrics: + turn.turn_metrics = [ + m for m in turn.turn_metrics if m in metrics_set + ] + # Semantic validation (metrics availability and requirements) if not self._validate_evaluation_data(evaluation_data): raise DataValidationError("Evaluation data validation failed") diff --git a/src/lightspeed_evaluation/runner/evaluation.py b/src/lightspeed_evaluation/runner/evaluation.py index 39a1f34f..6e5f440c 100644 --- a/src/lightspeed_evaluation/runner/evaluation.py +++ b/src/lightspeed_evaluation/runner/evaluation.py @@ -132,6 +132,7 @@ def run_evaluation( # pylint: disable=too-many-locals eval_args.eval_data, tags=eval_args.tags, conv_ids=eval_args.conv_ids, + metrics=eval_args.metrics, ) print( @@ -236,6 +237,12 @@ def main() -> int: default=None, help="Filter by conversation group IDs (run only specified conversations)", ) + parser.add_argument( + "--metrics", + nargs="+", + default=None, + help="Filter to only run specified metrics (e.g. custom:answer_correctness)", + ) parser.add_argument( "--cache-warmup", action="store_true", From b563026db3c7bf5eb0535f3a2dc2d10e09ff50fe Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Thu, 9 Apr 2026 17:17:55 -0400 Subject: [PATCH 2/6] Add skip field to disable conversations with a reason MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds skip and skip_reason fields to EvaluationData. Conversations with skip: true are silently excluded during loading. skip_reason is documentation-only — it stays in the YAML for humans to read. --- src/lightspeed_evaluation/core/models/data.py | 8 ++++ .../core/system/validator.py | 3 ++ tests/unit/core/system/test_validator.py | 41 +++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/src/lightspeed_evaluation/core/models/data.py b/src/lightspeed_evaluation/core/models/data.py index 4408f440..f3f064de 100644 --- a/src/lightspeed_evaluation/core/models/data.py +++ b/src/lightspeed_evaluation/core/models/data.py @@ -370,6 +370,14 @@ class EvaluationData(BaseModel): min_length=1, description="Tag for grouping and filtering conversations", ) + skip: bool = Field( + default=False, + description="Skip this conversation during evaluation", + ) + skip_reason: Optional[str] = Field( + default=None, + description="Why this conversation is skipped (documentation only)", + ) # Conversation-level metrics conversation_metrics: Optional[list[str]] = Field( diff --git a/src/lightspeed_evaluation/core/system/validator.py b/src/lightspeed_evaluation/core/system/validator.py index 98c691ee..3c97ea21 100644 --- a/src/lightspeed_evaluation/core/system/validator.py +++ b/src/lightspeed_evaluation/core/system/validator.py @@ -232,6 +232,9 @@ def load_evaluation_data( # Filter by scope before validation evaluation_data = self._filter_by_scope(evaluation_data, tags, conv_ids) + # Remove skipped conversations + evaluation_data = [e for e in evaluation_data if not e.skip] + # Filter turn_metrics if --metrics was specified if metrics: metrics_set = set(metrics) diff --git a/tests/unit/core/system/test_validator.py b/tests/unit/core/system/test_validator.py index 44bba743..09840369 100644 --- a/tests/unit/core/system/test_validator.py +++ b/tests/unit/core/system/test_validator.py @@ -563,3 +563,44 @@ def test_filter_by_scope_no_match_returns_empty(self) -> None: ] result = validator._filter_by_scope(data, tags=["nonexistent"]) assert len(result) == 0 + + def test_skip_removes_conversation(self, mocker: MockerFixture) -> None: + """Test that conversations with skip=True are excluded.""" + yaml_data = [ + { + "conversation_group_id": "active", + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + { + "conversation_group_id": "skipped", + "skip": True, + "skip_reason": "Test needs rewrite", + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + { + "conversation_group_id": "also_active", + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + ] + mocker.patch("builtins.open", mocker.mock_open(read_data="")) + mocker.patch("yaml.safe_load", return_value=yaml_data) + validator = DataValidator() + result = validator.load_evaluation_data("dummy.yaml") + assert len(result) == 2 + assert {r.conversation_group_id for r in result} == {"active", "also_active"} + + def test_skip_false_keeps_conversation(self, mocker: MockerFixture) -> None: + """Test that skip=False does not exclude the conversation.""" + yaml_data = [ + { + "conversation_group_id": "explicit_no_skip", + "skip": False, + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + ] + mocker.patch("builtins.open", mocker.mock_open(read_data="")) + mocker.patch("yaml.safe_load", return_value=yaml_data) + validator = DataValidator() + result = validator.load_evaluation_data("dummy.yaml") + assert len(result) == 1 + assert result[0].conversation_group_id == "explicit_no_skip" From 4efda3c6116470adc7f01d850d2dda11f611f9d3 Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Wed, 22 Apr 2026 16:14:46 -0400 Subject: [PATCH 3/6] feat: add retry for all server errors and /infer endpoint support Broaden retry logic from HTTP 429 only to include 5xx server errors, enabling automatic retry with exponential backoff for transient server failures. Add RLSAPI /v1/infer endpoint support for tool call and RAG chunk metadata retrieval, used by RHEL Lightspeed backend testing. --- src/lightspeed_evaluation/core/api/client.py | 146 ++++++++- src/lightspeed_evaluation/core/constants.py | 2 +- .../core/models/system.py | 2 +- tests/unit/core/api/conftest.py | 13 + tests/unit/core/api/test_client.py | 298 +++++++++++++++++- 5 files changed, 430 insertions(+), 31 deletions(-) diff --git a/src/lightspeed_evaluation/core/api/client.py b/src/lightspeed_evaluation/core/api/client.py index 2c3ceb3d..35528280 100644 --- a/src/lightspeed_evaluation/core/api/client.py +++ b/src/lightspeed_evaluation/core/api/client.py @@ -27,12 +27,19 @@ logger = logging.getLogger(__name__) -def _is_too_many_requests_error(exception: BaseException) -> bool: - """Check if exception is a 429 error.""" - return ( - isinstance(exception, httpx.HTTPStatusError) - and exception.response.status_code == 429 - ) +def _is_retryable_server_error(exception: BaseException) -> bool: + """Check if exception is a retryable HTTP error (429 or 5xx). + + Args: + exception: The exception to check. + + Returns: + True if the exception is a retryable HTTP status error. + """ + if not isinstance(exception, httpx.HTTPStatusError): + return False + status = exception.response.status_code + return status == 429 or 500 <= status < 600 class APIClient: @@ -59,10 +66,13 @@ def __init__( retry_decorator = self._create_retry_decorator() self._standard_query_with_retry = retry_decorator(self._standard_query) self._streaming_query_with_retry = retry_decorator(self._streaming_query) + self._rlsapi_infer_query_with_retry = retry_decorator( + self._rlsapi_infer_query + ) def _create_retry_decorator(self) -> Any: return retry( - retry=retry_if_exception(_is_too_many_requests_error), + retry=retry_if_exception(_is_retryable_server_error), stop=stop_after_attempt( self.config.num_retries + 1 ), # +1 to account for the initial attempt @@ -186,6 +196,8 @@ def query( if self.config.endpoint_type == "streaming": response = self._streaming_query_with_retry(api_request) + elif self.config.endpoint_type == "infer": + response = self._rlsapi_infer_query_with_retry(api_request) else: response = self._standard_query_with_retry(api_request) @@ -196,7 +208,7 @@ def query( except RetryError as e: raise APIError( f"Maximum retry attempts ({self.config.num_retries}) reached " - "due to persistent rate limiting (HTTP 429)." + "due to retryable server errors (HTTP 429/5xx)." ) from e def _prepare_request( @@ -285,8 +297,7 @@ def _standard_query(self, api_request: APIRequest) -> APIResponse: except httpx.TimeoutException as e: raise self._handle_timeout_error("standard", self.config.timeout) from e except httpx.HTTPStatusError as e: - # Re-raise 429 errors without conversion to allow retry decorator to handle them - if e.response.status_code == 429: + if _is_retryable_server_error(e): raise raise self._handle_http_error(e) from e except ValueError as e: @@ -313,8 +324,7 @@ def _streaming_query(self, api_request: APIRequest) -> APIResponse: except httpx.TimeoutException as e: raise self._handle_timeout_error("streaming", self.config.timeout) from e except httpx.HTTPStatusError as e: - # Re-raise 429 errors without conversion to allow retry decorator to handle them - if e.response.status_code == 429: + if _is_retryable_server_error(e): raise raise self._handle_http_error(e) from e except ValueError as e: @@ -324,6 +334,118 @@ def _streaming_query(self, api_request: APIRequest) -> APIResponse: except Exception as e: raise self._handle_unexpected_error(e, "streaming query") from e + def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse: + """Query the RLSAPI /infer endpoint for tool call and RAG metadata. + + The infer endpoint uses a different request/response format than + the standard query/streaming endpoints, converting "query" to + "question" and parsing tool_calls and rag_chunks from tool_results. + + Args: + api_request: The prepared API request. + + Returns: + APIResponse with response text, tool calls, and RAG contexts. + + Raises: + APIError: If the request fails or response is invalid. + """ + if not self.client: + raise APIError("HTTP client not initialized") + try: + request_data = api_request.model_dump(exclude_none=True) + infer_request: dict[str, object] = { + "question": request_data.pop("query"), + "include_metadata": True, + } + + logger.debug( + "RLSAPI infer request URL: /api/lightspeed/%s/infer", + self.config.version, + ) + logger.debug("RLSAPI infer request body: %s", infer_request) + + response = self.client.post( + f"/api/lightspeed/{self.config.version}/infer", + json=infer_request, + ) + response.raise_for_status() + + response_data = response.json() + + if "data" in response_data: + data = response_data["data"] + if "text" in data: + response_data["response"] = data["text"] + if "request_id" in data: + response_data["conversation_id"] = data["request_id"] + if "input_tokens" in data: + response_data["input_tokens"] = data["input_tokens"] + if "output_tokens" in data: + response_data["output_tokens"] = data["output_tokens"] + if "tool_calls" in data: + response_data["tool_calls"] = data["tool_calls"] + if "tool_results" in data: + tool_results = data["tool_results"] + for result in tool_results: + if result.get("type") == "mcp_call": + content = result["content"].split("---") + response_data["rag_chunks"] = [ + {"content": chunk} for chunk in content + ] + + if "response" not in response_data: + raise APIError("API response missing 'response' field") + + if "tool_calls" in response_data and response_data["tool_calls"]: + raw_tool_calls = response_data["tool_calls"] + formatted_tool_calls = [] + + for tool_call in raw_tool_calls: + if isinstance(tool_call, dict): + formatted_tool: dict[str, object] = { + "tool_name": ( + tool_call.get("tool_name") + or tool_call.get("name") + or "" + ), + "arguments": ( + tool_call.get("arguments") + or tool_call.get("args") + or {} + ), + } + if "tool_results" in response_data.get("data", {}): + tool_call_id = tool_call.get("id") + matching_result = next( + ( + r + for r in response_data["data"]["tool_results"] + if r.get("id") == tool_call_id + ), + None, + ) + if matching_result: + formatted_tool["result"] = matching_result["status"] + formatted_tool_calls.append([formatted_tool]) + + response_data["tool_calls"] = formatted_tool_calls + + return APIResponse.from_raw_response(response_data) + + except httpx.TimeoutException as e: + raise self._handle_timeout_error("infer", self.config.timeout) from e + except httpx.HTTPStatusError as e: + if _is_retryable_server_error(e): + raise + raise self._handle_http_error(e) from e + except ValueError as e: + raise self._handle_validation_error(e) from e + except APIError: + raise + except Exception as e: + raise self._handle_unexpected_error(e, "infer query") from e + def _handle_response_errors(self, response: httpx.Response) -> None: """Handle HTTP response errors for streaming endpoint.""" if response.status_code != 200: diff --git a/src/lightspeed_evaluation/core/constants.py b/src/lightspeed_evaluation/core/constants.py index 7faefe26..b9ec20c3 100644 --- a/src/lightspeed_evaluation/core/constants.py +++ b/src/lightspeed_evaluation/core/constants.py @@ -56,7 +56,7 @@ DEFAULT_API_VERSION = "v1" DEFAULT_API_TIMEOUT = 300 DEFAULT_ENDPOINT_TYPE = "streaming" -SUPPORTED_ENDPOINT_TYPES = ["streaming", "query"] +SUPPORTED_ENDPOINT_TYPES = ["streaming", "query", "infer"] DEFAULT_API_CACHE_DIR = ".caches/api_cache" DEFAULT_API_NUM_RETRIES = 3 diff --git a/src/lightspeed_evaluation/core/models/system.py b/src/lightspeed_evaluation/core/models/system.py index 5067d3b2..7945d0e9 100644 --- a/src/lightspeed_evaluation/core/models/system.py +++ b/src/lightspeed_evaluation/core/models/system.py @@ -301,7 +301,7 @@ class APIConfig(BaseModel): ge=0, description=( "Maximum number of retry attempts for API calls on " - "429 Too Many Requests errors" + "retryable server errors (HTTP 429/5xx)" ), ) diff --git a/tests/unit/core/api/conftest.py b/tests/unit/core/api/conftest.py index 036501a5..f6ed6901 100644 --- a/tests/unit/core/api/conftest.py +++ b/tests/unit/core/api/conftest.py @@ -35,6 +35,19 @@ def basic_api_config_streaming_endpoint() -> APIConfig: ) +@pytest.fixture +def basic_api_config_infer_endpoint() -> APIConfig: + """Create test API config for infer endpoint.""" + return APIConfig( + enabled=True, + api_base="http://localhost:8080", + version="v1", + endpoint_type="infer", + timeout=30, + cache_enabled=False, + ) + + @pytest.fixture def mock_response(mocker: MockerFixture) -> Any: """Create a mock streaming response.""" diff --git a/tests/unit/core/api/test_client.py b/tests/unit/core/api/test_client.py index d8fe1ce8..a707bd97 100644 --- a/tests/unit/core/api/test_client.py +++ b/tests/unit/core/api/test_client.py @@ -10,7 +10,7 @@ from lightspeed_evaluation.core.models import APIConfig, APIResponse from lightspeed_evaluation.core.system.exceptions import APIError -from lightspeed_evaluation.core.api.client import APIClient, _is_too_many_requests_error +from lightspeed_evaluation.core.api.client import APIClient, _is_retryable_server_error class TestAPIClient: @@ -121,16 +121,15 @@ def test_query_with_attachments( assert request_data["attachments"][0]["content"] == "file1.txt" assert request_data["attachments"][1]["content"] == "file2.pdf" - def test_query_http_error( + def test_query_http_error_non_retryable( self, basic_api_config_query_endpoint: APIConfig, mocker: MockerFixture ) -> None: - """Test query handling HTTP errors.""" - + """Test query handling non-retryable HTTP errors (4xx except 429).""" mock_response = mocker.Mock() - mock_response.status_code = 500 - mock_response.text = "Internal server error" + mock_response.status_code = 400 + mock_response.text = "Bad request" mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "500 error", request=mocker.Mock(), response=mock_response + "400 error", request=mocker.Mock(), response=mock_response ) mock_client = mocker.Mock() @@ -144,7 +143,7 @@ def test_query_http_error( client = APIClient(basic_api_config_query_endpoint) - with pytest.raises(APIError, match="API error: 500"): + with pytest.raises(APIError, match="API error: 400"): client.query("Test query") def test_query_timeout_error( @@ -690,20 +689,40 @@ def test_serialize_request_skips_reserved_fields( class TestRetryLogic: """Unit tests for retry logic in APIClient.""" - def test_is_too_many_requests_error(self, mocker: MockerFixture) -> None: - """Test _is_too_many_requests_error identifies 429 errors.""" - # Test with 429 status code + def test_is_retryable_server_error(self, mocker: MockerFixture) -> None: + """Test _is_retryable_server_error identifies 429 and 5xx errors.""" resp_429 = mocker.Mock(status_code=429) - assert _is_too_many_requests_error( + assert _is_retryable_server_error( httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_429) ) - # Test with non-429 status code resp_500 = mocker.Mock(status_code=500) - assert not _is_too_many_requests_error( + assert _is_retryable_server_error( httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_500) ) + resp_502 = mocker.Mock(status_code=502) + assert _is_retryable_server_error( + httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_502) + ) + + resp_503 = mocker.Mock(status_code=503) + assert _is_retryable_server_error( + httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_503) + ) + + resp_400 = mocker.Mock(status_code=400) + assert not _is_retryable_server_error( + httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_400) + ) + + resp_404 = mocker.Mock(status_code=404) + assert not _is_retryable_server_error( + httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_404) + ) + + assert not _is_retryable_server_error(ValueError("not an HTTP error")) + def test_standard_query_retries_on_429_then_succeeds( self, basic_api_config_query_endpoint: APIConfig, mocker: MockerFixture ) -> None: @@ -790,9 +809,254 @@ def test_query_raises_api_error_after_max_retries( client = APIClient(basic_api_config_query_endpoint) - with pytest.raises( - APIError, match=str(basic_api_config_query_endpoint.num_retries) - ): + with pytest.raises(APIError, match="Maximum retry attempts"): client.query("Test query") assert mock_client.post.call_count == 4 # 3 retries + 1 initial attempt + + def test_standard_query_retries_on_500_then_succeeds( + self, basic_api_config_query_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test standard query retries on 500 error and succeeds on retry.""" + mock_response_500 = mocker.Mock(status_code=500, text="Internal server error") + mock_response_500.raise_for_status.side_effect = httpx.HTTPStatusError( + "500 error", request=mocker.Mock(), response=mock_response_500 + ) + + mock_response_success = mocker.Mock(status_code=200) + mock_response_success.json.return_value = { + "response": "Success after 500 retry", + "conversation_id": "conv_123", + } + + mock_client = mocker.Mock() + mock_client.post.side_effect = [mock_response_500, mock_response_success] + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_query_endpoint) + result = client.query("Test standard query") + + assert result.response == "Success after 500 retry" + assert mock_client.post.call_count == 2 + + +class TestInferEndpoint: + """Tests for RLSAPI /infer endpoint support.""" + + def test_query_infer_endpoint_success( + self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test successful query to infer endpoint.""" + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "text": "Infer response", + "request_id": "req_abc", + "input_tokens": 10, + "output_tokens": 20, + } + } + + mock_client = mocker.Mock() + mock_client.post.return_value = mock_response + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_infer_endpoint) + result = client.query("What is RHEL?") + + assert isinstance(result, APIResponse) + assert result.response == "Infer response" + assert result.conversation_id == "req_abc" + assert result.input_tokens == 10 + assert result.output_tokens == 20 + + call_kwargs = mock_client.post.call_args + assert "/api/lightspeed/v1/infer" in call_kwargs[0][0] + request_body = call_kwargs[1]["json"] + assert request_body["question"] == "What is RHEL?" + assert request_body["include_metadata"] is True + + def test_infer_query_formats_tool_calls( + self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test that infer query formats tool calls correctly.""" + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "text": "Response with tools", + "request_id": "req_abc", + "tool_calls": [ + {"id": "tc1", "name": "search_documentation", "args": {"q": "rhel"}}, + {"id": "tc2", "tool_name": "mcp_list_tools", "arguments": {}}, + ], + "tool_results": [ + {"id": "tc1", "type": "mcp_call", "status": "success", "content": "result1"}, + {"id": "tc2", "type": "tool_list", "status": "completed", "content": "tools"}, + ], + } + } + + mock_client = mocker.Mock() + mock_client.post.return_value = mock_response + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_infer_endpoint) + result = client.query("Test query") + + assert len(result.tool_calls) == 2 + assert isinstance(result.tool_calls[0], list) + assert result.tool_calls[0][0]["tool_name"] == "search_documentation" + assert result.tool_calls[0][0]["arguments"] == {"q": "rhel"} + assert result.tool_calls[0][0]["result"] == "success" + assert result.tool_calls[1][0]["tool_name"] == "mcp_list_tools" + assert result.tool_calls[1][0]["result"] == "completed" + + def test_infer_query_extracts_rag_chunks( + self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test that infer query extracts RAG chunks from tool_results.""" + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "text": "Response with RAG", + "request_id": "req_abc", + "tool_results": [ + { + "id": "tr1", + "type": "mcp_call", + "status": "success", + "content": "Chunk one---Chunk two---Chunk three", + } + ], + } + } + + mock_client = mocker.Mock() + mock_client.post.return_value = mock_response + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_infer_endpoint) + result = client.query("Test query") + + assert len(result.contexts) == 3 + assert "Chunk one" in result.contexts[0] + assert "Chunk two" in result.contexts[1] + assert "Chunk three" in result.contexts[2] + + def test_infer_query_missing_response_field( + self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test infer query handles missing response field.""" + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": {"request_id": "req_abc"}} + + mock_client = mocker.Mock() + mock_client.post.return_value = mock_response + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_infer_endpoint) + + with pytest.raises(APIError, match="missing 'response' field"): + client.query("Test query") + + def test_infer_query_timeout_error( + self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test infer query handles timeout.""" + mock_client = mocker.Mock() + mock_client.post.side_effect = httpx.TimeoutException("Timeout") + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_infer_endpoint) + + with pytest.raises(APIError, match="timeout"): + client.query("Test query") + + def test_infer_query_retries_on_429( + self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test infer query retries on 429 then succeeds.""" + mock_response_429 = mocker.Mock(status_code=429) + mock_response_429.raise_for_status.side_effect = httpx.HTTPStatusError( + "429 error", request=mocker.Mock(), response=mock_response_429 + ) + + mock_response_success = mocker.Mock(status_code=200) + mock_response_success.json.return_value = { + "data": { + "text": "Success after retry", + "request_id": "req_abc", + } + } + + mock_client = mocker.Mock() + mock_client.post.side_effect = [mock_response_429, mock_response_success] + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_infer_endpoint) + result = client.query("Test query") + + assert result.response == "Success after retry" + assert mock_client.post.call_count == 2 + + def test_infer_query_http_error_non_retryable( + self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture + ) -> None: + """Test infer query raises APIError for non-retryable HTTP errors.""" + mock_response = mocker.Mock(status_code=400, text="Bad request") + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "400 error", request=mocker.Mock(), response=mock_response + ) + + mock_client = mocker.Mock() + mock_client.post.return_value = mock_response + mock_client.headers = {} + + mocker.patch( + "lightspeed_evaluation.core.api.client.httpx.Client", + return_value=mock_client, + ) + + client = APIClient(basic_api_config_infer_endpoint) + + with pytest.raises(APIError, match="API error: 400"): + client.query("Test query") From 8171fa25c5fb7d2b46c1b34be43bfd9959782ca1 Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Thu, 23 Apr 2026 10:12:52 -0400 Subject: [PATCH 4/6] config: increase default retry attempts from 3 to 5 Provides more resilience against transient server failures, especially during long evaluation runs. --- src/lightspeed_evaluation/core/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightspeed_evaluation/core/constants.py b/src/lightspeed_evaluation/core/constants.py index b9ec20c3..3fbfa738 100644 --- a/src/lightspeed_evaluation/core/constants.py +++ b/src/lightspeed_evaluation/core/constants.py @@ -70,7 +70,7 @@ DEFAULT_SSL_CERT_FILE = None DEFAULT_LLM_TEMPERATURE = 0.0 DEFAULT_LLM_MAX_TOKENS = 512 -DEFAULT_LLM_RETRIES = 3 +DEFAULT_LLM_RETRIES = 5 DEFAULT_LLM_CACHE_DIR = ".caches/llm_cache" DEFAULT_EMBEDDING_PROVIDER = "openai" From ce00854aaeec2b55fd5eba8f5dcef0dbdcb29bde Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Thu, 23 Apr 2026 10:20:33 -0400 Subject: [PATCH 5/6] style: apply black formatting to client and tests --- src/lightspeed_evaluation/core/api/client.py | 4 +--- tests/unit/core/api/test_client.py | 20 +++++++++++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/lightspeed_evaluation/core/api/client.py b/src/lightspeed_evaluation/core/api/client.py index 35528280..9827a818 100644 --- a/src/lightspeed_evaluation/core/api/client.py +++ b/src/lightspeed_evaluation/core/api/client.py @@ -66,9 +66,7 @@ def __init__( retry_decorator = self._create_retry_decorator() self._standard_query_with_retry = retry_decorator(self._standard_query) self._streaming_query_with_retry = retry_decorator(self._streaming_query) - self._rlsapi_infer_query_with_retry = retry_decorator( - self._rlsapi_infer_query - ) + self._rlsapi_infer_query_with_retry = retry_decorator(self._rlsapi_infer_query) def _create_retry_decorator(self) -> Any: return retry( diff --git a/tests/unit/core/api/test_client.py b/tests/unit/core/api/test_client.py index a707bd97..e5850f5d 100644 --- a/tests/unit/core/api/test_client.py +++ b/tests/unit/core/api/test_client.py @@ -898,12 +898,26 @@ def test_infer_query_formats_tool_calls( "text": "Response with tools", "request_id": "req_abc", "tool_calls": [ - {"id": "tc1", "name": "search_documentation", "args": {"q": "rhel"}}, + { + "id": "tc1", + "name": "search_documentation", + "args": {"q": "rhel"}, + }, {"id": "tc2", "tool_name": "mcp_list_tools", "arguments": {}}, ], "tool_results": [ - {"id": "tc1", "type": "mcp_call", "status": "success", "content": "result1"}, - {"id": "tc2", "type": "tool_list", "status": "completed", "content": "tools"}, + { + "id": "tc1", + "type": "mcp_call", + "status": "success", + "content": "result1", + }, + { + "id": "tc2", + "type": "tool_list", + "status": "completed", + "content": "tools", + }, ], } } From be8019de598a1c0f90878276378ae0c0739b45c1 Mon Sep 17 00:00:00 2001 From: Ellis Low Date: Mon, 27 Apr 2026 14:44:04 -0400 Subject: [PATCH 6/6] fix: address PR review feedback (asamal4 + CodeRabbit) - Revert DEFAULT_LLM_RETRIES from 5 to 3 - Narrow retry codes to (429, 502, 503, 504), exclude 500 - Use RLSAPI native fields (name/args) in _rlsapi_infer_query - Fix RAG chunk accumulation across multiple mcp_call results - Redact prompt from debug log, log only metadata - Add comment about extra_request_params not forwarded to /infer - Fix tool result capture: use content with status fallback - Update endpoint_type description to include infer - Move skip tests from TestFilterByScope to TestDataValidator - Fix MockerFixture import in test_validator.py - Fix --metrics filter: handle turn_metrics=None by materializing system defaults before filtering; add conversation_metrics filter - Add metrics=None to runner test fixture for --metrics support - Add tests for metrics filter materialization Signed-off-by: Ellis Low --- src/lightspeed_evaluation/core/api/client.py | 45 +++-- src/lightspeed_evaluation/core/constants.py | 2 +- .../core/models/system.py | 2 +- .../core/system/validator.py | 28 +++- tests/unit/core/api/test_client.py | 26 +-- tests/unit/core/system/test_validator.py | 156 +++++++++++++++--- tests/unit/runner/test_evaluation.py | 7 +- 7 files changed, 208 insertions(+), 58 deletions(-) diff --git a/src/lightspeed_evaluation/core/api/client.py b/src/lightspeed_evaluation/core/api/client.py index 9827a818..cef5a655 100644 --- a/src/lightspeed_evaluation/core/api/client.py +++ b/src/lightspeed_evaluation/core/api/client.py @@ -28,7 +28,11 @@ def _is_retryable_server_error(exception: BaseException) -> bool: - """Check if exception is a retryable HTTP error (429 or 5xx). + """Check if exception is a retryable HTTP error (429 or transient 5xx). + + Only 502 Bad Gateway, 503 Service Unavailable, and 504 Gateway Timeout + are retried. 500 Internal Server Error is excluded as it may indicate + permanent server bugs. Args: exception: The exception to check. @@ -39,7 +43,7 @@ def _is_retryable_server_error(exception: BaseException) -> bool: if not isinstance(exception, httpx.HTTPStatusError): return False status = exception.response.status_code - return status == 429 or 500 <= status < 600 + return status in (429, 502, 503, 504) class APIClient: @@ -352,6 +356,10 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse: raise APIError("HTTP client not initialized") try: request_data = api_request.model_dump(exclude_none=True) + # `extra_request_params` are not forwarded to `/infer` — the + # endpoint only accepts `question` and `include_metadata`. + # Other params (model, provider, etc.) are not part of the + # RLSAPI `/infer` API contract. infer_request: dict[str, object] = { "question": request_data.pop("query"), "include_metadata": True, @@ -361,7 +369,13 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse: "RLSAPI infer request URL: /api/lightspeed/%s/infer", self.config.version, ) - logger.debug("RLSAPI infer request body: %s", infer_request) + logger.debug( + "RLSAPI infer request: version=%s, include_metadata=%s, " + "question_length=%d", + self.config.version, + True, + len(str(infer_request.get("question", ""))), + ) response = self.client.post( f"/api/lightspeed/{self.config.version}/infer", @@ -385,12 +399,12 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse: response_data["tool_calls"] = data["tool_calls"] if "tool_results" in data: tool_results = data["tool_results"] + rag_chunks: list[dict[str, str]] = [] for result in tool_results: if result.get("type") == "mcp_call": content = result["content"].split("---") - response_data["rag_chunks"] = [ - {"content": chunk} for chunk in content - ] + rag_chunks.extend([{"content": chunk} for chunk in content]) + response_data["rag_chunks"] = rag_chunks if "response" not in response_data: raise APIError("API response missing 'response' field") @@ -402,16 +416,8 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse: for tool_call in raw_tool_calls: if isinstance(tool_call, dict): formatted_tool: dict[str, object] = { - "tool_name": ( - tool_call.get("tool_name") - or tool_call.get("name") - or "" - ), - "arguments": ( - tool_call.get("arguments") - or tool_call.get("args") - or {} - ), + "tool_name": tool_call.get("name", ""), + "arguments": tool_call.get("args", {}), } if "tool_results" in response_data.get("data", {}): tool_call_id = tool_call.get("id") @@ -424,7 +430,12 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse: None, ) if matching_result: - formatted_tool["result"] = matching_result["status"] + formatted_tool["result"] = matching_result.get( + "content", matching_result.get("status", "") + ) + formatted_tool["status"] = matching_result.get( + "status", "" + ) formatted_tool_calls.append([formatted_tool]) response_data["tool_calls"] = formatted_tool_calls diff --git a/src/lightspeed_evaluation/core/constants.py b/src/lightspeed_evaluation/core/constants.py index 3fbfa738..b9ec20c3 100644 --- a/src/lightspeed_evaluation/core/constants.py +++ b/src/lightspeed_evaluation/core/constants.py @@ -70,7 +70,7 @@ DEFAULT_SSL_CERT_FILE = None DEFAULT_LLM_TEMPERATURE = 0.0 DEFAULT_LLM_MAX_TOKENS = 512 -DEFAULT_LLM_RETRIES = 5 +DEFAULT_LLM_RETRIES = 3 DEFAULT_LLM_CACHE_DIR = ".caches/llm_cache" DEFAULT_EMBEDDING_PROVIDER = "openai" diff --git a/src/lightspeed_evaluation/core/models/system.py b/src/lightspeed_evaluation/core/models/system.py index 7945d0e9..166e98bb 100644 --- a/src/lightspeed_evaluation/core/models/system.py +++ b/src/lightspeed_evaluation/core/models/system.py @@ -271,7 +271,7 @@ class APIConfig(BaseModel): ) endpoint_type: str = Field( default=DEFAULT_ENDPOINT_TYPE, - description="API endpoint type (streaming or query)", + description="API endpoint type (streaming, query, or infer)", ) timeout: int = Field( default=DEFAULT_API_TIMEOUT, ge=1, description="Request timeout in seconds" diff --git a/src/lightspeed_evaluation/core/system/validator.py b/src/lightspeed_evaluation/core/system/validator.py index 3c97ea21..314997f9 100644 --- a/src/lightspeed_evaluation/core/system/validator.py +++ b/src/lightspeed_evaluation/core/system/validator.py @@ -159,6 +159,7 @@ def __init__( self.api_enabled = api_enabled self.original_data_path: Optional[str] = None self.fail_on_invalid_data = fail_on_invalid_data + self._system_config = system_config self._turn_level_metrics: set[str] = ( system_config.turn_level_metric_names if system_config else set() ) @@ -235,15 +236,38 @@ def load_evaluation_data( # Remove skipped conversations evaluation_data = [e for e in evaluation_data if not e.skip] - # Filter turn_metrics if --metrics was specified + # Filter turn_metrics and conversation_metrics if --metrics was specified if metrics: metrics_set = set(metrics) for eval_data in evaluation_data: for turn in eval_data.turns: - if turn.turn_metrics: + if turn.turn_metrics is not None: turn.turn_metrics = [ m for m in turn.turn_metrics if m in metrics_set ] + elif self._system_config is not None: + turn_defaults = ( + self._system_config.default_turn_metrics_metadata + ) + turn.turn_metrics = [ + m + for m, meta in turn_defaults.items() + if meta.get("default", False) and m in metrics_set + ] + + if eval_data.conversation_metrics is not None: + eval_data.conversation_metrics = [ + m for m in eval_data.conversation_metrics if m in metrics_set + ] + elif self._system_config is not None: + conv_defaults = ( + self._system_config.default_conversation_metrics_metadata + ) + eval_data.conversation_metrics = [ + m + for m, meta in conv_defaults.items() + if meta.get("default", False) and m in metrics_set + ] # Semantic validation (metrics availability and requirements) if not self._validate_evaluation_data(evaluation_data): diff --git a/tests/unit/core/api/test_client.py b/tests/unit/core/api/test_client.py index e5850f5d..a73aa8d1 100644 --- a/tests/unit/core/api/test_client.py +++ b/tests/unit/core/api/test_client.py @@ -697,7 +697,7 @@ def test_is_retryable_server_error(self, mocker: MockerFixture) -> None: ) resp_500 = mocker.Mock(status_code=500) - assert _is_retryable_server_error( + assert not _is_retryable_server_error( httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_500) ) @@ -814,23 +814,23 @@ def test_query_raises_api_error_after_max_retries( assert mock_client.post.call_count == 4 # 3 retries + 1 initial attempt - def test_standard_query_retries_on_500_then_succeeds( + def test_standard_query_retries_on_502_then_succeeds( self, basic_api_config_query_endpoint: APIConfig, mocker: MockerFixture ) -> None: - """Test standard query retries on 500 error and succeeds on retry.""" - mock_response_500 = mocker.Mock(status_code=500, text="Internal server error") - mock_response_500.raise_for_status.side_effect = httpx.HTTPStatusError( - "500 error", request=mocker.Mock(), response=mock_response_500 + """Test standard query retries on 502 error and succeeds on retry.""" + mock_response_502 = mocker.Mock(status_code=502, text="Bad gateway") + mock_response_502.raise_for_status.side_effect = httpx.HTTPStatusError( + "502 error", request=mocker.Mock(), response=mock_response_502 ) mock_response_success = mocker.Mock(status_code=200) mock_response_success.json.return_value = { - "response": "Success after 500 retry", + "response": "Success after 502 retry", "conversation_id": "conv_123", } mock_client = mocker.Mock() - mock_client.post.side_effect = [mock_response_500, mock_response_success] + mock_client.post.side_effect = [mock_response_502, mock_response_success] mock_client.headers = {} mocker.patch( @@ -841,7 +841,7 @@ def test_standard_query_retries_on_500_then_succeeds( client = APIClient(basic_api_config_query_endpoint) result = client.query("Test standard query") - assert result.response == "Success after 500 retry" + assert result.response == "Success after 502 retry" assert mock_client.post.call_count == 2 @@ -903,7 +903,7 @@ def test_infer_query_formats_tool_calls( "name": "search_documentation", "args": {"q": "rhel"}, }, - {"id": "tc2", "tool_name": "mcp_list_tools", "arguments": {}}, + {"id": "tc2", "name": "mcp_list_tools", "args": {}}, ], "tool_results": [ { @@ -938,9 +938,11 @@ def test_infer_query_formats_tool_calls( assert isinstance(result.tool_calls[0], list) assert result.tool_calls[0][0]["tool_name"] == "search_documentation" assert result.tool_calls[0][0]["arguments"] == {"q": "rhel"} - assert result.tool_calls[0][0]["result"] == "success" + assert result.tool_calls[0][0]["result"] == "result1" + assert result.tool_calls[0][0]["status"] == "success" assert result.tool_calls[1][0]["tool_name"] == "mcp_list_tools" - assert result.tool_calls[1][0]["result"] == "completed" + assert result.tool_calls[1][0]["result"] == "tools" + assert result.tool_calls[1][0]["status"] == "completed" def test_infer_query_extracts_rag_chunks( self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture diff --git a/tests/unit/core/system/test_validator.py b/tests/unit/core/system/test_validator.py index 09840369..d98e3573 100644 --- a/tests/unit/core/system/test_validator.py +++ b/tests/unit/core/system/test_validator.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from pytest_mock import MockerFixture from pydantic import ValidationError @@ -463,6 +464,47 @@ def test_validate_evaluation_data_accumulates_errors(self) -> None: # Should have errors for both issues assert len(validator.validation_errors) >= 2 + def test_skip_removes_conversation(self, mocker: MockerFixture) -> None: + """Test that conversations with skip=True are excluded.""" + yaml_data = [ + { + "conversation_group_id": "active", + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + { + "conversation_group_id": "skipped", + "skip": True, + "skip_reason": "Test needs rewrite", + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + { + "conversation_group_id": "also_active", + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + ] + mocker.patch("builtins.open", mocker.mock_open(read_data="")) + mocker.patch("yaml.safe_load", return_value=yaml_data) + validator = DataValidator() + result = validator.load_evaluation_data("dummy.yaml") + assert len(result) == 2 + assert {r.conversation_group_id for r in result} == {"active", "also_active"} + + def test_skip_false_keeps_conversation(self, mocker: MockerFixture) -> None: + """Test that skip=False does not exclude the conversation.""" + yaml_data = [ + { + "conversation_group_id": "explicit_no_skip", + "skip": False, + "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + }, + ] + mocker.patch("builtins.open", mocker.mock_open(read_data="")) + mocker.patch("yaml.safe_load", return_value=yaml_data) + validator = DataValidator() + result = validator.load_evaluation_data("dummy.yaml") + assert len(result) == 1 + assert result[0].conversation_group_id == "explicit_no_skip" + class TestFilterByScope: """Unit test for filter by scope.""" @@ -564,43 +606,109 @@ def test_filter_by_scope_no_match_returns_empty(self) -> None: result = validator._filter_by_scope(data, tags=["nonexistent"]) assert len(result) == 0 - def test_skip_removes_conversation(self, mocker: MockerFixture) -> None: - """Test that conversations with skip=True are excluded.""" + +class TestMetricsFilter: + """Tests for --metrics filter in load_evaluation_data.""" + + def test_turn_metrics_none_materializes_defaults_and_filters( + self, mocker: MockerFixture + ) -> None: + """Test turn_metrics=None materializes system defaults, then filters.""" yaml_data = [ { - "conversation_group_id": "active", - "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + "conversation_group_id": "conv1", + "turns": [ + { + "turn_id": "t1", + "query": "Q", + "response": "A", + "contexts": ["C"], + "expected_response": "E", + }, + ], }, - { - "conversation_group_id": "skipped", - "skip": True, - "skip_reason": "Test needs rewrite", - "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + ] + mocker.patch("builtins.open", mocker.mock_open(read_data="")) + mocker.patch("yaml.safe_load", return_value=yaml_data) + config = SystemConfig( + default_turn_metrics_metadata={ + "ragas:faithfulness": {"default": True, "threshold": 0.7}, + "ragas:response_relevancy": {"default": True, "threshold": 0.7}, + "custom:answer_correctness": {"default": False, "threshold": 0.8}, }, + ) + validator = DataValidator(system_config=config) + result = validator.load_evaluation_data( + "dummy.yaml", metrics=["ragas:faithfulness"] + ) + assert result[0].turns[0].turn_metrics == ["ragas:faithfulness"] + + def test_conversation_metrics_none_materializes_defaults_and_filters( + self, mocker: MockerFixture + ) -> None: + """Test conversation_metrics=None materializes defaults, then filters.""" + yaml_data = [ { - "conversation_group_id": "also_active", - "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + "conversation_group_id": "conv1", + "turns": [ + {"turn_id": "t1", "query": "Q", "response": "A"}, + ], }, ] mocker.patch("builtins.open", mocker.mock_open(read_data="")) mocker.patch("yaml.safe_load", return_value=yaml_data) - validator = DataValidator() - result = validator.load_evaluation_data("dummy.yaml") - assert len(result) == 2 - assert {r.conversation_group_id for r in result} == {"active", "also_active"} + config = SystemConfig( + default_conversation_metrics_metadata={ + "deepeval:conversation_completeness": { + "default": True, + "threshold": 0.6, + }, + "deepeval:conversation_relevancy": { + "default": False, + "threshold": 0.5, + }, + }, + ) + validator = DataValidator(system_config=config) + result = validator.load_evaluation_data( + "dummy.yaml", + metrics=["deepeval:conversation_completeness"], + ) + assert result[0].conversation_metrics == ["deepeval:conversation_completeness"] - def test_skip_false_keeps_conversation(self, mocker: MockerFixture) -> None: - """Test that skip=False does not exclude the conversation.""" + def test_conversation_metrics_explicit_list_filters( + self, mocker: MockerFixture + ) -> None: + """Test explicit conversation_metrics list is filtered by --metrics.""" yaml_data = [ { - "conversation_group_id": "explicit_no_skip", - "skip": False, - "turns": [{"turn_id": "t1", "query": "Q", "response": "A"}], + "conversation_group_id": "conv1", + "conversation_metrics": [ + "deepeval:conversation_completeness", + "deepeval:conversation_relevancy", + ], + "turns": [ + {"turn_id": "t1", "query": "Q", "response": "A"}, + ], }, ] mocker.patch("builtins.open", mocker.mock_open(read_data="")) mocker.patch("yaml.safe_load", return_value=yaml_data) - validator = DataValidator() - result = validator.load_evaluation_data("dummy.yaml") - assert len(result) == 1 - assert result[0].conversation_group_id == "explicit_no_skip" + config = SystemConfig( + default_conversation_metrics_metadata={ + "deepeval:conversation_completeness": { + "default": True, + "threshold": 0.6, + }, + "deepeval:conversation_relevancy": { + "default": True, + "threshold": 0.5, + }, + }, + ) + validator = DataValidator(system_config=config) + result = validator.load_evaluation_data( + "dummy.yaml", + metrics=["deepeval:conversation_completeness"], + ) + assert result[0].conversation_metrics == ["deepeval:conversation_completeness"] diff --git a/tests/unit/runner/test_evaluation.py b/tests/unit/runner/test_evaluation.py index 432c8cf2..1a2ac21b 100644 --- a/tests/unit/runner/test_evaluation.py +++ b/tests/unit/runner/test_evaluation.py @@ -30,6 +30,7 @@ def _make_eval_args(**kwargs: Any) -> argparse.Namespace: "output_dir": None, "tags": None, "conv_ids": None, + "metrics": None, "cache_warmup": False, } defaults.update(kwargs) @@ -353,7 +354,10 @@ def test_run_evaluation_with_empty_filter_result( assert result is not None assert result["TOTAL"] == 0 mock_validator.return_value.load_evaluation_data.assert_called_once_with( - "config/evaluation_data.yaml", tags=["nonexistent"], conv_ids=None + "config/evaluation_data.yaml", + tags=["nonexistent"], + conv_ids=None, + metrics=None, ) # Verify warning message appears @@ -409,6 +413,7 @@ def test_run_evaluation_with_filter_parameters(self, mocker: MockerFixture) -> N "config/evaluation_data.yaml", tags=["basic"], conv_ids=["conv_1"], + metrics=None, )