|
1 |
| -#include <torch/script.h> |
2 |
| -#include <torch/torch.h> |
| 1 | +#include <libtorchaudio/utils.h> |
3 | 2 | #include <torch/csrc/stable/library.h>
|
4 |
| -#include <torch/csrc/stable/tensor.h> |
5 |
| -#include <torch/csrc/stable/ops.h> |
6 |
| -#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
7 |
| - |
8 |
| -using namespace std; |
9 | 3 |
|
10 | 4 | namespace torchaudio {
|
11 | 5 | namespace alignment {
|
12 | 6 | namespace cpu {
|
| 7 | + |
| 8 | +using torch::stable::Tensor; |
| 9 | +using torch::headeronly::ScalarType; |
| 10 | + |
13 | 11 | // Inspired from
|
14 | 12 | // https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
|
15 |
| -template <typename scalar_t, at::ScalarType target_scalar_type> |
| 13 | +template <typename scalar_t, ScalarType target_scalar_type> |
16 | 14 | void forced_align_impl(
|
17 |
| - const torch::Tensor& logProbs, |
18 |
| - const torch::Tensor& targets, |
| 15 | + const Tensor& logProbs, |
| 16 | + const Tensor& targets, |
19 | 17 | const int64_t blank,
|
20 |
| - torch::Tensor& paths) { |
| 18 | + Tensor& paths) { |
21 | 19 | const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
|
22 | 20 | using target_t = typename std::
|
23 |
| - conditional<target_scalar_type == torch::kInt, int, int64_t>::type; |
| 21 | + conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type; |
24 | 22 | const auto batchIndex =
|
25 | 23 | 0; // TODO: support batch version and use the real batch index
|
26 | 24 | const auto T = logProbs.size(1);
|
@@ -132,79 +130,111 @@ void forced_align_impl(
|
132 | 130 | delete[] backPtr_a;
|
133 | 131 | }
|
134 | 132 |
|
135 |
| -std::tuple<torch::Tensor, torch::Tensor> compute( |
136 |
| - const torch::Tensor& logProbs, |
137 |
| - const torch::Tensor& targets, |
138 |
| - const torch::Tensor& inputLengths, |
139 |
| - const torch::Tensor& targetLengths, |
| 133 | +std::tuple<Tensor, Tensor> compute( |
| 134 | + const Tensor& logProbs, |
| 135 | + const Tensor& targets, |
| 136 | + const Tensor& inputLengths, |
| 137 | + const Tensor& targetLengths, |
140 | 138 | const int64_t blank) {
|
141 |
| - TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); |
142 |
| - TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); |
143 |
| - TORCH_CHECK( |
144 |
| - logProbs.device() == targets.device(), |
145 |
| - "log_probs and targets need to be on the same device"); |
146 |
| - TORCH_CHECK( |
147 |
| - logProbs.dtype() == torch::kFloat64 || |
148 |
| - logProbs.dtype() == torch::kFloat32 || |
149 |
| - logProbs.dtype() == torch::kFloat16, |
| 139 | + STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); |
| 140 | + STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); |
| 141 | + STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor"); |
| 142 | + STD_TORCH_CHECK(targetLengths.is_cpu(), "target_lengths must be a CPU tensor"); |
| 143 | + STD_TORCH_CHECK( |
| 144 | + logProbs.scalar_type() == ScalarType::Double || |
| 145 | + logProbs.scalar_type() == ScalarType::Float || |
| 146 | + logProbs.scalar_type() == ScalarType::Half, |
150 | 147 | "log_probs must be float64, float32 or float16 (half) type");
|
151 |
| - TORCH_CHECK( |
152 |
| - targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64, |
| 148 | + STD_TORCH_CHECK( |
| 149 | + targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long, |
153 | 150 | "targets must be int32 or int64 type");
|
154 |
| - TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); |
155 |
| - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); |
156 |
| - TORCH_CHECK( |
| 151 | + STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); |
| 152 | + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); |
| 153 | + STD_TORCH_CHECK( |
157 | 154 | logProbs.dim() == 3,
|
158 | 155 | "log_probs must be 3-D (batch_size, input length, num classes)");
|
159 |
| - TORCH_CHECK( |
| 156 | + STD_TORCH_CHECK( |
160 | 157 | targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
|
161 |
| - TORCH_CHECK( |
| 158 | + STD_TORCH_CHECK( |
162 | 159 | inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
|
163 |
| - TORCH_CHECK( |
| 160 | + STD_TORCH_CHECK( |
164 | 161 | targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
|
165 |
| - TORCH_CHECK( |
| 162 | + STD_TORCH_CHECK( |
166 | 163 | logProbs.size(0) == 1,
|
167 | 164 | "The batch dimension for log_probs must be 1 at the current version.")
|
168 |
| - TORCH_CHECK( |
| 165 | + STD_TORCH_CHECK( |
169 | 166 | targets.size(0) == 1,
|
170 | 167 | "The batch dimension for targets must be 1 at the current version.")
|
171 |
| - TORCH_CHECK( |
| 168 | + STD_TORCH_CHECK( |
172 | 169 | blank >= 0 && blank < logProbs.size(-1),
|
173 | 170 | "blank must be within [0, num classes)");
|
174 | 171 |
|
175 |
| - TORCH_CHECK( |
176 |
| - logProbs.size(1) == at::max(inputLengths).item().toInt(), |
| 172 | + STD_TORCH_CHECK( |
| 173 | + logProbs.size(1) == torchaudio::util::max<int>(inputLengths), |
177 | 174 | "input length mismatch");
|
178 |
| - TORCH_CHECK( |
179 |
| - targets.size(1) == at::max(targetLengths).item().toInt(), |
| 175 | + STD_TORCH_CHECK( |
| 176 | + targets.size(1) == torchaudio::util::max<int>(targetLengths), |
180 | 177 | "target length mismatch");
|
181 | 178 |
|
182 | 179 | const auto B = logProbs.size(0);
|
183 | 180 | const auto T = logProbs.size(1);
|
184 |
| - auto paths = torch::zeros( |
185 |
| - {B, T}, |
186 |
| - torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); |
187 |
| - AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
188 |
| - logProbs.scalar_type(), "forced_align_impl", [&] { |
189 |
| - if (targets.scalar_type() == torch::kInt64) { |
190 |
| - forced_align_impl<scalar_t, torch::kInt64>( |
191 |
| - logProbs, targets, blank, paths); |
192 |
| - } else { |
193 |
| - forced_align_impl<scalar_t, torch::kInt32>( |
194 |
| - logProbs, targets, blank, paths); |
195 |
| - } |
196 |
| - }); |
197 |
| - return std::make_tuple( |
198 |
| - paths, |
199 |
| - logProbs |
200 |
| - ); |
201 |
| -} |
202 |
| - |
| 181 | + Tensor paths = torch::stable::new_empty(targets, {B, T}); |
| 182 | + torch::stable::zero_(paths); |
| 183 | + |
| 184 | + switch (logProbs.scalar_type()) { |
| 185 | + case ScalarType::Double: { |
| 186 | + if (targets.scalar_type() == ScalarType::Long) { |
| 187 | + forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths); |
| 188 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 189 | + forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths); |
| 190 | + } else { |
| 191 | + STD_TORCH_CHECK(false, "unreachable"); |
| 192 | + } |
| 193 | + break; |
| 194 | + } |
| 195 | + case ScalarType::Float: { |
| 196 | + if (targets.scalar_type() == ScalarType::Long) { |
| 197 | + forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths); |
| 198 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 199 | + forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths); |
| 200 | + } else { |
| 201 | + STD_TORCH_CHECK(false, "unreachable"); |
| 202 | + } |
| 203 | + break; |
| 204 | + } |
| 205 | + case ScalarType::Half: { |
| 206 | + if (targets.scalar_type() == ScalarType::Long) { |
| 207 | + forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths); |
| 208 | + } else if (targets.scalar_type() == ScalarType::Int) { |
| 209 | + forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths); |
| 210 | + } else { |
| 211 | + STD_TORCH_CHECK(false, "unreachable"); |
| 212 | + } |
| 213 | + break; |
| 214 | + } |
| 215 | + default: { |
| 216 | + STD_TORCH_CHECK(false, "unreachable"); |
| 217 | + } |
| 218 | + }; |
203 | 219 |
|
| 220 | + return std::make_tuple(paths, logProbs); |
| 221 | +} |
204 | 222 |
|
| 223 | +void boxed_forced_align_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { |
| 224 | + STD_TORCH_CHECK(num_args == 5, "num_args must be 5"); |
| 225 | + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); |
| 226 | + std::tuple<Tensor, Tensor> res = compute( |
| 227 | + /*logProbs*/to<Tensor>(stack[0]), |
| 228 | + /*targets*/to<Tensor>(stack[1]), |
| 229 | + /*logit_lengths*/to<Tensor>(stack[2]), |
| 230 | + /*target_lengths*/to<Tensor>(stack[3]), |
| 231 | + /*blank*/float(to<int64_t>(stack[4]))); |
| 232 | + stack[0] = from(std::get<0>(res)); |
| 233 | + stack[1] = from(std::get<1>(res)); |
| 234 | +} |
205 | 235 |
|
206 |
| -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
207 |
| - m.impl("forced_align", &compute); |
| 236 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
| 237 | + m.impl("forced_align", &boxed_forced_align_cpu); |
208 | 238 | }
|
209 | 239 |
|
210 | 240 | } // namespace cpu
|
|
0 commit comments