From 6567ce57642545790e14354613ca7065b53b6125 Mon Sep 17 00:00:00 2001 From: xibei chen Date: Thu, 25 Sep 2025 16:15:49 -0700 Subject: [PATCH] feature: add model_type hyperparameter support for Nova recipes --- src/sagemaker/pytorch/estimator.py | 3 +++ tests/unit/test_pytorch_nova.py | 37 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 9f41b5b2b9..9e2f0f0dd4 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -1180,6 +1180,9 @@ def _setup_for_nova_recipe( # Set up Nova-specific configuration run_config = recipe.get("run", {}) model_name_or_path = run_config.get("model_name_or_path") + # Set hyperparameters model_type + model_type = run_config.get("model_type") + args["hyperparameters"]["model_type"] = model_type # Set hyperparameters based on model_name_or_path if model_name_or_path: diff --git a/tests/unit/test_pytorch_nova.py b/tests/unit/test_pytorch_nova.py index 46d526f22e..b8604c2ef2 100644 --- a/tests/unit/test_pytorch_nova.py +++ b/tests/unit/test_pytorch_nova.py @@ -795,3 +795,40 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se pytorch._hyperparameters.get("role_arn") == "arn:aws:iam::123456789012:role/SageMakerRole" ) + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly sets model_type hyperparameter.""" + # Create a mock nova recipe with model_type + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.llama-2-7b", + "model_name_or_path": "llama/llama-2-7b", + "replicas": 1, + } + } + ) + + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify that model_type hyperparameter was set correctly + assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b"