Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions JC2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def load_models(model_path, dtype, device="cuda", max_memory=None):

try:
if dtype == "nf4":
assert torch.backends.mps.is_available() == False, "NF4 is not currently supported on MPS/Apple Silicon"

from transformers import BitsAndBytesConfig
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
Expand Down Expand Up @@ -138,8 +140,8 @@ def load_models(model_path, dtype, device="cuda", max_memory=None):
text_model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=nf4_config,
device_map=device if device == "cuda" else {"": device},
torch_dtype=torch.bfloat16,
device_map=device if device == "mps" or device == "cuda" else {"": device},
torch_dtype=torch.float16 if device == "mps" else torch.bfloat16,
max_memory=max_memory # 添加 max_memory 参数
).eval()

Expand All @@ -148,7 +150,7 @@ def load_models(model_path, dtype, device="cuda", max_memory=None):
text_model = PeftModel.from_pretrained(
model=text_model,
model_id=LORA_PATH,
device_map=device if device == "cuda" else {"": device},
device_map=device if device == "mps" or device == "cuda" else {"": device},
quantization_config=nf4_config
)
text_model = text_model.merge_and_unload(
Expand All @@ -169,7 +171,12 @@ def load_models(model_path, dtype, device="cuda", max_memory=None):
)
image_adapter.eval().to(device)
else: # bf16
print("Loading in bfloat16")
if torch.backends.mps.is_available():
print("MPS Detected (Apple Silicon), fallback to fp16")
print("Loading in float16 (without AMP)")
else:
print("Loading in bfloat16")

print("Loading CLIP")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
Expand All @@ -188,8 +195,8 @@ def load_models(model_path, dtype, device="cuda", max_memory=None):
print(f"Loading LLM: {model_path}")
text_model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
device_map=device if device == "mps" or device == "cuda" else {"": device},
torch_dtype=torch.float16 if device == "mps" else torch.bfloat16,
max_memory=max_memory # 添加 max_memory 参数
).eval()

Expand All @@ -198,7 +205,7 @@ def load_models(model_path, dtype, device="cuda", max_memory=None):
text_model = PeftModel.from_pretrained(
model=text_model,
model_id=LORA_PATH,
device_map=device if device == "cuda" else {"": device}
device_map=device if device == "mps" or device == "cuda" else {"": device},
)
text_model = text_model.merge_and_unload(
safe_merge=True
Expand Down Expand Up @@ -329,11 +336,19 @@ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_leng
continue

# Embed image
with torch.amp.autocast_mode.autocast(device, enabled=True):
if torch.backends.mps.is_available():
model.clip_model.to(torch.float16)
model.image_adapter.to(torch.float16)
pixel_values = pixel_values.to(torch.float16)
vision_outputs = model.clip_model(pixel_values=pixel_values, output_hidden_states=True)
image_features = vision_outputs.hidden_states
embedded_images = model.image_adapter(image_features).to(device)

else:
with torch.amp.autocast_mode.autocast(device, enabled=True):
vision_outputs = model.clip_model(pixel_values=pixel_values, output_hidden_states=True)
image_features = vision_outputs.hidden_states
embedded_images = model.image_adapter(image_features).to(device)

# Build the conversation
convo = [
{
Expand Down Expand Up @@ -512,7 +527,7 @@ def joycaption2(
llm_model_path_cache = os.path.join(comfy_model_dir, "cache--" + sanitized_model_name)

# 初始设备设置为 'cuda'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model_loaded_on = device # 跟踪模型加载在哪个设备上

try:
Expand Down Expand Up @@ -757,7 +772,7 @@ def joycaption2_simple(
llm_model_path_cache = os.path.join(comfy_model_dir, "cache--" + sanitized_model_name)

# 初始设备设置为 'cuda'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model_loaded_on = device # 跟踪模型加载在哪个设备上

try:
Expand Down