Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 224 additions & 6 deletions .github/workflows/pytorchsim_test.yml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2
rm *.tar.gz

# Install torchsim dependency
RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0
RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0 && pip install "transformers<4.44" && pip install diffusers==0.34.0

ENV RISCV=/workspace/riscv
ENV PATH=$RISCV/bin:$PATH
Expand Down
59 changes: 45 additions & 14 deletions PyTorchSimFrontend/extension_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -204,19 +205,25 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool

// Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
TORCH_CHECK(self.sizes() == dst.sizes());
TORCH_CHECK(self.scalar_type() == dst.scalar_type());

if (self.is_contiguous() && dst.is_contiguous()) {
const bool same_dtype = (self.scalar_type() == dst.scalar_type());
const bool both_contig = self.is_contiguous() && dst.is_contiguous();

// 1) fast path
if (same_dtype && both_contig) {
std::memcpy(dst.mutable_data_ptr(),
self.data_ptr(),
dst.storage().nbytes());
} else {
// Using cpu tensor to accomplishment stride copy.
at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self);
at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst);
cpu_dst.copy_(cpu_self);
return dst;
}

// 2) slow path
at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self);
at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst);
if (!same_dtype) {
cpu_self = cpu_self.to(cpu_dst.scalar_type(), /*non_blocking=*/false, /*copy=*/true);
}
cpu_dst.copy_(cpu_self);
return dst;
}

Expand All @@ -230,7 +237,6 @@ at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {

at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt, c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt, c10::optional<bool> pin_memory_opt) {
op_counter += 1;

constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
auto dtype = c10::dtype_or_default(dtype_opt);
return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
Expand All @@ -244,7 +250,23 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional<at::ScalarType> dty
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, dtype, optional_memory_format);
}

// This macro does the heavy lifting.
at::Tensor& custom_arange_start_out_impl(
const c10::Scalar& start,
const c10::Scalar& end,
const c10::Scalar& step,
at::Tensor& out) {
//const int64_t n = arange_len(start.toDouble(), end.toDouble(), step.toDouble());
//at::native::resize_output(out, {n});
return out;
}

static at::Tensor custom_to_dtype_impl(const at::Tensor& self,
c10::ScalarType dtype,
bool non_blocking, bool copy,
c10::optional<c10::MemoryFormat> memory_format) {
return at::native::to(self, dtype, non_blocking, copy, memory_format);
}

// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
// Later in this file, we map a custom device to the PrivateUse1 device type,
Expand All @@ -255,21 +277,27 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional<at::ScalarType> dty
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("to.Device", &custom_to_device);
m.impl("to.dtype", &custom_to_dtype_impl);
m.impl("fill_.Scalar", &custom_fill__scalar);
m.impl("_copy_from", &custom__copy_from);
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
m.impl("empty_strided", &custom_empty_strided);
m.impl("empty.memory_format", &custom_empty);
m.impl("as_strided", at::native::as_strided_tensorimpl);
m.impl("view", at::native::view);
m.impl("arange.start_out", &custom_arange_start_out_impl);
}

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl("to.dtype", &custom_to_dtype_impl);
}

TORCH_LIBRARY_FRAGMENT(aten, m) {
m.def(
"_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
torch::dispatch(
c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor),
{at::Tag::pt2_compliant_tag});
m.def(
"_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
torch::dispatch(
c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor),
{at::Tag::pt2_compliant_tag});
}

void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
Expand Down Expand Up @@ -307,6 +335,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("_foreach_addcdiv_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}

// This basic implementation doesn't bother dealing with different device indices
Expand Down
10 changes: 8 additions & 2 deletions PyTorchSimFrontend/mlir/mlir_bmm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def render(self,

W_tensor = empty_strided(W.layout.size, W.layout.stride)
X_tensor = empty_strided(X.layout.size, X.layout.stride)
if len(W_tensor.size()) > 3:
if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2:
W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]])
if len(X_tensor.size()) > 3:
if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2:
X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]])
B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2]

Expand Down Expand Up @@ -217,6 +217,7 @@ def render(self,
X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride)
X_tile_desc.set_name("X_buffer")
X_tile_desc.offset = X.get_layout().offset
X_stride = X_tensor.stride()
X_idx = [loop_dim[0]*X_stride[0], loop_dim[1]*X_stride[1], loop_dim[3]*X_stride[2]] # To keep index arguemnt order, we used index_list

Expand All @@ -225,6 +226,7 @@ def render(self,
W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride)
W_tile_desc.set_name("W_buffer")
W_tile_desc.offset = W.get_layout().offset
W_stride = W_tensor.stride()
W_idx = [loop_dim[0]*W_stride[0], loop_dim[3]*W_stride[1], loop_dim[2]*W_stride[2]]

Expand All @@ -241,8 +243,12 @@ def render(self,
Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[2]*Y_stride[2], loop_dim[1]*Y_stride[1]]

# Extract Bias info
Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride)
Bias_tile_desc.set_name("Y_buffer")
if Bias is not None:
Bias_stride = Bias.get_layout().stride
Bias_tile_desc.offset = Bias.get_layout().offset
if nr_rdim == 0:
Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[1]*Bias_stride[1], loop_dim[2]*Bias_stride[2]]
else:
Expand Down
Loading