@@ -1254,15 +1254,59 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
12541254 return wtype_stat;
12551255}
12561256
1257- void ModelLoader::set_wtype_override (ggml_type wtype, std::string prefix) {
1257+ static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules (const std::string& tensor_type_rules) {
1258+ std::vector<std::pair<std::string, ggml_type>> result;
1259+ for (const auto & item : split_string (tensor_type_rules, ' ,' )) {
1260+ if (item.size () == 0 )
1261+ continue ;
1262+ std::string::size_type pos = item.find (' =' );
1263+ if (pos == std::string::npos) {
1264+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1265+ continue ;
1266+ }
1267+ std::string tensor_pattern = item.substr (0 , pos);
1268+ std::string type_name = item.substr (pos + 1 );
1269+
1270+ ggml_type tensor_type = GGML_TYPE_COUNT;
1271+
1272+ if (type_name == " f32" ) {
1273+ tensor_type = GGML_TYPE_F32;
1274+ } else {
1275+ for (size_t i = 0 ; i < GGML_TYPE_COUNT; i++) {
1276+ auto trait = ggml_get_type_traits ((ggml_type)i);
1277+ if (trait->to_float && trait->type_size && type_name == trait->type_name ) {
1278+ tensor_type = (ggml_type)i;
1279+ }
1280+ }
1281+ }
1282+
1283+ if (tensor_type != GGML_TYPE_COUNT) {
1284+ result.emplace_back (tensor_pattern, tensor_type);
1285+ } else {
1286+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1287+ }
1288+ }
1289+ return result;
1290+ }
1291+
1292+ void ModelLoader::set_wtype_override (ggml_type wtype, std::string tensor_type_rules) {
1293+ auto map_rules = parse_tensor_type_rules (tensor_type_rules);
12581294 for (auto & [name, tensor_storage] : tensor_storage_map) {
1259- if (!starts_with (name, prefix)) {
1295+ ggml_type dst_type = wtype;
1296+ for (const auto & tensor_type_rule : map_rules) {
1297+ std::regex pattern (tensor_type_rule.first );
1298+ if (std::regex_search (name, pattern)) {
1299+ dst_type = tensor_type_rule.second ;
1300+ break ;
1301+ }
1302+ }
1303+ if (dst_type == GGML_TYPE_COUNT) {
12601304 continue ;
12611305 }
1262- if (!tensor_should_be_converted (tensor_storage, wtype )) {
1306+ if (!tensor_should_be_converted (tensor_storage, dst_type )) {
12631307 continue ;
12641308 }
1265- tensor_storage.expected_type = wtype ;
1309+ tensor_storage.expected_type = dst_type ;
12661310 }
12671311}
12681312
@@ -1603,41 +1647,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
16031647 return true ;
16041648}
16051649
1606- std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules (const std::string& tensor_type_rules) {
1607- std::vector<std::pair<std::string, ggml_type>> result;
1608- for (const auto & item : split_string (tensor_type_rules, ' ,' )) {
1609- if (item.size () == 0 )
1610- continue ;
1611- std::string::size_type pos = item.find (' =' );
1612- if (pos == std::string::npos) {
1613- LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1614- continue ;
1615- }
1616- std::string tensor_pattern = item.substr (0 , pos);
1617- std::string type_name = item.substr (pos + 1 );
1618-
1619- ggml_type tensor_type = GGML_TYPE_COUNT;
1620-
1621- if (type_name == " f32" ) {
1622- tensor_type = GGML_TYPE_F32;
1623- } else {
1624- for (size_t i = 0 ; i < GGML_TYPE_COUNT; i++) {
1625- auto trait = ggml_get_type_traits ((ggml_type)i);
1626- if (trait->to_float && trait->type_size && type_name == trait->type_name ) {
1627- tensor_type = (ggml_type)i;
1628- }
1629- }
1630- }
1631-
1632- if (tensor_type != GGML_TYPE_COUNT) {
1633- result.emplace_back (tensor_pattern, tensor_type);
1634- } else {
1635- LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
1636- }
1637- }
1638- return result;
1639- }
1640-
16411650bool ModelLoader::tensor_should_be_converted (const TensorStorage& tensor_storage, ggml_type type) {
16421651 const std::string& name = tensor_storage.name ;
16431652 if (type != GGML_TYPE_COUNT) {
0 commit comments