Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ uv run aif generate config/aif_config.yaml allenai/OLMo-1B-hf
uv run aif validate
```

To log the validation to Opik for automated evaluation, set the following environment variables:

```sh
export OPIK_BASE_URL="..."
export OPIK_PROJECT_NAME="..."

# optional for self-hosted installation
export OPIK_API_KEY="..."

# optionally, specify dataset name sent to Opik
export DATASET_NAME="education_qna_direct"
```

### Transform Data

```sh
Expand Down
44 changes: 43 additions & 1 deletion aif_gen/dataset/validation/llm_judge.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import logging
import os
import random
from collections import defaultdict
from functools import lru_cache
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

if TYPE_CHECKING:
pass

import backoff
import numpy as np
Expand Down Expand Up @@ -52,6 +56,18 @@ async def llm_judge_validation(
cache = await AsyncElasticsearchCache.maybe_from_env_var(
f'CACHE_VALIDATION_{model_name}'
)
dataset_name = os.environ.get('DATASET_NAME', 'unspecified')
source_model_name = os.environ.get('MODEL_NAME', 'unspecified')
opik_base_url = os.environ.get('OPIK_BASE_URL')
opik_client = None
if opik_base_url is not None:
import opik

opik_client = opik.Opik(
host=opik_base_url,
project_name=os.environ.get('OPIK_PROJECT_NAME'),
api_key=os.environ.get('OPIK_API_KEY'),
)

if dry_run:
logging.info(f'Doing dry-run data validation on a single sample...')
Expand Down Expand Up @@ -139,6 +155,32 @@ async def llm_judge_validation(
if score is not None:
results[dataset_idx][metric_name].append(score)

# Log to Opik if provided.
for _dataset_idx, (dataset, stats) in enumerate(zip(datasets, results)):
for _sample_idx, sample in list(enumerate(dataset.samples)):
responses = [sample.chosen, sample.rejected]
random.shuffle(responses)

if opik_client is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, but this check can go outside the loop

opik_client.trace(
name=f'{_sample_idx:05d} of dataset {dataset_name}/{_dataset_idx:02d}',
input={'prompt': sample.prompt},
output={
'task': str(dataset.task),
**{f'response {k}': v for k, v in enumerate(responses)},
},
metadata={
'chosen': sample.chosen,
'rejected': sample.rejected,
'responses': responses,
'source_model_name': source_model_name,
**{
_metric_name: metrics[_sample_idx]
for _metric_name, metrics in stats.items()
},
},
)

aggregated_results: List[Optional[Dict[str, float]]] = []
for i, dataset in enumerate(datasets):
if not len(dataset):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"nltk>=3.9.1",
"numpy>=2.0.2",
"openai>=1.61.1",
"opik>=1.7.9",
"pydantic>=2.10.4",
"pytest-asyncio>=0.25.3",
"pytest-mock>=3.14.0",
Expand Down
Loading