11use crate :: {
22 fuesd_softmax:: AttnMask ,
33 utils:: { dim_distinct, rank_error, type_distinct} ,
4- ConstPtr , Hardware , LaunchError , MaybeDyn , MutPtr , TensorLayout ,
4+ ConstPtr , Hardware , LaunchError , MutPtr , TensorLayout ,
55} ;
66use digit_layout:: DigitLayout ;
77
@@ -23,11 +23,11 @@ pub struct Args<H: Hardware> {
2323
2424pub ( super ) struct Meta {
2525 pub dt : DigitLayout ,
26- pub nh : MaybeDyn < usize > ,
27- pub nkvh : MaybeDyn < usize > ,
28- pub seq : MaybeDyn < usize > ,
29- pub att : MaybeDyn < usize > ,
30- pub dh : MaybeDyn < usize > ,
26+ pub nh : usize ,
27+ pub nkvh : usize ,
28+ pub seq : usize ,
29+ pub att : usize ,
30+ pub dh : usize ,
3131}
3232
3333impl < H : Hardware > Args < H > {
@@ -40,26 +40,26 @@ impl<H: Hardware> Args<H> {
4040 ..
4141 } = self ;
4242
43- let & [ nh_q, seq_q, dh_q] = q_layout. shape ( ) else {
43+ let & [ nh_q, seq_q, dh_q] = & * q_layout. shape ( ) else {
4444 return Err ( rank_error ( "q" , 3 , q_layout. ndim ( ) ) ) ;
4545 } ;
46- let & [ nkvh_k, att_k, dh_k] = k_layout. shape ( ) else {
46+ let & [ nkvh_k, att_k, dh_k] = & * k_layout. shape ( ) else {
4747 return Err ( rank_error ( "k" , 3 , k_layout. ndim ( ) ) ) ;
4848 } ;
49- let & [ nkvh_v, att_v, dh_v] = v_layout. shape ( ) else {
49+ let & [ nkvh_v, att_v, dh_v] = & * v_layout. shape ( ) else {
5050 return Err ( rank_error ( "v" , 3 , v_layout. ndim ( ) ) ) ;
5151 } ;
52- let & [ nh_o, seq_o, dh_o] = o_layout. shape ( ) else {
52+ let & [ nh_o, seq_o, dh_o] = & * o_layout. shape ( ) else {
5353 return Err ( rank_error ( "o" , 3 , o_layout. ndim ( ) ) ) ;
5454 } ;
5555
5656 Ok ( Meta {
57- dt : type_distinct ( & [ q_layout. dt ( ) , k_layout. dt ( ) , v_layout. dt ( ) , o_layout. dt ( ) ] ) ?,
58- nh : dim_distinct ( & [ nh_q, nh_o] ) ? ,
59- nkvh : dim_distinct ( & [ nkvh_k, nkvh_v] ) ? ,
60- seq : dim_distinct ( & [ seq_q, seq_o] ) ? ,
61- att : dim_distinct ( & [ att_k, att_v] ) ? ,
62- dh : dim_distinct ( & [ dh_q, dh_k, dh_v, dh_o] ) ? ,
57+ dt : type_distinct ( & [ q_layout. dt , k_layout. dt , v_layout. dt , o_layout. dt ] ) ?,
58+ nh : dim_distinct ( & [ nh_q, nh_o] ) . expect ( "nh mismatch" ) ,
59+ nkvh : dim_distinct ( & [ nkvh_k, nkvh_v] ) . expect ( "nkvh mismatch" ) ,
60+ seq : dim_distinct ( & [ seq_q, seq_o] ) . expect ( "seq mismatch" ) ,
61+ att : dim_distinct ( & [ att_k, att_v] ) . expect ( "att mismatch" ) ,
62+ dh : dim_distinct ( & [ dh_q, dh_k, dh_v, dh_o] ) . expect ( "dh mismatch" ) ,
6363 } )
6464 }
6565}
0 commit comments