Skip to content

Commit 20c876e

Browse files
authored
Update __init__.py
1 parent 4071489 commit 20c876e

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

pythaiasr/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
11
# -*- coding: utf-8 -*-
22
import torch
3-
from transformers import AutoProcessor, AutoModelForCTC
43
import torchaudio
54
import numpy as np
65

76
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87

98

109
class ASR:
11-
def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=None) -> None:
10+
def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm=False, device=None) -> None:
1211
"""
13-
:param str model: The ASR model
12+
:param str model: The ASR model name
13+
:param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model)
1414
:param str device: device
1515
1616
**Options for model**
1717
* *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model
1818
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model
1919
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
2020
"""
21-
self.processor = AutoProcessor.from_pretrained(model)
2221
self.model_name = model
23-
self.model = AutoModelForCTC.from_pretrained(model)
22+
self.support_model =[
23+
"airesearch/wav2vec2-large-xlsr-53-th",
24+
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm",
25+
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut"
26+
]
27+
assert self.model_name in self.support_model
28+
if not lm:
29+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
30+
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
31+
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
32+
else:
33+
from transformers import AutoProcessor, AutoModelForCTC
34+
self.processor = AutoProcessor.from_pretrained(self.model_name)
35+
self.model = AutoModelForCTC.from_pretrained(self.model_name)
2436
if device!=None:
2537
self.device = torch.device(device)
2638

0 commit comments

Comments
 (0)