Skip to content

Commit 42912b3

Browse files
Merge pull request #163 from code-kern-ai/hotfix-embedder-rework
fix: huggingface config dump
2 parents 79d909b + fd73dc6 commit 42912b3

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/embedders/classification/contextual.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,22 @@ def _encode(
3838
class HuggingFaceSentenceEmbedder(TransformerSentenceEmbedder):
3939
def __init__(self, config_string: str, batch_size: int = 128):
4040
super().__init__(config_string, batch_size)
41+
self.config_string = config_string
4142

4243
@staticmethod
4344
def load(embedder: dict) -> "HuggingFaceSentenceEmbedder":
45+
if os.path.exists(embedder["config_string"]):
46+
config_string = embedder["config_string"]
47+
else:
48+
config_string = request_util.get_model_path(embedder["config_string"])
4449
return HuggingFaceSentenceEmbedder(
45-
config_string=request_util.get_model_path(embedder["config_string"]),
46-
batch_size=embedder["batch_size"],
50+
config_string=config_string, batch_size=embedder["batch_size"]
4751
)
4852

4953
def to_json(self) -> dict:
5054
return {
5155
"cls": "HuggingFaceSentenceEmbedder",
52-
"config_string": self.model.model_card_data.base_model,
56+
"config_string": self.config_string,
5357
"batch_size": self.batch_size,
5458
}
5559

@@ -239,7 +243,9 @@ def _encode(
239243
self, documents: List[Union[str, Doc]], fit_model: bool
240244
) -> Generator[List[List[float]], None, None]:
241245
for documents_batch in util.batch(documents, self.batch_size):
242-
documents_batch = [self._trim_length(doc.replace("\n", " ")) for doc in documents_batch]
246+
documents_batch = [
247+
self._trim_length(doc.replace("\n", " ")) for doc in documents_batch
248+
]
243249
try:
244250
response = self.openai_client.embeddings.create(
245251
input=documents_batch, model=self.model_name
@@ -270,11 +276,13 @@ def dump(self, project_id: str, embedding_id: str) -> None:
270276
export_file.parent.mkdir(parents=True, exist_ok=True)
271277
util.write_json(self.to_json(), export_file, indent=2)
272278

273-
def _trim_length(self, text: str, max_length: int=512) -> str:
279+
def _trim_length(self, text: str, max_length: int = 512) -> str:
274280
tokens = self._auto_tokenizer(
275281
text,
276282
truncation=True,
277283
max_length=max_length,
278-
return_tensors=None # No tensors needed for just truncating
284+
return_tensors=None, # No tensors needed for just truncating
285+
)
286+
return self._auto_tokenizer.decode(
287+
tokens["input_ids"], skip_special_tokens=True
279288
)
280-
return self._auto_tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)

0 commit comments

Comments
 (0)