Skip to content
Open
27 changes: 26 additions & 1 deletion dagshub/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,32 @@ def set_host(new_host: str):
recommended_annotate_limit = int(os.environ.get(RECOMMENDED_ANNOTATE_LIMIT_KEY, 1e5))

DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE"
dataengine_metadata_upload_batch_size = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000))
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX"
dataengine_metadata_upload_batch_size = int(
os.environ.get(
DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY,
os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000),
)
)

DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN"
dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1))

DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_INITIAL"
dataengine_metadata_upload_batch_size_initial = int(
os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY, dataengine_metadata_upload_batch_size_min)
)

DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME"
DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS"
dataengine_metadata_upload_target_batch_time_seconds = float(
os.environ.get(
DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY,
os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY, 5.0),
)
)
# Backwards compatibility for code that imports the old module attribute name.
dataengine_metadata_upload_target_batch_time = dataengine_metadata_upload_target_batch_time_seconds

DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS"
disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ
Expand Down
97 changes: 90 additions & 7 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
run_preupload_transforms,
validate_uploading_metadata,
)
from dagshub.data_engine.model.metadata.upload_batching import (
AdaptiveUploadBatchConfig,
get_retry_delay_seconds,
is_retryable_metadata_upload_error,
next_batch_after_retryable_failure,
next_batch_after_success,
)
from dagshub.data_engine.model.metadata.dtypes import DatapointMetadataUpdateEntry
from dagshub.data_engine.model.metadata.transforms import DatasourceFieldInfo, _add_metadata
from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder
Expand Down Expand Up @@ -755,16 +762,92 @@ def _upload_metadata(self, metadata_entries: List[DatapointMetadataUpdateEntry])

progress = get_rich_progress(rich.progress.MofNCompleteColumn())

upload_batch_size = dagshub.common.config.dataengine_metadata_upload_batch_size
batch_config = AdaptiveUploadBatchConfig.from_values(
max_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size,
min_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_min,
initial_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_initial,
target_batch_time_seconds=dagshub.common.config.dataengine_metadata_upload_target_batch_time_seconds,
)
current_batch_size = batch_config.initial_batch_size

total_entries = len(metadata_entries)
total_task = progress.add_task(f"Uploading metadata (batch size {upload_batch_size})...", total=total_entries)
total_task = progress.add_task(
f"Uploading metadata (adaptive batch {batch_config.min_batch_size}-{batch_config.max_batch_size})...",
total=total_entries,
)

last_good_batch_size: Optional[int] = None
last_bad_batch_size: Optional[int] = None
consecutive_retryable_failures = 0

with progress:
for start in range(0, total_entries, upload_batch_size):
entries = metadata_entries[start : start + upload_batch_size]
logger.debug(f"Uploading {len(entries)} metadata entries...")
self.source.client.update_metadata(self, entries)
progress.update(total_task, advance=upload_batch_size)
start = 0
while start < total_entries:
entries_left = total_entries - start
batch_size = min(current_batch_size, entries_left)
entries = metadata_entries[start : start + batch_size]

progress.update(
total_task,
description=f"Uploading metadata (batch size {batch_size})...",
)
logger.debug(f"Uploading {batch_size} metadata entries...")

start_time = time.monotonic()
try:
self.source.client.update_metadata(self, entries)
except Exception as exc:
if not is_retryable_metadata_upload_error(exc):
logger.error("Metadata upload failed with a non-retryable error; aborting.", exc_info=True)
raise

if batch_size <= 1:
logger.error(
f"Metadata upload failed at minimum batch size ({batch_size}); aborting.",
exc_info=True,
)
raise

consecutive_retryable_failures += 1
retry_delay_sec = get_retry_delay_seconds(consecutive_retryable_failures)
time.sleep(retry_delay_sec)

last_bad_batch_size = (
batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size)
)
current_batch_size = next_batch_after_retryable_failure(
batch_size,
batch_config,
last_good_batch_size,
last_bad_batch_size,
)
logger.warning(
f"Metadata upload failed for batch size {batch_size} "
f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}."
)
continue

elapsed = time.monotonic() - start_time
consecutive_retryable_failures = 0
start += batch_size
progress.update(total_task, advance=batch_size)

if elapsed <= batch_config.target_batch_time_seconds:
last_good_batch_size = (
batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size)
)
current_batch_size = next_batch_after_success(batch_size, batch_config, last_bad_batch_size)
else:
last_bad_batch_size = (
batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size)
)
current_batch_size = next_batch_after_retryable_failure(
batch_size,
batch_config,
last_good_batch_size,
last_bad_batch_size,
)

progress.update(total_task, completed=total_entries, refresh=True)

# Update the status from dagshub, so we get back the new metadata columns
Expand Down
116 changes: 116 additions & 0 deletions dagshub/data_engine/model/metadata/upload_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Optional

from gql.transport.exceptions import TransportConnectionFailed, TransportServerError
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import Timeout as RequestsTimeout
from tenacity import wait_exponential

from dagshub.data_engine.model.errors import DataEngineGqlError

MIN_TARGET_BATCH_TIME_SECONDS = 0.01
BATCH_GROWTH_FACTOR = 10
RETRY_BACKOFF_BASE_SECONDS = 0.25
RETRY_BACKOFF_MAX_SECONDS = 4.0

_retry_delay_strategy = wait_exponential(
multiplier=RETRY_BACKOFF_BASE_SECONDS,
min=RETRY_BACKOFF_BASE_SECONDS,
max=RETRY_BACKOFF_MAX_SECONDS,
)


@dataclass(frozen=True)
class AdaptiveUploadBatchConfig:
max_batch_size: int
min_batch_size: int
initial_batch_size: int
target_batch_time_seconds: float

@classmethod
def from_values(
cls,
max_batch_size: int,
min_batch_size: int,
initial_batch_size: int,
target_batch_time_seconds: float,
) -> "AdaptiveUploadBatchConfig":
normalized_max_batch_size = max(1, max_batch_size)
normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size))
normalized_initial_batch_size = max(
normalized_min_batch_size,
min(initial_batch_size, normalized_max_batch_size),
)
normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS)
return cls(
max_batch_size=normalized_max_batch_size,
min_batch_size=normalized_min_batch_size,
initial_batch_size=normalized_initial_batch_size,
target_batch_time_seconds=normalized_target_batch_time_seconds,
)


def _midpoint(lower_bound: int, upper_bound: int) -> int:
return lower_bound + max(1, (upper_bound - lower_bound) // 2)


def next_batch_after_success(
batch_size: int,
config: AdaptiveUploadBatchConfig,
bad_batch_size: Optional[int],
) -> int:
if bad_batch_size is not None and batch_size < bad_batch_size:
next_batch_size = _midpoint(batch_size, bad_batch_size)
next_batch_size = min(next_batch_size, bad_batch_size - 1)
else:
next_batch_size = batch_size * BATCH_GROWTH_FACTOR

next_batch_size = min(config.max_batch_size, next_batch_size)
if next_batch_size <= batch_size and batch_size < config.max_batch_size:
next_batch_size = min(config.max_batch_size, batch_size + 1)
if bad_batch_size is not None:
next_batch_size = min(next_batch_size, bad_batch_size - 1)

return max(config.min_batch_size, next_batch_size)


def next_batch_after_retryable_failure(
batch_size: int,
config: AdaptiveUploadBatchConfig,
good_batch_size: Optional[int],
bad_batch_size: Optional[int],
) -> int:
if batch_size <= 1:
return 1

upper_bound = min(batch_size, bad_batch_size) if bad_batch_size is not None else batch_size
if good_batch_size is not None and good_batch_size < upper_bound:
next_batch_size = _midpoint(good_batch_size, upper_bound)
else:
next_batch_size = batch_size // 2

next_batch_size = min(next_batch_size, upper_bound - 1, batch_size - 1, config.max_batch_size)
return max(1, next_batch_size)


def is_retryable_metadata_upload_error(exc: Exception) -> bool:
if isinstance(exc, DataEngineGqlError):
return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed))

Comment on lines +97 to +100
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Wrapped retryable errors can be misclassified and skip retries.

When exc is DataEngineGqlError, only two wrapped types are checked. If original_exception is another retryable type already recognized below (e.g., timeout/connection), upload aborts instead of retrying.

Proposed fix
 def is_retryable_metadata_upload_error(exc: Exception) -> bool:
     if isinstance(exc, DataEngineGqlError):
-        return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed))
+        original_exception = exc.original_exception
+        return isinstance(original_exception, Exception) and is_retryable_metadata_upload_error(original_exception)

     return isinstance(
         exc,
         (
             TransportServerError,

return isinstance(
exc,
(
TransportServerError,
TransportConnectionFailed,
TimeoutError,
ConnectionError,
RequestsConnectionError,
RequestsTimeout,
),
)


def get_retry_delay_seconds(consecutive_retryable_failures: int) -> float:
retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures))
return float(_retry_delay_strategy(retry_state))
Loading