@@ -516,6 +516,7 @@ def train_supervised(*kargs, **kwargs):
516516 'model' : "supervised"
517517 })
518518
519+ callback = kwargs .pop ("callback" , None )
519520 arg_names = ['input' , 'lr' , 'dim' , 'ws' , 'epoch' , 'minCount' ,
520521 'minCountLabel' , 'minn' , 'maxn' , 'neg' , 'wordNgrams' , 'loss' , 'bucket' ,
521522 'thread' , 'lrUpdateRate' , 't' , 'label' , 'verbose' , 'pretrainedVectors' ,
@@ -525,7 +526,10 @@ def train_supervised(*kargs, **kwargs):
525526 supervised_default )
526527 a = _build_args (args , manually_set_args )
527528 ft = _FastText (args = a )
528- fasttext .train (ft .f , a )
529+ if callback :
530+ fasttext .train_with_callback (ft .f , a , callback )
531+ else :
532+ fasttext .train (ft .f , a )
529533 ft .set_args (ft .f .getArgs ())
530534 return ft
531535
@@ -544,13 +548,18 @@ def train_unsupervised(*kargs, **kwargs):
544548 dataset pulled by the example script word-vector-example.sh, which is
545549 part of the fastText repository.
546550 """
551+ callback = kwargs .pop ("callback" , None )
547552 arg_names = ['input' , 'model' , 'lr' , 'dim' , 'ws' , 'epoch' , 'minCount' ,
548553 'minCountLabel' , 'minn' , 'maxn' , 'neg' , 'wordNgrams' , 'loss' , 'bucket' ,
549554 'thread' , 'lrUpdateRate' , 't' , 'label' , 'verbose' , 'pretrainedVectors' ]
550555 args , manually_set_args = read_args (kargs , kwargs , arg_names ,
551556 unsupervised_default )
552557 a = _build_args (args , manually_set_args )
553558 ft = _FastText (args = a )
559+ if callback :
560+ fasttext .train_with_callback (ft .f , a , callback )
561+ else :
562+ fasttext .train (ft .f , a )
554563 fasttext .train (ft .f , a )
555564 ft .set_args (ft .f .getArgs ())
556565 return ft
0 commit comments