Skip to content

feat: add evals for question answering #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/ragbits-evaluate/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Add evals for question answering (#577)
- Add support for slicing dataset (#576)
- Separate load and map ops in data loaders (#576)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from collections.abc import Iterable

from ragbits.core.sources.base import Source
from ragbits.evaluate.dataloaders.base import DataLoader
from ragbits.evaluate.pipelines.question_answer import QuestionAnswerData


class QuestionAnswerDataLoader(DataLoader[QuestionAnswerData]):
"""
Question answer evaluation data loader.

The source used for this data loader should point to a file that can be loaded by [Hugging Face](https://huggingface.co/docs/datasets/loading#local-and-remote-files).
"""

def __init__(
self,
source: Source,
*,
split: str = "data",
question_key: str = "question",
answer_key: str = "answer",
context_key: str = "context",
) -> None:
"""
Initialize the question answer data loader.

Args:
source: The source to load the data from.
split: The split to load the data from.
required_keys: The required keys to load the data from.
question_key: The dataset column name that contains the question.
answer_key: The dataset column name that contains the answer.
context_key: The dataset column name that contains the context. Context is optional.
"""
super().__init__(source=source, split=split, required_keys={question_key, answer_key})
self.question_key = question_key
self.answer_key = answer_key
self.context_key = context_key

async def map(self, dataset: Iterable[dict]) -> Iterable[QuestionAnswerData]:
"""
Map the dataset to the question answer data schema.

Args:
dataset: The dataset to map.

Returns:
The question answer data.
"""
return [
QuestionAnswerData(
question=data.get(self.question_key, ""),
reference_answer=data.get(self.answer_key, ""),
reference_context=data.get(self.context_key),
)
for data in dataset
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import asyncio
from abc import ABC, abstractmethod
from itertools import chain
from typing import Generic, TypeVar

from continuous_eval.llm_factory import LLMInterface
from continuous_eval.metrics.base import LLMBasedMetric
from continuous_eval.metrics.generation.text import (
LLMBasedAnswerCorrectness,
LLMBasedAnswerRelevance,
LLMBasedFaithfulness,
LLMBasedStyleConsistency,
)
from typing_extensions import Self

from ragbits.agents.types import QuestionAnswerPromptOutputT
from ragbits.core.llms.base import LLM
from ragbits.core.utils.helpers import batched
from ragbits.evaluate.metrics.base import Metric
from ragbits.evaluate.pipelines.question_answer import QuestionAnswerResult

MetricT = TypeVar("MetricT", bound=LLMBasedMetric)


class _MetricLMM(LLMInterface):
"""
Implementation of required interface of Relari generative metrics based on LiteLMM.
"""

def __init__(self, llm: LLM) -> None:
self._llm = llm

def run(self, prompt: dict[str, str], temperature: float = 0, max_tokens: int = 1024) -> str:
formatted_prompt = [
{"role": "system", "content": prompt["system_prompt"]},
{"role": "user", "content": prompt["user_prompt"]},
]
options = self._llm.options_cls(
temperature=temperature,
max_tokens=max_tokens,
)
return asyncio.run(self._llm.generate(formatted_prompt, options=options))


class QuestionAnswerMetric(Generic[MetricT], Metric[QuestionAnswerResult], ABC):
"""
Metric for question answer evaluation based on Relari backend.
More details can be found [here](https://docs.relari.ai/category/text-generation).
"""

metric_cls: type[MetricT]

def __init__(self, llm: LLM, batch_size: int = 15, weight: float = 1.0) -> None:
"""
Initialize the agent metric.

Args:
llm: Judge LLM instance.
batch_size: Batch size for metric computation.
weight: Metric value weight in the final score, used during optimization.
"""
super().__init__(weight=weight)
self.metric = self.metric_cls(_MetricLMM(llm))
self.batch_size = batch_size

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Create an instance of `QuestionAnswerMetric` from a configuration dictionary.

Args:
config: A dictionary containing configuration settings for the metric.

Returns:
An instance of the metric class initialized with the provided configuration.
"""
config["llm"] = LLM.from_config(config["llm"])
config["batch_size"] = config.get("batch_size", 15)
config["weight"] = config.get("weight", 1.0)
return super().from_config(config)

async def compute(self, results: list[QuestionAnswerResult[QuestionAnswerPromptOutputT]]) -> dict:
"""
Compute the metric.

Args:
results: The evaluation results.

Returns:
The computed metric.
"""
metric_results = chain.from_iterable(
[
await asyncio.gather(*[asyncio.to_thread(self._call_metric, result) for result in batch])
for batch in batched(results, self.batch_size)
]
)
return self.metric.aggregate(list(metric_results))

@abstractmethod
def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
"""
Call the metric with the proper arguments.
"""


class QuestionAnswerAnswerCorrectness(QuestionAnswerMetric[LLMBasedAnswerCorrectness]):
"""
Metric checking answer correctness based on LLM.
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_correctness).
"""

metric_cls: type[LLMBasedAnswerCorrectness] = LLMBasedAnswerCorrectness

def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
return self.metric(
question=result.question,
answer=(
result.predicted_result.content
if isinstance(result.predicted_result.content, str)
else result.predicted_result.content.answer
),
ground_truth_answers=result.reference_answer,
)


class QuestionAnswerAnswerFaithfulness(QuestionAnswerMetric[LLMBasedFaithfulness]):
"""
Metric checking answer faithfulness based on LLM.
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_faithfulness).
"""

metric_cls: type[LLMBasedFaithfulness] = LLMBasedFaithfulness

def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
return self.metric(
question=result.question,
answer=(
result.predicted_result.content
if isinstance(result.predicted_result.content, str)
else result.predicted_result.content.answer
),
retrieved_context=result.reference_context,
)


class QuestionAnswerAnswerRelevance(QuestionAnswerMetric[LLMBasedAnswerRelevance]):
"""
Metric checking answer relevance based on LLM.
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_relevance).
"""

metric_cls: type[LLMBasedAnswerRelevance] = LLMBasedAnswerRelevance

def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
return self.metric(
question=result.question,
answer=(
result.predicted_result.content
if isinstance(result.predicted_result.content, str)
else result.predicted_result.content.answer
),
)


class QuestionAnswerAnswerConsistency(QuestionAnswerMetric[LLMBasedStyleConsistency]):
"""
Metric checking answer relevance based on LLM.
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_style).
"""

metric_cls: type[LLMBasedStyleConsistency] = LLMBasedStyleConsistency

def _call_metric(self, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
return self.metric(
answer=(
result.predicted_result.content
if isinstance(result.predicted_result.content, str)
else result.predicted_result.content.answer
),
ground_truth_answers=result.reference_answer,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Generic

from typing_extensions import Self

from ragbits.agents._main import AgentResult
from ragbits.agents.types import (
QuestionAnswerAgent,
QuestionAnswerPromptInput,
QuestionAnswerPromptOutputT,
)
from ragbits.core.llms.base import LLMClientOptionsT
from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult


class QuestionAnswerData(EvaluationData):
"""
Represents the evaluation data for question answer.
"""

question: str
reference_answer: str
reference_context: Any | None = None


@dataclass
class QuestionAnswerResult(EvaluationResult, Generic[QuestionAnswerPromptOutputT]):
"""
Represents the result of a single evaluation.
"""

question: str
predicted_result: AgentResult[QuestionAnswerPromptOutputT]
reference_answer: str
reference_context: Any | None = None


class QuestionAnswerPipeline(
EvaluationPipeline[
QuestionAnswerAgent[LLMClientOptionsT, QuestionAnswerPromptInput, QuestionAnswerPromptOutputT],
QuestionAnswerData,
QuestionAnswerResult,
]
):
"""
Question answer evaluation pipeline.
"""

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Create an instance of `QuestionAnswerPipeline` from a configuration dictionary.

Args:
config: A dictionary containing configuration settings for the pipeline.

Returns:
An instance of the pipeline class initialized with the provided configuration.
"""
config["evaluation_target"] = QuestionAnswerAgent.from_config(config)
return super().from_config(config)

async def __call__(
self, data: Iterable[QuestionAnswerData]
) -> Iterable[QuestionAnswerResult[QuestionAnswerPromptOutputT]]:
"""
Run the question answer evaluation pipeline.

Args:
data: The evaluation data batch.

Returns:
The evaluation result batch.
"""
results = await asyncio.gather(
*[
self.evaluation_target.run(
QuestionAnswerPromptInput(
question=row.question,
context=row.reference_context,
)
)
for row in data
]
)
return [
QuestionAnswerResult(
question=row.question,
predicted_result=result,
reference_answer=row.reference_answer,
reference_context=row.reference_context,
)
for row, result in zip(data, results, strict=False)
]
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading