-
Notifications
You must be signed in to change notification settings - Fork 175
feat(Ascend): Add some new Ascend Ops #621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,9 +34,6 @@ void AscendAddOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor> | |
| if (x.dtype() != y.dtype() || x.dtype() != z.dtype()) { | ||
| NYI("AscendAddOp currently requires x/y/z have same dtype"); | ||
| } | ||
| if (x.numel() != y.numel() || x.numel() != z.numel()) { | ||
| NYI("AscendAddOp demo only supports no-broadcast case (numel equal)"); | ||
| } | ||
|
|
||
| atb::infer::ElewiseParam addParam; | ||
| addParam.elewiseType = atb::infer::ElewiseParam::ELEWISE_ADD; | ||
|
|
@@ -106,4 +103,174 @@ void AscendAddOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor> | |
| atb::DestroyOperation(op); | ||
| } | ||
|
|
||
| AscendSubOp::AscendSubOp(const aops::SubOpOptions& options) : aops::SubOp(options) {} | ||
|
|
||
| void AscendSubOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| BaseOp::setup(inputs, outputs); | ||
| } | ||
|
|
||
| void AscendSubOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| MLLM_RT_ASSERT_EQ(inputs.size(), 2); | ||
| MLLM_RT_ASSERT_EQ(outputs.size(), 1); | ||
|
|
||
| const auto& x = inputs[0]; | ||
| const auto& y = inputs[1]; | ||
| auto& z = outputs[0]; | ||
|
|
||
| if (x.dtype() != y.dtype() || x.dtype() != z.dtype()) { | ||
| NYI("AscendSubOp currently requires x/y/z have same dtype"); | ||
| } | ||
|
|
||
| atb::infer::ElewiseParam subParam; | ||
| subParam.elewiseType = atb::infer::ElewiseParam::ELEWISE_SUB; | ||
|
|
||
| atb::Operation* op = nullptr; | ||
| auto st = atb::CreateOperation(subParam, &op); | ||
| if (st != atb::NO_ERROR || op == nullptr) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(ELEWISE_SUB) failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| atb::Context* atb_ctx = getGlobalAtbContext(); | ||
|
|
||
| atb::Tensor atb_x; | ||
| atb::Tensor atb_y; | ||
| atb::Tensor atb_z; | ||
|
|
||
| fillAtbTensorDesc(x, atb_x.desc); | ||
| fillAtbTensorDesc(y, atb_y.desc); | ||
| fillAtbTensorDesc(z, atb_z.desc); | ||
|
|
||
| atb_x.deviceData = reinterpret_cast<uint8_t*>(x.ptr<void>()); | ||
| atb_x.dataSize = x.bytes(); | ||
| atb_y.deviceData = reinterpret_cast<uint8_t*>(y.ptr<void>()); | ||
| atb_y.dataSize = y.bytes(); | ||
| atb_z.deviceData = reinterpret_cast<uint8_t*>(z.ptr<void>()); | ||
| atb_z.dataSize = z.bytes(); | ||
|
|
||
| atb::SVector<atb::Tensor> inTensors; | ||
| atb::SVector<atb::Tensor> outTensors; | ||
| inTensors.push_back(atb_x); | ||
| inTensors.push_back(atb_y); | ||
| outTensors.push_back(atb_z); | ||
|
|
||
| atb::VariantPack vp; | ||
| vp.inTensors = inTensors; | ||
| vp.outTensors = outTensors; | ||
|
|
||
| uint64_t workspaceSize = 0; | ||
| st = op->Setup(vp, workspaceSize, atb_ctx); | ||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SubOp Setup failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| void* workspace = nullptr; | ||
| int workspace_block_id = -1; | ||
| if (workspaceSize > 0) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id); | ||
| mem_mgr.getBlockPtr(workspace_block_id, workspace); | ||
| } | ||
|
Comment on lines
+168
to
+172
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same Both Also applies to: 253-257 🤖 Prompt for AI Agents |
||
| { | ||
| ASCEND_TIME_SCOPE("AscendSubOp::forward"); | ||
| st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx); | ||
| } | ||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SubOp Execute failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| syncGlobalAtbStream(); | ||
|
|
||
| if (workspace_block_id != -1) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.freeBlock(workspace_block_id); | ||
| } | ||
|
|
||
| atb::DestroyOperation(op); | ||
| } | ||
|
|
||
| AscendMulOp::AscendMulOp(const aops::MulOpOptions& options) : aops::MulOp(options) {} | ||
|
|
||
| void AscendMulOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| BaseOp::setup(inputs, outputs); | ||
| } | ||
|
|
||
| void AscendMulOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) { | ||
| MLLM_RT_ASSERT_EQ(inputs.size(), 2); | ||
| MLLM_RT_ASSERT_EQ(outputs.size(), 1); | ||
|
|
||
| const auto& x = inputs[0]; | ||
| const auto& y = inputs[1]; | ||
| auto& z = outputs[0]; | ||
|
|
||
| if (x.dtype() != y.dtype() || x.dtype() != z.dtype()) { | ||
| NYI("AscendMulOp currently requires x/y/z have same dtype"); | ||
| } | ||
|
|
||
| atb::infer::ElewiseParam mulParam; | ||
| mulParam.elewiseType = atb::infer::ElewiseParam::ELEWISE_MUL; | ||
|
|
||
| atb::Operation* op = nullptr; | ||
| auto st = atb::CreateOperation(mulParam, &op); | ||
| if (st != atb::NO_ERROR || op == nullptr) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(ELEWISE_MUL) failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| atb::Context* atb_ctx = getGlobalAtbContext(); | ||
|
|
||
| atb::Tensor atb_x; | ||
| atb::Tensor atb_y; | ||
| atb::Tensor atb_z; | ||
|
|
||
| fillAtbTensorDesc(x, atb_x.desc); | ||
| fillAtbTensorDesc(y, atb_y.desc); | ||
| fillAtbTensorDesc(z, atb_z.desc); | ||
|
|
||
| atb_x.deviceData = reinterpret_cast<uint8_t*>(x.ptr<void>()); | ||
| atb_x.dataSize = x.bytes(); | ||
| atb_y.deviceData = reinterpret_cast<uint8_t*>(y.ptr<void>()); | ||
| atb_y.dataSize = y.bytes(); | ||
| atb_z.deviceData = reinterpret_cast<uint8_t*>(z.ptr<void>()); | ||
| atb_z.dataSize = z.bytes(); | ||
|
|
||
| atb::SVector<atb::Tensor> inTensors; | ||
| atb::SVector<atb::Tensor> outTensors; | ||
| inTensors.push_back(atb_x); | ||
| inTensors.push_back(atb_y); | ||
| outTensors.push_back(atb_z); | ||
|
|
||
| atb::VariantPack vp; | ||
| vp.inTensors = inTensors; | ||
| vp.outTensors = outTensors; | ||
|
|
||
| uint64_t workspaceSize = 0; | ||
| st = op->Setup(vp, workspaceSize, atb_ctx); | ||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB MulOp Setup failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| void* workspace = nullptr; | ||
| int workspace_block_id = -1; | ||
| if (workspaceSize > 0) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id); | ||
| mem_mgr.getBlockPtr(workspace_block_id, workspace); | ||
| } | ||
| { | ||
| ASCEND_TIME_SCOPE("AscendMulOp::forward"); | ||
| st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx); | ||
| } | ||
| if (st != atb::NO_ERROR) { | ||
| MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB MulOp Execute failed, status={}", static_cast<int>(st)); | ||
| } | ||
|
|
||
| syncGlobalAtbStream(); | ||
|
|
||
| if (workspace_block_id != -1) { | ||
| auto& mem_mgr = getAscendMemoryManager(); | ||
| mem_mgr.freeBlock(workspace_block_id); | ||
| } | ||
|
|
||
| atb::DestroyOperation(op); | ||
| } | ||
|
|
||
| } // namespace mllm::ascend | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 721
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 4216
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 2581
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 94
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 814
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 9863
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 2524
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 49
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 3824
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 2143
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 830
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 989
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 2800
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 281
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 4524
🏁 Script executed:
Repository: UbiquitousLearning/mllm
Length of output: 354
Add input validation to
fillAtbTensor.fillAtbTensorhardcodes FP16 in the descriptor and assumes a device pointer without validating the input tensor. While some call sites (SoftmaxOp, LinearOp) validate dtype before calling this function, others (RMSNormOp) do not. Add runtime checks to ensure tensors are FP16 and on the Ascend device, providing defense-in-depth against invalid tensor usage.Suggested validation
void fillAtbTensor(const Tensor& t, atb::Tensor& atb_tensor) { + MLLM_RT_ASSERT(t.device() == kAscend); + MLLM_RT_ASSERT(t.dtype() == MLLM_TYPE_F16); + MLLM_RT_ASSERT(!t.isNil()); fillAtbTensorDesc(t, atb_tensor.desc); - atb_tensor.deviceData = reinterpret_cast<uint8_t*>(t.ptr<void>()); + auto* ptr = t.ptr<void>(); + MLLM_RT_ASSERT(ptr != nullptr); + atb_tensor.deviceData = reinterpret_cast<uint8_t*>(ptr); // Use MLLM tensor's actual bytes as dataSize to match allocated memory atb_tensor.dataSize = t.bytes(); }🤖 Prompt for AI Agents