Skip to content

Commit 9ed33ec

Browse files
committed
add qa evals
1 parent 984753a commit 9ed33ec

File tree

3 files changed

+360
-0
lines changed

3 files changed

+360
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from collections.abc import Iterable
2+
3+
from datasets import Dataset
4+
5+
from ragbits.core.sources.base import Source
6+
from ragbits.evaluate.dataloaders.base import DataLoader
7+
from ragbits.evaluate.pipelines.question_answer import QuestionAnswerData
8+
9+
10+
class QuestionAnswerDataLoader(DataLoader[QuestionAnswerData]):
11+
"""
12+
Question answer evaluation data loader.
13+
14+
The source used for this data loader should point to a file that can be loaded by [Hugging Face](https://huggingface.co/docs/datasets/loading#local-and-remote-files).
15+
"""
16+
17+
def __init__(
18+
self,
19+
source: Source,
20+
*,
21+
split: str = "data",
22+
question_key: str = "question",
23+
answer_key: str = "answer",
24+
context_key: str = "context",
25+
) -> None:
26+
"""
27+
Initialize the question answer data loader.
28+
29+
Args:
30+
source: The source to load the data from.
31+
split: The split to load the data from.
32+
required_keys: The required keys to load the data from.
33+
question_key: The dataset column name that contains the question.
34+
answer_key: The dataset column name that contains the answer.
35+
context_key: The dataset column name that contains the context. Context is optional.
36+
"""
37+
super().__init__(source=source, split=split, required_keys={question_key, answer_key})
38+
self.question_key = question_key
39+
self.answer_key = answer_key
40+
self.context_key = context_key
41+
42+
async def map(self, dataset: Dataset) -> Iterable[QuestionAnswerData]:
43+
"""
44+
Map the dataset to the question answer data.
45+
46+
Args:
47+
dataset: The dataset to map.
48+
49+
Returns:
50+
The question answer data.
51+
"""
52+
return [
53+
QuestionAnswerData(
54+
question=data.get(self.question_key),
55+
reference_answer=data.get(self.answer_key),
56+
reference_context=data.get(self.context_key),
57+
)
58+
for data in dataset
59+
]
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import asyncio
2+
from abc import ABC, abstractmethod
3+
from itertools import chain
4+
from typing import Generic, TypeVar
5+
6+
import litellm
7+
from continuous_eval.llm_factory import LLMInterface
8+
from continuous_eval.metrics.base import LLMBasedMetric
9+
from continuous_eval.metrics.generation.text import (
10+
LLMBasedAnswerCorrectness,
11+
LLMBasedAnswerRelevance,
12+
LLMBasedFaithfulness,
13+
LLMBasedStyleConsistency,
14+
)
15+
from typing_extensions import Self
16+
17+
from ragbits.agents.types import QuestionAnswerPromptOutputT
18+
from ragbits.core.llms.base import LLM
19+
from ragbits.core.llms.litellm import LiteLLM
20+
from ragbits.core.utils.helpers import batched
21+
from ragbits.evaluate.metrics.base import Metric
22+
from ragbits.evaluate.pipelines.question_answer import QuestionAnswerResult
23+
24+
MetricT = TypeVar("MetricT", bound=LLMBasedMetric)
25+
26+
27+
class _MetricLMM(LLMInterface):
28+
"""
29+
Implementation of required interface of Relari generative metrics based on LiteLMM.
30+
"""
31+
32+
def __init__(
33+
self, model_name: str, api_base: str | None = None, api_version: str | None = None, api_key: str | None = None
34+
) -> None:
35+
self._model_name = model_name
36+
self._api_base = api_base
37+
self._api_version = api_version
38+
self._api_key = api_key
39+
40+
def run(self, prompt: dict[str, str], temperature: float = 0, max_tokens: int = 1024) -> str:
41+
"""
42+
Run the prompt.
43+
44+
Args:
45+
prompt: Dict with system_prompt and user_prompt entries.
46+
temperature: Temperature to use.
47+
max_tokens: Max tokens to use.
48+
"""
49+
response = litellm.completion(
50+
model=self._model_name,
51+
messages=[
52+
{"role": "system", "content": prompt["system_prompt"]},
53+
{"role": "user", "content": prompt["user_prompt"]},
54+
],
55+
base_url=self._api_base,
56+
api_version=self._api_version,
57+
api_key=self._api_key,
58+
)
59+
return response.choices[0].message.content
60+
61+
62+
class QuestionAnswerMetric(Generic[MetricT], Metric[QuestionAnswerResult], ABC):
63+
"""
64+
Metric for question answer evaluation based on Relari backend.
65+
More details can be found [here](https://docs.relari.ai/category/text-generation).
66+
"""
67+
68+
metric_cls: type[MetricT]
69+
70+
def __init__(self, llm: LiteLLM, batch_size: int = 15, weight: float = 1.0) -> None:
71+
"""
72+
Initialize the agent metric.
73+
74+
Args:
75+
llm: Judge LLM instance.
76+
batch_size: Batch size for metric computation.
77+
weight: Metric value weight in the final score, used during optimization.
78+
"""
79+
super().__init__(weight=weight)
80+
self.metric = self.metric_cls(
81+
_MetricLMM(
82+
model_name=llm.model_name,
83+
api_base=llm.api_base,
84+
api_version=llm.api_version,
85+
api_key=llm.api_key,
86+
)
87+
)
88+
self.batch_size = batch_size
89+
90+
@classmethod
91+
def from_config(cls, config: dict) -> Self:
92+
"""
93+
Create an instance of `QuestionAnswerMetric` from a configuration dictionary.
94+
95+
Args:
96+
config: A dictionary containing configuration settings for the metric.
97+
98+
Returns:
99+
An instance of the metric class initialized with the provided configuration.
100+
"""
101+
config["llm"] = LLM.from_config(config["llm"])
102+
config["batch_size"] = config.get("batch_size", 15)
103+
config["weight"] = config.get("weight", 1.0)
104+
return super().from_config(config)
105+
106+
async def compute(self, results: list[QuestionAnswerResult[QuestionAnswerPromptOutputT]]) -> dict:
107+
"""
108+
Compute the metric.
109+
110+
Args:
111+
results: The evaluation results.
112+
113+
Returns:
114+
The computed metric.
115+
"""
116+
metric_results = chain.from_iterable(
117+
[
118+
await asyncio.gather(*[asyncio.to_thread(self._call_metric, result) for result in batch])
119+
for batch in batched(results, self.batch_size)
120+
]
121+
)
122+
return self.metric.aggregate(list(metric_results))
123+
124+
@abstractmethod
125+
def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
126+
"""
127+
Call the metric with the proper arguments.
128+
"""
129+
130+
131+
class QuestionAnswerAnswerCorrectness(QuestionAnswerMetric[LLMBasedAnswerCorrectness]):
132+
"""
133+
Metric checking answer correctness based on LLM.
134+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_correctness).
135+
"""
136+
137+
metric_cls: type[LLMBasedAnswerCorrectness] = LLMBasedAnswerCorrectness
138+
139+
def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
140+
return self.metric(
141+
question=result.question,
142+
answer=(
143+
result.predicted_result.content
144+
if isinstance(result.predicted_result.content, str)
145+
else result.predicted_result.content.answer
146+
),
147+
ground_truth_answers=result.reference_answer,
148+
)
149+
150+
151+
class QuestionAnswerAnswerFaithfulness(QuestionAnswerMetric[LLMBasedFaithfulness]):
152+
"""
153+
Metric checking answer faithfulness based on LLM.
154+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_faithfulness).
155+
"""
156+
157+
metric_cls: type[LLMBasedFaithfulness] = LLMBasedFaithfulness
158+
159+
def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
160+
return self.metric(
161+
question=result.question,
162+
answer=(
163+
result.predicted_result.content
164+
if isinstance(result.predicted_result.content, str)
165+
else result.predicted_result.content.answer
166+
),
167+
retrieved_context=result.reference_context,
168+
)
169+
170+
171+
class QuestionAnswerAnswerRelevance(QuestionAnswerMetric[LLMBasedAnswerRelevance]):
172+
"""
173+
Metric checking answer relevance based on LLM.
174+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_relevance).
175+
"""
176+
177+
metric_cls: type[LLMBasedAnswerRelevance] = LLMBasedAnswerRelevance
178+
179+
def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
180+
return self.metric(
181+
question=result.question,
182+
answer=(
183+
result.predicted_result.content
184+
if isinstance(result.predicted_result.content, str)
185+
else result.predicted_result.content.answer
186+
),
187+
)
188+
189+
190+
class QuestionAnswerAnswerConsistency(QuestionAnswerMetric[LLMBasedStyleConsistency]):
191+
"""
192+
Metric checking answer relevance based on LLM.
193+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_style).
194+
"""
195+
196+
metric_cls: type[LLMBasedStyleConsistency] = LLMBasedStyleConsistency
197+
198+
def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
199+
return self.metric(
200+
answer=(
201+
result.predicted_result.content
202+
if isinstance(result.predicted_result.content, str)
203+
else result.predicted_result.content.answer
204+
),
205+
ground_truth_answers=result.reference_answer,
206+
)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import asyncio
2+
from collections.abc import Iterable
3+
from dataclasses import dataclass
4+
from typing import Any, Generic
5+
6+
from typing_extensions import Self
7+
8+
from ragbits.agents._main import AgentResult
9+
from ragbits.agents.types import (
10+
QuestionAnswerAgent,
11+
QuestionAnswerPromptInput,
12+
QuestionAnswerPromptOutputT,
13+
)
14+
from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult
15+
16+
17+
class QuestionAnswerData(EvaluationData):
18+
"""
19+
Represents the evaluation data for question answer.
20+
"""
21+
22+
question: str
23+
reference_answer: str
24+
reference_context: Any | None = None
25+
26+
27+
@dataclass
28+
class QuestionAnswerResult(EvaluationResult, Generic[QuestionAnswerPromptOutputT]):
29+
"""
30+
Represents the result of a single evaluation.
31+
"""
32+
33+
question: str
34+
predicted_result: AgentResult[QuestionAnswerPromptOutputT]
35+
reference_answer: str
36+
reference_context: Any | None = None
37+
38+
39+
class QuestionAnswerPipeline(
40+
EvaluationPipeline[
41+
QuestionAnswerAgent[QuestionAnswerPromptInput, QuestionAnswerPromptOutputT],
42+
QuestionAnswerData,
43+
QuestionAnswerResult,
44+
]
45+
):
46+
"""
47+
Question answer evaluation pipeline.
48+
"""
49+
50+
@classmethod
51+
def from_config(cls, config: dict) -> Self:
52+
"""
53+
Create an instance of `QuestionAnswerPipeline` from a configuration dictionary.
54+
55+
Args:
56+
config: A dictionary containing configuration settings for the pipeline.
57+
58+
Returns:
59+
An instance of the pipeline class initialized with the provided configuration.
60+
"""
61+
config["evaluation_target"] = QuestionAnswerAgent.from_config(config)
62+
return super().from_config(config)
63+
64+
async def __call__(
65+
self, data: Iterable[QuestionAnswerData]
66+
) -> Iterable[QuestionAnswerResult[QuestionAnswerPromptOutputT]]:
67+
"""
68+
Run the question answer evaluation pipeline.
69+
70+
Args:
71+
data: The evaluation data batch.
72+
73+
Returns:
74+
The evaluation result batch.
75+
"""
76+
results = await asyncio.gather(
77+
*[
78+
self.evaluation_target.run(
79+
QuestionAnswerPromptInput(
80+
question=row.question,
81+
context=row.reference_context,
82+
)
83+
)
84+
for row in data
85+
]
86+
)
87+
return [
88+
QuestionAnswerResult(
89+
question=row.question,
90+
predicted_result=result,
91+
reference_answer=row.reference_answer,
92+
reference_context=row.reference_context,
93+
)
94+
for row, result in zip(data, results, strict=False)
95+
]

0 commit comments

Comments
 (0)