diff --git a/dagshub/common/config.py b/dagshub/common/config.py index e82b0063..9a2e75f6 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -58,7 +58,32 @@ def set_host(new_host: str): recommended_annotate_limit = int(os.environ.get(RECOMMENDED_ANNOTATE_LIMIT_KEY, 1e5)) DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE" -dataengine_metadata_upload_batch_size = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000)) +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX" +dataengine_metadata_upload_batch_size = int( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY, + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000), + ) +) + +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN" +dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1)) + +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_INITIAL" +dataengine_metadata_upload_batch_size_initial = int( + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY, dataengine_metadata_upload_batch_size_min) +) + +DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME" +DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS" +dataengine_metadata_upload_target_batch_time_seconds = float( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, + os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY, 5.0), + ) +) +# Backwards compatibility for code that imports the old module attribute name. +dataengine_metadata_upload_target_batch_time = dataengine_metadata_upload_target_batch_time_seconds DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 255bb76d..447f58ab 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -53,6 +53,13 @@ run_preupload_transforms, validate_uploading_metadata, ) +from dagshub.data_engine.model.metadata.upload_batching import ( + AdaptiveUploadBatchConfig, + get_retry_delay_seconds, + is_retryable_metadata_upload_error, + next_batch_after_retryable_failure, + next_batch_after_success, +) from dagshub.data_engine.model.metadata.dtypes import DatapointMetadataUpdateEntry from dagshub.data_engine.model.metadata.transforms import DatasourceFieldInfo, _add_metadata from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder @@ -755,16 +762,92 @@ def _upload_metadata(self, metadata_entries: List[DatapointMetadataUpdateEntry]) progress = get_rich_progress(rich.progress.MofNCompleteColumn()) - upload_batch_size = dagshub.common.config.dataengine_metadata_upload_batch_size + batch_config = AdaptiveUploadBatchConfig.from_values( + max_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size, + min_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_min, + initial_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_initial, + target_batch_time_seconds=dagshub.common.config.dataengine_metadata_upload_target_batch_time_seconds, + ) + current_batch_size = batch_config.initial_batch_size + total_entries = len(metadata_entries) - total_task = progress.add_task(f"Uploading metadata (batch size {upload_batch_size})...", total=total_entries) + total_task = progress.add_task( + f"Uploading metadata (adaptive batch {batch_config.min_batch_size}-{batch_config.max_batch_size})...", + total=total_entries, + ) + + last_good_batch_size: Optional[int] = None + last_bad_batch_size: Optional[int] = None + consecutive_retryable_failures = 0 with progress: - for start in range(0, total_entries, upload_batch_size): - entries = metadata_entries[start : start + upload_batch_size] - logger.debug(f"Uploading {len(entries)} metadata entries...") - self.source.client.update_metadata(self, entries) - progress.update(total_task, advance=upload_batch_size) + start = 0 + while start < total_entries: + entries_left = total_entries - start + batch_size = min(current_batch_size, entries_left) + entries = metadata_entries[start : start + batch_size] + + progress.update( + total_task, + description=f"Uploading metadata (batch size {batch_size})...", + ) + logger.debug(f"Uploading {batch_size} metadata entries...") + + start_time = time.monotonic() + try: + self.source.client.update_metadata(self, entries) + except Exception as exc: + if not is_retryable_metadata_upload_error(exc): + logger.error("Metadata upload failed with a non-retryable error; aborting.", exc_info=True) + raise + + if batch_size <= 1: + logger.error( + f"Metadata upload failed at minimum batch size ({batch_size}); aborting.", + exc_info=True, + ) + raise + + consecutive_retryable_failures += 1 + retry_delay_sec = get_retry_delay_seconds(consecutive_retryable_failures) + time.sleep(retry_delay_sec) + + last_bad_batch_size = ( + batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + ) + current_batch_size = next_batch_after_retryable_failure( + batch_size, + batch_config, + last_good_batch_size, + last_bad_batch_size, + ) + logger.warning( + f"Metadata upload failed for batch size {batch_size} " + f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." + ) + continue + + elapsed = time.monotonic() - start_time + consecutive_retryable_failures = 0 + start += batch_size + progress.update(total_task, advance=batch_size) + + if elapsed <= batch_config.target_batch_time_seconds: + last_good_batch_size = ( + batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) + ) + current_batch_size = next_batch_after_success(batch_size, batch_config, last_bad_batch_size) + else: + last_bad_batch_size = ( + batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + ) + current_batch_size = next_batch_after_retryable_failure( + batch_size, + batch_config, + last_good_batch_size, + last_bad_batch_size, + ) + progress.update(total_task, completed=total_entries, refresh=True) # Update the status from dagshub, so we get back the new metadata columns diff --git a/dagshub/data_engine/model/metadata/upload_batching.py b/dagshub/data_engine/model/metadata/upload_batching.py new file mode 100644 index 00000000..6394ff03 --- /dev/null +++ b/dagshub/data_engine/model/metadata/upload_batching.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Optional + +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import Timeout as RequestsTimeout +from tenacity import wait_exponential + +from dagshub.data_engine.model.errors import DataEngineGqlError + +MIN_TARGET_BATCH_TIME_SECONDS = 0.01 +BATCH_GROWTH_FACTOR = 10 +RETRY_BACKOFF_BASE_SECONDS = 0.25 +RETRY_BACKOFF_MAX_SECONDS = 4.0 + +_retry_delay_strategy = wait_exponential( + multiplier=RETRY_BACKOFF_BASE_SECONDS, + min=RETRY_BACKOFF_BASE_SECONDS, + max=RETRY_BACKOFF_MAX_SECONDS, +) + + +@dataclass(frozen=True) +class AdaptiveUploadBatchConfig: + max_batch_size: int + min_batch_size: int + initial_batch_size: int + target_batch_time_seconds: float + + @classmethod + def from_values( + cls, + max_batch_size: int, + min_batch_size: int, + initial_batch_size: int, + target_batch_time_seconds: float, + ) -> "AdaptiveUploadBatchConfig": + normalized_max_batch_size = max(1, max_batch_size) + normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) + normalized_initial_batch_size = max( + normalized_min_batch_size, + min(initial_batch_size, normalized_max_batch_size), + ) + normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS) + return cls( + max_batch_size=normalized_max_batch_size, + min_batch_size=normalized_min_batch_size, + initial_batch_size=normalized_initial_batch_size, + target_batch_time_seconds=normalized_target_batch_time_seconds, + ) + + +def _midpoint(lower_bound: int, upper_bound: int) -> int: + return lower_bound + max(1, (upper_bound - lower_bound) // 2) + + +def next_batch_after_success( + batch_size: int, + config: AdaptiveUploadBatchConfig, + bad_batch_size: Optional[int], +) -> int: + if bad_batch_size is not None and batch_size < bad_batch_size: + next_batch_size = _midpoint(batch_size, bad_batch_size) + next_batch_size = min(next_batch_size, bad_batch_size - 1) + else: + next_batch_size = batch_size * BATCH_GROWTH_FACTOR + + next_batch_size = min(config.max_batch_size, next_batch_size) + if next_batch_size <= batch_size and batch_size < config.max_batch_size: + next_batch_size = min(config.max_batch_size, batch_size + 1) + if bad_batch_size is not None: + next_batch_size = min(next_batch_size, bad_batch_size - 1) + + return max(config.min_batch_size, next_batch_size) + + +def next_batch_after_retryable_failure( + batch_size: int, + config: AdaptiveUploadBatchConfig, + good_batch_size: Optional[int], + bad_batch_size: Optional[int], +) -> int: + if batch_size <= 1: + return 1 + + upper_bound = min(batch_size, bad_batch_size) if bad_batch_size is not None else batch_size + if good_batch_size is not None and good_batch_size < upper_bound: + next_batch_size = _midpoint(good_batch_size, upper_bound) + else: + next_batch_size = batch_size // 2 + + next_batch_size = min(next_batch_size, upper_bound - 1, batch_size - 1, config.max_batch_size) + return max(1, next_batch_size) + + +def is_retryable_metadata_upload_error(exc: Exception) -> bool: + if isinstance(exc, DataEngineGqlError): + return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed)) + + return isinstance( + exc, + ( + TransportServerError, + TransportConnectionFailed, + TimeoutError, + ConnectionError, + RequestsConnectionError, + RequestsTimeout, + ), + ) + + +def get_retry_delay_seconds(consecutive_retryable_failures: int) -> float: + retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures)) + return float(_retry_delay_strategy(retry_state)) diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index bd1f1912..5fe2ea00 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -8,6 +8,7 @@ import pandas as pd import pytest +import dagshub.common.config from dagshub.common.util import wrap_bytes from dagshub.data_engine.annotation import MetadataAnnotations from dagshub.data_engine.client.models import MetadataFieldSchema @@ -19,6 +20,10 @@ from tests.data_engine.util import add_string_fields, add_document_fields, add_annotation_fields +def _uploaded_batch_sizes(ds: Datasource): + return [len(call.args[1]) for call in ds.source.client.update_metadata.call_args_list] + + @pytest.fixture def metadata_df(): data_dict = { @@ -142,6 +147,176 @@ def test_uploading_to_document_turns_into_blob(ds): client_mock.update_metadata.assert_called_with(ds, expected_data_upload) +def test_upload_metadata_starts_small_and_grows(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(14) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [2, 12] + + +def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) + + has_failed = {"value": False} + + def _flaky_upload(_ds, upload_entries): + if len(upload_entries) == 8 and not has_failed["value"]: + has_failed["value"] = True + raise TimeoutError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert has_failed["value"] + assert _uploaded_batch_sizes(ds) == [8, 4, 6] + + +def test_upload_metadata_does_not_retry_known_bad_batch_size(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(32) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) + + has_failed = {"value": False} + + def _flaky_upload(_ds, upload_entries): + if len(upload_entries) == 8 and not has_failed["value"]: + has_failed["value"] = True + raise TimeoutError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert has_failed["value"] + assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 7, 7, 1] + + +def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1.0) + mocker.patch("dagshub.data_engine.model.datasource.time.monotonic", side_effect=[0.0, 2.0, 3.0, 3.1]) + + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [8, 4] + + +def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + ds.source.client.update_metadata.side_effect = ValueError("simulated validation error") + + with pytest.raises(ValueError, match="simulated validation error"): + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [8] + + +def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) + + has_failed = {"value": False} + + def _flaky_upload(_ds, upload_entries): + if len(upload_entries) == 2 and not has_failed["value"]: + has_failed["value"] = True + raise TimeoutError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert has_failed["value"] + assert _uploaded_batch_sizes(ds) == [8, 2, 1, 1] + + +def test_upload_metadata_backoff_resets_after_success(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") + + call_idx = {"value": 0} + + def _flaky_upload(_ds, _upload_entries): + call_idx["value"] += 1 + if call_idx["value"] in {1, 3}: + raise TimeoutError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25, 0.25] + + +def test_upload_metadata_retries_below_configured_min_before_aborting(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(6) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") + ds.source.client.update_metadata.side_effect = TimeoutError("simulated timeout") + + with pytest.raises(TimeoutError, match="simulated timeout"): + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [2, 1] + assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25] + + def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"],