-
Notifications
You must be signed in to change notification settings - Fork 175
[Ascend] Implement Concat and Slice operators #629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| // Copyright (c) MLLM Team. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "mllm/backends/ascend/ops/AscendConcatOp.hpp" | ||
|
|
||
| #include <iostream> | ||
| #include <acl/acl.h> | ||
| #include <atb/atb_infer.h> | ||
| #include <atb/types.h> | ||
| #include <atb/utils.h> | ||
| #include <atb/infer_op_params.h> | ||
|
|
||
| #include "mllm/utils/Common.hpp" | ||
| #include "mllm/core/DataTypes.hpp" | ||
| #include "mllm/core/Tensor.hpp" | ||
| #include "mllm/backends/ascend/memory/AscendMemoryManager.hpp" | ||
| #include "mllm/backends/ascend/AscendCommon.hpp" | ||
|
|
||
| namespace mllm::ascend { | ||
|
|
||
| AscendConcatOp::AscendConcatOp(const aops::ConcatOpOptions& options) : aops::ConcatOp(options) {} | ||
|
|
||
| void AscendConcatOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| BaseOp::setup(inputs, outputs); | ||
| } | ||
|
|
||
| void AscendConcatOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| MLLM_RT_ASSERT(inputs.size() >= 1); | ||
| MLLM_RT_ASSERT_EQ(outputs.size(), 1); | ||
|
|
||
| if (inputs.size() == 1) { | ||
| const size_t data_size = inputs[0].bytes(); | ||
| const void* src_data = inputs[0].ptr<void>(); | ||
| void* dst_data = outputs[0].ptr<void>(); | ||
|
|
||
| if (src_data != dst_data) { | ||
| auto ret = aclrtMemcpy(dst_data, data_size, src_data, data_size, ACL_MEMCPY_DEVICE_TO_DEVICE); | ||
| if (ret != ACL_SUCCESS) { | ||
| MLLM_ACL_CHECK(ret); | ||
| } | ||
| syncGlobalAtbStream(); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| int32_t concat_dim = options().dim; | ||
| if (concat_dim < 0) { | ||
| concat_dim += static_cast<int32_t>(inputs[0].rank()); | ||
| } | ||
|
|
||
| auto run_concat = [&](const Tensor& left, const Tensor& right, Tensor& out) { | ||
| atb::infer::ConcatParam param; | ||
| param.concatDim = concat_dim; | ||
|
|
||
| atb::Operation* op = nullptr; | ||
| auto st = atb::CreateOperation(param, &op); | ||
| if (st != atb::NO_ERROR || op == nullptr) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(Concat) failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| atb::Context* atb_ctx = getGlobalAtbContext(); | ||
|
|
||
| atb::SVector<atb::Tensor> inTensors; | ||
| atb::Tensor atb_left; | ||
| atb::Tensor atb_right; | ||
| fillAtbTensor(left, atb_left); | ||
| fillAtbTensor(right, atb_right); | ||
| inTensors.push_back(atb_left); | ||
| inTensors.push_back(atb_right); | ||
|
|
||
| atb::Tensor atb_out; | ||
| fillAtbTensor(out, atb_out); | ||
| atb::SVector<atb::Tensor> outTensors; | ||
| outTensors.push_back(atb_out); | ||
|
|
||
| atb::VariantPack vp; | ||
| vp.inTensors = inTensors; | ||
| vp.outTensors = outTensors; | ||
|
|
||
| uint64_t workspaceSize = 0; | ||
| st = op->Setup(vp, workspaceSize, atb_ctx); | ||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB ConcatOp Setup failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| void* workspace = nullptr; | ||
| int workspace_block_id = -1; | ||
| if (workspaceSize > 0) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id); | ||
| mem_mgr.getBlockPtr(workspace_block_id, workspace); | ||
| } | ||
|
|
||
| { | ||
| ASCEND_TIME_SCOPE("AscendConcatOp::forward"); | ||
| st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx); | ||
| } | ||
|
|
||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB ConcatOp Execute failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| syncGlobalAtbStream(); | ||
|
|
||
| if (workspace_block_id != -1) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.freeBlock(workspace_block_id); | ||
| } | ||
|
|
||
| atb::DestroyOperation(op); | ||
| }; | ||
|
|
||
| std::vector<int32_t> current_shape = inputs[0].shape(); | ||
| Tensor current = inputs[0]; | ||
|
|
||
| for (size_t i = 1; i < inputs.size(); ++i) { | ||
| current_shape[concat_dim] += inputs[i].shape()[concat_dim]; | ||
|
|
||
| if (i == inputs.size() - 1) { | ||
| run_concat(current, inputs[i], outputs[0]); | ||
| } else { | ||
| Tensor temp = Tensor::empty(current_shape, outputs[0].dtype(), outputs[0].device()).alloc(); | ||
| run_concat(current, inputs[i], temp); | ||
| current = temp; | ||
|
Comment on lines
+46
to
+124
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate concat_dim and non-concat shapes before indexing.
✅ Suggested bounds + shape validation- int32_t concat_dim = options().dim;
- if (concat_dim < 0) {
- concat_dim += static_cast<int32_t>(inputs[0].rank());
- }
+ int32_t concat_dim = options().dim;
+ const int32_t rank = static_cast<int32_t>(inputs[0].rank());
+ if (concat_dim < 0) {
+ concat_dim += rank;
+ }
+ MLLM_RT_ASSERT(concat_dim >= 0 && concat_dim < rank);
- std::vector<int32_t> current_shape = inputs[0].shape();
+ const auto base_shape = inputs[0].shape();
+ std::vector<int32_t> current_shape = base_shape;
Tensor current = inputs[0];
for (size_t i = 1; i < inputs.size(); ++i) {
- current_shape[concat_dim] += inputs[i].shape()[concat_dim];
+ MLLM_RT_ASSERT_EQ(static_cast<int32_t>(inputs[i].rank()), rank);
+ for (int32_t d = 0; d < rank; ++d) {
+ if (d == concat_dim) continue;
+ MLLM_RT_ASSERT_EQ(inputs[i].shape()[d], base_shape[d]);
+ }
+ current_shape[concat_dim] += inputs[i].shape()[concat_dim];🤖 Prompt for AI Agents |
||
| } | ||
| } | ||
| } | ||
|
|
||
| } // namespace mllm::ascend | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,27 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Copyright (c) MLLM Team. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Licensed under the MIT License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "mllm/core/BaseOp.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "mllm/core/aops/ConcatOp.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "mllm/core/OpTypes.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace mllm::ascend { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AscendConcatOp final : public aops::ConcatOp { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| explicit AscendConcatOp(const aops::ConcatOpOptions& options); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AscendConcatOpFactory final : public TypedOpFactory<OpTypes::kConcat, aops::ConcatOpOptions> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::shared_ptr<BaseOp> createOpImpl(const aops::ConcatOpOptions& options) override { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return std::make_shared<AscendConcatOp>(options); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+12
to
+24
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add brief API docs for AscendConcatOp and its factory. These new public classes/methods are missing doc comments; please add short Doxygen-style descriptions for purpose and parameters. 📝 Suggested doc comments-class AscendConcatOp final : public aops::ConcatOp {
+/// Ascend backend implementation of Concat.
+class AscendConcatOp final : public aops::ConcatOp {
public:
- explicit AscendConcatOp(const aops::ConcatOpOptions& options);
+ /// Constructs the op with Concat options.
+ explicit AscendConcatOp(const aops::ConcatOpOptions& options);
- void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
- void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
+ /// Validates inputs/outputs and prepares resources.
+ void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
+ /// Executes the concat on Ascend.
+ void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
};
-class AscendConcatOpFactory final : public TypedOpFactory<OpTypes::kConcat, aops::ConcatOpOptions> {
+/// Factory for creating AscendConcatOp instances.
+class AscendConcatOpFactory final : public TypedOpFactory<OpTypes::kConcat, aops::ConcatOpOptions> {
public:
+ /// Creates an AscendConcatOp for the given options.
std::shared_ptr<BaseOp> createOpImpl(const aops::ConcatOpOptions& options) override {
return std::make_shared<AscendConcatOp>(options);
}
};As per coding guidelines: Ensure public APIs, classes, and functions have clear docstrings or comments explaining purpose, parameters, returns, and errors. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace mllm::ascend | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| // Copyright (c) MLLM Team. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "mllm/backends/ascend/ops/AscendSliceOp.hpp" | ||
|
|
||
| #include <acl/acl.h> | ||
| #include <atb/atb_infer.h> | ||
| #include <atb/types.h> | ||
| #include <atb/utils.h> | ||
| #include <atb/infer_op_params.h> | ||
|
|
||
| #include "mllm/utils/Common.hpp" | ||
| #include "mllm/core/DataTypes.hpp" | ||
| #include "mllm/core/Tensor.hpp" | ||
| #include "mllm/backends/ascend/memory/AscendMemoryManager.hpp" | ||
| #include "mllm/backends/ascend/AscendCommon.hpp" | ||
|
|
||
| namespace mllm::ascend { | ||
|
|
||
| AscendSliceOp::AscendSliceOp(const aops::SliceOpOptions& options) : aops::SliceOp(options) {} | ||
|
|
||
| void AscendSliceOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| BaseOp::setup(inputs, outputs); | ||
| } | ||
|
|
||
| void AscendSliceOp::reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| auto& input = inputs[0]; | ||
| auto shape = input.shape(); | ||
| auto slice_index = options().indices_; | ||
|
|
||
| MLLM_RT_ASSERT_EQ(slice_index.size(), shape.size()); | ||
|
|
||
| std::vector<int> out_shape; | ||
| for (size_t i = 0; i < shape.size(); ++i) { | ||
| const auto& pair = slice_index[i]; | ||
| int32_t start = pair.start_; | ||
| int32_t end = pair.end_; | ||
|
|
||
| if (start == kAll) { start = 0; } | ||
| if (end == kAll) { end = shape[i]; } | ||
|
|
||
| if (start < 0) { start = start + shape[i]; } | ||
| if (end < 0) { end = end + shape[i]; } | ||
|
|
||
| start = std::max(0, std::min(start, static_cast<int>(shape[i]))); | ||
| end = std::max(0, std::min(end, static_cast<int>(shape[i]))); | ||
|
|
||
| int len = std::max(0, end - start); | ||
| out_shape.push_back(len); | ||
| } | ||
|
|
||
| outputs.emplace_back(Tensor::empty(out_shape, input.dtype(), input.device())); | ||
| } | ||
|
|
||
| void AscendSliceOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| atb::infer::SliceParam param; | ||
| auto& input = inputs[0]; | ||
| auto shape = input.shape(); | ||
| auto slice_index = options().indices_; | ||
|
|
||
| for(size_t i=0; i<shape.size(); ++i) { | ||
| int32_t start = slice_index[i].start_; | ||
| int32_t end = slice_index[i].end_; | ||
| int32_t dim_size = shape[i]; | ||
|
|
||
| if (start == kAll) start = 0; | ||
| if (end == kAll) end = dim_size; | ||
|
|
||
| if (start < 0) start += dim_size; | ||
| if (end < 0) end += dim_size; | ||
|
|
||
| start = std::max(0, std::min(start, dim_size)); | ||
| end = std::max(0, std::min(end, dim_size)); | ||
|
|
||
| param.offsets.push_back(start); | ||
| param.size.push_back(std::max(0, end - start)); | ||
| } | ||
|
Comment on lines
+55
to
+77
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add slice index size validation in
🛠️ Proposed fix void AscendSliceOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
atb::infer::SliceParam param;
auto& input = inputs[0];
auto shape = input.shape();
auto slice_index = options().indices_;
+ MLLM_RT_ASSERT_EQ(slice_index.size(), shape.size());
for(size_t i=0; i<shape.size(); ++i) {🤖 Prompt for AI Agents |
||
|
|
||
| atb::Operation* op = nullptr; | ||
| auto st = atb::CreateOperation(param, &op); | ||
| if (st != atb::NO_ERROR || op == nullptr) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(Slice) failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| atb::Context* atb_ctx = getGlobalAtbContext(); | ||
|
|
||
| atb::SVector<atb::Tensor> inTensors; | ||
| std::vector<atb::Tensor> atb_inputs(inputs.size()); | ||
| for (size_t i = 0; i < inputs.size(); ++i) { | ||
| fillAtbTensor(inputs[i], atb_inputs[i]); | ||
| inTensors.push_back(atb_inputs[i]); | ||
| } | ||
|
|
||
| atb::Tensor atb_output; | ||
| fillAtbTensor(outputs[0], atb_output); | ||
| atb::SVector<atb::Tensor> outTensors; | ||
| outTensors.push_back(atb_output); | ||
|
|
||
| atb::VariantPack vp; | ||
| vp.inTensors = inTensors; | ||
| vp.outTensors = outTensors; | ||
|
|
||
| uint64_t workspaceSize = 0; | ||
| st = op->Setup(vp, workspaceSize, atb_ctx); | ||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SliceOp Setup failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| void* workspace = nullptr; | ||
| int workspace_block_id = -1; | ||
| if (workspaceSize > 0) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id); | ||
| mem_mgr.getBlockPtr(workspace_block_id, workspace); | ||
| } | ||
|
Comment on lines
+103
to
+115
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard against If ATB returns a workspace larger than 4GB, the cast will truncate and can under-allocate, risking memory corruption. 🛠️ Proposed fix+#include <limits>
@@
uint64_t workspaceSize = 0;
st = op->Setup(vp, workspaceSize, atb_ctx);
@@
void* workspace = nullptr;
int workspace_block_id = -1;
if (workspaceSize > 0) {
+ if (workspaceSize > std::numeric_limits<uint32_t>::max()) {
+ MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SliceOp workspace too large, size={}", workspaceSize);
+ }
auto& mem_mgr = getAscendMemoryManager();
mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id);
mem_mgr.getBlockPtr(workspace_block_id, workspace);
}🤖 Prompt for AI Agents |
||
|
|
||
| { | ||
| ASCEND_TIME_SCOPE("AscendSliceOp::forward"); | ||
| st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx); | ||
| } | ||
|
|
||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SliceOp Execute failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| syncGlobalAtbStream(); | ||
|
|
||
| if (workspace_block_id != -1) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.freeBlock(workspace_block_id); | ||
| } | ||
|
|
||
| atb::DestroyOperation(op); | ||
| } | ||
|
|
||
| } // namespace mllm::ascend | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,28 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Copyright (c) MLLM Team. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Licensed under the MIT License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "mllm/core/BaseOp.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "mllm/core/aops/SliceOp.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "mllm/core/OpTypes.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace mllm::ascend { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AscendSliceOp final : public aops::SliceOp { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| explicit AscendSliceOp(const aops::SliceOpOptions& options); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AscendSliceOpFactory final : public TypedOpFactory<OpTypes::kSlice, aops::SliceOpOptions> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::shared_ptr<BaseOp> createOpImpl(const aops::SliceOpOptions& options) override { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return std::make_shared<AscendSliceOp>(options); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+12
to
+25
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add brief API docs for AscendSliceOp and its factory. These new public classes/methods are missing doc comments; please add short Doxygen-style descriptions for purpose and parameters. 📝 Suggested doc comments-class AscendSliceOp final : public aops::SliceOp {
+/// Ascend backend implementation of Slice.
+class AscendSliceOp final : public aops::SliceOp {
public:
- explicit AscendSliceOp(const aops::SliceOpOptions& options);
+ /// Constructs the op with Slice options.
+ explicit AscendSliceOp(const aops::SliceOpOptions& options);
- void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
- void reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
- void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
+ /// Validates inputs/outputs and prepares resources.
+ void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
+ /// Infers output shapes from inputs.
+ void reshape(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
+ /// Executes the slice on Ascend.
+ void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
};
-class AscendSliceOpFactory final : public TypedOpFactory<OpTypes::kSlice, aops::SliceOpOptions> {
+/// Factory for creating AscendSliceOp instances.
+class AscendSliceOpFactory final : public TypedOpFactory<OpTypes::kSlice, aops::SliceOpOptions> {
public:
+ /// Creates an AscendSliceOp for the given options.
std::shared_ptr<BaseOp> createOpImpl(const aops::SliceOpOptions& options) override {
return std::make_shared<AscendSliceOp>(options);
}
};As per coding guidelines: Ensure public APIs, classes, and functions have clear docstrings or comments explaining purpose, parameters, returns, and errors. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace mllm::ascend | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| // Copyright (c) MLLM Team. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "mllm/mllm.hpp" | ||
| #include "mllm/core/Tensor.hpp" | ||
| #include "mllm/nn/Functional.hpp" | ||
| #include "KernelTestHelper.hpp" // Has KernelTest base class | ||
|
|
||
| class AscendConcatKernelTest : public KernelTest { | ||
| public: | ||
| bool ConcatFloat16Test(const std::vector<mllm::Tensor::shape_t>& input_shapes, int dim) { | ||
| using namespace mllm; | ||
|
|
||
| std::vector<Tensor> inputs_cpu; | ||
| for (const auto& shape : input_shapes) { | ||
| inputs_cpu.push_back(Tensor::random(shape, -1.0, 1.0, kFloat16, kCPU)); | ||
| } | ||
|
|
||
| // CPU Reference | ||
| auto out_cpu = nn::functional::concat(inputs_cpu, dim); | ||
|
|
||
| // Ascend | ||
| std::vector<Tensor> inputs_ascend; | ||
| for (auto& t : inputs_cpu) { | ||
| inputs_ascend.push_back(t.to(kAscend)); | ||
| } | ||
|
|
||
| auto out_ascend = nn::functional::concat(inputs_ascend, dim); | ||
| auto out_back = out_ascend.to(kCPU); | ||
|
|
||
| auto result = test::allClose(out_back, out_cpu, 1e-2, 1e-2); | ||
| if (!result.is_close) { | ||
| std::cout << "[ConcatTest] FAILED! dim=" << dim << std::endl; | ||
| return false; | ||
| } | ||
| std::cout << "[ConcatTest] PASSED dim=" << dim << std::endl; | ||
| return true; | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 5388
Add overflow guard for workspace size cast to uint32_t.
workspaceSize(uint64_t) is cast touint32_tforallocateBlock(). If ATB returns a workspace > 4 GiB, silent truncation will cause under-allocation and potential memory corruption. Add a bounds check before casting:int workspace_block_id = -1; if (workspaceSize > 0) { + if (workspaceSize > std::numeric_limits<uint32_t>::max()) { + MLLM_ERROR_EXIT(ExitCode::kAscendError, + "ATB ConcatOp workspace size exceeds uint32_t: {}", + workspaceSize); + } auto& mem_mgr = getAscendMemoryManager(); mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id); mem_mgr.getBlockPtr(workspace_block_id, workspace); }Also include
#include <limits>at the top.🤖 Prompt for AI Agents