diff --git a/src/libtorchaudio/lfilter.cpp b/src/libtorchaudio/lfilter.cpp index 9d9b05c7d8..acb5f9607a 100644 --- a/src/libtorchaudio/lfilter.cpp +++ b/src/libtorchaudio/lfilter.cpp @@ -73,27 +73,6 @@ void cpu_lfilter_core_loop( }); } -void lfilter_core_generic_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { - int64_t n_samples_input = input_signal_windows.size(2); - int64_t n_order = a_coeff_flipped.size(1); - auto coeff = a_coeff_flipped.unsqueeze(2); - for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - auto windowed_output_signal = - torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order) - .transpose(0, 1); - auto o0 = torch::select(input_signal_windows, 2, i_sample) - - at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); - padded_output_waveform.index_put_( - {torch::indexing::Slice(), - torch::indexing::Slice(), - i_sample + n_order - 1}, - o0); - } -} - } // namespace TORCH_LIBRARY(torchaudio, m) { @@ -110,7 +89,3 @@ TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop); } #endif - -TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { - m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop); -} diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 1a7aa3e37e..f302e602bd 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -932,8 +932,17 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T padded_output_waveform[:, :, i_sample + n_order - 1] = o0 +def _lfilter_core_loop_dispatcher( + input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor +): + if input_signal_windows.is_cuda or input_signal_windows.is_cpu: + return torch.ops.torchaudio._lfilter_core_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) + else: + return _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) + + if _IS_TORCHAUDIO_EXT_AVAILABLE: - _lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop + _lfilter_core_loop = _lfilter_core_loop_dispatcher else: _lfilter_core_loop = _lfilter_core_generic_loop