Skip to content

Commit 927de0e

Browse files
authored
Omni: Add benchmark for wan2.2 and flux.1-dev (#94)
Omni: Add benchmark for wan2.2 and flux.1-dev
1 parent 9e4d74c commit 927de0e

File tree

3 files changed

+510
-0
lines changed

3 files changed

+510
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import openai
2+
import time
3+
import statistics
4+
import sys
5+
6+
BASE_URL = "http://localhost:9997"
7+
from xinference.client import Client
8+
9+
x_client = Client(BASE_URL)
10+
11+
PROMPT = "A Shiba Inu chasing butterflies on a sunny grassy field, cartoon style, with vibrant colors."
12+
IMAGE_SIZE = "1024x1024"
13+
NUM_IMAGES = 1
14+
15+
NUM_WARMUP = 3
16+
NUM_TRIALS = 5
17+
18+
# --- Initialize OpenAI Client ---
19+
API_KEY = "not-really-a-secret-key"
20+
try:
21+
OPENAI_URL = BASE_URL + "/v1"
22+
client = openai.Client(api_key=API_KEY, base_url=OPENAI_URL)
23+
print(f"Initialized OpenAI client for base URL: {OPENAI_URL}")
24+
except Exception as e:
25+
print(f"Error initializing OpenAI client: {e}")
26+
print("Please ensure your 'openai' package is up to date and base_url is correct.")
27+
sys.exit(1)
28+
29+
30+
def benchmark_t2i_model(model_uid, model_path=None, relaunch=True):
31+
if relaunch:
32+
model_uid = x_client.launch_model(
33+
model_uid=model_uid,
34+
model_name=model_uid,
35+
model_engine="transformers",
36+
model_type="image",
37+
model_path=model_path,
38+
)
39+
40+
# --- Warm-up requests ---
41+
print(f"\n--- Starting Warm-up ({NUM_WARMUP} requests) ---")
42+
print("These requests are to let the service initialize and are not timed for results.")
43+
for i in range(NUM_WARMUP):
44+
print(f"Warm-up request {i+1}/{NUM_WARMUP}...")
45+
start_time = time.perf_counter()
46+
try:
47+
response = client.images.generate(
48+
model=model_uid,
49+
prompt=PROMPT,
50+
size=IMAGE_SIZE,
51+
n=NUM_IMAGES,
52+
)
53+
end_time = time.perf_counter()
54+
duration = end_time - start_time
55+
print(f" Warm-up request successful. Latency: {duration:.4f} seconds.")
56+
# Optionally, you can print parts of the response, e.g., print(response.data[0].url)
57+
except openai.APIError as e:
58+
print(f" Warm-up request failed with API error: {e}")
59+
print(
60+
" Please check if the local service is running and configured correctly."
61+
)
62+
print(
63+
f" Model '{model_uid}' might not be available or parameters like '{IMAGE_SIZE}'/'steps' are unsupported."
64+
)
65+
sys.exit(1)
66+
except Exception as e:
67+
print(f" Warm-up request failed with a general error: {e}")
68+
print(
69+
" Perhaps the local service is not reachable or there's a network issue."
70+
)
71+
sys.exit(1)
72+
73+
# --- Benchmark Trials ---
74+
print(f"\n--- Starting Benchmark Trials ({NUM_TRIALS} requests) ---")
75+
latencies = []
76+
for i in range(NUM_TRIALS):
77+
print(f"Trial request {i+1}/{NUM_TRIALS}...")
78+
start_time = time.perf_counter()
79+
try:
80+
response = client.images.generate(
81+
model=model_uid,
82+
prompt=PROMPT,
83+
size=IMAGE_SIZE,
84+
n=NUM_IMAGES,
85+
)
86+
end_time = time.perf_counter()
87+
duration = end_time - start_time
88+
latencies.append(duration)
89+
print(f" Trial successful. Latency: {duration:.4f} seconds.")
90+
# Optionally, you can print parts of the response, e.g., print(response.data[0].url)
91+
except openai.APIError as e:
92+
print(f" Trial request failed with API error: {e}")
93+
print(" Encountered an error during trials. Stopping benchmark.")
94+
break # Stop if an error occurs to avoid skewing results
95+
except Exception as e:
96+
print(f" Trial request failed with a general error: {e}")
97+
print(" Encountered an error during trials. Stopping benchmark.")
98+
break
99+
100+
# --- Results ---
101+
print("\n--- Benchmark Results ---")
102+
if not latencies:
103+
print("No successful trials completed to report statistics.")
104+
else:
105+
print(f"Total successful trials: {len(latencies)}")
106+
print(f"Configuration:")
107+
print(f" Model: {model_uid}")
108+
print(f" Prompt: '{PROMPT}'")
109+
print(f" Image Size: {IMAGE_SIZE}")
110+
print(f" Number of Images per request: {NUM_IMAGES}")
111+
112+
print(f" Warm-up requests: {NUM_WARMUP}")
113+
print(f" Trial requests: {NUM_TRIALS}")
114+
115+
min_latency = min(latencies)
116+
max_latency = max(latencies)
117+
avg_latency = statistics.mean(latencies)
118+
119+
# Standard deviation requires at least 2 data points
120+
if len(latencies) > 1:
121+
std_dev_latency = statistics.stdev(latencies)
122+
# print(f"Minimum Latency: {min_latency:.4f} seconds")
123+
# print(f"Maximum Latency: {max_latency:.4f} seconds")
124+
print(f"Average Latency: {avg_latency:.4f} seconds")
125+
print(f"Std Deviation: {std_dev_latency:.4f} seconds")
126+
else:
127+
print(
128+
f"Latency: {avg_latency:.4f} seconds (only one trial completed)"
129+
)
130+
131+
if relaunch:
132+
x_client.terminate_model(model_uid=model_uid)
133+
print(f"\nBenchmark for {model_uid} finished.")
134+
135+
return avg_latency
136+
137+
benchmark_t2i_model(
138+
model_uid="sd3.5-medium",
139+
model_path="/llm/models/stable-diffusion-3.5-medium/",
140+
)
141+
benchmark_t2i_model(
142+
model_uid="FLUX.1-dev",
143+
model_path="/llm/models/FLUX.1-dev/",
144+
)
145+
benchmark_t2i_model(
146+
model_uid="HunyuanDiT-v1.2",
147+
model_path="/llm/models/HunyuanDiT-v1.2-Diffusers/",
148+
)

0 commit comments

Comments
 (0)