Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions JCBO.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,12 @@ def _load_model_shared(model_hf_id: str, quantization_mode: str, target_device:

if LIGER_KERNEL_AVAILABLE and enable_liger and "cuda" in str(model.device).lower(): # Check actual model device for LIGER
try:
print(f"JoyCaptionBetaOne (Shared): Applying LIGER kernel to {model_hf_id} on {model.device}...")
apply_liger_kernel_to_llama(model=model.language_model)
print(f"JoyCaptionBetaOne (Shared): Applying LIGER kernel to {model_hf_id} on {model.device}...{type(model.language_model)}")
from transformers import LlamaForCausalLM
config = model.language_model.config
wrapped_model = LlamaForCausalLM(config)
wrapped_model.model = model.language_model
apply_liger_kernel_to_llama(model=wrapped_model)
CACHED_LIGER_ENABLED = True
except Exception as e: print(f"JoyCaptionBetaOne (Shared): LIGER kernel apply failed for {model_hf_id}: {e}"); CACHED_LIGER_ENABLED = False
else: CACHED_LIGER_ENABLED = False
Expand Down Expand Up @@ -386,7 +390,8 @@ def caption_image(self, image: torch.Tensor, caption_type: str, caption_length:
with torch.cuda.amp.autocast(enabled=("cuda" in str(model_device).lower() and model.dtype != torch.float32)):
generate_ids = model.generate(**inputs_on_device, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature if temperature > 0 else None, top_p=top_p if temperature > 0 else None, use_cache=True)
except Exception as e:
print(f"{self.NODE_NAME}: Generation error: {e}")
import traceback
print(f"{self.NODE_NAME}: Generation error: {e}\n{traceback.format_exc()}")
if "out of memory" in str(e).lower() and "cuda" in str(model_device).lower():
print(f"{self.NODE_NAME}: OOM error detected. Clearing model cache."); _free_model_memory_shared()
return ([f"Error generating caption: {e}"],)
Expand Down Expand Up @@ -515,7 +520,8 @@ def caption_image_simple(self, image: torch.Tensor, caption_type: str, caption_l
with torch.cuda.amp.autocast(enabled=("cuda" in str(model_device).lower() and model.dtype != torch.float32)):
generate_ids = model.generate(**inputs_on_device, max_new_tokens=max_new_tokens, do_sample=(temperature > 0), temperature=temperature if temperature > 0 else None, top_p=top_p if temperature > 0 else None, use_cache=True)
except Exception as e:
print(f"{self.NODE_NAME}: Generation error: {e}")
import traceback
print(f"{self.NODE_NAME}: Generation error: {e}\n{traceback.format_exc()}")
if "out of memory" in str(e).lower() and "cuda" in str(model_device).lower():
print(f"{self.NODE_NAME}: OOM error detected. Clearing model cache."); _free_model_memory_shared()
return ([f"Error generating caption: {e}"],)
Expand Down