diff --git a/pystoi/stoi.py b/pystoi/stoi.py index e955afb..920871b 100644 --- a/pystoi/stoi.py +++ b/pystoi/stoi.py @@ -19,7 +19,8 @@ def stoi(x, y, fs_sig, extended=False): Computes the STOI (See [1][2]) of a denoised signal compared to a clean signal, The output is expected to have a monotonic relation with the subjective speech-intelligibility, where a higher score denotes better - speech intelligibility. + speech intelligibility. Accepts either a single waveform or a batch of + waveforms stored in a numpy array. # Arguments x (np.ndarray): clean original speech @@ -28,8 +29,9 @@ def stoi(x, y, fs_sig, extended=False): extended (bool): Boolean, whether to use the extended STOI described in [3] # Returns - float: Short time objective intelligibility measure between clean and - denoised speech + float or np.ndarray: Short time objective intelligibility measure between + clean and denoised speech. Returns float if called with a single waveform, + np.ndarray if called with a batch of waveforms # Raises AssertionError : if x and y have different lengths @@ -46,9 +48,14 @@ def stoi(x, y, fs_sig, extended=False): IEEE Transactions on Audio, Speech and Language Processing, 2016. """ if x.shape != y.shape: - raise Exception('x and y should have the same length,' + + raise Exception('x and y should have the same shape,' + 'found {} and {}'.format(x.shape, y.shape)) + out_shape = x.shape[:-1] + if len(x.shape) == 1: # Add a batch size if missing, shape (batch, num_samples) + x = x[None, :] + y = y[None, :] + # Resample is fs_sig is different than fs if fs_sig != FS: x = utils.resample_oct(x, FS, fs_sig) @@ -57,32 +64,33 @@ def stoi(x, y, fs_sig, extended=False): # Remove silent frames x, y = utils.remove_silent_frames(x, y, DYN_RANGE, N_FRAME, int(N_FRAME/2)) - # Take STFT - x_spec = utils.stft(x, N_FRAME, NFFT, overlap=2).transpose() - y_spec = utils.stft(y, N_FRAME, NFFT, overlap=2).transpose() + # Take STFT, shape (batch, num_frames, n_fft//2 + 1) + x_spec = utils.stft(x, N_FRAME, NFFT, overlap=2) + y_spec = utils.stft(y, N_FRAME, NFFT, overlap=2) # Ensure at least 30 frames for intermediate intelligibility - if x_spec.shape[-1] < N: + mask = ~np.all(x_spec == 0, axis=-1) # Check for frames with non-zero values + if np.any(np.sum(mask, axis=-1) < N): warnings.warn('Not enough STFT frames to compute intermediate ' 'intelligibility measure after removing silent ' 'frames. Returning 1e-5. Please check you wav files', RuntimeWarning) - return 1e-5 + return np.array([1e-5 for _ in range(x.shape[0])]).reshape(out_shape) - # Apply OB matrix to the spectrograms as in Eq. (1) - x_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(x_spec)))) - y_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(y_spec)))) + # Apply OB matrix to the spectrograms as in Eq. (1), shape (batch, frames, bands) + x_tob = np.sqrt(np.matmul(np.square(np.abs(x_spec)), OBM.T)) + y_tob = np.sqrt(np.matmul(np.square(np.abs(y_spec)), OBM.T)) # Take segments of x_tob, y_tob - x_segments = np.array( - [x_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)]) - y_segments = np.array( - [y_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)]) + x_segments = utils.segment_frames(x_tob, mask, N) + y_segments = utils.segment_frames(y_tob, mask, N) + # From now on the shape is always (batch, num_segments, seg_size, bands) if extended: x_n = utils.row_col_normalize(x_segments) y_n = utils.row_col_normalize(y_segments) - return np.sum(x_n * y_n / N) / x_n.shape[0] + d_n = np.mean(np.sum(x_n * y_n, axis=3), axis=2) + return np.mean(d_n, axis=1).reshape(out_shape) else: # Find normalization constants and normalize @@ -93,8 +101,7 @@ def stoi(x, y, fs_sig, extended=False): # Clip as described in [1] clip_value = 10 ** (-BETA / 20) - y_primes = np.minimum( - y_segments_normalized, x_segments * (1 + clip_value)) + y_primes = np.minimum(y_segments_normalized, x_segments * (1 + clip_value)) # Subtract mean vectors y_primes = y_primes - np.mean(y_primes, axis=2, keepdims=True) @@ -104,12 +111,11 @@ def stoi(x, y, fs_sig, extended=False): y_primes /= (np.linalg.norm(y_primes, axis=2, keepdims=True) + utils.EPS) x_segments /= (np.linalg.norm(x_segments, axis=2, keepdims=True) + utils.EPS) # Find a matrix with entries summing to sum of correlations of vectors - correlations_components = y_primes * x_segments - - # J, M as in [1], eq.6 - J = x_segments.shape[0] - M = x_segments.shape[1] + correlations_components = np.sum(y_primes * x_segments, axis=-2, keepdims=True) # Find the mean of all correlations - d = np.sum(correlations_components) / (J * M) - return d + d = np.mean(correlations_components, axis=(1, 3), keepdims=True) + # Exclude the contribution of silent frames from the calculation of the mean + d *= np.mean(mask, axis=1, keepdims=True)[..., None, None] + # Return just a float if stoi was called with a single waveform + return np.squeeze(d).reshape(out_shape) diff --git a/pystoi/utils.py b/pystoi/utils.py index c1fdb3e..abc7660 100644 --- a/pystoi/utils.py +++ b/pystoi/utils.py @@ -47,7 +47,7 @@ def resample_oct(x, p, q): """Resampler that is compatible with Octave""" h = _resample_window_oct(p, q) window = h / np.sum(h) - return resample_poly(x, p, q, window=window) + return resample_poly(x, p, q, axis=-1, window=window) @functools.lru_cache(maxsize=None) @@ -95,39 +95,51 @@ def stft(x, win_size, fft_size, overlap=4): """ hop = int(win_size / overlap) w = np.hanning(win_size + 2)[1: -1] # = matlab.hanning(win_size) - stft_out = np.array([np.fft.rfft(w * x[i:i + win_size], n=fft_size) - for i in range(0, len(x) - win_size, hop)]) - return stft_out + stft_out = np.array([np.fft.rfft(w * x[:, i:i + win_size], n=fft_size) + for i in range(0, x.shape[-1] - win_size, hop)]) + return stft_out.transpose([1, 0, 2]) def _overlap_and_add(x_frames, hop): - num_frames, framelen = x_frames.shape + batch_size, num_frames, framelen = x_frames.shape # Compute the number of segments, per frame. segments = -(-framelen // hop) # Divide and round up. # Pad the framelen dimension to segments * hop and add n=segments frames - signal = np.pad(x_frames, ((0, segments), (0, segments * hop - framelen))) - - # Reshape to a 3D tensor, splitting the framelen dimension in two - signal = signal.reshape((num_frames + segments, segments, hop)) - # Transpose dimensions so that signal.shape = (segments, frame+segments, hop) - signal = np.transpose(signal, [1, 0, 2]) - # Reshape so that signal.shape = (segments * (frame+segments), hop) - signal = signal.reshape((-1, hop)) - - # Now behold the magic!! Remove the last n=segments elements from the first axis - signal = signal[:-segments] - # Reshape to (segments, frame+segments-1, hop) - signal = signal.reshape((segments, num_frames + segments - 1, hop)) + signal = np.pad(x_frames, ((0, 0), (0, segments), (0, segments * hop - framelen))) + + # Reshape to a 4D tensor, splitting the framelen dimension in two + signal = signal.reshape((batch_size, num_frames + segments, segments, hop)) + # Transpose dimensions so that signal.shape = (batch, segments, frame+segments, hop) + signal = np.transpose(signal, [0, 2, 1, 3]) + # Reshape so that signal.shape = (batch, segments * (frame+segments), hop) + signal = signal.reshape((batch_size, -1, hop)) + + # Now behold the magic!! Remove the last n=segments elements from the second axis + signal = signal[:, :-segments] + # Reshape to (batch, segments, frame+segments-1, hop) + signal = signal.reshape((batch_size, segments, num_frames + segments - 1, hop)) # This has introduced a shift by one in all rows # Now, reduce over the columns and flatten the array to achieve the result - signal = np.sum(signal, axis=0) - end = (len(x_frames) - 1) * hop + framelen - signal = signal.reshape(-1)[:end] + signal = np.sum(signal, axis=1) + end = (num_frames - 1) * hop + framelen + signal = signal.reshape((batch_size, -1))[:end] return signal +def segment_frames(x, mask, seg_size): + segments = np.array([x[:, m - seg_size: m] for m in range(seg_size, x.shape[1] + 1)]) + segments = segments.transpose([1, 0, 2, 3]) # put back batch in the first dimension + return segments * mask[:, seg_size - 1:, None, None] + + +def _mask_audio(x, mask): + return np.array([ + np.pad(xi[mi], ((0, len(xi) - np.sum(mi)), (0, 0))) for xi, mi in zip(x, mask) + ]) + + def remove_silent_frames(x, y, dyn_range, framelen, hop): """ Remove silent frames of x and y based on x A frame is excluded if its energy is lower than max(energy) - dyn_range @@ -139,27 +151,29 @@ def remove_silent_frames(x, y, dyn_range, framelen, hop): framelen : Window size for energy evaluation hop : Hop size for energy evaluation # Returns : - x without the silent frames - y without the silent frames (aligned to x) + x without the silent frames, zero-padded to the original length + y without the silent frames, zero-padded to the original length (aligned to x) """ # Compute Mask w = np.hanning(framelen + 2)[1:-1] x_frames = np.array( - [w * x[i:i + framelen] for i in range(0, len(x) - framelen, hop)]) + [w * x[..., i : i + framelen] for i in range(0, x.shape[-1] - framelen, hop)] + ).transpose([1, 0, 2]) y_frames = np.array( - [w * y[i:i + framelen] for i in range(0, len(x) - framelen, hop)]) + [w * y[..., i : i + framelen] for i in range(0, x.shape[-1] - framelen, hop)] + ).transpose([1, 0, 2]) # Compute energies in dB - x_energies = 20 * np.log10(np.linalg.norm(x_frames, axis=1) + EPS) + x_energies = 20 * np.log10(np.linalg.norm(x_frames, axis=-1) + EPS) # Find boolean mask of energies lower than dynamic_range dB # with respect to maximum clean speech energy frame mask = (np.max(x_energies) - dyn_range - x_energies) < 0 - # Remove silent frames by masking - x_frames = x_frames[mask] - y_frames = y_frames[mask] + # Remove silent frames and pad with zeroes + x_frames = _mask_audio(x_frames, mask) + y_frames = _mask_audio(y_frames, mask) x_sil = _overlap_and_add(x_frames, hop) y_sil = _overlap_and_add(y_frames, hop) @@ -167,25 +181,13 @@ def remove_silent_frames(x, y, dyn_range, framelen, hop): return x_sil, y_sil -def vect_two_norm(x, axis=-1): - """ Returns an array of vectors of norms of the rows of matrices from 3D array """ - return np.sum(np.square(x), axis=axis, keepdims=True) - - def row_col_normalize(x): """ Row and column mean and variance normalize an array of 2D segments """ - # Row mean and variance normalization - x_normed = x + EPS * np.random.standard_normal(x.shape) + # input shape (batch, num_segments, seg_size, bands) + # Row mean and variance normalization -- axis: seg_size + x_normed = x - np.mean(x, axis=-2, keepdims=True) + x_normed = x_normed / (np.linalg.norm(x_normed, axis=-2, keepdims=True) + EPS) + # Column mean and variance normalization -- axis: bands x_normed -= np.mean(x_normed, axis=-1, keepdims=True) - x_inv = 1. / np.sqrt(vect_two_norm(x_normed)) - x_diags = np.array( - [np.diag(x_inv[i].reshape(-1)) for i in range(x_inv.shape[0])]) - x_normed = np.matmul(x_diags, x_normed) - # Column mean and variance normalization - x_normed += + EPS * np.random.standard_normal(x_normed.shape) - x_normed -= np.mean(x_normed, axis=1, keepdims=True) - x_inv = 1. / np.sqrt(vect_two_norm(x_normed, axis=1)) - x_diags = np.array( - [np.diag(x_inv[i].reshape(-1)) for i in range(x_inv.shape[0])]) - x_normed = np.matmul(x_normed, x_diags) + x_normed = x_normed / (np.linalg.norm(x_normed, axis=-1, keepdims=True) + EPS) return x_normed diff --git a/tests/test_overlap_and_add.py b/tests/test_overlap_and_add.py index c6137d6..b86bd20 100644 --- a/tests/test_overlap_and_add.py +++ b/tests/test_overlap_and_add.py @@ -1,7 +1,8 @@ import numpy as np +import pytest from numpy.testing import assert_allclose -from pystoi.stoi import N_FRAME +from pystoi.stoi import N_FRAME, stoi, FS from pystoi.utils import _overlap_and_add @@ -15,12 +16,64 @@ def old_overlap_and_app(x_frames, hop): x_sil[range(i * hop, i * hop + framelen)] += x_frames[i, :] return x_sil + batch_size = 4 # Initialize - x = np.random.randn(1000 * N_FRAME) + x = np.random.randn(batch_size, 1000 * N_FRAME) # Add silence segment - silence = np.zeros(10 * N_FRAME) - x = np.concatenate([x[: 500 * N_FRAME], silence, x[500 * N_FRAME :]]) - x = x.reshape([-1, N_FRAME]) - xs = old_overlap_and_app(x, N_FRAME // 2) + silence = np.zeros((batch_size, 10 * N_FRAME)) + x = np.concatenate([x[:, : 500 * N_FRAME], silence, x[:, 500 * N_FRAME :]], axis=1) + x = x.reshape([batch_size, -1, N_FRAME]) + xs = [old_overlap_and_app(xi, N_FRAME // 2) for xi in x] xs_vectorise = _overlap_and_add(x, N_FRAME // 2) assert_allclose(xs, xs_vectorise) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("fs", [10000, 16000]) +@pytest.mark.parametrize("extended", [True, False]) +def test_pystoi_run(batch_size, fs, extended): + N = fs * 4 # 4 seconds of random audio + x = np.random.randn(batch_size, N) + res = stoi(x, x, fs, extended) + print(batch_size, fs, extended, res) + assert res.shape == x.shape[:-1] + + +@pytest.mark.parametrize("extended", [True, False]) +@pytest.mark.parametrize("batch_size", [1, 4]) +def test_pystoi_complete_silence(batch_size, extended): + fs = 16000 + N = fs * 4 # 4 seconds of random audio + x = np.zeros((batch_size, N)) + res = stoi(x, x, fs, extended) + print(batch_size, fs, extended, res) + assert res.shape == x.shape[:-1] + + +@pytest.mark.parametrize("extended", [True, False]) +def test_pystoi_silence(extended): + rng = np.random.default_rng(seed=0) + batch_size = 4 + fs = 16000 + N = fs * 4 # 4 seconds of random audio + x = np.random.randn(batch_size, N) + silence = np.random.randn(int(N / 7)) + audio = [] + for i in range(batch_size): + t = int(rng.random() * N) + audio.append(np.concatenate([x[i, :t], silence, x[i, t:]])) + audio = np.array(audio) + res = stoi(audio, audio, fs, extended) + print(batch_size, fs, extended, res) + assert res.shape == x.shape[:-1] + + +def test_vectorisation(): + # Initialize batch of data + batch_size = 4 + x = np.random.random((batch_size, 100 * N_FRAME)) + y = np.random.random((batch_size, 100 * N_FRAME)) + res = np.array([stoi(xi, yi, FS) for xi, yi in zip(x, y)]) + res_vec = stoi(x, y, FS) + assert res_vec.shape == x.shape[:-1] + assert np.allclose(res, res_vec)