1313
1414
1515class ThaiTextAugmenter :
16- def __init__ (self ,) -> None :
17- from transformers import (AutoTokenizer ,
18- AutoModelForMaskedLM ,
19- pipeline ,)
16+ def __init__ (self ) -> None :
17+ from transformers import (
18+ AutoTokenizer ,
19+ AutoModelForMaskedLM ,
20+ pipeline ,
21+ )
22+
2023 self .tokenizer = AutoTokenizer .from_pretrained (_MODEL_NAME )
21- self .model_for_masked_lm = AutoModelForMaskedLM .from_pretrained (_MODEL_NAME )
22- self .model = pipeline ("fill-mask" , tokenizer = self .tokenizer , model = self .model_for_masked_lm )
24+ self .model_for_masked_lm = AutoModelForMaskedLM .from_pretrained (
25+ _MODEL_NAME
26+ )
27+ self .model = pipeline (
28+ "fill-mask" ,
29+ tokenizer = self .tokenizer ,
30+ model = self .model_for_masked_lm ,
31+ )
2332 self .processor = ThaiTextProcessor ()
2433
25- def generate (self ,
26- sample_text : str ,
27- word_rank : int ,
28- max_length : int = 3 ,
29- sample : bool = False
30- ) -> str :
34+ def generate (
35+ self ,
36+ sample_text : str ,
37+ word_rank : int ,
38+ max_length : int = 3 ,
39+ sample : bool = False ,
40+ ) -> str :
3141 sample_txt = sample_text
3242 final_text = ""
3343
@@ -45,11 +55,9 @@ def generate(self,
4555
4656 return gen_txt
4757
48- def augment (self ,
49- text : str ,
50- num_augs : int = 3 ,
51- sample : bool = False
52- ) -> List [str ]:
58+ def augment (
59+ self , text : str , num_augs : int = 3 , sample : bool = False
60+ ) -> List [str ]:
5361 """
5462 Text augmentation from PhayaThaiBERT
5563
@@ -84,11 +92,14 @@ def augment(self,
8492 if num_augs <= MAX_NUM_AUGS :
8593 for rank in range (num_augs ):
8694 gen_text = self .generate (text , rank , sample = sample )
87- processed_text = re .sub ("<_>" , " " , self .processor .preprocess (gen_text ))
95+ processed_text = re .sub (
96+ "<_>" , " " , self .processor .preprocess (gen_text )
97+ )
8898 augment_list .append (processed_text )
99+ else :
100+ raise ValueError (
101+ f"augmentation of more than { num_augs } is exceeded \
102+ the default limit: { MAX_NUM_AUGS } "
103+ )
89104
90- return augment_list
91-
92- raise ValueError (
93- f"augmentation of more than { num_augs } is exceeded the default limit: { MAX_NUM_AUGS } "
94- )
105+ return augment_list
0 commit comments