|
| 1 | +""" |
| 2 | +This module contains the Spokestack KeywordRecognizer which identifies multiple keywords |
| 3 | +from an audio stream. |
| 4 | +""" |
| 5 | +import os |
| 6 | +from typing import List |
| 7 | + |
| 8 | +import numpy as np # type: ignore |
| 9 | + |
| 10 | +from spokestack.context import SpeechContext |
| 11 | +from spokestack.models.tensorflow import TFLiteModel |
| 12 | +from spokestack.ring_buffer import RingBuffer |
| 13 | + |
| 14 | + |
| 15 | +class KeywordRecognizer: |
| 16 | + """Recognizes keywords in an audio stream. |
| 17 | +
|
| 18 | + Args: |
| 19 | + classes (List[str]): Keyword labels |
| 20 | + pre_emphasis (float): The value of the pre-emphasis filter |
| 21 | + sample_rate (int): The number of audio samples per second of audio (kHz) |
| 22 | + fft_window_type (str): The type of fft window. (only support for hann) |
| 23 | + fft_hop_length (int): Audio sliding window for STFT calculation (ms) |
| 24 | + model_dir (str): Path to the directory containing .tflite models |
| 25 | + posterior_threshold (float): Probability threshold for detection |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + classes: List[str], |
| 31 | + pre_emphasis: float = 0.97, |
| 32 | + sample_rate: int = 16000, |
| 33 | + fft_window_type: str = "hann", |
| 34 | + fft_hop_length: int = 10, |
| 35 | + model_dir: str = "", |
| 36 | + posterior_threshold: float = 0.5, |
| 37 | + **kwargs |
| 38 | + ) -> None: |
| 39 | + |
| 40 | + self.classes = classes |
| 41 | + self.pre_emphasis: float = pre_emphasis |
| 42 | + self.hop_length: int = int(fft_hop_length * sample_rate / 1000) |
| 43 | + |
| 44 | + if fft_window_type != "hann": |
| 45 | + raise ValueError("Invalid fft_window_type") |
| 46 | + |
| 47 | + self.filter_model: TFLiteModel = TFLiteModel( |
| 48 | + model_path=os.path.join(model_dir, "filter.tflite") |
| 49 | + ) |
| 50 | + self.encode_model: TFLiteModel = TFLiteModel( |
| 51 | + model_path=os.path.join(model_dir, "encode.tflite") |
| 52 | + ) |
| 53 | + self.detect_model: TFLiteModel = TFLiteModel( |
| 54 | + model_path=os.path.join(model_dir, "detect.tflite") |
| 55 | + ) |
| 56 | + |
| 57 | + if len(classes) != self.detect_model.output_details[0]["shape"][-1]: |
| 58 | + raise ValueError("Invalid number of classes") |
| 59 | + |
| 60 | + # window size calculated based on fft |
| 61 | + # the filter inputs are (fft_size - 1) / 2 |
| 62 | + # which makes the window size (post_fft_size - 1) * 2 |
| 63 | + self._window_size = (self.filter_model.input_details[0]["shape"][-1] - 1) * 2 |
| 64 | + self._fft_window = np.hanning(self._window_size) |
| 65 | + |
| 66 | + # retrieve the mel_length and mel_width based on the encoder model metadata |
| 67 | + # these allocate the buffer to the correct size |
| 68 | + self.mel_length: int = self.encode_model.input_details[0]["shape"][1] |
| 69 | + self.mel_width: int = self.encode_model.input_details[0]["shape"][-1] |
| 70 | + |
| 71 | + # initialize the first state input for autoregressive encoder model |
| 72 | + # retrieve the encode_length and encode_width from the model detect_model |
| 73 | + # metadata. We get the dimensions from the detect_model inputs because the |
| 74 | + # encode_model runs autoregressive and outputs a single encoded sample. |
| 75 | + # the detect_model input is a collection of these samples. |
| 76 | + self.state = np.zeros(self.encode_model.input_details[1]["shape"], np.float32) |
| 77 | + self.encode_length: int = self.detect_model.input_details[0]["shape"][1] |
| 78 | + self.encode_width: int = self.detect_model.input_details[0]["shape"][-1] |
| 79 | + |
| 80 | + self.sample_window: RingBuffer = RingBuffer( |
| 81 | + shape=[self._window_size], |
| 82 | + ) |
| 83 | + self.frame_window: RingBuffer = RingBuffer( |
| 84 | + shape=[self.mel_length, self.mel_width] |
| 85 | + ) |
| 86 | + self.encode_window: RingBuffer = RingBuffer( |
| 87 | + shape=[self.encode_length, self.encode_width] |
| 88 | + ) |
| 89 | + |
| 90 | + # initialize the frame and encode windows with zeros |
| 91 | + # this minimizes the delay caused by filling the buffer |
| 92 | + self.frame_window.fill(0.0) |
| 93 | + self.encode_window.fill(-1.0) |
| 94 | + |
| 95 | + self._posterior_threshold: float = posterior_threshold |
| 96 | + self._prev_sample: float = 0.0 |
| 97 | + self._is_active = False |
| 98 | + |
| 99 | + def __call__(self, context: SpeechContext, frame) -> None: |
| 100 | + |
| 101 | + self._sample(context, frame) |
| 102 | + |
| 103 | + if not context.is_active and self._is_active: |
| 104 | + self._detect(context) |
| 105 | + |
| 106 | + self._is_active = context.is_active |
| 107 | + |
| 108 | + def _sample(self, context: SpeechContext, frame) -> None: |
| 109 | + # convert the PCM-16 audio to float32 in (-1.0, 1.0) |
| 110 | + frame = frame.astype(np.float32) / (2 ** 15 - 1) |
| 111 | + frame = np.clip(frame, -1.0, 1.0) |
| 112 | + |
| 113 | + # pull out a single value from the frame and apply pre-emphasis |
| 114 | + # with the previous sample then cache the previous sample |
| 115 | + # to be use in the next iteration |
| 116 | + prev_sample = frame[-1] |
| 117 | + frame -= self.pre_emphasis * np.append(self._prev_sample, frame[:-1]) |
| 118 | + self._prev_sample = prev_sample |
| 119 | + |
| 120 | + # fill the sample window to analyze speech containing samples |
| 121 | + # after each window fill the buffer advances by the hop length |
| 122 | + # to produce an overlapping window |
| 123 | + for sample in frame: |
| 124 | + self.sample_window.write(sample) |
| 125 | + if self.sample_window.is_full: |
| 126 | + if context.is_active: |
| 127 | + self._analyze(context) |
| 128 | + self.sample_window.rewind().seek(self.hop_length) |
| 129 | + |
| 130 | + def _analyze(self, context: SpeechContext) -> None: |
| 131 | + # read the full contents of the sample window to calculate a single frame |
| 132 | + # of the STFT by applying the DFT to a real-valued input and |
| 133 | + # taking the magnitude of the complex DFT |
| 134 | + frame = self.sample_window.read_all() |
| 135 | + frame = np.fft.rfft(frame * self._fft_window, n=self._window_size) |
| 136 | + frame = np.abs(frame).astype(np.float32) |
| 137 | + |
| 138 | + # compute mel spectrogram |
| 139 | + self._filter(context, frame) |
| 140 | + |
| 141 | + def _filter(self, context: SpeechContext, frame) -> None: |
| 142 | + # add the batch dimension and compute the mel spectrogram with filter model |
| 143 | + frame = np.expand_dims(frame, 0) |
| 144 | + frame = self.filter_model(frame)[0] |
| 145 | + |
| 146 | + # advance the window by 1 and write mel frame to the frame buffer |
| 147 | + self.frame_window.rewind().seek(1) |
| 148 | + self.frame_window.write(frame) |
| 149 | + |
| 150 | + # encode the mel spectrogram |
| 151 | + self._encode(context) |
| 152 | + |
| 153 | + def _encode(self, context: SpeechContext) -> None: |
| 154 | + # read the full contents of the frame window and add the batch dimension |
| 155 | + # run the encoder and save the output state for autoregression |
| 156 | + frame = self.frame_window.read_all() |
| 157 | + frame = np.expand_dims(frame, 0) |
| 158 | + frame, self.state = self.encode_model(frame, self.state) |
| 159 | + |
| 160 | + # accumulate encoded samples until size of detection window |
| 161 | + self.encode_window.rewind().seek(1) |
| 162 | + self.encode_window.write(frame) |
| 163 | + |
| 164 | + def _detect(self, context: SpeechContext) -> None: |
| 165 | + # read the full contents of the encode window and add the batch dimension |
| 166 | + # calculate a scalar likelihood that the frame contains a keyword |
| 167 | + # with the detect model |
| 168 | + frame = self.encode_window.read_all() |
| 169 | + frame = np.expand_dims(frame, 0) |
| 170 | + posterior = self.detect_model(frame)[0][0] |
| 171 | + class_index = np.argmax(posterior) |
| 172 | + confidence = posterior[class_index] |
| 173 | + |
| 174 | + if confidence >= self._posterior_threshold: |
| 175 | + context.transcript = self.classes[class_index] |
| 176 | + context.confidence = confidence |
| 177 | + context.event("recognize") |
| 178 | + else: |
| 179 | + context.event("timeout") |
| 180 | + self.reset() |
| 181 | + |
| 182 | + def reset(self) -> None: |
| 183 | + """ Resets the current KeywordDetector state """ |
| 184 | + self.sample_window.reset() |
| 185 | + self.frame_window.reset().fill(0.0) |
| 186 | + self.encode_window.reset().fill(-1.0) |
| 187 | + self.state[:] = 0.0 |
| 188 | + |
| 189 | + def close(self) -> None: |
| 190 | + """ Close interface for use in the SpeechPipeline """ |
| 191 | + self.reset() |
0 commit comments