From c43900efffba7cb809a61cf95119ebf8599ca500 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Wed, 16 Dec 2020 18:34:12 -0500 Subject: [PATCH 01/13] first shot --- adaptdl/adaptdl/goodput.py | 4 +-- adaptdl/adaptdl/torch/_metrics.py | 3 +- adaptdl/adaptdl/torch/data.py | 53 +++++++++++++++++++++++++------ 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/adaptdl/adaptdl/goodput.py b/adaptdl/adaptdl/goodput.py index c889aaa6..f7b75383 100644 --- a/adaptdl/adaptdl/goodput.py +++ b/adaptdl/adaptdl/goodput.py @@ -18,7 +18,7 @@ import collections import scipy.optimize import scipy.stats - +import ipdb # Parameters for a performance model which predicts the per-step time of # distributed SGD using all-reduce. At a high level, models compute time and @@ -98,7 +98,7 @@ def optimize(self, num_nodes, num_replicas, max_batch_size=None, num_nodes = np.broadcast_to(num_nodes, output_shape).flatten() num_replicas = np.broadcast_to(num_replicas, output_shape).flatten() # Samples 50 different total batch sizes in geometric space. - min_batch_size = np.maximum(self._init_batch_size, + min_batch_size = np.minimum(self._init_batch_size, min_atomic_bsz * num_replicas) batch_size = np.geomspace(min_batch_size, max_batch_size) local_bsz = batch_size / num_replicas diff --git a/adaptdl/adaptdl/torch/_metrics.py b/adaptdl/adaptdl/torch/_metrics.py index 64db2206..1bb20de7 100644 --- a/adaptdl/adaptdl/torch/_metrics.py +++ b/adaptdl/adaptdl/torch/_metrics.py @@ -62,7 +62,7 @@ def profile_step_commit(accumulation_step=False): if not accumulation_step: if _PREV_REPORT is None: _PREV_REPORT = time.time() - if adaptdl.env.replica_rank() == 0 and time.time() - _PREV_REPORT > 30: + if adaptdl.env.replica_rank() == 0: _fit_perf_params() _report_sched_hints() _PREV_REPORT = time.time() @@ -97,6 +97,7 @@ def set_batch_size(init_batch_size, max_batch_size, local_bsz_bounds, def get_goodput_fn(): state = _metrics_state() + print(state.grad_params, state.perf_params) if state.grad_params is None or state.perf_params is None: return None return GoodputFunction(state.perf_params, state.grad_params, diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 8f6f3a68..f69cb8b2 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -205,6 +205,15 @@ def current_local_bsz(self): """ return self._state.current_local_bsz + @property + def previous_local_bsz(self): + """ + The current logical local batch size used by the dataloader. + The batch size returned by the dataloader may be smaller if + gradient accumulation is used + """ + return self._state.previous_local_bsz + @property def accumulation_steps(self): """ @@ -266,6 +275,7 @@ def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, self._gradient_accumulation = gradient_accumulation self.train() + def _sync_local_bsz(self): goodput_fn = get_goodput_fn() if self.max_batch_size is None or goodput_fn is None: @@ -297,9 +307,9 @@ def _sync_local_bsz(self): self.current_local_bsz, self.accumulation_steps) # use only if speedup is significant speedup = suggest_goodput / max(current_goodput, 1e-8) - if speedup > self._speedup_threshold: - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps + self._state.current_local_bsz, self._state.accumulation_steps = \ adaptdl.collective.broadcast((self._state.current_local_bsz, self._state.accumulation_steps)) @@ -338,17 +348,22 @@ def context(self): proper cleanup of elastic context at the end of each epoch. """ epoch = current_epoch() + exception = False try: if AdaptiveDataLoaderHelper._current is not None: raise RuntimeError("overlapping dataloader \ iterations detected") AdaptiveDataLoaderHelper._current = self yield + except GeneratorExit as e: + LOG.info(f"GeneratorExit") + exception = True finally: - self._state.current_index = 0 - self._state.end_index = 0 - self._state.last_position[epoch] = self._position[epoch] - self._position[epoch] += 1 + if not exception: + self._state.current_index = 0 + self._state.end_index = 0 + self._state.last_position[epoch] = self._position[epoch] + self._position[epoch] += 1 AdaptiveDataLoaderHelper._current = None @property @@ -467,6 +482,7 @@ class AdaptiveDataLoader(DataLoader, AdaptiveDataLoaderMixin): .. automethod:: __iter__ """ + def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): if kwargs.get("batch_sampler") is not None \ or kwargs.get("sampler") is not None: @@ -535,15 +551,32 @@ def __init__(self): self.current_index = 0 # Index within the current dataloader loop. self.end_index = 0 # End index of the current DataLoader loop. self.last_position = {} # Epoch -> position of last completed loop. - self.current_local_bsz = 0 + self._current_local_bsz = [0, 0] # (previous, current) local bsize self.accumulation_steps = 0 + @property + def current_local_bsz(self): + return self._current_local_bsz[1] + + @property + def previous_local_bsz(self): + return self._current_local_bsz[0] + + @current_local_bsz.setter + def current_local_bsz(self, batch_size): + if self._current_local_bsz == [0, 0]: + self._current_local_bsz = [batch_size, batch_size] + elif self._current_local_bsz[1] != batch_size: + self._current_local_bsz[0] = self._current_local_bsz[1] + self._current_local_bsz[1] = batch_size + + def save(self, fileobj): pickle.dump((self.current_index, self.end_index, - self.last_position, self.current_local_bsz, + self.last_position, self._current_local_bsz, self.accumulation_steps), fileobj) def load(self, fileobj): self.current_index, self.end_index, self.last_position, \ - self.current_local_bsz, self.accumulation_steps = \ + self._current_local_bsz, self.accumulation_steps = \ pickle.load(fileobj) From 8ccc3f3a5efaf97f2dde252242e434d83620eea7 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Fri, 18 Dec 2020 03:20:52 -0500 Subject: [PATCH 02/13] introduce retry --- adaptdl/adaptdl/retry.py | 48 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 adaptdl/adaptdl/retry.py diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py new file mode 100644 index 00000000..1999179e --- /dev/null +++ b/adaptdl/adaptdl/retry.py @@ -0,0 +1,48 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import logging + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.INFO) + +def cudaoom(e): + return "RuntimeError: CUDA out of memory" in str(e) + +def retry(dataloader): + def deco(func): + @functools.wraps(func) + def inner(*args, **kwargs): + for _ in range(2): # we only try once + try: + func(*args, **kwargs) + break + except RuntimeError as e: + LOG.info(f"-------------- {e} ---------------") + if dataloader._elastic.local_bsz_bounds and cudaoom(e): + low, high = dataloader._elastic.local_bsz_bounds + max_batch_size = dataloader._elastic.max_batch_size + previous_local_bsz = dataloader._elastic.previous_local_bsz + print(max_batch_size, high, previous_local_bsz) + if high > previous_local_bsz: + local_bsz_bounds = (low, previous_local_bsz) + dataloader.autoscale_batch_size(max_batch_size=max_batch_size, + local_bsz_bounds=local_bsz_bounds) + else: raise e + else: raise e + return inner + return deco From 95bf28d0b45d3c2df661c8675b4ead33e1794138 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Sat, 19 Dec 2020 01:31:30 -0500 Subject: [PATCH 03/13] fix cudaoom --- adaptdl/adaptdl/retry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index 1999179e..7b56d9ab 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -21,7 +21,7 @@ LOG.setLevel(logging.INFO) def cudaoom(e): - return "RuntimeError: CUDA out of memory" in str(e) + return "CUDA out of memory" in str(e) def retry(dataloader): def deco(func): @@ -32,7 +32,7 @@ def inner(*args, **kwargs): func(*args, **kwargs) break except RuntimeError as e: - LOG.info(f"-------------- {e} ---------------") + LOG.info(f"{e}") if dataloader._elastic.local_bsz_bounds and cudaoom(e): low, high = dataloader._elastic.local_bsz_bounds max_batch_size = dataloader._elastic.max_batch_size From 7c0e24c944185e38bed5c1075e133335dbb1e1db Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Sat, 19 Dec 2020 02:06:45 -0500 Subject: [PATCH 04/13] revert silly changes --- adaptdl/adaptdl/goodput.py | 2 +- adaptdl/adaptdl/retry.py | 2 +- adaptdl/adaptdl/torch/_metrics.py | 3 +-- adaptdl/adaptdl/torch/data.py | 5 +++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/adaptdl/adaptdl/goodput.py b/adaptdl/adaptdl/goodput.py index 33667df3..9e6a29ee 100644 --- a/adaptdl/adaptdl/goodput.py +++ b/adaptdl/adaptdl/goodput.py @@ -102,7 +102,7 @@ def optimize(self, num_nodes, num_replicas, max_batch_size=None, num_nodes = np.broadcast_to(num_nodes, output_shape).flatten() num_replicas = np.broadcast_to(num_replicas, output_shape).flatten() # Samples 50 different total batch sizes in geometric space. - min_batch_size = np.minimum(self._init_batch_size, + min_batch_size = np.maximum(self._init_batch_size, min_atomic_bsz * num_replicas) batch_size = np.geomspace(min_batch_size, max_batch_size) local_bsz = batch_size / num_replicas diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index 7b56d9ab..daf67484 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -37,11 +37,11 @@ def inner(*args, **kwargs): low, high = dataloader._elastic.local_bsz_bounds max_batch_size = dataloader._elastic.max_batch_size previous_local_bsz = dataloader._elastic.previous_local_bsz - print(max_batch_size, high, previous_local_bsz) if high > previous_local_bsz: local_bsz_bounds = (low, previous_local_bsz) dataloader.autoscale_batch_size(max_batch_size=max_batch_size, local_bsz_bounds=local_bsz_bounds) + LOG.info(f"Local batch size bounds changed to {local_bsz_bounds}") else: raise e else: raise e return inner diff --git a/adaptdl/adaptdl/torch/_metrics.py b/adaptdl/adaptdl/torch/_metrics.py index 209251d7..a38cc2a2 100644 --- a/adaptdl/adaptdl/torch/_metrics.py +++ b/adaptdl/adaptdl/torch/_metrics.py @@ -60,7 +60,7 @@ def profile_step_commit(accumulation_step=False): if not accumulation_step: if _PREV_REPORT is None: _PREV_REPORT = time.time() - if adaptdl.env.replica_rank() == 0: + if adaptdl.env.replica_rank() == 0 and time.time() - _PREV_REPORT > 30: _fit_perf_params() _report_sched_hints() _PREV_REPORT = time.time() @@ -95,7 +95,6 @@ def set_batch_size(init_batch_size, max_batch_size, local_bsz_bounds, def get_goodput_fn(): state = _metrics_state() - print(state.grad_params, state.perf_params) if state.grad_params is None or state.perf_params is None: return None return GoodputFunction(state.perf_params, state.grad_params, diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 14d0d29f..751a6a0d 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -306,8 +306,9 @@ def _sync_local_bsz(self): self.current_local_bsz, self.accumulation_steps) # use only if speedup is significant speedup = suggest_goodput / max(current_goodput, 1e-8) - self._state.current_local_bsz = atomic_bsz - self._state.accumulation_steps = accum_steps + if speedup > self._speedup_threshold: + self._state.current_local_bsz = atomic_bsz + self._state.accumulation_steps = accum_steps self._state.current_local_bsz, self._state.accumulation_steps = \ adaptdl.collective.broadcast((self._state.current_local_bsz, From 5aa108443b56aeee2c1b866f8bd7160bb1b2c867 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Tue, 5 Jan 2021 02:37:15 -0500 Subject: [PATCH 05/13] fixes --- adaptdl/adaptdl/retry.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index daf67484..42b219f3 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -15,14 +15,21 @@ import functools import logging +import torch + +import adaptdl.checkpoint +import adaptdl.env +from adaptdl.torch._metrics import _report_sched_hints logging.basicConfig(level=logging.INFO) LOG = logging.getLogger(__name__) LOG.setLevel(logging.INFO) + def cudaoom(e): return "CUDA out of memory" in str(e) + def retry(dataloader): def deco(func): @functools.wraps(func) @@ -33,16 +40,28 @@ def inner(*args, **kwargs): break except RuntimeError as e: LOG.info(f"{e}") - if dataloader._elastic.local_bsz_bounds and cudaoom(e): + if (dataloader is not None and + dataloader._elastic.local_bsz_bounds is not None and + cudaoom(e)): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + LOG.info("**\n" + torch.cuda.memory_summary()) low, high = dataloader._elastic.local_bsz_bounds max_batch_size = dataloader._elastic.max_batch_size previous_local_bsz = dataloader._elastic.previous_local_bsz if high > previous_local_bsz: local_bsz_bounds = (low, previous_local_bsz) dataloader.autoscale_batch_size(max_batch_size=max_batch_size, - local_bsz_bounds=local_bsz_bounds) - LOG.info(f"Local batch size bounds changed to {local_bsz_bounds}") - else: raise e - else: raise e + local_bsz_bounds=local_bsz_bounds) + LOG.info( + f"Local batch size bounds changed to {local_bsz_bounds}") + if adaptdl.env.replica_rank() == 0: + _report_sched_hints() + adaptdl.checkpoint.save_all_states() + exit(143) + else: + raise e + else: + raise e return inner return deco From 3689d336b441bc3b8fb08f686f6ff433b6690227 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 12:00:01 -0500 Subject: [PATCH 06/13] directly reset the bounds --- adaptdl/adaptdl/retry.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index 42b219f3..ddfdf2dc 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -26,6 +26,8 @@ LOG.setLevel(logging.INFO) +GPU_MEM_CUTOFF_PCT = 0.2 + def cudaoom(e): return "CUDA out of memory" in str(e) @@ -46,21 +48,18 @@ def inner(*args, **kwargs): if torch.cuda.is_available(): torch.cuda.empty_cache() LOG.info("**\n" + torch.cuda.memory_summary()) + current_local_bsz = dataloader._elastic.current_local_bsz low, high = dataloader._elastic.local_bsz_bounds - max_batch_size = dataloader._elastic.max_batch_size - previous_local_bsz = dataloader._elastic.previous_local_bsz - if high > previous_local_bsz: - local_bsz_bounds = (low, previous_local_bsz) - dataloader.autoscale_batch_size(max_batch_size=max_batch_size, - local_bsz_bounds=local_bsz_bounds) - LOG.info( - f"Local batch size bounds changed to {local_bsz_bounds}") - if adaptdl.env.replica_rank() == 0: - _report_sched_hints() - adaptdl.checkpoint.save_all_states() - exit(143) - else: - raise e + LOG.info(f"current_local_bsz is {current_local_bsz} local_bsz_bounds ({low}, {high})") + assert current_local_bsz <= high + new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) + dataloader._elastic.local_bsz_bounds = [low, new_high] + LOG.info( + f"Local batch size bounds changed to ({low}, {new_high})") + if adaptdl.env.replica_rank() == 0: + _report_sched_hints() + adaptdl.checkpoint.save_all_states() + exit(143) else: raise e return inner From 8a64bd9e82370bd95686c86360ae423119670c90 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 16:03:47 -0500 Subject: [PATCH 07/13] remove dataloader argument --- adaptdl/adaptdl/retry.py | 66 ++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index ddfdf2dc..078e4b8e 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -20,47 +20,47 @@ import adaptdl.checkpoint import adaptdl.env from adaptdl.torch._metrics import _report_sched_hints +from adaptdl.torch.data import current_dataloader logging.basicConfig(level=logging.INFO) LOG = logging.getLogger(__name__) LOG.setLevel(logging.INFO) -GPU_MEM_CUTOFF_PCT = 0.2 +GPU_MEM_CUTOFF_PCT = 0.1 def cudaoom(e): return "CUDA out of memory" in str(e) -def retry(dataloader): - def deco(func): - @functools.wraps(func) - def inner(*args, **kwargs): - for _ in range(2): # we only try once - try: - func(*args, **kwargs) - break - except RuntimeError as e: - LOG.info(f"{e}") - if (dataloader is not None and - dataloader._elastic.local_bsz_bounds is not None and - cudaoom(e)): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - LOG.info("**\n" + torch.cuda.memory_summary()) - current_local_bsz = dataloader._elastic.current_local_bsz - low, high = dataloader._elastic.local_bsz_bounds - LOG.info(f"current_local_bsz is {current_local_bsz} local_bsz_bounds ({low}, {high})") - assert current_local_bsz <= high - new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) - dataloader._elastic.local_bsz_bounds = [low, new_high] - LOG.info( - f"Local batch size bounds changed to ({low}, {new_high})") - if adaptdl.env.replica_rank() == 0: - _report_sched_hints() - adaptdl.checkpoint.save_all_states() - exit(143) - else: - raise e - return inner - return deco +def retry(func): + @functools.wraps(func) + def inner(*args, **kwargs): + for _ in range(2): # we only try once + try: + func(*args, **kwargs) + break + except RuntimeError as e: + LOG.info(f"{e}") + dataloader = current_dataloader() + if (dataloader is not None and + dataloader.local_bsz_bounds is not None and + cudaoom(e)): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + LOG.info("\n" + torch.cuda.memory_summary()) + current_local_bsz = dataloader.current_local_bsz + low, high = dataloader.local_bsz_bounds + LOG.info(f"current_local_bsz is {current_local_bsz} local_bsz_bounds ({low}, {high})") + assert current_local_bsz <= high + new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) + dataloader.local_bsz_bounds = [low, new_high] + LOG.info( + f"Local batch size bounds changed to ({low}, {new_high})") + if adaptdl.env.replica_rank() == 0: + _report_sched_hints() + adaptdl.checkpoint.save_all_states() + exit(143) + else: + raise e + return inner From f5e797eb507831061a80a933083cd6cc8ba2c0c0 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 16:26:20 -0500 Subject: [PATCH 08/13] cleanups --- adaptdl/adaptdl/retry.py | 52 +++++++++++++++++------------------ adaptdl/adaptdl/torch/data.py | 50 ++++++++++----------------------- 2 files changed, 40 insertions(+), 62 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index 078e4b8e..c04a6ce0 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -36,31 +36,29 @@ def cudaoom(e): def retry(func): @functools.wraps(func) def inner(*args, **kwargs): - for _ in range(2): # we only try once - try: - func(*args, **kwargs) - break - except RuntimeError as e: - LOG.info(f"{e}") - dataloader = current_dataloader() - if (dataloader is not None and - dataloader.local_bsz_bounds is not None and - cudaoom(e)): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - LOG.info("\n" + torch.cuda.memory_summary()) - current_local_bsz = dataloader.current_local_bsz - low, high = dataloader.local_bsz_bounds - LOG.info(f"current_local_bsz is {current_local_bsz} local_bsz_bounds ({low}, {high})") - assert current_local_bsz <= high - new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) - dataloader.local_bsz_bounds = [low, new_high] - LOG.info( - f"Local batch size bounds changed to ({low}, {new_high})") - if adaptdl.env.replica_rank() == 0: - _report_sched_hints() - adaptdl.checkpoint.save_all_states() - exit(143) - else: - raise e + try: + func(*args, **kwargs) + except RuntimeError as e: + LOG.info(f"{e}") + dataloader = current_dataloader() + if (dataloader is not None and + dataloader.local_bsz_bounds is not None and + cudaoom(e)): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + LOG.info("\n" + torch.cuda.memory_summary()) + current_local_bsz = dataloader.current_local_bsz + low, high = dataloader.local_bsz_bounds + LOG.info(f"current_local_bsz is {current_local_bsz} local_bsz_bounds ({low}, {high})") + assert current_local_bsz <= high + new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) + dataloader.local_bsz_bounds = [low, new_high] + LOG.info( + f"Local batch size bounds changed to ({low}, {new_high})") + if adaptdl.env.replica_rank() == 0: + _report_sched_hints() + adaptdl.checkpoint.save_all_states() + exit(143) + else: + raise e return inner diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 751a6a0d..51e3ddc2 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -142,7 +142,6 @@ class AdaptiveDataLoaderHelper(object): def __init__(self, batch_size=1): # Autoscale batch size fields. self._max_batch_size = None - self._local_bsz_bounds = None # Create and load state. self._state = _AdaptiveDataLoaderState() adaptdl.checkpoint.load_state(self._state) @@ -198,7 +197,11 @@ def local_bsz_bounds(self): The local batch size bounds on each replica. A pair of integers, (min_local_bsz, max_local_bsz). """ - return self._local_bsz_bounds + return self._state.local_bsz_bounds + + @local_bsz_bounds.setter + def local_bsz_bounds(self, bounds): + self._state.local_bsz_bounds = bounds @property def current_local_bsz(self): @@ -209,15 +212,6 @@ def current_local_bsz(self): """ return self._state.current_local_bsz - @property - def previous_local_bsz(self): - """ - The current logical local batch size used by the dataloader. - The batch size returned by the dataloader may be smaller if - gradient accumulation is used - """ - return self._state.previous_local_bsz - @property def accumulation_steps(self): """ @@ -272,7 +266,8 @@ def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, local_bsz_bounds[1] < self.batch_size): raise ValueError("invalid local_bsz_bounds") self._max_batch_size = max_batch_size - self._local_bsz_bounds = local_bsz_bounds + if self.local_bsz_bounds is None: + self.local_bsz_bounds = local_bsz_bounds self._gradient_accumulation = gradient_accumulation self.train() @@ -289,7 +284,7 @@ def _sync_local_bsz(self): _, atomic_bsz, accum_steps = goodput_fn.optimize( adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, + atomic_bsz_range=self.local_bsz_bounds, accumulation=self._gradient_accumulation) self._state.current_local_bsz = atomic_bsz self._state.accumulation_steps = accum_steps @@ -298,7 +293,7 @@ def _sync_local_bsz(self): suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize( adaptdl.env.num_nodes(), adaptdl.env.num_replicas(), max_batch_size=self._max_batch_size, - atomic_bsz_range=self._local_bsz_bounds, + atomic_bsz_range=self.local_bsz_bounds, accumulation=self._gradient_accumulation) # get current goodput current_goodput = goodput_fn( @@ -367,7 +362,7 @@ def context(self): self._state.end_index = 0 self._state.last_position[epoch] = self._position[epoch] self._position[epoch] += 1 - AdaptiveDataLoaderHelper._current = None + AdaptiveDataLoaderHelper._current = None @property def current_batch_size(self): @@ -580,29 +575,14 @@ def __init__(self): self.current_index = 0 # Index within the current dataloader loop. self.end_index = 0 # End index of the current DataLoader loop. self.last_position = {} # Epoch -> position of last completed loop. - self._current_local_bsz = [0, 0] # (previous, current) local bsize + self.local_bsz_bounds = None + self.current_local_bsz = 0 self.accumulation_steps = 0 - @property - def current_local_bsz(self): - return self._current_local_bsz[1] - - @property - def previous_local_bsz(self): - return self._current_local_bsz[0] - - @current_local_bsz.setter - def current_local_bsz(self, batch_size): - if self._current_local_bsz == [0, 0]: - self._current_local_bsz = [batch_size, batch_size] - elif self._current_local_bsz[1] != batch_size: - self._current_local_bsz[0] = self._current_local_bsz[1] - self._current_local_bsz[1] = batch_size - def save(self, fileobj): pickle.dump((self.current_index, self.end_index, - self.last_position), fileobj) + self.last_position, self.local_bsz_bounds), fileobj) def load(self, fileobj): - self.current_index, self.end_index, self.last_position = \ - pickle.load(fileobj) + self.current_index, self.end_index, self.last_position, \ + self.local_bsz_bounds = pickle.load(fileobj) From af06f5fa28521fdad18d75db1f98d2e46d6613da Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 16:48:32 -0500 Subject: [PATCH 09/13] more cleanups --- adaptdl/adaptdl/retry.py | 8 ++------ adaptdl/adaptdl/torch/data.py | 2 +- examples/pytorch-cifar/main.py | 2 ++ 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index c04a6ce0..82608baf 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -44,17 +44,13 @@ def inner(*args, **kwargs): if (dataloader is not None and dataloader.local_bsz_bounds is not None and cudaoom(e)): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - LOG.info("\n" + torch.cuda.memory_summary()) current_local_bsz = dataloader.current_local_bsz low, high = dataloader.local_bsz_bounds - LOG.info(f"current_local_bsz is {current_local_bsz} local_bsz_bounds ({low}, {high})") assert current_local_bsz <= high new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) - dataloader.local_bsz_bounds = [low, new_high] + dataloader.local_bsz_bounds = (low, new_high) LOG.info( - f"Local batch size bounds changed to ({low}, {new_high})") + f"Local batch size bounds changed to {dataloader.local_bsz_bounds}") if adaptdl.env.replica_rank() == 0: _report_sched_hints() adaptdl.checkpoint.save_all_states() diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 51e3ddc2..3f4e9c58 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -354,7 +354,7 @@ def context(self): AdaptiveDataLoaderHelper._current = self yield except GeneratorExit as e: - LOG.info(f"GeneratorExit") + # Generic Exception outside of the dataloader exception = True finally: if not exception: diff --git a/examples/pytorch-cifar/main.py b/examples/pytorch-cifar/main.py index c5b8f5f5..0ea818e1 100644 --- a/examples/pytorch-cifar/main.py +++ b/examples/pytorch-cifar/main.py @@ -31,6 +31,7 @@ import adaptdl import adaptdl.torch as adl +from adaptdl.retry import retry from torch.optim.lr_scheduler import MultiStepLR from torch.utils.tensorboard import SummaryWriter @@ -104,6 +105,7 @@ net = adl.AdaptiveDataParallel(net, optimizer, lr_scheduler) # Training +@retry def train(epoch): print('\nEpoch: %d' % epoch) net.train() From 116f2bcd38875c1b51ad273126ac9c8e93a5bc42 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 16:53:14 -0500 Subject: [PATCH 10/13] controller fix --- sched/adaptdl_sched/controller.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sched/adaptdl_sched/controller.py b/sched/adaptdl_sched/controller.py index 5d120db8..e9e1bcc2 100644 --- a/sched/adaptdl_sched/controller.py +++ b/sched/adaptdl_sched/controller.py @@ -114,8 +114,10 @@ async def _sync_job(self, namespace, job_name): replicas = job["status"].get("replicas", 0) preemptible = job["spec"].get("preemptible", True) if (completion_status := self._detect_completion(pods, preemptible)): - # Job is already completed. job["status"].update(completion_status) + phase = job["status"]["phase"] + if phase in ("Succeeded", "Failed"): + # Job is already completed. job["status"].setdefault("completionTimestamp", current_ts) job["status"]["allocation"] = allocation = [] await self._delete_pods( # Keep failed pods for debug purposes. @@ -284,14 +286,17 @@ def any143(pod): # resources before this pod could bind to that node. LOG.warning("UnexpectedAdmissionError for pod %s: %s", pod.metadata.name, pod.status.message) + return {"phase": "Stopping"} elif str(pod.status.reason).startswith("Outof"): # we might be temporarily out of pods on this node LOG.warning(f"Pod {pod.metadata.name} is {pod.status.reason} " f"on {pod.spec.node_name}") + return {"phase": "Stopping"} elif preemptible and (pod.metadata.deletion_timestamp is not None or any143(pod)): # This pod was intentionally terminated. LOG.warning(f"Pod {pod.metadata.name} terminated") + return {"phase": "Stopping"} else: return {"phase": "Failed", "reason": "PodFailure", "message": f"{pod.metadata.name} {pod.status.phase}"} From 1cd1bc55d46585d7619b5a2daf8fd3a5ff13d94a Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 19:38:14 -0500 Subject: [PATCH 11/13] cleanup --- adaptdl/adaptdl/goodput.py | 1 - 1 file changed, 1 deletion(-) diff --git a/adaptdl/adaptdl/goodput.py b/adaptdl/adaptdl/goodput.py index 9e6a29ee..469fa8e8 100644 --- a/adaptdl/adaptdl/goodput.py +++ b/adaptdl/adaptdl/goodput.py @@ -19,7 +19,6 @@ import scipy.optimize import scipy.stats - # Parameters for a performance model which predicts the per-step time of # distributed SGD using all-reduce. At a high level, models compute time and # network time separately, and combines them with some degree of overlap. From e535ce9bec7b7dd72611ff2657d19340135524df Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 20:31:05 -0500 Subject: [PATCH 12/13] lint --- adaptdl/adaptdl/retry.py | 8 ++++---- adaptdl/adaptdl/torch/data.py | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index 82608baf..e760b274 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -15,7 +15,6 @@ import functools import logging -import torch import adaptdl.checkpoint import adaptdl.env @@ -29,6 +28,7 @@ GPU_MEM_CUTOFF_PCT = 0.1 + def cudaoom(e): return "CUDA out of memory" in str(e) @@ -43,14 +43,14 @@ def inner(*args, **kwargs): dataloader = current_dataloader() if (dataloader is not None and dataloader.local_bsz_bounds is not None and - cudaoom(e)): + cudaoom(e)): current_local_bsz = dataloader.current_local_bsz low, high = dataloader.local_bsz_bounds assert current_local_bsz <= high new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) dataloader.local_bsz_bounds = (low, new_high) - LOG.info( - f"Local batch size bounds changed to {dataloader.local_bsz_bounds}") + LOG.info(f"Local batch size bounds changed to " + f"{dataloader.local_bsz_bounds}") if adaptdl.env.replica_rank() == 0: _report_sched_hints() adaptdl.checkpoint.save_all_states() diff --git a/adaptdl/adaptdl/torch/data.py b/adaptdl/adaptdl/torch/data.py index 3f4e9c58..bc8e8a7a 100644 --- a/adaptdl/adaptdl/torch/data.py +++ b/adaptdl/adaptdl/torch/data.py @@ -271,7 +271,6 @@ def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None, self._gradient_accumulation = gradient_accumulation self.train() - def _sync_local_bsz(self): goodput_fn = get_goodput_fn() if self.max_batch_size is None or goodput_fn is None: @@ -353,7 +352,7 @@ def context(self): iterations detected") AdaptiveDataLoaderHelper._current = self yield - except GeneratorExit as e: + except GeneratorExit: # Generic Exception outside of the dataloader exception = True finally: @@ -581,7 +580,7 @@ def __init__(self): def save(self, fileobj): pickle.dump((self.current_index, self.end_index, - self.last_position, self.local_bsz_bounds), fileobj) + self.last_position, self.local_bsz_bounds), fileobj) def load(self, fileobj): self.current_index, self.end_index, self.last_position, \ From cf3d41b395a187076b416aa9650a9324299f9105 Mon Sep 17 00:00:00 2001 From: Omkar Pangarkar Date: Mon, 11 Jan 2021 21:56:09 -0500 Subject: [PATCH 13/13] lint --- adaptdl/adaptdl/retry.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/adaptdl/adaptdl/retry.py b/adaptdl/adaptdl/retry.py index e760b274..31ae7ca5 100644 --- a/adaptdl/adaptdl/retry.py +++ b/adaptdl/adaptdl/retry.py @@ -26,7 +26,9 @@ LOG.setLevel(logging.INFO) -GPU_MEM_CUTOFF_PCT = 0.1 +# Percentage of current_local_bsz used to decide upper bound on +# local_bsz_bounds after OOM +LOCAL_BSZ_CUTOFF_PCT = 0.1 def cudaoom(e): @@ -47,7 +49,9 @@ def inner(*args, **kwargs): current_local_bsz = dataloader.current_local_bsz low, high = dataloader.local_bsz_bounds assert current_local_bsz <= high - new_high = int((1. - GPU_MEM_CUTOFF_PCT) * current_local_bsz) + new_high = int((1. - LOCAL_BSZ_CUTOFF_PCT) * current_local_bsz) + if new_high < low: + raise e dataloader.local_bsz_bounds = (low, new_high) LOG.info(f"Local batch size bounds changed to " f"{dataloader.local_bsz_bounds}")