Skip to content

Commit d3a201f

Browse files
authored
Mh/add qa evals llmfix (#580)
1 parent 9ae0ac2 commit d3a201f

File tree

1 file changed

+17
-26
lines changed

1 file changed

+17
-26
lines changed

packages/ragbits-evaluate/src/ragbits/evaluate/metrics/question_answer.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from itertools import chain
44
from typing import Generic, TypeVar
55

6-
import litellm
76
from continuous_eval.llm_factory import LLMInterface
87
from continuous_eval.metrics.base import LLMBasedMetric
98
from continuous_eval.metrics.generation.text import (
@@ -16,7 +15,6 @@
1615

1716
from ragbits.agents.types import QuestionAnswerPromptOutputT
1817
from ragbits.core.llms.base import LLM
19-
from ragbits.core.llms.litellm import LiteLLM
2018
from ragbits.core.utils.helpers import batched
2119
from ragbits.evaluate.metrics.base import Metric
2220
from ragbits.evaluate.pipelines.question_answer import QuestionAnswerResult
@@ -30,12 +28,10 @@ class _MetricLMM(LLMInterface):
3028
"""
3129

3230
def __init__(
33-
self, model_name: str, api_base: str | None = None, api_version: str | None = None, api_key: str | None = None
31+
self,
32+
llm: LLM,
3433
) -> None:
35-
self._model_name = model_name
36-
self._api_base = api_base
37-
self._api_version = api_version
38-
self._api_key = api_key
34+
self._llm = llm
3935

4036
def run(self, prompt: dict[str, str], temperature: float = 0, max_tokens: int = 1024) -> str:
4137
"""
@@ -46,17 +42,19 @@ def run(self, prompt: dict[str, str], temperature: float = 0, max_tokens: int =
4642
temperature: Temperature to use.
4743
max_tokens: Max tokens to use.
4844
"""
49-
response = litellm.completion(
50-
model=self._model_name,
51-
messages=[
52-
{"role": "system", "content": prompt["system_prompt"]},
53-
{"role": "user", "content": prompt["user_prompt"]},
54-
],
55-
base_url=self._api_base,
56-
api_version=self._api_version,
57-
api_key=self._api_key,
45+
response = asyncio.run(
46+
self._llm.generate(
47+
[
48+
{"role": "system", "content": prompt["system_prompt"]},
49+
{"role": "user", "content": prompt["user_prompt"]},
50+
],
51+
options=self._llm.options_cls(
52+
temperature=temperature,
53+
max_tokens=max_tokens,
54+
),
55+
)
5856
)
59-
return response.choices[0].message.content
57+
return response
6058

6159

6260
class QuestionAnswerMetric(Generic[MetricT], Metric[QuestionAnswerResult], ABC):
@@ -67,7 +65,7 @@ class QuestionAnswerMetric(Generic[MetricT], Metric[QuestionAnswerResult], ABC):
6765

6866
metric_cls: type[MetricT]
6967

70-
def __init__(self, llm: LiteLLM, batch_size: int = 15, weight: float = 1.0) -> None:
68+
def __init__(self, llm: LLM, batch_size: int = 15, weight: float = 1.0) -> None:
7169
"""
7270
Initialize the agent metric.
7371
@@ -77,14 +75,7 @@ def __init__(self, llm: LiteLLM, batch_size: int = 15, weight: float = 1.0) -> N
7775
weight: Metric value weight in the final score, used during optimization.
7876
"""
7977
super().__init__(weight=weight)
80-
self.metric = self.metric_cls(
81-
_MetricLMM(
82-
model_name=llm.model_name,
83-
api_base=llm.api_base,
84-
api_version=llm.api_version,
85-
api_key=llm.api_key,
86-
)
87-
)
78+
self.metric = self.metric_cls(_MetricLMM(llm))
8879
self.batch_size = batch_size
8980

9081
@classmethod

0 commit comments

Comments
 (0)