Skip to content

Commit 36e0b30

Browse files
Fix bf16 grad issue (#1070)
* Fix bf16 grad issue * revise error message for auto_kernel_selection Co-authored-by: Wei-Lin-Intel <wei2.lin@intel.com>
1 parent 42ecc88 commit 36e0b30

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_linear.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,18 @@ void replaceFrozenIPEXLinearWithAtenLinear(
2222
replaceFrozenIPEXLinearWithAtenLinear(
2323
block, get_data_handle_nodes, use_mkl_sgemm);
2424
}
25-
if (n->kind() == Symbol::fromQualString("torch_ipex::ipex_linear") ||
26-
n->kind() == Symbol::fromQualString("torch_ipex::ipex_MKLSGEMM")) {
25+
26+
bool is_ipex_linear =
27+
n->kind() == Symbol::fromQualString("torch_ipex::ipex_linear");
28+
bool is_mkl_sgemm =
29+
n->kind() == Symbol::fromQualString("torch_ipex::ipex_MKLSGEMM");
30+
if (is_ipex_linear || is_mkl_sgemm) {
31+
// mkl sgemm does not support grad mode
32+
bool mkl_sgemm_and_grad_mode =
33+
is_mkl_sgemm && c10::GradMode::is_enabled();
2734
TORCH_CHECK(
28-
!(c10::GradMode::is_enabled()),
29-
"Detect the Grad Mode! Please make sure torch.no_grad() is set priori to JIT trace");
35+
!mkl_sgemm_and_grad_mode,
36+
"Currently the auto_kernel_selection does not support the grad mode! Please add torch.no_grad() before the inference runtime.");
3037
if (!(constant_as<at::Tensor>(n->namedInput("weight")).has_value())) {
3138
continue;
3239
}

0 commit comments

Comments
 (0)