@@ -1078,7 +1078,7 @@ def _test_output(self, base_model, x, kind_in_graph=None, kind_not_in_graph=None
10781078 if kind_not_in_graph is not None :
10791079 self .assertTrue (all (n .kind () != kind_not_in_graph for n in trace_graph .nodes ()))
10801080
1081- def _test_mkl_fp32 (self , model , input , kind_in_graph = None , prec = 5e-3 ):
1081+ def _test_blas_backend_fp32 (self , model , input , kind_in_graph = None , prec = 5e-3 ):
10821082 model = model .eval ()
10831083 model = ipex .optimize (model , dtype = torch .float32 , auto_kernel_selection = True )
10841084 with torch .no_grad ():
@@ -1210,19 +1210,23 @@ def check_op_count(graph_str, op_names=[]):
12101210 linear_count_ori = check_op_count (graph_opt , ["aten::linear" ])
12111211 self .assertEqual (linear_count_ori , 2 )
12121212#call prepack mkl path(fp32)
1213+ ipex ._set_blas_backend ("mkl" )
12131214 model = ipex .optimize (origin_model , dtype = torch .float32 , auto_kernel_selection = True )
12141215 ori_res = model (test_val1 )
12151216 with torch .no_grad ():
12161217 model_jit = torch .jit .trace (model ,(test_val1 ))
12171218 graph_ori = str (model_jit .graph_for (test_val1 ))
12181219 linear_count_ori = check_op_count (graph_ori , ["torch_ipex::ipex_MKLSGEMM" ])
12191220 self .assertEqual (linear_count_ori , 4 )
1221+
12201222 model_jit = torch .jit .freeze (model_jit )
12211223 jit_res = model_jit (test_val1 )
12221224 self .assertEqual (ori_res , jit_res )
1225+
12231226 graph_opt = str (model_jit .graph_for (test_val1 ))
12241227 linear_count_ori = check_op_count (graph_opt , ["ipex_prepack::mkl_sgemm_run" ])
12251228 self .assertEqual (linear_count_ori , 2 )
1229+ ipex ._set_blas_backend ("dnnl" )
12261230
12271231 model = ipex .optimize (origin_model , dtype = torch .bfloat16 )
12281232 test_val1 = test_val1 .bfloat16 ()
@@ -1286,7 +1290,7 @@ def test_add_layernorm(self):
12861290 c = torch .randn (bs , seq_len , dim )
12871291 jit_model = torch .jit .trace (model ,(a , b , c ))
12881292 trace_graph = jit_model .graph_for (a , b , c )
1289-
1293+
12901294 jit_res = jit_model (a , b , c )
12911295 ori_res = model (a , b , c )
12921296 self .assertEqual (jit_res , ori_res )
@@ -1495,7 +1499,7 @@ def _test_pure_bf16_parts(model, trace_model, qk, mask, prec=3e-2):
14951499 res_jit = trace_model (qk_bf16 , mask_bf16 )
14961500 self .assertEqual (res_ref , res_jit , prec = prec )
14971501 _check_match_mha_parts (trace_model , qk_bf16 , mask )
1498-
1502+
14991503 for sequance_length in [128 , 100 ]:
15001504 mat1 = torch .randn (56 , 12 , sequance_length , sequance_length )
15011505 mat2 = torch .randn (56 , 12 , sequance_length , sequance_length )
@@ -2618,8 +2622,9 @@ def test_conv_transpose_sigmoid_mul(self):
26182622
26192623 def test_linear_auto_kernel_selection_fp32 (self ):
26202624 x = torch .rand (32 , 3 )
2621- options = itertools .product (['O0' , 'O1' ], [True , False ])
2622- for level , auto_select_kernel in options :
2625+ options = itertools .product (['O0' , 'O1' ], [True , False ], ["mkl" , "dnnl" ])
2626+ for level , auto_select_kernel , blas_backend in options :
2627+ ipex ._set_blas_backend (blas_backend )
26232628 model = LinearRelu (3 , 32 , bias = True ).eval ()
26242629 model = ipex .optimize (model , dtype = torch .float32 , level = level , auto_kernel_selection = auto_select_kernel )
26252630 with torch .no_grad ():
@@ -2629,10 +2634,11 @@ def test_linear_auto_kernel_selection_fp32(self):
26292634 trace_graph = traced_model .graph_for (x )
26302635
26312636 if auto_select_kernel and level == 'O1' :
2632- # for auto_select_kernel is True and level is O1, we will use ipex prepacked MKL linear
2633- self .assertTrue (any (n .kind () == 'ipex_prepack::mkl_sgemm_run' for n in trace_graph .nodes ()))
2637+ if ipex ._is_mkl_blas_backend ():
2638+ self .assertTrue (any (n .kind () == 'ipex_prepack::mkl_sgemm_run' for n in trace_graph .nodes ()))
2639+ else :
2640+ self .assertTrue (any (n .kind () == 'ipex_prepack::linear_relu_run' for n in trace_graph .nodes ()))
26342641 else :
2635- # auto_select_kernel is false, we will use mkl linear
26362642 self .assertTrue (any (n .kind () == 'aten::linear' for n in trace_graph .nodes ()))
26372643
26382644 def test_linear_auto_kernel_selection_bf16 (self ):
@@ -2788,10 +2794,22 @@ def _test_linear_unary_fusion(self, op_list, seed=None):
27882794 m ,
27892795 x ,
27902796 kind_in_graph = "aten::linear" )
2791- self ._test_mkl_fp32 (
2797+
2798+ blas_backend = {"mkl" :"ipex_prepack::mkl_sgemm_run" }
2799+ for _blas in blas_backend .keys ():
2800+ ipex ._set_blas_backend (_blas )
2801+ self ._test_blas_backend_fp32 (
2802+ m ,
2803+ x ,
2804+ kind_in_graph = blas_backend [_blas ])
2805+
2806+ ipex ._set_blas_backend ("dnnl" )
2807+ self ._test_blas_backend_fp32 (
27922808 m ,
27932809 x ,
2794- kind_in_graph = "ipex_prepack::mkl_sgemm_run" )
2810+ kind_in_graph = "ipex_prepack::linear_%s_run" % ipex_eltwise_op ,
2811+ prec = prec )
2812+
27952813 if bf16_supported :
27962814 self ._test_output_bf16 (
27972815 m ,
@@ -2836,10 +2854,16 @@ def test_output_linear_add(self):
28362854 LinearAdd (3 , 32 , bias = True ),
28372855 torch .rand (32 , 3 ),
28382856 kind_in_graph = "aten::linear" )
2839- self ._test_mkl_fp32 (
2840- LinearAdd (3 , 32 , bias = True ),
2841- torch .rand (32 , 3 ),
2842- kind_in_graph = "ipex_prepack::mkl_sgemm_run" )
2857+
2858+ blas_backend = {"mkl" :"ipex_prepack::mkl_sgemm_run" , "dnnl" :"ipex_prepack::linear_run" }
2859+ for _blas in blas_backend .keys ():
2860+ ipex ._set_blas_backend (_blas )
2861+ self ._test_blas_backend_fp32 (
2862+ LinearAdd (3 , 32 , bias = True ),
2863+ torch .rand (32 , 3 ),
2864+ kind_in_graph = blas_backend [_blas ])
2865+ ipex ._set_blas_backend ("dnnl" )
2866+
28432867 self ._test_output_bf16 (
28442868 LinearAdd (3 , 32 , bias = True ),
28452869 torch .rand (32 , 3 ),
@@ -2855,10 +2879,16 @@ def test_output_linear_add_relu(self):
28552879 m ,
28562880 x ,
28572881 kind_in_graph = "aten::linear" )
2858- self ._test_mkl_fp32 (
2859- m ,
2860- x ,
2861- kind_in_graph = "ipex_prepack::mkl_sgemm_run" )
2882+
2883+ blas_backend = {"mkl" :"ipex_prepack::mkl_sgemm_run" , "dnnl" :"ipex_prepack::linear_run" }
2884+ for _blas in blas_backend .keys ():
2885+ ipex ._set_blas_backend (_blas )
2886+ self ._test_blas_backend_fp32 (
2887+ m ,
2888+ x ,
2889+ kind_in_graph = blas_backend [_blas ])
2890+ ipex ._set_blas_backend ("dnnl" )
2891+
28622892 self ._test_output_bf16 (
28632893 m ,
28642894 x ,
@@ -2885,14 +2915,21 @@ def test_output_linear_reshape_bn(self):
28852915 kind_in_graph = "aten::linear" )
28862916
28872917 def test_output_linear_swish (self ):
2888- self ._test_mkl_fp32 (
2889- LinearSigmoidMul (3 , 32 , bias = True ),
2890- torch .rand (32 , 3 ),
2891- kind_in_graph = "ipex_prepack::mkl_sgemm_run" )
2892- self ._test_mkl_fp32 (
2893- LinearSigmoidMul (3 , 32 , bias = False ),
2894- torch .rand (32 , 3 ),
2895- kind_in_graph = "ipex_prepack::mkl_sgemm_run" )
2918+
2919+ blas_backend = {"mkl" :"ipex_prepack::mkl_sgemm_run" , "dnnl" :"ipex_prepack::linear_swish_run" }
2920+ for _blas in blas_backend .keys ():
2921+ ipex ._set_blas_backend (_blas )
2922+
2923+ self ._test_blas_backend_fp32 (
2924+ LinearSigmoidMul (3 , 32 , bias = True ),
2925+ torch .rand (32 , 3 ),
2926+ kind_in_graph = blas_backend [_blas ])
2927+ self ._test_blas_backend_fp32 (
2928+ LinearSigmoidMul (3 , 32 , bias = False ),
2929+ torch .rand (32 , 3 ),
2930+ kind_in_graph = blas_backend [_blas ])
2931+ ipex ._set_blas_backend ("dnnl" )
2932+
28962933 self ._test_output_bf16 (
28972934 LinearSigmoidMul (3 , 32 , bias = True ),
28982935 torch .rand (32 , 3 ),
0 commit comments