44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ # This tool supports the QC internal QA pipeline by quantizing, compiling,
8+ # and executing models under various configuration flags.
9+
710import argparse
811import importlib
912import logging
1013import os
1114import re
15+ import shutil
1216from pathlib import Path
1317
1418import executorch .backends .qualcomm .python .PyQnnManagerAdaptor as PyQnnManagerAdaptor
3438 to_edge_transform_and_lower_to_qnn ,
3539)
3640from executorch .examples .qualcomm .qaihub_scripts .utils .utils import preprocess_binary
37- from executorch .examples .qualcomm .utils import (
38- make_output_dir ,
39- make_quantizer ,
40- SimpleADB ,
41- )
41+ from executorch .examples .qualcomm .utils import make_quantizer , SimpleADB
4242from executorch .exir import ExecutorchBackendConfig
4343from executorch .exir .passes .memory_planning_pass import MemoryPlanningPass
4444from torchao .quantization import pt2e
4545from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
4646
47+ INPUT_ORDER = "input_order"
48+
4749
4850def get_logger ():
4951 logger = logging .getLogger ("examples.qualcomm.util_scripts.cli" )
@@ -74,6 +76,7 @@ def fill_tensor_info(info, qnn_tensors, category):
7476 "offset" : encoding .data ["offset" ].tolist (),
7577 "axis" : encoding .axis ,
7678 }
79+
7780 info [category ].append (
7881 {
7982 "name" : tensor .GetName (),
@@ -106,6 +109,26 @@ def fill_tensor_info(info, qnn_tensors, category):
106109 return tensor_info
107110
108111
112+ class InputListParser :
113+ def __init__ (self , input_list ):
114+ self .input_list = input_list
115+
116+ def __iter__ (self ):
117+ with open (self .input_list , "r" ) as f :
118+ for line in re .split (r"\r?\n" , f .read ()):
119+ if not line :
120+ continue
121+ split_line = line .strip ().split (" " )
122+ inputs = {}
123+ if ":=" in line :
124+ for input_assignment in split_line :
125+ name , path = input_assignment .split (":=" )
126+ inputs [name ] = torch .load (path , weights_only = True )
127+ else :
128+ inputs = [torch .load (t , weights_only = True ) for t in split_line ]
129+ yield inputs
130+
131+
109132def quantize (args ):
110133 logger = get_logger ()
111134
@@ -131,15 +154,21 @@ def quantize(args):
131154 ep_prepared = prepare_pt2e (ep .module (), quantizer )
132155 logger .info (f"perform calibration on { args .artifact } " )
133156 # step 2: perform calibration
134- with open (args .input_list , "r" ) as f :
135- for line in f .read ().split ("\n " )[:- 1 ]:
136- inputs = [torch .load (t , weights_only = True ) for t in line .split (" " )]
137- ep_prepared (* inputs )
157+ input_list_parser = InputListParser (args .input_list )
158+ graph_input_names = [
159+ spec .arg .name
160+ for spec in ep .graph_signature .input_specs
161+ if spec .kind .name == "USER_INPUT"
162+ ]
163+ for inputs in input_list_parser :
164+ if isinstance (inputs , dict ):
165+ inputs = [inputs [name ] for name in graph_input_names ]
166+ ep_prepared (* inputs )
138167 # step 3: use convert_pt2e to fix encodings of QDQ pairs
139168 logger .info (f"saving calibrated model for { args .artifact } " )
140169 ep_converted = convert_pt2e (ep_prepared )
141170 ep_quantized = torch .export .export (ep_converted , tuple (inputs ))
142- make_output_dir (args .output_folder )
171+ os . makedirs (args .output_folder , exist_ok = True )
143172 torch .export .save (
144173 ep_quantized , f"{ args .output_folder } /{ Path (args .artifact ).stem } _quantized.pt2"
145174 )
@@ -155,7 +184,7 @@ def compile(args):
155184 )
156185
157186 file_name , extension = Path (args .artifact ).stem , Path (args .artifact ).suffix
158- make_output_dir (args .output_folder )
187+ os . makedirs (args .output_folder , exist_ok = True )
159188 # setup compiler spec dedicated to QNN HTP backend
160189 backend_options = generate_htp_compiler_spec (use_fp16 = True )
161190 # setup general compiler spec for QNN
@@ -201,12 +230,13 @@ def compile(args):
201230
202231 for user_pass in user_passes :
203232 passes [user_pass ][QCOM_PASS_ACTIVATE_KEY ] = True
204-
233+ input_order = { INPUT_ORDER : ep . graph_signature . user_inputs }
205234 edge_prog_mgr = to_edge_transform_and_lower_to_qnn (
206235 module = ep .module (),
207236 inputs = sample_inputs ,
208237 compiler_specs = compiler_specs ,
209238 passes_job = passes ,
239+ constant_methods = input_order ,
210240 )
211241 # step 2: write pte files and store final graph
212242 logger .info (f"exporting { file_name } .pte" )
@@ -227,15 +257,30 @@ def execute(args):
227257
228258 pte_name = Path (args .artifact ).stem
229259
260+ # get input order
261+ from executorch .runtime import Runtime , Verification
262+
263+ et_runtime = Runtime .get ()
264+ program = et_runtime .load_program (
265+ args .artifact ,
266+ verification = Verification .Minimal ,
267+ )
268+ input_order_func = program .load_method (INPUT_ORDER )
269+ input_order = input_order_func .execute ([])
270+
230271 # load input files
231272 logger .info ("loading user inputs" )
273+ input_list_parser = InputListParser (args .input_list )
232274 user_inputs = []
233- with open (args .input_list , "r" ) as f :
234- for line in f .read ().split ("\n " )[:- 1 ]:
235- inputs , input_names = [], ""
236- for data in line .split (" " ):
237- input_names += f"{ Path (data ).stem } .raw "
238- inputs .append (torch .load (data , weights_only = True ))
275+ for inputs in input_list_parser :
276+ if isinstance (inputs , dict ):
277+ ordered_inputs = []
278+ # since io_info is dict and it is ordered in python
279+ # we use it to reorder input assignments here
280+ for name in input_order :
281+ ordered_inputs .append (inputs [name ])
282+ user_inputs .append (ordered_inputs )
283+ else :
239284 user_inputs .append (inputs )
240285
241286 logger .info ("retrieving graph I/O" )
@@ -247,7 +292,6 @@ def execute(args):
247292 backend_options = backend_options ,
248293 )
249294 io_info = get_io_info (args .artifact , compiler_specs )
250-
251295 logger .info ("preparing ADB connection" )
252296 # leverage SimpleADB for e2e inference
253297 adb = SimpleADB (
@@ -263,11 +307,16 @@ def execute(args):
263307 )
264308
265309 logger .info ("pushing QNN libraries & other artifacts" )
310+
266311 adb .push (inputs = user_inputs )
267312
268313 logger .info ("starting inference" )
269314 adb .execute ()
270315
316+ tmp_dir = f"{ args .output_folder } /tmp_outputs"
317+ os .makedirs (tmp_dir , exist_ok = True )
318+ os .makedirs (args .output_folder , exist_ok = True )
319+
271320 def post_process ():
272321 torch_to_numpy_dtype_dict = {
273322 torch .bool : np .dtype ("bool" ),
@@ -283,11 +332,14 @@ def post_process():
283332 torch .complex128 : np .dtype ("complex128" ),
284333 }
285334 output_info = io_info ["outputs" ]
286- output_folder = f"{ args . output_folder } /outputs"
287- for _ , f in enumerate (os .listdir (output_folder )):
288- filename = os .path .join (output_folder , f )
289- match_res = re .match (r".*([0-9]+)_([0-9]+)\.raw$" , filename )
335+ tmp_output_folder = f"{ tmp_dir } /outputs"
336+ for _ , f in enumerate (os .listdir (tmp_output_folder )):
337+ filename = os .path .join (tmp_output_folder , f )
338+ match_res = re .match (r".*output_ ([0-9]+)_([0-9]+)\.raw$" , filename )
290339 data_index , output_index = int (match_res .group (1 )), int (match_res .group (2 ))
340+
341+ output_result_folder = f"{ args .output_folder } /Result_{ data_index } "
342+ os .makedirs (output_result_folder , exist_ok = True )
291343 output = np .fromfile (
292344 filename ,
293345 dtype = eval (
@@ -297,13 +349,11 @@ def post_process():
297349 output = torch .from_numpy (
298350 output .reshape (output_info [output_index ]["shape" ])
299351 )
300- torch .save (
301- output , f"{ args .output_folder } /output_{ data_index } _{ output_index } .pt"
302- )
352+ torch .save (output , f"{ output_result_folder } /output_{ output_index } .pt" )
303353
304354 logger .info ("collecting output data" )
305- make_output_dir ( args . output_folder )
306- adb . pull ( args . output_folder , post_process )
355+ adb . pull ( tmp_dir , post_process )
356+ shutil . rmtree ( tmp_dir )
307357 logger .info (f"execution finished, please check { args .output_folder } for results" )
308358
309359
0 commit comments