Skip to content

Commit 76de8f0

Browse files
committed
Postprocessing to share lm_head weights to embedding
1 parent 36cd2ca commit 76de8f0

File tree

1 file changed

+313
-0
lines changed

1 file changed

+313
-0
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
import onnx
8+
import numpy as np
9+
from onnx import helper, numpy_helper, TensorProto
10+
from onnx.external_data_helper import load_external_data_for_model
11+
import argparse
12+
import os
13+
14+
def convert_gather_to_use_lm_head_weights_helper(graph, quant_weight_name, scales_name, zero_points_name, use_zero_points, hidden_size, scale_value_type):
15+
"""
16+
Replace the embed_tokens/Gather with operations that reuse the quantized lm_head weights
17+
"""
18+
# Find the Gather node for embeddings
19+
gather_node = None
20+
for node in graph.node:
21+
if node.name == "/model/embed_tokens/Gather":
22+
gather_node = node
23+
break
24+
25+
if gather_node is None:
26+
print("Warning: /model/embed_tokens/Gather not found, skipping weight tying optimization")
27+
return
28+
29+
# Save the original inputs and outputs of the Gather node
30+
embedding_weights_name = gather_node.input[0]
31+
input_ids = gather_node.input[1] # This is typically the input_ids tensor
32+
original_output = gather_node.output[0]
33+
34+
# Create new nodes to replace the Gather operation
35+
36+
# 1. Gather the quantized weights
37+
gathered_quant_weights = "gathered_quant_weights"
38+
gather_weights_node = helper.make_node(
39+
'Gather',
40+
inputs=[quant_weight_name, input_ids],
41+
outputs=[gathered_quant_weights],
42+
name='/model/embed_tokens/GatherQuantizedWeights',
43+
axis=0
44+
)
45+
46+
# 2. Gather the scales
47+
gathered_scales_raw = "gathered_scales_raw"
48+
gather_scales_node = helper.make_node(
49+
'Gather',
50+
inputs=[scales_name, input_ids],
51+
outputs=[gathered_scales_raw],
52+
name='/model/embed_tokens/GatherScales',
53+
axis=0
54+
)
55+
56+
# Reshape the scales to add an extra dimension for broadcasting
57+
unsqueeze_scales_node = helper.make_node(
58+
'Unsqueeze',
59+
inputs=[gathered_scales_raw, "scales_axes"],
60+
outputs=["gathered_scales"],
61+
name='/model/embed_tokens/UnsqueezeScales'
62+
)
63+
64+
# Create axes tensor for unsqueeze operation (adding dimension at axis 2)
65+
scales_axes = np.array([3], dtype=np.int64)
66+
scales_axes_name = "scales_axes"
67+
scales_axes_initializer = numpy_helper.from_array(scales_axes, scales_axes_name)
68+
graph.initializer.extend([scales_axes_initializer])
69+
70+
# Cast the quantized weights to floating point
71+
cast_weights_node = helper.make_node(
72+
'Cast',
73+
inputs=[gathered_quant_weights],
74+
outputs=["casted_quant_weights"],
75+
name='/model/embed_tokens/CastWeights',
76+
to=scale_value_type
77+
)
78+
79+
# Create a constant for the zero point (128 for symmetric quantization)
80+
zero_point_const = np.array([128], dtype=np.uint8)
81+
zero_point_const_name = "zero_offset_const"
82+
zero_point_initializer = numpy_helper.from_array(zero_point_const, zero_point_const_name)
83+
graph.initializer.extend([zero_point_initializer])
84+
85+
# Cast the zero point to the same type as weights
86+
cast_zp_node = helper.make_node(
87+
'Cast',
88+
inputs=[zero_point_const_name],
89+
outputs=["casted_zero_point"],
90+
name='/model/embed_tokens/CastZeroPoint',
91+
to=scale_value_type
92+
)
93+
94+
# Subtract zero point from casted weights
95+
sub_node = helper.make_node(
96+
'Sub',
97+
inputs=["casted_quant_weights", "casted_zero_point"],
98+
outputs=["centered_weights"],
99+
name='/model/embed_tokens/SubtractZeroPoint'
100+
)
101+
102+
# Multiply by scale
103+
dequantized_output = "dequantized_embeddings"
104+
mul_node = helper.make_node(
105+
'Mul',
106+
inputs=["centered_weights", "gathered_scales"],
107+
outputs=[dequantized_output],
108+
name='/model/embed_tokens/MultiplyByScale'
109+
)
110+
111+
# 4. Reshape to the final embedding shape
112+
# Get token shape
113+
shape_node = helper.make_node(
114+
'Shape',
115+
inputs=[input_ids],
116+
outputs=["token_shape"],
117+
name='/model/embed_tokens/GetTokenShape'
118+
)
119+
120+
# Add constant for hidden dimension
121+
const_hidden_size = np.array([hidden_size], dtype=np.int64)
122+
const_hidden_size_name = "const_hidden_size"
123+
hidden_size_initializer = numpy_helper.from_array(const_hidden_size, const_hidden_size_name)
124+
graph.initializer.extend([hidden_size_initializer])
125+
126+
# Concat to get final shape
127+
concat_final_shape = helper.make_node(
128+
'Concat',
129+
inputs=["token_shape", const_hidden_size_name],
130+
outputs=["final_shape"],
131+
name='/model/embed_tokens/ConcatFinalShape',
132+
axis=0
133+
)
134+
135+
# Final reshape to get the right output shape
136+
final_reshape_node = helper.make_node(
137+
'Reshape',
138+
inputs=[dequantized_output, "final_shape"],
139+
outputs=[original_output],
140+
name='/model/embed_tokens/FinalReshape'
141+
)
142+
143+
# Find and remove the original Gather node
144+
for i, node in enumerate(graph.node):
145+
if node.name == gather_node.name:
146+
del graph.node[i]
147+
break
148+
149+
# Remove the original embedding weights from initializers
150+
for i, initializer in enumerate(graph.initializer):
151+
if initializer.name == embedding_weights_name:
152+
print(f"Removing original embedding weights: {embedding_weights_name}")
153+
del graph.initializer[i]
154+
break
155+
156+
# Add all new nodes to the graph
157+
new_nodes = [
158+
gather_weights_node,
159+
gather_scales_node,
160+
unsqueeze_scales_node,
161+
cast_weights_node,
162+
cast_zp_node,
163+
sub_node,
164+
mul_node,
165+
shape_node,
166+
concat_final_shape,
167+
final_reshape_node
168+
]
169+
170+
# Modify this part to handle asymmetric quantization if needed
171+
if use_zero_points:
172+
# Gather the zero points
173+
gathered_zero_points = "gathered_zero_points"
174+
gather_zero_points_node = helper.make_node(
175+
'Gather',
176+
inputs=[zero_points_name, input_ids],
177+
outputs=[gathered_zero_points],
178+
name='/model/embed_tokens/GatherZeroPoints',
179+
axis=0
180+
)
181+
182+
# Unsqueeze zero points for broadcasting
183+
unsqueeze_zp_node = helper.make_node(
184+
'Unsqueeze',
185+
inputs=[gathered_zero_points, "scales_axes"],
186+
outputs=["unsqueezed_zero_points"],
187+
name='/model/embed_tokens/UnsqueezeZeroPoints'
188+
)
189+
190+
# Cast zero points to float
191+
cast_gathered_zp_node = helper.make_node(
192+
'Cast',
193+
inputs=["unsqueezed_zero_points"],
194+
outputs=["casted_gathered_zero_point"],
195+
name='/model/embed_tokens/CastGatheredZeroPoint',
196+
to=scale_value_type
197+
)
198+
199+
# Replace the standard zero_point subtraction with the gathered one
200+
sub_node.input[1] = "casted_gathered_zero_point"
201+
202+
# Insert the new nodes
203+
new_nodes.insert(2, gather_zero_points_node)
204+
new_nodes.insert(3, unsqueeze_zp_node)
205+
new_nodes.insert(6, cast_gathered_zp_node)
206+
207+
graph.node.extend(new_nodes)
208+
209+
print("Successfully tied embedding weights to quantized LM head weights using Cast+Mul operations")
210+
211+
212+
def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
213+
for attr in node.attribute:
214+
if attr.name == attribute_name:
215+
value = onnx.helper.get_attribute_value(attr)
216+
return value
217+
return None
218+
219+
220+
def find_graph_input(graph, input_name):
221+
for input in graph.input:
222+
if input.name == input_name:
223+
return input
224+
return None
225+
226+
227+
def find_graph_output(graph, output_name):
228+
for output in graph.output:
229+
if output.name == output_name:
230+
return output
231+
return None
232+
233+
234+
def get_tensor_type_from_graph(graph, tensor_name: str):
235+
tensor_type_map = {obj.name: obj.type for obj in graph.value_info}
236+
237+
if tensor_name in tensor_type_map:
238+
return tensor_type_map[tensor_name].tensor_type
239+
240+
g_input = find_graph_input(graph, tensor_name)
241+
if g_input:
242+
return g_input.type.tensor_type
243+
244+
g_output = find_graph_output(graph, tensor_name)
245+
if g_output:
246+
return g_output.type.tensor_type
247+
248+
return None
249+
250+
251+
def convert_gather_to_use_lm_head_weights(model_path, output_path, load_external_data=True):
252+
# Load the ONNX model
253+
print(f"Loading model from {model_path}...")
254+
model_name = "model.onnx"
255+
model = onnx.load(model_path + model_name, load_external_data=False)
256+
if load_external_data:
257+
load_external_data_for_model(model, model_path)
258+
graph = model.graph
259+
260+
# Find the MatMul node
261+
matmul_node = None
262+
for node in graph.node:
263+
if node.name.startswith("/lm_head/MatMul"):
264+
if node.op_type == "MatMulNBits":
265+
matmul_node = node
266+
break
267+
else:
268+
raise ValueError("/lm_head/MatMul node type is not MatMulNBits")
269+
270+
if matmul_node is None:
271+
raise ValueError("/lm_head/MatMul node not found in the model")
272+
273+
# Inputs A and scale has the same type, but scale is in external data. So we can only get the type from A here.
274+
scale_value_type = get_tensor_type_from_graph(graph, matmul_node.input[0])
275+
if scale_value_type:
276+
scale_value_type = scale_value_type.elem_type
277+
else:
278+
raise ValueError("/lm_head/MatMul scale value type is None")
279+
280+
hidden_size = get_node_attribute(matmul_node, "K")
281+
282+
use_zero_points = len(matmul_node.input) > 3
283+
284+
# If embedding weight tying is enabled, replace the embedding Gather
285+
convert_gather_to_use_lm_head_weights_helper(
286+
graph,
287+
matmul_node.input[1], # B (quantized weights)
288+
matmul_node.input[2], # scales
289+
matmul_node.input[3] if use_zero_points else None, # zero_points
290+
use_zero_points,
291+
hidden_size,
292+
scale_value_type
293+
)
294+
295+
# Save the modified model
296+
print(f"Saving model to {output_path}...")
297+
data_file = os.path.basename(output_path) + model_name + ".data"
298+
onnx.save(model, output_path + model_name, save_as_external_data=True, location=data_file)
299+
300+
print(f"Saved to {output_path} with external data in {data_file}")
301+
302+
if __name__ == "__main__":
303+
parser = argparse.ArgumentParser(description="Tie MatMulNBits with Gather for LM head weights")
304+
parser.add_argument("--input_path", type=str, help="Path to the input ONNX model")
305+
parser.add_argument("--output_path", type=str, help="Path to save the modified ONNX model")
306+
parser.add_argument("--load_external_data", required=False, type=bool, default=True, help="Whether to load external data")
307+
args = parser.parse_args()
308+
309+
convert_gather_to_use_lm_head_weights(
310+
args.input_path,
311+
args.output_path,
312+
args.load_external_data
313+
)

0 commit comments

Comments
 (0)