diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py new file mode 100644 index 00000000..31ae7ca5 --- /dev/null +++ b/adaptdl/adaptdl/retry.py @@ -0,0 +1,64 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import logging + +import adaptdl.checkpoint +import adaptdl.env +from adaptdl.torch._metrics import _report_sched_hints +from adaptdl.torch.data import current_dataloader + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.INFO) + + +# Percentage of current_local_bsz used to decide upper bound on +# local_bsz_bounds after OOM +LOCAL_BSZ_CUTOFF_PCT = 0.1 + + +def cudaoom(e): + return "CUDA out of memory" in str(e) + + +def retry(func): + @functools.wraps(func) + def inner(*args, **kwargs): + try: + func(*args, **kwargs) + except RuntimeError as e: + LOG.info(f"{e}") + dataloader = current_dataloader() + if (dataloader is not None and + dataloader.local_bsz_bounds is not None and + cudaoom(e)): + current_local_bsz = dataloader.current_local_bsz + low, high = dataloader.local_bsz_bounds + assert current_local_bsz <= high + new_high = int((1. - LOCAL_BSZ_CUTOFF_PCT) * current_local_bsz) + if new_high < low: + raise e + dataloader.local_bsz_bounds = (low, new_high) + LOG.info(f"Local batch size bounds changed to " + f"{dataloader.local_bsz_bounds}") + if adaptdl.env.replica_rank() == 0: + _report_sched_hints() + adaptdl.checkpoint.save_all_states() + exit(143) + else: + raise e + return inner diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 90a8767a..bc8e8a7a 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -142,7 +142,6 @@ class AdaptiveDataLoaderHelper(object): def __init__(self, batch_size=1): # Autoscale batch size fields. self._max_batch_size = None - self._local_bsz_bounds = None # Create and load state. self._state = _AdaptiveDataLoaderState() adaptdl.checkpoint.load_state(self._state) @@ -198,7 +197,11 @@ def local_bsz_bounds(self): The local batch size bounds on each replica. A pair of integers, (min_local_bsz, max_local_bsz). """ - return self._local_bsz_bounds + return self._state.local_bsz_bounds + + @local_bsz_bounds.setter + def local_bsz_bounds(self, bounds): + self._state.local_bsz_bounds = bounds @property def current_local_bsz(self): @@ -263,7 +266,8 @@ def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, local_bsz_bounds[1] < self.batch_size): raise ValueError("invalid local_bsz_bounds") self._max_batch_size = max_batch_size - self._local_bsz_bounds = local_bsz_bounds + if self.local_bsz_bounds is None: + self.local_bsz_bounds = local_bsz_bounds self._gradient_accumulation = gradient_accumulation self.train() @@ -279,7 +283,7 @@ def _sync_local_bsz(self): _, atomic_bsz, accum_steps = goodput_fn.optimize( adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, + atomic_bsz_range=self.local_bsz_bounds, accumulation=self._gradient_accumulation) self._state.current_local_bsz = atomic_bsz self._state.accumulation_steps = accum_steps @@ -288,7 +292,7 @@ def _sync_local_bsz(self): suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, + atomic_bsz_range=self.local_bsz_bounds, accumulation=self._gradient_accumulation) # get current goodput current_goodput = goodput_fn( @@ -299,6 +303,7 @@ def _sync_local_bsz(self): if speedup > self._speedup_threshold: self._state.current_local_bsz = atomic_bsz self._state.accumulation_steps = accum_steps + self._state.current_local_bsz, self._state.accumulation_steps = \ adaptdl.collective.broadcast((self._state.current_local_bsz, self._state.accumulation_steps)) @@ -340,18 +345,23 @@ def context(self): proper cleanup of elastic context at the end of each epoch. """ epoch = current_epoch() + exception = False try: if AdaptiveDataLoaderHelper._current is not None: raise RuntimeError("overlapping dataloader \ iterations detected") AdaptiveDataLoaderHelper._current = self yield + except GeneratorExit: + # Generic Exception outside of the dataloader + exception = True finally: - self._state.current_index = 0 - self._state.end_index = 0 - self._state.last_position[epoch] = self._position[epoch] - self._position[epoch] += 1 - AdaptiveDataLoaderHelper._current = None + if not exception: + self._state.current_index = 0 + self._state.end_index = 0 + self._state.last_position[epoch] = self._position[epoch] + self._position[epoch] += 1 + AdaptiveDataLoaderHelper._current = None @property def current_batch_size(self): @@ -490,6 +500,7 @@ class AdaptiveDataLoader(DataLoader, AdaptiveDataLoaderMixin): .. automethod:: __iter__ """ + def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): if kwargs.get("batch_sampler") is not None \ or kwargs.get("sampler") is not None: @@ -563,13 +574,14 @@ def __init__(self): self.current_index = 0 # Index within the current dataloader loop. self.end_index = 0 # End index of the current DataLoader loop. self.last_position = {} # Epoch -> position of last completed loop. + self.local_bsz_bounds = None self.current_local_bsz = 0 self.accumulation_steps = 0 def save(self, fileobj): pickle.dump((self.current_index, self.end_index, - self.last_position), fileobj) + self.last_position, self.local_bsz_bounds), fileobj) def load(self, fileobj): - self.current_index, self.end_index, self.last_position = \ - pickle.load(fileobj) + self.current_index, self.end_index, self.last_position, \ + self.local_bsz_bounds = pickle.load(fileobj) diff --git a/examples/pytorch-cifar/main.py b/examples/pytorch-cifar/main.py index 9f10e7cb..ab3ab066 100644 --- a/examples/pytorch-cifar/main.py +++ b/examples/pytorch-cifar/main.py @@ -30,6 +30,7 @@ import adaptdl import adaptdl.torch as adl +from adaptdl.retry import retry from torch.optim.lr_scheduler import MultiStepLR from torch.utils.tensorboard import SummaryWriter @@ -99,6 +100,7 @@ # Training +@retry def train(epoch): print('\nEpoch: %d' % epoch) net.train() diff --git a/sched/adaptdl_sched/controller.py b/sched/adaptdl_sched/controller.py index 483aafe7..aeb981dd 100644 --- a/sched/adaptdl_sched/controller.py +++ b/sched/adaptdl_sched/controller.py @@ -114,8 +114,10 @@ async def _sync_job(self, namespace, job_name): replicas = job["status"].get("replicas", 0) preemptible = job["spec"].get("preemptible", True) if (completion_status := self._detect_completion(pods, preemptible)): - # Job is already completed. job["status"].update(completion_status) + phase = job["status"]["phase"] + if phase in ("Succeeded", "Failed"): + # Job is already completed. job["status"].setdefault("completionTimestamp", current_ts) job["status"]["allocation"] = allocation = [] await self._delete_pods( # Keep failed pods for debug purposes. @@ -294,14 +296,17 @@ def any143(pod): # resources before this pod could bind to that node. LOG.warning("UnexpectedAdmissionError for pod %s: %s", pod.metadata.name, pod.status.message) + return {"phase": "Stopping"} elif str(pod.status.reason).startswith("Outof"): # we might be temporarily out of pods on this node LOG.warning(f"Pod {pod.metadata.name} is {pod.status.reason} " f"on {pod.spec.node_name}") + return {"phase": "Stopping"} elif preemptible and (pod.metadata.deletion_timestamp is not None or any143(pod)): # This pod was intentionally terminated. LOG.warning(f"Pod {pod.metadata.name} terminated") + return {"phase": "Stopping"} else: return {"phase": "Failed", "reason": "PodFailure", "message": f"{pod.metadata.name} {pod.status.phase}"}