|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 | import torch |
3 | | -from transformers import AutoProcessor, AutoModelForCTC |
4 | 3 | import torchaudio |
5 | 4 | import numpy as np |
6 | 5 |
|
7 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
8 | 7 |
|
9 | 8 |
|
10 | 9 | class ASR: |
11 | | - def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=None) -> None: |
| 10 | + def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm=False, device=None) -> None: |
12 | 11 | """ |
13 | | - :param str model: The ASR model |
| 12 | + :param str model: The ASR model name |
| 13 | + :param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model) |
14 | 14 | :param str device: device |
15 | 15 |
|
16 | 16 | **Options for model** |
17 | 17 | * *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model |
18 | 18 | * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model |
19 | 19 | * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model |
20 | 20 | """ |
21 | | - self.processor = AutoProcessor.from_pretrained(model) |
22 | 21 | self.model_name = model |
23 | | - self.model = AutoModelForCTC.from_pretrained(model) |
| 22 | + self.support_model =[ |
| 23 | + "airesearch/wav2vec2-large-xlsr-53-th", |
| 24 | + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm", |
| 25 | + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut" |
| 26 | + ] |
| 27 | + assert self.model_name in self.support_model |
| 28 | + if not lm: |
| 29 | + from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
| 30 | + self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) |
| 31 | + self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name) |
| 32 | + else: |
| 33 | + from transformers import AutoProcessor, AutoModelForCTC |
| 34 | + self.processor = AutoProcessor.from_pretrained(self.model_name) |
| 35 | + self.model = AutoModelForCTC.from_pretrained(self.model_name) |
24 | 36 | if device!=None: |
25 | 37 | self.device = torch.device(device) |
26 | 38 |
|
|
0 commit comments