Skip to content

Commit 4979d74

Browse files
authored
feature: adding support for ExperimentConfig in training and transform steps (#23)
1 parent 5b35e09 commit 4979d74

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class TrainingStep(Task):
2727
Creates a Task State to execute a `SageMaker Training Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html>`_. The TrainingStep will also create a model by default, and the model shares the same name as the training job.
2828
"""
2929

30-
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, wait_for_completion=True, **kwargs):
30+
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, **kwargs):
3131
"""
3232
Args:
3333
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -50,6 +50,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
5050
where each instance is a different channel of training data.
5151
hyperparameters (dict, optional): Specify the hyper parameters for the training. (Default: None)
5252
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
53+
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
5354
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
5455
"""
5556
self.estimator = estimator
@@ -71,6 +72,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
7172
if hyperparameters is not None:
7273
parameters['HyperParameters'] = hyperparameters
7374

75+
if experiment_config is not None:
76+
parameters['ExperimentConfig'] = experiment_config
77+
7478
if 'S3Operations' in parameters:
7579
del parameters['S3Operations']
7680

@@ -101,7 +105,7 @@ class TransformStep(Task):
101105
Creates a Task State to execute a `SageMaker Transform Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
102106
"""
103107

104-
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, wait_for_completion=True, **kwargs):
108+
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, **kwargs):
105109
"""
106110
Args:
107111
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -119,6 +123,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
119123
content_type (str): MIME type of the input data (default: None).
120124
compression_type (str): Compression type of the input data, if compressed (default: None). Valid values: 'Gzip', None.
121125
split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
126+
experiment_config (dict, optional): Specify the experiment config for the transform. (Default: None)
122127
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True)
123128
"""
124129
if wait_for_completion:
@@ -151,6 +156,9 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
151156

152157
parameters['ModelName'] = model_name
153158

159+
if experiment_config is not None:
160+
parameters['ExperimentConfig'] = experiment_config
161+
154162
kwargs[Field.Parameters.value] = parameters
155163
super(TransformStep, self).__init__(state_id, **kwargs)
156164

tests/unit/test_sagemaker_steps.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,14 @@ def tensorflow_estimator():
9898

9999
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
100100
def test_training_step_creation(pca_estimator):
101-
step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')
101+
step = TrainingStep('Training',
102+
estimator=pca_estimator,
103+
job_name='TrainingJob',
104+
experiment_config={
105+
'ExperimentName': 'pca_experiment',
106+
'TrialName': 'pca_trial',
107+
'TrialComponentDisplayName': 'Training'
108+
})
102109
assert step.to_dict() == {
103110
'Type': 'Task',
104111
'Parameters': {
@@ -125,6 +132,11 @@ def test_training_step_creation(pca_estimator):
125132
'algorithm_mode': 'randomized',
126133
'mini_batch_size': '200'
127134
},
135+
'ExperimentConfig': {
136+
'ExperimentName': 'pca_experiment',
137+
'TrialName': 'pca_trial',
138+
'TrialComponentDisplayName': 'Training'
139+
},
128140
'TrainingJobName': 'TrainingJob'
129141
},
130142
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
@@ -243,7 +255,12 @@ def test_transform_step_creation(pca_transformer):
243255
transformer=pca_transformer,
244256
data='s3://sagemaker/inference',
245257
job_name='transform-job',
246-
model_name='pca-model'
258+
model_name='pca-model',
259+
experiment_config={
260+
'ExperimentName': 'pca_experiment',
261+
'TrialName': 'pca_trial',
262+
'TrialComponentDisplayName': 'Transform'
263+
}
247264
)
248265
assert step.to_dict() == {
249266
'Type': 'Task',
@@ -264,6 +281,11 @@ def test_transform_step_creation(pca_transformer):
264281
'TransformResources': {
265282
'InstanceCount': 1,
266283
'InstanceType': 'ml.c4.xlarge'
284+
},
285+
'ExperimentConfig': {
286+
'ExperimentName': 'pca_experiment',
287+
'TrialName': 'pca_trial',
288+
'TrialComponentDisplayName': 'Transform'
267289
}
268290
},
269291
'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync',

0 commit comments

Comments
 (0)