Skip to content

Commit a305498

Browse files
authored
feature: adding support for DataCaptureConfig in endpoint config (#26)
1 parent 7f460e0 commit a305498

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config
2121
from sagemaker.model import Model, FrameworkModel
22-
22+
from sagemaker.model_monitor import DataCaptureConfig
2323

2424
class TrainingStep(Task):
2525

@@ -209,14 +209,17 @@ class EndpointConfigStep(Task):
209209
Creates a Task State to `create an endpoint configuration in SageMaker <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateEndpointConfig.html>`_.
210210
"""
211211

212-
def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_count, instance_type, tags=None, **kwargs):
212+
def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_count, instance_type, data_capture_config=None, tags=None, **kwargs):
213213
"""
214214
Args:
215215
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
216216
endpoint_config_name (str or Placeholder): The name of the endpoint configuration to create. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
217217
model_name (str or Placeholder): The name of the SageMaker model to attach to the endpoint configuration. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
218218
initial_instance_count (int or Placeholder): The initial number of instances to run in the ``Endpoint`` created from this ``Model``.
219219
instance_type (str or Placeholder): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
220+
data_capture_config (sagemaker.model_monitor.DataCaptureConfig, optional): Specifies
221+
configuration related to Endpoint data capture for use with
222+
Amazon SageMaker Model Monitoring. Default: None.
220223
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
221224
"""
222225
parameters = {
@@ -229,6 +232,9 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
229232
}]
230233
}
231234

235+
if isinstance(data_capture_config, DataCaptureConfig):
236+
parameters['DataCaptureConfig'] = data_capture_config._to_request_dict()
237+
232238
if tags:
233239
parameters['Tags'] = tags_dict_to_kv_list(tags)
234240

tests/unit/test_sagemaker_steps.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.model import Model
2121
from sagemaker.tensorflow import TensorFlow
2222
from sagemaker.pipeline import PipelineModel
23+
from sagemaker.model_monitor import DataCaptureConfig
2324

2425
from unittest.mock import MagicMock, patch
2526
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep
@@ -366,7 +367,16 @@ def test_model_step_creation(pca_model):
366367
}
367368

368369
def test_endpoint_config_step_creation(pca_model):
369-
step = EndpointConfigStep('Endpoint Config', endpoint_config_name='MyEndpointConfig', model_name='pca-model', initial_instance_count=1, instance_type='ml.p2.xlarge')
370+
data_capture_config = DataCaptureConfig(
371+
enable_capture=True,
372+
sampling_percentage=100,
373+
destination_s3_uri='s3://sagemaker/datacapture')
374+
step = EndpointConfigStep('Endpoint Config',
375+
endpoint_config_name='MyEndpointConfig',
376+
model_name='pca-model',
377+
initial_instance_count=1,
378+
instance_type='ml.p2.xlarge',
379+
data_capture_config=data_capture_config)
370380
assert step.to_dict() == {
371381
'Type': 'Task',
372382
'Parameters': {
@@ -376,7 +386,20 @@ def test_endpoint_config_step_creation(pca_model):
376386
'InstanceType': 'ml.p2.xlarge',
377387
'ModelName': 'pca-model',
378388
'VariantName': 'AllTraffic'
379-
}]
389+
}],
390+
'DataCaptureConfig': {
391+
'EnableCapture': True,
392+
'InitialSamplingPercentage': 100,
393+
'DestinationS3Uri': 's3://sagemaker/datacapture',
394+
'CaptureOptions': [
395+
{'CaptureMode': 'Input'},
396+
{'CaptureMode': 'Output'}
397+
],
398+
'CaptureContentTypeHeader': {
399+
'CsvContentTypes': ['text/csv'],
400+
'JsonContentTypes': ['application/json']
401+
}
402+
}
380403
},
381404
'Resource': 'arn:aws:states:::sagemaker:createEndpointConfig',
382405
'End': True

0 commit comments

Comments
 (0)