Skip to content

Commit 25a1da6

Browse files
committed
refactor: 移除 MaybeDyn 和复杂的 TensroLayout
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 4f0fa2a commit 25a1da6

File tree

64 files changed

+352
-762
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+352
-762
lines changed

operators/src/.clang-format

Lines changed: 16 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,29 @@
1-
# Generated from CLion C/C++ Code Style settings
1+
---
22
BasedOnStyle: LLVM
3-
AccessModifierOffset: -4
4-
AlignAfterOpenBracket: Align
5-
# AlignConsecutiveAssignments: None
6-
AlignOperands: Align
7-
AllowAllArgumentsOnNextLine: false
8-
AllowAllConstructorInitializersOnNextLine: false
9-
AllowAllParametersOfDeclarationOnNextLine: false
10-
AllowShortBlocksOnASingleLine: Always
11-
AllowShortCaseLabelsOnASingleLine: false
12-
AllowShortFunctionsOnASingleLine: All
13-
AllowShortIfStatementsOnASingleLine: Always
14-
AllowShortLambdasOnASingleLine: All
15-
AllowShortLoopsOnASingleLine: true
16-
AlwaysBreakAfterReturnType: None
17-
AlwaysBreakTemplateDeclarations: No
18-
BreakBeforeBraces: Custom
3+
IndentWidth: 4 # 缩进宽度,LLVM 默认值为 2,改为 4
4+
AccessModifierOffset: -4 # public/protected/private 访问控制符相对成员的偏移,与 IndentWidth 配合,LLVM 默认值为 -2
5+
AlignOperands: AlignAfterOperator # 双目运算符的行间对齐,LLVM 默认值为 Align,改为带符号一起换行
6+
ColumnLimit: 0 # 列宽限制,LLVM 默认值为 80,改为不限制
7+
AllowShortBlocksOnASingleLine: Always # 是否允许短块(单个语句的块)不换行,LLVM 默认值为 Never,改为允许
8+
AllowShortLoopsOnASingleLine: true # 是否允许短循环不换行,LLVM 默认值为 false,改为允许
9+
InsertBraces: true # 是否在 if/for/while/switch 等语句后插入大括号,LLVM 默认值为 false,改为允许
10+
BreakBeforeBraces: Custom # 大括号换行配置,LLVM 默认值为 LLVM,改为自定义以使 BraceWrapping 生效
1911
BraceWrapping:
2012
AfterCaseLabel: false
2113
AfterClass: false
2214
AfterControlStatement: Never
2315
AfterEnum: false
2416
AfterFunction: false
2517
AfterNamespace: false
18+
AfterObjCDeclaration: false
19+
AfterStruct: false
2620
AfterUnion: false
21+
AfterExternBlock: false
2722
BeforeCatch: false
2823
BeforeElse: false
24+
BeforeLambdaBody: false
25+
BeforeWhile: false
2926
IndentBraces: false
30-
SplitEmptyFunction: false
27+
SplitEmptyFunction: true
3128
SplitEmptyRecord: true
32-
BreakBeforeBinaryOperators: None
33-
BreakBeforeTernaryOperators: true
34-
BreakConstructorInitializers: BeforeColon
35-
BreakInheritanceList: BeforeColon
36-
ColumnLimit: 0
37-
CompactNamespaces: true
38-
ContinuationIndentWidth: 4
39-
IndentCaseLabels: true
40-
IndentPPDirectives: None
41-
IndentWidth: 4
42-
KeepEmptyLinesAtTheStartOfBlocks: true
43-
MaxEmptyLinesToKeep: 2
44-
NamespaceIndentation: All
45-
ObjCSpaceAfterProperty: false
46-
ObjCSpaceBeforeProtocolList: true
47-
PointerAlignment: Right
48-
ReflowComments: false
49-
SpaceAfterCStyleCast: true
50-
SpaceAfterLogicalNot: false
51-
SpaceAfterTemplateKeyword: false
52-
SpaceBeforeAssignmentOperators: true
53-
SpaceBeforeCpp11BracedList: false
54-
SpaceBeforeCtorInitializerColon: true
55-
SpaceBeforeInheritanceColon: true
56-
SpaceBeforeParens: ControlStatements
57-
SpaceBeforeRangeBasedForLoopColon: true
58-
SpaceInEmptyParentheses: false
59-
SpacesBeforeTrailingComments: 0
60-
SpacesInAngles: false
61-
SpacesInCStyleCastParentheses: false
62-
SpacesInContainerLiterals: false
63-
SpacesInParentheses: false
64-
SpacesInSquareBrackets: false
65-
TabWidth: 4
66-
UseTab: Never
29+
SplitEmptyNamespace: true

operators/src/add/args.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
2-
get_static, rank_mismatch, shape_mismatch, shape_not_support, utils::type_distinct, ConstPtr,
3-
Hardware, LaunchError, MutPtr, TensorLayout,
2+
rank_mismatch, shape_mismatch, shape_not_support, utils::type_distinct, ConstPtr, Hardware,
3+
LaunchError, MutPtr, TensorLayout,
44
};
55
use digit_layout::DigitLayout;
66
use itertools::izip;
@@ -48,7 +48,7 @@ impl Scheme {
4848
..
4949
} = args;
5050
// # 检查基本属性
51-
let dt = type_distinct(&[c.dt(), a.dt(), b.dt()])?;
51+
let dt = type_distinct(&[c.dt, a.dt, b.dt])?;
5252
let ndim = c.ndim();
5353
if a.ndim() != ndim || b.ndim() != ndim {
5454
return Err(rank_mismatch(format!(
@@ -68,17 +68,13 @@ impl Scheme {
6868
}
6969
let mut dims = Vec::with_capacity(ndim);
7070
for (&d, &da, &db, &sc, &sa, &sb) in izip!(
71-
c.shape(),
72-
a.shape(),
73-
b.shape(),
71+
c.shape_group(),
72+
a.shape_group(),
73+
b.shape_group(),
7474
c.strides(),
7575
a.strides(),
76-
b.strides()
76+
b.strides(),
7777
) {
78-
get_static! {
79-
d da db
80-
sc sa sb
81-
}
8278
if da != d || db != d {
8379
return Err(shape_mismatch(format!(
8480
"c: {:?}, a: {:?}, b: {:?}",

operators/src/add/cuda/add.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
template<class Tdata>
1+
template <class Tdata>
22
static __device__ void _add(
33
Tdata *__restrict__ c,
44
Tdata const *__restrict__ a,

operators/src/add_rows/args.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
type_not_support,
33
utils::{dim_distinct, rank_error, type_distinct},
4-
ConstPtr, Hardware, LaunchError, MaybeDyn, MutPtr, TensorLayout,
4+
ConstPtr, Hardware, LaunchError, MutPtr, TensorLayout,
55
};
66
use digit_layout::{DigitLayout, LayoutContent::Unsigned};
77
use std::ptr::{null, null_mut};
@@ -37,10 +37,10 @@ impl<H: Hardware> Args<H> {
3737
pub(super) struct Meta {
3838
pub dt: DigitLayout,
3939
pub dt_idx: DigitLayout,
40-
pub batch: MaybeDyn<usize>,
41-
pub m: MaybeDyn<usize>,
42-
pub n: MaybeDyn<usize>,
43-
pub k: MaybeDyn<usize>,
40+
pub batch: usize,
41+
pub m: usize,
42+
pub n: usize,
43+
pub k: usize,
4444
}
4545

4646
impl<H: Hardware> Args<H> {
@@ -52,30 +52,30 @@ impl<H: Hardware> Args<H> {
5252
..
5353
} = self;
5454

55-
let dt = type_distinct(&[dst.dt(), src.dt()])?;
56-
let dt_idx = idx.dt();
55+
let dt = type_distinct(&[dst.dt, src.dt])?;
56+
let dt_idx = idx.dt;
5757
if !matches!(dt_idx.decode(), Unsigned { .. }) {
5858
return Err(type_not_support(format!(
5959
"data type {dt_idx} is not supported, must be unsigned integers"
6060
)));
6161
}
6262

63-
let &[batch, m, n] = dst.shape() else {
63+
let &[batch, m, n] = &*dst.shape() else {
6464
return Err(rank_error("dst", 3, dst.ndim()));
6565
};
66-
let &[k, n_] = src.shape() else {
66+
let &[k, n_] = &*src.shape() else {
6767
return Err(rank_error("src", 2, src.ndim()));
6868
};
69-
let &[batch_, m_] = idx.shape() else {
69+
let &[batch_, m_] = &*idx.shape() else {
7070
return Err(rank_error("idx", 2, idx.ndim()));
7171
};
7272

7373
Ok(Meta {
7474
dt,
7575
dt_idx,
76-
batch: dim_distinct(&[batch, batch_])?,
77-
m: dim_distinct(&[m, m_])?,
78-
n: dim_distinct(&[n, n_])?,
76+
batch: dim_distinct(&[batch, batch_]).expect("batch mismatch"),
77+
m: dim_distinct(&[m, m_]).expect("m mismatch"),
78+
n: dim_distinct(&[n, n_]).expect("n mismatch"),
7979
k,
8080
})
8181
}

operators/src/add_rows/common_cpu/mod.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::{args::Meta, AddRows, Args};
2-
use crate::{common_cpu::Cpu, get_static, ByteOf, LaunchError, QueueAlloc, Unsigned};
2+
use crate::{common_cpu::Cpu, ByteOf, LaunchError, QueueAlloc, Unsigned};
33
use digit_layout::types as ty;
44
use half::f16;
55
use rayon::iter::{IntoParallelIterator, ParallelIterator};
@@ -54,12 +54,6 @@ impl crate::Operator for Operator {
5454
unreachable!()
5555
};
5656

57-
get_static! {
58-
b m n k
59-
bsd msd nsd
60-
bsi msi nss kss
61-
}
62-
6357
let dst = *dst_base as usize;
6458
let src = *src_base as usize;
6559
let idx = *idx_base as usize;

operators/src/add_rows/cuda/add_rows.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
template<class Tdata, class Tidx>
1+
template <class Tdata, class Tidx>
22
static __device__ void add_rows(
33
Tdata *__restrict__ dst,
44
Tdata const *__restrict__ src,

operators/src/add_rows/cuda/mod.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::{AddRows, Args};
22
use crate::{
33
add_rows::args::Meta,
44
cuda::{dt_name, Gpu, Handle, ModuleBox},
5-
get_static, strides_not_support,
5+
strides_not_support,
66
utils::gcd,
77
ByteOf, LaunchError, QueueAlloc, SchemeDiversity,
88
};
@@ -63,13 +63,8 @@ impl crate::Operator for Operator {
6363
unreachable!()
6464
};
6565

66-
get_static! {
67-
b n m
68-
bsd msd nsd
69-
bsi msi nss kss
70-
}
71-
let unit_dst = dst_layout.dt().nbytes() as isize;
72-
let unit_idx = idx_layout.dt().nbytes() as isize;
66+
let unit_dst = dst_layout.dt.nbytes() as isize;
67+
let unit_idx = idx_layout.dt.nbytes() as isize;
7368
if nsd != unit_dst || nss != unit_dst || msi != unit_idx {
7469
return Err(strides_not_support(""));
7570
};
@@ -85,9 +80,7 @@ impl crate::Operator for Operator {
8580
let params = cuda::params![dst_base, src_base, idx_base, bsd, msd, kss, bsi];
8681
let block = gcd(self.max_threads_block, n);
8782
let dimx = n.div_ceil(block);
88-
let key = SchemeKey {
89-
dt: dst_layout.dt(),
90-
};
83+
let key = SchemeKey { dt: dst_layout.dt };
9184
let scheme = self
9285
.schemes
9386
.lock()
@@ -160,6 +153,7 @@ mod test {
160153
};
161154
use half::f16;
162155

156+
#[allow(clippy::too_many_arguments)]
163157
fn args<H: Hardware>(
164158
dt: DigitLayout,
165159
b: usize,

operators/src/all_reduce/args.rs

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
use super::ReduceOp;
22
use crate::{
3-
dyn_not_support, rearrange, shape_mismatch, strides_not_support, utils::type_distinct,
4-
Hardware, LaunchError, MaybeDyn,
3+
rearrange, shape_mismatch, strides_not_support, utils::type_distinct, Hardware, LaunchError,
54
};
65
use digit_layout::DigitLayout;
7-
use ndarray_layout::ArrayLayout;
86

97
pub struct Args<H: Hardware> {
108
pub pair: rearrange::Args<H>,
@@ -35,15 +33,9 @@ impl<H: Hardware> Args<H> {
3533
..
3634
} = self;
3735

38-
let dt = type_distinct(&[dst_layout.dt(), src_layout.dt()])?;
36+
let dt = type_distinct(&[dst_layout.dt, src_layout.dt])?;
3937

40-
let Some(shape) = MaybeDyn::get_all(dst_layout.shape()) else {
41-
return Err(dyn_not_support(""));
42-
};
43-
let Some(strides) = MaybeDyn::get_all(dst_layout.strides()) else {
44-
return Err(dyn_not_support(""));
45-
};
46-
let dst = ArrayLayout::<2>::new(shape, strides, 0);
38+
let dst = &dst_layout.layout;
4739
let &[dst] = dst
4840
.merge_be(0, dst.ndim())
4941
.ok_or(strides_not_support(""))?
@@ -52,13 +44,7 @@ impl<H: Hardware> Args<H> {
5244
unreachable!()
5345
};
5446

55-
let Some(shape) = MaybeDyn::get_all(src_layout.shape()) else {
56-
return Err(dyn_not_support(""));
57-
};
58-
let Some(strides) = MaybeDyn::get_all(src_layout.strides()) else {
59-
return Err(dyn_not_support(""));
60-
};
61-
let src = ArrayLayout::<2>::new(shape, strides, 0);
47+
let src = &src_layout.layout;
6248
let &[src] = src
6349
.merge_be(0, src.ndim())
6450
.ok_or(strides_not_support(""))?

operators/src/attention/args.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
fuesd_softmax::AttnMask,
33
utils::{dim_distinct, rank_error, type_distinct},
4-
ConstPtr, Hardware, LaunchError, MaybeDyn, MutPtr, TensorLayout,
4+
ConstPtr, Hardware, LaunchError, MutPtr, TensorLayout,
55
};
66
use digit_layout::DigitLayout;
77

@@ -23,11 +23,11 @@ pub struct Args<H: Hardware> {
2323

2424
pub(super) struct Meta {
2525
pub dt: DigitLayout,
26-
pub nh: MaybeDyn<usize>,
27-
pub nkvh: MaybeDyn<usize>,
28-
pub seq: MaybeDyn<usize>,
29-
pub att: MaybeDyn<usize>,
30-
pub dh: MaybeDyn<usize>,
26+
pub nh: usize,
27+
pub nkvh: usize,
28+
pub seq: usize,
29+
pub att: usize,
30+
pub dh: usize,
3131
}
3232

3333
impl<H: Hardware> Args<H> {
@@ -40,26 +40,26 @@ impl<H: Hardware> Args<H> {
4040
..
4141
} = self;
4242

43-
let &[nh_q, seq_q, dh_q] = q_layout.shape() else {
43+
let &[nh_q, seq_q, dh_q] = &*q_layout.shape() else {
4444
return Err(rank_error("q", 3, q_layout.ndim()));
4545
};
46-
let &[nkvh_k, att_k, dh_k] = k_layout.shape() else {
46+
let &[nkvh_k, att_k, dh_k] = &*k_layout.shape() else {
4747
return Err(rank_error("k", 3, k_layout.ndim()));
4848
};
49-
let &[nkvh_v, att_v, dh_v] = v_layout.shape() else {
49+
let &[nkvh_v, att_v, dh_v] = &*v_layout.shape() else {
5050
return Err(rank_error("v", 3, v_layout.ndim()));
5151
};
52-
let &[nh_o, seq_o, dh_o] = o_layout.shape() else {
52+
let &[nh_o, seq_o, dh_o] = &*o_layout.shape() else {
5353
return Err(rank_error("o", 3, o_layout.ndim()));
5454
};
5555

5656
Ok(Meta {
57-
dt: type_distinct(&[q_layout.dt(), k_layout.dt(), v_layout.dt(), o_layout.dt()])?,
58-
nh: dim_distinct(&[nh_q, nh_o])?,
59-
nkvh: dim_distinct(&[nkvh_k, nkvh_v])?,
60-
seq: dim_distinct(&[seq_q, seq_o])?,
61-
att: dim_distinct(&[att_k, att_v])?,
62-
dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o])?,
57+
dt: type_distinct(&[q_layout.dt, k_layout.dt, v_layout.dt, o_layout.dt])?,
58+
nh: dim_distinct(&[nh_q, nh_o]).expect("nh mismatch"),
59+
nkvh: dim_distinct(&[nkvh_k, nkvh_v]).expect("nkvh mismatch"),
60+
seq: dim_distinct(&[seq_q, seq_o]).expect("seq mismatch"),
61+
att: dim_distinct(&[att_k, att_v]).expect("att mismatch"),
62+
dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o]).expect("dh mismatch"),
6363
})
6464
}
6565
}

operators/src/attention/cuda.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod test {
66
use crate::{cuda::Gpu, ByteOf, Hardware, Operator as _, TensorLayout};
77
use digit_layout::{types as ty, DigitLayout};
88

9+
#[allow(clippy::too_many_arguments)]
910
fn args<H: Hardware>(
1011
dt: DigitLayout,
1112
nh: usize,

0 commit comments

Comments
 (0)