Skip to content

Commit 2091850

Browse files
authored
feat: Support Placeholders with ModelStep (#175)
1 parent 19761b0 commit 2091850

File tree

3 files changed

+108
-20
lines changed

3 files changed

+108
-20
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
7676
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
7777
experiment_config (dict or Placeholder, optional): Specify the experiment config for the training. (Default: None)
7878
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
79-
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
79+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
8080
output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model
8181
artifacts and output files). If specified, it overrides the `output_path` property of `estimator`.
8282
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTrainingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>`_. (Default: None)
@@ -220,7 +220,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
220220
split_type (str or Placeholder): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
221221
experiment_config (dict or Placeholder, optional): Specify the experiment config for the transform. (Default: None)
222222
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True)
223-
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
223+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
224224
input_filter (str or Placeholder): A JSONPath to select a portion of the input to pass to the algorithm container for inference. If you omit the field, it gets the value ‘$’, representing the entire input. For CSV data, each row is taken as a JSON array, so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. CSV data should follow the RFC format. See Supported JSONPath Operators for a table of supported JSONPath operators. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.features” (default: None).
225225
output_filter (str or Placeholder): A JSONPath to select a portion of the joined/original output to return as the output. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.prediction” (default: None).
226226
join_source (str or Placeholder): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
@@ -302,14 +302,16 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
302302
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
303303
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
304304
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
305-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
305+
tags (list[dict] or Placeholders, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
306+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateModel<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_. (Default: None)
307+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
306308
"""
307309
if isinstance(model, FrameworkModel):
308-
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
310+
model_parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
309311
if model_name:
310-
parameters['ModelName'] = model_name
312+
model_parameters['ModelName'] = model_name
311313
elif isinstance(model, Model):
312-
parameters = {
314+
model_parameters = {
313315
'ExecutionRoleArn': model.role,
314316
'ModelName': model_name or model.name,
315317
'PrimaryContainer': {
@@ -321,13 +323,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
321323
else:
322324
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
323325

324-
if 'S3Operations' in parameters:
325-
del parameters['S3Operations']
326+
if 'S3Operations' in model_parameters:
327+
del model_parameters['S3Operations']
326328

327329
if tags:
328-
parameters['Tags'] = tags_dict_to_kv_list(tags)
330+
model_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
329331

330-
kwargs[Field.Parameters.value] = parameters
332+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
333+
# Update model parameters with input parameters
334+
merge_dicts(model_parameters, kwargs[Field.Parameters.value])
335+
336+
kwargs[Field.Parameters.value] = model_parameters
331337

332338
"""
333339
Example resource arn: arn:aws:states:::sagemaker:createModel
@@ -357,7 +363,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
357363
data_capture_config (sagemaker.model_monitor.DataCaptureConfig, optional): Specifies
358364
configuration related to Endpoint data capture for use with
359365
Amazon SageMaker Model Monitoring. Default: None.
360-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
366+
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
361367
"""
362368
parameters = {
363369
'EndpointConfigName': endpoint_config_name,
@@ -399,9 +405,8 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
399405
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
400406
endpoint_name (str or Placeholder): The name of the endpoint to create. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
401407
endpoint_config_name (str or Placeholder): The name of the endpoint configuration to use for the endpoint. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
402-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
403408
update (bool, optional): Boolean flag set to `True` if endpoint must to be updated. Set to `False` if new endpoint must be created. (default: False)
404-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
409+
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
405410
"""
406411

407412
parameters = {
@@ -460,7 +465,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
460465
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
461466
where each instance is a different channel of training data.
462467
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
463-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
468+
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
464469
"""
465470
if wait_for_completion:
466471
"""
@@ -522,7 +527,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
522527
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
523528
The KmsKeyId is applied to all outputs.
524529
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
525-
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
530+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
526531
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_.
527532
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
528533

tests/integ/test_sagemaker_steps.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,59 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
195195
delete_sagemaker_model(model_name, sagemaker_session)
196196
# End of Cleanup
197197

198+
199+
def test_model_step_with_placeholders(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
200+
# Build workflow definition
201+
execution_input = ExecutionInput(schema={
202+
'ModelName': str,
203+
'Mode': str,
204+
'Tags': list
205+
})
206+
207+
parameters = {
208+
'PrimaryContainer': {
209+
'Mode': execution_input['Mode']
210+
},
211+
'Tags': execution_input['Tags']
212+
}
213+
214+
model_step = ModelStep('create_model_step', model=trained_estimator.create_model(),
215+
model_name=execution_input['ModelName'], parameters=parameters)
216+
model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
217+
workflow_graph = Chain([model_step])
218+
219+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
220+
# Create workflow and check definition
221+
workflow = create_workflow_and_check_definition(
222+
workflow_graph=workflow_graph,
223+
workflow_name=unique_name_from_base("integ-test-model-step-workflow"),
224+
sfn_client=sfn_client,
225+
sfn_role_arn=sfn_role_arn
226+
)
227+
228+
inputs = {
229+
'ModelName': generate_job_name(),
230+
'Mode': 'SingleModel',
231+
'Tags': [{
232+
'Key': 'Environment',
233+
'Value': 'test'
234+
}]
235+
}
236+
237+
# Execute workflow
238+
execution = workflow.execute(inputs=inputs)
239+
execution_output = execution.get_output(wait=True)
240+
241+
# Check workflow output
242+
assert execution_output.get("ModelArn") is not None
243+
assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200
244+
245+
# Cleanup
246+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
247+
model_name = get_resource_name_from_arn(execution_output.get("ModelArn")).split("/")[1]
248+
delete_sagemaker_model(model_name, sagemaker_session)
249+
250+
198251
def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
199252
# Create transformer from previously created estimator
200253
job_name = generate_job_name()
@@ -351,7 +404,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
351404
# Execute workflow
352405
execution = workflow.execute()
353406
execution_output = execution.get_output(wait=True)
354-
407+
355408
# Check workflow output
356409
assert execution_output.get("EndpointConfigArn") is not None
357410
assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200
@@ -392,7 +445,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
392445
# Execute workflow
393446
execution = workflow.execute()
394447
execution_output = execution.get_output(wait=True)
395-
448+
396449
# Check workflow output
397450
endpoint_arn = execution_output.get("EndpointArn")
398451
assert execution_output.get("EndpointArn") is not None
@@ -430,7 +483,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
430483
max_jobs=2,
431484
max_parallel_jobs=2,
432485
)
433-
486+
434487
# Build workflow definition
435488
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=job_name, data=record_set_for_hyperparameter_tuning)
436489
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
@@ -448,7 +501,7 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
448501
# Execute workflow
449502
execution = workflow.execute()
450503
execution_output = execution.get_output(wait=True)
451-
504+
452505
# Check workflow output
453506
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"
454507

@@ -498,7 +551,7 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
498551
sfn_client=sfn_client,
499552
sfn_role_arn=sfn_role_arn
500553
)
501-
554+
502555
# Execute workflow
503556
execution = workflow.execute()
504557
execution_output = execution.get_output(wait=True)

tests/unit/test_sagemaker_steps.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,36 @@ def test_model_step_creation(pca_model):
11441144
}
11451145

11461146

1147+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1148+
def test_model_step_creation_with_placeholders(pca_model):
1149+
execution_input = ExecutionInput(schema={
1150+
'Environment': str,
1151+
'Tags': str
1152+
})
1153+
1154+
step_input = StepInput(schema={
1155+
'ModelName': str
1156+
})
1157+
1158+
parameters = {
1159+
'PrimaryContainer': {
1160+
'Environment': execution_input['Environment']
1161+
}
1162+
}
1163+
step = ModelStep('Create model', model=pca_model, model_name=step_input['ModelName'], tags=execution_input['Tags'],
1164+
parameters=parameters)
1165+
assert step.to_dict()['Parameters'] == {
1166+
'ExecutionRoleArn': EXECUTION_ROLE,
1167+
'ModelName.$': "$['ModelName']",
1168+
'PrimaryContainer': {
1169+
'Environment.$': "$$.Execution.Input['Environment']",
1170+
'Image': pca_model.image_uri,
1171+
'ModelDataUrl': pca_model.model_data
1172+
},
1173+
'Tags.$': "$$.Execution.Input['Tags']"
1174+
}
1175+
1176+
11471177
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
11481178
def test_model_step_creation_with_env(pca_model_with_env):
11491179
step = ModelStep('Create model', model=pca_model_with_env, model_name='pca-model', tags=DEFAULT_TAGS)

0 commit comments

Comments
 (0)