|
| 1 | +mod gguf; |
| 2 | + |
| 3 | +use digit_layout::DigitLayout; |
| 4 | +use gguf::{Content, GGufTensor}; |
| 5 | +use ggus::{GGufMetaDataValueType, GGufMetaMap, GGufMetaMapExt}; |
| 6 | +use memmap2::Mmap; |
| 7 | +use ndarray_layout::{ArrayLayout, Endian::BigEndian}; |
| 8 | +use std::{collections::HashMap, env::var_os, fs::File}; |
| 9 | + |
| 10 | +/// 测例数据。 |
| 11 | +pub struct TestCase<'a> { |
| 12 | + /// 测例在文件中的序号。 |
| 13 | + pub index: usize, |
| 14 | + /// 测例元信息,即传递给算子非张量参数。 |
| 15 | + pub attributes: HashMap<String, MetaValue<'a>>, |
| 16 | + /// 测例张量,包括算子的输入和正确答案。 |
| 17 | + pub tensors: HashMap<String, Tensor<'a>>, |
| 18 | +} |
| 19 | + |
| 20 | +/// 元信息键值对。 |
| 21 | +pub struct MetaValue<'a> { |
| 22 | + pub ty: GGufMetaDataValueType, |
| 23 | + pub value: &'a [u8], |
| 24 | +} |
| 25 | + |
| 26 | +/// 测例张量。 |
| 27 | +pub struct Tensor<'a> { |
| 28 | + pub ty: DigitLayout, |
| 29 | + pub layout: ArrayLayout<4>, |
| 30 | + pub data: &'a [u8], |
| 31 | +} |
| 32 | + |
| 33 | +impl GGufMetaMap for TestCase<'_> { |
| 34 | + fn get(&self, key: &str) -> Option<(GGufMetaDataValueType, &[u8])> { |
| 35 | + self.attributes.get(key).map(|v| (v.ty, &*v.value)) |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +impl<'a> Content<'a> { |
| 40 | + pub fn into_cases(self) -> HashMap<String, Vec<TestCase<'a>>> { |
| 41 | + assert_eq!(self.general_architecture().unwrap(), "infiniop-test"); |
| 42 | + let mut ans = HashMap::new(); |
| 43 | + |
| 44 | + let ntest = self.get_usize("test_count").unwrap(); |
| 45 | + let Self { |
| 46 | + mut meta_kvs, |
| 47 | + mut tensors, |
| 48 | + } = self; |
| 49 | + for i in 0..ntest { |
| 50 | + let prefix = format!("test.{i}."); |
| 51 | + |
| 52 | + let mut meta_kvs = meta_kvs |
| 53 | + .split_by_prefix(&prefix) |
| 54 | + .into_iter() |
| 55 | + .map(|(k, v)| (k[prefix.len()..].to_string(), v)) |
| 56 | + .collect::<HashMap<_, _>>(); |
| 57 | + let tensors = tensors |
| 58 | + .split_by_prefix(&prefix) |
| 59 | + .into_iter() |
| 60 | + .map(|(k, v)| { |
| 61 | + let GGufTensor { |
| 62 | + ty, |
| 63 | + mut shape, |
| 64 | + data, |
| 65 | + } = v; |
| 66 | + shape.reverse(); |
| 67 | + |
| 68 | + let k = k[prefix.len()..].to_string(); |
| 69 | + let element_size = ty.nbytes(); |
| 70 | + let layout = if let Some(strides) = meta_kvs.remove(&format!("{k}.strides")) { |
| 71 | + let mut strides = strides.to_vec_isize().unwrap(); |
| 72 | + for x in &mut strides { |
| 73 | + *x *= element_size as isize |
| 74 | + } |
| 75 | + strides.reverse(); |
| 76 | + |
| 77 | + ArrayLayout::<4>::new(&shape, &strides, 0) |
| 78 | + } else { |
| 79 | + ArrayLayout::<4>::new_contiguous(&shape, BigEndian, element_size) |
| 80 | + }; |
| 81 | + (k, Tensor { ty, layout, data }) |
| 82 | + }) |
| 83 | + .collect(); |
| 84 | + |
| 85 | + let case = TestCase { |
| 86 | + index: i, |
| 87 | + attributes: meta_kvs, |
| 88 | + tensors, |
| 89 | + }; |
| 90 | + |
| 91 | + let op_name = case.get_str("op_name").unwrap().to_string(); |
| 92 | + ans.entry(op_name).or_insert_with(Vec::new).push(case); |
| 93 | + } |
| 94 | + |
| 95 | + ans |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +#[test] |
| 100 | +fn test() { |
| 101 | + let Some(name) = var_os("TEST_CASES") else { |
| 102 | + eprintln!("TEST_CASES not set"); |
| 103 | + return; |
| 104 | + }; |
| 105 | + let Ok(file) = File::open(&name) else { |
| 106 | + eprintln!("Failed to open {}", name.to_string_lossy()); |
| 107 | + return; |
| 108 | + }; |
| 109 | + let mmap = unsafe { Mmap::map(&file).unwrap() }; |
| 110 | + let cases = Content::new(&[&*mmap]).unwrap().into_cases(); |
| 111 | + |
| 112 | + for (op_name, cases) in cases { |
| 113 | + for case in cases { |
| 114 | + println!("Test case {}: {op_name}", case.index); |
| 115 | + for (k, v) in &case.attributes { |
| 116 | + println!(" {k}: {:?}", v.ty) |
| 117 | + } |
| 118 | + for (k, v) in &case.tensors { |
| 119 | + println!( |
| 120 | + " {k}: {} {:?} / {:?} {}", |
| 121 | + v.ty, |
| 122 | + v.layout.shape(), |
| 123 | + v.layout.strides(), |
| 124 | + v.data.len() |
| 125 | + ) |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | +} |
0 commit comments