77import json
88import logging
99import os
10+ import random
1011import re
1112import sys
13+ import threading
1214import time
1315import uuid
1416from collections import Counter
17+ from concurrent .futures import Future , ThreadPoolExecutor , wait
1518from datetime import datetime
1619from itertools import islice
1720from multiprocessing .pool import ThreadPool
3033 Union ,
3134)
3235
36+ import requests
3337from datasets import Dataset , DatasetDict , Image
3438from tqdm import tqdm , trange
3539from tqdm .asyncio import tqdm_asyncio
@@ -276,7 +280,7 @@ def infer(
276280 if prediction is None :
277281 continue
278282 cache_key = self ._get_cache_key (item )
279- self ._cache [ cache_key ] = prediction
283+ self .store_in_cache ( cache_key , prediction )
280284 else :
281285 inferred_results = []
282286 # Combine cached and inferred results in original order
@@ -286,6 +290,9 @@ def infer(
286290 result .extend (batch_predictions )
287291 else :
288292 result = self ._infer (dataset , return_meta_data )
293+
294+ result = self .post_process_results (result )
295+
289296 return ListWithMetadata (
290297 result ,
291298 metadata = {
@@ -295,6 +302,12 @@ def infer(
295302 },
296303 )
297304
305+ def store_in_cache (self , cache_key , prediction ):
306+ self ._cache [cache_key ] = prediction
307+
308+ def post_process_results (self , result ):
309+ return result
310+
298311 def _mock_infer (
299312 self ,
300313 dataset : Union [List [Dict [str , Any ]], Dataset ],
@@ -1957,7 +1970,7 @@ def prepare_engine(self):
19571970 @staticmethod
19581971 def get_base_url_from_model_name (model_name : str ):
19591972 base_url_template = (
1960- "https ://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com /{}"
1973+ "http ://localhost:5000 /{}"
19611974 )
19621975 return base_url_template .format (
19631976 RITSInferenceEngine ._get_model_name_for_endpoint (model_name )
@@ -3546,10 +3559,9 @@ def _infer(
35463559 dataset : Union [List [Dict [str , Any ]], Dataset ],
35473560 return_meta_data : bool = False ,
35483561 ) -> Union [List [str ], List [TextGenerationInferenceOutput ]]:
3549- if return_meta_data and not hasattr ( self . engine , "get_return_object" ) :
3562+ if return_meta_data :
35503563 raise NotImplementedError (
3551- f"Inference engine { self .engine .__class__ .__name__ } does not support return_meta_data as it "
3552- f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
3564+ f"Inference engine { self .engine .__class__ .__name__ } does not support return_meta_data."
35533565 )
35543566
35553567 inputs = []
@@ -3576,3 +3588,189 @@ def _infer(
35763588 predictions .append (options_scores .most_common (1 )[0 ][0 ])
35773589
35783590 return predictions
3591+
3592+
3593+ class MultiServersInferenceEngine (OpenAiInferenceEngine ,
3594+ HFGenerationParamsMixin ):
3595+
3596+ workers_url : List [str ]
3597+
3598+ def post_server (self , server_url , endpoint , data ):
3599+ headers = {"Content-Type" : "application/json" }
3600+ response = requests .post (url = f"{ server_url } /{ endpoint } " , json = data , headers = headers )
3601+ response .raise_for_status ()
3602+ return response .json ()
3603+
3604+ def prepare_engine (self ):
3605+ from openai import OpenAI
3606+ self .lock = threading .Lock ()
3607+ self .workers_state = {}
3608+ credentials = self ._prepare_credentials ()
3609+ for url in self .workers_url :
3610+ init_result = self .post_server (endpoint = "init_server" ,server_url = url ,
3611+ data = {** self .to_dict ([HFGenerationParamsMixin ]), ** {"model_name" : self .model_name }})
3612+ if init_result == "Accepted" :
3613+ self .add_worker (url , client = OpenAI (
3614+ api_key = credentials ["api_key" ],
3615+ base_url = f"{ url } /{ self .model_name } " + "/v1" ,
3616+ default_headers = self .get_default_headers (),
3617+ ))
3618+
3619+ #def init_server_and_add_to_workers_list
3620+
3621+
3622+ def add_worker (self , url , client ):
3623+ with self .lock :
3624+ self .workers_state [url ] = {"status" : "ready" , "client" : client }
3625+
3626+ def release_worker (self , url ):
3627+ with self .lock :
3628+ self .workers_state [url ]["status" ] = "ready"
3629+
3630+ def assign_worker (self ):
3631+ with self .lock :
3632+ while True :
3633+ # print("trying to assign worker...")
3634+ for url , rec in self .workers_state .items ():
3635+ if rec ["status" ] == "ready" :
3636+ rec ["status" ] = "assigned"
3637+ return url , rec ["client" ]
3638+ time .sleep (random .uniform (0 , 1 ))
3639+
3640+ def _prepare_credentials (self ) -> CredentialsOpenAi :
3641+ return {"api" + "_" + "key" : "no-api-key" ,}
3642+
3643+ def _infer (
3644+ self ,
3645+ dataset : Union [List [Dict [str , Any ]], Dataset ],
3646+ return_meta_data : bool = False ,
3647+ ) -> List [Any ]: # Now returns a Future object
3648+ """Runs inference in parallel, returning futures for each batch."""
3649+ # Lazy-initialize executor if not already created
3650+ if not hasattr (self , "_executor" ):
3651+ self ._executor = ThreadPoolExecutor (max_workers = len (self .workers_state ))
3652+
3653+ # Submit the batch job
3654+ batch_future = self ._executor .submit (self ._run_batch , dataset , return_meta_data )
3655+
3656+ # Create individual futures that resolve when batch_future is done
3657+ element_futures = [Future () for _ in dataset ]
3658+
3659+ def set_results (batch_fut : Future ):
3660+ """Callback to set individual results once batch computation is done."""
3661+ try :
3662+ results = batch_fut .result () # Get the batch results
3663+ for i , res in enumerate (results ):
3664+ element_futures [i ].set_result (res ) # Set each individual future
3665+ except Exception as e :
3666+ for f in element_futures :
3667+ f .set_exception (e ) # Propagate any exception
3668+
3669+ # Attach the callback to the batch future
3670+ batch_future .add_done_callback (set_results )
3671+
3672+ return element_futures # Return a list of futures
3673+
3674+ def _run_batch (self , batch , return_meta_data ):
3675+ """Helper function to process a batch inside a thread."""
3676+ logger .info (f"Trying to get assigned: { self .workers_state } " )
3677+ url , client = self .assign_worker ()
3678+ logger .info (f"Thread { url } processing batch: { self .workers_state } " )
3679+ messages = [self .to_messages (instance ) for instance in batch ]
3680+ logger .info (f"a { url } " )
3681+ try :
3682+ response = client .chat .completions .create (
3683+ messages = messages ,
3684+ model = self .model_name ,
3685+ ** self ._get_completion_kwargs (),
3686+ )
3687+ logger .info (f"response: { response } " )
3688+ predictions = [r .message .content for r in response .choices ]
3689+ result = [self .get_return_object (p , response , return_meta_data ) for p in predictions ]
3690+ finally :
3691+ logger .info (f"Thread { url } release state:" )
3692+ self .release_worker (url )
3693+ logger .info (f"Thread { url } release state done: { self .workers_state } " )
3694+ return result
3695+
3696+ def post_process_results (self , result ):
3697+ futures = [r for r in result if isinstance (r , Future )]
3698+ if futures :
3699+ wait (futures )
3700+
3701+ return [r .result () if isinstance (r , Future ) else r for r in result ]
3702+
3703+ def store_in_cache (self , cache_key , prediction ):
3704+ if isinstance (prediction , Future ):
3705+ def store_after_pack_in_cache (future , cache_key ):
3706+ prediction = future .result ()
3707+ if prediction is not None :
3708+ self ._cache [cache_key ] = prediction
3709+
3710+ prediction .add_done_callback (lambda f , key = cache_key : store_after_pack_in_cache (f , key ))
3711+ else :
3712+ self ._cache [cache_key ] = prediction
3713+
3714+
3715+ class CCCInferenceEngine (MultiServersInferenceEngine ):
3716+ ccc_host : str
3717+ ccc_user : str
3718+ ccc_path : str
3719+ ccc_python : str
3720+ server_port : str = "5000"
3721+ num_of_workers : int = 5
3722+ workers_url : List [str ] = []
3723+
3724+ def prepare_engine (self ):
3725+ assert not self .workers_url , "CCCInferenceEngine doesn't support explicit setting of workers_url"
3726+ self .start_ccc_servers ()
3727+ self .prepare_engine ()
3728+
3729+ def start_ccc_servers (self ):
3730+ import paramiko
3731+ ssh = paramiko .SSHClient ()
3732+ ssh .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
3733+ ssh .connect (self .ccc_host , username = self .ccc_user )
3734+ ssh .exec_command (f"mkdir -p { self .ccc_path } " )
3735+ self .ccc_jobs = {}
3736+ for i in range (self .num_of_workers ):
3737+ command = f"bash -l -c 'jbsub -queue x86_6h -cores 4+1 -require v100 -mem 24G -out ~/server{ i } .log { self .ccc_python } /dccstor/fuse/unitxt/ccc_worker_server.py { self .server_port } '"
3738+ stdin , stdout , stderr = ssh .exec_command (command )
3739+ job_output = stdout .read ().decode ().strip ()
3740+ job_error = stderr .read ().decode ().strip ()
3741+ match = re .search (r"Job <(\d+)> is submitted" , job_output )
3742+ if match :
3743+ job_id = match .group (1 )
3744+ logger .info (f"Start job ID: { job_id } " )
3745+ self .ccc_jobs [job_id ] = {"status" : "AVAIL" , "log_id" : i }
3746+ else :
3747+ raise RuntimeError (f"Failed to run jbsub on host { self .ccc_host } .\n stdout: { job_output } .\n stderr: { job_error } " )
3748+
3749+ def run_monitor_ccc_jobs (ssh , sample_every ):
3750+ while True :
3751+ command = "bash -l -c 'jbinfo'"
3752+ stdin , stdout , stderr = ssh .exec_command (command )
3753+ output = stdout .read ().decode ().strip ()
3754+ #error = stderr.read().decode().strip()
3755+ for job_id in self .ccc_jobs .keys ():
3756+ match = re .search (rf"^{ job_id } \s+\S+\s+(\w+)" , output , re .MULTILINE )
3757+ if match :
3758+ status = match .group (1 )
3759+ if status != self .ccc_jobs [job_id ]["status" ]:
3760+ if self .ccc_jobs [job_id ]["status" ] == "RUN" :
3761+ pass # add server to server list
3762+ elif status == "RUN" :
3763+ pass # remove server from server list. Consider fetching the server log.
3764+ self .ccc_jobs [job_id ]["status" ] = status
3765+ logger .info (f"status has been changed: { job_id } - { status } " )
3766+
3767+
3768+ time .sleep (sample_every )
3769+
3770+ thread = threading .Thread (target = run_monitor_ccc_jobs , args = (ssh , 10 ))
3771+ thread .daemon = True #
3772+ thread .start ()
3773+
3774+
3775+ time .sleep (200 ) # This keeps the main thread alive so the background thread can continue
3776+
0 commit comments