Skip to content

Commit b3a5a0e

Browse files
committed
[STABLE ABI] Port forced_align
1 parent 7d4c0bf commit b3a5a0e

File tree

4 files changed

+261
-151
lines changed

4 files changed

+261
-151
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 92 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
1-
#include <torch/script.h>
2-
#include <torch/torch.h>
1+
#include <libtorchaudio/utils.h>
32
#include <torch/csrc/stable/library.h>
4-
#include <torch/csrc/stable/tensor.h>
5-
#include <torch/csrc/stable/ops.h>
6-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
7-
8-
using namespace std;
93

104
namespace torchaudio {
115
namespace alignment {
126
namespace cpu {
7+
8+
using torch::stable::Tensor;
9+
using torch::headeronly::ScalarType;
10+
1311
// Inspired from
1412
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
15-
template <typename scalar_t, at::ScalarType target_scalar_type>
13+
template <typename scalar_t, ScalarType target_scalar_type>
1614
void forced_align_impl(
17-
const torch::Tensor& logProbs,
18-
const torch::Tensor& targets,
15+
const Tensor& logProbs,
16+
const Tensor& targets,
1917
const int64_t blank,
20-
torch::Tensor& paths) {
18+
Tensor& paths) {
2119
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
2220
using target_t = typename std::
23-
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
21+
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
2422
const auto batchIndex =
2523
0; // TODO: support batch version and use the real batch index
2624
const auto T = logProbs.size(1);
@@ -132,79 +130,111 @@ void forced_align_impl(
132130
delete[] backPtr_a;
133131
}
134132

135-
std::tuple<torch::Tensor, torch::Tensor> compute(
136-
const torch::Tensor& logProbs,
137-
const torch::Tensor& targets,
138-
const torch::Tensor& inputLengths,
139-
const torch::Tensor& targetLengths,
133+
std::tuple<Tensor, Tensor> compute(
134+
const Tensor& logProbs,
135+
const Tensor& targets,
136+
const Tensor& inputLengths,
137+
const Tensor& targetLengths,
140138
const int64_t blank) {
141-
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
142-
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
143-
TORCH_CHECK(
144-
logProbs.device() == targets.device(),
145-
"log_probs and targets need to be on the same device");
146-
TORCH_CHECK(
147-
logProbs.dtype() == torch::kFloat64 ||
148-
logProbs.dtype() == torch::kFloat32 ||
149-
logProbs.dtype() == torch::kFloat16,
139+
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
140+
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
141+
STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor");
142+
STD_TORCH_CHECK(targetLengths.is_cpu(), "target_lengths must be a CPU tensor");
143+
STD_TORCH_CHECK(
144+
logProbs.scalar_type() == ScalarType::Double ||
145+
logProbs.scalar_type() == ScalarType::Float ||
146+
logProbs.scalar_type() == ScalarType::Half,
150147
"log_probs must be float64, float32 or float16 (half) type");
151-
TORCH_CHECK(
152-
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
148+
STD_TORCH_CHECK(
149+
targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long,
153150
"targets must be int32 or int64 type");
154-
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
155-
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
156-
TORCH_CHECK(
151+
STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
152+
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
153+
STD_TORCH_CHECK(
157154
logProbs.dim() == 3,
158155
"log_probs must be 3-D (batch_size, input length, num classes)");
159-
TORCH_CHECK(
156+
STD_TORCH_CHECK(
160157
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
161-
TORCH_CHECK(
158+
STD_TORCH_CHECK(
162159
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
163-
TORCH_CHECK(
160+
STD_TORCH_CHECK(
164161
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
165-
TORCH_CHECK(
162+
STD_TORCH_CHECK(
166163
logProbs.size(0) == 1,
167164
"The batch dimension for log_probs must be 1 at the current version.")
168-
TORCH_CHECK(
165+
STD_TORCH_CHECK(
169166
targets.size(0) == 1,
170167
"The batch dimension for targets must be 1 at the current version.")
171-
TORCH_CHECK(
168+
STD_TORCH_CHECK(
172169
blank >= 0 && blank < logProbs.size(-1),
173170
"blank must be within [0, num classes)");
174171

175-
TORCH_CHECK(
176-
logProbs.size(1) == at::max(inputLengths).item().toInt(),
172+
STD_TORCH_CHECK(
173+
logProbs.size(1) == torchaudio::util::max<int>(inputLengths),
177174
"input length mismatch");
178-
TORCH_CHECK(
179-
targets.size(1) == at::max(targetLengths).item().toInt(),
175+
STD_TORCH_CHECK(
176+
targets.size(1) == torchaudio::util::max<int>(targetLengths),
180177
"target length mismatch");
181178

182179
const auto B = logProbs.size(0);
183180
const auto T = logProbs.size(1);
184-
auto paths = torch::zeros(
185-
{B, T},
186-
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
187-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
188-
logProbs.scalar_type(), "forced_align_impl", [&] {
189-
if (targets.scalar_type() == torch::kInt64) {
190-
forced_align_impl<scalar_t, torch::kInt64>(
191-
logProbs, targets, blank, paths);
192-
} else {
193-
forced_align_impl<scalar_t, torch::kInt32>(
194-
logProbs, targets, blank, paths);
195-
}
196-
});
197-
return std::make_tuple(
198-
paths,
199-
logProbs
200-
);
201-
}
202-
181+
Tensor paths = torch::stable::new_empty(targets, {B, T});
182+
torch::stable::zero_(paths);
183+
184+
switch (logProbs.scalar_type()) {
185+
case ScalarType::Double: {
186+
if (targets.scalar_type() == ScalarType::Long) {
187+
forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths);
188+
} else if (targets.scalar_type() == ScalarType::Int) {
189+
forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths);
190+
} else {
191+
STD_TORCH_CHECK(false, "unreachable");
192+
}
193+
break;
194+
}
195+
case ScalarType::Float: {
196+
if (targets.scalar_type() == ScalarType::Long) {
197+
forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths);
198+
} else if (targets.scalar_type() == ScalarType::Int) {
199+
forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths);
200+
} else {
201+
STD_TORCH_CHECK(false, "unreachable");
202+
}
203+
break;
204+
}
205+
case ScalarType::Half: {
206+
if (targets.scalar_type() == ScalarType::Long) {
207+
forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths);
208+
} else if (targets.scalar_type() == ScalarType::Int) {
209+
forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths);
210+
} else {
211+
STD_TORCH_CHECK(false, "unreachable");
212+
}
213+
break;
214+
}
215+
default: {
216+
STD_TORCH_CHECK(false, "unreachable");
217+
}
218+
};
203219

220+
return std::make_tuple(paths, logProbs);
221+
}
204222

223+
void boxed_forced_align_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
224+
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
225+
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
226+
std::tuple<Tensor, Tensor> res = compute(
227+
/*logProbs*/to<Tensor>(stack[0]),
228+
/*targets*/to<Tensor>(stack[1]),
229+
/*logit_lengths*/to<Tensor>(stack[2]),
230+
/*target_lengths*/to<Tensor>(stack[3]),
231+
/*blank*/float(to<int64_t>(stack[4])));
232+
stack[0] = from(std::get<0>(res));
233+
stack[1] = from(std::get<1>(res));
234+
}
205235

206-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
207-
m.impl("forced_align", &compute);
236+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
237+
m.impl("forced_align", &boxed_forced_align_cpu);
208238
}
209239

210240
} // namespace cpu

0 commit comments

Comments
 (0)