diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py new file mode 100644 index 000000000..aac3ac336 --- /dev/null +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -0,0 +1,166 @@ +import logging + +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import List, Union + +from databricks.sql.cloudfetch.downloader import ( + ResultSetDownloadHandler, + DownloadableResultSettings, +) +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink + +logger = logging.getLogger(__name__) + + +@dataclass +class DownloadedFile: + """ + Class for the result file and metadata. + + Attributes: + file_bytes (bytes): Downloaded file in bytes. + start_row_offset (int): The offset of the starting row in relation to the full result. + row_count (int): Number of rows the file represents in the result. + """ + + file_bytes: bytes + start_row_offset: int + row_count: int + + +class ResultFileDownloadManager: + def __init__(self, max_download_threads: int, lz4_compressed: bool): + self.download_handlers: List[ResultSetDownloadHandler] = [] + self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) + self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed) + self.fetch_need_retry = False + self.num_consecutive_result_file_download_retries = 0 + + def add_file_links( + self, t_spark_arrow_result_links: List[TSparkArrowResultLink] + ) -> None: + """ + Create download handler for each cloud fetch link. + + Args: + t_spark_arrow_result_links: List of cloud fetch links consisting of file URL and metadata. + """ + for link in t_spark_arrow_result_links: + if link.rowCount <= 0: + continue + self.download_handlers.append( + ResultSetDownloadHandler(self.downloadable_result_settings, link) + ) + + def get_next_downloaded_file( + self, next_row_offset: int + ) -> Union[DownloadedFile, None]: + """ + Get next file that starts at given offset. + + This function gets the next downloaded file in which its rows start at the specified next_row_offset + in relation to the full result. File downloads are scheduled if not already, and once the correct + download handler is located, the function waits for the download status and returns the resulting file. + If there are no more downloads, a download was not successful, or the correct file could not be located, + this function shuts down the thread pool and returns None. + + Args: + next_row_offset (int): The offset of the starting row of the next file we want data from. + """ + # No more files to download from this batch of links + if not self.download_handlers: + self._shutdown_manager() + return None + + # Remove handlers we don't need anymore + self._remove_past_handlers(next_row_offset) + + # Schedule the downloads + self._schedule_downloads() + + # Find next file + idx = self._find_next_file_index(next_row_offset) + if idx is None: + self._shutdown_manager() + return None + handler = self.download_handlers[idx] + + # Check (and wait) for download status + if self._check_if_download_successful(handler): + # Buffer should be empty so set buffer to new ArrowQueue with result_file + result = DownloadedFile( + handler.result_file, + handler.result_link.startRowOffset, + handler.result_link.rowCount, + ) + self.download_handlers.pop(idx) + # Return True upon successful download to continue loop and not force a retry + return result + # Download was not successful for next download item, force a retry + self._shutdown_manager() + return None + + def _remove_past_handlers(self, next_row_offset: int): + # Any link in which its start to end range doesn't include the next row to be fetched does not need downloading + i = 0 + while i < len(self.download_handlers): + result_link = self.download_handlers[i].result_link + if result_link.startRowOffset + result_link.rowCount > next_row_offset: + i += 1 + continue + self.download_handlers.pop(i) + + def _schedule_downloads(self): + # Schedule downloads for all download handlers if not already scheduled. + for handler in self.download_handlers: + if handler.is_download_scheduled: + continue + try: + self.thread_pool.submit(handler.run) + except Exception as e: + logger.error(e) + break + handler.is_download_scheduled = True + + def _find_next_file_index(self, next_row_offset: int): + # Get the handler index of the next file in order + next_indices = [ + i + for i, handler in enumerate(self.download_handlers) + if handler.is_download_scheduled + and handler.result_link.startRowOffset == next_row_offset + ] + return next_indices[0] if len(next_indices) > 0 else None + + def _check_if_download_successful(self, handler: ResultSetDownloadHandler): + # Check (and wait until download finishes) if download was successful + if not handler.is_file_download_successful(): + if handler.is_link_expired: + self.fetch_need_retry = True + return False + elif handler.is_download_timedout: + # Consecutive file retries should not exceed threshold in settings + if ( + self.num_consecutive_result_file_download_retries + >= self.downloadable_result_settings.max_consecutive_file_download_retries + ): + self.fetch_need_retry = True + return False + self.num_consecutive_result_file_download_retries += 1 + + # Re-submit handler run to thread pool and recursively check download status + self.thread_pool.submit(handler.run) + return self._check_if_download_successful(handler) + else: + self.fetch_need_retry = True + return False + + self.num_consecutive_result_file_download_retries = 0 + self.fetch_need_retry = False + return True + + def _shutdown_manager(self): + # Clear download handlers and shutdown the thread pool to cancel pending futures + self.download_handlers = [] + self.thread_pool.shutdown(wait=False, cancel_futures=True) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index d3c4a480f..019c4ef92 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass import requests import lz4.frame @@ -10,10 +11,28 @@ logger = logging.getLogger(__name__) +@dataclass +class DownloadableResultSettings: + """ + Class for settings common to each download handler. + + Attributes: + is_lz4_compressed (bool): Whether file is expected to be lz4 compressed. + link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. + download_timeout (int): Timeout for download requests. Default 60 secs. + max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. + """ + + is_lz4_compressed: bool + link_expiry_buffer_secs: int = 0 + download_timeout: int = 60 + max_consecutive_file_download_retries: int = 0 + + class ResultSetDownloadHandler(threading.Thread): def __init__( self, - downloadable_result_settings, + downloadable_result_settings: DownloadableResultSettings, t_spark_arrow_result_link: TSparkArrowResultLink, ): super().__init__() @@ -32,8 +51,11 @@ def is_file_download_successful(self) -> bool: This function will block until a file download finishes or until a timeout. """ - timeout = self.settings.download_timeout - timeout = timeout if timeout and timeout > 0 else None + timeout = ( + self.settings.download_timeout + if self.settings.download_timeout > 0 + else None + ) try: if not self.is_download_finished.wait(timeout=timeout): self.is_download_timedout = True diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py new file mode 100644 index 000000000..97bf407aa --- /dev/null +++ b/tests/unit/test_download_manager.py @@ -0,0 +1,207 @@ +import unittest +from unittest.mock import patch, MagicMock + +import databricks.sql.cloudfetch.download_manager as download_manager +import databricks.sql.cloudfetch.downloader as downloader +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink + + +class DownloadManagerTests(unittest.TestCase): + """ + Unit tests for checking download manager logic. + """ + + def create_download_manager(self): + max_download_threads = 10 + lz4_compressed = True + return download_manager.ResultFileDownloadManager(max_download_threads, lz4_compressed) + + def create_result_link( + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520 + ): + return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + + def create_result_links(self, num_files: int, start_row_offset: int = 0): + result_links = [] + for i in range(num_files): + file_link = "fileLink_" + str(i) + result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_links.append(result_link) + start_row_offset += result_link.rowCount + return result_links + + def test_add_file_links_zero_row_count(self): + links = [self.create_result_link(row_count=0, bytes_num=0)] + manager = self.create_download_manager() + manager.add_file_links(links) + + assert not manager.download_handlers + + def test_add_file_links_success(self): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + + assert len(manager.download_handlers) == 10 + + def test_remove_past_handlers_one(self): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + + manager._remove_past_handlers(8000) + assert len(manager.download_handlers) == 9 + + def test_remove_past_handlers_all(self): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + + manager._remove_past_handlers(8000*10) + assert len(manager.download_handlers) == 0 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_schedule_downloads_partial_already_scheduled(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + + for i in range(5): + manager.download_handlers[i].is_download_scheduled = True + + manager._schedule_downloads() + assert mock_submit.call_count == 5 + assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_schedule_downloads_will_not_schedule_twice(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + + for i in range(5): + manager.download_handlers[i].is_download_scheduled = True + + manager._schedule_downloads() + assert mock_submit.call_count == 5 + assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10 + + manager._schedule_downloads() + assert mock_submit.call_count == 5 + + @patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")]) + def test_schedule_downloads_submit_fails(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + + manager._schedule_downloads() + assert mock_submit.call_count == 2 + assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 1 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_find_next_file_index_all_scheduled_next_row_0(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + manager._schedule_downloads() + + assert manager._find_next_file_index(0) == 0 + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_find_next_file_index_all_scheduled_next_row_7999(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + manager._schedule_downloads() + + assert manager._find_next_file_index(7999) is None + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_find_next_file_index_all_scheduled_next_row_8000(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + manager._schedule_downloads() + + assert manager._find_next_file_index(8000) == 1 + + @patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")]) + def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + manager._schedule_downloads() + + assert manager._find_next_file_index(8000) is None + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=True) + @patch("concurrent.futures.ThreadPoolExecutor.submit") + def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful): + links = self.create_result_links(num_files=10) + manager = self.create_download_manager() + manager.add_file_links(links) + manager._schedule_downloads() + + status = manager._check_if_download_successful(manager.download_handlers[0]) + assert status + assert manager.num_consecutive_result_file_download_retries == 0 + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_link_expired(self, mock_is_file_download_successful): + manager = self.create_download_manager() + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + handler.is_link_expired = True + + status = manager._check_if_download_successful(handler) + mock_is_file_download_successful.assert_called() + assert not status + assert manager.fetch_need_retry + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_download_timed_out_no_retries(self, mock_is_file_download_successful): + manager = self.create_download_manager() + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + handler.is_download_timedout = True + + status = manager._check_if_download_successful(handler) + mock_is_file_download_successful.assert_called() + assert not status + assert manager.fetch_need_retry + + @patch("concurrent.futures.ThreadPoolExecutor.submit") + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_file_download_successful, mock_submit): + manager = self.create_download_manager() + manager.downloadable_result_settings = download_manager.DownloadableResultSettings( + is_lz4_compressed=True, + download_timeout=0, + max_consecutive_file_download_retries=1, + ) + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + handler.is_download_timedout = True + + status = manager._check_if_download_successful(handler) + assert mock_is_file_download_successful.call_count == 2 + assert mock_submit.call_count == 1 + assert not status + assert manager.fetch_need_retry + + @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", + return_value=False) + def test_check_if_download_successful_other_reason(self, mock_is_file_download_successful): + manager = self.create_download_manager() + handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) + + status = manager._check_if_download_successful(handler) + mock_is_file_download_successful.assert_called() + assert not status + assert manager.fetch_need_retry diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index cee3a83c7..6e13c9496 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -136,7 +136,7 @@ def test_download_timeout(self, mock_time, mock_session): @patch("threading.Event.wait", return_value=True) def test_is_file_download_successful_has_finished(self, mock_wait): - for timeout in [None, 0, 1]: + for timeout in [0, 1]: with self.subTest(timeout=timeout): settings = Mock(download_timeout=timeout) result_link = Mock()