-
Notifications
You must be signed in to change notification settings - Fork 25
feat(codegen): Add 910B PTO backend op support for paged attention #195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -112,7 +112,19 @@ static std::string MakeTernaryTileTileCodegenPTO(const std::string& pto_op_name, | |||||||||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // Helper function for binary Tile-Scalar operations | ||||||||||||||||||||||||||||||||||||||||||||
| // Helper function for full op | ||||||||||||||||||||||||||||||||||||||||||||
| static std::string MakeFullCodegenPTO(const std::string& pto_op_name, const CallPtr& op, | ||||||||||||||||||||||||||||||||||||||||||||
| codegen::CodegenBase& codegen_base) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base); | ||||||||||||||||||||||||||||||||||||||||||||
| CHECK(op->args_.size() == 2) << "full op requires 3 arguments." | ||||||||||||||||||||||||||||||||||||||||||||
| << op->args_.size(); // Actually 2 args, two of them are conbined! | ||||||||||||||||||||||||||||||||||||||||||||
| std::string scalar = codegen.GetExprAsCode(op->args_[1]); | ||||||||||||||||||||||||||||||||||||||||||||
| std::string dst = codegen.GetCurrentResultTarget(); | ||||||||||||||||||||||||||||||||||||||||||||
| codegen.Emit(pto_op_name + " " + "ins(" + scalar + ") outs(" + dst + ")"); | ||||||||||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+115
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix misleading argument-count message in The CHECK message says “3 arguments” even though the code expects 2, which will confuse debugging. ✏️ Suggested fix- CHECK(op->args_.size() == 2) << "full op requires 3 arguments."
- << op->args_.size(); // Actually 2 args, two of them are conbined!
+ CHECK(op->args_.size() == 2) << "full op requires 2 arguments, got " << op->args_.size();📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // Helper function for Binary Tile-Scalar operations | ||||||||||||||||||||||||||||||||||||||||||||
| static std::string MakeBinaryTileScalarCodegenPTO(const std::string& pto_op_name, const CallPtr& op, | ||||||||||||||||||||||||||||||||||||||||||||
| codegen::CodegenBase& codegen_base) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -175,6 +187,33 @@ static std::string MakeTernaryGEMVCodegenPTO(const std::string& pto_op_name, con | |||||||||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // Helper function for padding operations | ||||||||||||||||||||||||||||||||||||||||||||
| static std::string MakeFillPadCodegenPTO(const std::string& pto_op_name, const CallPtr& op, | ||||||||||||||||||||||||||||||||||||||||||||
| codegen::CodegenBase& codegen_base) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base); | ||||||||||||||||||||||||||||||||||||||||||||
| CHECK(op->args_.size() == 1) << "Fill pad op requires 1 argument."; | ||||||||||||||||||||||||||||||||||||||||||||
| codegen.Emit(pto_op_name + " " + GenerateInsOutsClause(op, codegen)); | ||||||||||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // Helper function for Ternary Data Movement/Layout operations | ||||||||||||||||||||||||||||||||||||||||||||
| static std::string MakeTernaryDataMoveLayoutCodegenPTO(const std::string& pto_op_name, const CallPtr& op, | ||||||||||||||||||||||||||||||||||||||||||||
| codegen::CodegenBase& codegen_base) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base); | ||||||||||||||||||||||||||||||||||||||||||||
| CHECK(op->args_.size() == 3) << "Ternary move/layout op requires 3 arguments."; | ||||||||||||||||||||||||||||||||||||||||||||
| codegen.Emit(pto_op_name + " " + GenerateInsOutsClause(op, codegen)); | ||||||||||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // Helper function for Binary Axis Reduction/Expansion operations | ||||||||||||||||||||||||||||||||||||||||||||
| static std::string MakeBinaryAxisCodegenPTO(const std::string& pto_op_name, const CallPtr& op, | ||||||||||||||||||||||||||||||||||||||||||||
| codegen::CodegenBase& codegen_base) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base); | ||||||||||||||||||||||||||||||||||||||||||||
| CHECK(op->args_.size() == 2) << "Binary Axis op requires 2 arguments."; | ||||||||||||||||||||||||||||||||||||||||||||
| codegen.Emit(pto_op_name + " " + GenerateInsOutsClause(op, codegen)); | ||||||||||||||||||||||||||||||||||||||||||||
| return ""; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // block.load: emit pto.subview + pto.tload (same format as original IR layer codegen) | ||||||||||||||||||||||||||||||||||||||||||||
| static std::string MakeBlockLoadCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -524,7 +563,13 @@ REGISTER_BACKEND_OP(Backend910B_PTO, "block.mins") | |||||||||||||||||||||||||||||||||||||||||||
| return MakeBinaryTileScalarCodegenPTO("pto.tmins", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // Not Implemented: tlrelu tcmps taddsc tsubsc tsels texpands | ||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.full") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeFullCodegenPTO("pto.texpands", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // Not Implemented: tlrelu tcmps taddsc tsubsc tsels | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
| // Matrix Multiplication Operations | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -584,6 +629,66 @@ REGISTER_BACKEND_OP(Backend910B_PTO, "block.gemv_bias") | |||||||||||||||||||||||||||||||||||||||||||
| return MakeTernaryGEMVCodegenPTO("pto.tgemv.bias", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
| // Data Movement/Layout Operations | ||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.transpose") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeTernaryDataMoveLayoutCodegenPTO("pto.ttrans", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
| // Axis reduction/expansion Operations | ||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_sum") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeBinaryAxisCodegenPTO("pto.trowsum", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_max") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeBinaryAxisCodegenPTO("pto.trowmax", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_min") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeBinaryAxisCodegenPTO("pto.trowmin", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_expand_div") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeBinaryAxisCodegenPTO("pto.trowexpanddiv", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_expand_mul") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeBinaryAxisCodegenPTO("pto.trowexpandmul", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.row_expand_sub") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeBinaryAxisCodegenPTO("pto.trowexpandsub", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
| // Padding Operations | ||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_BACKEND_OP(Backend910B_PTO, "block.fillpad") | ||||||||||||||||||||||||||||||||||||||||||||
| .set_pipe(ir::PipeType::V) | ||||||||||||||||||||||||||||||||||||||||||||
| .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||||||||||||||||||||||||||||||||||||||||||||
| return MakeFillPadCodegenPTO("pto.tfillpad", op, codegen); | ||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
| // Memory Operations | ||||||||||||||||||||||||||||||||||||||||||||
| // ============================================================================ | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message in this
CHECKis confusing. It states thatfull op requires 3 arguments, but the check is forop->args_.size() == 2. The comment also clarifies there are 2 arguments. The error message should be updated to reflect that 2 arguments are expected.