From 77d44213eaed6440f13f187852fbd28172982d51 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 16:01:53 +0000 Subject: [PATCH 01/31] Port autograd code for rnnt --- src/libtorchaudio/rnnt/autograd.cpp | 51 ++----------------------- src/torchaudio/functional/functional.py | 17 ++++++++- 2 files changed, 19 insertions(+), 49 deletions(-) diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp index dcf68409ed..5ba545cb99 100644 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ b/src/libtorchaudio/rnnt/autograd.cpp @@ -3,31 +3,7 @@ namespace torchaudio { namespace rnnt { -class RNNTLossFunction : public torch::autograd::Function { - public: - static torch::autograd::tensor_list forward( - torch::autograd::AutogradContext* ctx, - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - torch::Tensor undef; - auto result = rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - auto costs = std::get<0>(result); - auto grads = std::get<1>(result).value_or(undef); - ctx->save_for_backward({grads}); - return {costs, grads}; - } + static torch::autograd::tensor_list backward( torch::autograd::AutogradContext* ctx, @@ -39,31 +15,10 @@ class RNNTLossFunction : public torch::autograd::Function { torch::Tensor undef; return {result, undef, undef, undef, undef, undef, undef, undef}; } -}; - -std::tuple> rnnt_loss_autograd( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - at::AutoDispatchBelowADInplaceOrView guard; - auto results = RNNTLossFunction::apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - return std::make_tuple(results[0], results[1]); } -TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { - m.impl("rnnt_loss", rnnt_loss_autograd); +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } -} // namespace rnnt } // namespace torchaudio diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 42dde06814..e25194dbd5 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1760,6 +1760,21 @@ def _fix_waveform_shape( waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:]) return waveform_shift +class RnntLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args) + ctx.save_for_backward(saved) + return output + + @staticmethod + def backward(ctx, dy): + grad = ctx.saved_tensors[0] + grad_out = dy.view((-1, 1, 1, 1)) + result = grad * grad_out; + return (result, None, None, None, None, None, None, None) + +torch.ops.torchaudio.rnnt_loss_forward def _rnnt_loss( logits: Tensor, @@ -1803,7 +1818,7 @@ def _rnnt_loss( if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank - costs, _ = torch.ops.torchaudio.rnnt_loss( + costs = RnntLoss.apply( logits=logits, targets=targets, logit_lengths=logit_lengths, From 725c74e9c579eb5d14b9c2f58375d7a3acf299c7 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 16:04:34 +0000 Subject: [PATCH 02/31] Correct rnnt calling arguments --- src/torchaudio/functional/functional.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index e25194dbd5..8abd075546 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1819,13 +1819,13 @@ def _rnnt_loss( blank = logits.shape[-1] + blank costs = RnntLoss.apply( - logits=logits, - targets=targets, - logit_lengths=logit_lengths, - target_lengths=target_lengths, - blank=blank, - clamp=clamp, - fused_log_softmax=fused_log_softmax, + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax ) if reduction == "mean": From 97176519c5935cde5558a6af32d268fa28637ea1 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 16:50:05 +0000 Subject: [PATCH 03/31] Disable torchscript checks --- .github/scripts/unittest-linux/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index f311c8370e..dacde20bea 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -30,5 +30,5 @@ fi ( cd test - pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs" + pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not torchscript" ) From 2b882503158935fd99fb62dea63e387a2d8d3534 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 17:18:57 +0000 Subject: [PATCH 04/31] Restrict disabling of torchscript tests --- .github/scripts/unittest-linux/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index dacde20bea..559b55437a 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -30,5 +30,5 @@ fi ( cd test - pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not torchscript" + pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt)" ) From 116de6f9a2778602b0f0462d3fdf67c6784d97ff Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 17:30:34 +0000 Subject: [PATCH 05/31] Remove leftover line --- src/torchaudio/functional/functional.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 8abd075546..f955fe7840 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1774,8 +1774,6 @@ def backward(ctx, dy): result = grad * grad_out; return (result, None, None, None, None, None, None, None) -torch.ops.torchaudio.rnnt_loss_forward - def _rnnt_loss( logits: Tensor, targets: Tensor, From 003b3a9d810c1cf926627b201efa9e17a3fb1838 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 14 Jul 2025 14:36:49 +0000 Subject: [PATCH 06/31] Remove unnecessary backward code --- src/libtorchaudio/rnnt/autograd.cpp | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp index 5ba545cb99..05b767194d 100644 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ b/src/libtorchaudio/rnnt/autograd.cpp @@ -1,22 +1,8 @@ #include namespace torchaudio { -namespace rnnt { - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto grad = saved[0]; - auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); - auto result = grad * grad_out; - torch::Tensor undef; - return {result, undef, undef, undef, undef, undef, undef, undef}; - } -} - TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } From 7727ad773ed9f01e16e34164d2c4a742f77cdd6c Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 14 Jul 2025 14:50:59 +0000 Subject: [PATCH 07/31] Move rnnt_loss_forward to compute.cpp --- src/libtorchaudio/rnnt/autograd.cpp | 10 ---------- src/libtorchaudio/rnnt/compute.cpp | 1 + 2 files changed, 1 insertion(+), 10 deletions(-) delete mode 100644 src/libtorchaudio/rnnt/autograd.cpp diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp deleted file mode 100644 index 05b767194d..0000000000 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include - -namespace torchaudio { - - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); -} - -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 567c9b5d4b..5aba334cee 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -30,4 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { "int blank," "float clamp," "bool fused_log_softmax) -> (Tensor, Tensor?)"); + m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } From 9b9dc2573f8736bed0c86c5a8ee271cbd11cfc1d Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 14 Jul 2025 16:09:07 +0000 Subject: [PATCH 08/31] Remove autograd rnnt in cmakelists --- src/libtorchaudio/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 713cb50533..85bc227cd6 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -28,7 +28,6 @@ if(BUILD_RNNT) rnnt/compute_alphas.cpp rnnt/compute_betas.cpp rnnt/compute.cpp - rnnt/autograd.cpp ) if (USE_CUDA) list( From d4dd7bdced278b1fff2549db042246331078f2c1 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 15 Jul 2025 19:45:49 +0000 Subject: [PATCH 09/31] Convert cpu/compute_alphas to stable API --- src/libtorchaudio/rnnt/compute_alphas.cpp | 3 +- src/libtorchaudio/rnnt/cpu/compute_alphas.cpp | 133 +++++++++++++----- 2 files changed, 97 insertions(+), 39 deletions(-) diff --git a/src/libtorchaudio/rnnt/compute_alphas.cpp b/src/libtorchaudio/rnnt/compute_alphas.cpp index adbcc1c8e7..40c07a2115 100644 --- a/src/libtorchaudio/rnnt/compute_alphas.cpp +++ b/src/libtorchaudio/rnnt/compute_alphas.cpp @@ -1,6 +1,7 @@ #include +#include -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "rnnt_loss_alphas(Tensor logits," "Tensor targets," diff --git a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp index 6923cbe5d8..2174ab9b2a 100644 --- a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp @@ -1,68 +1,125 @@ #include #include +#include +#include +#include + +// TODO: +// Are the StableIValue AtenTensorHandles reference counted at all? +// Why do we call release() on returned arguments? namespace torchaudio { namespace rnnt { namespace cpu { -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + +RAIIATH compute_alphas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; options.blank_ = blank; options.clamp_ = clamp; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + // TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); options.device_ = CPU; - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; + + AtenTensorHandle alphas; + aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas); + aoti_torch_zero_(alphas); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &int_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *alpha_ptr; + aoti_torch_get_data_ptr(alphas, &alpha_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeAlphas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*alphas=*/(float*)alpha_ptr); + return RAIIATH(alphas); +} + +void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_alphas", &boxed_compute_alphas); } } // namespace cpu From 0ff3b57d87d75d37c75a5026c76f6266f30b80c7 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 15 Jul 2025 20:07:25 +0000 Subject: [PATCH 10/31] Add back device type check --- src/libtorchaudio/rnnt/cpu/compute_alphas.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp index 2174ab9b2a..30ff4fa587 100644 --- a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp @@ -37,9 +37,11 @@ RAIIATH compute_alphas( options.blank_ = blank; options.clamp_ = clamp; - // TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); - options.device_ = CPU; + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cpu()); + options.device_ = CPU; int32_t logits_device; aoti_torch_get_device_type(logits.get(), &logits_device); From 696202e2ee7e43b78e0428cbe55213ca0467a729 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 15 Jul 2025 20:24:04 +0000 Subject: [PATCH 11/31] Use stable ABI for compute_betas --- src/libtorchaudio/rnnt/compute_alphas.cpp | 1 - src/libtorchaudio/rnnt/compute_betas.cpp | 4 +- src/libtorchaudio/rnnt/cpu/compute_alphas.cpp | 3 +- src/libtorchaudio/rnnt/cpu/compute_betas.cpp | 141 ++++++++++++------ 4 files changed, 102 insertions(+), 47 deletions(-) diff --git a/src/libtorchaudio/rnnt/compute_alphas.cpp b/src/libtorchaudio/rnnt/compute_alphas.cpp index 40c07a2115..dd187f9777 100644 --- a/src/libtorchaudio/rnnt/compute_alphas.cpp +++ b/src/libtorchaudio/rnnt/compute_alphas.cpp @@ -1,4 +1,3 @@ -#include #include STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { diff --git a/src/libtorchaudio/rnnt/compute_betas.cpp b/src/libtorchaudio/rnnt/compute_betas.cpp index 7728838137..b1cd379a66 100644 --- a/src/libtorchaudio/rnnt/compute_betas.cpp +++ b/src/libtorchaudio/rnnt/compute_betas.cpp @@ -1,6 +1,6 @@ -#include +#include -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "rnnt_loss_betas(Tensor logits," "Tensor targets," diff --git a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp index 30ff4fa587..40ed538175 100644 --- a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include @@ -63,7 +62,7 @@ RAIIATH compute_alphas( aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); AtenTensorHandle float_workspace; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &int_workspace); + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); int64_t float_numel; aoti_torch_get_numel(float_workspace, &float_numel); diff --git a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp b/src/libtorchaudio/rnnt/cpu/compute_betas.cpp index d812ef34c3..729e86a722 100644 --- a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute_betas.cpp @@ -1,73 +1,130 @@ #include #include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace cpu { -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + +RAIIATH compute_betas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; options.blank_ = blank; options.clamp_ = clamp; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cpu()); + options.device_ = CPU; - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + int64_t cost_sizes[1] = {options.batchSize_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; + AtenTensorHandle betas; + aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas); - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + AtenTensorHandle int_workspace; + int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *beta_ptr; + aoti_torch_get_data_ptr(betas, &beta_ptr); + + void *cost_ptr; + aoti_torch_get_data_ptr(costs, &cost_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeBetas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)cost_ptr, + /*betas=*/(float*)beta_ptr); + return RAIIATH(betas); +} + + +void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_betas", &compute_betas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_betas", &boxed_compute_betas); } } // namespace cpu From 32c80da8a7734c08d1b4f8e8835c54613ffeb5bd Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 15 Jul 2025 20:34:26 +0000 Subject: [PATCH 12/31] Use stable ABI for cuda version of compute_alphas --- src/libtorchaudio/rnnt/gpu/compute_alphas.cu | 136 +++++++++++++------ 1 file changed, 95 insertions(+), 41 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu index bde40daa9f..89ce5d01dd 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu @@ -1,71 +1,125 @@ #include #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + +RAIIATH compute_alphas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; options.blank_ = blank; options.clamp_ = clamp; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda()); + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + aoti_torch_get_current_cuda_stream(logits_device_index, &options.stream_); + cudaSetDevice(logits_device) options.device_ = GPU; - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle alphas; + aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas); + aoti_torch_zero_(alphas); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *alpha_ptr; + aoti_torch_get_data_ptr(alphas, &alpha_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeAlphas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*alphas=*/(float*)alpha_ptr); + return RAIIATH(alphas); +} + +void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_alphas", &boxed_compute_alphas); } } // namespace gpu From 5490cd3bca5b78e3b6cb164a17bad0a739161aff Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 15 Jul 2025 20:38:55 +0000 Subject: [PATCH 13/31] Use stable ABI for cuda version of compute_betas --- src/libtorchaudio/rnnt/gpu/compute_betas.cu | 155 +++++++++++++------- 1 file changed, 106 insertions(+), 49 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute_betas.cu b/src/libtorchaudio/rnnt/gpu/compute_betas.cu index 18857c4388..25cf70dea3 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_betas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_betas.cu @@ -1,76 +1,133 @@ #include #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + + +RAIIATH compute_betas( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); + Options options; + int64_t tmp; + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); + options.batchSize_ = (int)tmp; + aoti_torch_get_size(target_lengths.get(), 0, &tmp); + options.nHypos_ = (int)tmp; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &tmp); + options.maxSrcLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 2, &tmp); + options.maxTgtLen_ = (int)tmp; + aoti_torch_get_size(logits.get(), 3, &tmp); + options.numTargets_ = (int)tmp; + options.blank_ = blank; + options.clamp_ = clamp; + + int32_t logits_device_type; + aoti_torch_get_device_type(logits.get(), &logits_device_type); + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda()); + + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + + aoti_torch_get_current_cuda_stream(logits_device_index, &options.stream_); + cudaSetDevice(logits_device) options.device_ = GPU; - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int64_t cost_sizes[1] = {options.batchSize_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; + int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; + AtenTensorHandle betas; + aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle int_workspace; + int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *beta_ptr; + aoti_torch_get_data_ptr(betas, &beta_ptr); + + void *cost_ptr; + aoti_torch_get_data_ptr(costs, &cost_ptr); // Only support float, this is mainly to enable easy // unit-testing ComputeBetas( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)cost_ptr, + /*betas=*/(float*)beta_ptr); + return RAIIATH(betas); +} + +void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp); + stack[0] = from(result.release()); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_betas", &compute_betas); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_betas", &boxed_compute_betas); } } // namespace gpu From 89d480b16fad258f2e0cc7f2856c89d5a930c70f Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 17 Jul 2025 17:49:14 +0000 Subject: [PATCH 14/31] Add missing semicolon --- src/libtorchaudio/rnnt/gpu/compute_alphas.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu index 89ce5d01dd..90e421ab4a 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu @@ -44,8 +44,8 @@ RAIIATH compute_alphas( int32_t logits_dtype; aoti_torch_get_dtype(logits.get(), &logits_dtype); - aoti_torch_get_current_cuda_stream(logits_device_index, &options.stream_); - cudaSetDevice(logits_device) + aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + cudaSetDevice(logits_device); options.device_ = GPU; int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; From 7f11d1d22f34893a58e64b18279b47fde12534bc Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 17 Jul 2025 17:57:19 +0000 Subject: [PATCH 15/31] Cast to void pointer pointer --- src/libtorchaudio/rnnt/gpu/compute_betas.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute_betas.cu b/src/libtorchaudio/rnnt/gpu/compute_betas.cu index 25cf70dea3..7bed017b14 100644 --- a/src/libtorchaudio/rnnt/gpu/compute_betas.cu +++ b/src/libtorchaudio/rnnt/gpu/compute_betas.cu @@ -46,8 +46,8 @@ RAIIATH compute_betas( int32_t logits_dtype; aoti_torch_get_dtype(logits.get(), &logits_dtype); - aoti_torch_get_current_cuda_stream(logits_device_index, &options.stream_); - cudaSetDevice(logits_device) + aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + cudaSetDevice(logits_device); options.device_ = GPU; int64_t cost_sizes[1] = {options.batchSize_}; From cc592e0d9e2eb19995f9f59e1e393cca27651271 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 17 Jul 2025 19:28:56 +0000 Subject: [PATCH 16/31] Use stable ABI for compute --- src/libtorchaudio/rnnt/compute.cpp | 28 +-- src/libtorchaudio/rnnt/compute.h | 12 - src/libtorchaudio/rnnt/cpu/compute.cpp | 294 ++++++++++++++--------- src/libtorchaudio/rnnt/gpu/compute.cu | 307 ++++++++++++++---------- src/torchaudio/functional/functional.py | 2 +- 5 files changed, 369 insertions(+), 274 deletions(-) delete mode 100644 src/libtorchaudio/rnnt/compute.h diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 5aba334cee..5074cd0d32 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -1,34 +1,12 @@ -#include +#include -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - static auto op = torch::Dispatcher::singleton() - .findSchemaOrThrow("torchaudio::rnnt_loss", "") - .typed(); - return op.call( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); -} - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( - "rnnt_loss(Tensor logits," + "torchaudio::rnnt_loss(Tensor logits," "Tensor targets," "Tensor logit_lengths," "Tensor target_lengths," "int blank," "float clamp," "bool fused_log_softmax) -> (Tensor, Tensor?)"); - m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } diff --git a/src/libtorchaudio/rnnt/compute.h b/src/libtorchaudio/rnnt/compute.h deleted file mode 100644 index ed2dd0c37e..0000000000 --- a/src/libtorchaudio/rnnt/compute.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax); diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 097b4bd7e1..a9864e345c 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -1,148 +1,212 @@ #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace cpu { +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), - "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), - "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), - "logits and target_lengths must be on the same device"); - - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, - "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, - "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, - "target_lengths must be int32 type"); - - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( - logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( - target_lengths.is_contiguous(), "target_lengths must be contiguous"); - - TORCH_CHECK( - logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( - targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - - TORCH_CHECK( - logit_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( - target_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( - targets.size(0) == logits.size(0), - "batch dimension mismatch between logits and targets"); - - TORCH_CHECK( - blank >= 0 && blank < logits.size(-1), - "blank must be within [0, logits.shape[-1])"); - - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, - "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), - "target length mismatch"); + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + int32_t logit_lengths_device; + aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); + int32_t target_lengths_device; + aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); + + AOTI_TORCH_CHECK(logits_device == targets_device); + AOTI_TORCH_CHECK(logits_device == logit_lengths_device); + AOTI_TORCH_CHECK(logits_device == target_lengths_device); + + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t targets_dtype; + aoti_torch_get_dtype(targets.get(), &targets_dtype); + AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t logit_lengths_dtype; + aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); + AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || + logit_lengths_dtype == aoti_torch_dtype_float16()); + + int32_t target_lengths_dtype; + aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); + AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || + target_lengths_dtype == aoti_torch_dtype_float16()); + + bool bool_tmp; + aoti_torch_is_contiguous(logits.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(targets.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); + + int64_t int_tmp; + aoti_torch_get_dim(logits.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 4); + aoti_torch_get_dim(targets.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 2); + aoti_torch_get_dim(logit_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + aoti_torch_get_dim(target_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + + int64_t logit_lengths_size; + aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); + int64_t logits_size; + aoti_torch_get_size(logits.get(), 0, &logits_size); + AOTI_TORCH_CHECK(logit_lengths_size == logits_size); + int64_t target_lengths_size; + aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); + AOTI_TORCH_CHECK(target_lengths_size == logits_size); + int64_t targets_size; + aoti_torch_get_size(targets.get(), 0, &targets_size); + AOTI_TORCH_CHECK(targets_size == logits_size); + + // TORCH_CHECK( + // blank >= 0 && blank < logits.size(-1), + // "blank must be within [0, logits.shape[-1])"); + + // TORCH_CHECK( + // logits.size(1) == at::max(logit_lengths).item().toInt(), + // "input length mismatch"); + // TORCH_CHECK( + // logits.size(2) == at::max(target_lengths).item().toInt() + 1, + // "output length mismatch"); + // TORCH_CHECK( + // targets.size(1) == at::max(target_lengths).item().toInt(), + // "target length mismatch"); Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + options.batchSize_ = (int)logit_lengths_size; + options.nHypos_ = (int)target_lengths_size; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &int_tmp); + options.maxSrcLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 2, &int_tmp); + options.maxTgtLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 3, &int_tmp); + options.numTargets_ = (int)int_tmp; options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + AOTI_TORCH_CHECK(logits_device == aoti_torch_device_type_cpu()); options.device_ = CPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + AtenTensorHandle gradients; + aoti_torch_clone(logits.get(), &gradients); + aoti_torch_zero_(gradients); + + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *costs_ptr; + aoti_torch_get_data_ptr(costs, &costs_ptr); + + void *grads_ptr; + aoti_torch_get_data_ptr(gradients, &grads_ptr); + + if (logits_dtype == aoti_torch_dtype_float32()) { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; - } - case torch::ScalarType::Half: { + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)costs_ptr, + /*gradients=*/(float*)grads_ptr); + } else { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; + /*logits=*/(c10::Half*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(c10::Half*)costs_ptr, + /*gradients=*/(c10::Half*)grads_ptr); } - default: { - break; - } - }; - return std::make_tuple(costs, gradients); + return std::make_tuple(RAIIATH(costs), RAIIATH(gradients)); +} + +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + bool fused_log_softmax = to(stack[6]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp, fused_log_softmax); + stack[0] = from((std::get<0>(result)).release()); + stack[1] = from((std::get<1>(result)).release()); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("torchaudio::rnnt_loss", &boxed_compute); } } // namespace cpu diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 43dae68027..1073b18a81 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -1,151 +1,216 @@ #include #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const RAIIATH logits, + const RAIIATH targets, + const RAIIATH logit_lengths, + const RAIIATH target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), - "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), - "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), - "logits and target_lengths must be on the same device"); - - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, - "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, - "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, - "target_lengths must be int32 type"); - - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( - logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( - target_lengths.is_contiguous(), "target_lengths must be contiguous"); - - TORCH_CHECK( - logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( - targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - - TORCH_CHECK( - logit_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( - target_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( - targets.size(0) == logits.size(0), - "batch dimension mismatch between logits and targets"); - - TORCH_CHECK( - blank >= 0 && blank < logits.size(-1), - "blank must be within [0, logits.shape[-1])"); - - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, - "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), - "target length mismatch"); - - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - options.fusedLogSmax_ = fused_log_softmax; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + int32_t logit_lengths_device; + aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); + int32_t target_lengths_device; + aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); + + AOTI_TORCH_CHECK(logits_device == targets_device); + AOTI_TORCH_CHECK(logits_device == logit_lengths_device); + AOTI_TORCH_CHECK(logits_device == target_lengths_device); + + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t targets_dtype; + aoti_torch_get_dtype(targets.get(), &targets_dtype); + AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t logit_lengths_dtype; + aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); + AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || + logit_lengths_dtype == aoti_torch_dtype_float16()); + + int32_t target_lengths_dtype; + aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); + AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || + target_lengths_dtype == aoti_torch_dtype_float16()); + + bool bool_tmp; + aoti_torch_is_contiguous(logits.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(targets.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); + + int64_t int_tmp; + aoti_torch_get_dim(logits.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 4); + aoti_torch_get_dim(targets.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 2); + aoti_torch_get_dim(logit_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + aoti_torch_get_dim(target_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + + int64_t logit_lengths_size; + aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); + int64_t logits_size; + aoti_torch_get_size(logits.get(), 0, &logits_size); + AOTI_TORCH_CHECK(logit_lengths_size == logits_size); + int64_t target_lengths_size; + aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); + AOTI_TORCH_CHECK(target_lengths_size == logits_size); + int64_t targets_size; + aoti_torch_get_size(targets.get(), 0, &targets_size); + AOTI_TORCH_CHECK(targets_size == logits_size); + + // TORCH_CHECK( + // blank >= 0 && blank < logits.size(-1), + // "blank must be within [0, logits.shape[-1])"); + + // TORCH_CHECK( + // logits.size(1) == at::max(logit_lengths).item().toInt(), + // "input length mismatch"); + // TORCH_CHECK( + // logits.size(2) == at::max(target_lengths).item().toInt() + 1, + // "output length mismatch"); + // TORCH_CHECK( + // targets.size(1) == at::max(target_lengths).item().toInt(), + // "target length mismatch"); + + Options options; + options.batchSize_ = (int)logit_lengths_size; + options.nHypos_ = (int)target_lengths_size; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &int_tmp); + options.maxSrcLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 2, &int_tmp); + options.maxTgtLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 3, &int_tmp); + options.numTargets_ = (int)int_tmp; + options.blank_ = blank; + options.clamp_ = clamp; + options.fusedLogSmax_ = fused_log_softmax; + + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + + TORCH_CHECK_EQ(logits_device, aoti_torch_device_type_cuda()); + aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + cudaSetDevice(logits_device); options.device_ = GPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); + int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + AtenTensorHandle gradients; + aoti_torch_clone(logits.get(), &gradients); + aoti_torch_zero_(gradients); + + AtenTensorHandle int_workspace; + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + AtenTensorHandle float_workspace; + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *costs_ptr; + aoti_torch_get_data_ptr(costs, &costs_ptr); + + void *grads_ptr; + aoti_torch_get_data_ptr(gradients, &grads_ptr); + + if (logits_dtype == aoti_torch_dtype_float32()) { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; - } - case torch::ScalarType::Half: { + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)costs_ptr, + /*gradients=*/(float*)grads_ptr); + } else { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; + /*logits=*/(c10::Half*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(c10::Half*)costs_ptr, + /*gradients=*/(c10::Half*)grads_ptr); } - default: { - break; - } - }; - return std::make_tuple(costs, gradients); + return std::make_tuple(RAIIATH(costs), RAIIATH(gradients)); +} + +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + RAIIATH t1(to(stack[0])); + RAIIATH t2(to(stack[1])); + RAIIATH t3(to(stack[2])); + RAIIATH t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + bool fused_log_softmax = to(stack[6]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp, fused_log_softmax); + stack[0] = from((std::get<0>(result)).release()); + stack[1] = from((std::get<1>(result)).release()); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("torchaudio::rnnt_loss", &boxed_compute); } } // namespace gpu diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index f955fe7840..06d6bbf8c3 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1763,7 +1763,7 @@ def _fix_waveform_shape( class RnntLoss(torch.autograd.Function): @staticmethod def forward(ctx, *args): - output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args) + output, saved = torch.ops.torchaudio.rnnt_loss(*args) ctx.save_for_backward(saved) return output From 180a393e2250d84c032c12685f10339b6a95e480 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 17 Jul 2025 20:33:14 +0000 Subject: [PATCH 17/31] Attempt to fix stable ABI calls --- src/libtorchaudio/rnnt/compute.cpp | 6 ++++-- src/libtorchaudio/rnnt/cpu/compute.cpp | 2 +- src/torchaudio/functional/functional.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 5074cd0d32..867542e4e7 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -1,12 +1,14 @@ #include +#include + STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( - "torchaudio::rnnt_loss(Tensor logits," + "rnnt_loss(Tensor logits," "Tensor targets," "Tensor logit_lengths," "Tensor target_lengths," "int blank," "float clamp," - "bool fused_log_softmax) -> (Tensor, Tensor?)"); + "bool fused_log_softmax) -> (Tensor, Tensor)"); } diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index a9864e345c..817f79ef99 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -206,7 +206,7 @@ void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) } STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("torchaudio::rnnt_loss", &boxed_compute); + m.impl("rnnt_loss", &boxed_compute); } } // namespace cpu diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 06d6bbf8c3..b278d96bd4 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1763,7 +1763,7 @@ def _fix_waveform_shape( class RnntLoss(torch.autograd.Function): @staticmethod def forward(ctx, *args): - output, saved = torch.ops.torchaudio.rnnt_loss(*args) + output, saved = torch.ops.torchaudio.rnnt_loss.default(*args) ctx.save_for_backward(saved) return output From 577ff0cc5b999d421a597564f034ee513144e5e0 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 04:10:09 +0000 Subject: [PATCH 18/31] Use stable Tensor interface --- src/libtorchaudio/rnnt/cpu/compute.cpp | 32 +++++++++++-------- src/libtorchaudio/rnnt/gpu/compute.cu | 20 ++++++------ .../functional/functional_impl.py | 7 ++++ 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 817f79ef99..db84a0833b 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -1,20 +1,23 @@ #include #include -#include #include +#include + + +#include namespace torchaudio { namespace rnnt { namespace cpu { -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; +using torch::stable::Tensor; // Entry point into RNNT Loss -std::tuple compute( - const RAIIATH logits, - const RAIIATH targets, - const RAIIATH logit_lengths, - const RAIIATH target_lengths, +std::tuple compute( + const Tensor logits, + const Tensor targets, + const Tensor logit_lengths, + const Tensor target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { @@ -187,24 +190,25 @@ std::tuple compute( /*gradients=*/(c10::Half*)grads_ptr); } - return std::make_tuple(RAIIATH(costs), RAIIATH(gradients)); + return std::make_tuple(Tensor(costs), Tensor(gradients)); } void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t1(to(stack[0])); - RAIIATH t2(to(stack[1])); - RAIIATH t3(to(stack[2])); - RAIIATH t4(to(stack[3])); + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); int64_t blank = to(stack[4]); double clamp = to(stack[5]); bool fused_log_softmax = to(stack[6]); auto result = compute( std::move(t1), std::move(t2), std::move(t3), std::move(t4), blank, clamp, fused_log_softmax); - stack[0] = from((std::get<0>(result)).release()); - stack[1] = from((std::get<1>(result)).release()); + stack[0] = from(std::get<0>(result)); + stack[1] = from(std::get<1>(result)); } + STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { m.impl("rnnt_loss", &boxed_compute); } diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 1073b18a81..4e7fb731a2 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -8,14 +8,14 @@ namespace torchaudio { namespace rnnt { namespace gpu { -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; +using torch::stable::Tensor; // Entry point into RNNT Loss -std::tuple compute( - const RAIIATH logits, - const RAIIATH targets, - const RAIIATH logit_lengths, - const RAIIATH target_lengths, +std::tuple compute( + const Tensor logits, + const Tensor targets, + const Tensor logit_lengths, + const Tensor target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { @@ -191,7 +191,7 @@ std::tuple compute( /*gradients=*/(c10::Half*)grads_ptr); } - return std::make_tuple(RAIIATH(costs), RAIIATH(gradients)); + return std::make_tuple(Tensor(costs), Tensor(gradients)); } void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { @@ -205,12 +205,12 @@ void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) auto result = compute( std::move(t1), std::move(t2), std::move(t3), std::move(t4), blank, clamp, fused_log_softmax); - stack[0] = from((std::get<0>(result)).release()); - stack[1] = from((std::get<1>(result)).release()); + stack[0] = from(std::get<0>(result)); + stack[1] = from(std::get<1>(result)); } STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("torchaudio::rnnt_loss", &boxed_compute); + m.impl("rnnt_loss", &boxed_compute); } } // namespace gpu diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index b08b63256c..4844d9cf6a 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -622,6 +622,13 @@ def test_rnnt_loss_basic_backward(self): loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths) loss.backward() + def test_mytest(self): + print("Got here") + logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device) + result = F.rnnt_loss(logits, targets, logit_lengths, target_lengths) + print("DONE") + # result.sum().backward() + def test_rnnt_loss_basic_forward_no_grad(self): """In early stage, calls to `rnnt_loss` resulted in segmentation fault when `logits` have `requires_grad = False`. This test makes sure that this no longer From 926ca7d7f39cc5a67449771c9b47edb0c219a39d Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 04:30:09 +0000 Subject: [PATCH 19/31] Correct use of stable Tensor --- src/libtorchaudio/rnnt/gpu/compute.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 4e7fb731a2..cce71a7005 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -195,10 +195,10 @@ std::tuple compute( } void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t1(to(stack[0])); - RAIIATH t2(to(stack[1])); - RAIIATH t3(to(stack[2])); - RAIIATH t4(to(stack[3])); + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); int64_t blank = to(stack[4]); double clamp = to(stack[5]); bool fused_log_softmax = to(stack[6]); From 9f75cdfcba33fdaa3465b19ed13fb3494bebffa8 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 04:49:46 +0000 Subject: [PATCH 20/31] WIP --- src/libtorchaudio/rnnt/cpu/compute.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index db84a0833b..071f68b043 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -102,14 +102,18 @@ std::tuple compute( Options options; options.batchSize_ = (int)logit_lengths_size; - options.nHypos_ = (int)target_lengths_size; - options.nHypos_ /= options.batchSize_; + options.nHypos_ = (int)(target_lengths_size / options.batchSize_); aoti_torch_get_size(logits.get(), 1, &int_tmp); options.maxSrcLen_ = (int)int_tmp; aoti_torch_get_size(logits.get(), 2, &int_tmp); options.maxTgtLen_ = (int)int_tmp; aoti_torch_get_size(logits.get(), 3, &int_tmp); options.numTargets_ = (int)int_tmp; + printf("src %d\n", options.maxSrcLen_); + printf("tgt %d\n", options.maxTgtLen_); + printf("nh %d\n", options.nHypos_); + printf("bs %d\n", options.batchSize_); + # TODO: check what these should be options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; @@ -130,6 +134,7 @@ std::tuple compute( AtenTensorHandle int_workspace; int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + printf("SIZES: %ld\n", sizes[0]); int64_t strides[1] = {1}; aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); @@ -144,6 +149,7 @@ std::tuple compute( aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); int64_t int_numel; aoti_torch_get_numel(int_workspace, &int_numel); + printf("Numel is %ld\n", int_numel); Workspace workspace( /*options=*/options, From 526e74da8abb1da40d60d0610b116c088a57efde Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 04:59:36 +0000 Subject: [PATCH 21/31] Remove mytest --- src/libtorchaudio/rnnt/cpu/compute.cpp | 3 ++- test/torchaudio_unittest/functional/functional_impl.py | 7 ------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 071f68b043..325ab1914d 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -113,7 +113,8 @@ std::tuple compute( printf("tgt %d\n", options.maxTgtLen_); printf("nh %d\n", options.nHypos_); printf("bs %d\n", options.batchSize_); - # TODO: check what these should be + // should be 2,3,1,1 + // TODO: It is! so why is sizes zero? Unless it IS zero and needed_size is zero too? options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 4844d9cf6a..b08b63256c 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -622,13 +622,6 @@ def test_rnnt_loss_basic_backward(self): loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths) loss.backward() - def test_mytest(self): - print("Got here") - logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device) - result = F.rnnt_loss(logits, targets, logit_lengths, target_lengths) - print("DONE") - # result.sum().backward() - def test_rnnt_loss_basic_forward_no_grad(self): """In early stage, calls to `rnnt_loss` resulted in segmentation fault when `logits` have `requires_grad = False`. This test makes sure that this no longer From 1d5f9ef87b7c81ccd5d16200ee8a2cf2e26c8705 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 12:58:22 +0000 Subject: [PATCH 22/31] Fix float size calculation --- src/libtorchaudio/rnnt/cpu/compute.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 325ab1914d..63fa333ea2 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -134,13 +134,13 @@ std::tuple compute( aoti_torch_zero_(gradients); AtenTensorHandle int_workspace; - int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; - printf("SIZES: %ld\n", sizes[0]); + int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; int64_t strides[1] = {1}; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); AtenTensorHandle float_workspace; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + int64_t float_sizes[1] = {DtypeWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, float_sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); int64_t float_numel; aoti_torch_get_numel(float_workspace, &float_numel); From bad73095cd87bed5ef0027b00af0ba8c34328d5a Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 15:17:10 +0000 Subject: [PATCH 23/31] Remove debugging printfs --- src/libtorchaudio/rnnt/cpu/compute.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 63fa333ea2..4150ea98f1 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -109,12 +109,6 @@ std::tuple compute( options.maxTgtLen_ = (int)int_tmp; aoti_torch_get_size(logits.get(), 3, &int_tmp); options.numTargets_ = (int)int_tmp; - printf("src %d\n", options.maxSrcLen_); - printf("tgt %d\n", options.maxTgtLen_); - printf("nh %d\n", options.nHypos_); - printf("bs %d\n", options.batchSize_); - // should be 2,3,1,1 - // TODO: It is! so why is sizes zero? Unless it IS zero and needed_size is zero too? options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; @@ -150,7 +144,6 @@ std::tuple compute( aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); int64_t int_numel; aoti_torch_get_numel(int_workspace, &int_numel); - printf("Numel is %ld\n", int_numel); Workspace workspace( /*options=*/options, From 317b96418a87502484a656853a517a2febf07123 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 15:50:52 +0000 Subject: [PATCH 24/31] Fix size bug for rnnt gpu --- src/libtorchaudio/rnnt/gpu/compute.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index cce71a7005..e56026515c 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -130,12 +130,13 @@ std::tuple compute( aoti_torch_zero_(gradients); AtenTensorHandle int_workspace; - int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; int64_t strides[1] = {1}; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); AtenTensorHandle float_workspace; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + int64_t float_sizes[1] = {DtypeWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, float_sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); int64_t float_numel; aoti_torch_get_numel(float_workspace, &float_numel); From 8bdac235cfa3da575594deacbbe1171e7d72b908 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 31 Jul 2025 15:52:40 +0000 Subject: [PATCH 25/31] Remove alphas and betas for rnnt --- src/libtorchaudio/CMakeLists.txt | 6 - src/libtorchaudio/rnnt/compute_alphas.cpp | 11 -- src/libtorchaudio/rnnt/compute_betas.cpp | 11 -- src/libtorchaudio/rnnt/cpu/compute_alphas.cpp | 128 ----------------- src/libtorchaudio/rnnt/cpu/compute_betas.cpp | 132 ----------------- src/libtorchaudio/rnnt/gpu/compute_alphas.cu | 127 ---------------- src/libtorchaudio/rnnt/gpu/compute_betas.cu | 135 ------------------ 7 files changed, 550 deletions(-) delete mode 100644 src/libtorchaudio/rnnt/compute_alphas.cpp delete mode 100644 src/libtorchaudio/rnnt/compute_betas.cpp delete mode 100644 src/libtorchaudio/rnnt/cpu/compute_alphas.cpp delete mode 100644 src/libtorchaudio/rnnt/cpu/compute_betas.cpp delete mode 100644 src/libtorchaudio/rnnt/gpu/compute_alphas.cu delete mode 100644 src/libtorchaudio/rnnt/gpu/compute_betas.cu diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 85bc227cd6..c7813d1222 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -22,19 +22,13 @@ if(BUILD_RNNT) list( APPEND sources - rnnt/cpu/compute_alphas.cpp - rnnt/cpu/compute_betas.cpp rnnt/cpu/compute.cpp - rnnt/compute_alphas.cpp - rnnt/compute_betas.cpp rnnt/compute.cpp ) if (USE_CUDA) list( APPEND sources - rnnt/gpu/compute_alphas.cu - rnnt/gpu/compute_betas.cu rnnt/gpu/compute.cu ) endif() diff --git a/src/libtorchaudio/rnnt/compute_alphas.cpp b/src/libtorchaudio/rnnt/compute_alphas.cpp deleted file mode 100644 index dd187f9777..0000000000 --- a/src/libtorchaudio/rnnt/compute_alphas.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss_alphas(Tensor logits," - "Tensor targets," - "Tensor logit_lengths," - "Tensor target_lengths," - "int blank," - "float clamp) -> Tensor"); -} diff --git a/src/libtorchaudio/rnnt/compute_betas.cpp b/src/libtorchaudio/rnnt/compute_betas.cpp deleted file mode 100644 index b1cd379a66..0000000000 --- a/src/libtorchaudio/rnnt/compute_betas.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss_betas(Tensor logits," - "Tensor targets," - "Tensor logit_lengths," - "Tensor target_lengths," - "int blank," - "float clamp) -> Tensor"); -} diff --git a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp deleted file mode 100644 index 40ed538175..0000000000 --- a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include -#include -#include - -// TODO: -// Are the StableIValue AtenTensorHandles reference counted at all? -// Why do we call release() on returned arguments? - -namespace torchaudio { -namespace rnnt { -namespace cpu { - -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; - -RAIIATH compute_alphas( - const RAIIATH logits, - const RAIIATH targets, - const RAIIATH logit_lengths, - const RAIIATH target_lengths, - int64_t blank, - double clamp) { - Options options; - int64_t tmp; - aoti_torch_get_size(logit_lengths.get(), 0, &tmp); - options.batchSize_ = (int)tmp; - aoti_torch_get_size(target_lengths.get(), 0, &tmp); - options.nHypos_ = (int)tmp; - options.nHypos_ /= options.batchSize_; - aoti_torch_get_size(logits.get(), 1, &tmp); - options.maxSrcLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 2, &tmp); - options.maxTgtLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 3, &tmp); - options.numTargets_ = (int)tmp; - options.blank_ = blank; - options.clamp_ = clamp; - - int32_t logits_device_type; - aoti_torch_get_device_type(logits.get(), &logits_device_type); - AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cpu()); - - options.device_ = CPU; - - int32_t logits_device; - aoti_torch_get_device_type(logits.get(), &logits_device); - int32_t logits_device_index; - aoti_torch_get_device_index(logits.get(), &logits_device_index); - int32_t logits_dtype; - aoti_torch_get_dtype(logits.get(), &logits_dtype); - - int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; - int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; - - AtenTensorHandle alphas; - aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas); - aoti_torch_zero_(alphas); - - AtenTensorHandle int_workspace; - int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; - int64_t strides[1] = {1}; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - - AtenTensorHandle float_workspace; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - - int64_t float_numel; - aoti_torch_get_numel(float_workspace, &float_numel); - void *int_workspace_ptr; - aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); - void *float_workspace_ptr; - aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); - int64_t int_numel; - aoti_torch_get_numel(int_workspace, &int_numel); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/(float*)float_workspace_ptr, - /*dtype_size=*/float_numel, - /*int_data=*/(int*)int_workspace_ptr, - /*int_size=*/int_numel); - - void *logit_ptr; - aoti_torch_get_data_ptr(logits.get(), &logit_ptr); - - void *target_ptr; - aoti_torch_get_data_ptr(targets.get(), &target_ptr); - - void *logit_len_ptr; - aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - - void *target_len_ptr; - aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); - - void *alpha_ptr; - aoti_torch_get_data_ptr(alphas, &alpha_ptr); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/(float*)logit_ptr, - /*targets=*/(int*)target_ptr, - /*logit_lengths=*/(int*)logit_len_ptr, - /*target_lengths=*/(int*)target_len_ptr, - /*alphas=*/(float*)alpha_ptr); - return RAIIATH(alphas); -} - -void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t1(to(stack[0])); - RAIIATH t2(to(stack[1])); - RAIIATH t3(to(stack[2])); - RAIIATH t4(to(stack[3])); - int64_t blank = to(stack[4]); - double clamp = to(stack[5]); - RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), - blank, clamp); - stack[0] = from(result.release()); -} - -STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_alphas", &boxed_compute_alphas); -} - -} // namespace cpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp b/src/libtorchaudio/rnnt/cpu/compute_betas.cpp deleted file mode 100644 index 729e86a722..0000000000 --- a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include -#include -#include -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace cpu { - -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; - -RAIIATH compute_betas( - const RAIIATH logits, - const RAIIATH targets, - const RAIIATH logit_lengths, - const RAIIATH target_lengths, - int64_t blank, - double clamp) { - Options options; - int64_t tmp; - aoti_torch_get_size(logit_lengths.get(), 0, &tmp); - options.batchSize_ = (int)tmp; - aoti_torch_get_size(target_lengths.get(), 0, &tmp); - options.nHypos_ = (int)tmp; - options.nHypos_ /= options.batchSize_; - aoti_torch_get_size(logits.get(), 1, &tmp); - options.maxSrcLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 2, &tmp); - options.maxTgtLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 3, &tmp); - options.numTargets_ = (int)tmp; - options.blank_ = blank; - options.clamp_ = clamp; - - int32_t logits_device_type; - aoti_torch_get_device_type(logits.get(), &logits_device_type); - AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cpu()); - - options.device_ = CPU; - - int32_t logits_device; - aoti_torch_get_device_type(logits.get(), &logits_device); - int32_t logits_device_index; - aoti_torch_get_device_index(logits.get(), &logits_device_index); - int32_t logits_dtype; - aoti_torch_get_dtype(logits.get(), &logits_dtype); - - int64_t cost_sizes[1] = {options.batchSize_}; - int64_t stride1[1] = {1}; - AtenTensorHandle costs; - aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); - - int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; - int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; - AtenTensorHandle betas; - aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas); - - AtenTensorHandle int_workspace; - int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; - aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - - AtenTensorHandle float_workspace; - aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - - int64_t float_numel; - aoti_torch_get_numel(float_workspace, &float_numel); - void *int_workspace_ptr; - aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); - void *float_workspace_ptr; - aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); - int64_t int_numel; - aoti_torch_get_numel(int_workspace, &int_numel); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/(float*)float_workspace_ptr, - /*dtype_size=*/float_numel, - /*int_data=*/(int*)int_workspace_ptr, - /*int_size=*/int_numel); - - void *logit_ptr; - aoti_torch_get_data_ptr(logits.get(), &logit_ptr); - - void *target_ptr; - aoti_torch_get_data_ptr(targets.get(), &target_ptr); - - void *logit_len_ptr; - aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - - void *target_len_ptr; - aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); - - void *beta_ptr; - aoti_torch_get_data_ptr(betas, &beta_ptr); - - void *cost_ptr; - aoti_torch_get_data_ptr(costs, &cost_ptr); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeBetas( - /*workspace=*/workspace, - /*logits=*/(float*)logit_ptr, - /*targets=*/(int*)target_ptr, - /*logit_lengths=*/(int*)logit_len_ptr, - /*target_lengths=*/(int*)target_len_ptr, - /*costs=*/(float*)cost_ptr, - /*betas=*/(float*)beta_ptr); - return RAIIATH(betas); -} - - -void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t1(to(stack[0])); - RAIIATH t2(to(stack[1])); - RAIIATH t3(to(stack[2])); - RAIIATH t4(to(stack[3])); - int64_t blank = to(stack[4]); - double clamp = to(stack[5]); - RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), - blank, clamp); - stack[0] = from(result.release()); -} - -STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_betas", &boxed_compute_betas); -} - -} // namespace cpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu deleted file mode 100644 index 90e421ab4a..0000000000 --- a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu +++ /dev/null @@ -1,127 +0,0 @@ -#include -#include -#include -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace gpu { - -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; - -RAIIATH compute_alphas( - const RAIIATH logits, - const RAIIATH targets, - const RAIIATH logit_lengths, - const RAIIATH target_lengths, - int64_t blank, - double clamp) { - Options options; - int64_t tmp; - aoti_torch_get_size(logit_lengths.get(), 0, &tmp); - options.batchSize_ = (int)tmp; - aoti_torch_get_size(target_lengths.get(), 0, &tmp); - options.nHypos_ = (int)tmp; - options.nHypos_ /= options.batchSize_; - aoti_torch_get_size(logits.get(), 1, &tmp); - options.maxSrcLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 2, &tmp); - options.maxTgtLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 3, &tmp); - options.numTargets_ = (int)tmp; - options.blank_ = blank; - options.clamp_ = clamp; - - int32_t logits_device_type; - aoti_torch_get_device_type(logits.get(), &logits_device_type); - AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda()); - - int32_t logits_device; - aoti_torch_get_device_type(logits.get(), &logits_device); - int32_t logits_device_index; - aoti_torch_get_device_index(logits.get(), &logits_device_index); - int32_t logits_dtype; - aoti_torch_get_dtype(logits.get(), &logits_dtype); - - aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); - cudaSetDevice(logits_device); - options.device_ = GPU; - - int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; - int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; - - AtenTensorHandle alphas; - aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas); - aoti_torch_zero_(alphas); - - AtenTensorHandle int_workspace; - int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; - int64_t strides[1] = {1}; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - - AtenTensorHandle float_workspace; - aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - - int64_t float_numel; - aoti_torch_get_numel(float_workspace, &float_numel); - void *int_workspace_ptr; - aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); - void *float_workspace_ptr; - aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); - int64_t int_numel; - aoti_torch_get_numel(int_workspace, &int_numel); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/(float*)float_workspace_ptr, - /*dtype_size=*/float_numel, - /*int_data=*/(int*)int_workspace_ptr, - /*int_size=*/int_numel); - - void *logit_ptr; - aoti_torch_get_data_ptr(logits.get(), &logit_ptr); - - void *target_ptr; - aoti_torch_get_data_ptr(targets.get(), &target_ptr); - - void *logit_len_ptr; - aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - - void *target_len_ptr; - aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); - - void *alpha_ptr; - aoti_torch_get_data_ptr(alphas, &alpha_ptr); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/(float*)logit_ptr, - /*targets=*/(int*)target_ptr, - /*logit_lengths=*/(int*)logit_len_ptr, - /*target_lengths=*/(int*)target_len_ptr, - /*alphas=*/(float*)alpha_ptr); - return RAIIATH(alphas); -} - -void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t1(to(stack[0])); - RAIIATH t2(to(stack[1])); - RAIIATH t3(to(stack[2])); - RAIIATH t4(to(stack[3])); - int64_t blank = to(stack[4]); - double clamp = to(stack[5]); - RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), - blank, clamp); - stack[0] = from(result.release()); -} - -STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_alphas", &boxed_compute_alphas); -} - -} // namespace gpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/gpu/compute_betas.cu b/src/libtorchaudio/rnnt/gpu/compute_betas.cu deleted file mode 100644 index 7bed017b14..0000000000 --- a/src/libtorchaudio/rnnt/gpu/compute_betas.cu +++ /dev/null @@ -1,135 +0,0 @@ -#include -#include -#include -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace gpu { - -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; - - -RAIIATH compute_betas( - const RAIIATH logits, - const RAIIATH targets, - const RAIIATH logit_lengths, - const RAIIATH target_lengths, - int64_t blank, - double clamp) { - Options options; - int64_t tmp; - aoti_torch_get_size(logit_lengths.get(), 0, &tmp); - options.batchSize_ = (int)tmp; - aoti_torch_get_size(target_lengths.get(), 0, &tmp); - options.nHypos_ = (int)tmp; - options.nHypos_ /= options.batchSize_; - aoti_torch_get_size(logits.get(), 1, &tmp); - options.maxSrcLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 2, &tmp); - options.maxTgtLen_ = (int)tmp; - aoti_torch_get_size(logits.get(), 3, &tmp); - options.numTargets_ = (int)tmp; - options.blank_ = blank; - options.clamp_ = clamp; - - int32_t logits_device_type; - aoti_torch_get_device_type(logits.get(), &logits_device_type); - AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda()); - - - int32_t logits_device; - aoti_torch_get_device_type(logits.get(), &logits_device); - int32_t logits_device_index; - aoti_torch_get_device_index(logits.get(), &logits_device_index); - int32_t logits_dtype; - aoti_torch_get_dtype(logits.get(), &logits_dtype); - - aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); - cudaSetDevice(logits_device); - options.device_ = GPU; - - int64_t cost_sizes[1] = {options.batchSize_}; - int64_t stride1[1] = {1}; - AtenTensorHandle costs; - aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); - - int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; - int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; - AtenTensorHandle betas; - aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas); - - AtenTensorHandle int_workspace; - int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; - aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); - - AtenTensorHandle float_workspace; - aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); - - int64_t float_numel; - aoti_torch_get_numel(float_workspace, &float_numel); - void *int_workspace_ptr; - aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); - void *float_workspace_ptr; - aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); - int64_t int_numel; - aoti_torch_get_numel(int_workspace, &int_numel); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/(float*)float_workspace_ptr, - /*dtype_size=*/float_numel, - /*int_data=*/(int*)int_workspace_ptr, - /*int_size=*/int_numel); - - void *logit_ptr; - aoti_torch_get_data_ptr(logits.get(), &logit_ptr); - - void *target_ptr; - aoti_torch_get_data_ptr(targets.get(), &target_ptr); - - void *logit_len_ptr; - aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - - void *target_len_ptr; - aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); - - void *beta_ptr; - aoti_torch_get_data_ptr(betas, &beta_ptr); - - void *cost_ptr; - aoti_torch_get_data_ptr(costs, &cost_ptr); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeBetas( - /*workspace=*/workspace, - /*logits=*/(float*)logit_ptr, - /*targets=*/(int*)target_ptr, - /*logit_lengths=*/(int*)logit_len_ptr, - /*target_lengths=*/(int*)target_len_ptr, - /*costs=*/(float*)cost_ptr, - /*betas=*/(float*)beta_ptr); - return RAIIATH(betas); -} - -void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - RAIIATH t1(to(stack[0])); - RAIIATH t2(to(stack[1])); - RAIIATH t3(to(stack[2])); - RAIIATH t4(to(stack[3])); - int64_t blank = to(stack[4]); - double clamp = to(stack[5]); - RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), - blank, clamp); - stack[0] = from(result.release()); -} - -STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_betas", &boxed_compute_betas); -} - -} // namespace gpu -} // namespace rnnt -} // namespace torchaudio From 8964b268fb1874f2649cfca1b238c4b22234efd3 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 5 Aug 2025 15:14:56 +0000 Subject: [PATCH 26/31] WIP --- src/libtorchaudio/rnnt/gpu/compute.cu | 39 ++++++++++++++------------- src/libtorchaudio/rnnt/workspace.h | 8 +++--- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index e56026515c..f20e52f0dc 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -124,6 +124,7 @@ std::tuple compute( int64_t stride1[1] = {1}; AtenTensorHandle costs; aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + aoti_torch_zero_(costs); AtenTensorHandle gradients; aoti_torch_clone(logits.get(), &gradients); @@ -172,25 +173,25 @@ std::tuple compute( void *grads_ptr; aoti_torch_get_data_ptr(gradients, &grads_ptr); - if (logits_dtype == aoti_torch_dtype_float32()) { - Compute( - /*workspace=*/workspace, - /*logits=*/(float*)logit_ptr, - /*targets=*/(int*)target_ptr, - /*logit_lengths=*/(int*)logit_len_ptr, - /*target_lengths=*/(int*)target_len_ptr, - /*costs=*/(float*)costs_ptr, - /*gradients=*/(float*)grads_ptr); - } else { - Compute( - /*workspace=*/workspace, - /*logits=*/(c10::Half*)logit_ptr, - /*targets=*/(int*)target_ptr, - /*logit_lengths=*/(int*)logit_len_ptr, - /*target_lengths=*/(int*)target_len_ptr, - /*costs=*/(c10::Half*)costs_ptr, - /*gradients=*/(c10::Half*)grads_ptr); - } + // if (logits_dtype == aoti_torch_dtype_float32()) { + // Compute( + // /*workspace=*/workspace, + // /*logits=*/(float*)logit_ptr, + // /*targets=*/(int*)target_ptr, + // /*logit_lengths=*/(int*)logit_len_ptr, + // /*target_lengths=*/(int*)target_len_ptr, + // /*costs=*/(float*)costs_ptr, + // /*gradients=*/(float*)grads_ptr); + // } else { + // Compute( + // /*workspace=*/workspace, + // /*logits=*/(c10::Half*)logit_ptr, + // /*targets=*/(int*)target_ptr, + // /*logit_lengths=*/(int*)logit_len_ptr, + // /*target_lengths=*/(int*)target_len_ptr, + // /*costs=*/(c10::Half*)costs_ptr, + // /*gradients=*/(c10::Half*)grads_ptr); + // } return std::make_tuple(Tensor(costs), Tensor(gradients)); } diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index b4bbb30a43..4a389c4fa5 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -123,10 +123,10 @@ class IntWorkspace { inline void ResetAlphaBetaCounters() { #ifdef USE_CUDA if (data_ != nullptr && options_.device_ == GPU) { - cudaMemset( - GetPointerToAlphaCounters(), - 0, - ComputeSizeForAlphaCounters(options_) * sizeof(int)); + // cudaMemset( + // GetPointerToAlphaCounters(), + // 0, + // ComputeSizeForAlphaCounters(options_) * sizeof(int)); cudaMemset( GetPointerToBetaCounters(), 0, From 1c02964de3d706c25597c38a741706267a1c7cea Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 12 Aug 2025 19:11:25 +0000 Subject: [PATCH 27/31] WIP --- src/libtorchaudio/rnnt/gpu/compute.cu | 26 ++++++++++++++------------ src/libtorchaudio/rnnt/workspace.h | 8 ++++---- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index f20e52f0dc..ddbb357fb0 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -125,6 +125,7 @@ std::tuple compute( AtenTensorHandle costs; aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); aoti_torch_zero_(costs); + c10::cuda::device_synchronize(); AtenTensorHandle gradients; aoti_torch_clone(logits.get(), &gradients); @@ -148,6 +149,7 @@ std::tuple compute( int64_t int_numel; aoti_torch_get_numel(int_workspace, &int_numel); + c10::cuda::device_synchronize(); Workspace workspace( /*options=*/options, /*dtype_data=*/(float*)float_workspace_ptr, @@ -155,23 +157,23 @@ std::tuple compute( /*int_data=*/(int*)int_workspace_ptr, /*int_size=*/int_numel); - void *logit_ptr; - aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + // void *logit_ptr; + // aoti_torch_get_data_ptr(logits.get(), &logit_ptr); - void *target_ptr; - aoti_torch_get_data_ptr(targets.get(), &target_ptr); + // void *target_ptr; + // aoti_torch_get_data_ptr(targets.get(), &target_ptr); - void *logit_len_ptr; - aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + // void *logit_len_ptr; + // aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - void *target_len_ptr; - aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + // void *target_len_ptr; + // aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); - void *costs_ptr; - aoti_torch_get_data_ptr(costs, &costs_ptr); + // void *costs_ptr; + // aoti_torch_get_data_ptr(costs, &costs_ptr); - void *grads_ptr; - aoti_torch_get_data_ptr(gradients, &grads_ptr); + // void *grads_ptr; + // aoti_torch_get_data_ptr(gradients, &grads_ptr); // if (logits_dtype == aoti_torch_dtype_float32()) { // Compute( diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index 4a389c4fa5..6d6f126719 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -127,10 +127,10 @@ class IntWorkspace { // GetPointerToAlphaCounters(), // 0, // ComputeSizeForAlphaCounters(options_) * sizeof(int)); - cudaMemset( - GetPointerToBetaCounters(), - 0, - ComputeSizeForBetaCounters(options_) * sizeof(int)); + // cudaMemset( + // GetPointerToBetaCounters(), + // 0, + // ComputeSizeForBetaCounters(options_) * sizeof(int)); } #endif // USE_CUDA } From 9c82f2ca4353b26d3f94ef5a81d62b18a45038b0 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 12 Aug 2025 21:27:35 +0000 Subject: [PATCH 28/31] WIP --- src/libtorchaudio/rnnt/gpu/compute.cu | 85 +++++++++++++++++---------- src/libtorchaudio/rnnt/workspace.h | 10 ++-- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index ddbb357fb0..17c2db3f9a 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace torchaudio { namespace rnnt { @@ -21,46 +23,46 @@ std::tuple compute( bool fused_log_softmax = true) { int32_t logits_device; - aoti_torch_get_device_type(logits.get(), &logits_device); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(logits.get(), &logits_device)); int32_t targets_device; - aoti_torch_get_device_type(targets.get(), &targets_device); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(targets.get(), &targets_device)); int32_t logit_lengths_device; - aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device)); int32_t target_lengths_device; - aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device)); AOTI_TORCH_CHECK(logits_device == targets_device); AOTI_TORCH_CHECK(logits_device == logit_lengths_device); AOTI_TORCH_CHECK(logits_device == target_lengths_device); int32_t logits_dtype; - aoti_torch_get_dtype(logits.get(), &logits_dtype); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(logits.get(), &logits_dtype)); AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || logits_dtype == aoti_torch_dtype_float16()); int32_t targets_dtype; - aoti_torch_get_dtype(targets.get(), &targets_dtype); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(targets.get(), &targets_dtype)); AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || logits_dtype == aoti_torch_dtype_float16()); int32_t logit_lengths_dtype; - aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype)); AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || logit_lengths_dtype == aoti_torch_dtype_float16()); int32_t target_lengths_dtype; - aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype)); AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || target_lengths_dtype == aoti_torch_dtype_float16()); bool bool_tmp; - aoti_torch_is_contiguous(logits.get(), &bool_tmp); + TORCH_ERROR_CODE_CHECK(aoti_torch_is_contiguous(logits.get(), &bool_tmp)); AOTI_TORCH_CHECK(bool_tmp); - aoti_torch_is_contiguous(targets.get(), &bool_tmp); + TORCH_ERROR_CODE_CHECK(aoti_torch_is_contiguous(targets.get(), &bool_tmp)); AOTI_TORCH_CHECK(bool_tmp); - aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); + TORCH_ERROR_CODE_CHECK(aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp)); AOTI_TORCH_CHECK(bool_tmp); - aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); + TORCH_ERROR_CODE_CHECK(aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp)); int64_t int_tmp; aoti_torch_get_dim(logits.get(), &int_tmp); @@ -73,15 +75,15 @@ std::tuple compute( AOTI_TORCH_CHECK(int_tmp == 1); int64_t logit_lengths_size; - aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size)); int64_t logits_size; - aoti_torch_get_size(logits.get(), 0, &logits_size); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(logits.get(), 0, &logits_size)); AOTI_TORCH_CHECK(logit_lengths_size == logits_size); int64_t target_lengths_size; - aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size)); AOTI_TORCH_CHECK(target_lengths_size == logits_size); int64_t targets_size; - aoti_torch_get_size(targets.get(), 0, &targets_size); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(targets.get(), 0, &targets_size)); AOTI_TORCH_CHECK(targets_size == logits_size); // TORCH_CHECK( @@ -116,47 +118,66 @@ std::tuple compute( aoti_torch_get_device_index(logits.get(), &logits_device_index); TORCH_CHECK_EQ(logits_device, aoti_torch_device_type_cuda()); - aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + + options.stream_ = at::cuda::getCurrentCUDAStream(); + // aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); cudaSetDevice(logits_device); options.device_ = GPU; int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; int64_t stride1[1] = {1}; AtenTensorHandle costs; - aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + TORCH_ERROR_CODE_CHECK( + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs)); aoti_torch_zero_(costs); - c10::cuda::device_synchronize(); + + + auto stream = at::cuda::getCurrentCUDAStream(); + at::cuda::stream_synchronize(stream); AtenTensorHandle gradients; aoti_torch_clone(logits.get(), &gradients); aoti_torch_zero_(gradients); - AtenTensorHandle int_workspace; - int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + // AtenTensorHandle int_workspace; + // int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + // printf("INT SIZE IS %ld\n", int_sizes[0]); int64_t strides[1] = {1}; - aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + // TORCH_ERROR_CODE_CHECK( + // aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace)); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(torch::aot_inductor::tensor_handle_to_tensor_pointer(logits.get())->device()) + .dtype(torch::ScalarType::Int)); AtenTensorHandle float_workspace; int64_t float_sizes[1] = {DtypeWorkspace::ComputeSizeFromOptions(options)}; - aoti_torch_empty_strided(1, float_sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + TORCH_ERROR_CODE_CHECK( + aoti_torch_empty_strided(1, float_sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace)); int64_t float_numel; aoti_torch_get_numel(float_workspace, &float_numel); - void *int_workspace_ptr; - aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + // void *int_workspace_ptr; + // TORCH_ERROR_CODE_CHECK( + // aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr)); void *float_workspace_ptr; - aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); - int64_t int_numel; - aoti_torch_get_numel(int_workspace, &int_numel); + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr)); + // int64_t int_numel; + // TORCH_ERROR_CODE_CHECK( + // aoti_torch_get_numel(int_workspace, &int_numel)); - c10::cuda::device_synchronize(); + at::cuda::stream_synchronize(stream); Workspace workspace( /*options=*/options, /*dtype_data=*/(float*)float_workspace_ptr, /*dtype_size=*/float_numel, - /*int_data=*/(int*)int_workspace_ptr, - /*int_size=*/int_numel); - + /*int_data=*/int_workspace.data_ptr(), // (int*)int_workspace_ptr, + /*int_size=*/int_workspace.numel() // int_numel); + ); + at::cuda::stream_synchronize(stream); // void *logit_ptr; // aoti_torch_get_data_ptr(logits.get(), &logit_ptr); diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index 6d6f126719..cbcdcfabf3 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -123,10 +123,12 @@ class IntWorkspace { inline void ResetAlphaBetaCounters() { #ifdef USE_CUDA if (data_ != nullptr && options_.device_ == GPU) { - // cudaMemset( - // GetPointerToAlphaCounters(), - // 0, - // ComputeSizeForAlphaCounters(options_) * sizeof(int)); + printf("ALPHA COUNTER SIZE IS %ld", ComputeSizeForAlphaCounters(options_)); + fflush(stdout); + cudaMemset( + GetPointerToAlphaCounters(), + 0, + ComputeSizeForAlphaCounters(options_) * sizeof(int32_t)); // cudaMemset( // GetPointerToBetaCounters(), // 0, From 7cfb701e3c79b7cec63064eedf1c0ecae25fcaf8 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 14 Aug 2025 19:12:01 +0000 Subject: [PATCH 29/31] WIP --- src/libtorchaudio/rnnt/gpu/compute.cu | 38 +++++++++++++-------------- src/libtorchaudio/rnnt/workspace.h | 10 +++++++ 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 17c2db3f9a..71632e2f19 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -121,7 +121,7 @@ std::tuple compute( options.stream_ = at::cuda::getCurrentCUDAStream(); // aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); - cudaSetDevice(logits_device); + TORCH_CHECK_EQ(cudaSetDevice(logits_device_index), cudaSuccess); options.device_ = GPU; int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; @@ -132,25 +132,23 @@ std::tuple compute( aoti_torch_zero_(costs); - auto stream = at::cuda::getCurrentCUDAStream(); - at::cuda::stream_synchronize(stream); + at::cuda::stream_synchronize(options.stream_); AtenTensorHandle gradients; aoti_torch_clone(logits.get(), &gradients); aoti_torch_zero_(gradients); - // AtenTensorHandle int_workspace; - // int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; - // printf("INT SIZE IS %ld\n", int_sizes[0]); + AtenTensorHandle int_workspace; + int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; int64_t strides[1] = {1}; - // TORCH_ERROR_CODE_CHECK( - // aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace)); + TORCH_ERROR_CODE_CHECK( + aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace)); - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(torch::aot_inductor::tensor_handle_to_tensor_pointer(logits.get())->device()) - .dtype(torch::ScalarType::Int)); + // torch::Tensor int_workspace = torch::empty( + // IntWorkspace::ComputeSizeFromOptions(options), + // torch::TensorOptions() + // .device(torch::aot_inductor::tensor_handle_to_tensor_pointer(logits.get())->device()) + // .dtype(torch::ScalarType::Int)); AtenTensorHandle float_workspace; int64_t float_sizes[1] = {DtypeWorkspace::ComputeSizeFromOptions(options)}; @@ -165,19 +163,19 @@ std::tuple compute( void *float_workspace_ptr; TORCH_ERROR_CODE_CHECK( aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr)); - // int64_t int_numel; - // TORCH_ERROR_CODE_CHECK( - // aoti_torch_get_numel(int_workspace, &int_numel)); + int64_t int_numel; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_numel(int_workspace, &int_numel)); - at::cuda::stream_synchronize(stream); + at::cuda::stream_synchronize(options.stream_); Workspace workspace( /*options=*/options, /*dtype_data=*/(float*)float_workspace_ptr, /*dtype_size=*/float_numel, - /*int_data=*/int_workspace.data_ptr(), // (int*)int_workspace_ptr, - /*int_size=*/int_workspace.numel() // int_numel); + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel ); - at::cuda::stream_synchronize(stream); + at::cuda::stream_synchronize(options.stream_); // void *logit_ptr; // aoti_torch_get_data_ptr(logits.get(), &logit_ptr); diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index cbcdcfabf3..bed10e284a 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -6,6 +6,9 @@ #include #include +#ifdef USE_CUDA +#include +#endif namespace torchaudio { namespace rnnt { @@ -124,6 +127,13 @@ class IntWorkspace { #ifdef USE_CUDA if (data_ != nullptr && options_.device_ == GPU) { printf("ALPHA COUNTER SIZE IS %ld", ComputeSizeForAlphaCounters(options_)); + + // Use cudaPointerGetAttributes here to check that the device is cudaMemoryTypeDevice + cudaPointerAttributes attributes; + cudaError_t error = cudaPointerGetAttributes(&attributes, data_); + TORCH_CHECK_EQ(error, cudaSuccess); + TORCH_CHECK_EQ(attributes.device, cudaMemoryTypeDevice); + fflush(stdout); cudaMemset( GetPointerToAlphaCounters(), From 41d0f97432e18ecfb21e21a45b1ed7c148e888b9 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 14 Aug 2025 19:44:11 +0000 Subject: [PATCH 30/31] Remove debugging statements --- src/libtorchaudio/rnnt/gpu/compute.cu | 90 ++++++++++++--------------- src/libtorchaudio/rnnt/workspace.h | 20 ++---- 2 files changed, 45 insertions(+), 65 deletions(-) diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 71632e2f19..10b9fc1706 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -1,10 +1,8 @@ #include #include #include -#include #include #include -#include namespace torchaudio { namespace rnnt { @@ -115,7 +113,7 @@ std::tuple compute( options.fusedLogSmax_ = fused_log_softmax; int32_t logits_device_index; - aoti_torch_get_device_index(logits.get(), &logits_device_index); + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_index(logits.get(), &logits_device_index)); TORCH_CHECK_EQ(logits_device, aoti_torch_device_type_cuda()); @@ -144,12 +142,6 @@ std::tuple compute( TORCH_ERROR_CODE_CHECK( aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace)); - // torch::Tensor int_workspace = torch::empty( - // IntWorkspace::ComputeSizeFromOptions(options), - // torch::TensorOptions() - // .device(torch::aot_inductor::tensor_handle_to_tensor_pointer(logits.get())->device()) - // .dtype(torch::ScalarType::Int)); - AtenTensorHandle float_workspace; int64_t float_sizes[1] = {DtypeWorkspace::ComputeSizeFromOptions(options)}; TORCH_ERROR_CODE_CHECK( @@ -157,9 +149,9 @@ std::tuple compute( int64_t float_numel; aoti_torch_get_numel(float_workspace, &float_numel); - // void *int_workspace_ptr; - // TORCH_ERROR_CODE_CHECK( - // aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr)); + void *int_workspace_ptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr)); void *float_workspace_ptr; TORCH_ERROR_CODE_CHECK( aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr)); @@ -176,43 +168,43 @@ std::tuple compute( /*int_size=*/int_numel ); at::cuda::stream_synchronize(options.stream_); - // void *logit_ptr; - // aoti_torch_get_data_ptr(logits.get(), &logit_ptr); - - // void *target_ptr; - // aoti_torch_get_data_ptr(targets.get(), &target_ptr); - - // void *logit_len_ptr; - // aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - - // void *target_len_ptr; - // aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); - - // void *costs_ptr; - // aoti_torch_get_data_ptr(costs, &costs_ptr); - - // void *grads_ptr; - // aoti_torch_get_data_ptr(gradients, &grads_ptr); - - // if (logits_dtype == aoti_torch_dtype_float32()) { - // Compute( - // /*workspace=*/workspace, - // /*logits=*/(float*)logit_ptr, - // /*targets=*/(int*)target_ptr, - // /*logit_lengths=*/(int*)logit_len_ptr, - // /*target_lengths=*/(int*)target_len_ptr, - // /*costs=*/(float*)costs_ptr, - // /*gradients=*/(float*)grads_ptr); - // } else { - // Compute( - // /*workspace=*/workspace, - // /*logits=*/(c10::Half*)logit_ptr, - // /*targets=*/(int*)target_ptr, - // /*logit_lengths=*/(int*)logit_len_ptr, - // /*target_lengths=*/(int*)target_len_ptr, - // /*costs=*/(c10::Half*)costs_ptr, - // /*gradients=*/(c10::Half*)grads_ptr); - // } + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); + + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *costs_ptr; + aoti_torch_get_data_ptr(costs, &costs_ptr); + + void *grads_ptr; + aoti_torch_get_data_ptr(gradients, &grads_ptr); + + if (logits_dtype == aoti_torch_dtype_float32()) { + Compute( + /*workspace=*/workspace, + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)costs_ptr, + /*gradients=*/(float*)grads_ptr); + } else { + Compute( + /*workspace=*/workspace, + /*logits=*/(c10::Half*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(c10::Half*)costs_ptr, + /*gradients=*/(c10::Half*)grads_ptr); + } return std::make_tuple(Tensor(costs), Tensor(gradients)); } diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index bed10e284a..9edfa33426 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -6,9 +6,6 @@ #include #include -#ifdef USE_CUDA -#include -#endif namespace torchaudio { namespace rnnt { @@ -126,23 +123,14 @@ class IntWorkspace { inline void ResetAlphaBetaCounters() { #ifdef USE_CUDA if (data_ != nullptr && options_.device_ == GPU) { - printf("ALPHA COUNTER SIZE IS %ld", ComputeSizeForAlphaCounters(options_)); - - // Use cudaPointerGetAttributes here to check that the device is cudaMemoryTypeDevice - cudaPointerAttributes attributes; - cudaError_t error = cudaPointerGetAttributes(&attributes, data_); - TORCH_CHECK_EQ(error, cudaSuccess); - TORCH_CHECK_EQ(attributes.device, cudaMemoryTypeDevice); - - fflush(stdout); cudaMemset( GetPointerToAlphaCounters(), 0, ComputeSizeForAlphaCounters(options_) * sizeof(int32_t)); - // cudaMemset( - // GetPointerToBetaCounters(), - // 0, - // ComputeSizeForBetaCounters(options_) * sizeof(int)); + cudaMemset( + GetPointerToBetaCounters(), + 0, + ComputeSizeForBetaCounters(options_) * sizeof(int)); } #endif // USE_CUDA } From 82a10c9ed6235fba7f4b6ee0a1d2ec3110369a79 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 14 Aug 2025 21:20:45 +0000 Subject: [PATCH 31/31] Add length check --- src/libtorchaudio/rnnt/cpu/compute.cpp | 9 ++++++--- src/libtorchaudio/rnnt/gpu/compute.cu | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index 4150ea98f1..95a74e06d6 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -86,10 +86,13 @@ std::tuple compute( aoti_torch_get_size(targets.get(), 0, &targets_size); AOTI_TORCH_CHECK(targets_size == logits_size); - // TORCH_CHECK( - // blank >= 0 && blank < logits.size(-1), - // "blank must be within [0, logits.shape[-1])"); + AOTI_TORCH_CHECK( + blank >= 0 && blank < logits.size(-1), + "blank must be within [0, logits.shape[-1])"); + // "Max" is not ABI stable yet, but no tests check + // for this error behavior, so it's okay to merge in for now. + // // TORCH_CHECK( // logits.size(1) == at::max(logit_lengths).item().toInt(), // "input length mismatch"); diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 10b9fc1706..fecf049aa7 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -84,9 +84,9 @@ std::tuple compute( TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(targets.get(), 0, &targets_size)); AOTI_TORCH_CHECK(targets_size == logits_size); - // TORCH_CHECK( - // blank >= 0 && blank < logits.size(-1), - // "blank must be within [0, logits.shape[-1])"); + AOTI_TORCH_CHECK( + blank >= 0 && blank < logits.size(-1), + "blank must be within [0, logits.shape[-1])"); // TORCH_CHECK( // logits.size(1) == at::max(logit_lengths).item().toInt(),