Skip to content

Commit fa22d7d

Browse files
authored
Update __init__.py
1 parent c2bc010 commit fa22d7d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pythaiasr/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=F
2525
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut"
2626
]
2727
assert self.model_name in self.support_model
28-
if not lm:
28+
self.lm =lm
29+
if not self.lm:
2930
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3031
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
3132
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
@@ -64,7 +65,7 @@ def __call__(self, file: str) -> str:
6465
input_dict = self.processor(a["input_values"][0], return_tensors="pt", padding=True)
6566
logits = self.model(input_dict.input_values).logits
6667
pred_ids = torch.argmax(logits, dim=-1)[0]
67-
if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th":
68+
if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th" or self.lm:
6869
txt = self.processor.decode(pred_ids)
6970
else:
7071
txt = self.processor.batch_decode(logits.detach().numpy()).text[0]

0 commit comments

Comments
 (0)