Skip to content

Commit 64be9bc

Browse files
authored
Cloud Fetch download handler (#127)
* Cloud Fetch download handler Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Issue fix: final result link compressed data has multiple LZ4 end-of-frame markers Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Addressing PR comments - Linting - Type annotations - Use response.ok - Log exception - Remove semaphore and only use threading.event - reset() flags method - Fix tests after removing semaphore - Link expiry logic should be in secs - Decompress data static function - link_expiry_buffer and static public methods - Docstrings and comments Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Changing logger.debug to remove url Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * _reset() comment to docstring Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * link_expiry_buffer -> link_expiry_buffer_secs Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --------- Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
1 parent c351b57 commit 64be9bc

File tree

2 files changed

+306
-0
lines changed

2 files changed

+306
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import logging
2+
3+
import requests
4+
import lz4.frame
5+
import threading
6+
import time
7+
8+
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class ResultSetDownloadHandler(threading.Thread):
14+
def __init__(
15+
self,
16+
downloadable_result_settings,
17+
t_spark_arrow_result_link: TSparkArrowResultLink,
18+
):
19+
super().__init__()
20+
self.settings = downloadable_result_settings
21+
self.result_link = t_spark_arrow_result_link
22+
self.is_download_scheduled = False
23+
self.is_download_finished = threading.Event()
24+
self.is_file_downloaded_successfully = False
25+
self.is_link_expired = False
26+
self.is_download_timedout = False
27+
self.result_file = None
28+
29+
def is_file_download_successful(self) -> bool:
30+
"""
31+
Check and report if cloud fetch file downloaded successfully.
32+
33+
This function will block until a file download finishes or until a timeout.
34+
"""
35+
timeout = self.settings.download_timeout
36+
timeout = timeout if timeout and timeout > 0 else None
37+
try:
38+
if not self.is_download_finished.wait(timeout=timeout):
39+
self.is_download_timedout = True
40+
logger.debug(
41+
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
42+
self.settings.download_timeout,
43+
self.result_link.startRowOffset,
44+
self.result_link.startRowOffset + self.result_link.rowCount,
45+
)
46+
)
47+
return False
48+
except Exception as e:
49+
logger.error(e)
50+
return False
51+
return self.is_file_downloaded_successfully
52+
53+
def run(self):
54+
"""
55+
Download the file described in the cloud fetch link.
56+
57+
This function checks if the link has or is expiring, gets the file via a requests session, decompresses the
58+
file, and signals to waiting threads that the download is finished and whether it was successful.
59+
"""
60+
self._reset()
61+
62+
# Check if link is already expired or is expiring
63+
if ResultSetDownloadHandler.check_link_expired(
64+
self.result_link, self.settings.link_expiry_buffer_secs
65+
):
66+
self.is_link_expired = True
67+
return
68+
69+
session = requests.Session()
70+
session.timeout = self.settings.download_timeout
71+
72+
try:
73+
# Get the file via HTTP request
74+
response = session.get(self.result_link.fileLink)
75+
76+
if not response.ok:
77+
self.is_file_downloaded_successfully = False
78+
return
79+
80+
# Save (and decompress if needed) the downloaded file
81+
compressed_data = response.content
82+
decompressed_data = (
83+
ResultSetDownloadHandler.decompress_data(compressed_data)
84+
if self.settings.is_lz4_compressed
85+
else compressed_data
86+
)
87+
self.result_file = decompressed_data
88+
89+
# The size of the downloaded file should match the size specified from TSparkArrowResultLink
90+
self.is_file_downloaded_successfully = (
91+
len(self.result_file) == self.result_link.bytesNum
92+
)
93+
except Exception as e:
94+
logger.error(e)
95+
self.is_file_downloaded_successfully = False
96+
97+
finally:
98+
session and session.close()
99+
# Awaken threads waiting for this to be true which signals the run is complete
100+
self.is_download_finished.set()
101+
102+
def _reset(self):
103+
"""
104+
Reset download-related flags for every retry of run()
105+
"""
106+
self.is_file_downloaded_successfully = False
107+
self.is_link_expired = False
108+
self.is_download_timedout = False
109+
self.is_download_finished = threading.Event()
110+
111+
@staticmethod
112+
def check_link_expired(
113+
link: TSparkArrowResultLink, expiry_buffer_secs: int
114+
) -> bool:
115+
"""
116+
Check if a link has expired or will expire.
117+
118+
Expiry buffer can be set to avoid downloading files that has not expired yet when the function is called,
119+
but may expire before the file has fully downloaded.
120+
"""
121+
current_time = int(time.time())
122+
if (
123+
link.expiryTime < current_time
124+
or link.expiryTime - current_time < expiry_buffer_secs
125+
):
126+
return True
127+
return False
128+
129+
@staticmethod
130+
def decompress_data(compressed_data: bytes) -> bytes:
131+
"""
132+
Decompress lz4 frame compressed data.
133+
134+
Decompresses data that has been lz4 compressed, either via the whole frame or by series of chunks.
135+
"""
136+
uncompressed_data, bytes_read = lz4.frame.decompress(
137+
compressed_data, return_bytes_read=True
138+
)
139+
# The last cloud fetch file of the entire result is commonly punctuated by frequent end-of-frame markers.
140+
# Full frame decompression above will short-circuit, so chunking is necessary
141+
if bytes_read < len(compressed_data):
142+
d_context = lz4.frame.create_decompression_context()
143+
start = 0
144+
uncompressed_data = bytearray()
145+
while start < len(compressed_data):
146+
data, num_bytes, is_end = lz4.frame.decompress_chunk(
147+
d_context, compressed_data[start:]
148+
)
149+
uncompressed_data += data
150+
start += num_bytes
151+
return uncompressed_data

tests/unit/test_downloader.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import unittest
2+
from unittest.mock import Mock, patch, MagicMock
3+
4+
import databricks.sql.cloudfetch.downloader as downloader
5+
6+
7+
class DownloaderTests(unittest.TestCase):
8+
"""
9+
Unit tests for checking downloader logic.
10+
"""
11+
12+
@patch('time.time', return_value=1000)
13+
def test_run_link_expired(self, mock_time):
14+
settings = Mock()
15+
result_link = Mock()
16+
# Already expired
17+
result_link.expiryTime = 999
18+
d = downloader.ResultSetDownloadHandler(settings, result_link)
19+
assert not d.is_link_expired
20+
d.run()
21+
assert d.is_link_expired
22+
mock_time.assert_called_once()
23+
24+
@patch('time.time', return_value=1000)
25+
def test_run_link_past_expiry_buffer(self, mock_time):
26+
settings = Mock(link_expiry_buffer_secs=5)
27+
result_link = Mock()
28+
# Within the expiry buffer time
29+
result_link.expiryTime = 1004
30+
d = downloader.ResultSetDownloadHandler(settings, result_link)
31+
assert not d.is_link_expired
32+
d.run()
33+
assert d.is_link_expired
34+
mock_time.assert_called_once()
35+
36+
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False))))
37+
@patch('time.time', return_value=1000)
38+
def test_run_get_response_not_ok(self, mock_time, mock_session):
39+
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0)
40+
settings.download_timeout = 0
41+
settings.use_proxy = False
42+
result_link = Mock(expiryTime=1001)
43+
44+
d = downloader.ResultSetDownloadHandler(settings, result_link)
45+
d.run()
46+
47+
assert not d.is_file_downloaded_successfully
48+
assert d.is_download_finished.is_set()
49+
50+
@patch('requests.Session',
51+
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 9))))
52+
@patch('time.time', return_value=1000)
53+
def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session):
54+
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False)
55+
result_link = Mock(bytesNum=100, expiryTime=1001)
56+
57+
d = downloader.ResultSetDownloadHandler(settings, result_link)
58+
d.run()
59+
60+
assert not d.is_file_downloaded_successfully
61+
assert d.is_download_finished.is_set()
62+
63+
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))))
64+
@patch('time.time', return_value=1000)
65+
def test_run_compressed_data_length_incorrect(self, mock_time, mock_session):
66+
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
67+
settings.is_lz4_compressed = True
68+
result_link = Mock(bytesNum=100, expiryTime=1001)
69+
mock_session.return_value.get.return_value.content = \
70+
b'\x04"M\x18h@Z\x00\x00\x00\x00\x00\x00\x00\xec\x14\x00\x00\x00\xaf1234567890\n\x008P67890\x00\x00\x00\x00'
71+
72+
d = downloader.ResultSetDownloadHandler(settings, result_link)
73+
d.run()
74+
75+
assert not d.is_file_downloaded_successfully
76+
assert d.is_download_finished.is_set()
77+
78+
@patch('requests.Session',
79+
return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 10))))
80+
@patch('time.time', return_value=1000)
81+
def test_run_uncompressed_successful(self, mock_time, mock_session):
82+
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
83+
settings.is_lz4_compressed = False
84+
result_link = Mock(bytesNum=100, expiryTime=1001)
85+
86+
d = downloader.ResultSetDownloadHandler(settings, result_link)
87+
d.run()
88+
89+
assert d.result_file == b"1234567890" * 10
90+
assert d.is_file_downloaded_successfully
91+
assert d.is_download_finished.is_set()
92+
93+
@patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))))
94+
@patch('time.time', return_value=1000)
95+
def test_run_compressed_successful(self, mock_time, mock_session):
96+
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
97+
settings.is_lz4_compressed = True
98+
result_link = Mock(bytesNum=100, expiryTime=1001)
99+
mock_session.return_value.get.return_value.content = \
100+
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
101+
102+
d = downloader.ResultSetDownloadHandler(settings, result_link)
103+
d.run()
104+
105+
assert d.result_file == b"1234567890" * 10
106+
assert d.is_file_downloaded_successfully
107+
assert d.is_download_finished.is_set()
108+
109+
@patch('requests.Session.get', side_effect=ConnectionError('foo'))
110+
@patch('time.time', return_value=1000)
111+
def test_download_connection_error(self, mock_time, mock_session):
112+
settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True)
113+
result_link = Mock(bytesNum=100, expiryTime=1001)
114+
mock_session.return_value.get.return_value.content = \
115+
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
116+
117+
d = downloader.ResultSetDownloadHandler(settings, result_link)
118+
d.run()
119+
120+
assert not d.is_file_downloaded_successfully
121+
assert d.is_download_finished.is_set()
122+
123+
@patch('requests.Session.get', side_effect=TimeoutError('foo'))
124+
@patch('time.time', return_value=1000)
125+
def test_download_timeout(self, mock_time, mock_session):
126+
settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True)
127+
result_link = Mock(bytesNum=100, expiryTime=1001)
128+
mock_session.return_value.get.return_value.content = \
129+
b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
130+
131+
d = downloader.ResultSetDownloadHandler(settings, result_link)
132+
d.run()
133+
134+
assert not d.is_file_downloaded_successfully
135+
assert d.is_download_finished.is_set()
136+
137+
@patch("threading.Event.wait", return_value=True)
138+
def test_is_file_download_successful_has_finished(self, mock_wait):
139+
for timeout in [None, 0, 1]:
140+
with self.subTest(timeout=timeout):
141+
settings = Mock(download_timeout=timeout)
142+
result_link = Mock()
143+
handler = downloader.ResultSetDownloadHandler(settings, result_link)
144+
145+
status = handler.is_file_download_successful()
146+
assert status == handler.is_file_downloaded_successfully
147+
148+
def test_is_file_download_successful_times_outs(self):
149+
settings = Mock(download_timeout=1)
150+
result_link = Mock()
151+
handler = downloader.ResultSetDownloadHandler(settings, result_link)
152+
153+
status = handler.is_file_download_successful()
154+
assert not status
155+
assert handler.is_download_timedout

0 commit comments

Comments
 (0)