File tree Expand file tree Collapse file tree 1 file changed +11
-4
lines changed
intel_extension_for_pytorch/csrc/jit/cpu/passes Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -22,11 +22,18 @@ void replaceFrozenIPEXLinearWithAtenLinear(
2222 replaceFrozenIPEXLinearWithAtenLinear (
2323 block, get_data_handle_nodes, use_mkl_sgemm);
2424 }
25- if (n->kind () == Symbol::fromQualString (" torch_ipex::ipex_linear" ) ||
26- n->kind () == Symbol::fromQualString (" torch_ipex::ipex_MKLSGEMM" )) {
25+
26+ bool is_ipex_linear =
27+ n->kind () == Symbol::fromQualString (" torch_ipex::ipex_linear" );
28+ bool is_mkl_sgemm =
29+ n->kind () == Symbol::fromQualString (" torch_ipex::ipex_MKLSGEMM" );
30+ if (is_ipex_linear || is_mkl_sgemm) {
31+ // mkl sgemm does not support grad mode
32+ bool mkl_sgemm_and_grad_mode =
33+ is_mkl_sgemm && c10::GradMode::is_enabled ();
2734 TORCH_CHECK (
28- !( c10::GradMode::is_enabled ()) ,
29- " Detect the Grad Mode ! Please make sure torch.no_grad() is set priori to JIT trace " );
35+ !mkl_sgemm_and_grad_mode ,
36+ " Currently the auto_kernel_selection does not support the grad mode ! Please add torch.no_grad() before the inference runtime. " );
3037 if (!(constant_as<at::Tensor>(n->namedInput (" weight" )).has_value ())) {
3138 continue ;
3239 }
You can’t perform that action at this time.
0 commit comments