Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions pystoi/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def stoi(x, y, fs_sig, extended=False):
Computes the STOI (See [1][2]) of a denoised signal compared to a clean
signal, The output is expected to have a monotonic relation with the
subjective speech-intelligibility, where a higher score denotes better
speech intelligibility.
speech intelligibility. Accepts either a single waveform or a batch of
waveforms stored in a numpy array.

# Arguments
x (np.ndarray): clean original speech
Expand All @@ -28,8 +29,9 @@ def stoi(x, y, fs_sig, extended=False):
extended (bool): Boolean, whether to use the extended STOI described in [3]

# Returns
float: Short time objective intelligibility measure between clean and
denoised speech
float or np.ndarray: Short time objective intelligibility measure between
clean and denoised speech. Returns float if called with a single waveform,
np.ndarray if called with a batch of waveforms

# Raises
AssertionError : if x and y have different lengths
Expand All @@ -46,9 +48,14 @@ def stoi(x, y, fs_sig, extended=False):
IEEE Transactions on Audio, Speech and Language Processing, 2016.
"""
if x.shape != y.shape:
raise Exception('x and y should have the same length,' +
raise Exception('x and y should have the same shape,' +
'found {} and {}'.format(x.shape, y.shape))

out_shape = x.shape[:-1]
if len(x.shape) == 1: # Add a batch size if missing, shape (batch, num_samples)
x = x[None, :]
y = y[None, :]

# Resample is fs_sig is different than fs
if fs_sig != FS:
x = utils.resample_oct(x, FS, fs_sig)
Expand All @@ -57,32 +64,33 @@ def stoi(x, y, fs_sig, extended=False):
# Remove silent frames
x, y = utils.remove_silent_frames(x, y, DYN_RANGE, N_FRAME, int(N_FRAME/2))

# Take STFT
x_spec = utils.stft(x, N_FRAME, NFFT, overlap=2).transpose()
y_spec = utils.stft(y, N_FRAME, NFFT, overlap=2).transpose()
# Take STFT, shape (batch, num_frames, n_fft//2 + 1)
x_spec = utils.stft(x, N_FRAME, NFFT, overlap=2)
y_spec = utils.stft(y, N_FRAME, NFFT, overlap=2)

# Ensure at least 30 frames for intermediate intelligibility
if x_spec.shape[-1] < N:
mask = ~np.all(x_spec == 0, axis=-1) # Check for frames with non-zero values
if np.any(np.sum(mask, axis=-1) < N):
warnings.warn('Not enough STFT frames to compute intermediate '
'intelligibility measure after removing silent '
'frames. Returning 1e-5. Please check you wav files',
RuntimeWarning)
return 1e-5
return np.array([1e-5 for _ in range(x.shape[0])]).reshape(out_shape)

# Apply OB matrix to the spectrograms as in Eq. (1)
x_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(x_spec))))
y_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(y_spec))))
# Apply OB matrix to the spectrograms as in Eq. (1), shape (batch, frames, bands)
x_tob = np.sqrt(np.matmul(np.square(np.abs(x_spec)), OBM.T))
y_tob = np.sqrt(np.matmul(np.square(np.abs(y_spec)), OBM.T))

# Take segments of x_tob, y_tob
x_segments = np.array(
[x_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)])
y_segments = np.array(
[y_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)])
x_segments = utils.segment_frames(x_tob, mask, N)
y_segments = utils.segment_frames(y_tob, mask, N)

# From now on the shape is always (batch, num_segments, seg_size, bands)
if extended:
x_n = utils.row_col_normalize(x_segments)
y_n = utils.row_col_normalize(y_segments)
return np.sum(x_n * y_n / N) / x_n.shape[0]
d_n = np.mean(np.sum(x_n * y_n, axis=3), axis=2)
return np.mean(d_n, axis=1).reshape(out_shape)

else:
# Find normalization constants and normalize
Expand All @@ -93,8 +101,7 @@ def stoi(x, y, fs_sig, extended=False):

# Clip as described in [1]
clip_value = 10 ** (-BETA / 20)
y_primes = np.minimum(
y_segments_normalized, x_segments * (1 + clip_value))
y_primes = np.minimum(y_segments_normalized, x_segments * (1 + clip_value))

# Subtract mean vectors
y_primes = y_primes - np.mean(y_primes, axis=2, keepdims=True)
Expand All @@ -104,12 +111,11 @@ def stoi(x, y, fs_sig, extended=False):
y_primes /= (np.linalg.norm(y_primes, axis=2, keepdims=True) + utils.EPS)
x_segments /= (np.linalg.norm(x_segments, axis=2, keepdims=True) + utils.EPS)
# Find a matrix with entries summing to sum of correlations of vectors
correlations_components = y_primes * x_segments

# J, M as in [1], eq.6
J = x_segments.shape[0]
M = x_segments.shape[1]
correlations_components = np.sum(y_primes * x_segments, axis=-2, keepdims=True)

# Find the mean of all correlations
d = np.sum(correlations_components) / (J * M)
return d
d = np.mean(correlations_components, axis=(1, 3), keepdims=True)
# Exclude the contribution of silent frames from the calculation of the mean
d *= np.mean(mask, axis=1, keepdims=True)[..., None, None]
# Return just a float if stoi was called with a single waveform
return np.squeeze(d).reshape(out_shape)
96 changes: 49 additions & 47 deletions pystoi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def resample_oct(x, p, q):
"""Resampler that is compatible with Octave"""
h = _resample_window_oct(p, q)
window = h / np.sum(h)
return resample_poly(x, p, q, window=window)
return resample_poly(x, p, q, axis=-1, window=window)


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -95,39 +95,51 @@ def stft(x, win_size, fft_size, overlap=4):
"""
hop = int(win_size / overlap)
w = np.hanning(win_size + 2)[1: -1] # = matlab.hanning(win_size)
stft_out = np.array([np.fft.rfft(w * x[i:i + win_size], n=fft_size)
for i in range(0, len(x) - win_size, hop)])
return stft_out
stft_out = np.array([np.fft.rfft(w * x[:, i:i + win_size], n=fft_size)
for i in range(0, x.shape[-1] - win_size, hop)])
return stft_out.transpose([1, 0, 2])


def _overlap_and_add(x_frames, hop):
num_frames, framelen = x_frames.shape
batch_size, num_frames, framelen = x_frames.shape
# Compute the number of segments, per frame.
segments = -(-framelen // hop) # Divide and round up.

# Pad the framelen dimension to segments * hop and add n=segments frames
signal = np.pad(x_frames, ((0, segments), (0, segments * hop - framelen)))

# Reshape to a 3D tensor, splitting the framelen dimension in two
signal = signal.reshape((num_frames + segments, segments, hop))
# Transpose dimensions so that signal.shape = (segments, frame+segments, hop)
signal = np.transpose(signal, [1, 0, 2])
# Reshape so that signal.shape = (segments * (frame+segments), hop)
signal = signal.reshape((-1, hop))

# Now behold the magic!! Remove the last n=segments elements from the first axis
signal = signal[:-segments]
# Reshape to (segments, frame+segments-1, hop)
signal = signal.reshape((segments, num_frames + segments - 1, hop))
signal = np.pad(x_frames, ((0, 0), (0, segments), (0, segments * hop - framelen)))

# Reshape to a 4D tensor, splitting the framelen dimension in two
signal = signal.reshape((batch_size, num_frames + segments, segments, hop))
# Transpose dimensions so that signal.shape = (batch, segments, frame+segments, hop)
signal = np.transpose(signal, [0, 2, 1, 3])
# Reshape so that signal.shape = (batch, segments * (frame+segments), hop)
signal = signal.reshape((batch_size, -1, hop))

# Now behold the magic!! Remove the last n=segments elements from the second axis
signal = signal[:, :-segments]
# Reshape to (batch, segments, frame+segments-1, hop)
signal = signal.reshape((batch_size, segments, num_frames + segments - 1, hop))
# This has introduced a shift by one in all rows

# Now, reduce over the columns and flatten the array to achieve the result
signal = np.sum(signal, axis=0)
end = (len(x_frames) - 1) * hop + framelen
signal = signal.reshape(-1)[:end]
signal = np.sum(signal, axis=1)
end = (num_frames - 1) * hop + framelen
signal = signal.reshape((batch_size, -1))[:end]
return signal


def segment_frames(x, mask, seg_size):
segments = np.array([x[:, m - seg_size: m] for m in range(seg_size, x.shape[1] + 1)])
segments = segments.transpose([1, 0, 2, 3]) # put back batch in the first dimension
return segments * mask[:, seg_size - 1:, None, None]


def _mask_audio(x, mask):
return np.array([
np.pad(xi[mi], ((0, len(xi) - np.sum(mi)), (0, 0))) for xi, mi in zip(x, mask)
])


def remove_silent_frames(x, y, dyn_range, framelen, hop):
""" Remove silent frames of x and y based on x
A frame is excluded if its energy is lower than max(energy) - dyn_range
Expand All @@ -139,53 +151,43 @@ def remove_silent_frames(x, y, dyn_range, framelen, hop):
framelen : Window size for energy evaluation
hop : Hop size for energy evaluation
# Returns :
x without the silent frames
y without the silent frames (aligned to x)
x without the silent frames, zero-padded to the original length
y without the silent frames, zero-padded to the original length (aligned to x)
"""
# Compute Mask
w = np.hanning(framelen + 2)[1:-1]

x_frames = np.array(
[w * x[i:i + framelen] for i in range(0, len(x) - framelen, hop)])
[w * x[..., i : i + framelen] for i in range(0, x.shape[-1] - framelen, hop)]
).transpose([1, 0, 2])
y_frames = np.array(
[w * y[i:i + framelen] for i in range(0, len(x) - framelen, hop)])
[w * y[..., i : i + framelen] for i in range(0, x.shape[-1] - framelen, hop)]
).transpose([1, 0, 2])

# Compute energies in dB
x_energies = 20 * np.log10(np.linalg.norm(x_frames, axis=1) + EPS)
x_energies = 20 * np.log10(np.linalg.norm(x_frames, axis=-1) + EPS)

# Find boolean mask of energies lower than dynamic_range dB
# with respect to maximum clean speech energy frame
mask = (np.max(x_energies) - dyn_range - x_energies) < 0

# Remove silent frames by masking
x_frames = x_frames[mask]
y_frames = y_frames[mask]
# Remove silent frames and pad with zeroes
x_frames = _mask_audio(x_frames, mask)
y_frames = _mask_audio(y_frames, mask)

x_sil = _overlap_and_add(x_frames, hop)
y_sil = _overlap_and_add(y_frames, hop)

return x_sil, y_sil


def vect_two_norm(x, axis=-1):
""" Returns an array of vectors of norms of the rows of matrices from 3D array """
return np.sum(np.square(x), axis=axis, keepdims=True)


def row_col_normalize(x):
""" Row and column mean and variance normalize an array of 2D segments """
# Row mean and variance normalization
x_normed = x + EPS * np.random.standard_normal(x.shape)
# input shape (batch, num_segments, seg_size, bands)
# Row mean and variance normalization -- axis: seg_size
x_normed = x - np.mean(x, axis=-2, keepdims=True)
x_normed = x_normed / (np.linalg.norm(x_normed, axis=-2, keepdims=True) + EPS)
# Column mean and variance normalization -- axis: bands
x_normed -= np.mean(x_normed, axis=-1, keepdims=True)
x_inv = 1. / np.sqrt(vect_two_norm(x_normed))
x_diags = np.array(
[np.diag(x_inv[i].reshape(-1)) for i in range(x_inv.shape[0])])
x_normed = np.matmul(x_diags, x_normed)
# Column mean and variance normalization
x_normed += + EPS * np.random.standard_normal(x_normed.shape)
x_normed -= np.mean(x_normed, axis=1, keepdims=True)
x_inv = 1. / np.sqrt(vect_two_norm(x_normed, axis=1))
x_diags = np.array(
[np.diag(x_inv[i].reshape(-1)) for i in range(x_inv.shape[0])])
x_normed = np.matmul(x_normed, x_diags)
x_normed = x_normed / (np.linalg.norm(x_normed, axis=-1, keepdims=True) + EPS)
return x_normed
65 changes: 59 additions & 6 deletions tests/test_overlap_and_add.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose

from pystoi.stoi import N_FRAME
from pystoi.stoi import N_FRAME, stoi, FS
from pystoi.utils import _overlap_and_add


Expand All @@ -15,12 +16,64 @@ def old_overlap_and_app(x_frames, hop):
x_sil[range(i * hop, i * hop + framelen)] += x_frames[i, :]
return x_sil

batch_size = 4
# Initialize
x = np.random.randn(1000 * N_FRAME)
x = np.random.randn(batch_size, 1000 * N_FRAME)
# Add silence segment
silence = np.zeros(10 * N_FRAME)
x = np.concatenate([x[: 500 * N_FRAME], silence, x[500 * N_FRAME :]])
x = x.reshape([-1, N_FRAME])
xs = old_overlap_and_app(x, N_FRAME // 2)
silence = np.zeros((batch_size, 10 * N_FRAME))
x = np.concatenate([x[:, : 500 * N_FRAME], silence, x[:, 500 * N_FRAME :]], axis=1)
x = x.reshape([batch_size, -1, N_FRAME])
xs = [old_overlap_and_app(xi, N_FRAME // 2) for xi in x]
xs_vectorise = _overlap_and_add(x, N_FRAME // 2)
assert_allclose(xs, xs_vectorise)


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("fs", [10000, 16000])
@pytest.mark.parametrize("extended", [True, False])
def test_pystoi_run(batch_size, fs, extended):
N = fs * 4 # 4 seconds of random audio
x = np.random.randn(batch_size, N)
res = stoi(x, x, fs, extended)
print(batch_size, fs, extended, res)
assert res.shape == x.shape[:-1]


@pytest.mark.parametrize("extended", [True, False])
@pytest.mark.parametrize("batch_size", [1, 4])
def test_pystoi_complete_silence(batch_size, extended):
fs = 16000
N = fs * 4 # 4 seconds of random audio
x = np.zeros((batch_size, N))
res = stoi(x, x, fs, extended)
print(batch_size, fs, extended, res)
assert res.shape == x.shape[:-1]


@pytest.mark.parametrize("extended", [True, False])
def test_pystoi_silence(extended):
rng = np.random.default_rng(seed=0)
batch_size = 4
fs = 16000
N = fs * 4 # 4 seconds of random audio
x = np.random.randn(batch_size, N)
silence = np.random.randn(int(N / 7))
audio = []
for i in range(batch_size):
t = int(rng.random() * N)
audio.append(np.concatenate([x[i, :t], silence, x[i, t:]]))
audio = np.array(audio)
res = stoi(audio, audio, fs, extended)
print(batch_size, fs, extended, res)
assert res.shape == x.shape[:-1]


def test_vectorisation():
# Initialize batch of data
batch_size = 4
x = np.random.random((batch_size, 100 * N_FRAME))
y = np.random.random((batch_size, 100 * N_FRAME))
res = np.array([stoi(xi, yi, FS) for xi, yi in zip(x, y)])
res_vec = stoi(x, y, FS)
assert res_vec.shape == x.shape[:-1]
assert np.allclose(res, res_vec)