@@ -541,15 +541,32 @@ def __init__(self, cli_executable=None, cli_options=None, llvmlite_engine=None):
541541 backing_mod = llvm .parse_assembly ("" )
542542 llvmlite_engine = llvm .create_mcjit_compiler (backing_mod , target_machine )
543543 self ._engine = llvmlite_engine
544- self .tmp_dir = tempfile .TemporaryDirectory ()
544+ self .default_profile_dir = tempfile .TemporaryDirectory ()
545+ self .current_profile_dir : Optional [str ] = None
545546 self .c_compiler = distutils .ccompiler .new_compiler ()
546547 self ._cli = MlirOptCli (cli_executable , cli_options )
547548 self .name_to_callable : Dict [str , Callable ] = {}
548549 return
549550
550- def profiled_function (self , main_callable : Callable ) -> Callable :
551+ @property
552+ def profile_dir_name (self ) -> str :
553+ if self .current_profile_dir is not None :
554+ # TODO consider making a context manager for setting self.current_profile_dir
555+ return self .current_profile_dir
556+ return self .default_profile_dir .name
557+
558+ def profiled_function (
559+ self , main_callable : Callable , symbol_to_profile : str
560+ ) -> Callable :
551561 """Decorator to profile a function via Linux's perf tool."""
552562
563+ # set this at the time that decorated_func is created and
564+ # not at the time it is called as the value of
565+ # self.profile_dir_name may change.
566+ perf_data_file_name = os .path .join (
567+ self .profile_dir_name , f"perf-{ uuid .uuid4 ()} .data"
568+ )
569+
553570 def mp_func (queue : mp .SimpleQueue , * args ):
554571 # send pid to decorated_func
555572 queue .put (os .getpid ())
@@ -576,12 +593,11 @@ def decorated_func(*args) -> Any:
576593 stdout = subprocess .DEVNULL ,
577594 stderr = subprocess .PIPE ,
578595 )
579- record_process .stdin .write (
580- f"""
581- perf record -p { execution_process_id }
596+ record_command = f"""
597+ perf record -p { execution_process_id } --output={ perf_data_file_name }
582598 exit
583- """ . encode ()
584- )
599+ """
600+ record_process . stdin . write ( record_command . encode () )
585601 record_process .stdin .flush ()
586602
587603 # wait for profiling to initialize
@@ -600,14 +616,14 @@ def decorated_func(*args) -> Any:
600616 pass
601617
602618 # gather profiling results
603- report_command = f"perf report --pid= { execution_process_id } | cat "
604- report_process = subprocess .Popen (
605- report_command ,
619+ annotate_command = f"perf annotate --stdio --symbol { symbol_to_profile } -l --input= { perf_data_file_name } "
620+ annotate_process = subprocess .Popen (
621+ annotate_command ,
606622 shell = True ,
607623 stdout = subprocess .PIPE ,
608624 stderr = subprocess .STDOUT ,
609625 ) # combines stdout and stderr
610- stdout_string , _ = report_process .communicate ()
626+ stdout_string , _ = annotate_process .communicate ()
611627 stdout_string = stdout_string .decode ()
612628
613629 # print results
@@ -645,14 +661,14 @@ def _add_llvm_module(
645661 assert len (o_file_bytes_list ) in (1 , 2 )
646662 o_file_bytes = o_file_bytes_list [- 1 ]
647663 o_file_name = f"mod-{ uuid .uuid4 ()} .o"
648- o_file_name = os .path .join (self .tmp_dir . name , o_file_name )
664+ o_file_name = os .path .join (self .profile_dir_name , o_file_name )
649665 with open (o_file_name , "wb" ) as f :
650666 f .write (o_file_bytes_list [- 1 ])
651667 files_to_link .append (o_file_name )
652668 if isinstance (profile , ctypes .CDLL ):
653669 files_to_link .append (profile ._name )
654670 so_file_name = f"shared-{ uuid .uuid4 ()} .so"
655- so_file_name = os .path .join (self .tmp_dir . name , so_file_name )
671+ so_file_name = os .path .join (self .profile_dir_name , so_file_name )
656672 self .c_compiler .link_shared_object (files_to_link , so_file_name )
657673 ctypes .cdll .LoadLibrary (so_file_name )
658674 shared_lib = ctypes .CDLL (so_file_name )
@@ -753,7 +769,7 @@ def python_callable(mlir_function, encoders, c_callable, decoder, *args):
753769 python_callable , mlir_function , encoders , c_callable , decoder
754770 )
755771 if shared_lib is not None :
756- bound_func = self .profiled_function (bound_func )
772+ bound_func = self .profiled_function (bound_func , name )
757773 name_to_callable [name ] = bound_func
758774
759775 return name_to_callable
@@ -791,7 +807,7 @@ def _lower_types_to_strings(
791807
792808 def _generate_mlir_string_for_multivalued_functions (
793809 self , mlir_functions : Iterable [mlir .astnodes .Function ], passes : List [str ]
794- ) -> Tuple [str , str ]:
810+ ) -> Tuple [str , List [ str ], List [ str ] ]:
795811
796812 result_type_name_to_lowered_result_type_name = self ._lower_types_to_strings (
797813 sum ((mlir_function .result_types for mlir_function in mlir_functions ), []),
@@ -800,9 +816,8 @@ def _generate_mlir_string_for_multivalued_functions(
800816
801817 # Generate conglomerate MLIR string for all wrappers
802818 mlir_wrapper_texts : List [str ] = []
803- wrapper_names = [
804- mlir_function .name .value + "wrapper" for mlir_function in mlir_functions
805- ]
819+ names = [mlir_function .name .value for mlir_function in mlir_functions ]
820+ wrapper_names = [name + "wrapper" for name in names ]
806821 for mlir_function , wrapper_name in zip (mlir_functions , wrapper_names ):
807822 lowered_result_type_names = [
808823 result_type_name_to_lowered_result_type_name [result_type .dump ()]
@@ -860,7 +875,7 @@ def _generate_mlir_string_for_multivalued_functions(
860875 mlir_wrapper_texts .append (mlir_wrapper_text )
861876
862877 mlir_text = "\n " .join (mlir_wrapper_texts )
863- return mlir_text , wrapper_names
878+ return mlir_text , names , wrapper_names
864879
865880 def _generate_multivalued_functions (
866881 self ,
@@ -870,9 +885,11 @@ def _generate_multivalued_functions(
870885 ) -> Dict [str , Callable ]:
871886 name_to_callable : Dict [str , Callable ] = {}
872887
873- mlir_text , wrapper_names = self ._generate_mlir_string_for_multivalued_functions (
874- mlir_functions , passes
875- )
888+ (
889+ mlir_text ,
890+ names ,
891+ wrapper_names ,
892+ ) = self ._generate_mlir_string_for_multivalued_functions (mlir_functions , passes )
876893
877894 # this is guaranteed to not raise exceptions since the user-provided
878895 # code was already added (failures would occur then)
@@ -882,7 +899,9 @@ def _generate_multivalued_functions(
882899 assert bool (wrapper_shared_lib ) == bool (internal_shared_lib )
883900
884901 # Generate callables
885- for mlir_function , wrapper_name in zip (mlir_functions , wrapper_names ):
902+ for mlir_function , name , wrapper_name in zip (
903+ mlir_functions , names , wrapper_names
904+ ):
886905 ctypes_input_types , input_encoders = mlir_function_input_encoders (
887906 mlir_function
888907 )
@@ -954,7 +973,7 @@ def python_callable(
954973 decoders ,
955974 )
956975 if wrapper_shared_lib is not None :
957- bound_func = self .profiled_function (bound_func )
976+ bound_func = self .profiled_function (bound_func , name )
958977 name_to_callable [mlir_function .name .value ] = bound_func
959978
960979 return name_to_callable
@@ -975,10 +994,18 @@ def add(
975994 mlir_text : Union [str , bytes ],
976995 passes : Tuple [str ],
977996 * ,
978- debug = False ,
979- profile = False ,
997+ debug : bool = False ,
998+ profile : bool = False ,
999+ profile_result_directory : Optional [str ] = None ,
9801000 ) -> Union [List [str ], DebugResult ]:
9811001 """List of new function names added."""
1002+ if profile_result_directory is not None :
1003+ if not profile :
1004+ raise ValueError (
1005+ "Cannot specify a profile result directory without also enabling profiling."
1006+ )
1007+ self .current_profile_dir = profile_result_directory
1008+
9821009 if profile :
9831010 if not sys .platform .startswith ("linux" ):
9841011 raise NotImplementedError ("Profiling only supported on linux." )
@@ -1054,6 +1081,9 @@ def add(
10541081
10551082 self .name_to_callable [name ] = python_callable
10561083
1084+ if profile_result_directory is not None :
1085+ self .current_profile_dir = None
1086+
10571087 return function_names
10581088
10591089 def __getitem__ (self , func_name : str ) -> Callable :
0 commit comments