@@ -4,12 +4,13 @@ use crate::exec::output_head::OutputHead;
44use crate :: exec:: sample_manager:: SampleManager ;
55use crate :: memory:: MemPages ;
66use crate :: op:: random_sample:: { KVPair , SampleArgs } ;
7- use crate :: utils:: { self , meta} ;
7+ use crate :: utils:: meta;
88use crate :: { handle:: Handle , model:: map_files} ;
99use cuda:: Device ;
1010use ggus:: GGufMetaMapExt ;
1111use nn:: Distribution ;
1212use std:: env;
13+ use std:: time:: Instant ;
1314use tokeneer:: Bpe ;
1415
1516#[ allow( dead_code) ]
@@ -77,6 +78,7 @@ pub fn mamba_infer(
7778 let mut pages = MemPages :: new ( device) ;
7879 let mut mcache = MambaCache :: new ( n_layer, d_inner, d_conv, d_state, & mut pages) ;
7980
81+ let start = Instant :: now ( ) ;
8082 // Prefill
8183 let ( key, _tok_buf) =
8284 models. load_inputs_mamba_prefill ( & mut handle, tokens. len ( ) , & tokens, & stream) ;
@@ -86,8 +88,13 @@ pub fn mamba_infer(
8688 let last_idx: [ tokeneer:: utok ; 1 ] = [ ( tokens. len ( ) - 1 ) as tokeneer:: utok ] ;
8789 let logits_prefill_last = output_head. launch ( x. clone ( ) , & last_idx, & mut handle, & stream) ;
8890
89- let logits_prefill_last_vir = logits_prefill_last. as_ref ( ) . map ( |mem| mem. as_ptr ( ) . cast ( ) ) ;
90- utils:: fmt ( & logits_prefill_last_vir, stream. ctx ( ) ) ;
91+ let prefill_time = start. elapsed ( ) ;
92+ println ! ( "prefill time = {:.2}" , prefill_time. as_secs_f32( ) ) ;
93+
94+ // let logits_prefill_last_vir = logits_prefill_last
95+ // .as_ref()
96+ // .map(|mem| mem.as_ptr().cast::<VirByte>());
97+ // utils::fmt(&logits_prefill_last_vir, stream.ctx());
9198 // check prefill logits
9299
93100 let mut next_id: tokeneer:: utok ;
@@ -97,7 +104,7 @@ pub fn mamba_infer(
97104 let cfg0 = vec ! [ (
98105 crate :: batch:: SessionId ( 0 ) ,
99106 crate :: batch:: SampleInfo {
100- args: SampleArgs :: new( 0.8 , 0.95 , 50 , 1.2 ) . unwrap( ) ,
107+ args: SampleArgs :: new( 0.8 , 0.95 , 50 , 1.3 ) . unwrap( ) ,
101108 input_idx: tokens. len( ) ,
102109 decode_len: tokens. len( ) ,
103110 } ,
@@ -119,7 +126,8 @@ pub fn mamba_infer(
119126 let max_decode_steps: usize = env:: var ( "MAMBA_STEPS" )
120127 . ok ( )
121128 . and_then ( |s| s. parse ( ) . ok ( ) )
122- . unwrap_or ( 100 ) ;
129+ . unwrap_or ( 200 ) ;
130+ println ! ( "max steps = {}" , max_decode_steps) ;
123131 for _step in 1 ..max_decode_steps {
124132 let out_idx: [ tokeneer:: utok ; 1 ] = [ 0 ] ;
125133
@@ -130,7 +138,7 @@ pub fn mamba_infer(
130138 let cfg = vec ! [ (
131139 crate :: batch:: SessionId ( 0 ) ,
132140 crate :: batch:: SampleInfo {
133- args: SampleArgs :: new( 0.8 , 0.95 , 50 , 1.2 ) . unwrap( ) ,
141+ args: SampleArgs :: new( 0.8 , 0.95 , 50 , 1.3 ) . unwrap( ) ,
134142 input_idx: tokens. len( ) ,
135143 decode_len: tokens. len( ) ,
136144 } ,
@@ -151,7 +159,9 @@ pub fn mamba_infer(
151159 x = models. launch_mamba ( key, & mut mcache, & mut handle, & stream) ;
152160 }
153161
154- println ! ( "tokens = {:?}" , tokens) ;
162+ let decode_time = start. elapsed ( ) - prefill_time;
163+ println ! ( "decode time = {:.2}" , decode_time. as_secs_f64( ) ) ;
164+ // println!("tokens = {:?}", tokens);
155165 let mut text_buf = tokeneer:: TextBuf :: new ( ) ;
156166 let s = tokenizer. decode ( & generated, & mut text_buf) ;
157167 let text = String :: from_utf8_lossy ( & s. into_bytes ( ) ) . to_string ( ) ;
@@ -160,22 +170,23 @@ pub fn mamba_infer(
160170 } )
161171}
162172
163- // #[cfg(test)]
164- // mod tests {
165- // use super::*;
166- // use std::{path::PathBuf, time::Instant};
167-
168- // #[test]
169- // fn test_mamba_infer_decode() {
170- // let start = Instant::now();
171- // let model = PathBuf::from("/home/cearx/qy/model/Mamba_adf32-2.8B-hf-v1.0-F16.gguf");
172- // let prompt = "Once upon a time,";
173- // let (text, len) = mamba_infer(model, prompt, false);
174- // let end = Instant::now();
175- // let tokens_per_second = len as f64 / (end - start).as_secs_f64();
176- // println!("infer time = {:?}", end - start);
177- // println!("tokens/s = {}", tokens_per_second);
178- // println!("prompt = {}", prompt);
179- // println!("mamba infer text = {}", text);
180- // }
181- // }
173+ #[ cfg( test) ]
174+ mod tests {
175+ use super :: * ;
176+ use std:: { path:: PathBuf , time:: Instant } ;
177+
178+ #[ test]
179+ fn test_mamba_infer_decode ( ) {
180+ let start = Instant :: now ( ) ;
181+ let model = PathBuf :: from ( "/home/cearx/Mamba-2.8B-hf-v1.0-F16.gguf" ) ;
182+ let prompt = "Once upon a time," ;
183+ let ( text, len) = mamba_infer ( model, prompt, false ) ;
184+ let end = Instant :: now ( ) ;
185+ let tokens_per_second = len as f64 / ( end - start) . as_secs_f64 ( ) ;
186+ let infer_time = end - start;
187+ println ! ( "infer time = {:.2} s" , infer_time. as_secs_f64( ) ) ;
188+ println ! ( "tokens/s = {:.2}" , tokens_per_second) ;
189+ println ! ( "prompt = {}" , prompt) ;
190+ println ! ( "output text = {}" , text) ;
191+ }
192+ }
0 commit comments