Skip to content

Commit 120d30f

Browse files
committed
fix: 支持attention计算 q,k,v不相同的情况
1 parent 61789f7 commit 120d30f

File tree

4 files changed

+25
-15
lines changed

4 files changed

+25
-15
lines changed

operators/src/attention/args.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub(super) struct Meta {
3030
pub seq: MaybeDyn<usize>,
3131
pub att: MaybeDyn<usize>,
3232
pub dh: MaybeDyn<usize>,
33+
pub dv: MaybeDyn<usize>,
3334
}
3435

3536
impl<H: Hardware> Args<H> {
@@ -41,17 +42,20 @@ impl<H: Hardware> Args<H> {
4142
seq: MaybeDyn<usize>,
4243
att: MaybeDyn<usize>,
4344
dh: MaybeDyn<usize>,
45+
dv: MaybeDyn<usize>,
4446
) -> Self {
45-
let qo_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]);
46-
let kv_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]);
47+
let q_layout = TensorLayout::new_dyn(dt, &[nh, seq, dh], &[dyn_(); 3]);
48+
let k_layout = TensorLayout::new_dyn(dt, &[nkvh, seq, dh], &[dyn_(); 3]);
49+
let v_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dv], &[dyn_(); 3]);
50+
let o_layout = TensorLayout::new_dyn(dt, &[nkvh, att, dh], &[dyn_(); 3]);
4751
Self {
48-
q_layout: qo_layout.clone(),
52+
q_layout: q_layout.clone(),
4953
q_base: null_mut(),
50-
k_layout: kv_layout.clone(),
54+
k_layout: k_layout.clone(),
5155
k_base: null(),
52-
v_layout: kv_layout,
56+
v_layout: v_layout,
5357
v_base: null(),
54-
o_layout: qo_layout,
58+
o_layout: o_layout,
5559
o_base: null_mut(),
5660
mask,
5761
}
@@ -85,7 +89,8 @@ impl<H: Hardware> Args<H> {
8589
nkvh: dim_distinct(&[nkvh_k, nkvh_v])?,
8690
seq: dim_distinct(&[seq_q, seq_o])?,
8791
att: dim_distinct(&[att_k, att_v])?,
88-
dh: dim_distinct(&[dh_q, dh_k, dh_v, dh_o])?,
92+
dh: dim_distinct(&[dh_q, dh_k])?,
93+
dv: dim_distinct(&[dh_v, dh_o])?,
8994
})
9095
}
9196
}

operators/src/attention/cuda.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ mod test {
1616
seq.into(),
1717
att.into(),
1818
dyn_(),
19+
dyn_(),
1920
)
2021
}
2122

operators/src/attention/operator.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::{args::Meta, Args, Attention};
1+
use super::{args::Meta, Args, Attention};
22
use crate::{
33
dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc,
44
SchemeError, TensorLayout, Workspace, WorkspaceCollector,
@@ -53,6 +53,7 @@ where
5353
seq,
5454
att,
5555
dh,
56+
dv,
5657
..
5758
} = args.meta()?;
5859
let Args {
@@ -64,11 +65,12 @@ where
6465
} = args;
6566

6667
// 如果不能保证 nh seq att dh 已知,用任意值初始化算子
67-
let (Some(&nh), Some(&seq), Some(&att), Some(&dh)) = (
68+
let (Some(&nh), Some(&seq), Some(&att), Some(&dh), Some(&dv)) = (
6869
nh.get_static(),
6970
seq.get_static(),
7071
att.get_static(),
7172
dh.get_static(),
73+
dv.get_static(),
7274
) else {
7375
let mut wc = WorkspaceCollector::new();
7476

@@ -149,6 +151,7 @@ where
149151
seq,
150152
att,
151153
dh,
154+
dv,
152155
} = args.meta()?;
153156
let Args {
154157
mask,
@@ -172,8 +175,8 @@ where
172175
let ele = dt.nbytes();
173176
get_static! {
174177
nh seq dh
175-
nh_sq seq_sq dh_sq
176-
nkvh att
178+
dv seq_sq dh_sq
179+
nkvh att nh_sq
177180
nkvh_sk att_sk dh_sk
178181
};
179182

@@ -219,6 +222,7 @@ where
219222
let k_layout = TensorLayout::new(dt, k_layout.shape(), k_layout.strides());
220223
let att_mat_mul = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, att]);
221224
let att_softmax = TensorLayout::new_contiguous(dt, &[nh, seq, att]);
225+
let att_result = TensorLayout::new_contiguous(dt, &[nkvh, head_group * seq, dv]);
222226

223227
// att = q . k^T
224228
self.mat_mul.launch(
@@ -248,7 +252,7 @@ where
248252
// q = att . v
249253
self.mat_mul.launch(
250254
&mat_mul::Args {
251-
c_layout: qx_layout.clone(),
255+
c_layout: att_result.clone(),
252256
c_base: q_base,
253257
beta: 0.,
254258
a_layout: att_mat_mul,
@@ -266,7 +270,7 @@ where
266270
&rearrange::Args {
267271
dst_layout: o_layout.clone(),
268272
dst_base: *o_base,
269-
src_layout: q_layout.clone(),
273+
src_layout: o_layout.clone(),
270274
src_base: q_base,
271275
},
272276
workspace,

operators/src/attention_kv_cached/operator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::{args::Meta, Args, AttnKVCached};
1+
use super::{args::Meta, Args, AttnKVCached};
22
use crate::{
33
attention, dyn_, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError,
44
MaybeDyn, QueueAlloc, TensorLayout, WorkspaceCollector,
@@ -66,7 +66,7 @@ where
6666
};
6767

6868
wc.push_sub(self.attention.scheme(
69-
&attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh),
69+
&attention::Args::new_null(args.mask, dt, nh, nkvh, seq, att, dh, dh),
7070
max_workspace_size,
7171
)?);
7272

0 commit comments

Comments
 (0)