From 56e271397b4aebfeb7031dd6ea5c29ac4994ccfc Mon Sep 17 00:00:00 2001 From: Mifeet Date: Fri, 18 Aug 2023 18:52:05 +0100 Subject: [PATCH 1/2] Support SSE-C for S3 --- tensorizer/stream_io.py | 44 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/tensorizer/stream_io.py b/tensorizer/stream_io.py index d5402a69..2df628cc 100644 --- a/tensorizer/stream_io.py +++ b/tensorizer/stream_io.py @@ -1,4 +1,6 @@ +import base64 import functools +import hashlib import io import logging import os @@ -433,10 +435,17 @@ def s3_upload( s3_access_key_id: str, s3_secret_access_key: str, s3_endpoint: str = default_s3_write_endpoint, + s3_sse_customer_key: Optional[bytes] = None, + s3_sse_customer_algorithm: Optional[str] = None, ): bucket, key = _parse_s3_uri(target_uri) client = _new_s3_client(s3_access_key_id, s3_secret_access_key, s3_endpoint) - client.upload_file(path, bucket, key) + extra_args = {} + if s3_sse_customer_key is not None: + extra_args["SSECustomerAlgorithm"] = s3_sse_customer_algorithm + if s3_sse_customer_algorithm is not None: + extra_args["SSECustomerKey"] = s3_sse_customer_key + client.upload_file(path, bucket, key, ExtraArgs=extra_args) def s3_download( @@ -444,15 +453,29 @@ def s3_download( s3_access_key_id: str, s3_secret_access_key: str, s3_endpoint: str = default_s3_read_endpoint, + s3_sse_customer_key: Optional[bytes] = None, + s3_sse_customer_algorithm: Optional[str] = None, ) -> CURLStreamFile: bucket, key = _parse_s3_uri(path_uri) client = _new_s3_client(s3_access_key_id, s3_secret_access_key, s3_endpoint) + encryption_params = {} + if s3_sse_customer_key is not None: + encryption_params["SSECustomerAlgorithm"] = s3_sse_customer_algorithm + if s3_sse_customer_algorithm is not None: + encryption_params["SSECustomerKey"] = s3_sse_customer_key url = client.generate_presigned_url( ClientMethod="get_object", - Params={"Bucket": bucket, "Key": key}, + Params={"Bucket": bucket, "Key": key, **encryption_params}, ExpiresIn=300, ) - return CURLStreamFile(url) + request_headers = {} + if s3_sse_customer_algorithm is not None: + request_headers['x-amz-server-side-encryption-customer-algorithm'] = s3_sse_customer_algorithm + if s3_sse_customer_key is not None: + request_headers['x-amz-server-side-encryption-customer-key'] = base64.b64encode(s3_sse_customer_key).decode() + key_md5 = hashlib.md5(s3_sse_customer_key).digest() + request_headers['x-amz-server-side-encryption-customer-key-MD5'] = base64.b64encode(key_md5).decode() + return CURLStreamFile(url, headers=request_headers) def _infer_credentials( @@ -601,6 +624,8 @@ def open_stream( s3_secret_access_key: Optional[str] = None, s3_endpoint: Optional[str] = None, s3_config_path: Optional[Union[str, bytes, os.PathLike]] = None, + s3_sse_customer_key: Optional[bytes] = None, + s3_sse_customer_algorithm: Optional[str] = None, ) -> Union[CURLStreamFile, typing.BinaryIO]: """ Open a file path, http(s):// URL, or s3:// URI. @@ -638,6 +663,10 @@ def open_stream( s3_config_path: An explicit path to the `~/.s3cfg` config file to be parsed if full credentials are not provided. If None, platform-specific default paths are used. + s3_sse_customer_key: Specifies the customer-provided encryption + key for Amazon S3 to use in encrypting data. + s3_sse_customer_algorithm: Specifies the algorithm to use to + when encrypting the object (for example, AES256). Returns: An opened file-like object representing the target resource. @@ -753,13 +782,20 @@ def open_stream( s3_access_key_id, s3_secret_access_key, s3_endpoint, + s3_sse_customer_key, + s3_sse_customer_algorithm, ) temp_file.close = guaranteed_closer return temp_file else: s3_endpoint = s3_endpoint or default_s3_read_endpoint curl_stream_file = s3_download( - path_uri, s3_access_key_id, s3_secret_access_key, s3_endpoint + path_uri, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, + s3_sse_customer_key=s3_sse_customer_key, + s3_sse_customer_algorithm=s3_sse_customer_algorithm, ) if error_context: curl_stream_file.register_error_context(error_context) From 260d7ec478db01bb6915baaf97f4b73b640227c2 Mon Sep 17 00:00:00 2001 From: Mifeet Date: Fri, 18 Aug 2023 18:52:36 +0100 Subject: [PATCH 2/2] Add test for sse-c --- tests/test_stream_io.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_stream_io.py b/tests/test_stream_io.py index dfde20e6..ad19e827 100644 --- a/tests/test_stream_io.py +++ b/tests/test_stream_io.py @@ -163,6 +163,8 @@ def test_upload(self): s3_access_key_id=self.ACCESS_KEY, s3_secret_access_key=self.SECRET_KEY, s3_endpoint=self.endpoint, + s3_sse_customer_key=os.urandom(32), + s3_sse_customer_algorithm="AES256", ) long_string = b"Hello" * 1024 s.write(long_string) @@ -184,5 +186,7 @@ def test_download(self): s3_access_key_id="X", s3_secret_access_key="X", s3_endpoint=endpoint, + s3_sse_customer_key=os.urandom(32), + s3_sse_customer_algorithm="AES256", ) as s: self.assertEqual(s.read(), long_string)