@@ -2835,24 +2835,7 @@ def __call__(
28352835 )
28362836 llama .eval (tokens )
28372837 else :
2838- image_bytes = self .load_image (value )
2839- embed = self ._embed_image_bytes (image_bytes , llama .context_params .n_threads_batch )
2840- if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
2841- raise ValueError (
2842- f"Prompt exceeds n_ctx: { llama .n_tokens + embed .contents .n_image_pos } > { llama .n_ctx ()} "
2843- )
2844- n_past = ctypes .c_int (llama .n_tokens )
2845- n_past_p = ctypes .pointer (n_past )
2846- with suppress_stdout_stderr (disable = self .verbose ):
2847- self ._llava_cpp .llava_eval_image_embed (
2848- llama .ctx ,
2849- embed ,
2850- llama .n_batch ,
2851- n_past_p ,
2852- )
2853- # Required to avoid issues with hf tokenizer
2854- llama .input_ids [llama .n_tokens : n_past .value ] = - 1
2855- llama .n_tokens = n_past .value
2838+ self .eval_image (llama , value )
28562839
28572840 # Get prompt tokens to avoid a cache miss
28582841 prompt = llama .input_ids [: llama .n_tokens ].tolist ()
@@ -2938,6 +2921,26 @@ def __call__(
29382921 )
29392922 return _convert_completion_to_chat (completion_or_chunks , stream = stream )
29402923
2924+ def eval_image (self , llama : llama .Llama , image_url : str ):
2925+ image_bytes = self .load_image (image_url )
2926+ embed = self ._embed_image_bytes (image_bytes , llama .context_params .n_threads_batch )
2927+ if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
2928+ raise ValueError (
2929+ f"Prompt exceeds n_ctx: { llama .n_tokens + embed .contents .n_image_pos } > { llama .n_ctx ()} "
2930+ )
2931+ n_past = ctypes .c_int (llama .n_tokens )
2932+ n_past_p = ctypes .pointer (n_past )
2933+ with suppress_stdout_stderr (disable = self .verbose ):
2934+ self ._llava_cpp .llava_eval_image_embed (
2935+ llama .ctx ,
2936+ embed ,
2937+ llama .n_batch ,
2938+ n_past_p ,
2939+ )
2940+ # Required to avoid issues with hf tokenizer
2941+ llama .input_ids [llama .n_tokens : n_past .value ] = - 1
2942+ llama .n_tokens = n_past .value
2943+
29412944 @staticmethod
29422945 def _load_image (image_url : str ) -> bytes :
29432946 # TODO: Add Pillow support for other image formats beyond (jpg, png)
@@ -3435,10 +3438,10 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
34353438 if pos != - 1 :
34363439 assert len (copied_urls ) > 0
34373440 if pos > 0 :
3438- split_text += [( "text" , remaining [:pos ])]
3439- split_text += [( "text" , "\n \n <start_of_image>" )]
3440- split_text += [( "image_url" , copied_urls .pop (0 ))]
3441- split_text += [( "text" , "<end_of_image>\n \n " )]
3441+ split_text . append (( "text" , remaining [:pos ]))
3442+ split_text . append (( "text" , "\n \n <start_of_image>" ))
3443+ split_text . append (( "image_url" , copied_urls .pop (0 )))
3444+ split_text . append (( "text" , "<end_of_image>\n \n " ))
34423445 remaining = remaining [pos + len (image_placeholder ):]
34433446 else :
34443447 assert len (copied_urls ) == 0
@@ -3461,6 +3464,60 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
34613464 image_urls .append (content ["url" ])
34623465 return image_urls
34633466
3467+ def eval_image (self , llama : llama .Llama , image_url : str ):
3468+ import llama_cpp
3469+
3470+ img_bytes = self .load_image (image_url )
3471+ img_u8_p = self ._llava_cpp .clip_image_u8_init ()
3472+ if not self ._llava_cpp .clip_image_load_from_bytes (
3473+ ctypes .create_string_buffer (img_bytes , len (img_bytes )),
3474+ ctypes .c_size_t (len (img_bytes )),
3475+ img_u8_p ,
3476+ ):
3477+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3478+ raise ValueError ("Failed to load image." )
3479+
3480+ img_f32 = self ._llava_cpp .clip_image_f32_batch ()
3481+ img_f32_p = ctypes .byref (img_f32 )
3482+ if not self ._llava_cpp .clip_image_preprocess (self .clip_ctx , img_u8_p , img_f32_p ):
3483+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3484+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3485+ raise ValueError ("Failed to preprocess image." )
3486+
3487+ n_embd = llama_cpp .llama_model_n_embd (llama ._model .model )
3488+ n_tokens = 256
3489+ embed = (ctypes .c_float * (n_tokens * n_embd ))()
3490+ if not self ._llava_cpp .clip_image_batch_encode (self .clip_ctx , llama .n_threads , img_f32_p , embed ):
3491+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3492+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3493+ raise ValueError ("Failed to encode image." )
3494+
3495+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3496+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3497+ llama_cpp .llama_set_causal_attn (llama .ctx , False )
3498+
3499+ seq_id_0 = (ctypes .c_int32 * 1 )()
3500+ seq_ids = (ctypes .POINTER (ctypes .c_int32 ) * (n_tokens + 1 ))()
3501+ for i in range (n_tokens ):
3502+ seq_ids [i ] = seq_id_0
3503+
3504+ batch = llama_cpp .llama_batch ()
3505+ batch .n_tokens = n_tokens
3506+ batch .token = None
3507+ batch .embd = embed
3508+ batch .pos = (ctypes .c_int32 * n_tokens )(* [i + llama .n_tokens for i in range (n_tokens )])
3509+ batch .seq_id = seq_ids
3510+ batch .n_seq_id = (ctypes .c_int32 * n_tokens )(* ([1 ] * n_tokens ))
3511+ batch .logits = (ctypes .c_int8 * n_tokens )()
3512+
3513+ if llama_cpp .llama_decode (llama .ctx , batch ):
3514+ raise ValueError ("Failed to decode image." )
3515+
3516+ llama_cpp .llama_set_causal_attn (llama .ctx , True )
3517+ # Required to avoid issues with hf tokenizer
3518+ llama .input_ids [llama .n_tokens : llama .n_tokens + n_tokens ] = - 1
3519+ llama .n_tokens += n_tokens
3520+
34643521
34653522@register_chat_completion_handler ("chatml-function-calling" )
34663523def chatml_function_calling (
0 commit comments