From b46993fa8de70b4d0cb2dc09ee88d99d136101da Mon Sep 17 00:00:00 2001 From: Yogurt Wang Date: Thu, 28 Aug 2025 10:55:02 +0800 Subject: [PATCH] fix LIGER kernel compatibility and error handling Signed-off-by: Yogurt Wang --- JCBO.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/JCBO.py b/JCBO.py index e3b99a0..b5c8989 100644 --- a/JCBO.py +++ b/JCBO.py @@ -276,8 +276,12 @@ def _load_model_shared(model_hf_id: str, quantization_mode: str, target_device: if LIGER_KERNEL_AVAILABLE and enable_liger and "cuda" in str(model.device).lower(): # Check actual model device for LIGER try: - print(f"JoyCaptionBetaOne (Shared): Applying LIGER kernel to {model_hf_id} on {model.device}...") - apply_liger_kernel_to_llama(model=model.language_model) + print(f"JoyCaptionBetaOne (Shared): Applying LIGER kernel to {model_hf_id} on {model.device}...{type(model.language_model)}") + from transformers import LlamaForCausalLM + config = model.language_model.config + wrapped_model = LlamaForCausalLM(config) + wrapped_model.model = model.language_model + apply_liger_kernel_to_llama(model=wrapped_model) CACHED_LIGER_ENABLED = True except Exception as e: print(f"JoyCaptionBetaOne (Shared): LIGER kernel apply failed for {model_hf_id}: {e}"); CACHED_LIGER_ENABLED = False else: CACHED_LIGER_ENABLED = False @@ -386,7 +390,8 @@ def caption_image(self, image: torch.Tensor, caption_type: str, caption_length: with torch.cuda.amp.autocast(enabled=("cuda" in str(model_device).lower() and model.dtype != torch.float32)): generate_ids = model.generate(**inputs_on_device, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature if temperature > 0 else None, top_p=top_p if temperature > 0 else None, use_cache=True) except Exception as e: - print(f"{self.NODE_NAME}: Generation error: {e}") + import traceback + print(f"{self.NODE_NAME}: Generation error: {e}\n{traceback.format_exc()}") if "out of memory" in str(e).lower() and "cuda" in str(model_device).lower(): print(f"{self.NODE_NAME}: OOM error detected. Clearing model cache."); _free_model_memory_shared() return ([f"Error generating caption: {e}"],) @@ -515,7 +520,8 @@ def caption_image_simple(self, image: torch.Tensor, caption_type: str, caption_l with torch.cuda.amp.autocast(enabled=("cuda" in str(model_device).lower() and model.dtype != torch.float32)): generate_ids = model.generate(**inputs_on_device, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature if temperature > 0 else None, top_p=top_p if temperature > 0 else None, use_cache=True) except Exception as e: - print(f"{self.NODE_NAME}: Generation error: {e}") + import traceback + print(f"{self.NODE_NAME}: Generation error: {e}\n{traceback.format_exc()}") if "out of memory" in str(e).lower() and "cuda" in str(model_device).lower(): print(f"{self.NODE_NAME}: OOM error detected. Clearing model cache."); _free_model_memory_shared() return ([f"Error generating caption: {e}"],)