1- use super :: { args:: Meta , fill_pos, Args , Rope , Seq , SinCosTable } ;
1+ use std:: ptr:: null;
2+
3+ use super :: { args:: Meta , args:: RopeType as R , fill_pos, Args , Rope , Seq , SinCosTable } ;
24use crate :: {
35 common_cpu:: Cpu , get_static, strides_not_support, ByteOf , LaunchError , QueueAlloc , SchemeError ,
46 Unsigned ,
57} ;
68use digit_layout:: { types as ty, DigitLayout } ;
79use half:: f16;
8-
10+ #[ derive( Copy , Clone ) ]
11+ enum SchemeType < T > {
12+ Rope ,
13+ Long {
14+ long : * const T ,
15+ short : * const T ,
16+ s : f32 ,
17+ origin_pos : u32 ,
18+ } ,
19+ Yarn ,
20+ }
921pub struct Operator ;
1022
1123impl Rope < Cpu > for Operator {
@@ -78,6 +90,7 @@ impl crate::Operator for Operator {
7890 p_layout,
7991 p_base,
8092 theta,
93+ rope_type,
8194 ..
8295 } = args;
8396 let & [ _, nh, dh] = t_layout. shape ( ) else {
@@ -99,33 +112,89 @@ impl crate::Operator for Operator {
99112 return Err ( strides_not_support ( "" ) . into ( ) ) ;
100113 }
101114
102- macro_rules! calculate {
103- ( $t: ty, $p: ty) => {
104- Scheme :: <$t, $p> {
105- nt,
106- nh,
107- dh,
108- st,
109- sh,
110- sp,
111- theta: * theta,
112- t_base: t_base. cast( ) ,
113- p_base: p_base. cast( ) ,
115+ match rope_type {
116+ R :: Rope | R :: Dyn { .. } | R :: Ntk { .. } => {
117+ let theta = match rope_type {
118+ R :: Rope => * theta,
119+ R :: Dyn { s, a } => theta * ( a * s - a + 1. ) ,
120+ R :: Ntk { s } => theta * s,
121+ _ => unreachable ! ( ) ,
122+ } ;
123+ macro_rules! calculate {
124+ ( $t: ty, $p: ty) => {
125+ Scheme :: <$t, $p> {
126+ nt,
127+ nh,
128+ dh,
129+ st,
130+ sh,
131+ sp,
132+ theta,
133+ t_base: t_base. cast( ) ,
134+ p_base: p_base. cast( ) ,
135+ scheme_type: SchemeType :: Rope ,
136+ }
137+ . calculate( )
138+ } ;
114139 }
115- . calculate( )
116- } ;
117- }
118140
119- use digit_layout:: types as ty;
120- match ( dt_t, dt_p) {
121- ( ty:: F16 , ty:: U32 ) => calculate ! ( f16, u32 ) ,
122- ( ty:: F16 , ty:: U64 ) => calculate ! ( f16, u64 ) ,
123- ( ty:: F32 , ty:: U32 ) => calculate ! ( f32 , u32 ) ,
124- ( ty:: F32 , ty:: U64 ) => calculate ! ( f32 , u64 ) ,
125- ( ty:: F64 , ty:: U32 ) => calculate ! ( f64 , u32 ) ,
126- ( ty:: F64 , ty:: U64 ) => calculate ! ( f64 , u64 ) ,
127- _ => todo ! ( ) ,
141+ use digit_layout:: types as ty;
142+ match ( dt_t, dt_p) {
143+ ( ty:: F16 , ty:: U32 ) => calculate ! ( f16, u32 ) ,
144+ ( ty:: F16 , ty:: U64 ) => calculate ! ( f16, u64 ) ,
145+ ( ty:: F32 , ty:: U32 ) => calculate ! ( f32 , u32 ) ,
146+ ( ty:: F32 , ty:: U64 ) => calculate ! ( f32 , u64 ) ,
147+ ( ty:: F64 , ty:: U32 ) => calculate ! ( f64 , u32 ) ,
148+ ( ty:: F64 , ty:: U64 ) => calculate ! ( f64 , u64 ) ,
149+ _ => todo ! ( ) ,
150+ }
151+ }
152+ R :: Long {
153+ long,
154+ short,
155+ max_pos,
156+ origin_pos,
157+ } => {
158+ let s = 1.0
159+ + ( ( * max_pos as f32 / * origin_pos as f32 ) . ln ( ) / ( * origin_pos as f32 ) . ln ( ) )
160+ . sqrt ( ) ;
161+ macro_rules! calculate {
162+ ( $t: ty, $p: ty) => {
163+ Scheme :: <$t, $p> {
164+ nt,
165+ nh,
166+ dh,
167+ st,
168+ sh,
169+ sp,
170+ theta: * theta,
171+ t_base: t_base. cast( ) ,
172+ p_base: p_base. cast( ) ,
173+ scheme_type: SchemeType :: Long {
174+ long: long. cast( ) ,
175+ short: short. cast( ) ,
176+ s,
177+ origin_pos: * origin_pos,
178+ } ,
179+ }
180+ . calculate( )
181+ } ;
182+ }
183+
184+ use digit_layout:: types as ty;
185+ match ( dt_t, dt_p) {
186+ ( ty:: F16 , ty:: U32 ) => calculate ! ( f16, u32 ) ,
187+ ( ty:: F16 , ty:: U64 ) => calculate ! ( f16, u64 ) ,
188+ ( ty:: F32 , ty:: U32 ) => calculate ! ( f32 , u32 ) ,
189+ ( ty:: F32 , ty:: U64 ) => calculate ! ( f32 , u64 ) ,
190+ ( ty:: F64 , ty:: U32 ) => calculate ! ( f64 , u32 ) ,
191+ ( ty:: F64 , ty:: U64 ) => calculate ! ( f64 , u64 ) ,
192+ _ => todo ! ( ) ,
193+ }
194+ }
195+ _ => { }
128196 }
197+
129198 Ok ( ( ) )
130199 }
131200}
@@ -142,15 +211,15 @@ struct Scheme<A, P> {
142211 theta : f32 ,
143212 t_base : * mut A ,
144213 p_base : * const P ,
214+ scheme_type : SchemeType < A > ,
145215}
146216
147217unsafe impl < A , P > Send for Scheme < A , P > { }
148218unsafe impl < A , P > Sync for Scheme < A , P > { }
149-
150219/// 激活值。
151220trait Activation : Sized {
152221 /// 激活值类型决定计算类型。
153- type Calculation ;
222+ type Calculation : Copy ;
154223 /// 计算流程。
155224 fn calculate ( pair : & mut [ Self ; 2 ] , sin : Self :: Calculation , cos : Self :: Calculation ) ;
156225}
@@ -187,16 +256,37 @@ impl Activation for f64 {
187256}
188257
189258trait Position < Calculation > {
190- fn freq_sin_cos ( self , k : isize , dh : isize , theta : f32 ) -> ( Calculation , Calculation ) ;
259+ fn freq_sin_cos_rope ( self , k : isize , dh : isize , theta : f32 ) -> ( Calculation , Calculation ) ;
260+ fn freq_sin_cos_long (
261+ self ,
262+ k : isize ,
263+ dh : isize ,
264+ theta : f32 ,
265+ factor : Calculation ,
266+ s : f32 ,
267+ ) -> ( Calculation , Calculation ) ;
191268}
192269
193270macro_rules! impl_position {
194271 ( $a: ty) => {
195272 impl <T : Unsigned > Position <$a> for T {
196273 #[ inline]
197- fn freq_sin_cos ( self , k: isize , dh: isize , theta: f32 ) -> ( $a, $a) {
274+ fn freq_sin_cos_rope ( self , k: isize , dh: isize , theta: f32 ) -> ( $a, $a) {
198275 ( self . val( ) as $a / ( theta as $a) . powf( k as $a / dh as $a) ) . sin_cos( )
199276 }
277+ #[ inline]
278+ fn freq_sin_cos_long(
279+ self ,
280+ k: isize ,
281+ dh: isize ,
282+ theta: f32 ,
283+ factor: $a,
284+ s: f32 ,
285+ ) -> ( $a, $a) {
286+ let ( sin, cos) =
287+ ( self . val( ) as $a / ( theta as $a) . powf( k as $a / dh as $a) * factor) . sin_cos( ) ;
288+ ( sin * s as $a, cos * s as $a)
289+ }
200290 }
201291 } ;
202292}
@@ -206,8 +296,8 @@ impl_position!(f64);
206296
207297impl < A , P > Scheme < A , P >
208298where
209- A : Activation ,
210- P : Position < A :: Calculation > + Sync + Copy ,
299+ A : Activation + Copy ,
300+ P : Position < A :: Calculation > + Sync + Copy + Unsigned ,
211301{
212302 fn calculate ( & self ) {
213303 let & Self {
@@ -220,6 +310,7 @@ where
220310 theta,
221311 t_base,
222312 p_base,
313+ scheme_type,
223314 } = self ;
224315 let nt = nt as isize ;
225316 let nh = nh as isize ;
@@ -229,10 +320,39 @@ where
229320 for i in 0 ..nt {
230321 let t = unsafe { t_base. byte_offset ( i * st) . cast :: < [ A ; 2 ] > ( ) } ;
231322 let p = unsafe { * p_base. byte_offset ( i * sp) } ;
323+ let factor = match scheme_type {
324+ SchemeType :: Long {
325+ long,
326+ short,
327+ origin_pos,
328+ ..
329+ } => unsafe {
330+ if p. val ( ) < origin_pos as usize {
331+ short. byte_offset ( i * st) . cast ( )
332+ } else {
333+ long. byte_offset ( i * st) . cast ( )
334+ }
335+ } ,
336+ _ => null ( ) ,
337+ } ;
232338 for j in 0 ..nh {
233339 for k in 0 ..dh {
234340 let pair = unsafe { & mut * t. byte_offset ( j * sh + k * sd) } ;
235- let ( sin, cos) = p. freq_sin_cos ( k, dh, theta) ;
341+ let ( sin, cos) = match scheme_type {
342+ SchemeType :: Rope => p. freq_sin_cos_rope ( k, dh, theta) ,
343+ SchemeType :: Long {
344+ long,
345+ short,
346+ origin_pos,
347+ s,
348+ } => {
349+ let factor = unsafe { * factor } ;
350+ p. freq_sin_cos_long ( k, dh, theta, factor, s)
351+ }
352+ _ => {
353+ todo ! ( )
354+ }
355+ } ;
236356 A :: calculate ( pair, sin, cos)
237357 }
238358 }
0 commit comments