Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mllm/backends/ascend/AscendBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@

#include "mllm/backends/ascend/ops/AscendElewiseOps.hpp"
#include "mllm/backends/ascend/ops/AscendX2XOp.hpp"
#include "mllm/backends/ascend/ops/AscendSiLUOp.hpp"
#include "mllm/backends/ascend/ops/AscendLinearOp.hpp"
#include "mllm/backends/ascend/ops/AscendRMSNormOp.hpp"
#include "mllm/backends/ascend/ops/AscendViewOp.hpp"
#include "mllm/backends/ascend/ops/AscendMatMulOp.hpp"
#include "mllm/backends/ascend/ops/AscendSoftmaxOp.hpp"

namespace mllm::ascend {

AscendBackend::AscendBackend() : Backend(kAscend, createAscendAllocator()) {
regOpFactory<AscendAddOpFactory>();
regOpFactory<AscendX2XOpFactory>();
regOpFactory<AscendAddOpFactory,AscendSubOpFactory,AscendMulOpFactory,AscendX2XOpFactory,AscendSiLUOpFactory,
AscendLinearOpFactory,AscendRMSNormOpFactory,AscendViewOpFactory,AscendMatMulOpFactory,AscendSoftmaxOpFactory>();
auto& devices = AscendDeviceMetaInfo::instance().devices;
for (const auto& device : devices) {
const auto bytes_to_mb = [](size_t bytes) { return bytes / (1024.0 * 1024.0); };
Expand Down
10 changes: 7 additions & 3 deletions mllm/backends/ascend/AscendCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ void fillAtbTensorDesc(const Tensor& t, atb::TensorDesc& desc) {
}
}

void fillAtbTensor(const Tensor& t, atb::Tensor& atb_tensor) {
fillAtbTensorDesc(t, atb_tensor.desc);
atb_tensor.deviceData = reinterpret_cast<uint8_t*>(t.ptr<void>());
// Use MLLM tensor's actual bytes as dataSize to match allocated memory
atb_tensor.dataSize = t.bytes();
}
Comment on lines +220 to +225
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n mllm/backends/ascend/AscendCommon.cpp | sed -n '215,230p'

Repository: UbiquitousLearning/mllm

Length of output: 721


🏁 Script executed:

# Search for fillAtbTensorDesc definition to understand what it does
rg -n "fillAtbTensorDesc" mllm/backends/ascend/ -A 5

Repository: UbiquitousLearning/mllm

Length of output: 4216


🏁 Script executed:

# Find all call sites of fillAtbTensor
rg -n "fillAtbTensor\(" mllm/backends/ascend/ -B 2 -A 2

Repository: UbiquitousLearning/mllm

Length of output: 2581


🏁 Script executed:

# Check if there's a Tensor class definition to understand dtype() and device() methods
rg -n "class Tensor" -t h -t hpp

Repository: UbiquitousLearning/mllm

Length of output: 94


🏁 Script executed:

# Search for Tensor class definition
rg -n "class Tensor" -t c -t h --max-count=5

Repository: UbiquitousLearning/mllm

Length of output: 814


🏁 Script executed:

# Search for MLLM_RT_ASSERT macro definition
rg -n "MLLM_RT_ASSERT" -t c -t h --max-count=5

Repository: UbiquitousLearning/mllm

Length of output: 9863


🏁 Script executed:

# Check if Tensor has device() and dtype() methods
rg -n "\.device\(\)|\.dtype\(\)" mllm/backends/ascend/ -B 1 -A 1 | head -30

Repository: UbiquitousLearning/mllm

Length of output: 2524


🏁 Script executed:

# Look for Tensor definition more broadly
fd "Tensor" mllm --type f -e h | head -5

Repository: UbiquitousLearning/mllm

Length of output: 49


🏁 Script executed:

# Check Tensor class definition for device() and dtype() methods
cat -n mllm/core/Tensor.hpp | sed -n '53,150p'

Repository: UbiquitousLearning/mllm

Length of output: 3824


🏁 Script executed:

# Look at one of the call sites more closely to see if there's validation before fillAtbTensor
cat -n mllm/backends/ascend/ops/AscendSoftmaxOp.cpp | sed -n '30,85p'

Repository: UbiquitousLearning/mllm

Length of output: 2143


🏁 Script executed:

# Check AscendRMSNormOp for dtype/device validation before fillAtbTensor
cat -n mllm/backends/ascend/ops/AscendRMSNormOp.cpp | sed -n '50,70p'

Repository: UbiquitousLearning/mllm

Length of output: 830


🏁 Script executed:

# Check AscendLinearOp for dtype validation before fillAtbTensor
cat -n mllm/backends/ascend/ops/AscendLinearOp.cpp | sed -n '100,125p'

Repository: UbiquitousLearning/mllm

Length of output: 989


🏁 Script executed:

# Check AscendRMSNormOp for any validation
cat -n mllm/backends/ascend/ops/AscendRMSNormOp.cpp | sed -n '1,70p'

Repository: UbiquitousLearning/mllm

Length of output: 2800


🏁 Script executed:

# Search for dtype() and device() method definitions in Tensor class
cat -n mllm/core/Tensor.hpp | grep -A 1 "dtype()\|device()" | head -20

Repository: UbiquitousLearning/mllm

Length of output: 281


🏁 Script executed:

# Check AscendLinearOp more thoroughly for dtype validation
cat -n mllm/backends/ascend/ops/AscendLinearOp.cpp | sed -n '1,100p'

Repository: UbiquitousLearning/mllm

Length of output: 4524


🏁 Script executed:

# Check what MLLM_TYPE_F16 is and its relationship with kFloat16
rg -n "MLLM_TYPE_F16|kFloat16" mllm/core/DataTypes.hpp -A 1 | head -20

Repository: UbiquitousLearning/mllm

Length of output: 354


Add input validation to fillAtbTensor.

fillAtbTensor hardcodes FP16 in the descriptor and assumes a device pointer without validating the input tensor. While some call sites (SoftmaxOp, LinearOp) validate dtype before calling this function, others (RMSNormOp) do not. Add runtime checks to ensure tensors are FP16 and on the Ascend device, providing defense-in-depth against invalid tensor usage.

Suggested validation
 void fillAtbTensor(const Tensor& t, atb::Tensor& atb_tensor) {
+  MLLM_RT_ASSERT(t.device() == kAscend);
+  MLLM_RT_ASSERT(t.dtype() == MLLM_TYPE_F16);
+  MLLM_RT_ASSERT(!t.isNil());
   fillAtbTensorDesc(t, atb_tensor.desc);
-  atb_tensor.deviceData = reinterpret_cast<uint8_t*>(t.ptr<void>());
+  auto* ptr = t.ptr<void>();
+  MLLM_RT_ASSERT(ptr != nullptr);
+  atb_tensor.deviceData = reinterpret_cast<uint8_t*>(ptr);
   // Use MLLM tensor's actual bytes as dataSize to match allocated memory
   atb_tensor.dataSize = t.bytes();
 }
🤖 Prompt for AI Agents
In `@mllm/backends/ascend/AscendCommon.cpp` around lines 220 - 225, Add defensive
runtime validation at the start of fillAtbTensor: verify the input Tensor t has
FP16 dtype and is on the Ascend device (e.g., check t.dtype() == DType::FP16 and
t.device/type indicates Ascend), and ensure t.ptr<void>() is non-null; if any
check fails, throw or return a clear error (e.g., throw std::invalid_argument
with a message identifying fillAtbTensor), then proceed to call
fillAtbTensorDesc(t, atb_tensor.desc) and set atb_tensor.deviceData =
reinterpret_cast<uint8_t*>(t.ptr<void>()) and atb_tensor.dataSize = t.bytes();
this adds defense-in-depth for callers like RMSNormOp that may pass invalid
tensors.


AscendDeviceMetaInfo::AscendDeviceMetaInfo() {
#ifndef ASCENDC_CPU_DEBUG
// Initialize ACL to query devices
Expand All @@ -231,7 +238,6 @@ AscendDeviceMetaInfo::AscendDeviceMetaInfo() {
ret = aclrtGetDeviceCount(&device_count);
if (ret != ACL_SUCCESS) {
MLLM_ERROR("Failed to get Ascend device count: {}", ret);
aclFinalize();
return;
}

Expand Down Expand Up @@ -266,8 +272,6 @@ AscendDeviceMetaInfo::AscendDeviceMetaInfo() {
devices.push_back(info);
}

// Finalize ACL after enumeration
aclFinalize();
#else
// In CPU debug mode, add a dummy device
AscendDeviceInfo info;
Expand Down
3 changes: 3 additions & 0 deletions mllm/backends/ascend/AscendCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ void syncGlobalAtbStream();
// Convert MLLM Tensor metadata to ATB TensorDesc
void fillAtbTensorDesc(const Tensor& t, atb::TensorDesc& desc);

// Setup ATB Tensor with correct dataSize calculated by ATB Utils
void fillAtbTensor(const Tensor& t, atb::Tensor& atb_tensor);

// Ascend device information structure
struct AscendDeviceInfo {
std::string name;
Expand Down
173 changes: 170 additions & 3 deletions mllm/backends/ascend/ops/AscendElewiseOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ void AscendAddOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>
if (x.dtype() != y.dtype() || x.dtype() != z.dtype()) {
NYI("AscendAddOp currently requires x/y/z have same dtype");
}
if (x.numel() != y.numel() || x.numel() != z.numel()) {
NYI("AscendAddOp demo only supports no-broadcast case (numel equal)");
}

atb::infer::ElewiseParam addParam;
addParam.elewiseType = atb::infer::ElewiseParam::ELEWISE_ADD;
Expand Down Expand Up @@ -106,4 +103,174 @@ void AscendAddOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>
atb::DestroyOperation(op);
}

AscendSubOp::AscendSubOp(const aops::SubOpOptions& options) : aops::SubOp(options) {}

void AscendSubOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
BaseOp::setup(inputs, outputs);
}

void AscendSubOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
MLLM_RT_ASSERT_EQ(inputs.size(), 2);
MLLM_RT_ASSERT_EQ(outputs.size(), 1);

const auto& x = inputs[0];
const auto& y = inputs[1];
auto& z = outputs[0];

if (x.dtype() != y.dtype() || x.dtype() != z.dtype()) {
NYI("AscendSubOp currently requires x/y/z have same dtype");
}

atb::infer::ElewiseParam subParam;
subParam.elewiseType = atb::infer::ElewiseParam::ELEWISE_SUB;

atb::Operation* op = nullptr;
auto st = atb::CreateOperation(subParam, &op);
if (st != atb::NO_ERROR || op == nullptr) {
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(ELEWISE_SUB) failed, status={}", static_cast<int>(st));
}

atb::Context* atb_ctx = getGlobalAtbContext();

atb::Tensor atb_x;
atb::Tensor atb_y;
atb::Tensor atb_z;

fillAtbTensorDesc(x, atb_x.desc);
fillAtbTensorDesc(y, atb_y.desc);
fillAtbTensorDesc(z, atb_z.desc);

atb_x.deviceData = reinterpret_cast<uint8_t*>(x.ptr<void>());
atb_x.dataSize = x.bytes();
atb_y.deviceData = reinterpret_cast<uint8_t*>(y.ptr<void>());
atb_y.dataSize = y.bytes();
atb_z.deviceData = reinterpret_cast<uint8_t*>(z.ptr<void>());
atb_z.dataSize = z.bytes();

atb::SVector<atb::Tensor> inTensors;
atb::SVector<atb::Tensor> outTensors;
inTensors.push_back(atb_x);
inTensors.push_back(atb_y);
outTensors.push_back(atb_z);

atb::VariantPack vp;
vp.inTensors = inTensors;
vp.outTensors = outTensors;

uint64_t workspaceSize = 0;
st = op->Setup(vp, workspaceSize, atb_ctx);
if (st != atb::NO_ERROR) {
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SubOp Setup failed, status={}", static_cast<int>(st));
}

void* workspace = nullptr;
int workspace_block_id = -1;
if (workspaceSize > 0) {
auto& mem_mgr = getAscendMemoryManager();
mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id);
mem_mgr.getBlockPtr(workspace_block_id, workspace);
}
Comment on lines +168 to +172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same workspaceSize truncation concern as in AscendMatMulOp.

Both AscendSubOp and AscendMulOp cast uint64_t workspaceSize to uint32_t when calling allocateBlock. Consider adding the same guard suggested for AscendMatMulOp for consistency across all Ascend ops.

Also applies to: 253-257

🤖 Prompt for AI Agents
In `@mllm/backends/ascend/ops/AscendElewiseOps.cpp` around lines 168 - 172,
AscendSubOp and AscendMulOp are truncating uint64_t workspaceSize to uint32_t
when calling getAscendMemoryManager().allocateBlock; add the same guard used in
AscendMatMulOp to check if workspaceSize > UINT32_MAX (or otherwise unsafe to
cast) before allocating, and handle the error path (log/throw or set an error
status) instead of silently truncating; update the allocation calls to cast only
after the guard and reference workspace_block_id, getAscendMemoryManager(), and
allocateBlock to locate the change.

{
ASCEND_TIME_SCOPE("AscendSubOp::forward");
st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx);
}
if (st != atb::NO_ERROR) {
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB SubOp Execute failed, status={}", static_cast<int>(st));
}

syncGlobalAtbStream();

if (workspace_block_id != -1) {
auto& mem_mgr = getAscendMemoryManager();
mem_mgr.freeBlock(workspace_block_id);
}

atb::DestroyOperation(op);
}

AscendMulOp::AscendMulOp(const aops::MulOpOptions& options) : aops::MulOp(options) {}

void AscendMulOp::setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
BaseOp::setup(inputs, outputs);
}

void AscendMulOp::forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) {
MLLM_RT_ASSERT_EQ(inputs.size(), 2);
MLLM_RT_ASSERT_EQ(outputs.size(), 1);

const auto& x = inputs[0];
const auto& y = inputs[1];
auto& z = outputs[0];

if (x.dtype() != y.dtype() || x.dtype() != z.dtype()) {
NYI("AscendMulOp currently requires x/y/z have same dtype");
}

atb::infer::ElewiseParam mulParam;
mulParam.elewiseType = atb::infer::ElewiseParam::ELEWISE_MUL;

atb::Operation* op = nullptr;
auto st = atb::CreateOperation(mulParam, &op);
if (st != atb::NO_ERROR || op == nullptr) {
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB CreateOperation(ELEWISE_MUL) failed, status={}", static_cast<int>(st));
}

atb::Context* atb_ctx = getGlobalAtbContext();

atb::Tensor atb_x;
atb::Tensor atb_y;
atb::Tensor atb_z;

fillAtbTensorDesc(x, atb_x.desc);
fillAtbTensorDesc(y, atb_y.desc);
fillAtbTensorDesc(z, atb_z.desc);

atb_x.deviceData = reinterpret_cast<uint8_t*>(x.ptr<void>());
atb_x.dataSize = x.bytes();
atb_y.deviceData = reinterpret_cast<uint8_t*>(y.ptr<void>());
atb_y.dataSize = y.bytes();
atb_z.deviceData = reinterpret_cast<uint8_t*>(z.ptr<void>());
atb_z.dataSize = z.bytes();

atb::SVector<atb::Tensor> inTensors;
atb::SVector<atb::Tensor> outTensors;
inTensors.push_back(atb_x);
inTensors.push_back(atb_y);
outTensors.push_back(atb_z);

atb::VariantPack vp;
vp.inTensors = inTensors;
vp.outTensors = outTensors;

uint64_t workspaceSize = 0;
st = op->Setup(vp, workspaceSize, atb_ctx);
if (st != atb::NO_ERROR) {
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB MulOp Setup failed, status={}", static_cast<int>(st));
}

void* workspace = nullptr;
int workspace_block_id = -1;
if (workspaceSize > 0) {
auto& mem_mgr = getAscendMemoryManager();
mem_mgr.allocateBlock(static_cast<uint32_t>(workspaceSize), workspace_block_id);
mem_mgr.getBlockPtr(workspace_block_id, workspace);
}
{
ASCEND_TIME_SCOPE("AscendMulOp::forward");
st = op->Execute(vp, reinterpret_cast<uint8_t*>(workspace), workspaceSize, atb_ctx);
}
if (st != atb::NO_ERROR) {
MLLM_ERROR_EXIT(ExitCode::kAscendError, "ATB MulOp Execute failed, status={}", static_cast<int>(st));
}

syncGlobalAtbStream();

if (workspace_block_id != -1) {
auto& mem_mgr = getAscendMemoryManager();
mem_mgr.freeBlock(workspace_block_id);
}

atb::DestroyOperation(op);
}

} // namespace mllm::ascend
30 changes: 30 additions & 0 deletions mllm/backends/ascend/ops/AscendElewiseOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,34 @@ class AscendAddOpFactory final : public TypedOpFactory<OpTypes::kAdd, aops::AddO
}
};

class AscendSubOp final : public aops::SubOp {
public:
explicit AscendSubOp(const aops::SubOpOptions& options);

void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
};

class AscendSubOpFactory final : public TypedOpFactory<OpTypes::kSub, aops::SubOpOptions> {
public:
std::shared_ptr<BaseOp> createOpImpl(const aops::SubOpOptions& options) override {
return std::make_shared<AscendSubOp>(options);
}
};

class AscendMulOp final : public aops::MulOp {
public:
explicit AscendMulOp(const aops::MulOpOptions& options);

void setup(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
void forward(const std::vector<Tensor>& inputs, std::vector<Tensor>& outputs) override;
};

class AscendMulOpFactory final : public TypedOpFactory<OpTypes::kMul, aops::MulOpOptions> {
public:
std::shared_ptr<BaseOp> createOpImpl(const aops::MulOpOptions& options) override {
return std::make_shared<AscendMulOp>(options);
}
};

} // namespace mllm::ascend
Loading