Skip to content

Commit 4a18b22

Browse files
committed
succed to creat simple server and interact with it
Add auth endpoint Working with Single server Succeed multi at simple scenario Works for case where there are more batches than workers Mimic server error Auto test, and store in cache. Add shotedown mechanisim Satart working on CCC Signed-off-by: Elad Venezian <eladv@il.ibm.com>
1 parent 3fd80f1 commit 4a18b22

File tree

5 files changed

+520
-41
lines changed

5 files changed

+520
-41
lines changed

ccc_worker_server.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
2+
import logging
3+
import os
4+
import random
5+
import sys
6+
import threading
7+
import time
8+
9+
import requests
10+
from flask import Flask, jsonify, request
11+
from unitxt.inference import HFPipelineBasedInferenceEngine
12+
13+
logging.basicConfig(level=logging.INFO)
14+
15+
app = Flask(__name__)
16+
PORT = None
17+
18+
class Server:
19+
def __init__(self):
20+
self.inference_engine = None
21+
self.inactivity_timeout = 600
22+
self.monitor_thread = threading.Thread(target=self.monitor_activity, daemon=True)
23+
self.last_request_time = time.time()
24+
self.shutdown_flag = False
25+
self.monitor_thread.start()
26+
27+
def update_last_request_time(self):
28+
self.last_request_time = time.time()
29+
30+
def monitor_activity(self):
31+
while not self.shutdown_flag:
32+
time.sleep(5)
33+
if time.time() - self.last_request_time > self.inactivity_timeout:
34+
app.logger.info(f"No requests for {self.inactivity_timeout} seconds. Shutting down server...")
35+
try:
36+
requests.post(f"http://localhost:{PORT}/shutdown", timeout=5)
37+
except Exception:
38+
pass
39+
else:
40+
app.logger.info(
41+
f"{int(self.inactivity_timeout - (time.time() - self.last_request_time))} till shutdown...")
42+
43+
def shutdown_server(self):
44+
self.shutdown_flag = True
45+
app.logger.info("Server shutting down...")
46+
shutdown_func = request.environ.get("werkzeug.server.shutdown")
47+
if shutdown_func:
48+
shutdown_func()
49+
# Allow the shutdown process to complete, then force exit the program
50+
time.sleep(1)
51+
os._exit(0) # This immediately stops the program
52+
53+
def init_server(self, **kwargs):
54+
kwargs["use_cache"] =True
55+
self.inference_engine = HFPipelineBasedInferenceEngine(**kwargs)
56+
57+
def infer(self, **kwargs):
58+
inputs = []
59+
return self.inference_engine(inputs)
60+
61+
62+
server = Server()
63+
64+
@app.before_request
65+
def update_activity():
66+
server.update_last_request_time()
67+
68+
69+
@app.route("/shutdown", methods=["POST"])
70+
def shutdown():
71+
app.logger.info("Received shutdown request")
72+
server.shutdown_server()
73+
return jsonify({"message": "Shutting down server..."}), 200
74+
75+
76+
@app.route("/init_server", methods=["POST"])
77+
def init_server():
78+
kwargs = request.get_json()
79+
server.init_server(**kwargs)
80+
return jsonify("Accepted")
81+
82+
83+
@app.route("/<model>/v1/chat/completions", methods=["POST"])
84+
@app.route("/<model_prefix>/<model>/v1/chat/completions", methods=["POST"])
85+
def completions(model: str, model_prefix: str = "None"):
86+
if random.random() < 0:
87+
logging.error("Bad luck! Returning 500 with an error message.")
88+
app.logger.info("Server shutting down...")
89+
shutdown_func = request.environ.get("werkzeug.server.shutdown")
90+
if shutdown_func:
91+
shutdown_func()
92+
# Allow the shutdown process to complete, then force exit the program
93+
time.sleep(1)
94+
os._exit(0) # This immediately stops the program
95+
return jsonify({"error": "Bad luck, something went wrong!"}), 500
96+
97+
body = request.get_json()
98+
# validate that request parameters are equal to the model config. Print warnings if not.
99+
for k, v in body.items():
100+
if k == "messages":
101+
continue
102+
k = "model_name" if k == "model" else k
103+
attr = getattr(server.inference_engine, k, None)
104+
if attr is None:
105+
logging.warning(f"Warning: {k} is not an attribute in inference_engine")
106+
else:
107+
if attr != v:
108+
logging.warning(f"Warning: {k} value in boody({v}) is different from value in inference engine ({attr})")
109+
texts = [{"source": m[0]["content"]} for m in body["messages"]]
110+
predictions = server.inference_engine(texts)
111+
return jsonify({
112+
"choices": [{"message": {"role": "assistant","content": p}} for p in predictions],
113+
})
114+
115+
116+
@app.route("/status", methods=["GET"])
117+
def status():
118+
return "up", 200
119+
120+
121+
if __name__ == "__main__":
122+
PORT = sys.argv[1]
123+
app.run(host="0.0.0.0", port=PORT, debug=True)

src/unitxt/inference.py

Lines changed: 203 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
import json
88
import logging
99
import os
10+
import random
1011
import re
1112
import sys
13+
import threading
1214
import time
1315
import uuid
1416
from collections import Counter
17+
from concurrent.futures import Future, ThreadPoolExecutor, wait
1518
from datetime import datetime
1619
from itertools import islice
1720
from multiprocessing.pool import ThreadPool
@@ -30,6 +33,7 @@
3033
Union,
3134
)
3235

36+
import requests
3337
from datasets import Dataset, DatasetDict, Image
3438
from tqdm import tqdm, trange
3539
from tqdm.asyncio import tqdm_asyncio
@@ -276,7 +280,7 @@ def infer(
276280
if prediction is None:
277281
continue
278282
cache_key = self._get_cache_key(item)
279-
self._cache[cache_key] = prediction
283+
self.store_in_cache(cache_key, prediction)
280284
else:
281285
inferred_results = []
282286
# Combine cached and inferred results in original order
@@ -286,6 +290,9 @@ def infer(
286290
result.extend(batch_predictions)
287291
else:
288292
result = self._infer(dataset, return_meta_data)
293+
294+
result = self.post_process_results(result)
295+
289296
return ListWithMetadata(
290297
result,
291298
metadata={
@@ -295,6 +302,12 @@ def infer(
295302
},
296303
)
297304

305+
def store_in_cache(self, cache_key, prediction):
306+
self._cache[cache_key] = prediction
307+
308+
def post_process_results(self, result):
309+
return result
310+
298311
def _mock_infer(
299312
self,
300313
dataset: Union[List[Dict[str, Any]], Dataset],
@@ -1957,7 +1970,7 @@ def prepare_engine(self):
19571970
@staticmethod
19581971
def get_base_url_from_model_name(model_name: str):
19591972
base_url_template = (
1960-
"https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{}"
1973+
"http://localhost:5000/{}"
19611974
)
19621975
return base_url_template.format(
19631976
RITSInferenceEngine._get_model_name_for_endpoint(model_name)
@@ -3546,10 +3559,9 @@ def _infer(
35463559
dataset: Union[List[Dict[str, Any]], Dataset],
35473560
return_meta_data: bool = False,
35483561
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
3549-
if return_meta_data and not hasattr(self.engine, "get_return_object"):
3562+
if return_meta_data:
35503563
raise NotImplementedError(
3551-
f"Inference engine {self.engine.__class__.__name__} does not support return_meta_data as it "
3552-
f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
3564+
f"Inference engine {self.engine.__class__.__name__} does not support return_meta_data."
35533565
)
35543566

35553567
inputs = []
@@ -3576,3 +3588,189 @@ def _infer(
35763588
predictions.append(options_scores.most_common(1)[0][0])
35773589

35783590
return predictions
3591+
3592+
3593+
class MultiServersInferenceEngine(OpenAiInferenceEngine,
3594+
HFGenerationParamsMixin):
3595+
3596+
workers_url: List[str]
3597+
3598+
def post_server(self, server_url, endpoint, data):
3599+
headers = {"Content-Type": "application/json"}
3600+
response = requests.post(url=f"{server_url}/{endpoint}", json=data, headers=headers)
3601+
response.raise_for_status()
3602+
return response.json()
3603+
3604+
def prepare_engine(self):
3605+
from openai import OpenAI
3606+
self.lock = threading.Lock()
3607+
self.workers_state = {}
3608+
credentials = self._prepare_credentials()
3609+
for url in self.workers_url:
3610+
init_result = self.post_server(endpoint="init_server",server_url=url,
3611+
data={**self.to_dict([HFGenerationParamsMixin]), **{"model_name": self.model_name}})
3612+
if init_result == "Accepted":
3613+
self.add_worker(url, client=OpenAI(
3614+
api_key=credentials["api_key"],
3615+
base_url= f"{url}/{self.model_name}" + "/v1",
3616+
default_headers=self.get_default_headers(),
3617+
))
3618+
3619+
#def init_server_and_add_to_workers_list
3620+
3621+
3622+
def add_worker(self, url, client):
3623+
with self.lock:
3624+
self.workers_state[url] = {"status": "ready", "client": client}
3625+
3626+
def release_worker(self, url):
3627+
with self.lock:
3628+
self.workers_state[url]["status"] = "ready"
3629+
3630+
def assign_worker(self):
3631+
with self.lock:
3632+
while True:
3633+
# print("trying to assign worker...")
3634+
for url, rec in self.workers_state.items():
3635+
if rec["status"] == "ready":
3636+
rec["status"] ="assigned"
3637+
return url, rec["client"]
3638+
time.sleep(random.uniform(0, 1))
3639+
3640+
def _prepare_credentials(self) -> CredentialsOpenAi:
3641+
return {"api" + "_" + "key": "no-api-key",}
3642+
3643+
def _infer(
3644+
self,
3645+
dataset: Union[List[Dict[str, Any]], Dataset],
3646+
return_meta_data: bool = False,
3647+
) -> List[Any]: # Now returns a Future object
3648+
"""Runs inference in parallel, returning futures for each batch."""
3649+
# Lazy-initialize executor if not already created
3650+
if not hasattr(self, "_executor"):
3651+
self._executor = ThreadPoolExecutor(max_workers=len(self.workers_state))
3652+
3653+
# Submit the batch job
3654+
batch_future = self._executor.submit(self._run_batch, dataset, return_meta_data)
3655+
3656+
# Create individual futures that resolve when batch_future is done
3657+
element_futures = [Future() for _ in dataset]
3658+
3659+
def set_results(batch_fut: Future):
3660+
"""Callback to set individual results once batch computation is done."""
3661+
try:
3662+
results = batch_fut.result() # Get the batch results
3663+
for i, res in enumerate(results):
3664+
element_futures[i].set_result(res) # Set each individual future
3665+
except Exception as e:
3666+
for f in element_futures:
3667+
f.set_exception(e) # Propagate any exception
3668+
3669+
# Attach the callback to the batch future
3670+
batch_future.add_done_callback(set_results)
3671+
3672+
return element_futures # Return a list of futures
3673+
3674+
def _run_batch(self, batch, return_meta_data):
3675+
"""Helper function to process a batch inside a thread."""
3676+
logger.info(f"Trying to get assigned: {self.workers_state}")
3677+
url, client = self.assign_worker()
3678+
logger.info(f"Thread {url} processing batch: {self.workers_state}")
3679+
messages = [self.to_messages(instance) for instance in batch]
3680+
logger.info(f"a {url}")
3681+
try:
3682+
response = client.chat.completions.create(
3683+
messages=messages,
3684+
model=self.model_name,
3685+
**self._get_completion_kwargs(),
3686+
)
3687+
logger.info(f"response: {response}")
3688+
predictions = [r.message.content for r in response.choices]
3689+
result = [self.get_return_object(p, response, return_meta_data) for p in predictions]
3690+
finally:
3691+
logger.info(f"Thread {url} release state:")
3692+
self.release_worker(url)
3693+
logger.info(f"Thread {url} release state done: {self.workers_state}")
3694+
return result
3695+
3696+
def post_process_results(self, result):
3697+
futures = [r for r in result if isinstance(r, Future)]
3698+
if futures:
3699+
wait(futures)
3700+
3701+
return [r.result() if isinstance(r, Future) else r for r in result]
3702+
3703+
def store_in_cache(self, cache_key, prediction):
3704+
if isinstance(prediction, Future):
3705+
def store_after_pack_in_cache(future, cache_key):
3706+
prediction = future.result()
3707+
if prediction is not None:
3708+
self._cache[cache_key] = prediction
3709+
3710+
prediction.add_done_callback(lambda f, key=cache_key: store_after_pack_in_cache(f, key))
3711+
else:
3712+
self._cache[cache_key] = prediction
3713+
3714+
3715+
class CCCInferenceEngine(MultiServersInferenceEngine):
3716+
ccc_host: str
3717+
ccc_user: str
3718+
ccc_path: str
3719+
ccc_python: str
3720+
server_port: str = "5000"
3721+
num_of_workers: int = 5
3722+
workers_url: List[str] = []
3723+
3724+
def prepare_engine(self):
3725+
assert not self.workers_url, "CCCInferenceEngine doesn't support explicit setting of workers_url"
3726+
self.start_ccc_servers()
3727+
self.prepare_engine()
3728+
3729+
def start_ccc_servers(self):
3730+
import paramiko
3731+
ssh = paramiko.SSHClient()
3732+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
3733+
ssh.connect(self.ccc_host, username=self.ccc_user)
3734+
ssh.exec_command(f"mkdir -p {self.ccc_path}")
3735+
self.ccc_jobs = {}
3736+
for i in range(self.num_of_workers):
3737+
command = f"bash -l -c 'jbsub -queue x86_6h -cores 4+1 -require v100 -mem 24G -out ~/server{i}.log {self.ccc_python} /dccstor/fuse/unitxt/ccc_worker_server.py {self.server_port}'"
3738+
stdin, stdout, stderr = ssh.exec_command(command)
3739+
job_output = stdout.read().decode().strip()
3740+
job_error = stderr.read().decode().strip()
3741+
match = re.search(r"Job <(\d+)> is submitted", job_output)
3742+
if match:
3743+
job_id = match.group(1)
3744+
logger.info(f"Start job ID: {job_id}")
3745+
self.ccc_jobs[job_id] ={"status": "AVAIL", "log_id": i}
3746+
else:
3747+
raise RuntimeError(f"Failed to run jbsub on host {self.ccc_host}.\nstdout: {job_output}.\nstderr: {job_error}")
3748+
3749+
def run_monitor_ccc_jobs(ssh, sample_every):
3750+
while True:
3751+
command = "bash -l -c 'jbinfo'"
3752+
stdin, stdout, stderr = ssh.exec_command(command)
3753+
output = stdout.read().decode().strip()
3754+
#error = stderr.read().decode().strip()
3755+
for job_id in self.ccc_jobs.keys():
3756+
match = re.search(rf"^{job_id}\s+\S+\s+(\w+)", output, re.MULTILINE)
3757+
if match:
3758+
status = match.group(1)
3759+
if status != self.ccc_jobs[job_id]["status"]:
3760+
if self.ccc_jobs[job_id]["status"] == "RUN":
3761+
pass # add server to server list
3762+
elif status == "RUN":
3763+
pass # remove server from server list. Consider fetching the server log.
3764+
self.ccc_jobs[job_id]["status"] = status
3765+
logger.info(f"status has been changed: {job_id} - {status}")
3766+
3767+
3768+
time.sleep(sample_every)
3769+
3770+
thread = threading.Thread(target=run_monitor_ccc_jobs, args=(ssh, 10))
3771+
thread.daemon = True #
3772+
thread.start()
3773+
3774+
3775+
time.sleep(200) # This keeps the main thread alive so the background thread can continue
3776+

0 commit comments

Comments
 (0)