Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(
lfilter.cpp
overdrive.cpp
utils.cpp
accessor_tests.cpp
)

set(
Expand Down
53 changes: 53 additions & 0 deletions src/libtorchaudio/accessor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <torch/csrc/stable/tensor_struct.h>
#include <type_traits>
#include <cstdarg>

using torch::stable::Tensor;

template<unsigned int k, typename T, bool IsConst = true>
class Accessor {
int64_t strides[k];
int64_t sizes[k];
T *data;

public:
using tensor_type = typename std::conditional<IsConst, const Tensor&, Tensor&>::type;

Accessor(tensor_type tensor) {
auto raw_ptr = tensor.data_ptr();
data = static_cast<T*>(raw_ptr);
for (unsigned int i = 0; i < k; i++) {
strides[i] = tensor.stride(i);
sizes[i] = tensor.size(i);
}
}

T index(...) {
va_list args;
va_start(args, k);
int64_t ix = 0;
for (unsigned int i = 0; i < k; i++) {
ix += strides[i] * va_arg(args, int);
}
va_end(args);
return data[ix];
}

int64_t size(int dim) {
return sizes[dim];
}

template<bool C = IsConst>
typename std::enable_if<!C, void>::type set_index(T value, ...) {
va_list args;
va_start(args, value);
int64_t ix = 0;
for (unsigned int i = 0; i < k; i++) {
ix += strides[i] * va_arg(args, int);
}
va_end(args);
data[ix] = value;
}
};
44 changes: 44 additions & 0 deletions src/libtorchaudio/accessor_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <libtorchaudio/accessor.h>
#include <cstdint>
#include <torch/csrc/stable/library.h>

namespace torchaudio {

namespace accessor_tests {

using namespace std;
using torch::stable::Tensor;

bool test_accessor(const Tensor tensor) {
int64_t* data_ptr = (int64_t*)tensor.data_ptr();
auto accessor = Accessor<3, int64_t>(tensor);
for (unsigned int i = 0; i < tensor.size(0); i++) {
for (unsigned int j = 0; j < tensor.size(1); j++) {
for (unsigned int k = 0; k < tensor.size(2); k++) {
auto check = *(data_ptr++) == accessor.index(i, j, k);
if (!check) {
return false;
}
}
}
}
return true;
}

void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor t1(to<AtenTensorHandle>(stack[0]));
auto result = test_accessor(std::move(t1));
stack[0] = from(result);
}

STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"_test_accessor(Tensor log_probs) -> bool");
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_test_accessor", &boxed_test_accessor);
}

}
}
3 changes: 0 additions & 3 deletions src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

using namespace std;
Expand Down
7 changes: 7 additions & 0 deletions test/torchaudio_unittest/accessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch
from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE

if _IS_TORCHAUDIO_EXT_AVAILABLE:
def test_accessor():
tensor = torch.randint(1000, (5,4,3))
assert torch.ops.torchaudio._test_accessor(tensor)
Loading