From d6ca0060a845fa3af9aabd8e998e5f5672792493 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 15:43:50 +0200 Subject: [PATCH 01/24] Add windowing function --- .../passes/defs/lower_complex_dialect_ops.py | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index ed36d87f3..18e4381c2 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -345,16 +345,7 @@ def _stft( # create a window of centered 1s of the requested size if win_length: - n_left = (n_fft.val - win_length.val) // 2 - n_right = n_fft.val - win_length.val - n_left - - left = mb.fill(shape=(n_left,), value=0., before_op=before_op) - if not window: - window = mb.fill(shape=(win_length.val,), value=1., before_op=before_op) - right = mb.fill(shape=(n_right,), value=0., before_op=before_op) - - # concatenate - window = mb.concat(values=(left, window, right), axis=0, before_op=before_op) + window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op) # apply time window if window: @@ -397,6 +388,23 @@ def _stft( return real_result, imag_result +def _get_window( + win_length: Var, + n_fft: Var, + before_op: Operation, +) -> Var: + n_left = (n_fft.val - win_length.val) // 2 + n_right = n_fft.val - win_length.val - n_left + + left = mb.fill(shape=(n_left,), value=0., before_op=before_op) + if not window: + window = mb.fill(shape=(win_length.val,), value=1., before_op=before_op) + right = mb.fill(shape=(n_right,), value=0., before_op=before_op) + + # concatenate + return mb.concat(values=(left, window, right), axis=0, before_op=before_op) + + def _wrap_complex_output(original_output: Var, real_data: Var, imag_data: Var) -> ComplexVar: return ComplexVar( name=original_output.name + "_lowered", From b101282b398d77a00f1db2d6b70e9e943afbf1b9 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 15:44:18 +0200 Subject: [PATCH 02/24] Add ovelap-add --- .../passes/defs/lower_complex_dialect_ops.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 18e4381c2..dc838168a 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -376,6 +376,10 @@ def _stft( real_result = cos_windows_real imag_result = sin_windows_real + # Overlap-add + real_result = _overlap_add(x=real_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + # reduce the rank of the output if should_increase_rank: real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op) @@ -388,6 +392,23 @@ def _stft( return real_result, imag_result +def _overlap_add( + x: Var, + n_fft: Var, + hop_length: Var, + before_op: Operation, +) -> Var: + n_frames = mb.shape(x=x, before_op=before_op)[1] + output = mb.fill(shape=(n_fft + hop_length * (n_frames - 1)), value=0., before_op=before_op) + signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op) + local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op) + + for frame_num, frame in enumerate(signal_frames): + global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op) + output = mb.scatter_nd(data=output, indices=global_idx, updates=frame, before_op=before_op) + + return output + def _get_window( win_length: Var, n_fft: Var, From 18fdef4d26e03d30de0ebcab73223733cbd9e8d9 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 15:48:48 +0200 Subject: [PATCH 03/24] Add ISTFT --- .../mil/mil/ops/defs/complex_dialect_ops.py | 82 +++++++++++++++++ .../passes/defs/lower_complex_dialect_ops.py | 92 +++++++++++++++++++ 2 files changed, 174 insertions(+) diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index e19bf1757..44d262d25 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -861,3 +861,85 @@ def type_inference(self): return types.tensor(output_type, tuple(output_shape)) +@register_op(namespace="complex") +class complex_istft(Operation): + """ + Dialect op for 1-D ISTFT. + + Parameters + ---------- + input: tensor<\*V, complex64> (Required) + * A complex tensor where real and imag parts have the same shape. + n_fft: const i32 (Required) + * Size of the fourier transform. + hop_length: const i32 (Optional) + * Stride between window frames of the input tensor. + win_length: const i32 (optional) + * The size of the window frame. + window: tensor<1, win_length> (optional) + * The window to apply to the input signal before performing the fourier transform. + normalized: const bool (optional, Default=``false``) + * Whether to normalize the results of the STFT + onesided: const bool (optional, Default=``true``) + * Whether the STFT was onesieded + length: const i32 (Required) + * Output fixed length, which will be zeropadded + + + Returns + ------- + tensor<\*D, T> + * The output tensor + + Attributes + ---------- + T: fp32, complex64 + + References + ---------- + See `torch.istft `_. + """ + + input_spec = InputSpec( + input=TensorInputType(type_domain="T"), + n_fft=TensorInputType(const=True, type_domain=types.int32), + hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32), + win_length=TensorInputType(const=True, optional=True, type_domain=types.int32), + window=TensorInputType(const=True, optional=True, type_domain=types.fp32), + normalized=TensorInputType(const=True, optional=True, type_domain=types.bool), + onesided=TensorInputType(const=True, optional=True, type_domain=types.bool), + length=TensorInputType(const=True, optional=True, type_domain=types.int32), + ) + + type_domains = { + "T": (types.fp32, types.complex64), + } + + def default_inputs(self): + return DefaultInputs( + hop_length = None, + win_length = None, + window = None, + normalized = False, + onesided = True, + length = None + ) + + def type_inference(self): + output_type = (types.fp32) + output_shape = [] + + # add back rank if needed + if self.input.rank == 2: + output_shape += [self.input.shape[0]] + + if self.length: + output_shape += [self.length] + return types.tensor(output_type, tuple(output_shape)) + + + n_frames = self.input.shape[-1] + output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1) + + return types.tensor(output_type, tuple(output_shape)) + diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index dc838168a..9ef0e1478 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -376,6 +376,80 @@ def _stft( real_result = cos_windows_real imag_result = sin_windows_real +def _istft( + input_real: Var, + input_imaginary: Var, + n_fft: Var, + hop_length: Optional[Var], + win_length: Optional[Var], + window: Optional[Var], + normalized: Optional[Var], + onesided: Optional[Var], + before_op: Operation, +) -> Tuple[Var, Var]: + """ + We can write ISTFT in terms of convolutions with a DFT kernel. + At the end: + * The real part output is: cos_base * input_real + sin_base * input_imag + * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag) + Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py + """ + # Set the default hop, if it's not already specified + hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op) + + # By default, use the entire frame + win_length = win_length or n_fft + + # input should always be 2D + should_increase_rank = input_real.rank == 1 + if should_increase_rank: + input_real = mb.expand_dims(x=input_real, axes=(0,), before_op=before_op) + if input_imaginary: + input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op) + + is_onesided = onesided and onesided.val + cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) + + # create a window of centered 1s of the requested size + if win_length: + window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op) + + # apply time window + if window: + cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) + sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) + + # The DFT matrix is obtained with the equation e^(2pi/N i), which is what we want but we actually need the conjuate => e^(-2pi/N i) + # or in terms of cos and sin => cos+i*sin cos-i*sin + sin_base = mb.sub(x=0., ysin_base, before_op=before_op) + + cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) + sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) + hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op) + + signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op) + signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) + + # Conv with DFT kernel across the input signal + # We can describe the IDFT in terms of DFT just by swapping the input and output + # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT + # So IDFT(x) = (1/N) * swap(DFT(swap(x))) + # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i) + # If x is complex then x[n]=(a+i*b) + # So the real part = (1/N)*Σ(a*cos(2kpi/N)-b*sin(2kpi/N)) + # So the imag part = (1/N)*Σ(b*cos(2kpi/N)+a*sin(2kpi/N)) + cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) + sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) + cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) + sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) + + real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) + + # Divide by N + real_result = mb.real_div(x=real_result, y=n_fft, before_op=before_op) + imag_result = mb.real_div(x=imag_result, y=n_fft, before_op=before_op) + # Overlap-add real_result = _overlap_add(x=real_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) @@ -638,6 +712,24 @@ def _lower_complex_stft(op: Operation): return _wrap_complex_output(op.outputs[0], real, imag) +@LowerComplex.register_lower_func(op_type="complex_istft") +def _lower_complex_istft(op: Operation): + is_complex = types.is_complex(op.input.dtype) + + # check parameters for validity + if op.win_length and op.win_length.val > op.n_fft.val: + raise ValueError("Window length must be less than or equal to n_fft") + if is_complex and op.onesided and op.onesided.val: + raise ValueError("Onesided is only valid for real inputs") + + real, imag = _istft( + op.input.real if is_complex else op.input, + op.input.imag if is_complex else None, + op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, before_op=op) + + return _wrap_complex_output(op.outputs[0], real, imag) + + @LowerComplex.register_lower_func(op_type="complex_shape") def _lower_complex_shape(op: Operation): return mb.shape(x=op.data.real, before_op=op) From 67b5db5036086f98ef8f3dfe1c298e95b1906e5a Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 16:01:55 +0200 Subject: [PATCH 04/24] Normalize by window square --- .../mil/passes/defs/lower_complex_dialect_ops.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 9ef0e1478..35cbad1c1 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -454,16 +454,20 @@ def _istft( real_result = _overlap_add(x=real_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + # Normalize by the window square + n_frames = mb.shape(x=real_result, before_op=before_op)[1] + window_square = mb.mul(x=window, y=window, before_op=before_op) + window_mtx = mb.stack(values=[window_square] * n_frames, axis=1) + normalization_factor = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + + real_result = mb.real_div(x=real_result, y=normalization_factor, before_op=before_op) + imag_result = mb.real_div(x=imag_result, y=normalization_factor, before_op=before_op) + # reduce the rank of the output if should_increase_rank: real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op) imag_result = mb.squeeze(x=imag_result, axes=(0,), before_op=before_op) - if normalized and normalized.val: - divisor = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op) - real_result = mb.real_div(x=real_result, y=divisor, before_op=before_op) - imag_result = mb.real_div(x=imag_result, y=divisor, before_op=before_op) - return real_result, imag_result def _overlap_add( @@ -473,7 +477,7 @@ def _overlap_add( before_op: Operation, ) -> Var: n_frames = mb.shape(x=x, before_op=before_op)[1] - output = mb.fill(shape=(n_fft + hop_length * (n_frames - 1)), value=0., before_op=before_op) + output = mb.fill(shape=(n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op) signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op) local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op) From 304b1ca40c63a24f81cb101c54e06f952bdbf0a7 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 19:03:14 +0200 Subject: [PATCH 05/24] Fix --- .../converters/mil/mil/passes/defs/lower_complex_dialect_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 35cbad1c1..3e94bd1f9 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -421,7 +421,7 @@ def _istft( # The DFT matrix is obtained with the equation e^(2pi/N i), which is what we want but we actually need the conjuate => e^(-2pi/N i) # or in terms of cos and sin => cos+i*sin cos-i*sin - sin_base = mb.sub(x=0., ysin_base, before_op=before_op) + sin_base = mb.sub(x=0., y=sin_base, before_op=before_op) cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) From ec1c232633bebbea6dd55c9b6471aa25f6e5538e Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 19:03:27 +0200 Subject: [PATCH 06/24] Add test --- .../mil/frontend/torch/test/test_torch_ops.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 242ec8740..3efd9df22 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9588,6 +9588,54 @@ def forward(self, x): compute_unit=compute_unit ) +class TestISTFT(TorchBaseTest): + @pytest.mark.slow + @pytest.mark.parametrize( + "compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length", + itertools.product( + compute_units, + backends, + [(1, 32, 9), (32, 9), (3, 32, 9)], # input shape + [False, True], # complex + [16], # n_fft + [None, 4, 5], # hop_length + [None, 16, 9], # win_length + [None, torch.hann_window], # window + [None, False, True], # center + ["constant", "reflect", "replicate"], # pad mode + [False, True], # normalized + [None, False, True], # onesided + [None, 60], # length + ) + ) + def test_istft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): + if complex and onesided: + pytest.skip("Onesided stft not possible for complex inputs") + + class ISTFTModel(torch.nn.Module): + def forward(self, x): + applied_window = window(win_length) if window and win_length else None + x = torch.complex(x, x) + x = torch.istft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=applied_window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + return_complex=True) + x = torch.stack([torch.real(x), torch.imag(x)], dim=0) + return x + + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) if _HAS_TORCH_AUDIO: From a9e49fda2effec44b50512f3b7afa0ba3cd0d8c5 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 19:07:05 +0200 Subject: [PATCH 07/24] Simplify stft --- .../passes/defs/lower_complex_dialect_ops.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 3e94bd1f9..30b6b4ddc 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -338,10 +338,7 @@ def _stft( input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op) is_onesided = onesided and onesided.val - cos_base, sin_base = _calculate_dft_matrix( - n_fft, - onesided=is_onesided, - before_op=before_op) + cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) # create a window of centered 1s of the requested size if win_length: @@ -352,29 +349,46 @@ def _stft( cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) - # conv with DFT kernel across the input signal - sin_base = mb.sub(x=0., y=sin_base, before_op=before_op) + + # Expand cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op) - signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op) + if input_imaginary: + signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) + + # conv with DFT kernel across the input signal + # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is: + # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i) + # If x is complex then x[n]=(a+i*b) + # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) + # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) - if input_imaginary: - signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) # add everything together if input_imaginary: - # sin base is already negative so subtract - real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) - imag_result = mb.add(x=sin_windows_real, y=cos_windows_imag, before_op=before_op) + real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) else: real_result = cos_windows_real - imag_result = sin_windows_real + imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op) + + # reduce the rank of the output + if should_increase_rank: + real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op) + imag_result = mb.squeeze(x=imag_result, axes=(0,), before_op=before_op) + + if normalized and normalized.val: + divisor = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op) + real_result = mb.real_div(x=real_result, y=divisor, before_op=before_op) + imag_result = mb.real_div(x=imag_result, y=divisor, before_op=before_op) + + return real_result, imag_result def _istft( input_real: Var, From 38aa86b264a666522d932c615732891d48bc882c Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 20:36:37 +0200 Subject: [PATCH 08/24] updates --- .../passes/defs/lower_complex_dialect_ops.py | 63 +++++++++++-------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 30b6b4ddc..4313ef53b 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -403,10 +403,12 @@ def _istft( ) -> Tuple[Var, Var]: """ We can write ISTFT in terms of convolutions with a DFT kernel. - At the end: - * The real part output is: cos_base * input_real + sin_base * input_imag - * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag) - Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py + + The input has shape (channels, fft_size, n_frames) + + References: + H. Zhivomirov, “On the Development of STFT-analysis and ISTFT-synthesis Routines and their Practical Implementation,” TEM Journal, vol. 8, no. 1, pp. 56–64, 2019. + https://en.wikipedia.org/wiki/Discrete_Fourier_transform """ # Set the default hop, if it's not already specified hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op) @@ -414,14 +416,12 @@ def _istft( # By default, use the entire frame win_length = win_length or n_fft - # input should always be 2D - should_increase_rank = input_real.rank == 1 - if should_increase_rank: - input_real = mb.expand_dims(x=input_real, axes=(0,), before_op=before_op) - if input_imaginary: - input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op) + input_shape = mb.shape(x=x, before_op=before_op) + n_frames = input_shape.val[-1] + fft_size = input_shape.val[-2] + expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) - is_onesided = onesided and onesided.val + is_onesided = onesided.val if onesided else fft_size != n_fft cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) # create a window of centered 1s of the requested size @@ -433,10 +433,6 @@ def _istft( cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) - # The DFT matrix is obtained with the equation e^(2pi/N i), which is what we want but we actually need the conjuate => e^(-2pi/N i) - # or in terms of cos and sin => cos+i*sin cos-i*sin - sin_base = mb.sub(x=0., y=sin_base, before_op=before_op) - cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op) @@ -444,21 +440,27 @@ def _istft( signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op) signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) + # De-normalized signal before applying the IFT + if normalized and normalized.val: + multiplier = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op) + signal_real = mb.mul(x=signal_real, y=multiplier, before_op=before_op) + signal_imaginary = mb.mul(x=signal_imaginary, y=multiplier, before_op=before_op) + # Conv with DFT kernel across the input signal # We can describe the IDFT in terms of DFT just by swapping the input and output # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT # So IDFT(x) = (1/N) * swap(DFT(swap(x))) - # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i) + # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i) # If x is complex then x[n]=(a+i*b) - # So the real part = (1/N)*Σ(a*cos(2kpi/N)-b*sin(2kpi/N)) - # So the imag part = (1/N)*Σ(b*cos(2kpi/N)+a*sin(2kpi/N)) + # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) + # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) - real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) - imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) + real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) # Divide by N real_result = mb.real_div(x=real_result, y=n_fft, before_op=before_op) @@ -472,10 +474,9 @@ def _istft( n_frames = mb.shape(x=real_result, before_op=before_op)[1] window_square = mb.mul(x=window, y=window, before_op=before_op) window_mtx = mb.stack(values=[window_square] * n_frames, axis=1) - normalization_factor = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op) - - real_result = mb.real_div(x=real_result, y=normalization_factor, before_op=before_op) - imag_result = mb.real_div(x=imag_result, y=normalization_factor, before_op=before_op) + window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op) + imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op) # reduce the rank of the output if should_increase_rank: @@ -490,13 +491,21 @@ def _overlap_add( hop_length: Var, before_op: Operation, ) -> Var: - n_frames = mb.shape(x=x, before_op=before_op)[1] - output = mb.fill(shape=(n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op) - signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op) + """ + The input has shape (channels, fft_size, n_frames) + """ + input_shape = mb.shape(x=x, before_op=before_op) + channels = input_shape.val[0] + n_frames = input_shape.val[2] + + output = mb.fill(shape=(channels, n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op) + signal_frames = mb.split(x=x, num_splits=n_frames, axis=2, before_op=before_op) local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op) for frame_num, frame in enumerate(signal_frames): global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op) + global_idx = mb.expand_dims(x=global_idx, axes=(0,), before_op=before_op) + global_idx = mb.stack(values=[global_idx] * channels, axis=0) output = mb.scatter_nd(data=output, indices=global_idx, updates=frame, before_op=before_op) return output From 82eefff7c2904c4e11e542f56ae36490e42079f9 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 24 Oct 2023 20:55:15 +0200 Subject: [PATCH 09/24] Try adding length --- .../mil/passes/defs/lower_complex_dialect_ops.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 4313ef53b..5fcac5ba2 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -399,6 +399,7 @@ def _istft( window: Optional[Var], normalized: Optional[Var], onesided: Optional[Var], + length: Optional[Var], before_op: Operation, ) -> Tuple[Var, Var]: """ @@ -419,7 +420,7 @@ def _istft( input_shape = mb.shape(x=x, before_op=before_op) n_frames = input_shape.val[-1] fft_size = input_shape.val[-2] - expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) + # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) is_onesided = onesided.val if onesided else fft_size != n_fft cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) @@ -478,10 +479,14 @@ def _istft( real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op) imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op) - # reduce the rank of the output - if should_increase_rank: - real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op) - imag_result = mb.squeeze(x=imag_result, axes=(0,), before_op=before_op) + # We need to adapt last dimension + if length is not None: + if length > expected_output_signal_len: + real_result = mb.pad(x=real_result, pad=, mode="constant", constant_val=0, before_op=before_op) + imag_result = mb.pad(x=imag_result, pad=, mode="constant", constant_val=0, before_op=before_op) + elif length < expected_output_signal_len: + real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length], before_op=before_op) + imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length], before_op=before_op) return real_result, imag_result From 667be1dc3099344436982d7f90e531af2b1beaef Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 31 Oct 2023 15:52:42 +0100 Subject: [PATCH 10/24] Fix padding --- .../mil/mil/passes/defs/lower_complex_dialect_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 5fcac5ba2..2c37c9b1e 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -482,8 +482,9 @@ def _istft( # We need to adapt last dimension if length is not None: if length > expected_output_signal_len: - real_result = mb.pad(x=real_result, pad=, mode="constant", constant_val=0, before_op=before_op) - imag_result = mb.pad(x=imag_result, pad=, mode="constant", constant_val=0, before_op=before_op) + right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op) + real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op) + imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op) elif length < expected_output_signal_len: real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length], before_op=before_op) imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length], before_op=before_op) From 0bfee9d43db395d58f98536d95a88fb46e057439 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Wed, 15 Nov 2023 16:39:57 +0100 Subject: [PATCH 11/24] Fixes --- .../mil/frontend/torch/test/test_torch_ops.py | 20 +++--- .../mil/mil/ops/defs/complex_dialect_ops.py | 6 +- .../passes/defs/lower_complex_dialect_ops.py | 65 ++++++++++--------- 3 files changed, 47 insertions(+), 44 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 3efd9df22..f3b4ef734 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9566,9 +9566,8 @@ def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_leng class STFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None - x = torch.complex(x, x) if complex else x x = torch.stft( - x, + torch.complex(x, x) if complex else x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, @@ -9596,28 +9595,26 @@ class TestISTFT(TorchBaseTest): compute_units, backends, [(1, 32, 9), (32, 9), (3, 32, 9)], # input shape - [False, True], # complex [16], # n_fft [None, 4, 5], # hop_length [None, 16, 9], # win_length [None, torch.hann_window], # window [None, False, True], # center - ["constant", "reflect", "replicate"], # pad mode [False, True], # normalized [None, False, True], # onesided [None, 60], # length + [False, True], # return_complex ) ) - def test_istft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): - if complex and onesided: - pytest.skip("Onesided stft not possible for complex inputs") + def test_istft(self, compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex): + if return_complex and onesided: + pytest.skip("Complex output is incompatible with onesided") class ISTFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None - x = torch.complex(x, x) x = torch.istft( - x, + torch.complex(x, x), n_fft=n_fft, hop_length=hop_length, win_length=win_length, @@ -9626,8 +9623,9 @@ def forward(self, x): normalized=normalized, onesided=onesided, length=length, - return_complex=True) - x = torch.stack([torch.real(x), torch.imag(x)], dim=0) + return_complex=return_complex) + if return_complex: + x = torch.stack([torch.real(x), torch.imag(x)], dim=0) return x TorchBaseTest.run_compare_torch( diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index 44d262d25..ea9c13ce4 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -893,6 +893,7 @@ class complex_istft(Operation): Attributes ---------- + V: complex64 T: fp32, complex64 References @@ -901,7 +902,7 @@ class complex_istft(Operation): """ input_spec = InputSpec( - input=TensorInputType(type_domain="T"), + input=TensorInputType(type_domain="V"), n_fft=TensorInputType(const=True, type_domain=types.int32), hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32), win_length=TensorInputType(const=True, optional=True, type_domain=types.int32), @@ -912,7 +913,7 @@ class complex_istft(Operation): ) type_domains = { - "T": (types.fp32, types.complex64), + "V": types.complex64, } def default_inputs(self): @@ -937,7 +938,6 @@ def type_inference(self): output_shape += [self.length] return types.tensor(output_type, tuple(output_shape)) - n_frames = self.input.shape[-1] output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 2c37c9b1e..b66f3eabc 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -325,7 +325,7 @@ def _stft( We can write STFT in terms of convolutions with a DFT kernel. At the end: * The real part output is: cos_base * input_real + sin_base * input_imag - * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag) + * The imaginary part output is: cos_base * input_imag - sin_base * input_real Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py """ hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op) @@ -342,14 +342,13 @@ def _stft( # create a window of centered 1s of the requested size if win_length: - window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op) + window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op) # apply time window if window: cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) - # Expand cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) @@ -358,12 +357,13 @@ def _stft( if input_imaginary: signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) - # conv with DFT kernel across the input signal - # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is: - # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i) - # If x is complex then x[n]=(a+i*b) - # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) - # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) + # Convolve the DFT kernel with the input signal + # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n]) + # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k)) + # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k)) + # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k): + # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k)) + # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) if input_imaginary: @@ -372,11 +372,11 @@ def _stft( # add everything together if input_imaginary: - real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) - imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) + real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) else: real_result = cos_windows_real - imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op) + imag_result = sin_windows_real # reduce the rank of the output if should_increase_rank: @@ -417,17 +417,18 @@ def _istft( # By default, use the entire frame win_length = win_length or n_fft - input_shape = mb.shape(x=x, before_op=before_op) - n_frames = input_shape.val[-1] - fft_size = input_shape.val[-2] - # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) + input_shape = mb.shape(x=input_real, before_op=before_op) + channels = input_shape.val[0] + fft_size = input_shape.val[1] + n_frames = input_shape.val[2] + expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) is_onesided = onesided.val if onesided else fft_size != n_fft cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) # create a window of centered 1s of the requested size if win_length: - window = _get_window(win_length=win_length, n_fft=n_fft, before_op=before_op) + window = _get_window(win_length=win_length, n_fft=n_fft, window=window, before_op=before_op) # apply time window if window: @@ -447,14 +448,13 @@ def _istft( signal_real = mb.mul(x=signal_real, y=multiplier, before_op=before_op) signal_imaginary = mb.mul(x=signal_imaginary, y=multiplier, before_op=before_op) - # Conv with DFT kernel across the input signal - # We can describe the IDFT in terms of DFT just by swapping the input and output + # Convolve the DFT kernel with the input signal + # We can describe the IDFT in terms of DFT just by swapping the input and output. # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT - # So IDFT(x) = (1/N) * swap(DFT(swap(x))) - # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i) - # If x is complex then x[n]=(a+i*b) - # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) - # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) + # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N + # So using the definition in stft function, we get: + # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n)) + # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) @@ -519,6 +519,7 @@ def _overlap_add( def _get_window( win_length: Var, n_fft: Var, + window: Optional[Var], before_op: Operation, ) -> Var: n_left = (n_fft.val - win_length.val) // 2 @@ -750,17 +751,21 @@ def _lower_complex_istft(op: Operation): is_complex = types.is_complex(op.input.dtype) # check parameters for validity + if is_complex: + raise ValueError("Only complex inputs are allowed") if op.win_length and op.win_length.val > op.n_fft.val: raise ValueError("Window length must be less than or equal to n_fft") - if is_complex and op.onesided and op.onesided.val: - raise ValueError("Onesided is only valid for real inputs") + if op.return_complex and op.onesided and op.onesided.val: + raise ValueError("Complex output is not compatible with onesided") real, imag = _istft( - op.input.real if is_complex else op.input, - op.input.imag if is_complex else None, - op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, before_op=op) + op.input.real, op.input.imag, + op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, op.length, before_op=op) - return _wrap_complex_output(op.outputs[0], real, imag) + if op.return_complex: + return _wrap_complex_output(op.outputs[0], real, imag) + else + return real @LowerComplex.register_lower_func(op_type="complex_shape") From 2605f55936042bbc550870d1e93f1a13a5196fc5 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Wed, 15 Nov 2023 17:05:03 +0100 Subject: [PATCH 12/24] Minor --- .../converters/mil/mil/passes/defs/lower_complex_dialect_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index b66f3eabc..80eda5169 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -764,7 +764,7 @@ def _lower_complex_istft(op: Operation): if op.return_complex: return _wrap_complex_output(op.outputs[0], real, imag) - else + else: return real From fe0ad056e5704b1655b580e7fd7dea26b98c071f Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 5 Dec 2023 13:55:59 +0100 Subject: [PATCH 13/24] Remove ISTFT test class and complex parameter --- .../converters/mil/frontend/torch/test/test_torch_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index f3b4ef734..056e99c92 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9587,10 +9587,9 @@ def forward(self, x): compute_unit=compute_unit ) -class TestISTFT(TorchBaseTest): @pytest.mark.slow @pytest.mark.parametrize( - "compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, length", + "compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex", itertools.product( compute_units, backends, From e0cfffd0186f807becd943a4b6b889d3f7f0c2bb Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Mon, 1 Jan 2024 19:40:43 +0100 Subject: [PATCH 14/24] Fixes --- .../converters/mil/frontend/torch/ops.py | 26 +++++++++- .../mil/frontend/torch/test/test_torch_ops.py | 7 +-- .../mil/mil/ops/defs/complex_dialect_ops.py | 29 +++++------ .../passes/defs/lower_complex_dialect_ops.py | 49 +++++++++++-------- 4 files changed, 70 insertions(+), 41 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 0755d4d9d..5781e9689 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -6362,6 +6362,7 @@ def stft(context, node): Lowers torch.stft with the dialect op `complex_stft` from complex_dialect_ops.py """ input_data, n_fft, hop_length, win_length, window, normalized, onesided, _ = _get_inputs(context, node, min_expected=2) + if types.is_complex(input_data.dtype): onesided = False # pytorch defaults onesided to False for complex inputs stft_res = mb.complex_stft( @@ -6371,9 +6372,32 @@ def stft(context, node): win_length=win_length, window=window, normalized=normalized, - onesided=onesided) + onesided=onesided + ) context.add(stft_res, node.name) +@register_torch_op +def istft(context, node): + """ + Lowers torch.istft with the dialect op `complex_istft` from complex_dialect_ops.py + """ + input_data, n_fft, hop_length, win_length, window, center, normalized, onesided, length, _ = _get_inputs(context, node, min_expected=2) + + if types.is_complex(input_data.dtype): + onesided = False # pytorch defaults onesided to False for complex inputs + istft_res = mb.complex_istft( + input=input_data, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + ) + context.add(istft_res, node.name) + @register_torch_op(torch_alias=["torchvision::nms"]) def torchvision_nms(context, node): inputs = _get_inputs(context, node, expected=3) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 056e99c92..ca4259f73 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9589,12 +9589,11 @@ def forward(self, x): @pytest.mark.slow @pytest.mark.parametrize( - "compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex", + "compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex", itertools.product( compute_units, backends, [(1, 32, 9), (32, 9), (3, 32, 9)], # input shape - [16], # n_fft [None, 4, 5], # hop_length [None, 16, 9], # win_length [None, torch.hann_window], # window @@ -9605,10 +9604,12 @@ def forward(self, x): [False, True], # return_complex ) ) - def test_istft(self, compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex): + def test_istft(self, compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex): if return_complex and onesided: pytest.skip("Complex output is incompatible with onesided") + n_fft = input_shape[1] + class ISTFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index ea9c13ce4..5b0a5008a 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -893,7 +893,6 @@ class complex_istft(Operation): Attributes ---------- - V: complex64 T: fp32, complex64 References @@ -902,20 +901,18 @@ class complex_istft(Operation): """ input_spec = InputSpec( - input=TensorInputType(type_domain="V"), + input=TensorInputType(type_domain=types.complex), n_fft=TensorInputType(const=True, type_domain=types.int32), hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32), win_length=TensorInputType(const=True, optional=True, type_domain=types.int32), window=TensorInputType(const=True, optional=True, type_domain=types.fp32), - normalized=TensorInputType(const=True, optional=True, type_domain=types.bool), + center=TensorInputType(const=True, type_domain=types.bool), + normalized=TensorInputType(const=True, optional=False, type_domain=types.bool), onesided=TensorInputType(const=True, optional=True, type_domain=types.bool), length=TensorInputType(const=True, optional=True, type_domain=types.int32), + return_complex=TensorInputType(const=True, optional=True, type_domain=types.bool), ) - type_domains = { - "V": types.complex64, - } - def default_inputs(self): return DefaultInputs( hop_length = None, @@ -923,23 +920,21 @@ def default_inputs(self): window = None, normalized = False, onesided = True, - length = None + length = None, + return_complex = True, ) def type_inference(self): - output_type = (types.fp32) - output_shape = [] + output_type = (types.complex64) if self.return_complex else (types.fp32) - # add back rank if needed - if self.input.rank == 2: - output_shape += [self.input.shape[0]] + # add batch size if given + output_shape = [self.input.shape[0] if self.input.rank == 3 else 1] if self.length: output_shape += [self.length] - return types.tensor(output_type, tuple(output_shape)) + else: + n_frames = self.input.shape[-1] + output_shape += [self.n_fft.val + self.hop_length.val * (n_frames - 1)] - n_frames = self.input.shape[-1] - output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1) return types.tensor(output_type, tuple(output_shape)) - diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 80eda5169..4b824a92e 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -397,6 +397,7 @@ def _istft( hop_length: Optional[Var], win_length: Optional[Var], window: Optional[Var], + center: Optional[Var], normalized: Optional[Var], onesided: Optional[Var], length: Optional[Var], @@ -435,12 +436,10 @@ def _istft( cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) - cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) - sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op) - signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op) - signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) + signal_real = input_real + signal_imaginary = input_imaginary # De-normalized signal before applying the IFT if normalized and normalized.val: @@ -455,15 +454,16 @@ def _istft( # So using the definition in stft function, we get: # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n)) # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n)) - cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) - sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) - cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) - sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) + cos_windows_real = mb.matmul(x=signal_real, y=cos_base, transpose_x=True, before_op=before_op) + sin_windows_real = mb.matmul(x=signal_real, y=sin_base, transpose_x=True, before_op=before_op) + cos_windows_imag = mb.matmul(x=signal_imaginary, y=cos_base, transpose_x=True, before_op=before_op) + sin_windows_imag = mb.matmul(x=signal_imaginary, y=sin_base, transpose_x=True, before_op=before_op) real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) # Divide by N + n_fft = mb.cast(x=n_fft, dtype="fp32", before_op=before_op) real_result = mb.real_div(x=real_result, y=n_fft, before_op=before_op) imag_result = mb.real_div(x=imag_result, y=n_fft, before_op=before_op) @@ -472,9 +472,9 @@ def _istft( imag_result = _overlap_add(x=imag_result, n_fft=n_fft, hop_length=hop_length, before_op=before_op) # Normalize by the window square - n_frames = mb.shape(x=real_result, before_op=before_op)[1] window_square = mb.mul(x=window, y=window, before_op=before_op) - window_mtx = mb.stack(values=[window_square] * n_frames, axis=1) + window_mtx = mb.stack(values=[window_square] * n_frames, axis=0, before_op=before_op) + window_mtx = mb.expand_dims(x=window_mtx, axes=(0,), before_op=before_op) window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op) real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op) imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op) @@ -502,17 +502,27 @@ def _overlap_add( """ input_shape = mb.shape(x=x, before_op=before_op) channels = input_shape.val[0] - n_frames = input_shape.val[2] + n_frames = input_shape.val[1] + + # Create empty output with final shape + output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1))), value=0., before_op=before_op) - output = mb.fill(shape=(channels, n_fft.val + hop_length.val * (n_frames - 1)), value=0., before_op=before_op) - signal_frames = mb.split(x=x, num_splits=n_frames, axis=2, before_op=before_op) + # Create an index used later on overlap add + n_fft = mb.cast(x=n_fft, dtype="int32", before_op=before_op) local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op) + # Split data into frames and iterate + signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op) + for frame_num, frame in enumerate(signal_frames): + frame = mb.squeeze(x=frame, axes=[1], before_op=before_op) + + # Create index to align data frames global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op) - global_idx = mb.expand_dims(x=global_idx, axes=(0,), before_op=before_op) - global_idx = mb.stack(values=[global_idx] * channels, axis=0) - output = mb.scatter_nd(data=output, indices=global_idx, updates=frame, before_op=before_op) + global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op) + + # Add data frame + output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1, mode="add", before_op=before_op) return output @@ -748,11 +758,10 @@ def _lower_complex_stft(op: Operation): @LowerComplex.register_lower_func(op_type="complex_istft") def _lower_complex_istft(op: Operation): - is_complex = types.is_complex(op.input.dtype) # check parameters for validity - if is_complex: - raise ValueError("Only complex inputs are allowed") + if not types.is_complex(op.input.dtype): + raise TypeError("Input type must be complex") if op.win_length and op.win_length.val > op.n_fft.val: raise ValueError("Window length must be less than or equal to n_fft") if op.return_complex and op.onesided and op.onesided.val: @@ -760,7 +769,7 @@ def _lower_complex_istft(op: Operation): real, imag = _istft( op.input.real, op.input.imag, - op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, op.length, before_op=op) + op.n_fft, op.hop_length, op.win_length, op.window, op.center, op.normalized, op.onesided, op.length, before_op=op) if op.return_complex: return _wrap_complex_output(op.outputs[0], real, imag) From c2686c8199e6de65091cd079eb335a4febc27f9b Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Mon, 1 Jan 2024 19:50:37 +0100 Subject: [PATCH 15/24] More fixes --- .../converters/mil/frontend/torch/test/test_torch_ops.py | 2 +- .../converters/mil/mil/ops/defs/complex_dialect_ops.py | 4 +++- .../mil/mil/passes/defs/lower_complex_dialect_ops.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index ca4259f73..1157fe24e 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9597,7 +9597,7 @@ def forward(self, x): [None, 4, 5], # hop_length [None, 16, 9], # win_length [None, torch.hann_window], # window - [None, False, True], # center + [False, True], # center [False, True], # normalized [None, False, True], # onesided [None, 60], # length diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index 5b0a5008a..a5f3bdf52 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -934,7 +934,9 @@ def type_inference(self): output_shape += [self.length] else: n_frames = self.input.shape[-1] - output_shape += [self.n_fft.val + self.hop_length.val * (n_frames - 1)] + + hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4 + output_shape += [self.n_fft.val + hop_length * (n_frames - 1)] return types.tensor(output_type, tuple(output_shape)) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 4b824a92e..d91901cb0 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -481,11 +481,11 @@ def _istft( # We need to adapt last dimension if length is not None: - if length > expected_output_signal_len: + if length.val > expected_output_signal_len: right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op) real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op) imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op) - elif length < expected_output_signal_len: + elif length.val < expected_output_signal_len: real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length], before_op=before_op) imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length], before_op=before_op) From 0d3238aa85637ceade487b51b3d44134eec4f069 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Mon, 1 Jan 2024 20:20:42 +0100 Subject: [PATCH 16/24] More fixes --- .../mil/frontend/torch/test/test_torch_ops.py | 11 +++--- .../mil/mil/ops/defs/complex_dialect_ops.py | 1 - .../passes/defs/lower_complex_dialect_ops.py | 34 +++++++++++++------ 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 1157fe24e..50e10fa15 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9589,11 +9589,13 @@ def forward(self, x): @pytest.mark.slow @pytest.mark.parametrize( - "compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex", + "compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex", itertools.product( compute_units, backends, - [(1, 32, 9), (32, 9), (3, 32, 9)], # input shape + [None, 1, 3], # channels + [16, 32], # n_fft + [5, 9], # num_frames [None, 4, 5], # hop_length [None, 16, 9], # win_length [None, torch.hann_window], # window @@ -9604,11 +9606,12 @@ def forward(self, x): [False, True], # return_complex ) ) - def test_istft(self, compute_unit, backend, input_shape, hop_length, win_length, window, center, normalized, onesided, length, return_complex): + def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex): if return_complex and onesided: pytest.skip("Complex output is incompatible with onesided") - n_fft = input_shape[1] + freq = n_fft*2+1 if onesided else n_fft + input_shape = (channels, freq, num_frames) if channels else (freq, num_frames) class ISTFTModel(torch.nn.Module): def forward(self, x): diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index a5f3bdf52..547c71633 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -938,5 +938,4 @@ def type_inference(self): hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4 output_shape += [self.n_fft.val + hop_length * (n_frames - 1)] - return types.tensor(output_type, tuple(output_shape)) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index d91901cb0..d9443cdee 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -419,9 +419,12 @@ def _istft( win_length = win_length or n_fft input_shape = mb.shape(x=input_real, before_op=before_op) - channels = input_shape.val[0] - fft_size = input_shape.val[1] - n_frames = input_shape.val[2] + if input_shape.rank == 3: + channels, fft_size, n_frames = input_shape.val + else: + channels = None + fft_size, n_frames = input_shape.val + expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) is_onesided = onesided.val if onesided else fft_size != n_fft @@ -482,12 +485,16 @@ def _istft( # We need to adapt last dimension if length is not None: if length.val > expected_output_signal_len: - right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op) + if channels: + right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op) + else: + right_pad = mb.fill(shape=(expected_output_signal_len - length,), value=0., before_op=before_op) + real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op) imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op) elif length.val < expected_output_signal_len: - real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length], before_op=before_op) - imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length], before_op=before_op) + real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op) + imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op) return real_result, imag_result @@ -498,14 +505,18 @@ def _overlap_add( before_op: Operation, ) -> Var: """ - The input has shape (channels, fft_size, n_frames) + The input has shape (channels, n_frames, fft_size) """ input_shape = mb.shape(x=x, before_op=before_op) - channels = input_shape.val[0] - n_frames = input_shape.val[1] # Create empty output with final shape - output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1))), value=0., before_op=before_op) + if input_shape.rank == 3: + channels, n_frames = input_shape.val + output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1))), value=0., before_op=before_op) + else: + channels = None + n_frames= input_shape.val + output = mb.fill(shape=(int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op) # Create an index used later on overlap add n_fft = mb.cast(x=n_fft, dtype="int32", before_op=before_op) @@ -519,7 +530,8 @@ def _overlap_add( # Create index to align data frames global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op) - global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op) + if channels: + global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op) # Add data frame output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1, mode="add", before_op=before_op) From 4edfd202831f6148f7814e3ba95fdf704b1b982a Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Mon, 1 Jan 2024 20:44:25 +0100 Subject: [PATCH 17/24] More fixes --- .../mil/mil/passes/defs/lower_complex_dialect_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index d9443cdee..c2e7da9d4 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -511,11 +511,11 @@ def _overlap_add( # Create empty output with final shape if input_shape.rank == 3: - channels, n_frames = input_shape.val - output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1))), value=0., before_op=before_op) + channels, n_frames, _= input_shape.val + output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op) else: channels = None - n_frames= input_shape.val + n_frames, _ = input_shape.val output = mb.fill(shape=(int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op) # Create an index used later on overlap add From 729d5d4d31ff53b92ba0b567c0c0340d42f63b30 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Mon, 1 Jan 2024 21:39:22 +0100 Subject: [PATCH 18/24] More fixes --- .../converters/mil/mil/ops/defs/complex_dialect_ops.py | 1 - .../mil/mil/passes/defs/lower_complex_dialect_ops.py | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index 547c71633..75c935bc7 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -934,7 +934,6 @@ def type_inference(self): output_shape += [self.length] else: n_frames = self.input.shape[-1] - hop_length = self.hop_length.val if self.hop_length else self.n_fft.val // 4 output_shape += [self.n_fft.val + hop_length * (n_frames - 1)] diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index c2e7da9d4..f86937fdb 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -419,7 +419,7 @@ def _istft( win_length = win_length or n_fft input_shape = mb.shape(x=input_real, before_op=before_op) - if input_shape.rank == 3: + if input_real.rank == 3: channels, fft_size, n_frames = input_shape.val else: channels = None @@ -510,7 +510,7 @@ def _overlap_add( input_shape = mb.shape(x=x, before_op=before_op) # Create empty output with final shape - if input_shape.rank == 3: + if x.rank == 3: channels, n_frames, _= input_shape.val output = mb.fill(shape=(channels, int(n_fft.val + hop_length.val * (n_frames - 1)),), value=0., before_op=before_op) else: @@ -523,10 +523,10 @@ def _overlap_add( local_idx = mb.range_1d(start=0, end=n_fft, step=1, before_op=before_op) # Split data into frames and iterate - signal_frames = mb.split(x=x, num_splits=n_frames, axis=1, before_op=before_op) + signal_frames = mb.split(x=x, num_splits=n_frames, axis=1 if channels else 0, before_op=before_op) for frame_num, frame in enumerate(signal_frames): - frame = mb.squeeze(x=frame, axes=[1], before_op=before_op) + frame = mb.squeeze(x=frame, axes=[1] if channels else [0], before_op=before_op) # Create index to align data frames global_idx = mb.add(x=local_idx , y=frame_num*hop_length.val, before_op=before_op) @@ -534,7 +534,7 @@ def _overlap_add( global_idx = mb.stack(values=[global_idx] * channels, axis=0, before_op=before_op) # Add data frame - output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1, mode="add", before_op=before_op) + output = mb.scatter_along_axis(data=output, indices=global_idx, updates=frame, axis=1 if channels else 0, mode="add", before_op=before_op) return output From 866d61e16ddaa2b1ba4e03376f6de2b08e2797fe Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 2 Jan 2024 01:43:33 +0100 Subject: [PATCH 19/24] Fixes --- .../mil/frontend/torch/test/test_torch_ops.py | 36 +++++++++++++++---- .../passes/defs/lower_complex_dialect_ops.py | 8 ++--- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 50e10fa15..ff06ab289 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9610,7 +9610,7 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len if return_complex and onesided: pytest.skip("Complex output is incompatible with onesided") - freq = n_fft*2+1 if onesided else n_fft + freq = n_fft//2+1 if onesided else n_fft input_shape = (channels, freq, num_frames) if channels else (freq, num_frames) class ISTFTModel(torch.nn.Module): @@ -9631,12 +9631,34 @@ def forward(self, x): x = torch.stack([torch.real(x), torch.imag(x)], dim=0) return x - TorchBaseTest.run_compare_torch( - input_shape, - ISTFTModel(), - backend=backend, - compute_unit=compute_unit - ) + if length is not None or center is False: + # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033 + with pytest.raises( + RuntimeError, match="istft\(.*\) window overlap add min: 1" + ): + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) + elif return_complex is False: + with pytest.raises( + ValueError, match="MIL doesn't support complex data as model's output" + ): + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) + else: + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) if _HAS_TORCH_AUDIO: diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index f86937fdb..135b596c7 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -486,12 +486,12 @@ def _istft( if length is not None: if length.val > expected_output_signal_len: if channels: - right_pad = mb.fill(shape=(channels, expected_output_signal_len - length), value=0., before_op=before_op) + right_pad = mb.fill(shape=(channels, length.val - expected_output_signal_len ), value=0., before_op=before_op) else: - right_pad = mb.fill(shape=(expected_output_signal_len - length,), value=0., before_op=before_op) + right_pad = mb.fill(shape=(length.val - expected_output_signal_len,), value=0., before_op=before_op) - real_result = mb.stack(x=(real_result, right_pad), axis=1, before_op=before_op) - imag_result = mb.stack(x=(imag_result, right_pad), axis=1, before_op=before_op) + real_result = mb.stack(values=(real_result, right_pad), axis=1, before_op=before_op) + imag_result = mb.stack(values=(imag_result, right_pad), axis=1, before_op=before_op) elif length.val < expected_output_signal_len: real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op) imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op) From 186110c844ff929dd5c4a3bc5ac8696005e6bb15 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 2 Jan 2024 12:52:48 +0100 Subject: [PATCH 20/24] Fixes --- .../mil/frontend/torch/test/test_torch_ops.py | 26 +++++++------------ .../passes/defs/lower_complex_dialect_ops.py | 18 ++++++------- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index ff06ab289..352f29a47 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9597,7 +9597,7 @@ def forward(self, x): [16, 32], # n_fft [5, 9], # num_frames [None, 4, 5], # hop_length - [None, 16, 9], # win_length + [None, 10, 8], # win_length [None, torch.hann_window], # window [False, True], # center [False, True], # normalized @@ -9610,6 +9610,9 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len if return_complex and onesided: pytest.skip("Complex output is incompatible with onesided") + if hop_length is None and win_length is not None: + pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length") + freq = n_fft//2+1 if onesided else n_fft input_shape = (channels, freq, num_frames) if channels else (freq, num_frames) @@ -9628,24 +9631,13 @@ def forward(self, x): length=length, return_complex=return_complex) if return_complex: - x = torch.stack([torch.real(x), torch.imag(x)], dim=0) - return x + return torch.stack([torch.real(x), torch.imag(x)], dim=0) + else: + return torch.real(x) - if length is not None or center is False: + if win_length and center is False: # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033 - with pytest.raises( - RuntimeError, match="istft\(.*\) window overlap add min: 1" - ): - TorchBaseTest.run_compare_torch( - input_shape, - ISTFTModel(), - backend=backend, - compute_unit=compute_unit - ) - elif return_complex is False: - with pytest.raises( - ValueError, match="MIL doesn't support complex data as model's output" - ): + with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"): TorchBaseTest.run_compare_torch( input_shape, ISTFTModel(), diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 135b596c7..8438d56b5 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -427,7 +427,7 @@ def _istft( expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) - is_onesided = onesided.val if onesided else fft_size != n_fft + is_onesided = True if fft_size != n_fft.val else onesided and onesided.val cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) # create a window of centered 1s of the requested size @@ -481,20 +481,18 @@ def _istft( window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op) real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op) imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op) - # We need to adapt last dimension if length is not None: if length.val > expected_output_signal_len: + real_result = mb.pad(x=real_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op) + imag_result = mb.pad(x=imag_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op) + elif length.val < expected_output_signal_len: if channels: - right_pad = mb.fill(shape=(channels, length.val - expected_output_signal_len ), value=0., before_op=before_op) + real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op) + imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op) else: - right_pad = mb.fill(shape=(length.val - expected_output_signal_len,), value=0., before_op=before_op) - - real_result = mb.stack(values=(real_result, right_pad), axis=1, before_op=before_op) - imag_result = mb.stack(values=(imag_result, right_pad), axis=1, before_op=before_op) - elif length.val < expected_output_signal_len: - real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op) - imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op) + real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op) + imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op) return real_result, imag_result From 8cb63ed986339fa1de3e0c00a9ebbdf2e1b63618 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 2 Jan 2024 13:52:59 +0100 Subject: [PATCH 21/24] More fixes --- .../mil/frontend/torch/test/test_torch_ops.py | 10 +++++++++- .../mil/mil/passes/defs/lower_complex_dialect_ops.py | 10 ++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 352f29a47..66b55c277 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9602,7 +9602,7 @@ def forward(self, x): [False, True], # center [False, True], # normalized [None, False, True], # onesided - [None, 60], # length + [None, 30, 40], # length [False, True], # return_complex ) ) @@ -9644,6 +9644,14 @@ def forward(self, x): backend=backend, compute_unit=compute_unit ) + elif length is not None and return_complex is True: + with pytest.raises(ValueError, match="New var type `.tensor'>` not a subtype of existing var type `.tensor'>`"): + TorchBaseTest.run_compare_torch( + input_shape, + ISTFTModel(), + backend=backend, + compute_unit=compute_unit + ) else: TorchBaseTest.run_compare_torch( input_shape, diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 8438d56b5..5caf3b673 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -479,6 +479,8 @@ def _istft( window_mtx = mb.stack(values=[window_square] * n_frames, axis=0, before_op=before_op) window_mtx = mb.expand_dims(x=window_mtx, axes=(0,), before_op=before_op) window_envelope = _overlap_add(x=window_mtx, n_fft=n_fft, hop_length=hop_length, before_op=before_op) + + # After this operation if it didn't have any channels dimention it adds one real_result = mb.real_div(x=real_result, y=window_envelope, before_op=before_op) imag_result = mb.real_div(x=imag_result, y=window_envelope, before_op=before_op) # We need to adapt last dimension @@ -487,12 +489,8 @@ def _istft( real_result = mb.pad(x=real_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op) imag_result = mb.pad(x=imag_result, pad=(0, length.val - expected_output_signal_len), before_op=before_op) elif length.val < expected_output_signal_len: - if channels: - real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op) - imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op) - else: - real_result = mb.slice_by_size(x=real_result, begin=[0], size=[length.val], before_op=before_op) - imag_result = mb.slice_by_size(x=imag_result, begin=[0], size=[length.val], before_op=before_op) + real_result = mb.slice_by_size(x=real_result, begin=[0,0], size=[-1, length.val], before_op=before_op) + imag_result = mb.slice_by_size(x=imag_result, begin=[0,0], size=[-1, length.val], before_op=before_op) return real_result, imag_result From e016e68afa1e4ba0ccab6c8b374bbeababe4db80 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 2 Jan 2024 20:12:10 +0100 Subject: [PATCH 22/24] More fixes --- .../mil/frontend/torch/test/test_torch_ops.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 66b55c277..153a5d4c0 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9596,13 +9596,13 @@ def forward(self, x): [None, 1, 3], # channels [16, 32], # n_fft [5, 9], # num_frames - [None, 4, 5], # hop_length + [None, 5], # hop_length [None, 10, 8], # win_length [None, torch.hann_window], # window [False, True], # center [False, True], # normalized [None, False, True], # onesided - [None, 30, 40], # length + [None, "shorter", "larger"], # length [False, True], # return_complex ) ) @@ -9613,9 +9613,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len if hop_length is None and win_length is not None: pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length") + # Compute input_shape to generate test case freq = n_fft//2+1 if onesided else n_fft input_shape = (channels, freq, num_frames) if channels else (freq, num_frames) + # If not set,c ompute hop_length for capturing errors + if hop_length is None: + hop_length = n_fft // 4 + + if length == "shorter": + length = n_fft//2 + hop_length * (num_frames - 1) + elif length == "larger": + length = n_fft*3//2 + hop_length * (num_frames - 1) + class ISTFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None @@ -9635,7 +9645,7 @@ def forward(self, x): else: return torch.real(x) - if win_length and center is False: + if (center is False and win_length) or (center and win_length and length): # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033 with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"): TorchBaseTest.run_compare_torch( @@ -9644,7 +9654,7 @@ def forward(self, x): backend=backend, compute_unit=compute_unit ) - elif length is not None and return_complex is True: + elif length and return_complex: with pytest.raises(ValueError, match="New var type `.tensor'>` not a subtype of existing var type `.tensor'>`"): TorchBaseTest.run_compare_torch( input_shape, From c1cf5ec1e28ba9543d0aa9b2d07f3d53ff5ecd02 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 9 Jan 2024 13:54:49 +0100 Subject: [PATCH 23/24] Seems DFT is the correct one and not the conjugate --- .../mil/mil/passes/defs/lower_complex_dialect_ops.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 5caf3b673..b3ab30bf6 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -361,9 +361,6 @@ def _stft( # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n]) # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k)) # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k)) - # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k): - # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k)) - # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) if input_imaginary: @@ -372,11 +369,11 @@ def _stft( # add everything together if input_imaginary: - real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) - imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) + real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) else: real_result = cos_windows_real - imag_result = sin_windows_real + imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op) # reduce the rank of the output if should_increase_rank: From 6738aba7b978ee095958c23d754903afa0cdff39 Mon Sep 17 00:00:00 2001 From: Junpei Zhou Date: Wed, 10 Jan 2024 14:44:49 -0800 Subject: [PATCH 24/24] Removes pytest.mark.slow decorator to run STFT related tests in CI. --- .../converters/mil/frontend/torch/test/test_torch_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 153a5d4c0..279846c2a 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9540,8 +9540,8 @@ def forward(self, x): (2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit ) + class TestSTFT(TorchBaseTest): - @pytest.mark.slow @pytest.mark.parametrize( "compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided", itertools.product( @@ -9587,7 +9587,6 @@ def forward(self, x): compute_unit=compute_unit ) - @pytest.mark.slow @pytest.mark.parametrize( "compute_unit, backend, channels, n_fft, num_frames, hop_length, win_length, window, center, normalized, onesided, length, return_complex", itertools.product(