Skip to content

Commit 52b3c1a

Browse files
authored
Add private API to set blas backend (#1050)
* Add private API to set blas backend * Fix UT
1 parent 7076524 commit 52b3c1a

File tree

5 files changed

+141
-39
lines changed

5 files changed

+141
-39
lines changed

intel_extension_for_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@
2929
from . import autocast
3030

3131
from .utils.verbose import verbose
32-
from .frontend import optimize, enable_onednn_fusion, set_fp32_math_mode, get_fp32_math_mode, FP32MathMode
32+
from .frontend import optimize, enable_onednn_fusion, set_fp32_math_mode, get_fp32_math_mode, FP32MathMode, _set_blas_backend, _is_mkl_blas_backend, _is_dnnl_blas_backend

intel_extension_for_pytorch/frontend.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def optimize(
159159
input data will impact the block format of packed weight. If not feed a sample
160160
input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics.
161161
If feed a sample input with real input shape, Intel® Extension for PyTorch* can get
162-
best block format.
162+
best block format.
163163
auto_kernel_selection (bool) [experimental]: Different backends may have
164164
different performances with different dtypes/shapes. Default value
165165
is False. Intel® Extension for PyTorch* will try to optimize the
@@ -241,7 +241,7 @@ def optimize(
241241
if fuse_update_step is not None:
242242
opt_properties.fuse_update_step = fuse_update_step
243243
if auto_kernel_selection is not None:
244-
opt_properties.auto_kernel_selection = auto_kernel_selection
244+
opt_properties.auto_kernel_selection = auto_kernel_selection
245245

246246
if inplace:
247247
optimized_model = model
@@ -253,7 +253,7 @@ def optimize(
253253
if isinstance(sample_input, torch.Tensor):
254254
sample_input = (sample_input,)
255255
utils._weight_prepack.record_input_shape_for_prepack(optimized_model, sample_input)
256-
256+
257257
if not model.training:
258258
if opt_properties.conv_bn_folding:
259259
try:
@@ -384,3 +384,12 @@ def get_fp32_math_mode(device="cpu"):
384384
"""
385385

386386
return core.get_fp32_math_mode()
387+
388+
def _set_blas_backend(backend="dnnl"):
389+
utils._weight_prepack.BlasBackend.set_backend(backend)
390+
391+
def _is_mkl_blas_backend():
392+
return utils._weight_prepack.BlasBackend.is_mkl()
393+
394+
def _is_dnnl_blas_backend():
395+
return utils._weight_prepack.BlasBackend.is_dnnl()

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88

99
logger = logging.getLogger(__name__)
1010

11+
class BlasBackend:
12+
_blas_backend = "dnnl"
13+
14+
@classmethod
15+
def set_backend(cls, backend="dnnl"):
16+
cls._blas_backend = backend
17+
18+
@classmethod
19+
def is_mkl(cls):
20+
return cls._blas_backend == "mkl"
21+
22+
@classmethod
23+
def is_dnnl(cls):
24+
return cls._blas_backend == "dnnl"
25+
1126
class _IPEXConvNd(nn.Module):
1227
__constants__ = ['stride', 'padding', 'dilation', 'groups',
1328
'out_channels', 'kernel_size']
@@ -302,7 +317,7 @@ def convert(m, optimizer, params_attr, auto_kernel_selection):
302317
if weight not in params_attr:
303318
params_attr[weight] = {}
304319
if type(m) is torch.nn.Linear:
305-
if m.weight.dtype == torch.float32 and optimizer is None and frontend.get_fp32_math_mode(device="cpu") == frontend.FP32MathMode.FP32:
320+
if BlasBackend.is_mkl() and m.weight.dtype == torch.float32 and optimizer is None and frontend.get_fp32_math_mode(device="cpu") == frontend.FP32MathMode.FP32:
306321
new_m = IPEX_WEIGHT_PREPACK_MODULE[type(m)](m, use_dnnl = False)
307322
else:
308323
new_m = IPEX_WEIGHT_PREPACK_MODULE[type(m)](m, use_dnnl = True)

tests/cpu/test_jit.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ def _test_output(self, base_model, x, kind_in_graph=None, kind_not_in_graph=None
10781078
if kind_not_in_graph is not None:
10791079
self.assertTrue(all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
10801080

1081-
def _test_mkl_fp32(self, model, input, kind_in_graph=None, prec=5e-3):
1081+
def _test_blas_backend_fp32(self, model, input, kind_in_graph=None, prec=5e-3):
10821082
model = model.eval()
10831083
model = ipex.optimize(model, dtype=torch.float32, auto_kernel_selection=True)
10841084
with torch.no_grad():
@@ -1210,19 +1210,23 @@ def check_op_count(graph_str, op_names=[]):
12101210
linear_count_ori = check_op_count(graph_opt, ["aten::linear"])
12111211
self.assertEqual(linear_count_ori, 2)
12121212
#call prepack mkl path(fp32)
1213+
ipex._set_blas_backend("mkl")
12131214
model = ipex.optimize(origin_model, dtype=torch.float32, auto_kernel_selection=True)
12141215
ori_res = model(test_val1)
12151216
with torch.no_grad():
12161217
model_jit = torch.jit.trace(model,(test_val1))
12171218
graph_ori = str(model_jit.graph_for(test_val1))
12181219
linear_count_ori = check_op_count(graph_ori, ["torch_ipex::ipex_MKLSGEMM"])
12191220
self.assertEqual(linear_count_ori, 4)
1221+
12201222
model_jit = torch.jit.freeze(model_jit)
12211223
jit_res = model_jit(test_val1)
12221224
self.assertEqual(ori_res, jit_res)
1225+
12231226
graph_opt = str(model_jit.graph_for(test_val1))
12241227
linear_count_ori = check_op_count(graph_opt, ["ipex_prepack::mkl_sgemm_run"])
12251228
self.assertEqual(linear_count_ori, 2)
1229+
ipex._set_blas_backend("dnnl")
12261230

12271231
model = ipex.optimize(origin_model, dtype=torch.bfloat16)
12281232
test_val1 = test_val1.bfloat16()
@@ -1286,7 +1290,7 @@ def test_add_layernorm(self):
12861290
c = torch.randn(bs, seq_len, dim)
12871291
jit_model = torch.jit.trace(model,(a, b, c))
12881292
trace_graph = jit_model.graph_for(a, b, c)
1289-
1293+
12901294
jit_res = jit_model(a, b, c)
12911295
ori_res = model(a, b, c)
12921296
self.assertEqual(jit_res, ori_res)
@@ -1495,7 +1499,7 @@ def _test_pure_bf16_parts(model, trace_model, qk, mask, prec=3e-2):
14951499
res_jit = trace_model(qk_bf16, mask_bf16)
14961500
self.assertEqual(res_ref, res_jit, prec=prec)
14971501
_check_match_mha_parts(trace_model, qk_bf16, mask)
1498-
1502+
14991503
for sequance_length in [128, 100]:
15001504
mat1 = torch.randn(56, 12, sequance_length, sequance_length)
15011505
mat2 = torch.randn(56, 12, sequance_length, sequance_length)
@@ -2618,8 +2622,9 @@ def test_conv_transpose_sigmoid_mul(self):
26182622

26192623
def test_linear_auto_kernel_selection_fp32(self):
26202624
x = torch.rand(32, 3)
2621-
options = itertools.product(['O0', 'O1'], [True, False])
2622-
for level, auto_select_kernel in options:
2625+
options = itertools.product(['O0', 'O1'], [True, False], ["mkl", "dnnl"])
2626+
for level, auto_select_kernel, blas_backend in options:
2627+
ipex._set_blas_backend(blas_backend)
26232628
model = LinearRelu(3, 32, bias=True).eval()
26242629
model = ipex.optimize(model, dtype=torch.float32, level=level, auto_kernel_selection=auto_select_kernel)
26252630
with torch.no_grad():
@@ -2629,10 +2634,11 @@ def test_linear_auto_kernel_selection_fp32(self):
26292634
trace_graph = traced_model.graph_for(x)
26302635

26312636
if auto_select_kernel and level == 'O1':
2632-
# for auto_select_kernel is True and level is O1, we will use ipex prepacked MKL linear
2633-
self.assertTrue(any(n.kind() == 'ipex_prepack::mkl_sgemm_run' for n in trace_graph.nodes()))
2637+
if ipex._is_mkl_blas_backend():
2638+
self.assertTrue(any(n.kind() == 'ipex_prepack::mkl_sgemm_run' for n in trace_graph.nodes()))
2639+
else:
2640+
self.assertTrue(any(n.kind() == 'ipex_prepack::linear_relu_run' for n in trace_graph.nodes()))
26342641
else:
2635-
# auto_select_kernel is false, we will use mkl linear
26362642
self.assertTrue(any(n.kind() == 'aten::linear' for n in trace_graph.nodes()))
26372643

26382644
def test_linear_auto_kernel_selection_bf16(self):
@@ -2788,10 +2794,22 @@ def _test_linear_unary_fusion(self, op_list, seed=None):
27882794
m,
27892795
x,
27902796
kind_in_graph="aten::linear")
2791-
self._test_mkl_fp32(
2797+
2798+
blas_backend = {"mkl":"ipex_prepack::mkl_sgemm_run"}
2799+
for _blas in blas_backend.keys():
2800+
ipex._set_blas_backend(_blas)
2801+
self._test_blas_backend_fp32(
2802+
m,
2803+
x,
2804+
kind_in_graph=blas_backend[_blas])
2805+
2806+
ipex._set_blas_backend("dnnl")
2807+
self._test_blas_backend_fp32(
27922808
m,
27932809
x,
2794-
kind_in_graph="ipex_prepack::mkl_sgemm_run")
2810+
kind_in_graph="ipex_prepack::linear_%s_run" % ipex_eltwise_op,
2811+
prec=prec)
2812+
27952813
if bf16_supported:
27962814
self._test_output_bf16(
27972815
m,
@@ -2836,10 +2854,16 @@ def test_output_linear_add(self):
28362854
LinearAdd(3, 32, bias=True),
28372855
torch.rand(32, 3),
28382856
kind_in_graph="aten::linear")
2839-
self._test_mkl_fp32(
2840-
LinearAdd(3, 32, bias=True),
2841-
torch.rand(32, 3),
2842-
kind_in_graph="ipex_prepack::mkl_sgemm_run")
2857+
2858+
blas_backend = {"mkl":"ipex_prepack::mkl_sgemm_run", "dnnl":"ipex_prepack::linear_run"}
2859+
for _blas in blas_backend.keys():
2860+
ipex._set_blas_backend(_blas)
2861+
self._test_blas_backend_fp32(
2862+
LinearAdd(3, 32, bias=True),
2863+
torch.rand(32, 3),
2864+
kind_in_graph=blas_backend[_blas])
2865+
ipex._set_blas_backend("dnnl")
2866+
28432867
self._test_output_bf16(
28442868
LinearAdd(3, 32, bias=True),
28452869
torch.rand(32, 3),
@@ -2855,10 +2879,16 @@ def test_output_linear_add_relu(self):
28552879
m,
28562880
x,
28572881
kind_in_graph="aten::linear")
2858-
self._test_mkl_fp32(
2859-
m,
2860-
x,
2861-
kind_in_graph="ipex_prepack::mkl_sgemm_run")
2882+
2883+
blas_backend = {"mkl":"ipex_prepack::mkl_sgemm_run", "dnnl":"ipex_prepack::linear_run"}
2884+
for _blas in blas_backend.keys():
2885+
ipex._set_blas_backend(_blas)
2886+
self._test_blas_backend_fp32(
2887+
m,
2888+
x,
2889+
kind_in_graph=blas_backend[_blas])
2890+
ipex._set_blas_backend("dnnl")
2891+
28622892
self._test_output_bf16(
28632893
m,
28642894
x,
@@ -2885,14 +2915,21 @@ def test_output_linear_reshape_bn(self):
28852915
kind_in_graph="aten::linear")
28862916

28872917
def test_output_linear_swish(self):
2888-
self._test_mkl_fp32(
2889-
LinearSigmoidMul(3, 32, bias=True),
2890-
torch.rand(32, 3),
2891-
kind_in_graph="ipex_prepack::mkl_sgemm_run")
2892-
self._test_mkl_fp32(
2893-
LinearSigmoidMul(3, 32, bias=False),
2894-
torch.rand(32, 3),
2895-
kind_in_graph="ipex_prepack::mkl_sgemm_run")
2918+
2919+
blas_backend = {"mkl":"ipex_prepack::mkl_sgemm_run", "dnnl":"ipex_prepack::linear_swish_run"}
2920+
for _blas in blas_backend.keys():
2921+
ipex._set_blas_backend(_blas)
2922+
2923+
self._test_blas_backend_fp32(
2924+
LinearSigmoidMul(3, 32, bias=True),
2925+
torch.rand(32, 3),
2926+
kind_in_graph=blas_backend[_blas])
2927+
self._test_blas_backend_fp32(
2928+
LinearSigmoidMul(3, 32, bias=False),
2929+
torch.rand(32, 3),
2930+
kind_in_graph=blas_backend[_blas])
2931+
ipex._set_blas_backend("dnnl")
2932+
28962933
self._test_output_bf16(
28972934
LinearSigmoidMul(3, 32, bias=True),
28982935
torch.rand(32, 3),

tests/cpu/test_weight_prepack.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _test_convolution_training_base(self, dim, dtype, rtol=None, atol=None):
182182

183183
def test_conv2d_training(self):
184184
self._test_convolution_training_base(dim=2, dtype=torch.float)
185-
if core.onednn_has_bf16_support():
185+
if core.onednn_has_bf16_support():
186186
self._test_convolution_training_base(dim=2, dtype=torch.bfloat16, rtol=1e-2, atol=1e-03)
187187

188188
# TODO: add inference case.
@@ -436,6 +436,47 @@ def test_resnext50_32x4d(self):
436436
model = torchvision.models.resnet.resnext50_32x4d(pretrained=False)
437437
self._test_imagenet_model(model)
438438

439+
def test_blas_backend(self):
440+
class L(torch.nn.Module):
441+
def __init__(self, in_f, out_f, bias):
442+
super(L, self).__init__()
443+
self.linear = torch.nn.Linear(in_f, out_f, bias=bias)
444+
445+
def forward(self, x):
446+
return self.linear(x)
447+
448+
out_features = torch.randint(3, 10, (1,)).item()
449+
in_features = torch.randint(3, 10, (1,)).item()
450+
451+
input_shape = (8, in_features)
452+
x = torch.randn(input_shape, dtype=torch.float32)
453+
model = L(in_features, out_features, True)
454+
origin_model = copy.deepcopy(model).eval()
455+
456+
def test_dnnl():
457+
self.assertTrue(ipex._is_dnnl_blas_backend())
458+
ipex_model_dnnl = ipex.optimize(origin_model, dtype=torch.float32, level='O1', auto_kernel_selection=True)
459+
with torch.no_grad():
460+
dnnl_graph = torch.jit.trace(ipex_model_dnnl.eval(), x)
461+
dnnl_graph = torch.jit.freeze(dnnl_graph)
462+
dnnl_graph(x)
463+
trace_graph = dnnl_graph.graph_for(x)
464+
self.assertTrue(any(n.kind() == "ipex_prepack::linear_run" for n in trace_graph.nodes()))
465+
test_dnnl()
466+
467+
ipex._set_blas_backend("mkl")
468+
self.assertTrue(ipex._is_mkl_blas_backend())
469+
ipex_model_dnnl = ipex.optimize(origin_model, dtype=torch.float32, level='O1', auto_kernel_selection=True)
470+
with torch.no_grad():
471+
dnnl_graph = torch.jit.trace(ipex_model_dnnl.eval(), x)
472+
dnnl_graph = torch.jit.freeze(dnnl_graph)
473+
dnnl_graph(x)
474+
trace_graph = dnnl_graph.graph_for(x)
475+
self.assertTrue(any(n.kind() == "ipex_prepack::mkl_sgemm_run" for n in trace_graph.nodes()))
476+
477+
ipex._set_blas_backend("dnnl")
478+
test_dnnl()
479+
439480
def test_linear_inference(self):
440481
class L(torch.nn.Module):
441482
def __init__(self, in_f, out_f, bias):
@@ -479,7 +520,7 @@ def test_linear_training(self):
479520
input_shapes = []
480521
for s in in_feature:
481522
input_shapes += [(128, s), (2, 64, s), (2, 2, 32, s)]
482-
523+
483524
options = itertools.product(out_feature, [True, False], input_shapes, [torch.bfloat16], [True, False])
484525
for out_features, bias, x_shape, dtype, feed_sample_input in options:
485526
in_features = x_shape[-1]
@@ -564,12 +605,12 @@ def _deconv_with_output_padding(self):
564605
"groups": 1,
565606
"dilation": 3,
566607
}
567-
608+
568609
params_list = []
569610

570611
for key, value in params_dict.items():
571612
params_list.append(value)
572-
return params_list
613+
return params_list
573614

574615
# mkldnn does not support the case where:
575616
# padding - output_padding + stride <= 0
@@ -594,7 +635,7 @@ def _deconv_fallback_shape(self):
594635

595636
for key, value in params_dict.items():
596637
params_list.append(value)
597-
return params_list
638+
return params_list
598639

599640
def _test_deconv(self, dims, inference):
600641
class Deconv2d(torch.nn.Module):
@@ -667,14 +708,14 @@ def forward(self, x):
667708
ipex_model, ipex_optimizer = ipex.optimize(origin_model, dtype=dtype, optimizer=origin_optimizer, level='O1', sample_input=x)
668709
else:
669710
ipex_model, ipex_optimizer = ipex.optimize(origin_model, dtype=dtype, optimizer=origin_optimizer, level='O1')
670-
711+
671712
if padding - output_padding + stride <= 0:
672713
# unsupported in mkldnn, should not replace the original ConvTranspose module
673714
self.assertTrue(module_found(ipex_model, torch.nn.ConvTranspose2d if dims == 2 else torch.nn.ConvTranspose3d))
674715
continue
675716
else:
676-
self.assertFalse(module_found(ipex_model, torch.nn.ConvTranspose2d if dims == 2 else torch.nn.ConvTranspose3d))
677-
717+
self.assertFalse(module_found(ipex_model, torch.nn.ConvTranspose2d if dims == 2 else torch.nn.ConvTranspose3d))
718+
678719
x1 = x.clone().requires_grad_()
679720
x2 = x.clone().requires_grad_()
680721

0 commit comments

Comments
 (0)