Skip to content

Commit f84d45b

Browse files
authored
Qualcomm AI Engine Direct - GLM1.5B (#15691)
### Summary GLM Enablement `python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s $DEVICE -m SM8750 --temperature 0 --model_mode kv --max_seq_len 128 --decoder_model glm-1_5b --prompt "Could you tell me about Facebook?"` ### Test plan `python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_glm1_5b --model SM8750 --build_folder build-android/ --executorch_root . -s $DEVICE --artifact ./glm1_5b`
1 parent cec1834 commit f84d45b

File tree

15 files changed

+268
-9
lines changed

15 files changed

+268
-9
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5862,6 +5862,9 @@ def setUp(self):
58625862
"gemma3-1b": TestExampleLLMScript.LlmSpecs(
58635863
SM8650=70, SM8750=100, ppl=23, pte_size=1_200_000_000
58645864
), # 1.2 GB
5865+
"glm-1_5b": TestExampleLLMScript.LlmSpecs(
5866+
SM8650=42, SM8750=52, ppl=21, pte_size=1_100_000_000
5867+
), # 1.1 GB
58655868
"phi_4_mini": TestExampleLLMScript.LlmSpecs(
58665869
SM8650=14, SM8750=19, ppl=12, pte_size=4_000_000_000
58675870
), # 4GB

examples/models/glm/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.glm.convert_weights import convert_weights
5+
from executorch.examples.models.llama.model import Llama2Model
6+
7+
8+
class GLMModel(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"GLMModel",
15+
"convert_weights",
16+
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 6144,
5+
"n_heads": 16,
6+
"head_dim": 128,
7+
"n_kv_heads": 4,
8+
"n_layers": 28,
9+
"norm_eps": 1e-05,
10+
"rope_theta": 10000.0,
11+
"use_scaled_rope": false,
12+
"vocab_size": 59264,
13+
"use_hf_rope": true,
14+
"attention_qkv_bias": false,
15+
"use_qk_norm": false,
16+
"model_architecture" : "GlmForCausalLM"
17+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
import os
3+
from typing import Dict
4+
5+
import torch
6+
from safetensors.torch import load_file
7+
from torchtune.models.convert_weights import get_mapped_key
8+
9+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
10+
_GLM_FROM_META = {
11+
"tok_embeddings.weight": "model.embed_tokens.weight",
12+
"norm.weight": "model.norm.weight",
13+
"output.weight": "lm_head.weight",
14+
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
15+
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
16+
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
17+
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
18+
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
19+
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
20+
"layers.{}.feed_forward.gate_up_proj.weight": "model.layers.{}.mlp.gate_up_proj.weight",
21+
"layers.{}.feed_forward.down_proj.weight": "model.layers.{}.mlp.down_proj.weight",
22+
}
23+
24+
25+
def glm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
26+
"""
27+
Convert a state dict from torchtune's format to Meta's format. This function
28+
doesn't handle any sharding or splitting of state dicts. It follows the
29+
state_dict IN -> state_dict OUT pattern.
30+
31+
Args:
32+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
33+
34+
Returns:
35+
Dict[str, torch.Tensor]: State dict in Meta's format.
36+
"""
37+
converted_state_dict = {}
38+
inverted_mapping_dict = {v: k for k, v in _GLM_FROM_META.items()}
39+
40+
for key, value in state_dict.items():
41+
new_key = get_mapped_key(key, inverted_mapping_dict)
42+
converted_state_dict[new_key] = value
43+
44+
if "lm_head.weight" not in state_dict:
45+
converted_state_dict["output.weight"] = converted_state_dict[
46+
"tok_embeddings.weight"
47+
]
48+
49+
return converted_state_dict
50+
51+
52+
def convert_weights(input_dir: str, output_file: str) -> None:
53+
pt_path = os.path.join(input_dir, "model.safetensors")
54+
print("Loading checkpoint from file...")
55+
sd = load_file(pt_path)
56+
57+
print("Converting checkpoint...")
58+
sd = glm_tune_to_meta(sd)
59+
60+
print("Saving checkpoint...")
61+
torch.save(sd, output_file)
62+
print("Done.")
63+
64+
65+
def main():
66+
parser = argparse.ArgumentParser(description="Convert GLM weights to Meta format.")
67+
parser.add_argument(
68+
"input_dir",
69+
type=str,
70+
help="Path to directory containing checkpoint files",
71+
)
72+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
73+
74+
args = parser.parse_args()
75+
convert_weights(args.input_dir, args.output)
76+
77+
78+
if __name__ == "__main__":
79+
main()

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ class ModelArgs:
131131
attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
132132
# Hybrid models can have layer types different from attention
133133
layer_types: Optional[list] = None
134+
model_architecture: Optional[str] = (
135+
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
136+
)
134137

135138
def __post_init__(self):
136139
if self.n_kv_heads is None:

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ This file provides you the instructions to run LLM Decoder model with different
99
1. Codegen2 1B
1010
1. Gemma 2B
1111
1. Gemma3 1B
12+
1. GLM 1.5B
1213
1. Granite3.3 2B
1314
1. Phi4-mini-instruct
1415
1. QWEN2.5 0.5B / 1.5B
@@ -65,7 +66,10 @@ Follow the [instructions](https://www.llama.com/) to download models.
6566
At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`.
6667

6768

68-
### Step3: Run default examples using hybrid mode for smaller models and kv mode for larger models.
69+
### Step3: Run default examples.
70+
#### Note:
71+
All example scripts below use hybrid mode, which is optimized for on-device performance. However, compiling a model in hybrid mode can consume a significant amount of memory on the host machine—sometimes up to ~100 GB. If your host machine has limited memory, it is highly recommended to switch from `--model_mode hybrid` to `--model_mode kv` and remove the `--prefill_ar_len` flag.
72+
6973
#### LLAMA2
7074
```bash
7175
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
@@ -80,7 +84,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
8084
#### LLAMA3.2 3B Instruct
8185
Default example using kv mode.
8286
```bash
83-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
87+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
8488
```
8589

8690
#### Codegen2
@@ -102,6 +106,12 @@ Default example using hybrid mode
102106
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
103107
```
104108

109+
#### GLM 1.5B
110+
Default example using hybrid mode
111+
```bash
112+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model glm-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
113+
```
114+
105115
#### Granite3.3 2B
106116
Default example using hybrid mode
107117
```bash
@@ -111,7 +121,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
111121
#### Phi4-mini-instruct
112122
Default example using kv mode.
113123
```bash
114-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
124+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
115125
```
116126

117127
#### QWEN2.5 0.5B
@@ -123,7 +133,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
123133
#### QWEN2.5 1.5B
124134
Default example using kv mode
125135
```bash
126-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
136+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
127137
```
128138

129139
#### QWEN3 0.6B
@@ -135,7 +145,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
135145
#### QWEN3 1.7B
136146
Default example using hybrid mode
137147
```bash
138-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
148+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
139149
```
140150

141151
#### SmolLM2
@@ -147,7 +157,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
147157
#### SmolLM3
148158
Default example using kv mode.
149159
```bash
150-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
160+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
151161
```
152162

153163
### KV Cache update mechanism

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
)
1717
from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
1818
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights
19+
20+
from executorch.examples.models.glm import convert_weights as convert_glm_weights
1921
from executorch.examples.models.granite import (
2022
convert_weights as convert_granite_weights,
2123
)
@@ -44,6 +46,7 @@
4446
CodegenQuantRecipe,
4547
Gemma3QuantRecipe,
4648
Gemma_2BQuantRecipe,
49+
GLM_1_5B_InstructQuantRecipe,
4750
Granite_3_3_2B_InstructQuantRecipe,
4851
Llama3_1BQuantRecipe,
4952
Llama3_3BQuantRecipe,
@@ -293,6 +296,26 @@ class Gemma3(LLMModelConfig):
293296
quant_recipe = Gemma3QuantRecipe
294297

295298

299+
@register_llm_model("glm-1_5b")
300+
@dataclass(init=False, frozen=True)
301+
class GLM_1_5B(LLMModelConfig):
302+
repo_id: str = "THUDM/glm-edge-1.5b-chat"
303+
params_path: str = os.path.join(
304+
BASE_DIR, "../../../models/glm/config/1_5b_config.json"
305+
)
306+
convert_weights = convert_glm_weights
307+
transform_weight = True
308+
instruct_model = True
309+
num_sharding = 1
310+
group_size = 32
311+
masked_softmax = False
312+
seq_mse_candidates = 0
313+
r1 = False
314+
r2 = False
315+
r3 = False
316+
quant_recipe = GLM_1_5B_InstructQuantRecipe
317+
318+
296319
@register_llm_model("granite_3_3-2b_instruct")
297320
@dataclass(init=False, frozen=True)
298321
class Granite_3_3_2b_Instruct(LLMModelConfig):

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@
2727
"smollm2_135m": "smollm2_135m",
2828
"smollm3-3b": "smollm3",
2929
"codegen2_1b": "codegen",
30+
"glm-1_5b": "glm",
3031
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,15 +1309,23 @@ def export_llama(args) -> None:
13091309
# For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json.
13101310
runtime_tokenizer_path = tokenizer_artifacts[-3]
13111311
else:
1312+
if args.decoder_model == "glm-1_5b":
1313+
with open(tokenizer_config, "r+") as file:
1314+
data = json.load(file)
1315+
# Verified with HF flow and it uses <|user|> as eos condition
1316+
data["bos_token"] = "<|user|>"
1317+
data["eos_token"] = "<|user|>"
1318+
file.seek(0)
1319+
json.dump(data, file, indent=4)
1320+
file.truncate()
13121321
runtime_tokenizer_path = tokenizer_artifacts[-1]
1322+
13131323
tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config)
13141324

13151325
if args.decoder_model == "codegen2_1b":
13161326
# Override the default BOS and EOS token IDs for codegen2_1b
13171327
tokenizer.bos_id = 1
13181328
tokenizer.eos_id = 2
1319-
1320-
# TODO: Remove this once error is resolved.
13211329
elif args.decoder_model == "phi_4_mini":
13221330
with open(runtime_tokenizer_path, "r+") as file:
13231331
data = json.load(file)

examples/qualcomm/oss_scripts/llama/model/feed_forward.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,54 @@ def forward(self, x):
8888
hidden_states = self.act(hidden_states)
8989
hidden_states = self.fc_out(hidden_states)
9090
return hidden_states
91+
92+
93+
@register_feed_forward("GlmForCausalLM")
94+
class GLMFeedForward(FeedForwardBase):
95+
"""FeedForward with gate_up_proj and down_proj"""
96+
97+
def __init__(self, args: ModelArgs): # in MLP: intermediate_size= 4 * embed_dim
98+
super().__init__()
99+
100+
assert args.hidden_dim is not None
101+
self.dim = args.dim
102+
self.hidden_dim = args.hidden_dim
103+
104+
self.gate_up_proj = torch.nn.Linear(args.dim, 2 * args.hidden_dim, bias=False)
105+
self.down_proj = torch.nn.Linear(args.hidden_dim, args.dim, bias=False)
106+
self.activation_fn = args.act_fn.get_function()
107+
108+
def prepare_feedfoward_conv(self):
109+
self.gate_up_proj_conv = torch.nn.Conv2d(
110+
self.dim, 2 * self.hidden_dim, 1, bias=False
111+
)
112+
self.down_proj_conv = torch.nn.Conv2d(self.hidden_dim, self.dim, 1, bias=False)
113+
114+
self.forward_no_conv = self.forward
115+
self.forward = self.forward_feedfoward_conv
116+
117+
self.gate_up_proj_conv.weight.data.copy_(
118+
self.gate_up_proj.weight[:, :, None, None]
119+
)
120+
self.down_proj_conv.weight.data.copy_(self.down_proj.weight[:, :, None, None])
121+
122+
del self.gate_up_proj
123+
del self.down_proj
124+
125+
def forward_feedfoward_conv(self, x):
126+
bsz, _, _ = x.size()
127+
x = torch.reshape(x, (bsz, -1, 1, self.dim))
128+
x = x.transpose(1, 3) # Transpose right before and after Conv
129+
up_states = self.gate_up_proj_conv(x)
130+
gate, up_states = up_states.chunk(2, dim=1)
131+
up_states = up_states * self.activation_fn(gate)
132+
x = self.down_proj_conv(up_states)
133+
x = x.transpose(1, 3)
134+
x = torch.reshape(x, (bsz, -1, self.dim))
135+
return x
136+
137+
def forward(self, x):
138+
up_states = self.gate_up_proj(x)
139+
gate, up_states = up_states.chunk(2, dim=-1)
140+
up_states = up_states * self.activation_fn(gate)
141+
return self.down_proj(up_states)

0 commit comments

Comments
 (0)