Skip to content

Commit ba605ee

Browse files
committed
addressed comments, fixed rich diff table
1 parent fbfdb91 commit ba605ee

File tree

12 files changed

+253
-135
lines changed

12 files changed

+253
-135
lines changed

ads/aqua/common/entities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ class ComputeShapeSummary(Serializable):
101101
including CPU, memory, and optional GPU characteristics.
102102
"""
103103

104+
available: Optional[bool] = Field(
105+
default = False,
106+
description="True if shape is available on user tenancy, "
107+
)
104108
core_count: Optional[int] = Field(
105109
default=None,
106110
description="Total number of CPU cores available for the compute shape.",

ads/aqua/common/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,6 @@ def load_gpu_shapes_index(
12871287

12881288
# Merge: remote shapes override local
12891289
local_shapes = local_data.get("shapes", {})
1290-
remote_data = {}
12911290
remote_shapes = remote_data.get("shapes", {})
12921291
merged_shapes = {**local_shapes, **remote_shapes}
12931292

ads/aqua/extension/recommend_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ads.aqua.extension.base_handler import AquaAPIhandler
55
from ads.aqua.extension.errors import Errors
66
from ads.aqua.shaperecommend.recommend import AquaRecommendApp
7+
from ads.config import COMPARTMENT_OCID
78

89

910
class AquaRecommendHandler(AquaAPIhandler):

ads/aqua/shaperecommend/constants.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,56 @@
1414
1515
NEXT_QUANT suggests the next quantization level based on the current quantization (if applied) or the model weights (if no quantization yet)
1616
"""
17+
1718
LLAMA_REQUIRED_FIELDS = [
18-
"num_hidden_layers", "hidden_size", "num_attention_heads",
19-
"num_key_value_heads", "head_dim", "intermediate_size", "vocab_size"
19+
"num_hidden_layers",
20+
"hidden_size",
21+
"num_attention_heads",
22+
"num_key_value_heads",
23+
"head_dim",
24+
"intermediate_size",
25+
"vocab_size",
2026
]
2127

22-
MOE_REQUIRED_FIELDS = LLAMA_REQUIRED_FIELDS + [
23-
"num_local_experts", "intermediate_size"
24-
]
28+
MOE_REQUIRED_FIELDS = LLAMA_REQUIRED_FIELDS + ["num_local_experts", "intermediate_size"]
2529

2630
NEXT_QUANT = {
27-
"float32": ["8bit", "4bit"], # bits and bytes does not support bfloat16, pytorch responsibility
28-
"bfloat16": ["8bit", "4bit"],
29-
"float16": ["8bit", "4bit"],
31+
"float32": ["4bit"], # vLLM only supports 4bit in-flight-quantization
32+
"bfloat16": ["4bit"],
33+
"float16": ["4bit"],
3034
"int8": ["4bit"],
31-
"fp8": ["4bit"],
35+
"fp8": ["4bit"],
3236
"8bit": ["4bit"],
3337
"int4": ["No smaller quantization available"],
34-
"4bit": ["No smaller quantization available"]
38+
"4bit": ["No smaller quantization available"],
3539
}
3640

3741
TEXT_GENERATION = "text_generation"
3842
SAFETENSORS = "safetensors"
3943

44+
IN_FLIGHT_QUANTIZATION = {"4bit"}
45+
4046
TROUBLESHOOT_MSG = "The selected model is too large to fit on standard GPU shapes with the current configuration.\nAs troubleshooting, we have suggested the two largest available GPU shapes using the smallest quantization level ('4bit') to maximize chances of fitting the model. "
4147

48+
VLLM_PARAMS = {
49+
"max_model_len": "--max-model-len",
50+
"in_flight_quant": "--quantization bitsandbytes --load-format bitsandbytes",
51+
}
4252

43-
QUANT_MAPPING = {
44-
"float32": 4,
45-
"bfloat16": 2,
46-
"float16": 2,
47-
"fp16": 2,
48-
"half": 2,
49-
"int8": 1,
50-
"fp8": 1,
51-
"8bit": 1,
52-
"4bit": 0.5,
53-
"int4": 0.5,
54-
}
53+
DEFAULT_WEIGHT_SIZE = "float32"
5554

55+
BITS_AND_BYTES_8BIT = "8bit"
56+
BITS_AND_BYTES_4BIT = "4bit"
5657

58+
QUANT_MAPPING = {
59+
"float32": 4,
60+
"bfloat16": 2,
61+
"float16": 2,
62+
"fp16": 2,
63+
"half": 2,
64+
"int8": 1,
65+
"fp8": 1,
66+
"8bit": 1,
67+
"4bit": 0.5,
68+
"int4": 0.5,
69+
}

ads/aqua/shaperecommend/estimator.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
from ads.aqua.app import logger
99
from ads.aqua.shaperecommend.constants import (
10+
IN_FLIGHT_QUANTIZATION,
1011
LLAMA_REQUIRED_FIELDS,
1112
MOE_REQUIRED_FIELDS,
1213
NEXT_QUANT,
1314
QUANT_MAPPING,
15+
VLLM_PARAMS,
1416
)
1517
from ads.aqua.shaperecommend.llm_config import LLMConfig
1618

@@ -47,7 +49,7 @@ def kv_cache_memory(self) -> float:
4749
c = self.llm_config
4850
kv_cache_dtype_bytes = QUANT_MAPPING.get(
4951
c.weight_dtype, 2
50-
) # vLLM uses model's weight/quantization applied to KV cache
52+
) # vLLM uses model's weight applied to KV cache
5153

5254
total_bytes = (
5355
self.batch_size
@@ -84,7 +86,9 @@ def total_memory(self) -> float:
8486
"""
8587
return self.model_memory + self.kv_cache_memory
8688

87-
def validate_shape(self, allowed_gpu_memory: float, gpu_utilization: float = 0.9) -> bool:
89+
def validate_shape(
90+
self, allowed_gpu_memory: float, gpu_utilization: float = 0.9
91+
) -> bool:
8892
"""
8993
Validates if a given model estimator fits within the allowed GPU memory budget, using a fixed utilization margin.
9094
@@ -102,6 +106,30 @@ def validate_shape(self, allowed_gpu_memory: float, gpu_utilization: float = 0.9
102106
"""
103107
return (allowed_gpu_memory * gpu_utilization) > self.total_memory
104108

109+
def construct_deployment_params(self) -> str:
110+
"""
111+
Constructs a deployment parameter string for the model.
112+
113+
This method assembles runtime configuration parameters to be passed
114+
during model deployment. It:
115+
- Overrides the max sequence length if a shorter length is provided.
116+
- Suggests in-flight quantization **only if the model is unquantized**
117+
and in-flight quantization (such as '4bit') is requested in config.
118+
119+
Returns:
120+
str: Parameter string for model deployment.
121+
"""
122+
c = self.llm_config
123+
params = ""
124+
if self.seq_len < c.max_seq_len:
125+
params += f"{VLLM_PARAMS['max_model_len']} {str(self.seq_len)}"
126+
127+
# Only suggest in-flight quantization for unquantized models when such quantization is requested
128+
if not c.quantization and c.in_flight_quantization in IN_FLIGHT_QUANTIZATION:
129+
params += " " + VLLM_PARAMS["in_flight_quant"]
130+
131+
return params
132+
105133
def suggest_param_advice(self, allowed: float) -> str:
106134
"""
107135
Suggests parameter modifications to help a model fit within GPU memory limits.
@@ -126,12 +154,12 @@ def suggest_param_advice(self, allowed: float) -> str:
126154
config = self.llm_config
127155

128156
suggested_quant_msg = None
129-
quant_advice = ", ".join(getattr(config, "suggested_quantizations", []))
157+
quant_advice = ", ".join(config.suggested_quantizations)
130158
quantization = getattr(config, "quantization", None)
131159

132160
advice = []
133161

134-
if getattr(config, "suggested_quantizations", []):
162+
if config.suggested_quantizations:
135163
to_do = f", which is smaller than the current {quantization if quantization in NEXT_QUANT else weight_size} format."
136164
if "No" in quant_advice:
137165
suggested_quant_msg = "No smaller quantized version exists. Use a model with fewer parameters."
@@ -142,37 +170,36 @@ def suggest_param_advice(self, allowed: float) -> str:
142170
)
143171
else:
144172
suggested_quant_msg = (
145-
f"Use a model with or apply in-flight {quant_advice} quantization" + to_do
173+
f"Either use a pre-quantized model at {quant_advice}, or apply in-flight {quant_advice} quantization"
174+
+ to_do
146175
)
147176

148-
kv_advice = [
149-
f"Reduce maximum context length (set --max-model-len < {seq_len})"
150-
]
177+
kv_advice = [f"Reduce maximum context length (set --max-model-len < {seq_len})"]
151178

152179
if batch_size != 1:
153180
kv_advice.append(f"Reduce batch size to less than {batch_size}.")
154181

155182
wt_advice = [
156183
"Use a model with fewer parameters.",
157-
f"{suggested_quant_msg}"
158-
if suggested_quant_msg
159-
else ""
184+
f"{suggested_quant_msg}" if suggested_quant_msg else "",
160185
]
161186

162187
if kv_gb > wt_gb and kv_gb > allowed * 0.5:
163-
main = "KV cache memory usage is the main limiting factor."
188+
main = "KV cache memory usage is the main limiting factor"
164189
advice = kv_advice
165190
elif wt_gb > kv_gb and wt_gb > allowed * 0.5:
166-
main = "Model weights are the main limiting factor."
191+
main = "Model weights are the main limiting factor"
167192
advice = wt_advice
168193
else:
169-
main = "Both model weights and KV cache are significant contributors to memory use."
194+
main = "Both model weights and KV cache are significant contributors to memory use"
170195
advice = kv_advice
171196
advice.extend(wt_advice)
172197

173198
advice_str = "\n".join(f"{i}. {item}" for i, item in enumerate(advice, 1))
174199

175-
return f"{advice_str}\n\n{main} (KV cache: {kv_gb:.1f}GB, Weights: {wt_gb:.1f}GB)."
200+
return (
201+
f"{advice_str}\n\n{main} (KV cache: {kv_gb:.1f}GB, Weights: {wt_gb:.1f}GB)."
202+
)
176203

177204
def limiting_factor(
178205
self, allowed_gpu_memory: float, warn_delta: float = 0.85
@@ -202,8 +229,7 @@ def limiting_factor(
202229
advice = (
203230
f"While the selected compute shape is estimated to work "
204231
f"({required:.1f}GB used / {allowed_gpu_memory:.1f}GB allowed), "
205-
f"the model configuration is close to the GPU memory limit. "
206-
"This estimation is theoretical; actual memory usage may vary at runtime.\n\n"
232+
f"the model configuration is close to the GPU memory limit.\n\n"
207233
"If you encounter issues with this shape, consider the following options to reduce memory usage:\n\n"
208234
f"{model_params.lstrip()}"
209235
)
@@ -216,7 +242,7 @@ def limiting_factor(
216242
)
217243
else:
218244
advice = (
219-
f"Model fits well within the allowed compute shape "
245+
f"No override PARAMS needed. \n\nModel fits well within the allowed compute shape "
220246
f"({required:.1f}GB used / {allowed_gpu_memory:.1f}GB allowed)."
221247
)
222248
return advice
@@ -252,6 +278,7 @@ def model_memory(self) -> float:
252278
layer_params = attn_params + mlp_params
253279
# Total params
254280
num_params = c.num_hidden_layers * layer_params + embedding_params
281+
255282
return num_params * c.bytes_per_parameter / 1e9
256283

257284
@property

ads/aqua/shaperecommend/llm_config.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from pydantic import BaseModel, Field
99

1010
from ads.aqua.common.errors import AquaRecommendationError
11-
from ads.aqua.shaperecommend.constants import NEXT_QUANT, QUANT_MAPPING
11+
from ads.aqua.shaperecommend.constants import (
12+
BITS_AND_BYTES_4BIT,
13+
BITS_AND_BYTES_8BIT,
14+
DEFAULT_WEIGHT_SIZE,
15+
NEXT_QUANT,
16+
QUANT_MAPPING,
17+
)
1218

1319

1420
class LLMConfig(BaseModel):
@@ -35,10 +41,11 @@ class LLMConfig(BaseModel):
3541
description="Dimension of each attention head. Typically hidden_size // num_attention_heads.",
3642
)
3743
max_seq_len: Optional[int] = Field(
38-
8192, description="Maximum input sequence length (context window)."
44+
4096, description="Maximum input sequence length (context window)."
3945
)
4046
weight_dtype: Optional[str] = Field(
41-
"float32", description="Parameter data type: 'float32', 'float16', etc."
47+
DEFAULT_WEIGHT_SIZE,
48+
description="Parameter data type: 'float32', 'float16', etc.",
4249
)
4350
quantization: Optional[str] = Field(
4451
None,
@@ -49,6 +56,11 @@ class LLMConfig(BaseModel):
4956
description="Quantization method (e.g., '8bit', '4bit', 'gptq', 'awq') or None if unquantized.",
5057
)
5158

59+
in_flight_quantization: Optional[str] = Field(
60+
None,
61+
description="By setting this, enables recalculation of model footprint using 4bit in-flight quantization",
62+
)
63+
5264
num_key_value_heads: Optional[int] = Field(
5365
None,
5466
description="Number of key/value heads (for GQA architectures: Llama, Mistral, Falcon, Qwen, etc.). Used to determine KV cache size",
@@ -82,9 +94,13 @@ def bytes_per_parameter(self) -> float:
8294
bits = int(m[1])
8395
return bits / 8 # bytes per parameter
8496

97+
# consider in-flight quantization
98+
if self.in_flight_quantization in QUANT_MAPPING:
99+
return QUANT_MAPPING[self.in_flight_quantization]
100+
85101
# Fallback to dtype mapping
86-
dtype = (self.weight_dtype or "float32").lower()
87-
return QUANT_MAPPING.get(dtype, QUANT_MAPPING["float32"])
102+
dtype = (self.weight_dtype or DEFAULT_WEIGHT_SIZE).lower()
103+
return QUANT_MAPPING.get(dtype, QUANT_MAPPING[DEFAULT_WEIGHT_SIZE])
88104

89105
@classmethod
90106
def detect_quantization_type(cls, raw: dict) -> Optional[str]:
@@ -114,9 +130,9 @@ def detect_quantization_bits(cls, raw: dict) -> Optional[str]:
114130
Detects quantization bit-width as a string (e.g., '4bit', '8bit') from Hugging Face config dict.
115131
"""
116132
if raw.get("load_in_8bit"):
117-
return "8bit"
133+
return BITS_AND_BYTES_8BIT
118134
if raw.get("load_in_4bit"):
119-
return "4bit"
135+
return BITS_AND_BYTES_4BIT
120136
if "quantization_config" in raw:
121137
qcfg = raw["quantization_config"]
122138
bits = qcfg.get("bits") or qcfg.get("wbits")
@@ -132,7 +148,12 @@ def suggested_quantizations(self):
132148
If model is un-quantized, uses the weight size.
133149
If model is pre-quantized, uses the quantization level.
134150
"""
135-
key = (self.quantization or self.weight_dtype or "float32").lower()
151+
key = (
152+
self.quantization
153+
or self.in_flight_quantization
154+
or self.weight_dtype
155+
or DEFAULT_WEIGHT_SIZE
156+
).lower()
136157
return NEXT_QUANT.get(key, [])
137158

138159
def calculate_possible_seq_len(self, min_len=2048):
@@ -142,22 +163,21 @@ def calculate_possible_seq_len(self, min_len=2048):
142163
"""
143164
vals = []
144165
curr = min_len
145-
max_seq_len = 16384 if not self.max_seq_len else self.max_seq_len
146-
while curr <= max_seq_len:
166+
while curr <= self.max_seq_len:
147167
vals.append(curr)
148168
curr *= 2
149-
if vals and vals[-1] != max_seq_len:
150-
vals.append(max_seq_len)
169+
if vals and vals[-1] != self.max_seq_len:
170+
vals.append(self.max_seq_len)
151171
return vals
152172

153173
def optimal_config(self):
154174
"""
155175
Builds a list of optimal configuration parameters (sorted descending). Combination of:
156-
- Quantization / weight sizes: bfloat16 weight size -> 8bit -> 4bit
176+
- Quantization / weight sizes: bfloat16 weight size -> 4bit
157177
- max-model-len: power-of-two model lengths from max length (config.json of model) to 2048 tokens.
158178
159179
Example:
160-
[('bfloat16', max_model_len supported by model) ('bfloat16', 1/2 of max_model_len) ... ('int8', 2048), ('int4', 4096), ('int4', 2048)]
180+
[('bfloat16', max_model_len supported by model) ('bfloat16', 1/2 of max_model_len) ... ('int4', 4096), ('int4', 2048)]
161181
162182
"""
163183
# Create a copy of the suggested_quantizations list
@@ -183,9 +203,11 @@ def validate_model_support(cls, raw: dict) -> ValueError:
183203
"""
184204
excluded_models = {"t5", "gemma", "bart", "bert", "roberta", "albert"}
185205
if (
186-
raw.get("is_encoder_decoder", False) # exclude encoder-decoder models
187-
or (raw.get("is_decoder") is False) # exclude explicit encoder-only models (altho no text-generation task ones, just dbl check)
188-
or raw.get("model_type", "").lower() # exclude by known model types
206+
raw.get("is_encoder_decoder", False) # exclude encoder-decoder models
207+
or (
208+
raw.get("is_decoder") is False
209+
) # exclude explicit encoder-only models (altho no text-generation task ones, just dbl check)
210+
or raw.get("model_type", "").lower() # exclude by known model types
189211
in excluded_models
190212
):
191213
raise AquaRecommendationError(
@@ -207,7 +229,7 @@ def from_raw_config(cls, raw: dict) -> "LLMConfig":
207229
)
208230
hidden_size = raw.get("hidden_size") or raw.get("n_embd") or raw.get("d_model")
209231
vocab_size = raw.get("vocab_size")
210-
weight_dtype = str(raw.get("torch_dtype", "float32"))
232+
weight_dtype = str(raw.get("torch_dtype", DEFAULT_WEIGHT_SIZE))
211233
quantization = cls.detect_quantization_bits(raw)
212234
quantization_type = cls.detect_quantization_type(raw)
213235

0 commit comments

Comments
 (0)