diff --git a/JC2.py b/JC2.py index 269cb1f..10aedca 100644 --- a/JC2.py +++ b/JC2.py @@ -111,6 +111,8 @@ def load_models(model_path, dtype, device="cuda", max_memory=None): try: if dtype == "nf4": + assert torch.backends.mps.is_available() == False, "NF4 is not currently supported on MPS/Apple Silicon" + from transformers import BitsAndBytesConfig nf4_config = BitsAndBytesConfig( load_in_4bit=True, @@ -138,8 +140,8 @@ def load_models(model_path, dtype, device="cuda", max_memory=None): text_model = AutoModelForCausalLM.from_pretrained( model_path, quantization_config=nf4_config, - device_map=device if device == "cuda" else {"": device}, - torch_dtype=torch.bfloat16, + device_map=device if device == "mps" or device == "cuda" else {"": device}, + torch_dtype=torch.float16 if device == "mps" else torch.bfloat16, max_memory=max_memory # 添加 max_memory 参数 ).eval() @@ -148,7 +150,7 @@ def load_models(model_path, dtype, device="cuda", max_memory=None): text_model = PeftModel.from_pretrained( model=text_model, model_id=LORA_PATH, - device_map=device if device == "cuda" else {"": device}, + device_map=device if device == "mps" or device == "cuda" else {"": device}, quantization_config=nf4_config ) text_model = text_model.merge_and_unload( @@ -169,7 +171,12 @@ def load_models(model_path, dtype, device="cuda", max_memory=None): ) image_adapter.eval().to(device) else: # bf16 - print("Loading in bfloat16") + if torch.backends.mps.is_available(): + print("MPS Detected (Apple Silicon), fallback to fp16") + print("Loading in float16 (without AMP)") + else: + print("Loading in bfloat16") + print("Loading CLIP") clip_processor = AutoProcessor.from_pretrained(CLIP_PATH) clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model @@ -188,8 +195,8 @@ def load_models(model_path, dtype, device="cuda", max_memory=None): print(f"Loading LLM: {model_path}") text_model = AutoModelForCausalLM.from_pretrained( model_path, - device_map="auto", - torch_dtype=torch.bfloat16, + device_map=device if device == "mps" or device == "cuda" else {"": device}, + torch_dtype=torch.float16 if device == "mps" else torch.bfloat16, max_memory=max_memory # 添加 max_memory 参数 ).eval() @@ -198,7 +205,7 @@ def load_models(model_path, dtype, device="cuda", max_memory=None): text_model = PeftModel.from_pretrained( model=text_model, model_id=LORA_PATH, - device_map=device if device == "cuda" else {"": device} + device_map=device if device == "mps" or device == "cuda" else {"": device}, ) text_model = text_model.merge_and_unload( safe_merge=True @@ -329,11 +336,19 @@ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_leng continue # Embed image - with torch.amp.autocast_mode.autocast(device, enabled=True): + if torch.backends.mps.is_available(): + model.clip_model.to(torch.float16) + model.image_adapter.to(torch.float16) + pixel_values = pixel_values.to(torch.float16) vision_outputs = model.clip_model(pixel_values=pixel_values, output_hidden_states=True) image_features = vision_outputs.hidden_states embedded_images = model.image_adapter(image_features).to(device) - + else: + with torch.amp.autocast_mode.autocast(device, enabled=True): + vision_outputs = model.clip_model(pixel_values=pixel_values, output_hidden_states=True) + image_features = vision_outputs.hidden_states + embedded_images = model.image_adapter(image_features).to(device) + # Build the conversation convo = [ { @@ -512,7 +527,7 @@ def joycaption2( llm_model_path_cache = os.path.join(comfy_model_dir, "cache--" + sanitized_model_name) # 初始设备设置为 'cuda' - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' model_loaded_on = device # 跟踪模型加载在哪个设备上 try: @@ -757,7 +772,7 @@ def joycaption2_simple( llm_model_path_cache = os.path.join(comfy_model_dir, "cache--" + sanitized_model_name) # 初始设备设置为 'cuda' - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' model_loaded_on = device # 跟踪模型加载在哪个设备上 try: