diff --git a/evaluate.py b/evaluate.py index 7df26d379..53052143e 100644 --- a/evaluate.py +++ b/evaluate.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. +import gc import os +import time import numpy as np + #import mxnet.gluon as gluon import tvm from tvm import relay @@ -25,6 +28,7 @@ from tvm import autotvm from tvm.contrib import utils, ndk from tvm.topi import testing +from memory_profiler import memory_usage # DEELVIN-207 # from tvm.relay.op import register_mixed_precision_conversion @@ -402,9 +406,10 @@ def get_args(): args.rpc_key, host=args.rpc_tracker_host, port=args.rpc_tracker_port, - number=50, + number=1, + repeat=3, timeout=15, - #min_repeat_ms=150, + min_repeat_ms=200, #cooldown_interval=150 ), ), @@ -416,6 +421,7 @@ def get_args(): def main(): + print(args) # ICE TODO if "opencl" in args.target: executor = Executor(use_tracker="android") else: @@ -1178,9 +1184,21 @@ def bench(): self.benchmarks.append(bench) def tune(apply_previous_tune=False, options=args.tuning_options): - print("Extracting tasks") - tasks = autotvm.task.extract_from_program( - mod, target=target, target_host=self.host_target, params=params + tasks = [] + print("Extracting tasks...") + st = time.time() + mem_usage = memory_usage( + lambda: + tasks.append( + autotvm.task.extract_from_program( + mod, target=target, target_host=self.host_target, params=params + ) + ) + ) + elapsed_time = time.time() - st + tasks = tasks[0] + print('Extracting tasks: maximum memory usage: {} MiB, execution time: {} seconds'.format( + round(max(mem_usage), 2), round(elapsed_time, 2)) ) if apply_previous_tune == False: print("Tuning kernels") @@ -1213,9 +1231,9 @@ def tune_tasks( tmp_log_file = log_filename + ".tmp" #if os.path.exists(tmp_log_file) and use_transfer_learning == False: # os.remove(tmp_log_file) - - for i, tsk in enumerate(reversed(tasks)): - print("Task: ", tsk) + def tune_task(i, tsk, info): + print("Task: ", tsk.name, tsk.args) + print("tune with {} tuner".format(tuner)) prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) if tuner == "xgb" or tuner == "xgb-rank": tuner_obj = XGBTuner(tsk, loss_type="rank") @@ -1234,7 +1252,7 @@ def tune_tasks( if os.path.isfile(tmp_log_file): tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) - tsk_trial = min(n_trial, len(tsk.config_space)) + tsk_trial = min(n_trial, tsk.config_space.filtered_length) tuner_obj.tune( n_trial=tsk_trial, early_stopping=early_stopping, @@ -1244,6 +1262,60 @@ def tune_tasks( autotvm.callback.log_to_file(tmp_log_file), ], ) + info["best_flops"] = tuner_obj.best_flops + info["tsk_trial"] = tsk_trial + info["filtered_length"] = tsk.config_space.filtered_length + info["total_length"] = tsk.config_space.total_length + + # Memory cleaning + tsk.config_space.clear_shared_filter_cash() + del tuner_obj.cost_model.feature_cache + tuner_obj.cost_model.feature_cache = {} + gc.collect() + + for i, tsk in enumerate(reversed(tasks)): + info = {} + st = time.time() + mem_usage = memory_usage((tune_task, (i, tsk, info))) + elapsed_time = time.time() - st + + task_res_info = { + "INDEX": i+1, + "WORKLOAD": tsk.workload, + "BEST_FLOPS": round(info["best_flops"] / 1e9, 2), + "MAX_MEM_USAGE": round(max(mem_usage), 2), + "ELAPSED_TIME": round(elapsed_time, 2), + "TRIALS": info["tsk_trial"], + "FILTRED_SPACE": info["filtered_length"], + "ALL_SPACE": info["total_length"] + } + + print("Tune finished with:\n" + "index: {INDEX}\n" + "workload: {WORKLOAD}\n" + "trials: {TRIALS}\n" + "all_space: {ALL_SPACE}\n" + "filtred_space: {FILTRED_SPACE}\n" + "best_flops: {BEST_FLOPS} GFLOPS\n" + "elapsed_time: {ELAPSED_TIME} sec\n" + "max_mem_usage: {MAX_MEM_USAGE} MiB\n" + .format(**task_res_info) + ) + + tasks_result_info_file = log_filename + ".tasks_info" + + with open(tasks_result_info_file, "a") as f: + f.write( + "{INDEX}\t" + "{WORKLOAD}\t" + "{TRIALS}\t" + "{ALL_SPACE}\t" + "{FILTRED_SPACE}\t" + "{BEST_FLOPS}\t" + "{ELAPSED_TIME}\t" + "{MAX_MEM_USAGE}\n" + .format(**task_res_info) + ) autotvm.record.pick_best(tmp_log_file, log_filename) # os.remove(tmp_log_file)