Skip to content
Open
64 changes: 64 additions & 0 deletions adaptdl/adaptdl/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import logging

import adaptdl.checkpoint
import adaptdl.env
from adaptdl.torch._metrics import _report_sched_hints
from adaptdl.torch.data import current_dataloader

logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)


# Percentage of current_local_bsz used to decide upper bound on
# local_bsz_bounds after OOM
LOCAL_BSZ_CUTOFF_PCT = 0.1


def cudaoom(e):
return "CUDA out of memory" in str(e)


def retry(func):
@functools.wraps(func)
def inner(*args, **kwargs):
try:
func(*args, **kwargs)
except RuntimeError as e:
LOG.info(f"{e}")
dataloader = current_dataloader()
if (dataloader is not None and
dataloader.local_bsz_bounds is not None and
cudaoom(e)):
current_local_bsz = dataloader.current_local_bsz
low, high = dataloader.local_bsz_bounds
assert current_local_bsz <= high
new_high = int((1. - LOCAL_BSZ_CUTOFF_PCT) * current_local_bsz)
if new_high < low:
raise e
dataloader.local_bsz_bounds = (low, new_high)
LOG.info(f"Local batch size bounds changed to "
f"{dataloader.local_bsz_bounds}")
if adaptdl.env.replica_rank() == 0:
_report_sched_hints()
adaptdl.checkpoint.save_all_states()
exit(143)
else:
raise e
return inner
38 changes: 25 additions & 13 deletions adaptdl/adaptdl/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ class AdaptiveDataLoaderHelper(object):
def __init__(self, batch_size=1):
# Autoscale batch size fields.
self._max_batch_size = None
self._local_bsz_bounds = None
# Create and load state.
self._state = _AdaptiveDataLoaderState()
adaptdl.checkpoint.load_state(self._state)
Expand Down Expand Up @@ -198,7 +197,11 @@ def local_bsz_bounds(self):
The local batch size bounds on each replica. A pair of integers,
(min_local_bsz, max_local_bsz).
"""
return self._local_bsz_bounds
return self._state.local_bsz_bounds

@local_bsz_bounds.setter
def local_bsz_bounds(self, bounds):
self._state.local_bsz_bounds = bounds

@property
def current_local_bsz(self):
Expand Down Expand Up @@ -263,7 +266,8 @@ def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None,
local_bsz_bounds[1] < self.batch_size):
raise ValueError("invalid local_bsz_bounds")
self._max_batch_size = max_batch_size
self._local_bsz_bounds = local_bsz_bounds
if self.local_bsz_bounds is None:
self.local_bsz_bounds = local_bsz_bounds
self._gradient_accumulation = gradient_accumulation
self.train()

Expand All @@ -279,7 +283,7 @@ def _sync_local_bsz(self):
_, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
atomic_bsz_range=self.local_bsz_bounds,
accumulation=self._gradient_accumulation)
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
Expand All @@ -288,7 +292,7 @@ def _sync_local_bsz(self):
suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
atomic_bsz_range=self.local_bsz_bounds,
accumulation=self._gradient_accumulation)
# get current goodput
current_goodput = goodput_fn(
Expand All @@ -299,6 +303,7 @@ def _sync_local_bsz(self):
if speedup > self._speedup_threshold:
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps

self._state.current_local_bsz, self._state.accumulation_steps = \
adaptdl.collective.broadcast((self._state.current_local_bsz,
self._state.accumulation_steps))
Expand Down Expand Up @@ -340,18 +345,23 @@ def context(self):
proper cleanup of elastic context at the end of each epoch.
"""
epoch = current_epoch()
exception = False
try:
if AdaptiveDataLoaderHelper._current is not None:
raise RuntimeError("overlapping dataloader \
iterations detected")
AdaptiveDataLoaderHelper._current = self
yield
except GeneratorExit:
# Generic Exception outside of the dataloader
exception = True
finally:
self._state.current_index = 0
self._state.end_index = 0
self._state.last_position[epoch] = self._position[epoch]
self._position[epoch] += 1
AdaptiveDataLoaderHelper._current = None
if not exception:
self._state.current_index = 0
self._state.end_index = 0
self._state.last_position[epoch] = self._position[epoch]
self._position[epoch] += 1
AdaptiveDataLoaderHelper._current = None

@property
def current_batch_size(self):
Expand Down Expand Up @@ -490,6 +500,7 @@ class AdaptiveDataLoader(DataLoader, AdaptiveDataLoaderMixin):

.. automethod:: __iter__
"""

def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
if kwargs.get("batch_sampler") is not None \
or kwargs.get("sampler") is not None:
Expand Down Expand Up @@ -563,13 +574,14 @@ def __init__(self):
self.current_index = 0 # Index within the current dataloader loop.
self.end_index = 0 # End index of the current DataLoader loop.
self.last_position = {} # Epoch -> position of last completed loop.
self.local_bsz_bounds = None
self.current_local_bsz = 0
self.accumulation_steps = 0

def save(self, fileobj):
pickle.dump((self.current_index, self.end_index,
self.last_position), fileobj)
self.last_position, self.local_bsz_bounds), fileobj)

def load(self, fileobj):
self.current_index, self.end_index, self.last_position = \
pickle.load(fileobj)
self.current_index, self.end_index, self.last_position, \
self.local_bsz_bounds = pickle.load(fileobj)
2 changes: 2 additions & 0 deletions examples/pytorch-cifar/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import adaptdl
import adaptdl.torch as adl
from adaptdl.retry import retry

from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -99,6 +100,7 @@


# Training
@retry
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
Expand Down
7 changes: 6 additions & 1 deletion sched/adaptdl_sched/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ async def _sync_job(self, namespace, job_name):
replicas = job["status"].get("replicas", 0)
preemptible = job["spec"].get("preemptible", True)
if (completion_status := self._detect_completion(pods, preemptible)):
# Job is already completed.
job["status"].update(completion_status)
phase = job["status"]["phase"]
if phase in ("Succeeded", "Failed"):
# Job is already completed.
job["status"].setdefault("completionTimestamp", current_ts)
job["status"]["allocation"] = allocation = []
await self._delete_pods( # Keep failed pods for debug purposes.
Expand Down Expand Up @@ -294,14 +296,17 @@ def any143(pod):
# resources before this pod could bind to that node.
LOG.warning("UnexpectedAdmissionError for pod %s: %s",
pod.metadata.name, pod.status.message)
return {"phase": "Stopping"}
elif str(pod.status.reason).startswith("Outof"):
# we might be temporarily out of pods on this node
LOG.warning(f"Pod {pod.metadata.name} is {pod.status.reason} "
f"on {pod.spec.node_name}")
return {"phase": "Stopping"}
elif preemptible and (pod.metadata.deletion_timestamp is not None
or any143(pod)):
# This pod was intentionally terminated.
LOG.warning(f"Pod {pod.metadata.name} terminated")
return {"phase": "Stopping"}
else:
return {"phase": "Failed", "reason": "PodFailure",
"message": f"{pod.metadata.name} {pod.status.phase}"}
Expand Down