Skip to content

Commit ec21c79

Browse files
committed
fix: 添加更多rope
1 parent f4a83f7 commit ec21c79

File tree

5 files changed

+195
-38
lines changed

5 files changed

+195
-38
lines changed

operators/src/attention/args.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ impl<H: Hardware> Args<H> {
5353
q_base: null_mut(),
5454
k_layout: k_layout.clone(),
5555
k_base: null(),
56-
v_layout: v_layout,
56+
v_layout,
5757
v_base: null(),
58-
o_layout: o_layout,
58+
o_layout,
5959
o_base: null_mut(),
6060
mask,
6161
}

operators/src/rope/args.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,44 @@
1-
use crate::{
1+
use crate::{
22
type_not_support,
33
utils::{dim_distinct, rank_error},
44
ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout,
55
};
66
use digit_layout::DigitLayout;
77

8+
pub enum RopeType<H: Hardware> {
9+
// 以下枚举通用一个 Scheme
10+
Rope,
11+
// Pi {
12+
// s: f32,
13+
// },
14+
Ntk {
15+
s: f32,
16+
},
17+
Dyn {
18+
s: f32,
19+
a: f32,
20+
},
21+
// 以下枚举通用一个 Scheme
22+
Long {
23+
long: ConstPtr<H>,
24+
short: ConstPtr<H>,
25+
max_pos: u32,
26+
origin_pos: u32,
27+
},
28+
NtkParts {
29+
alpha: f32,
30+
beta: f32,
31+
l0: f32,
32+
s: f32,
33+
},
34+
Yarn {
35+
alpha: f32,
36+
beta: f32,
37+
l0: f32,
38+
s: f32,
39+
},
40+
}
41+
842
pub struct Args<H: Hardware> {
943
pub t_layout: TensorLayout,
1044
pub t_base: MutPtr<H>,
@@ -15,6 +49,7 @@ pub struct Args<H: Hardware> {
1549
pub cos_layout: TensorLayout,
1650
pub cos_base: ConstPtr<H>,
1751
pub theta: f32,
52+
pub rope_type: RopeType<H>,
1853
}
1954

2055
pub(super) struct Meta {

operators/src/rope/common_cpu/mod.rs

Lines changed: 153 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
1-
use super::{args::Meta, fill_pos, Args, Rope, Seq, SinCosTable};
1+
use std::ptr::null;
2+
3+
use super::{args::Meta, args::RopeType as R, fill_pos, Args, Rope, Seq, SinCosTable};
24
use crate::{
35
common_cpu::Cpu, get_static, strides_not_support, ByteOf, LaunchError, QueueAlloc, SchemeError,
46
Unsigned,
57
};
68
use digit_layout::{types as ty, DigitLayout};
79
use half::f16;
8-
10+
#[derive(Copy, Clone)]
11+
enum SchemeType<T> {
12+
Rope,
13+
Long {
14+
long: *const T,
15+
short: *const T,
16+
s: f32,
17+
origin_pos: u32,
18+
},
19+
Yarn,
20+
}
921
pub struct Operator;
1022

1123
impl Rope<Cpu> for Operator {
@@ -78,6 +90,7 @@ impl crate::Operator for Operator {
7890
p_layout,
7991
p_base,
8092
theta,
93+
rope_type,
8194
..
8295
} = args;
8396
let &[_, nh, dh] = t_layout.shape() else {
@@ -99,33 +112,89 @@ impl crate::Operator for Operator {
99112
return Err(strides_not_support("").into());
100113
}
101114

102-
macro_rules! calculate {
103-
($t:ty, $p:ty) => {
104-
Scheme::<$t, $p> {
105-
nt,
106-
nh,
107-
dh,
108-
st,
109-
sh,
110-
sp,
111-
theta: *theta,
112-
t_base: t_base.cast(),
113-
p_base: p_base.cast(),
115+
match rope_type {
116+
R::Rope | R::Dyn { .. } | R::Ntk { .. } => {
117+
let theta = match rope_type {
118+
R::Rope => *theta,
119+
R::Dyn { s, a } => theta * (a * s - a + 1.),
120+
R::Ntk { s } => theta * s,
121+
_ => unreachable!(),
122+
};
123+
macro_rules! calculate {
124+
($t:ty, $p:ty) => {
125+
Scheme::<$t, $p> {
126+
nt,
127+
nh,
128+
dh,
129+
st,
130+
sh,
131+
sp,
132+
theta,
133+
t_base: t_base.cast(),
134+
p_base: p_base.cast(),
135+
scheme_type: SchemeType::Rope,
136+
}
137+
.calculate()
138+
};
114139
}
115-
.calculate()
116-
};
117-
}
118140

119-
use digit_layout::types as ty;
120-
match (dt_t, dt_p) {
121-
(ty::F16, ty::U32) => calculate!(f16, u32),
122-
(ty::F16, ty::U64) => calculate!(f16, u64),
123-
(ty::F32, ty::U32) => calculate!(f32, u32),
124-
(ty::F32, ty::U64) => calculate!(f32, u64),
125-
(ty::F64, ty::U32) => calculate!(f64, u32),
126-
(ty::F64, ty::U64) => calculate!(f64, u64),
127-
_ => todo!(),
141+
use digit_layout::types as ty;
142+
match (dt_t, dt_p) {
143+
(ty::F16, ty::U32) => calculate!(f16, u32),
144+
(ty::F16, ty::U64) => calculate!(f16, u64),
145+
(ty::F32, ty::U32) => calculate!(f32, u32),
146+
(ty::F32, ty::U64) => calculate!(f32, u64),
147+
(ty::F64, ty::U32) => calculate!(f64, u32),
148+
(ty::F64, ty::U64) => calculate!(f64, u64),
149+
_ => todo!(),
150+
}
151+
}
152+
R::Long {
153+
long,
154+
short,
155+
max_pos,
156+
origin_pos,
157+
} => {
158+
let s = 1.0
159+
+ ((*max_pos as f32 / *origin_pos as f32).ln() / (*origin_pos as f32).ln())
160+
.sqrt();
161+
macro_rules! calculate {
162+
($t:ty, $p:ty) => {
163+
Scheme::<$t, $p> {
164+
nt,
165+
nh,
166+
dh,
167+
st,
168+
sh,
169+
sp,
170+
theta: *theta,
171+
t_base: t_base.cast(),
172+
p_base: p_base.cast(),
173+
scheme_type: SchemeType::Long {
174+
long: long.cast(),
175+
short: short.cast(),
176+
s,
177+
origin_pos: *origin_pos,
178+
},
179+
}
180+
.calculate()
181+
};
182+
}
183+
184+
use digit_layout::types as ty;
185+
match (dt_t, dt_p) {
186+
(ty::F16, ty::U32) => calculate!(f16, u32),
187+
(ty::F16, ty::U64) => calculate!(f16, u64),
188+
(ty::F32, ty::U32) => calculate!(f32, u32),
189+
(ty::F32, ty::U64) => calculate!(f32, u64),
190+
(ty::F64, ty::U32) => calculate!(f64, u32),
191+
(ty::F64, ty::U64) => calculate!(f64, u64),
192+
_ => todo!(),
193+
}
194+
}
195+
_ => {}
128196
}
197+
129198
Ok(())
130199
}
131200
}
@@ -142,15 +211,15 @@ struct Scheme<A, P> {
142211
theta: f32,
143212
t_base: *mut A,
144213
p_base: *const P,
214+
scheme_type: SchemeType<A>,
145215
}
146216

147217
unsafe impl<A, P> Send for Scheme<A, P> {}
148218
unsafe impl<A, P> Sync for Scheme<A, P> {}
149-
150219
/// 激活值。
151220
trait Activation: Sized {
152221
/// 激活值类型决定计算类型。
153-
type Calculation;
222+
type Calculation: Copy;
154223
/// 计算流程。
155224
fn calculate(pair: &mut [Self; 2], sin: Self::Calculation, cos: Self::Calculation);
156225
}
@@ -187,16 +256,37 @@ impl Activation for f64 {
187256
}
188257

189258
trait Position<Calculation> {
190-
fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> (Calculation, Calculation);
259+
fn freq_sin_cos_rope(self, k: isize, dh: isize, theta: f32) -> (Calculation, Calculation);
260+
fn freq_sin_cos_long(
261+
self,
262+
k: isize,
263+
dh: isize,
264+
theta: f32,
265+
factor: Calculation,
266+
s: f32,
267+
) -> (Calculation, Calculation);
191268
}
192269

193270
macro_rules! impl_position {
194271
($a:ty) => {
195272
impl<T: Unsigned> Position<$a> for T {
196273
#[inline]
197-
fn freq_sin_cos(self, k: isize, dh: isize, theta: f32) -> ($a, $a) {
274+
fn freq_sin_cos_rope(self, k: isize, dh: isize, theta: f32) -> ($a, $a) {
198275
(self.val() as $a / (theta as $a).powf(k as $a / dh as $a)).sin_cos()
199276
}
277+
#[inline]
278+
fn freq_sin_cos_long(
279+
self,
280+
k: isize,
281+
dh: isize,
282+
theta: f32,
283+
factor: $a,
284+
s: f32,
285+
) -> ($a, $a) {
286+
let (sin, cos) =
287+
(self.val() as $a / (theta as $a).powf(k as $a / dh as $a) * factor).sin_cos();
288+
(sin * s as $a, cos * s as $a)
289+
}
200290
}
201291
};
202292
}
@@ -206,8 +296,8 @@ impl_position!(f64);
206296

207297
impl<A, P> Scheme<A, P>
208298
where
209-
A: Activation,
210-
P: Position<A::Calculation> + Sync + Copy,
299+
A: Activation + Copy,
300+
P: Position<A::Calculation> + Sync + Copy + Unsigned,
211301
{
212302
fn calculate(&self) {
213303
let &Self {
@@ -220,6 +310,7 @@ where
220310
theta,
221311
t_base,
222312
p_base,
313+
scheme_type,
223314
} = self;
224315
let nt = nt as isize;
225316
let nh = nh as isize;
@@ -229,10 +320,39 @@ where
229320
for i in 0..nt {
230321
let t = unsafe { t_base.byte_offset(i * st).cast::<[A; 2]>() };
231322
let p = unsafe { *p_base.byte_offset(i * sp) };
323+
let factor = match scheme_type {
324+
SchemeType::Long {
325+
long,
326+
short,
327+
origin_pos,
328+
..
329+
} => unsafe {
330+
if p.val() < origin_pos as usize {
331+
short.byte_offset(i * st).cast()
332+
} else {
333+
long.byte_offset(i * st).cast()
334+
}
335+
},
336+
_ => null(),
337+
};
232338
for j in 0..nh {
233339
for k in 0..dh {
234340
let pair = unsafe { &mut *t.byte_offset(j * sh + k * sd) };
235-
let (sin, cos) = p.freq_sin_cos(k, dh, theta);
341+
let (sin, cos) = match scheme_type {
342+
SchemeType::Rope => p.freq_sin_cos_rope(k, dh, theta),
343+
SchemeType::Long {
344+
long,
345+
short,
346+
origin_pos,
347+
s,
348+
} => {
349+
let factor = unsafe { *factor };
350+
p.freq_sin_cos_long(k, dh, theta, factor, s)
351+
}
352+
_ => {
353+
todo!()
354+
}
355+
};
236356
A::calculate(pair, sin, cos)
237357
}
238358
}

operators/src/rope/cuda/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ extern "C" __global__ void {POS_U64}(
184184
#[cfg(test)]
185185
mod test {
186186
use super::{Args, Gpu, Operator, POS_U32, POS_U64};
187-
use crate::{Hardware, Operator as _, TensorLayout};
187+
use crate::{rope::args, Hardware, Operator as _, TensorLayout};
188188
use digit_layout::{
189189
types::{F16, F64, U32},
190190
DigitLayout,
@@ -203,6 +203,7 @@ mod test {
203203
cos_layout: TensorLayout::new_dyn(dt_t, &[dyn_(); 2], &[dyn_(); 2]),
204204
cos_base: null(),
205205
theta: 0.,
206+
rope_type: args::RopeType::Rope,
206207
}
207208
}
208209

@@ -227,6 +228,7 @@ mod test {
227228
cos_layout: TensorLayout::new_contiguous(dt_t, &[0, dh]),
228229
cos_base: null(),
229230
theta,
231+
rope_type: args::RopeType::Rope,
230232
}
231233
}
232234

operators/src/rope/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pub mod infini;
88
pub mod opencl;
99

1010
mod args;
11-
pub use args::Args;
11+
pub use args::{Args,RopeType};
1212

1313
crate::op_trait! { Rope
1414
/// 生成 sincos 表([2, n, dh])。

0 commit comments

Comments
 (0)