|
| 1 | +diff --git a/comfy/model_management.py b/comfy/model_management.py |
| 2 | +index 709ebc40..c43e8eab 100644 |
| 3 | +--- a/comfy/model_management.py |
| 4 | ++++ b/comfy/model_management.py |
| 5 | +@@ -148,6 +148,90 @@ def is_intel_xpu(): |
| 6 | + return True |
| 7 | + return False |
| 8 | + |
| 9 | ++import os |
| 10 | ++if is_intel_xpu() and os.environ.get("_LLM_SCALER_DISABLE_INTERPOLATE_FIX") != "1": |
| 11 | ++ import torch |
| 12 | ++ import torch.nn.functional as F |
| 13 | ++ import functools # Used to preserve function metadata like docstrings |
| 14 | ++ |
| 15 | ++ # Global variables to store the original function and patch status |
| 16 | ++ _original_interpolate_func = None |
| 17 | ++ _is_interpolate_patched = False |
| 18 | ++ |
| 19 | ++ |
| 20 | ++ def patch_xpu_interpolate_to_cpu(): |
| 21 | ++ """ |
| 22 | ++ patches torch.nn.functional.interpolate. If an input tensor is on an XPU device, |
| 23 | ++ it will be moved to CPU for interpolation, and the result will be moved back |
| 24 | ++ to the original XPU device. |
| 25 | ++ """ |
| 26 | ++ global _original_interpolate_func, _is_interpolate_patched |
| 27 | ++ |
| 28 | ++ if _is_interpolate_patched: |
| 29 | ++ print("torch.nn.functional.interpolate is already patched for XPU. Skipping.") |
| 30 | ++ return |
| 31 | ++ |
| 32 | ++ # Store the original function |
| 33 | ++ _original_interpolate_func = F.interpolate |
| 34 | ++ |
| 35 | ++ @functools.wraps(_original_interpolate_func) |
| 36 | ++ def _custom_interpolate(input_tensor, *args, **kwargs): |
| 37 | ++ """ |
| 38 | ++ Custom wrapper for interpolate. Moves XPU tensors to CPU for computation. |
| 39 | ++ """ |
| 40 | ++ |
| 41 | ++ if input_tensor.device.type == "xpu": |
| 42 | ++ # print( |
| 43 | ++ # f"Intercepted interpolate call for XPU tensor at device {input_tensor.device}. Moving to CPU for computation." |
| 44 | ++ # ) |
| 45 | ++ original_device = input_tensor.device |
| 46 | ++ |
| 47 | ++ # Move input to CPU |
| 48 | ++ input_on_cpu = input_tensor.to("cpu") |
| 49 | ++ |
| 50 | ++ # Call the original interpolate function on CPU |
| 51 | ++ result_on_cpu = _original_interpolate_func(input_on_cpu, *args, **kwargs) |
| 52 | ++ |
| 53 | ++ # Move the result back to the original XPU device |
| 54 | ++ result_on_xpu = result_on_cpu.to(original_device) |
| 55 | ++ # print( |
| 56 | ++ # f"Interpolation completed on CPU, result moved back to {original_device}." |
| 57 | ++ # ) |
| 58 | ++ return result_on_xpu |
| 59 | ++ else: |
| 60 | ++ # If not an XPU tensor, just call the original function directly |
| 61 | ++ return _original_interpolate_func(input_tensor, *args, **kwargs) |
| 62 | ++ |
| 63 | ++ # Replace the original function with our custom one |
| 64 | ++ F.interpolate = _custom_interpolate |
| 65 | ++ _is_interpolate_patched = True |
| 66 | ++ print( |
| 67 | ++ "Successfully patched torch.nn.functional.interpolate to handle XPU tensors on CPU." |
| 68 | ++ ) |
| 69 | ++ |
| 70 | ++ |
| 71 | ++ def unpatch_xpu_interpolate_to_cpu(): |
| 72 | ++ """ |
| 73 | ++ Restores the original torch.nn.functional.interpolate function if it was patched. |
| 74 | ++ """ |
| 75 | ++ global _original_interpolate_func, _is_interpolate_patched |
| 76 | ++ |
| 77 | ++ if not _is_interpolate_patched: |
| 78 | ++ print( |
| 79 | ++ "torch.nn.functional.interpolate is not currently patched. Skipping unpatch." |
| 80 | ++ ) |
| 81 | ++ return |
| 82 | ++ |
| 83 | ++ if _original_interpolate_func is not None: |
| 84 | ++ F.interpolate = _original_interpolate_func |
| 85 | ++ _original_interpolate_func = None |
| 86 | ++ _is_interpolate_patched = False |
| 87 | ++ print("Successfully unpatched torch.nn.functional.interpolate.") |
| 88 | ++ else: |
| 89 | ++ print("Error: Could not unpatch. Original function reference missing.") |
| 90 | ++ |
| 91 | ++ |
| 92 | ++ patch_xpu_interpolate_to_cpu() |
| 93 | + def is_ascend_npu(): |
| 94 | + global npu_available |
| 95 | + if npu_available: |
| 96 | +@@ -720,7 +804,6 @@ def cleanup_models_gc(): |
| 97 | + logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) |
| 98 | + |
| 99 | + |
| 100 | +- |
| 101 | + def cleanup_models(): |
| 102 | + to_delete = [] |
| 103 | + for i in range(len(current_loaded_models)): |
| 104 | +@@ -1399,7 +1482,7 @@ def unload_all_models(): |
| 105 | + free_memory(1e30, get_torch_device()) |
| 106 | + |
| 107 | + |
| 108 | +-#TODO: might be cleaner to put this somewhere else |
| 109 | ++# TODO: might be cleaner to put this somewhere else |
| 110 | + import threading |
| 111 | + |
| 112 | + class InterruptProcessingException(Exception): |
0 commit comments