|
18 | 18 | THRESHOLD_KEY, |
19 | 19 | ThresholdComparison, |
20 | 20 | ) |
21 | | -from .data_transfer_objects.scenario_test_evaluations import GetEvalHistory |
| 21 | +from .data_transfer_objects.scenario_test_evaluations import ( |
| 22 | + EvaluationResult, |
| 23 | + GetEvalHistory, |
| 24 | +) |
22 | 25 | from .data_transfer_objects.scenario_test_metric import AddScenarioTestFunction |
23 | | -from .eval_functions.available_eval_functions import EvalFunction |
| 26 | +from .eval_functions.available_eval_functions import ( |
| 27 | + EvalFunction, |
| 28 | + ExternalEvalFunction, |
| 29 | +) |
24 | 30 | from .scenario_test_evaluation import ScenarioTestEvaluation |
25 | 31 | from .scenario_test_metric import ScenarioTestMetric |
26 | 32 |
|
@@ -83,9 +89,13 @@ def add_eval_function( |
83 | 89 | Args: |
84 | 90 | eval_function: :class:`EvalFunction` |
85 | 91 |
|
| 92 | + Raises: |
| 93 | + NucleusAPIError: By adding this function, the scenario test mixes external with non-external functions which is not permitted. |
| 94 | +
|
86 | 95 | Returns: |
87 | 96 | The created ScenarioTestMetric object. |
88 | 97 | """ |
| 98 | + |
89 | 99 | response = self.connection.post( |
90 | 100 | AddScenarioTestFunction( |
91 | 101 | scenario_test_name=self.name, |
@@ -174,3 +184,43 @@ def set_baseline_model(self, model_id: str): |
174 | 184 | ) |
175 | 185 | self.baseline_model_id = response.get("baseline_model_id") |
176 | 186 | return self.baseline_model_id |
| 187 | + |
| 188 | + def upload_external_evaluation_results( |
| 189 | + self, |
| 190 | + eval_fn: ExternalEvalFunction, |
| 191 | + results: List[EvaluationResult], |
| 192 | + model_id: str, |
| 193 | + ): |
| 194 | + assert ( |
| 195 | + eval_fn.eval_func_entry.is_external_function |
| 196 | + ), "Submitting evaluation results is only available for external functions." |
| 197 | + |
| 198 | + assert ( |
| 199 | + len(results) > 0 |
| 200 | + ), "Submitting evaluation requires at least one result." |
| 201 | + |
| 202 | + metric_per_ref_id = {} |
| 203 | + weight_per_ref_id = {} |
| 204 | + aggregate_weighted_sum = 0.0 |
| 205 | + aggregate_weight = 0.0 |
| 206 | + # aggregation based on https://en.wikipedia.org/wiki/Weighted_arithmetic_mean |
| 207 | + for r in results: |
| 208 | + metric_per_ref_id[r.item_ref_id] = r.score |
| 209 | + weight_per_ref_id[r.item_ref_id] = r.weight |
| 210 | + aggregate_weighted_sum += r.score * r.weight |
| 211 | + aggregate_weight += r.weight |
| 212 | + |
| 213 | + payload = { |
| 214 | + "unit_test_id": self.id, |
| 215 | + "eval_function_id": eval_fn.id, |
| 216 | + "result_per_ref_id": metric_per_ref_id, |
| 217 | + "weight_per_ref_id": weight_per_ref_id, |
| 218 | + "overall_metric": aggregate_weighted_sum / aggregate_weight, |
| 219 | + "model_id": model_id, |
| 220 | + "slice_id": self.slice_id, |
| 221 | + } |
| 222 | + response = self.connection.post( |
| 223 | + payload, |
| 224 | + "validate/scenario_test/upload_results", |
| 225 | + ) |
| 226 | + return response |
0 commit comments