3
3
from itertools import chain
4
4
from typing import Generic , TypeVar
5
5
6
- import litellm
7
6
from continuous_eval .llm_factory import LLMInterface
8
7
from continuous_eval .metrics .base import LLMBasedMetric
9
8
from continuous_eval .metrics .generation .text import (
16
15
17
16
from ragbits .agents .types import QuestionAnswerPromptOutputT
18
17
from ragbits .core .llms .base import LLM
19
- from ragbits .core .llms .litellm import LiteLLM
20
18
from ragbits .core .utils .helpers import batched
21
19
from ragbits .evaluate .metrics .base import Metric
22
20
from ragbits .evaluate .pipelines .question_answer import QuestionAnswerResult
@@ -30,12 +28,10 @@ class _MetricLMM(LLMInterface):
30
28
"""
31
29
32
30
def __init__ (
33
- self , model_name : str , api_base : str | None = None , api_version : str | None = None , api_key : str | None = None
31
+ self ,
32
+ llm : LLM ,
34
33
) -> None :
35
- self ._model_name = model_name
36
- self ._api_base = api_base
37
- self ._api_version = api_version
38
- self ._api_key = api_key
34
+ self ._llm = llm
39
35
40
36
def run (self , prompt : dict [str , str ], temperature : float = 0 , max_tokens : int = 1024 ) -> str :
41
37
"""
@@ -46,17 +42,19 @@ def run(self, prompt: dict[str, str], temperature: float = 0, max_tokens: int =
46
42
temperature: Temperature to use.
47
43
max_tokens: Max tokens to use.
48
44
"""
49
- response = litellm .completion (
50
- model = self ._model_name ,
51
- messages = [
52
- {"role" : "system" , "content" : prompt ["system_prompt" ]},
53
- {"role" : "user" , "content" : prompt ["user_prompt" ]},
54
- ],
55
- base_url = self ._api_base ,
56
- api_version = self ._api_version ,
57
- api_key = self ._api_key ,
45
+ response = asyncio .run (
46
+ self ._llm .generate (
47
+ [
48
+ {"role" : "system" , "content" : prompt ["system_prompt" ]},
49
+ {"role" : "user" , "content" : prompt ["user_prompt" ]},
50
+ ],
51
+ options = self ._llm .options_cls (
52
+ temperature = temperature ,
53
+ max_tokens = max_tokens ,
54
+ ),
55
+ )
58
56
)
59
- return response . choices [ 0 ]. message . content
57
+ return response
60
58
61
59
62
60
class QuestionAnswerMetric (Generic [MetricT ], Metric [QuestionAnswerResult ], ABC ):
@@ -67,7 +65,7 @@ class QuestionAnswerMetric(Generic[MetricT], Metric[QuestionAnswerResult], ABC):
67
65
68
66
metric_cls : type [MetricT ]
69
67
70
- def __init__ (self , llm : LiteLLM , batch_size : int = 15 , weight : float = 1.0 ) -> None :
68
+ def __init__ (self , llm : LLM , batch_size : int = 15 , weight : float = 1.0 ) -> None :
71
69
"""
72
70
Initialize the agent metric.
73
71
@@ -77,14 +75,7 @@ def __init__(self, llm: LiteLLM, batch_size: int = 15, weight: float = 1.0) -> N
77
75
weight: Metric value weight in the final score, used during optimization.
78
76
"""
79
77
super ().__init__ (weight = weight )
80
- self .metric = self .metric_cls (
81
- _MetricLMM (
82
- model_name = llm .model_name ,
83
- api_base = llm .api_base ,
84
- api_version = llm .api_version ,
85
- api_key = llm .api_key ,
86
- )
87
- )
78
+ self .metric = self .metric_cls (_MetricLMM (llm ))
88
79
self .batch_size = batch_size
89
80
90
81
@classmethod
0 commit comments