Skip to content

Commit 72fc92c

Browse files
meilame-tayebjeemicedre
authored andcommitted
feat: Enforce right values in y_train to avoid out of index error
np.max(y_train) == len(np.unique(y_train))-1 should be True Solves #53 and #54
1 parent 70ddd64 commit 72fc92c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchFastText/torchFastText.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ def build(
311311
self.num_classes = len(
312312
np.unique(y_train)
313313
) # Be sure that y_train contains all the classes !
314+
315+
if np.max(y_train) >= self.num_classes:
316+
raise ValueError(
317+
f"y_train must contain values between 0 and {self.num_classes - 1}. Make sure that np.max(y_train) == len(np.unique(y_train))-1."
318+
)
319+
314320
else:
315321
if self.num_classes is None:
316322
raise ValueError(

0 commit comments

Comments
 (0)