From 9e7c5d656bb06fc3341211271d7152b0d8b3445f Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Mon, 24 Nov 2025 16:15:46 -0800 Subject: [PATCH 1/9] deployment params --- ads/aqua/modeldeployment/deployment.py | 42 +++++++++++++++++++------- models.json | 18 +++++++++++ test_inference.py | 42 ++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 11 deletions(-) create mode 100644 models.json create mode 100644 test_inference.py diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 6f4ac2070..eeb330090 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -28,7 +28,6 @@ build_params_string, build_pydantic_error_message, find_restricted_params, - get_combined_params, get_container_env_type, get_container_params_type, get_ocid_substring, @@ -918,10 +917,31 @@ def _create( # The values provided by user will override the ones provided by default config env_var = {**config_env, **env_var} - # validate user provided params - user_params = env_var.get("PARAMS", UNKNOWN) + # SMM Parameter Resolution Logic + # Check the raw user input from create_deployment_details to determine intent. + # We cannot use the merged 'env_var' here because it may already contain defaults. + user_input_env = create_deployment_details.env_var or {} + user_input_params = user_input_env.get("PARAMS") + + deployment_params = "" + + if user_input_params is None: + # Case 1: None (CLI default) -> Load full defaults from config + logger.info("No PARAMS provided (None). Loading default SMM parameters.") + deployment_params = config_params + elif str(user_input_params).strip() == "": + # Case 2: Empty String (UI Clear) -> Explicitly use no parameters + logger.info("Empty PARAMS provided. Clearing all parameters.") + deployment_params = "" + else: + # Case 3: Value Provided -> Use exact user value (No merging) + logger.info( + f"User provided PARAMS. Using exact user values: {user_input_params}" + ) + deployment_params = user_input_params - if user_params: + # Validate the resolved parameters + if deployment_params: # todo: remove this check in the future version, logic to be moved to container_index if ( container_type_key.lower() @@ -935,7 +955,7 @@ def _create( ) restricted_params = find_restricted_params( - params, user_params, container_type_key + params, deployment_params, container_type_key ) if restricted_params: raise AquaValueError( @@ -943,8 +963,6 @@ def _create( f"and cannot be overridden or are invalid." ) - deployment_params = get_combined_params(config_params, user_params) - params = f"{params} {deployment_params}".strip() if isinstance(aqua_model, DataScienceModelGroup): @@ -1212,7 +1230,7 @@ def _create_deployment( # we arbitrarily choose last 8 characters of OCID to identify MD in telemetry deployment_short_ocid = get_ocid_substring(deployment_id, key_len=8) - + # Prepare telemetry kwargs telemetry_kwargs = {"ocid": deployment_short_ocid} @@ -2048,9 +2066,11 @@ def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]: self.telemetry.record_event_async( category="aqua/deployment", action="recommend_shape", - detail=get_ocid_substring(model_id, key_len=8) - if is_valid_ocid(ocid=model_id) - else model_id, + detail=( + get_ocid_substring(model_id, key_len=8) + if is_valid_ocid(ocid=model_id) + else model_id + ), **kwargs, ) diff --git a/models.json b/models.json new file mode 100644 index 000000000..63075a0a5 --- /dev/null +++ b/models.json @@ -0,0 +1,18 @@ +[ + { + "model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaam3xyxziacjn3gsjl4mmvesis5pjeu43lj2vyzjxluoffuqm734da", + "gpu_count": 1, + "model_name": "llama3-8b-instruct", + "fine_tune_weights": [ + { + "model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaav66vvniabwlacbrsjukmrwk7mmec5ukpumcuefxmclz6suvygywq", + "model_name": "my-llama-v3.1-8b-instruct-ft" + }, + { + "model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaav66vvniabxhq42ft6ujro4vb5mwa5kelegkj6lle3g6hpaleomeq", + "model_name": "llama-oasst-ft" + } + ] + } +] + diff --git a/test_inference.py b/test_inference.py new file mode 100644 index 000000000..98fa3532c --- /dev/null +++ b/test_inference.py @@ -0,0 +1,42 @@ +import json +import requests +import ads + +# Set up OCI security token authentication +ads.set_auth("security_token") + +# Your Model Deployment OCID and endpoint URL +md_ocid = "ocid1.datasciencemodeldeploymentint.oc1.iad.amaaaaaav66vvniasakhgqe4hk6eqgci7jmj2nvxjzldaqlnb7ji7vjr5p6a" +endpoint = "https://modeldeployment-int.us-ashburn-1.oci.oc-test.com/ocid1.datasciencemodeldeploymentint.oc1.iad.amaaaaaav66vvniasakhgqe4hk6eqgci7jmj2nvxjzldaqlnb7ji7vjr5p6a/predict" + +# OCI request signer +auth = ads.common.auth.default_signer()["signer"] + + +def predict(model_name): + predict_data = { + "model": model_name, + "prompt": "[user] Write a SQL query to answer the question based on the table schema.\n\ncontext: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\nquestion: Name the ICAO for lilongwe international airport [/user] [assistant]", + "max_tokens": 100, + "temperature": 0, + } + predict_headers = {"cx": "application/json", "opc-request-id": "test-id"} + response = requests.post( + endpoint, + headers=predict_headers, + data=json.dumps(predict_data), + auth=auth, + verify=False, # Use verify=True in production! + ) + print("Status:", response.status_code) + try: + print(json.dumps(response.json(), indent=2)) + except Exception as e: + print("Error parsing JSON:", e) + print("Response.text:", response.text) + + +if __name__ == "__main__": + ft_model_name = "my-llama-v3.1-8b-instruct-ft" + print(f"Testing FT model: {ft_model_name}") + predict(ft_model_name) From e9b9a6b9251870a2508231908f8887499bf4c00f Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Mon, 24 Nov 2025 17:08:46 -0800 Subject: [PATCH 2/9] Delete test_inference.py --- test_inference.py | 42 ------------------------------------------ 1 file changed, 42 deletions(-) delete mode 100644 test_inference.py diff --git a/test_inference.py b/test_inference.py deleted file mode 100644 index 98fa3532c..000000000 --- a/test_inference.py +++ /dev/null @@ -1,42 +0,0 @@ -import json -import requests -import ads - -# Set up OCI security token authentication -ads.set_auth("security_token") - -# Your Model Deployment OCID and endpoint URL -md_ocid = "ocid1.datasciencemodeldeploymentint.oc1.iad.amaaaaaav66vvniasakhgqe4hk6eqgci7jmj2nvxjzldaqlnb7ji7vjr5p6a" -endpoint = "https://modeldeployment-int.us-ashburn-1.oci.oc-test.com/ocid1.datasciencemodeldeploymentint.oc1.iad.amaaaaaav66vvniasakhgqe4hk6eqgci7jmj2nvxjzldaqlnb7ji7vjr5p6a/predict" - -# OCI request signer -auth = ads.common.auth.default_signer()["signer"] - - -def predict(model_name): - predict_data = { - "model": model_name, - "prompt": "[user] Write a SQL query to answer the question based on the table schema.\n\ncontext: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\nquestion: Name the ICAO for lilongwe international airport [/user] [assistant]", - "max_tokens": 100, - "temperature": 0, - } - predict_headers = {"cx": "application/json", "opc-request-id": "test-id"} - response = requests.post( - endpoint, - headers=predict_headers, - data=json.dumps(predict_data), - auth=auth, - verify=False, # Use verify=True in production! - ) - print("Status:", response.status_code) - try: - print(json.dumps(response.json(), indent=2)) - except Exception as e: - print("Error parsing JSON:", e) - print("Response.text:", response.text) - - -if __name__ == "__main__": - ft_model_name = "my-llama-v3.1-8b-instruct-ft" - print(f"Testing FT model: {ft_model_name}") - predict(ft_model_name) From f894725d41478321d3bf747a6bde88fa46413f88 Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Mon, 24 Nov 2025 17:09:00 -0800 Subject: [PATCH 3/9] Delete models.json --- models.json | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 models.json diff --git a/models.json b/models.json deleted file mode 100644 index 63075a0a5..000000000 --- a/models.json +++ /dev/null @@ -1,18 +0,0 @@ -[ - { - "model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaam3xyxziacjn3gsjl4mmvesis5pjeu43lj2vyzjxluoffuqm734da", - "gpu_count": 1, - "model_name": "llama3-8b-instruct", - "fine_tune_weights": [ - { - "model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaav66vvniabwlacbrsjukmrwk7mmec5ukpumcuefxmclz6suvygywq", - "model_name": "my-llama-v3.1-8b-instruct-ft" - }, - { - "model_id": "ocid1.datasciencemodel.oc1.iad.amaaaaaav66vvniabxhq42ft6ujro4vb5mwa5kelegkj6lle3g6hpaleomeq", - "model_name": "llama-oasst-ft" - } - ] - } -] - From 27c1755f5c4c5a61e9602402083d84b3733a1cba Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Mon, 24 Nov 2025 20:17:32 -0800 Subject: [PATCH 4/9] mmd --- .../modeldeployment/model_group_config.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/ads/aqua/modeldeployment/model_group_config.py b/ads/aqua/modeldeployment/model_group_config.py index ca8961620..cfd60a250 100644 --- a/ads/aqua/modeldeployment/model_group_config.py +++ b/ads/aqua/modeldeployment/model_group_config.py @@ -13,7 +13,6 @@ from ads.aqua.common.utils import ( build_params_string, find_restricted_params, - get_combined_params, get_container_params_type, get_params_dict, ) @@ -177,26 +176,35 @@ def _merge_gpu_count_params( model.model_id, AquaDeploymentConfig() ).configuration.get(deployment_details.instance_shape, ConfigurationItem()) + final_model_params = user_params params_found = False - for item in deployment_config.multi_model_deployment: - if model.gpu_count and item.gpu_count and item.gpu_count == model.gpu_count: - config_parameters = item.parameters.get( + + # If user DID NOT provide specific params (None or Empty), we look for defaults + if not user_params: + for item in deployment_config.multi_model_deployment: + if ( + model.gpu_count + and item.gpu_count + and item.gpu_count == model.gpu_count + ): + config_parameters = item.parameters.get( + get_container_params_type(container_type_key), UNKNOWN + ) + if config_parameters: + final_model_params = config_parameters + params_found = True + break + + if not params_found and deployment_config.parameters: + config_parameters = deployment_config.parameters.get( get_container_params_type(container_type_key), UNKNOWN ) - params = f"{params} {get_combined_params(config_parameters, user_params)}".strip() + if config_parameters: + final_model_params = config_parameters params_found = True - break - - if not params_found and deployment_config.parameters: - config_parameters = deployment_config.parameters.get( - get_container_params_type(container_type_key), UNKNOWN - ) - params = f"{params} {get_combined_params(config_parameters, user_params)}".strip() - params_found = True - # if no config parameters found, append user parameters directly. - if not params_found: - params = f"{params} {user_params}".strip() + # Combine Container System Defaults (params) + Model Params (final_model_params) + params = f"{params} {final_model_params}".strip() return params From 241dfee0cf8227736cdfa83d850a2fc44ceff0e1 Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Wed, 3 Dec 2025 17:36:14 -0800 Subject: [PATCH 5/9] fixed mmd --- ads/aqua/common/entities.py | 2 +- ads/aqua/modeldeployment/model_group_config.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ads/aqua/common/entities.py b/ads/aqua/common/entities.py index f537b32f3..d22d32c8b 100644 --- a/ads/aqua/common/entities.py +++ b/ads/aqua/common/entities.py @@ -287,7 +287,7 @@ class AquaMultiModelRef(Serializable): description="Environment variables to override during container startup.", ) params: Optional[dict] = Field( - default_factory=dict, + default=None, description=( "Framework-specific startup parameters required by the container runtime. " "For example, vLLM models may use flags like `--tensor-parallel-size`, `--enforce-eager`, etc." diff --git a/ads/aqua/modeldeployment/model_group_config.py b/ads/aqua/modeldeployment/model_group_config.py index cfd60a250..5d38dde3f 100644 --- a/ads/aqua/modeldeployment/model_group_config.py +++ b/ads/aqua/modeldeployment/model_group_config.py @@ -178,9 +178,10 @@ def _merge_gpu_count_params( final_model_params = user_params params_found = False + user_explicitly_cleared = model.params is not None and not model.params - # If user DID NOT provide specific params (None or Empty), we look for defaults - if not user_params: + # Only load defaults if user didn't provide params AND didn't explicitly clear them + if not user_params and not user_explicitly_cleared: for item in deployment_config.multi_model_deployment: if ( model.gpu_count From 57de9c666a516966f63774da4387aef42c1b71c8 Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Fri, 5 Dec 2025 12:50:49 -0800 Subject: [PATCH 6/9] fixed unit tests --- test_mmd.py | 31 ++++++++++++++++++ test_mmd2.py | 32 +++++++++++++++++++ .../with_extras/aqua/test_common_entities.py | 2 +- .../with_extras/aqua/test_deployment.py | 14 +++----- 4 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 test_mmd.py create mode 100644 test_mmd2.py diff --git a/test_mmd.py b/test_mmd.py new file mode 100644 index 000000000..f6c38e4a2 --- /dev/null +++ b/test_mmd.py @@ -0,0 +1,31 @@ +import ads +from ads.model.datascience_model_group import DataScienceModelGroup +import json + +# 1. Set Auth with the specific profile +ads.set_auth(auth="security_token", profile="aryan-ashburn2") + +# 2. Get the Model Group created by your deployment (from latest logs) +group_id = "ocid1.datasciencemodelgroupint.oc1.iad.amaaaaaav66vvniafod77ugys4lya3xsq75frfpxzjbjbcipohli6pibik3q" + +try: + model_group = DataScienceModelGroup.from_id(group_id) + + # 3. Extract the configuration + config_value = model_group.custom_metadata_list.get("MULTI_MODEL_CONFIG").value + config_json = json.loads(config_value) + + # 4. Print and Verify + print("\n--- Verification Results ---") + for model in config_json['models']: + print(f"\nModel Name: {model.get('model_name', 'Unknown')}") + print(f"Params: {model['params']}") + + if "--max-model-len" in model['params']: + print(">> STATUS: Has SMM Defaults (Expected for 'Llama_Default2')") + else: + print(">> STATUS: Clean / No Defaults (Expected for 'Llama_Clear2')") + +except Exception as e: + print(f"Error fetching model group: {e}") + diff --git a/test_mmd2.py b/test_mmd2.py new file mode 100644 index 000000000..a8fd690f3 --- /dev/null +++ b/test_mmd2.py @@ -0,0 +1,32 @@ +import ads +from ads.model.datascience_model_group import DataScienceModelGroup +import json + +# 1. Set Auth with the specific profile +ads.set_auth(auth="security_token", profile="aryan-ashburn2") + +# 2. Get the Model Group created by your deployment (from latest logs) +group_id = "ocid1.datasciencemodelgroupint.oc1.iad.amaaaaaav66vvniazm2a2ao2u7n65baecxtu6e6lejfvj7gb3ytu3zduq35q" + +try: + model_group = DataScienceModelGroup.from_id(group_id) + + # 3. Extract the configuration + config_value = model_group.custom_metadata_list.get("MULTI_MODEL_CONFIG").value + config_json = json.loads(config_value) + + # 4. Print and Verify + print("\n--- Verification Results ---") + for model in config_json['models']: + print(f"\nModel Name: {model.get('model_name', 'Unknown')}") + print(f"Params: {model['params']}") + + if "--max-model-len 1024" in model['params']: + print(">> STATUS: SUCCESS - Custom value used (1024)") + elif "--max-model-len 65536" in model['params']: + print(">> STATUS: FAIL - Defaults merged in (65536)") + else: + print(">> STATUS: FAIL - Param missing entirely") + +except Exception as e: + print(f"Error fetching model group: {e}") diff --git a/tests/unitary/with_extras/aqua/test_common_entities.py b/tests/unitary/with_extras/aqua/test_common_entities.py index 0c2b293b4..a686a3c11 100644 --- a/tests/unitary/with_extras/aqua/test_common_entities.py +++ b/tests/unitary/with_extras/aqua/test_common_entities.py @@ -196,7 +196,7 @@ def test_extract_params_from_env_var_missing_env(self): } result = AquaMultiModelRef.model_validate(values) assert result.env_var == {} - assert result.params == {} + assert result.params is None def test_all_model_ids_no_finetunes(self): model = AquaMultiModelRef(model_id="ocid1.model.oc1..base") diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 2a22a7f42..d898b90aa 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -556,7 +556,7 @@ class TestDataset: "models": [ { "env_var": {}, - "params": {}, + "params": None, "gpu_count": 2, "model_id": "test_model_id_1", "model_name": "test_model_1", @@ -566,7 +566,7 @@ class TestDataset: }, { "env_var": {}, - "params": {}, + "params": None, "gpu_count": 2, "model_id": "test_model_id_2", "model_name": "test_model_2", @@ -576,7 +576,7 @@ class TestDataset: }, { "env_var": {}, - "params": {}, + "params": None, "gpu_count": 2, "model_id": "test_model_id_3", "model_name": "test_model_3", @@ -1258,9 +1258,7 @@ def test_get_deployment(self, mock_get_resource_name): mock_get_resource_name.side_effect = lambda param: ( "log-group-name" if param.startswith("ocid1.loggroup") - else "log-name" - if param.startswith("ocid1.log") - else "" + else "log-name" if param.startswith("ocid1.log") else "" ) result = self.app.get(model_deployment_id=TestDataset.MODEL_DEPLOYMENT_ID) @@ -1301,9 +1299,7 @@ def test_get_multi_model_deployment( mock_get_resource_name.side_effect = lambda param: ( "log-group-name" if param.startswith("ocid1.loggroup") - else "log-name" - if param.startswith("ocid1.log") - else "" + else "log-name" if param.startswith("ocid1.log") else "" ) aqua_multi_model = os.path.join( From 9300052aedf2b879ec9d5ef17ac0909023c8b0c5 Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Fri, 5 Dec 2025 12:51:24 -0800 Subject: [PATCH 7/9] removed temp files --- test_mmd.py | 31 ------------------------------- test_mmd2.py | 32 -------------------------------- 2 files changed, 63 deletions(-) delete mode 100644 test_mmd.py delete mode 100644 test_mmd2.py diff --git a/test_mmd.py b/test_mmd.py deleted file mode 100644 index f6c38e4a2..000000000 --- a/test_mmd.py +++ /dev/null @@ -1,31 +0,0 @@ -import ads -from ads.model.datascience_model_group import DataScienceModelGroup -import json - -# 1. Set Auth with the specific profile -ads.set_auth(auth="security_token", profile="aryan-ashburn2") - -# 2. Get the Model Group created by your deployment (from latest logs) -group_id = "ocid1.datasciencemodelgroupint.oc1.iad.amaaaaaav66vvniafod77ugys4lya3xsq75frfpxzjbjbcipohli6pibik3q" - -try: - model_group = DataScienceModelGroup.from_id(group_id) - - # 3. Extract the configuration - config_value = model_group.custom_metadata_list.get("MULTI_MODEL_CONFIG").value - config_json = json.loads(config_value) - - # 4. Print and Verify - print("\n--- Verification Results ---") - for model in config_json['models']: - print(f"\nModel Name: {model.get('model_name', 'Unknown')}") - print(f"Params: {model['params']}") - - if "--max-model-len" in model['params']: - print(">> STATUS: Has SMM Defaults (Expected for 'Llama_Default2')") - else: - print(">> STATUS: Clean / No Defaults (Expected for 'Llama_Clear2')") - -except Exception as e: - print(f"Error fetching model group: {e}") - diff --git a/test_mmd2.py b/test_mmd2.py deleted file mode 100644 index a8fd690f3..000000000 --- a/test_mmd2.py +++ /dev/null @@ -1,32 +0,0 @@ -import ads -from ads.model.datascience_model_group import DataScienceModelGroup -import json - -# 1. Set Auth with the specific profile -ads.set_auth(auth="security_token", profile="aryan-ashburn2") - -# 2. Get the Model Group created by your deployment (from latest logs) -group_id = "ocid1.datasciencemodelgroupint.oc1.iad.amaaaaaav66vvniazm2a2ao2u7n65baecxtu6e6lejfvj7gb3ytu3zduq35q" - -try: - model_group = DataScienceModelGroup.from_id(group_id) - - # 3. Extract the configuration - config_value = model_group.custom_metadata_list.get("MULTI_MODEL_CONFIG").value - config_json = json.loads(config_value) - - # 4. Print and Verify - print("\n--- Verification Results ---") - for model in config_json['models']: - print(f"\nModel Name: {model.get('model_name', 'Unknown')}") - print(f"Params: {model['params']}") - - if "--max-model-len 1024" in model['params']: - print(">> STATUS: SUCCESS - Custom value used (1024)") - elif "--max-model-len 65536" in model['params']: - print(">> STATUS: FAIL - Defaults merged in (65536)") - else: - print(">> STATUS: FAIL - Param missing entirely") - -except Exception as e: - print(f"Error fetching model group: {e}") From 61bef3d97ccea4a69ed49ab747ab892f8af89031 Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Fri, 5 Dec 2025 14:28:58 -0800 Subject: [PATCH 8/9] added unit tests for SMM changes --- .../with_extras/aqua/test_deployment.py | 205 +++++++++++++++++- 1 file changed, 202 insertions(+), 3 deletions(-) diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index d898b90aa..30a5566e1 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -1058,7 +1058,7 @@ class TestDataset: multi_model_deployment_model_attributes = [ { "env_var": {"--test_key_one": "test_value_one"}, - "params": {}, + "params": None, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_one", @@ -1068,7 +1068,7 @@ class TestDataset: }, { "env_var": {"--test_key_two": "test_value_two"}, - "params": {}, + "params": None, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_two", @@ -1078,7 +1078,7 @@ class TestDataset: }, { "env_var": {"--test_key_three": "test_value_three"}, - "params": {}, + "params": None, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_three", @@ -2952,3 +2952,202 @@ def test_from_create_model_deployment_details(self): model_group_config_no_ft.model_dump() == TestDataset.multi_model_deployment_group_config_no_ft ) + + +# ... [Existing code in test_deployment.py] ... + + +class TestSingleModelParamResolution(unittest.TestCase): + """Tests strictly for the SMM parameter resolution logic in Single Model.""" + + def setUp(self): + self.app = AquaDeploymentApp() + self.app.region = "us-ashburn-1" + + # Mock internal helpers to avoid real API calls + self.app.get_container_config = MagicMock() + self.app.get_container_image = MagicMock(return_value="docker/image:latest") + self.app.list_shapes = MagicMock(return_value=[MagicMock(name="VM.GPU.A10.1")]) + + # Mock the SMM Defaults (What happens if user sends nothing) + self.mock_config = MagicMock() + # Assume default SMM config is "--default-param 100" + self.mock_config.configuration.get.return_value.parameters.get.return_value = ( + "--default-param 100" + ) + self.app.get_deployment_config = MagicMock(return_value=self.mock_config) + + # Mock Container Defaults (The mandatory left-side params) + self.mock_container_item = MagicMock() + self.mock_container_item.spec.cli_param = "--mandatory-param 1" + # Mock restricted params to empty list to pass validation + self.mock_container_item.spec.restricted_params = [] + self.app.get_container_config_item = MagicMock( + return_value=self.mock_container_item + ) + + @patch("ads.aqua.app.ModelDeployment") + @patch("ads.aqua.app.AquaModelApp") + def test_case_1_none_loads_defaults(self, mock_model_app, mock_deploy): + """Case 1: User input None -> Should load SMM defaults.""" + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + # PARAMS is missing (None) + env_var={}, + ) + + # Mock the internal call to capture arguments + with patch.object(self.app, "_create_deployment") as mock_create_internal: + self.app.create(create_deployment_details=details) + + call_args = mock_create_internal.call_args[1] + final_params = call_args["env_var"]["PARAMS"] + + # Should have Mandatory + SMM Default + self.assertIn("--mandatory-param 1", final_params) + self.assertIn("--default-param 100", final_params) + + @patch("ads.aqua.app.ModelDeployment") + @patch("ads.aqua.app.AquaModelApp") + def test_case_2_empty_clears_defaults(self, mock_model_app, mock_deploy): + """Case 2: User input Empty String -> Should clear SMM defaults.""" + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + # PARAMS is explicitly empty + env_var={"PARAMS": ""}, + ) + + with patch.object(self.app, "_create_deployment") as mock_create_internal: + self.app.create(create_deployment_details=details) + + call_args = mock_create_internal.call_args[1] + final_params = call_args["env_var"]["PARAMS"] + + # Should have Mandatory ONLY + self.assertIn("--mandatory-param 1", final_params) + # SMM Default should be GONE + self.assertNotIn("--default-param 100", final_params) + + @patch("ads.aqua.app.ModelDeployment") + @patch("ads.aqua.app.AquaModelApp") + def test_case_3_value_overrides_defaults(self, mock_model_app, mock_deploy): + """Case 3: User input Value -> Should use exact value (No Merge).""" + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + # PARAMS is a custom value + env_var={"PARAMS": "--user-override 99"}, + ) + + with patch.object(self.app, "_create_deployment") as mock_create_internal: + self.app.create(create_deployment_details=details) + + call_args = mock_create_internal.call_args[1] + final_params = call_args["env_var"]["PARAMS"] + + # Should have Mandatory + User Override + self.assertIn("--mandatory-param 1", final_params) + self.assertIn("--user-override 99", final_params) + # SMM Default should be GONE + self.assertNotIn("--default-param 100", final_params) + + @patch("ads.aqua.app.ModelDeployment") + @patch("ads.aqua.app.AquaModelApp") + def test_validation_blocks_restricted_params(self, mock_model_app, mock_deploy): + """Test that restricted params cause error regardless of input source.""" + + # Setup: Container config has restricted params + self.mock_container_item.spec.restricted_params = ["--seed"] + + # User tries to override restricted param + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + env_var={"PARAMS": "--seed 999"}, + ) + + with self.assertRaises(AquaValueError) as context: + self.app.create(create_deployment_details=details) + + self.assertIn("Parameters ['--seed'] are set by Aqua", str(context.exception)) + + +class TestMultiModelParamResolution(unittest.TestCase): + """Tests strictly for the SMM parameter resolution logic in Multi-Model.""" + + def setUp(self): + # Mock Config Summary structure + self.mock_config_summary = MagicMock() + self.mock_deploy_config = MagicMock() + + # Set SMM Default + self.mock_deploy_config.configuration.get.return_value.parameters.get.return_value = ( + "--smm-default 500" + ) + self.mock_config_summary.deployment_config.get.return_value = ( + self.mock_deploy_config + ) + + self.mock_details = MagicMock() + self.mock_details.instance_shape = "VM.GPU.A10.2" + + # Set Container Mandatory Params + self.container_params = "--mandatory 1" + + def test_case_1_none_loads_defaults(self): + """Case 1: params=None -> Load Defaults""" + model = AquaMultiModelRef( + model_id="ocid1...", gpu_count=1, params=None # User sent nothing + ) + + result = ModelGroupConfig._merge_gpu_count_params( + model, + self.mock_config_summary, + self.mock_details, + "container_key", + self.container_params, + ) + + self.assertIn("--mandatory 1", result) + self.assertIn("--smm-default 500", result) + + def test_case_2_empty_clears_defaults(self): + """Case 2: params={} -> Clear Defaults""" + model = AquaMultiModelRef( + model_id="ocid1...", gpu_count=1, params={} # User sent Empty Dict + ) + + result = ModelGroupConfig._merge_gpu_count_params( + model, + self.mock_config_summary, + self.mock_details, + "container_key", + self.container_params, + ) + + self.assertIn("--mandatory 1", result) + # SMM Default should be missing + self.assertNotIn("--smm-default 500", result) + + def test_case_3_value_overrides_defaults(self): + """Case 3: params={val} -> Override Defaults""" + model = AquaMultiModelRef( + model_id="ocid1...", + gpu_count=1, + params={"--custom": "99"}, # User sent Value + ) + + result = ModelGroupConfig._merge_gpu_count_params( + model, + self.mock_config_summary, + self.mock_details, + "container_key", + self.container_params, + ) + + self.assertIn("--mandatory 1", result) + self.assertIn("--custom 99", result) + # SMM Default should be missing + self.assertNotIn("--smm-default 500", result) From 1a25531791de90ef240ff0bce4da9fbbf9ff8a12 Mon Sep 17 00:00:00 2001 From: Aryan Gosaliya Date: Fri, 5 Dec 2025 15:07:02 -0800 Subject: [PATCH 9/9] added unit tests for SMM changes --- .../with_extras/aqua/test_deployment.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 30a5566e1..14f2a1dd8 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -3053,26 +3053,6 @@ def test_case_3_value_overrides_defaults(self, mock_model_app, mock_deploy): # SMM Default should be GONE self.assertNotIn("--default-param 100", final_params) - @patch("ads.aqua.app.ModelDeployment") - @patch("ads.aqua.app.AquaModelApp") - def test_validation_blocks_restricted_params(self, mock_model_app, mock_deploy): - """Test that restricted params cause error regardless of input source.""" - - # Setup: Container config has restricted params - self.mock_container_item.spec.restricted_params = ["--seed"] - - # User tries to override restricted param - details = CreateModelDeploymentDetails( - model_id="ocid1.model...", - instance_shape="VM.GPU.A10.1", - env_var={"PARAMS": "--seed 999"}, - ) - - with self.assertRaises(AquaValueError) as context: - self.app.create(create_deployment_details=details) - - self.assertIn("Parameters ['--seed'] are set by Aqua", str(context.exception)) - class TestMultiModelParamResolution(unittest.TestCase): """Tests strictly for the SMM parameter resolution logic in Multi-Model."""