|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 | import torch |
3 | | -from transformers import AutoProcessor, AutoModelForCTC |
4 | 3 | import torchaudio |
5 | 4 | import numpy as np |
6 | 5 |
|
7 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
8 | 7 |
|
9 | 8 |
|
10 | 9 | class ASR: |
11 | | - def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=None) -> None: |
| 10 | + def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=False, device: str=None) -> None: |
12 | 11 | """ |
13 | | - :param str model: The ASR model |
| 12 | + :param str model: The ASR model name |
| 13 | + :param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model) |
14 | 14 | :param str device: device |
15 | 15 |
|
16 | 16 | **Options for model** |
17 | 17 | * *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model |
18 | 18 | * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model |
19 | 19 | * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model |
20 | 20 | """ |
21 | | - self.processor = AutoProcessor.from_pretrained(model) |
22 | 21 | self.model_name = model |
23 | | - self.model = AutoModelForCTC.from_pretrained(model) |
| 22 | + self.support_model =[ |
| 23 | + "airesearch/wav2vec2-large-xlsr-53-th", |
| 24 | + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm", |
| 25 | + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut" |
| 26 | + ] |
| 27 | + assert self.model_name in self.support_model |
| 28 | + self.lm =lm |
| 29 | + if not self.lm: |
| 30 | + from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
| 31 | + self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) |
| 32 | + self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name) |
| 33 | + else: |
| 34 | + from transformers import AutoProcessor, AutoModelForCTC |
| 35 | + self.processor = AutoProcessor.from_pretrained(self.model_name) |
| 36 | + self.model = AutoModelForCTC.from_pretrained(self.model_name) |
24 | 37 | if device!=None: |
25 | 38 | self.device = torch.device(device) |
26 | 39 |
|
@@ -54,29 +67,33 @@ def __call__(self, file: str) -> str: |
54 | 67 | pred_ids = torch.argmax(logits, dim=-1)[0] |
55 | 68 | if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th": |
56 | 69 | txt = self.processor.decode(pred_ids) |
57 | | - else: |
| 70 | + elif self.lm: |
58 | 71 | txt = self.processor.batch_decode(logits.detach().numpy()).text[0] |
| 72 | + else: |
| 73 | + txt = self.processor.decode(pred_ids) |
59 | 74 | return txt |
60 | 75 |
|
61 | 76 | _model_name = "airesearch/wav2vec2-large-xlsr-53-th" |
62 | 77 | _model = None |
63 | 78 |
|
64 | 79 |
|
65 | | -def asr(file: str, model: str = _model_name) -> str: |
| 80 | +def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -> str: |
66 | 81 | """ |
67 | 82 | :param str file: path of sound file |
68 | | - :param str model: The ASR model |
| 83 | + :param str model: The ASR model name |
| 84 | + :param bool lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model) |
| 85 | + :param str device: device |
69 | 86 | :return: thai text from ASR |
70 | 87 | :rtype: str |
71 | 88 |
|
72 | 89 | **Options for model** |
73 | 90 | * *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model |
74 | | - * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model |
75 | | - * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model |
| 91 | + * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) (+ language model) |
| 92 | + * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) (+ language model) |
76 | 93 | """ |
77 | 94 | global _model, _model_name |
78 | 95 | if model!=_model or _model == None: |
79 | | - _model = ASR(model) |
| 96 | + _model = ASR(model, lm=lm, device=device) |
80 | 97 | _model_name = model |
81 | 98 |
|
82 | 99 | return _model(file=file) |
0 commit comments