diff --git a/project/nanoeval/nanoeval/evaluation.py b/project/nanoeval/nanoeval/evaluation.py index 068b900e..73a9aa6c 100644 --- a/project/nanoeval/nanoeval/evaluation.py +++ b/project/nanoeval/nanoeval/evaluation.py @@ -43,26 +43,16 @@ async def _make_pbar(spec: EvalSpec, n_tasks: int, recorder: RecorderProtocol) - def _create_clean_results(results: list[tuple[Task, Any]]) -> list[tuple[Task, Any]]: - """ - Clean results are defined as the LAST retry_idx for each task. These results are what get - used to compute metrics in the final summary. - """ + """Return only the latest attempt for each task.""" - # Pick the latest retry idx for each task - clean_results: list[tuple[Task, Any]] = [] + latest: dict[tuple[str, int], tuple[Task, Any]] = {} for task, result in results: - found = False - for i, (clean_task, _) in enumerate(clean_results): - if (clean_task.question_id, clean_task.attempt_id) == ( - task.question_id, - task.attempt_id, - ): - found = True - if task.retry_idx > clean_task.retry_idx: - clean_results[i] = (task, result) - break - if not found: - clean_results.append((task, result)) + key = (task.question_id, task.attempt_id) + current = latest.get(key) + if current is None or task.retry_idx > current[0].retry_idx: + latest[key] = (task, result) + + clean_results = list(latest.values()) # Sanity check the deduplication function. seen_tasks: set[tuple[str, int]] = set()