Skip to content

Commit 199e675

Browse files
authored
feat: support for --tensor-type-rules on generation modes (#932)
1 parent 742a733 commit 199e675

File tree

5 files changed

+59
-49
lines changed

5 files changed

+59
-49
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,10 +1241,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
12411241
exit(1);
12421242
}
12431243

1244-
if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
1245-
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
1246-
}
1247-
12481244
if (params.mode == VID_GEN && params.video_frames <= 0) {
12491245
fprintf(stderr, "warning: --video-frames must be at least 1\n");
12501246
exit(1);
@@ -1756,6 +1752,7 @@ int main(int argc, const char* argv[]) {
17561752
params.lora_model_dir.c_str(),
17571753
params.embedding_dir.c_str(),
17581754
params.photo_maker_path.c_str(),
1755+
params.tensor_type_rules.c_str(),
17591756
vae_decode_only,
17601757
true,
17611758
params.n_threads,

model.cpp

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,15 +1254,59 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
12541254
return wtype_stat;
12551255
}
12561256

1257-
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
1257+
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
1258+
std::vector<std::pair<std::string, ggml_type>> result;
1259+
for (const auto& item : split_string(tensor_type_rules, ',')) {
1260+
if (item.size() == 0)
1261+
continue;
1262+
std::string::size_type pos = item.find('=');
1263+
if (pos == std::string::npos) {
1264+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1265+
continue;
1266+
}
1267+
std::string tensor_pattern = item.substr(0, pos);
1268+
std::string type_name = item.substr(pos + 1);
1269+
1270+
ggml_type tensor_type = GGML_TYPE_COUNT;
1271+
1272+
if (type_name == "f32") {
1273+
tensor_type = GGML_TYPE_F32;
1274+
} else {
1275+
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
1276+
auto trait = ggml_get_type_traits((ggml_type)i);
1277+
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
1278+
tensor_type = (ggml_type)i;
1279+
}
1280+
}
1281+
}
1282+
1283+
if (tensor_type != GGML_TYPE_COUNT) {
1284+
result.emplace_back(tensor_pattern, tensor_type);
1285+
} else {
1286+
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1287+
}
1288+
}
1289+
return result;
1290+
}
1291+
1292+
void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_rules) {
1293+
auto map_rules = parse_tensor_type_rules(tensor_type_rules);
12581294
for (auto& [name, tensor_storage] : tensor_storage_map) {
1259-
if (!starts_with(name, prefix)) {
1295+
ggml_type dst_type = wtype;
1296+
for (const auto& tensor_type_rule : map_rules) {
1297+
std::regex pattern(tensor_type_rule.first);
1298+
if (std::regex_search(name, pattern)) {
1299+
dst_type = tensor_type_rule.second;
1300+
break;
1301+
}
1302+
}
1303+
if (dst_type == GGML_TYPE_COUNT) {
12601304
continue;
12611305
}
1262-
if (!tensor_should_be_converted(tensor_storage, wtype)) {
1306+
if (!tensor_should_be_converted(tensor_storage, dst_type)) {
12631307
continue;
12641308
}
1265-
tensor_storage.expected_type = wtype;
1309+
tensor_storage.expected_type = dst_type;
12661310
}
12671311
}
12681312

@@ -1603,41 +1647,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
16031647
return true;
16041648
}
16051649

1606-
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
1607-
std::vector<std::pair<std::string, ggml_type>> result;
1608-
for (const auto& item : split_string(tensor_type_rules, ',')) {
1609-
if (item.size() == 0)
1610-
continue;
1611-
std::string::size_type pos = item.find('=');
1612-
if (pos == std::string::npos) {
1613-
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1614-
continue;
1615-
}
1616-
std::string tensor_pattern = item.substr(0, pos);
1617-
std::string type_name = item.substr(pos + 1);
1618-
1619-
ggml_type tensor_type = GGML_TYPE_COUNT;
1620-
1621-
if (type_name == "f32") {
1622-
tensor_type = GGML_TYPE_F32;
1623-
} else {
1624-
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
1625-
auto trait = ggml_get_type_traits((ggml_type)i);
1626-
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
1627-
tensor_type = (ggml_type)i;
1628-
}
1629-
}
1630-
}
1631-
1632-
if (tensor_type != GGML_TYPE_COUNT) {
1633-
result.emplace_back(tensor_pattern, tensor_type);
1634-
} else {
1635-
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
1636-
}
1637-
}
1638-
return result;
1639-
}
1640-
16411650
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
16421651
const std::string& name = tensor_storage.name;
16431652
if (type != GGML_TYPE_COUNT) {

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class ModelLoader {
292292
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
293293
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
294294
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
295-
void set_wtype_override(ggml_type wtype, std::string prefix = "");
295+
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
296296
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
297297
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
298298
std::set<std::string> ignore_tensors = {},

stable-diffusion.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,12 @@ class StableDiffusionGGML {
304304
}
305305

306306
LOG_INFO("Version: %s ", model_version_to_str[version]);
307-
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
308-
? (ggml_type)sd_ctx_params->wtype
309-
: GGML_TYPE_COUNT;
310-
if (wtype != GGML_TYPE_COUNT) {
311-
model_loader.set_wtype_override(wtype);
307+
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
308+
? (ggml_type)sd_ctx_params->wtype
309+
: GGML_TYPE_COUNT;
310+
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
311+
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
312+
model_loader.set_wtype_override(wtype, tensor_type_rules);
312313
}
313314

314315
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
@@ -2325,6 +2326,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
23252326
"lora_model_dir: %s\n"
23262327
"embedding_dir: %s\n"
23272328
"photo_maker_path: %s\n"
2329+
"tensor_type_rules: %s\n"
23282330
"vae_decode_only: %s\n"
23292331
"free_params_immediately: %s\n"
23302332
"n_threads: %d\n"
@@ -2354,6 +2356,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
23542356
SAFE_STR(sd_ctx_params->lora_model_dir),
23552357
SAFE_STR(sd_ctx_params->embedding_dir),
23562358
SAFE_STR(sd_ctx_params->photo_maker_path),
2359+
SAFE_STR(sd_ctx_params->tensor_type_rules),
23572360
BOOL_STR(sd_ctx_params->vae_decode_only),
23582361
BOOL_STR(sd_ctx_params->free_params_immediately),
23592362
sd_ctx_params->n_threads,

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ typedef struct {
167167
const char* lora_model_dir;
168168
const char* embedding_dir;
169169
const char* photo_maker_path;
170+
const char* tensor_type_rules;
170171
bool vae_decode_only;
171172
bool free_params_immediately;
172173
int n_threads;

0 commit comments

Comments
 (0)