Skip to content

Commit 05e0f9f

Browse files
committed
temp for test
1 parent dcdde0a commit 05e0f9f

File tree

6 files changed

+207
-250
lines changed

6 files changed

+207
-250
lines changed

llama.cu/src/exec/mamba.rs

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ use crate::exec::output_head::OutputHead;
44
use crate::exec::sample_manager::SampleManager;
55
use crate::memory::MemPages;
66
use crate::op::random_sample::{KVPair, SampleArgs};
7-
use crate::utils::{self, meta};
7+
use crate::utils::meta;
88
use crate::{handle::Handle, model::map_files};
99
use cuda::Device;
1010
use ggus::GGufMetaMapExt;
1111
use nn::Distribution;
1212
use std::env;
13+
use std::time::Instant;
1314
use tokeneer::Bpe;
1415

1516
#[allow(dead_code)]
@@ -77,6 +78,7 @@ pub fn mamba_infer(
7778
let mut pages = MemPages::new(device);
7879
let mut mcache = MambaCache::new(n_layer, d_inner, d_conv, d_state, &mut pages);
7980

81+
let start = Instant::now();
8082
// Prefill
8183
let (key, _tok_buf) =
8284
models.load_inputs_mamba_prefill(&mut handle, tokens.len(), &tokens, &stream);
@@ -86,8 +88,13 @@ pub fn mamba_infer(
8688
let last_idx: [tokeneer::utok; 1] = [(tokens.len() - 1) as tokeneer::utok];
8789
let logits_prefill_last = output_head.launch(x.clone(), &last_idx, &mut handle, &stream);
8890

89-
let logits_prefill_last_vir = logits_prefill_last.as_ref().map(|mem| mem.as_ptr().cast());
90-
utils::fmt(&logits_prefill_last_vir, stream.ctx());
91+
let prefill_time = start.elapsed();
92+
println!("prefill time = {:.2}", prefill_time.as_secs_f32());
93+
94+
// let logits_prefill_last_vir = logits_prefill_last
95+
// .as_ref()
96+
// .map(|mem| mem.as_ptr().cast::<VirByte>());
97+
// utils::fmt(&logits_prefill_last_vir, stream.ctx());
9198
// check prefill logits
9299

93100
let mut next_id: tokeneer::utok;
@@ -97,7 +104,7 @@ pub fn mamba_infer(
97104
let cfg0 = vec![(
98105
crate::batch::SessionId(0),
99106
crate::batch::SampleInfo {
100-
args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(),
107+
args: SampleArgs::new(0.8, 0.95, 50, 1.3).unwrap(),
101108
input_idx: tokens.len(),
102109
decode_len: tokens.len(),
103110
},
@@ -119,7 +126,8 @@ pub fn mamba_infer(
119126
let max_decode_steps: usize = env::var("MAMBA_STEPS")
120127
.ok()
121128
.and_then(|s| s.parse().ok())
122-
.unwrap_or(100);
129+
.unwrap_or(200);
130+
println!("max steps = {}", max_decode_steps);
123131
for _step in 1..max_decode_steps {
124132
let out_idx: [tokeneer::utok; 1] = [0];
125133

@@ -130,7 +138,7 @@ pub fn mamba_infer(
130138
let cfg = vec![(
131139
crate::batch::SessionId(0),
132140
crate::batch::SampleInfo {
133-
args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(),
141+
args: SampleArgs::new(0.8, 0.95, 50, 1.3).unwrap(),
134142
input_idx: tokens.len(),
135143
decode_len: tokens.len(),
136144
},
@@ -151,7 +159,9 @@ pub fn mamba_infer(
151159
x = models.launch_mamba(key, &mut mcache, &mut handle, &stream);
152160
}
153161

154-
println!("tokens = {:?}", tokens);
162+
let decode_time = start.elapsed() - prefill_time;
163+
println!("decode time = {:.2}", decode_time.as_secs_f64());
164+
// println!("tokens = {:?}", tokens);
155165
let mut text_buf = tokeneer::TextBuf::new();
156166
let s = tokenizer.decode(&generated, &mut text_buf);
157167
let text = String::from_utf8_lossy(&s.into_bytes()).to_string();
@@ -160,22 +170,23 @@ pub fn mamba_infer(
160170
})
161171
}
162172

163-
// #[cfg(test)]
164-
// mod tests {
165-
// use super::*;
166-
// use std::{path::PathBuf, time::Instant};
167-
168-
// #[test]
169-
// fn test_mamba_infer_decode() {
170-
// let start = Instant::now();
171-
// let model = PathBuf::from("/home/cearx/qy/model/Mamba_adf32-2.8B-hf-v1.0-F16.gguf");
172-
// let prompt = "Once upon a time,";
173-
// let (text, len) = mamba_infer(model, prompt, false);
174-
// let end = Instant::now();
175-
// let tokens_per_second = len as f64 / (end - start).as_secs_f64();
176-
// println!("infer time = {:?}", end - start);
177-
// println!("tokens/s = {}", tokens_per_second);
178-
// println!("prompt = {}", prompt);
179-
// println!("mamba infer text = {}", text);
180-
// }
181-
// }
173+
#[cfg(test)]
174+
mod tests {
175+
use super::*;
176+
use std::{path::PathBuf, time::Instant};
177+
178+
#[test]
179+
fn test_mamba_infer_decode() {
180+
let start = Instant::now();
181+
let model = PathBuf::from("/home/cearx/Mamba-2.8B-hf-v1.0-F16.gguf");
182+
let prompt = "Once upon a time,";
183+
let (text, len) = mamba_infer(model, prompt, false);
184+
let end = Instant::now();
185+
let tokens_per_second = len as f64 / (end - start).as_secs_f64();
186+
let infer_time = end - start;
187+
println!("infer time = {:.2} s", infer_time.as_secs_f64());
188+
println!("tokens/s = {:.2}", tokens_per_second);
189+
println!("prompt = {}", prompt);
190+
println!("output text = {}", text);
191+
}
192+
}

0 commit comments

Comments
 (0)