Skip to content

Commit de40473

Browse files
authored
Merge pull request #10 from wannaphong/add-wav2vec2-mode
PyThaiASR v1.2.0
2 parents 4071489 + c0ab28c commit de40473

File tree

3 files changed

+36
-17
lines changed

3 files changed

+36
-17
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pip install pythaiasr
1919
```
2020

2121
**For Wav2Vec2 with language model:**
22-
if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model, you needs to install by the step.
22+
if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model with language model, you needs to install by the step.
2323

2424
```sh
2525
pip install pythaiasr[lm]
@@ -37,17 +37,19 @@ print(asr(file))
3737
### API
3838

3939
```python
40-
asr(file: str, model: str = "airesearch/wav2vec2-large-xlsr-53-th")
40+
asr(file: str, model: str = _model_name, lm: bool=False, device: str=None)
4141
```
4242

4343
- file: path of sound file
4444
- model: The ASR model
45+
- lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model)
46+
- device: device
4547
- return: thai text from ASR
4648

4749
**Options for model**
4850
- *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model
49-
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model
50-
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
51+
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer)
52+
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer)
5153

5254
You can read about models from the list:
5355

pythaiasr/__init__.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,39 @@
11
# -*- coding: utf-8 -*-
22
import torch
3-
from transformers import AutoProcessor, AutoModelForCTC
43
import torchaudio
54
import numpy as np
65

76
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87

98

109
class ASR:
11-
def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=None) -> None:
10+
def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=False, device: str=None) -> None:
1211
"""
13-
:param str model: The ASR model
12+
:param str model: The ASR model name
13+
:param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model)
1414
:param str device: device
1515
1616
**Options for model**
1717
* *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model
1818
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model
1919
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
2020
"""
21-
self.processor = AutoProcessor.from_pretrained(model)
2221
self.model_name = model
23-
self.model = AutoModelForCTC.from_pretrained(model)
22+
self.support_model =[
23+
"airesearch/wav2vec2-large-xlsr-53-th",
24+
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm",
25+
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut"
26+
]
27+
assert self.model_name in self.support_model
28+
self.lm =lm
29+
if not self.lm:
30+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
31+
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
32+
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
33+
else:
34+
from transformers import AutoProcessor, AutoModelForCTC
35+
self.processor = AutoProcessor.from_pretrained(self.model_name)
36+
self.model = AutoModelForCTC.from_pretrained(self.model_name)
2437
if device!=None:
2538
self.device = torch.device(device)
2639

@@ -54,29 +67,33 @@ def __call__(self, file: str) -> str:
5467
pred_ids = torch.argmax(logits, dim=-1)[0]
5568
if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th":
5669
txt = self.processor.decode(pred_ids)
57-
else:
70+
elif self.lm:
5871
txt = self.processor.batch_decode(logits.detach().numpy()).text[0]
72+
else:
73+
txt = self.processor.decode(pred_ids)
5974
return txt
6075

6176
_model_name = "airesearch/wav2vec2-large-xlsr-53-th"
6277
_model = None
6378

6479

65-
def asr(file: str, model: str = _model_name) -> str:
80+
def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -> str:
6681
"""
6782
:param str file: path of sound file
68-
:param str model: The ASR model
83+
:param str model: The ASR model name
84+
:param bool lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model)
85+
:param str device: device
6986
:return: thai text from ASR
7087
:rtype: str
7188
7289
**Options for model**
7390
* *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model
74-
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model
75-
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
91+
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) (+ language model)
92+
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) (+ language model)
7693
"""
7794
global _model, _model_name
7895
if model!=_model or _model == None:
79-
_model = ASR(model)
96+
_model = ASR(model, lm=lm, device=device)
8097
_model_name = model
8198

8299
return _model(file=file)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def read(*paths):
1010

1111
requirements = [
1212
'datasets',
13-
'transformers',
13+
'transformers<5.0',
1414
'torchaudio',
1515
'soundfile',
1616
'torch',
@@ -27,7 +27,7 @@ def read(*paths):
2727

2828
setup(
2929
name='pythaiasr',
30-
version='1.1.2',
30+
version='1.2.0',
3131
packages=['pythaiasr'],
3232
url='https://github.com/pythainlp/pythaiasr',
3333
license='Apache Software License 2.0',

0 commit comments

Comments
 (0)