From d1a8806f1948c6dd081bfd7ac54733d7bedcfc54 Mon Sep 17 00:00:00 2001 From: yuerqiqi <2500526025@qq.com> Date: Tue, 3 Feb 2026 11:25:41 +0800 Subject: [PATCH 1/2] [Ascend] Implement Concat and Slice operators --- mllm/backends/ascend/AscendBackend.cpp | 5 +- mllm/backends/ascend/ops/AscendConcatOp.cpp | 129 +++++++++++++++++++ mllm/backends/ascend/ops/AscendConcatOp.hpp | 27 ++++ mllm/backends/ascend/ops/AscendSliceOp.cpp | 136 ++++++++++++++++++++ mllm/backends/ascend/ops/AscendSliceOp.hpp | 28 ++++ tests/ascend/AscendConcatKernelTest.hpp | 41 ++++++ tests/ascend/AscendSliceKernelTest.hpp | 38 ++++++ tests/ascend/KernelTest.cpp | 27 ++++ 8 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 mllm/backends/ascend/ops/AscendConcatOp.cpp create mode 100644 mllm/backends/ascend/ops/AscendConcatOp.hpp create mode 100644 mllm/backends/ascend/ops/AscendSliceOp.cpp create mode 100644 mllm/backends/ascend/ops/AscendSliceOp.hpp create mode 100644 tests/ascend/AscendConcatKernelTest.hpp create mode 100644 tests/ascend/AscendSliceKernelTest.hpp diff --git a/mllm/backends/ascend/AscendBackend.cpp b/mllm/backends/ascend/AscendBackend.cpp index 6c17774b..838ca20f 100644 --- a/mllm/backends/ascend/AscendBackend.cpp +++ b/mllm/backends/ascend/AscendBackend.cpp @@ -14,12 +14,15 @@ #include "mllm/backends/ascend/ops/AscendViewOp.hpp" #include "mllm/backends/ascend/ops/AscendMatMulOp.hpp" #include "mllm/backends/ascend/ops/AscendSoftmaxOp.hpp" +#include "mllm/backends/ascend/ops/AscendConcatOp.hpp" +#include "mllm/backends/ascend/ops/AscendSliceOp.hpp" namespace mllm::ascend { AscendBackend::AscendBackend() : Backend(kAscend, createAscendAllocator()) { regOpFactory(); + AscendLinearOpFactory,AscendRMSNormOpFactory,AscendViewOpFactory,AscendMatMulOpFactory,AscendSoftmaxOpFactory, + AscendConcatOpFactory, AscendSliceOpFactory>(); auto& devices = AscendDeviceMetaInfo::instance().devices; for (const auto& device : devices) { const auto bytes_to_mb = [](size_t bytes) { return bytes / (1024.0 * 1024.0); }; diff --git a/mllm/backends/ascend/ops/AscendConcatOp.cpp b/mllm/backends/ascend/ops/AscendConcatOp.cpp new file mode 100644 index 00000000..1c4f1904 --- /dev/null +++ b/mllm/backends/ascend/ops/AscendConcatOp.cpp @@ -0,0 +1,129 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/ascend/ops/AscendConcatOp.hpp" + +#include +#include +#include +#include +#include +#include + +#include "mllm/utils/Common.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/backends/ascend/memory/AscendMemoryManager.hpp" +#include "mllm/backends/ascend/AscendCommon.hpp" + +namespace mllm::ascend { + +AscendConcatOp::AscendConcatOp(const aops::ConcatOpOptions& options) : aops::ConcatOp(options) {} + +void AscendConcatOp::setup(const std::vector& inputs, std::vector& outputs) { + BaseOp::setup(inputs, outputs); +} + +void AscendConcatOp::forward(const std::vector& inputs, std::vector& outputs) { + MLLM_RT_ASSERT(inputs.size() >= 1); + MLLM_RT_ASSERT_EQ(outputs.size(), 1); + + if (inputs.size() == 1) { + const size_t data_size = inputs[0].bytes(); + const void* src_data = inputs[0].ptr(); + void* dst_data = outputs[0].ptr(); + + if (src_data != dst_data) { + auto ret = aclrtMemcpy(dst_data, data_size, src_data, data_size, ACL_MEMCPY_DEVICE_TO_DEVICE); + if (ret != ACL_SUCCESS) { + MLLM_ACL_CHECK(ret); + } + syncGlobalAtbStream(); + } + return; + } + + int32_t concat_dim = options().dim; + if (concat_dim < 0) { + concat_dim += static_cast(inputs[0].rank()); + } + + auto run_concat = [&](const Tensor& left, const Tensor& right, Tensor& out) { + atb::infer::ConcatParam param; + param.concatDim = concat_dim; + + atb::Operation* op = nullptr; + auto st = atb::CreateOperation(param, &op); + if (st != atb::NO_ERROR || op == nullptr) { + MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(Concat) failed, status={}", static_cast(st)); + } + + atb::Context* atb_ctx = getGlobalAtbContext(); + + atb::SVector inTensors; + atb::Tensor atb_left; + atb::Tensor atb_right; + fillAtbTensor(left, atb_left); + fillAtbTensor(right, atb_right); + inTensors.push_back(atb_left); + inTensors.push_back(atb_right); + + atb::Tensor atb_out; + fillAtbTensor(out, atb_out); + atb::SVector outTensors; + outTensors.push_back(atb_out); + + atb::VariantPack vp; + vp.inTensors = inTensors; + vp.outTensors = outTensors; + + uint64_t workspaceSize = 0; + st = op->Setup(vp, workspaceSize, atb_ctx); + if (st != atb::NO_ERROR) { + MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB ConcatOp Setup failed, status={}", static_cast(st)); + } + + void* workspace = nullptr; + int workspace_block_id = -1; + if (workspaceSize > 0) { + auto& mem_mgr = getAscendMemoryManager(); + mem_mgr.allocateBlock(static_cast(workspaceSize), workspace_block_id); + mem_mgr.getBlockPtr(workspace_block_id, workspace); + } + + { + ASCEND_TIME_SCOPE("AscendConcatOp::forward"); + st = op->Execute(vp, reinterpret_cast(workspace), workspaceSize, atb_ctx); + } + + if (st != atb::NO_ERROR) { + MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB ConcatOp Execute failed, status={}", static_cast(st)); + } + + syncGlobalAtbStream(); + + if (workspace_block_id != -1) { + auto& mem_mgr = getAscendMemoryManager(); + mem_mgr.freeBlock(workspace_block_id); + } + + atb::DestroyOperation(op); + }; + + std::vector current_shape = inputs[0].shape(); + Tensor current = inputs[0]; + + for (size_t i = 1; i < inputs.size(); ++i) { + current_shape[concat_dim] += inputs[i].shape()[concat_dim]; + + if (i == inputs.size() - 1) { + run_concat(current, inputs[i], outputs[0]); + } else { + Tensor temp = Tensor::empty(current_shape, outputs[0].dtype(), outputs[0].device()).alloc(); + run_concat(current, inputs[i], temp); + current = temp; + } + } +} + +} // namespace mllm::ascend diff --git a/mllm/backends/ascend/ops/AscendConcatOp.hpp b/mllm/backends/ascend/ops/AscendConcatOp.hpp new file mode 100644 index 00000000..7410cfa7 --- /dev/null +++ b/mllm/backends/ascend/ops/AscendConcatOp.hpp @@ -0,0 +1,27 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/ConcatOp.hpp" +#include "mllm/core/OpTypes.hpp" + +namespace mllm::ascend { + +class AscendConcatOp final : public aops::ConcatOp { + public: + explicit AscendConcatOp(const aops::ConcatOpOptions& options); + + void setup(const std::vector& inputs, std::vector& outputs) override; + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class AscendConcatOpFactory final : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::ConcatOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::ascend diff --git a/mllm/backends/ascend/ops/AscendSliceOp.cpp b/mllm/backends/ascend/ops/AscendSliceOp.cpp new file mode 100644 index 00000000..039f8bbf --- /dev/null +++ b/mllm/backends/ascend/ops/AscendSliceOp.cpp @@ -0,0 +1,136 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/ascend/ops/AscendSliceOp.hpp" + +#include +#include +#include +#include +#include + +#include "mllm/utils/Common.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/backends/ascend/memory/AscendMemoryManager.hpp" +#include "mllm/backends/ascend/AscendCommon.hpp" + +namespace mllm::ascend { + +AscendSliceOp::AscendSliceOp(const aops::SliceOpOptions& options) : aops::SliceOp(options) {} + +void AscendSliceOp::setup(const std::vector& inputs, std::vector& outputs) { + BaseOp::setup(inputs, outputs); +} + +void AscendSliceOp::reshape(const std::vector& inputs, std::vector& outputs) { + auto& input = inputs[0]; + auto shape = input.shape(); + auto slice_index = options().indices_; + + MLLM_RT_ASSERT_EQ(slice_index.size(), shape.size()); + + std::vector out_shape; + for (size_t i = 0; i < shape.size(); ++i) { + const auto& pair = slice_index[i]; + int32_t start = pair.start_; + int32_t end = pair.end_; + + if (start == kAll) { start = 0; } + if (end == kAll) { end = shape[i]; } + + if (start < 0) { start = start + shape[i]; } + if (end < 0) { end = end + shape[i]; } + + start = std::max(0, std::min(start, static_cast(shape[i]))); + end = std::max(0, std::min(end, static_cast(shape[i]))); + + int len = std::max(0, end - start); + out_shape.push_back(len); + } + + outputs.emplace_back(Tensor::empty(out_shape, input.dtype(), input.device())); +} + +void AscendSliceOp::forward(const std::vector& inputs, std::vector& outputs) { + atb::infer::SliceParam param; + auto& input = inputs[0]; + auto shape = input.shape(); + auto slice_index = options().indices_; + + for(size_t i=0; i(st)); + } + + atb::Context* atb_ctx = getGlobalAtbContext(); + + atb::SVector inTensors; + std::vector atb_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + fillAtbTensor(inputs[i], atb_inputs[i]); + inTensors.push_back(atb_inputs[i]); + } + + atb::Tensor atb_output; + fillAtbTensor(outputs[0], atb_output); + atb::SVector outTensors; + outTensors.push_back(atb_output); + + atb::VariantPack vp; + vp.inTensors = inTensors; + vp.outTensors = outTensors; + + uint64_t workspaceSize = 0; + st = op->Setup(vp, workspaceSize, atb_ctx); + if (st != atb::NO_ERROR) { + MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SliceOp Setup failed, status={}", static_cast(st)); + } + + void* workspace = nullptr; + int workspace_block_id = -1; + if (workspaceSize > 0) { + auto& mem_mgr = getAscendMemoryManager(); + mem_mgr.allocateBlock(static_cast(workspaceSize), workspace_block_id); + mem_mgr.getBlockPtr(workspace_block_id, workspace); + } + + { + ASCEND_TIME_SCOPE("AscendSliceOp::forward"); + st = op->Execute(vp, reinterpret_cast(workspace), workspaceSize, atb_ctx); + } + + if (st != atb::NO_ERROR) { + MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SliceOp Execute failed, status={}", static_cast(st)); + } + + syncGlobalAtbStream(); + + if (workspace_block_id != -1) { + auto& mem_mgr = getAscendMemoryManager(); + mem_mgr.freeBlock(workspace_block_id); + } + + atb::DestroyOperation(op); +} + +} // namespace mllm::ascend diff --git a/mllm/backends/ascend/ops/AscendSliceOp.hpp b/mllm/backends/ascend/ops/AscendSliceOp.hpp new file mode 100644 index 00000000..fcec7d19 --- /dev/null +++ b/mllm/backends/ascend/ops/AscendSliceOp.hpp @@ -0,0 +1,28 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/SliceOp.hpp" +#include "mllm/core/OpTypes.hpp" + +namespace mllm::ascend { + +class AscendSliceOp final : public aops::SliceOp { + public: + explicit AscendSliceOp(const aops::SliceOpOptions& options); + + void setup(const std::vector& inputs, std::vector& outputs) override; + void reshape(const std::vector& inputs, std::vector& outputs) override; + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class AscendSliceOpFactory final : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::SliceOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::ascend diff --git a/tests/ascend/AscendConcatKernelTest.hpp b/tests/ascend/AscendConcatKernelTest.hpp new file mode 100644 index 00000000..880c150b --- /dev/null +++ b/tests/ascend/AscendConcatKernelTest.hpp @@ -0,0 +1,41 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/nn/Functional.hpp" +#include "KernelTestHelper.hpp" // Has KernelTest base class + +class AscendConcatKernelTest : public KernelTest { + public: + bool ConcatFloat16Test(const std::vector& input_shapes, int dim) { + using namespace mllm; + + std::vector inputs_cpu; + for (const auto& shape : input_shapes) { + inputs_cpu.push_back(Tensor::random(shape, -1.0, 1.0, kFloat16, kCPU)); + } + + // CPU Reference + auto out_cpu = nn::functional::concat(inputs_cpu, dim); + + // Ascend + std::vector inputs_ascend; + for (auto& t : inputs_cpu) { + inputs_ascend.push_back(t.to(kAscend)); + } + + auto out_ascend = nn::functional::concat(inputs_ascend, dim); + auto out_back = out_ascend.to(kCPU); + + auto result = test::allClose(out_back, out_cpu, 1e-2, 1e-2); + if (!result.is_close) { + std::cout << "[ConcatTest] FAILED! dim=" << dim << std::endl; + return false; + } + std::cout << "[ConcatTest] PASSED dim=" << dim << std::endl; + return true; + } +}; diff --git a/tests/ascend/AscendSliceKernelTest.hpp b/tests/ascend/AscendSliceKernelTest.hpp new file mode 100644 index 00000000..aba4430f --- /dev/null +++ b/tests/ascend/AscendSliceKernelTest.hpp @@ -0,0 +1,38 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/nn/Functional.hpp" +#include "KernelTestHelper.hpp" + +class AscendSliceKernelTest : public KernelTest { + public: + bool SliceFloat16Test(mllm::Tensor::shape_t input_shape, mllm::SliceIndices indices) { + using namespace mllm; + + Tensor in_cpu = Tensor::random(input_shape, -1.0, 1.0, kFloat16, kCPU); + + // CPU Reference (View) + Tensor out_cpu = in_cpu[indices]; + + // Ascend (Copy/Kernel) + Tensor in_ascend = in_cpu.to(kAscend); + Tensor out_ascend = in_ascend[indices]; + Tensor out_back = out_ascend.to(kCPU); + + // The output from Ascend should match the view on CPU + // We compare them. Note: out_cpu might be non-contiguous, allClose should handle or we make it contiguous. + Tensor out_cpu_cont = out_cpu.contiguous(); + + auto result = test::allClose(out_back, out_cpu_cont, 1e-2, 1e-2); + if (!result.is_close) { + std::cout << "[SliceTest] FAILED!" << std::endl; + return false; + } + std::cout << "[SliceTest] PASSED" << std::endl; + return true; + } +}; diff --git a/tests/ascend/KernelTest.cpp b/tests/ascend/KernelTest.cpp index 4e1747e8..1ddfd1d0 100644 --- a/tests/ascend/KernelTest.cpp +++ b/tests/ascend/KernelTest.cpp @@ -223,6 +223,33 @@ TEST_F(AscendAttentionKernelTest, GroupedQueryAttentionFloat16) { true); } +//===----------------------------------------------------------------------===// +// Concat +//===----------------------------------------------------------------------===// +#include "AscendConcatKernelTest.hpp" +TEST_F(AscendConcatKernelTest, ConcatFloat16) { + EXPECT_EQ(ConcatFloat16Test({{2, 3}, {2, 3}}, 0), true); + EXPECT_EQ(ConcatFloat16Test({{1, 8}, {1, 8}}, 1), true); + EXPECT_EQ(ConcatFloat16Test({{4, 16}, {4, 16}, {4, 16}}, 0), true); + EXPECT_EQ(ConcatFloat16Test({{2, 3, 4}, {2, 3, 5}}, 2), true); + EXPECT_EQ(ConcatFloat16Test({{2, 3, 4}, {2, 3, 6}}, -1), true); + EXPECT_EQ(ConcatFloat16Test({{2, 7}}, 0), true); +} + +//===----------------------------------------------------------------------===// +// Slice +//===----------------------------------------------------------------------===// +#include "AscendSliceKernelTest.hpp" +TEST_F(AscendSliceKernelTest, SliceFloat16) { + using namespace mllm; + // SliceIndicesPair(start, end) + EXPECT_EQ(SliceFloat16Test({4, 4}, {SliceIndicesPair(0, 2), SliceIndicesPair(0, 4)}), true); + EXPECT_EQ(SliceFloat16Test({4, 8}, {SliceIndicesPair(1, 3), SliceIndicesPair(2, 6)}), true); + EXPECT_EQ(SliceFloat16Test({2, 16}, {SliceIndicesPair(0, 1), SliceIndicesPair(0, 8)}), true); + EXPECT_EQ(SliceFloat16Test({5, 4}, {SliceIndicesPair(-3, -1), SliceIndicesPair(0, 4)}), true); + EXPECT_EQ(SliceFloat16Test({3, 4, 5}, {SliceIndicesPair(kAll, kAll), SliceIndicesPair(1, 3), SliceIndicesPair(0, 5)}), true); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); From a562bbd47bd80f38694de29be4bee9f810b72398 Mon Sep 17 00:00:00 2001 From: yuerqiqi <2500526025@qq.com> Date: Tue, 3 Feb 2026 12:11:09 +0800 Subject: [PATCH 2/2] Update mllm/backends/ascend/ops/AscendSliceOp.cpp Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- mllm/backends/ascend/ops/AscendSliceOp.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllm/backends/ascend/ops/AscendSliceOp.cpp b/mllm/backends/ascend/ops/AscendSliceOp.cpp index 039f8bbf..15571d67 100644 --- a/mllm/backends/ascend/ops/AscendSliceOp.cpp +++ b/mllm/backends/ascend/ops/AscendSliceOp.cpp @@ -73,7 +73,7 @@ void AscendSliceOp::forward(const std::vector& inputs, std::vector