99import logging
1010import os
1111import re
12+ import shutil
1213from pathlib import Path
1314
1415import executorch .backends .qualcomm .python .PyQnnManagerAdaptor as PyQnnManagerAdaptor
3435 to_edge_transform_and_lower_to_qnn ,
3536)
3637from executorch .examples .qualcomm .qaihub_scripts .utils .utils import preprocess_binary
37- from executorch .examples .qualcomm .utils import (
38- make_output_dir ,
39- make_quantizer ,
40- SimpleADB ,
41- )
38+ from executorch .examples .qualcomm .utils import make_quantizer , SimpleADB
4239from executorch .exir import ExecutorchBackendConfig
4340from executorch .exir .passes .memory_planning_pass import MemoryPlanningPass
4441from torchao .quantization import pt2e
4542from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
4643
44+ INPUT_ORDER = "input_order"
45+
4746
4847def get_logger ():
4948 logger = logging .getLogger ("examples.qualcomm.util_scripts.cli" )
@@ -74,6 +73,7 @@ def fill_tensor_info(info, qnn_tensors, category):
7473 "offset" : encoding .data ["offset" ].tolist (),
7574 "axis" : encoding .axis ,
7675 }
76+
7777 info [category ].append (
7878 {
7979 "name" : tensor .GetName (),
@@ -106,6 +106,26 @@ def fill_tensor_info(info, qnn_tensors, category):
106106 return tensor_info
107107
108108
109+ class InputListParser :
110+ def __init__ (self , input_list ):
111+ self .input_list = input_list
112+
113+ def __iter__ (self ):
114+ with open (self .input_list , "r" ) as f :
115+ for line in re .split (r"\r?\n" , f .read ()):
116+ if not line :
117+ continue
118+ split_line = line .strip ().split (" " )
119+ inputs = {}
120+ if ":=" in line :
121+ for input_assignment in split_line :
122+ name , path = input_assignment .split (":=" )
123+ inputs [name ] = torch .load (path , weights_only = True )
124+ else :
125+ inputs = [torch .load (t , weights_only = True ) for t in split_line ]
126+ yield inputs
127+
128+
109129def quantize (args ):
110130 logger = get_logger ()
111131
@@ -131,15 +151,21 @@ def quantize(args):
131151 ep_prepared = prepare_pt2e (ep .module (), quantizer )
132152 logger .info (f"perform calibration on { args .artifact } " )
133153 # step 2: perform calibration
134- with open (args .input_list , "r" ) as f :
135- for line in f .read ().split ("\n " )[:- 1 ]:
136- inputs = [torch .load (t , weights_only = True ) for t in line .split (" " )]
137- ep_prepared (* inputs )
154+ input_list_parser = InputListParser (args .input_list )
155+ graph_input_names = [
156+ spec .arg .name
157+ for spec in ep .graph_signature .input_specs
158+ if spec .kind .name == "USER_INPUT"
159+ ]
160+ for inputs in input_list_parser :
161+ if isinstance (inputs , dict ):
162+ inputs = [inputs [name ] for name in graph_input_names ]
163+ ep_prepared (* inputs )
138164 # step 3: use convert_pt2e to fix encodings of QDQ pairs
139165 logger .info (f"saving calibrated model for { args .artifact } " )
140166 ep_converted = convert_pt2e (ep_prepared )
141167 ep_quantized = torch .export .export (ep_converted , tuple (inputs ))
142- make_output_dir (args .output_folder )
168+ os . makedirs (args .output_folder , exist_ok = True )
143169 torch .export .save (
144170 ep_quantized , f"{ args .output_folder } /{ Path (args .artifact ).stem } _quantized.pt2"
145171 )
@@ -155,7 +181,7 @@ def compile(args):
155181 )
156182
157183 file_name , extension = Path (args .artifact ).stem , Path (args .artifact ).suffix
158- make_output_dir (args .output_folder )
184+ os . makedirs (args .output_folder , exist_ok = True )
159185 # setup compiler spec dedicated to QNN HTP backend
160186 backend_options = generate_htp_compiler_spec (use_fp16 = True )
161187 # setup general compiler spec for QNN
@@ -201,12 +227,13 @@ def compile(args):
201227
202228 for user_pass in user_passes :
203229 passes [user_pass ][QCOM_PASS_ACTIVATE_KEY ] = True
204-
230+ input_order = { INPUT_ORDER : ep . graph_signature . user_inputs }
205231 edge_prog_mgr = to_edge_transform_and_lower_to_qnn (
206232 module = ep .module (),
207233 inputs = sample_inputs ,
208234 compiler_specs = compiler_specs ,
209235 passes_job = passes ,
236+ constant_methods = input_order ,
210237 )
211238 # step 2: write pte files and store final graph
212239 logger .info (f"exporting { file_name } .pte" )
@@ -227,15 +254,30 @@ def execute(args):
227254
228255 pte_name = Path (args .artifact ).stem
229256
257+ # get input order
258+ from executorch .runtime import Runtime , Verification
259+
260+ et_runtime = Runtime .get ()
261+ program = et_runtime .load_program (
262+ args .artifact ,
263+ verification = Verification .Minimal ,
264+ )
265+ input_order_func = program .load_method (INPUT_ORDER )
266+ input_order = input_order_func .execute ([])
267+
230268 # load input files
231269 logger .info ("loading user inputs" )
270+ input_list_parser = InputListParser (args .input_list )
232271 user_inputs = []
233- with open (args .input_list , "r" ) as f :
234- for line in f .read ().split ("\n " )[:- 1 ]:
235- inputs , input_names = [], ""
236- for data in line .split (" " ):
237- input_names += f"{ Path (data ).stem } .raw "
238- inputs .append (torch .load (data , weights_only = True ))
272+ for inputs in input_list_parser :
273+ if isinstance (inputs , dict ):
274+ ordered_inputs = []
275+ # since io_info is dict and it is ordered in python
276+ # we use it to reorder input assignments here
277+ for name in input_order :
278+ ordered_inputs .append (inputs [name ])
279+ user_inputs .append (ordered_inputs )
280+ else :
239281 user_inputs .append (inputs )
240282
241283 logger .info ("retrieving graph I/O" )
@@ -247,7 +289,6 @@ def execute(args):
247289 backend_options = backend_options ,
248290 )
249291 io_info = get_io_info (args .artifact , compiler_specs )
250-
251292 logger .info ("preparing ADB connection" )
252293 # leverage SimpleADB for e2e inference
253294 adb = SimpleADB (
@@ -263,11 +304,16 @@ def execute(args):
263304 )
264305
265306 logger .info ("pushing QNN libraries & other artifacts" )
307+
266308 adb .push (inputs = user_inputs )
267309
268310 logger .info ("starting inference" )
269311 adb .execute ()
270312
313+ tmp_dir = f"{ args .output_folder } /tmp_outputs"
314+ os .makedirs (tmp_dir , exist_ok = True )
315+ os .makedirs (args .output_folder , exist_ok = True )
316+
271317 def post_process ():
272318 torch_to_numpy_dtype_dict = {
273319 torch .bool : np .dtype ("bool" ),
@@ -283,11 +329,14 @@ def post_process():
283329 torch .complex128 : np .dtype ("complex128" ),
284330 }
285331 output_info = io_info ["outputs" ]
286- output_folder = f"{ args . output_folder } /outputs"
287- for _ , f in enumerate (os .listdir (output_folder )):
288- filename = os .path .join (output_folder , f )
289- match_res = re .match (r".*([0-9]+)_([0-9]+)\.raw$" , filename )
332+ tmp_output_folder = f"{ tmp_dir } /outputs"
333+ for _ , f in enumerate (os .listdir (tmp_output_folder )):
334+ filename = os .path .join (tmp_output_folder , f )
335+ match_res = re .match (r".*output_ ([0-9]+)_([0-9]+)\.raw$" , filename )
290336 data_index , output_index = int (match_res .group (1 )), int (match_res .group (2 ))
337+
338+ output_result_folder = f"{ args .output_folder } /Result_{ data_index } "
339+ os .makedirs (output_result_folder , exist_ok = True )
291340 output = np .fromfile (
292341 filename ,
293342 dtype = eval (
@@ -297,13 +346,11 @@ def post_process():
297346 output = torch .from_numpy (
298347 output .reshape (output_info [output_index ]["shape" ])
299348 )
300- torch .save (
301- output , f"{ args .output_folder } /output_{ data_index } _{ output_index } .pt"
302- )
349+ torch .save (output , f"{ output_result_folder } /output_{ output_index } .pt" )
303350
304351 logger .info ("collecting output data" )
305- make_output_dir ( args . output_folder )
306- adb . pull ( args . output_folder , post_process )
352+ adb . pull ( tmp_dir , post_process )
353+ shutil . rmtree ( tmp_dir )
307354 logger .info (f"execution finished, please check { args .output_folder } for results" )
308355
309356
0 commit comments