Skip to content
This repository was archived by the owner on May 6, 2022. It is now read-only.

Commit 4f11ed4

Browse files
will-ricebrentspell
authored andcommitted
Addition of Keyword Recognizer
The keyword recognizer is similar to the wakeword models in composition, but with a use case that is closer to ASR. The model classifies a keyword from a stream of audio and adds the keyword to the context's transcript. The addtions are: - KeywordRecognizer - Tests
1 parent 55e61ba commit 4f11ed4

File tree

6 files changed

+281
-2
lines changed

6 files changed

+281
-2
lines changed

spokestack/asr/keyword/__init__.py

Whitespace-only changes.

spokestack/asr/keyword/tflite.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
This module contains the Spokestack KeywordRecognizer which identifies multiple keywords
3+
from an audio stream.
4+
"""
5+
import os
6+
from typing import List
7+
8+
import numpy as np # type: ignore
9+
10+
from spokestack.context import SpeechContext
11+
from spokestack.models.tensorflow import TFLiteModel
12+
from spokestack.ring_buffer import RingBuffer
13+
14+
15+
class KeywordRecognizer:
16+
"""Recognizes keywords in an audio stream.
17+
18+
Args:
19+
classes (List[str]): Keyword labels
20+
pre_emphasis (float): The value of the pre-emphasis filter
21+
sample_rate (int): The number of audio samples per second of audio (kHz)
22+
fft_window_type (str): The type of fft window. (only support for hann)
23+
fft_hop_length (int): Audio sliding window for STFT calculation (ms)
24+
model_dir (str): Path to the directory containing .tflite models
25+
posterior_threshold (float): Probability threshold for detection
26+
"""
27+
28+
def __init__(
29+
self,
30+
classes: List[str],
31+
pre_emphasis: float = 0.97,
32+
sample_rate: int = 16000,
33+
fft_window_type: str = "hann",
34+
fft_hop_length: int = 10,
35+
model_dir: str = "",
36+
posterior_threshold: float = 0.5,
37+
**kwargs
38+
) -> None:
39+
40+
self.classes = classes
41+
self.pre_emphasis: float = pre_emphasis
42+
self.hop_length: int = int(fft_hop_length * sample_rate / 1000)
43+
44+
if fft_window_type != "hann":
45+
raise ValueError("Invalid fft_window_type")
46+
47+
self.filter_model: TFLiteModel = TFLiteModel(
48+
model_path=os.path.join(model_dir, "filter.tflite")
49+
)
50+
self.encode_model: TFLiteModel = TFLiteModel(
51+
model_path=os.path.join(model_dir, "encode.tflite")
52+
)
53+
self.detect_model: TFLiteModel = TFLiteModel(
54+
model_path=os.path.join(model_dir, "detect.tflite")
55+
)
56+
57+
if len(classes) != self.detect_model.output_details[0]["shape"][-1]:
58+
raise ValueError("Invalid number of classes")
59+
60+
# window size calculated based on fft
61+
# the filter inputs are (fft_size - 1) / 2
62+
# which makes the window size (post_fft_size - 1) * 2
63+
self._window_size = (self.filter_model.input_details[0]["shape"][-1] - 1) * 2
64+
self._fft_window = np.hanning(self._window_size)
65+
66+
# retrieve the mel_length and mel_width based on the encoder model metadata
67+
# these allocate the buffer to the correct size
68+
self.mel_length: int = self.encode_model.input_details[0]["shape"][1]
69+
self.mel_width: int = self.encode_model.input_details[0]["shape"][-1]
70+
71+
# initialize the first state input for autoregressive encoder model
72+
# retrieve the encode_length and encode_width from the model detect_model
73+
# metadata. We get the dimensions from the detect_model inputs because the
74+
# encode_model runs autoregressive and outputs a single encoded sample.
75+
# the detect_model input is a collection of these samples.
76+
self.state = np.zeros(self.encode_model.input_details[1]["shape"], np.float32)
77+
self.encode_length: int = self.detect_model.input_details[0]["shape"][1]
78+
self.encode_width: int = self.detect_model.input_details[0]["shape"][-1]
79+
80+
self.sample_window: RingBuffer = RingBuffer(
81+
shape=[self._window_size],
82+
)
83+
self.frame_window: RingBuffer = RingBuffer(
84+
shape=[self.mel_length, self.mel_width]
85+
)
86+
self.encode_window: RingBuffer = RingBuffer(
87+
shape=[self.encode_length, self.encode_width]
88+
)
89+
90+
# initialize the frame and encode windows with zeros
91+
# this minimizes the delay caused by filling the buffer
92+
self.frame_window.fill(0.0)
93+
self.encode_window.fill(-1.0)
94+
95+
self._posterior_threshold: float = posterior_threshold
96+
self._prev_sample: float = 0.0
97+
self._is_active = False
98+
99+
def __call__(self, context: SpeechContext, frame) -> None:
100+
101+
self._sample(context, frame)
102+
103+
if not context.is_active and self._is_active:
104+
self._detect(context)
105+
106+
self._is_active = context.is_active
107+
108+
def _sample(self, context: SpeechContext, frame) -> None:
109+
# convert the PCM-16 audio to float32 in (-1.0, 1.0)
110+
frame = frame.astype(np.float32) / (2 ** 15 - 1)
111+
frame = np.clip(frame, -1.0, 1.0)
112+
113+
# pull out a single value from the frame and apply pre-emphasis
114+
# with the previous sample then cache the previous sample
115+
# to be use in the next iteration
116+
prev_sample = frame[-1]
117+
frame -= self.pre_emphasis * np.append(self._prev_sample, frame[:-1])
118+
self._prev_sample = prev_sample
119+
120+
# fill the sample window to analyze speech containing samples
121+
# after each window fill the buffer advances by the hop length
122+
# to produce an overlapping window
123+
for sample in frame:
124+
self.sample_window.write(sample)
125+
if self.sample_window.is_full:
126+
if context.is_active:
127+
self._analyze(context)
128+
self.sample_window.rewind().seek(self.hop_length)
129+
130+
def _analyze(self, context: SpeechContext) -> None:
131+
# read the full contents of the sample window to calculate a single frame
132+
# of the STFT by applying the DFT to a real-valued input and
133+
# taking the magnitude of the complex DFT
134+
frame = self.sample_window.read_all()
135+
frame = np.fft.rfft(frame * self._fft_window, n=self._window_size)
136+
frame = np.abs(frame).astype(np.float32)
137+
138+
# compute mel spectrogram
139+
self._filter(context, frame)
140+
141+
def _filter(self, context: SpeechContext, frame) -> None:
142+
# add the batch dimension and compute the mel spectrogram with filter model
143+
frame = np.expand_dims(frame, 0)
144+
frame = self.filter_model(frame)[0]
145+
146+
# advance the window by 1 and write mel frame to the frame buffer
147+
self.frame_window.rewind().seek(1)
148+
self.frame_window.write(frame)
149+
150+
# encode the mel spectrogram
151+
self._encode(context)
152+
153+
def _encode(self, context: SpeechContext) -> None:
154+
# read the full contents of the frame window and add the batch dimension
155+
# run the encoder and save the output state for autoregression
156+
frame = self.frame_window.read_all()
157+
frame = np.expand_dims(frame, 0)
158+
frame, self.state = self.encode_model(frame, self.state)
159+
160+
# accumulate encoded samples until size of detection window
161+
self.encode_window.rewind().seek(1)
162+
self.encode_window.write(frame)
163+
164+
def _detect(self, context: SpeechContext) -> None:
165+
# read the full contents of the encode window and add the batch dimension
166+
# calculate a scalar likelihood that the frame contains a keyword
167+
# with the detect model
168+
frame = self.encode_window.read_all()
169+
frame = np.expand_dims(frame, 0)
170+
posterior = self.detect_model(frame)[0][0]
171+
class_index = np.argmax(posterior)
172+
confidence = posterior[class_index]
173+
174+
if confidence >= self._posterior_threshold:
175+
context.transcript = self.classes[class_index]
176+
context.confidence = confidence
177+
context.event("recognize")
178+
else:
179+
context.event("timeout")
180+
self.reset()
181+
182+
def reset(self) -> None:
183+
""" Resets the current KeywordDetector state """
184+
self.sample_window.reset()
185+
self.frame_window.reset().fill(0.0)
186+
self.encode_window.reset().fill(-1.0)
187+
self.state[:] = 0.0
188+
189+
def close(self) -> None:
190+
""" Close interface for use in the SpeechPipeline """
191+
self.reset()

spokestack/wakeword/tflite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from spokestack.context import SpeechContext
1111
from spokestack.models.tensorflow import TFLiteModel
12-
from spokestack.wakeword.ring_buffer import RingBuffer
12+
from spokestack.ring_buffer import RingBuffer
1313

1414

1515
_LOG = logging.getLogger(__name__)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
This module contains tests for the keyword recognizer.
3+
"""
4+
from unittest import mock
5+
6+
import numpy as np # type: ignore
7+
import pytest # type: ignore
8+
9+
from spokestack.asr.keyword.tflite import KeywordRecognizer
10+
from spokestack.context import SpeechContext
11+
12+
13+
class ModelFactory(mock.MagicMock):
14+
def __call__(self, model_path):
15+
model = mock.MagicMock()
16+
if model_path.endswith("filter.tflite"):
17+
model.input_details = [{"shape": [1, 257]}]
18+
model.output_details = [{"shape": [1, 40]}]
19+
model.return_value = [np.zeros((1, 40))]
20+
elif model_path.endswith("encode.tflite"):
21+
model.input_details = [{"shape": [1, 1, 40]}, {"shape": [1, 128]}]
22+
model.output_details = [{"shape": [1, 128]}, {"shape": [1, 128]}]
23+
model.return_value = [np.zeros((1, 128)), np.zeros((1, 128))]
24+
elif model_path.endswith("detect.tflite"):
25+
model.input_details = [{"shape": [1, 92, 128]}]
26+
model.output_details = [{"shape": [1, 3]}]
27+
model.return_value = [[[0.8, 0.1, 0.3]]]
28+
return model
29+
30+
31+
@mock.patch("spokestack.asr.keyword.tflite.TFLiteModel", new_callable=ModelFactory)
32+
def test_invalid_args(*args):
33+
with pytest.raises(ValueError):
34+
_ = KeywordRecognizer(
35+
classes=["one", "two", "three"], fft_window_type="hamming"
36+
)
37+
38+
39+
@mock.patch("spokestack.asr.keyword.tflite.TFLiteModel", new_callable=ModelFactory)
40+
def test_invalid_classes(*args):
41+
with pytest.raises(ValueError):
42+
_ = KeywordRecognizer(classes=["one", "two", "three", "four"])
43+
44+
with pytest.raises(ValueError):
45+
_ = KeywordRecognizer(classes=["one", "two"])
46+
47+
48+
@mock.patch("spokestack.asr.keyword.tflite.TFLiteModel", new_callable=ModelFactory)
49+
def test_recognize(*args):
50+
context = SpeechContext()
51+
recognizer = KeywordRecognizer(classes=["one", "two", "three"])
52+
53+
test_frame = np.random.rand(
54+
160,
55+
).astype(np.float32)
56+
57+
context.is_active = True
58+
for i in range(10):
59+
recognizer(context, test_frame)
60+
recognizer(context, test_frame)
61+
62+
context.is_active = False
63+
recognizer(context, test_frame)
64+
assert context.transcript == "one"
65+
66+
recognizer.close()
67+
68+
69+
@mock.patch("spokestack.asr.keyword.tflite.TFLiteModel", new_callable=ModelFactory)
70+
def test_timeout(*args):
71+
context = SpeechContext()
72+
recognizer = KeywordRecognizer(classes=["one", "two", "three"])
73+
recognizer.detect_model.return_value = [[[0.0, 0.0, 0.0]]]
74+
75+
test_frame = np.random.rand(
76+
160,
77+
).astype(np.float32)
78+
79+
context.is_active = True
80+
for i in range(10):
81+
recognizer(context, test_frame)
82+
recognizer(context, test_frame)
83+
84+
context.is_active = False
85+
recognizer(context, test_frame)
86+
assert not context.transcript
87+
88+
recognizer.close()

tests/wakeword/test_ring_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66

7-
from spokestack.wakeword.ring_buffer import RingBuffer
7+
from spokestack.ring_buffer import RingBuffer
88

99

1010
def test_ring_buffer():

0 commit comments

Comments
 (0)