diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 85bc227cd6..20ad792b32 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -6,6 +6,7 @@ set( lfilter.cpp overdrive.cpp utils.cpp + accessor_tests.cpp ) set( diff --git a/src/libtorchaudio/accessor.h b/src/libtorchaudio/accessor.h new file mode 100644 index 0000000000..6211acff3d --- /dev/null +++ b/src/libtorchaudio/accessor.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include + +using torch::stable::Tensor; + +template +class Accessor { + int64_t strides[k]; + int64_t sizes[k]; + T *data; + +public: + using tensor_type = typename std::conditional::type; + + Accessor(tensor_type tensor) { + auto raw_ptr = tensor.data_ptr(); + data = static_cast(raw_ptr); + for (unsigned int i = 0; i < k; i++) { + strides[i] = tensor.stride(i); + sizes[i] = tensor.size(i); + } + } + + T index(...) { + va_list args; + va_start(args, k); + int64_t ix = 0; + for (unsigned int i = 0; i < k; i++) { + ix += strides[i] * va_arg(args, int); + } + va_end(args); + return data[ix]; + } + + int64_t size(int dim) { + return sizes[dim]; + } + + template + typename std::enable_if::type set_index(T value, ...) { + va_list args; + va_start(args, value); + int64_t ix = 0; + for (unsigned int i = 0; i < k; i++) { + ix += strides[i] * va_arg(args, int); + } + va_end(args); + data[ix] = value; + } +}; diff --git a/src/libtorchaudio/accessor_tests.cpp b/src/libtorchaudio/accessor_tests.cpp new file mode 100644 index 0000000000..45312a8408 --- /dev/null +++ b/src/libtorchaudio/accessor_tests.cpp @@ -0,0 +1,44 @@ +#include +#include +#include + +namespace torchaudio { + +namespace accessor_tests { + +using namespace std; +using torch::stable::Tensor; + +bool test_accessor(const Tensor tensor) { + int64_t* data_ptr = (int64_t*)tensor.data_ptr(); + auto accessor = Accessor<3, int64_t>(tensor); + for (unsigned int i = 0; i < tensor.size(0); i++) { + for (unsigned int j = 0; j < tensor.size(1); j++) { + for (unsigned int k = 0; k < tensor.size(2); k++) { + auto check = *(data_ptr++) == accessor.index(i, j, k); + if (!check) { + return false; + } + } + } + } + return true; +} + +void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + auto result = test_accessor(std::move(t1)); + stack[0] = from(result); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "_test_accessor(Tensor log_probs) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("torchaudio::_test_accessor", &boxed_test_accessor); +} + +} +} diff --git a/src/libtorchaudio/forced_align/cpu/compute.cpp b/src/libtorchaudio/forced_align/cpu/compute.cpp index 0ddd21b126..4b87b20e5f 100644 --- a/src/libtorchaudio/forced_align/cpu/compute.cpp +++ b/src/libtorchaudio/forced_align/cpu/compute.cpp @@ -1,8 +1,5 @@ #include #include -#include -#include -#include #include using namespace std; diff --git a/test/torchaudio_unittest/accessor_test.py b/test/torchaudio_unittest/accessor_test.py new file mode 100644 index 0000000000..db14258dc6 --- /dev/null +++ b/test/torchaudio_unittest/accessor_test.py @@ -0,0 +1,7 @@ +import torch +from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE + +if _IS_TORCHAUDIO_EXT_AVAILABLE: + def test_accessor(): + tensor = torch.randint(1000, (5,4,3)) + assert torch.ops.torchaudio._test_accessor(tensor)