11"""Data types for Scenario Test Evaluation results."""
2- from dataclasses import InitVar , dataclass , field
2+ from dataclasses import dataclass , field
33from enum import Enum
44from typing import List , Optional
55
@@ -77,31 +77,30 @@ class ScenarioTestEvaluation:
7777 status : ScenarioTestEvaluationStatus = field (init = False )
7878 result : Optional [float ] = field (init = False )
7979 passed : bool = field (init = False )
80- item_evals : List [ScenarioTestItemEvaluation ] = field (init = False )
81- connection : InitVar [Connection ]
82-
83- def __post_init__ (self , connection : Connection ):
84- # TODO(gunnar): Having the function call /info on every construction is too slow. The original
85- # endpoint should rather return the necessary human-readable information
86- response = connection .make_request (
80+ connection : Connection = field (init = False , repr = False )
81+
82+ @classmethod
83+ def from_request (cls , response , connection ):
84+ instance = cls (response ["id" ])
85+ instance .connection = connection
86+
87+ instance .scenario_test_id = response [SCENARIO_TEST_ID_KEY ]
88+ instance .eval_function_id = response [EVAL_FUNCTION_ID_KEY ]
89+ instance .model_id = response [MODEL_ID_KEY ]
90+ instance .status = ScenarioTestEvaluationStatus (response [STATUS_KEY ])
91+ instance .result = try_convert_float (response [RESULT_KEY ])
92+ instance .passed = bool (response [PASS_KEY ])
93+ return instance
94+
95+ @property
96+ def item_evals (self ) -> List [ScenarioTestItemEvaluation ]:
97+ response = self .connection .make_request (
8798 {},
8899 f"validate/eval/{ self .id } /info" ,
89100 requests_command = requests .get ,
90101 )
91- eval_response = response [SCENARIO_TEST_EVAL_KEY ]
92102 items_response = response [ITEM_EVAL_KEY ]
93-
94- self .scenario_test_id : str = eval_response [SCENARIO_TEST_ID_KEY ]
95- self .eval_function_id : str = eval_response [EVAL_FUNCTION_ID_KEY ]
96- self .model_id : str = eval_response [MODEL_ID_KEY ]
97- self .status : ScenarioTestEvaluationStatus = (
98- ScenarioTestEvaluationStatus (eval_response [STATUS_KEY ])
99- )
100- self .result : Optional [float ] = try_convert_float (
101- eval_response [RESULT_KEY ]
102- )
103- self .passed : bool = bool (eval_response [PASS_KEY ])
104- self .item_evals : List [ScenarioTestItemEvaluation ] = [
103+ items = [
105104 ScenarioTestItemEvaluation (
106105 evaluation_id = res [EVALUATION_ID_KEY ],
107106 scenario_test_id = res [SCENARIO_TEST_ID_KEY ],
@@ -112,3 +111,4 @@ def __post_init__(self, connection: Connection):
112111 )
113112 for res in items_response
114113 ]
114+ return items
0 commit comments