@@ -38,18 +38,22 @@ def _encode(
38
38
class HuggingFaceSentenceEmbedder (TransformerSentenceEmbedder ):
39
39
def __init__ (self , config_string : str , batch_size : int = 128 ):
40
40
super ().__init__ (config_string , batch_size )
41
+ self .config_string = config_string
41
42
42
43
@staticmethod
43
44
def load (embedder : dict ) -> "HuggingFaceSentenceEmbedder" :
45
+ if os .path .exists (embedder ["config_string" ]):
46
+ config_string = embedder ["config_string" ]
47
+ else :
48
+ config_string = request_util .get_model_path (embedder ["config_string" ])
44
49
return HuggingFaceSentenceEmbedder (
45
- config_string = request_util .get_model_path (embedder ["config_string" ]),
46
- batch_size = embedder ["batch_size" ],
50
+ config_string = config_string , batch_size = embedder ["batch_size" ]
47
51
)
48
52
49
53
def to_json (self ) -> dict :
50
54
return {
51
55
"cls" : "HuggingFaceSentenceEmbedder" ,
52
- "config_string" : self .model . model_card_data . base_model ,
56
+ "config_string" : self .config_string ,
53
57
"batch_size" : self .batch_size ,
54
58
}
55
59
@@ -239,7 +243,9 @@ def _encode(
239
243
self , documents : List [Union [str , Doc ]], fit_model : bool
240
244
) -> Generator [List [List [float ]], None , None ]:
241
245
for documents_batch in util .batch (documents , self .batch_size ):
242
- documents_batch = [self ._trim_length (doc .replace ("\n " , " " )) for doc in documents_batch ]
246
+ documents_batch = [
247
+ self ._trim_length (doc .replace ("\n " , " " )) for doc in documents_batch
248
+ ]
243
249
try :
244
250
response = self .openai_client .embeddings .create (
245
251
input = documents_batch , model = self .model_name
@@ -270,11 +276,13 @@ def dump(self, project_id: str, embedding_id: str) -> None:
270
276
export_file .parent .mkdir (parents = True , exist_ok = True )
271
277
util .write_json (self .to_json (), export_file , indent = 2 )
272
278
273
- def _trim_length (self , text : str , max_length : int = 512 ) -> str :
279
+ def _trim_length (self , text : str , max_length : int = 512 ) -> str :
274
280
tokens = self ._auto_tokenizer (
275
281
text ,
276
282
truncation = True ,
277
283
max_length = max_length ,
278
- return_tensors = None # No tensors needed for just truncating
284
+ return_tensors = None , # No tensors needed for just truncating
285
+ )
286
+ return self ._auto_tokenizer .decode (
287
+ tokens ["input_ids" ], skip_special_tokens = True
279
288
)
280
- return self ._auto_tokenizer .decode (tokens ["input_ids" ], skip_special_tokens = True )
0 commit comments