Skip to content

Commit ad7f714

Browse files
hfrunner.classify should return list[list[float]] not list[str] (#29671)
Signed-off-by: Chukwuma Nwaugha <nwaughac@gmail.com>
1 parent f4341f4 commit ad7f714

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,14 +459,17 @@ def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
459459
embeddings.append(embedding)
460460
return embeddings
461461

462-
def classify(self, prompts: list[str]) -> list[str]:
462+
def classify(self, prompts: list[str]) -> list[list[float]]:
463463
# output is final logits
464464
all_inputs = self.get_inputs(prompts)
465-
outputs = []
465+
outputs: list[list[float]] = []
466466
problem_type = getattr(self.config, "problem_type", "")
467467

468468
for inputs in all_inputs:
469469
output = self.model(**self.wrap_device(inputs))
470+
471+
assert isinstance(output.logits, torch.Tensor)
472+
470473
if problem_type == "regression":
471474
logits = output.logits[0].tolist()
472475
elif problem_type == "multi_label_classification":

0 commit comments

Comments
 (0)