1
+ # Copyright (c) "Neo4j"
2
+ # Neo4j Sweden AB [https://neo4j.com]
3
+ # #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ # #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ # #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
from __future__ import annotations
2
17
3
18
from typing import Any
6
21
7
22
8
23
class OpenAIEmbeddings (Embedder ):
9
- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
24
+ def __init__ (self , model : str = "text-embedding-ada-002" ) -> None :
10
25
try :
11
26
import openai
12
27
except ImportError :
@@ -15,10 +30,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
15
30
"Please install it with `pip install openai`."
16
31
)
17
32
18
- self .model = openai .OpenAI (* args , ** kwargs )
33
+ self .openai_model = openai .OpenAI ()
34
+ self .model = model
19
35
20
- def embed_query (
21
- self , text : str , model : str = "text-embedding-ada-002" , ** kwargs : Any
22
- ) -> list [ float ]:
23
- response = self . model . embeddings . create ( input = text , model = model , ** kwargs )
36
+ def embed_query (self , text : str , ** kwargs : Any ) -> list [ float ]:
37
+ response = self . openai_model . embeddings . create (
38
+ input = text , model = self . model , ** kwargs
39
+ )
24
40
return response .data [0 ].embedding
0 commit comments