diff --git a/autodist/graph_item.py b/autodist/graph_item.py index 796b469..90933a3 100644 --- a/autodist/graph_item.py +++ b/autodist/graph_item.py @@ -72,7 +72,6 @@ def get_default_graph_item(): def wrap_optimizer_init(fn: Callable): """Wraps the __init__ function of OptimizerV2 objects and stores the info in the default GraphItem.""" - def wrapper(*args, **kwargs): # args[0] should be `self`, which is an object of type == optimizer class containing_class = type(args[0]) diff --git a/autodist/kernel/__init__.py b/autodist/kernel/__init__.py index 8b5caba..9c451c5 100644 --- a/autodist/kernel/__init__.py +++ b/autodist/kernel/__init__.py @@ -15,3 +15,4 @@ # So Strategy.create() works from autodist.kernel.synchronization.ps_synchronizer import PSSynchronizer from autodist.kernel.synchronization.all_reduce_synchronizer import AllReduceSynchronizer +from autodist.kernel.synchronization.sfb_synchronizer import SFBSynchronizer \ No newline at end of file diff --git a/autodist/kernel/graph_transformer.py b/autodist/kernel/graph_transformer.py index 892f45b..d884a8f 100644 --- a/autodist/kernel/graph_transformer.py +++ b/autodist/kernel/graph_transformer.py @@ -99,11 +99,11 @@ def _initialize_synchronizers(self): for part in node.part_config: self._synchronizers[part.var_name] = \ Synchronizer.create(part.WhichOneof('synchronizer'), - getattr(part, part.WhichOneof('synchronizer'))) + part) else: self._synchronizers[node.var_name] = \ Synchronizer.create(node.WhichOneof('synchronizer'), - getattr(node, node.WhichOneof('synchronizer'))) + node) config = self._strategy.graph_config.replicas replica_devices = {device_spec.DeviceSpecV2.from_string(s) for s in config} diff --git a/autodist/kernel/synchronization/all_reduce_synchronizer.py b/autodist/kernel/synchronization/all_reduce_synchronizer.py index 5d63982..b01617d 100644 --- a/autodist/kernel/synchronization/all_reduce_synchronizer.py +++ b/autodist/kernel/synchronization/all_reduce_synchronizer.py @@ -18,6 +18,7 @@ from tensorflow.python import ops from tensorflow.python.framework import device_spec from tensorflow.python.ops import collective_ops +from tensorflow.python.framework.ops import Tensor import autodist from autodist.const import ENV @@ -25,12 +26,23 @@ replica_prefix, get_control_consumers, update_control_consumers from autodist.kernel.common.utils import get_op_name from autodist.kernel.synchronization.collective_key import get_collective_keys -from autodist.kernel.synchronization.compressor import Compressor, CollectiveOpsConfig +# from autodist.kernel.synchronization.compressor import Compressor, CollectiveOpsConfig +from autodist.kernel.synchronization.compressor import Compressor from autodist.kernel.synchronization.synchronizer import Synchronizer -from autodist.proto import synchronizers_pb2 +from autodist.proto import synchronizers_pb2, compressor_pb2, strategy_pb2 from autodist.utils import logging +class CollectiveOpsConfig: + """Config for using Collective Ops.""" + + group_size: int + group_key: str + instance_key: str + merge_op: str + final_op: str + + class AllReduceSynchronizer(Synchronizer): """ AllReduce Synchronizer. @@ -50,21 +62,39 @@ class AllReduceSynchronizer(Synchronizer): 2. any other types of hybrid reduction of PS and AllReduce. """ - def __init__(self, config: synchronizers_pb2.AllReduceSynchronizer): - self._spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Name(config.spec) + def __init__(self, config: strategy_pb2.Strategy.Node): + # compressor_value = getattr(config, 'compressor') + compressor_value = getattr(config.compressor, 'type') + syncer_config = getattr(config, config.WhichOneof('synchronizer')) + self._spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Name(syncer_config.spec) if autodist.float_major_minor_tf_version < 1.15 or autodist.float_major_minor_tf_version < 2.1: logging.warning('Collective synchronizer spec "{}" a.k.a communication_hint has no effect ' 'until tensorflow-gpu 1.x>= 1.15 or 2.x>=2.1. It may cause error currently.' .format(self._spec)) self._spec = None - self._compressor_type = synchronizers_pb2.AllReduceSynchronizer.Compressor.Name(config.compressor) - # Collective ops within the same group will be merged by the scoped optimizer. # Normally the group index shall be smaller than the number of variables in the graph; this kernel assumes # the strategy will validate the group assignments are legitimate. - self._group = config.group + self._group = syncer_config.group super().__init__() + if compressor_value is not None: + self._compressor_type = compressor_pb2.Compressor.Type.Name(compressor_value) + print(self._compressor_type) + + @staticmethod + def _all_reduce(tensor: Tensor, conf: CollectiveOpsConfig): + """ + Using CollectiveOps, AllReduce the given tensor. + + Args: + tensor (Tensor): the tensor to all-reduce + conf (CollectiveOpsConfig): the config for CollectiveOps + + Returns: + The All-Reduced Tensor + """ + return collective_ops.all_reduce(tensor, **conf.__dict__) def in_graph_apply(self, graph_item, var_name): """ @@ -124,8 +154,12 @@ def _collect_dense_gradients(self, graph_item, var_op_name): # "\/" is added for name scope reuse with ops.name_scope(replica_prefix(i) + "/collective-group-{}/".format(self._group)): + # compressed_grad = compressors[i].compress(grad) with ops.colocate_with(grad.op): - reduced_grad = compressors[i].reduce(grad, conf) + compressed_grad = compressors[i].compress(grad) + reduced = self._all_reduce(compressed_grad, conf) + reduced_grad = compressors[i].decompress(reduced) + # reduced_grad = compressors[i].decompress(reduced) update_consumers(grad_consumers, grad, reduced_grad) # TODO(Hao): update grad, target pair here or not? diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 4878d48..eead145 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -16,21 +16,21 @@ from abc import ABC, abstractmethod from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import Tensor -from tensorflow.python.ops import collective_ops, math_ops - +# from tensorflow.python.ops import collective_ops, math_ops +from tensorflow.python.ops import math_ops #from tensorflow.python.ops import array_ops, collective_ops, linalg_ops, math_ops, random_ops #from autodist.kernel.synchronization.collective_key import get_collective_keys #from autodist.utils import logging -class CollectiveOpsConfig: - """Config for using Collective Ops.""" +# class CollectiveOpsConfig: +# """Config for using Collective Ops.""" - group_size: int - group_key: str - instance_key: str - merge_op: str - final_op: str +# group_size: int +# group_key: str +# instance_key: str +# merge_op: str +# final_op: str class Compressor(ABC): @@ -44,21 +44,21 @@ class Compressor(ABC): def __init__(self, var_op_name): self.var_op_name = var_op_name - @abstractmethod - def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): - """ - Compress, reduce, and decompress a given tensor. + # @abstractmethod + # def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + # """ + # Compress, reduce, and decompress a given tensor. - Args: - tensor (Tensor): the Tensor to reduce. - conf (CollectiveOpsConfig): the config for Collective Ops. + # Args: + # tensor (Tensor): the Tensor to reduce. + # conf (CollectiveOpsConfig): the config for Collective Ops. - Returns: - Reduced Tensor - """ + # Returns: + # Reduced Tensor + # """ @abstractmethod - def _compress(self, tensor: Tensor): + def compress(self, tensor: Tensor): """ Compress a given tensor. @@ -70,7 +70,7 @@ def _compress(self, tensor: Tensor): """ @abstractmethod - def _decompress(self, compressed_tensor: Tensor): + def decompress(self, compressed_tensor: Tensor): """ Decompress a given tensor. @@ -81,19 +81,19 @@ def _decompress(self, compressed_tensor: Tensor): Tensor, Context """ - @staticmethod - def _all_reduce(tensor: Tensor, conf: CollectiveOpsConfig): - """ - Using CollectiveOps, AllReduce the given tensor. + # @staticmethod + # def _all_reduce(tensor: Tensor, conf: CollectiveOpsConfig): + # """ + # Using CollectiveOps, AllReduce the given tensor. - Args: - tensor (Tensor): the tensor to all-reduce - conf (CollectiveOpsConfig): the config for CollectiveOps + # Args: + # tensor (Tensor): the tensor to all-reduce + # conf (CollectiveOpsConfig): the config for CollectiveOps - Returns: - The All-Reduced Tensor - """ - return collective_ops.all_reduce(tensor, **conf.__dict__) + # Returns: + # The All-Reduced Tensor + # """ + # return collective_ops.all_reduce(tensor, **conf.__dict__) @classmethod def create(cls, name, *args, **kwargs): @@ -124,45 +124,69 @@ def __init__(self, var_op_name): self.error = None super().__init__(var_op_name) - def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): - """ - Compress, reduce, and decompress a given tensor. - - Args: - tensor (Tensor): the Tensor to reduce. - conf (CollectiveOpsConfig): the config for Collective Ops. - - Returns: - Reduced Tensor - """ + # def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + # """ + # Compress, reduce, and decompress a given tensor. + + # Args: + # tensor (Tensor): the Tensor to reduce. + # conf (CollectiveOpsConfig): the config for Collective Ops. + + # Returns: + # Reduced Tensor + # """ + # if self.error is not None: + # tensor += self.error + # compressed_tensor = self._compress(tensor) + # self.error = tensor - self._decompress(compressed_tensor) + # reduced = self._all_reduce(compressed_tensor, conf) + # return self._decompress(reduced) + + def _compute_error(self, tensor: Tensor): if self.error is not None: tensor += self.error - compressed_tensor = self._compress(tensor) - self.error = tensor - self._decompress(compressed_tensor) - reduced = self._all_reduce(compressed_tensor, conf) - return self._decompress(reduced) + compressed_tensor = self.compress(tensor) + self.error = tensor - self.decompress(compressed_tensor) class NoneCompressor(Compressor): """An identity Compressor.""" - def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + # def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + # """ + # Compress, reduce, and decompress a given tensor. + + # Args: + # tensor (Tensor): the Tensor to reduce. + # conf (CollectiveOpsConfig): the config for Collective Ops. + + # Returns: + # Reduced Tensor + # """ + # return self._all_reduce(tensor, conf) + + def compress(self, tensor: Tensor): """ - Compress, reduce, and decompress a given tensor. + Compress a given tensor. Args: - tensor (Tensor): the Tensor to reduce. - conf (CollectiveOpsConfig): the config for Collective Ops. + tensor (Tensor): the Tensor to compress. Returns: - Reduced Tensor + Tensor """ - return self._all_reduce(tensor, conf) - - def _compress(self, tensor: Tensor): return tensor - def _decompress(self, compressed_tensor: Tensor): + def decompress(self, compressed_tensor: Tensor, *args, **kwargs): + """ + Decompress a given tensor. + + Args: + compressed_tensor (Tensor): the Tensor to decompress. + + Returns: + Tensor, Context + """ return compressed_tensor @@ -173,22 +197,31 @@ def __init__(self, var_op_name): self.dtype = None super().__init__(var_op_name) - def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + # def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + # """ + # Compress, reduce, and decompress a given tensor. + + # Args: + # tensor (Tensor): the Tensor to reduce. + # conf (CollectiveOpsConfig): the config for Collective Ops. + + # Returns: + # Reduced Tensor + # """ + # compressed_tensor = self._compress(tensor) + # reduced = self._all_reduce(compressed_tensor, conf) + # return self._decompress(reduced) + + def compress(self, tensor: Tensor): """ - Compress, reduce, and decompress a given tensor. + Compress a given tensor. Args: - tensor (Tensor): the Tensor to reduce. - conf (CollectiveOpsConfig): the config for Collective Ops. + tensor (Tensor): the Tensor to compress. Returns: - Reduced Tensor + Tensor """ - compressed_tensor = self._compress(tensor) - reduced = self._all_reduce(compressed_tensor, conf) - return self._decompress(reduced) - - def _compress(self, tensor: Tensor): self.dtype = tensor.dtype tensor_compressed = tensor if tensor.dtype.is_floating: @@ -197,12 +230,38 @@ def _compress(self, tensor: Tensor): tensor_compressed = math_ops.cast(tensor, dtypes.float32) return tensor_compressed - def _decompress(self, compressed_tensor: Tensor): + def decompress(self, compressed_tensor: Tensor, *args, **kwargs): + """ + Decompress a given tensor. + + Args: + compressed_tensor (Tensor): the Tensor to decompress. + + Returns: + Tensor, Context + """ return math_ops.cast(compressed_tensor, self.dtype) -class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works because of Method Resolution Order - """Horovod's Compression but with Error Feedback.""" +class SFBCompressor(Compressor): + """Implement Sufficient Factor Broadcasting's Compressor.""" + + def decompress(self, compressed_tensor: Tensor, *args, **kwargs): + """ + Decompress given a pair of tensors. + + Args: + compressed_tensor : A tuple of tensors to decompress. + + Returns: + Tensor, Context + """ + compressed_tensor_2 = args[0] + return math_ops.multiply(compressed_tensor, compressed_tensor_2) + + +# class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works because of Method Resolution Order +# """Horovod's Compression but with Error Feedback.""" # class PowerSGDCompressor(CompressorEF): diff --git a/autodist/kernel/synchronization/ps_synchronizer.py b/autodist/kernel/synchronization/ps_synchronizer.py index 31b7ca0..705c723 100644 --- a/autodist/kernel/synchronization/ps_synchronizer.py +++ b/autodist/kernel/synchronization/ps_synchronizer.py @@ -35,7 +35,8 @@ remove_from_control_consumers, get_index_from_tensor_name, update_colocation_group from autodist.kernel.common.variable_utils import get_read_var_ops from autodist.kernel.synchronization.synchronizer import Synchronizer -from autodist.proto import synchronizers_pb2 +# from autodist.proto import synchronizers_pb2, strategy_pb2 +from autodist.proto import strategy_pb2, compressor_pb2 class PSSynchronizer(Synchronizer): @@ -53,15 +54,20 @@ class PSSynchronizer(Synchronizer): for each variable for the workers to mark when their variable update is complete. """ - def __init__(self, config: synchronizers_pb2.PSSynchronizer): - self.target_device = config.reduction_destination if config.reduction_destination else "" - self._local_replication = config.local_replication - self._sync = config.sync - self._staleness = config.staleness + def __init__(self, config: strategy_pb2.Strategy.Node): + syncer_config = getattr(config, config.WhichOneof('synchronizer')) + # compressor_value = getattr(config, 'compressor') + compressor_value = getattr(config.compressor, 'type') + self.target_device = syncer_config.reduction_destination if syncer_config.reduction_destination else "" + self._local_replication = syncer_config.local_replication + self._sync = syncer_config.sync + self._staleness = syncer_config.staleness self._var_op_to_agg_grad = {} self._var_op_to_accum_apply_op = {} super().__init__() + if compressor_value: + self._compressor_type = compressor_pb2.Compressor.Type.Name(compressor_value) def in_graph_apply(self, graph_item, var_name): """ diff --git a/autodist/kernel/synchronization/sfb_synchronizer.py b/autodist/kernel/synchronization/sfb_synchronizer.py new file mode 100644 index 0000000..5910d59 --- /dev/null +++ b/autodist/kernel/synchronization/sfb_synchronizer.py @@ -0,0 +1,260 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AllReduce Synchronizer.""" +from collections import defaultdict + +import tensorflow +from tensorflow.python import ops +from tensorflow.python.framework import device_spec +from tensorflow.python.ops import collective_ops, math_ops +from tensorflow.python.framework.ops import Tensor + +import autodist +# from autodist.const import ENV +from autodist.kernel.common.utils import get_consumers, update_consumers, \ + replica_prefix # , get_control_consumers, update_control_consumers +from autodist.kernel.common.utils import get_op_name +from autodist.kernel.synchronization.collective_key import get_collective_keys +# from autodist.kernel.synchronization.compressor import Compressor, CollectiveOpsConfig +from autodist.kernel.synchronization.compressor import Compressor +from autodist.kernel.synchronization.synchronizer import Synchronizer +from autodist.proto import synchronizers_pb2, compressor_pb2, strategy_pb2 +from autodist.utils import logging + + +class CollectiveOpsConfig: + """Config for using Collective Ops.""" + + group_size: int + group_key: str + instance_key: str + shape: tensorflow.TensorShape + dtype: tensorflow.dtypes.DType + # merge_op: str + # final_op: str + + +class SFBSynchronizer(Synchronizer): + """ + Sufficient Factor Broadcasting Synchronizer. + + This AllReduce Synchronizer currently uses TensorFlow's `collective_device_ops` + to insert their AllReduce ops into our graph. + + The class AllReduceSynchronizer class contains the following possible instantiations: + + 1. spec=`auto`: single-node multiple devices, or cross-node AllReduce based on collective ops + 2. spec=`nccl`: single-node multiple devices, or cross-node AllReduce based on NCCL + 3. spec=`ring`/'tree', AllReduce with different reduction structures: ring, tree, etc. + + However note that it does not contain the following instantiations: + + 1. shuffle reduce (reduce to CPU or GPU as in PS) + AllReduce across nodes + 2. any other types of hybrid reduction of PS and AllReduce. + """ + + def __init__(self, config: strategy_pb2.Strategy.Node): + # compressor_value = getattr(config, 'compressor') + compressor_value = getattr(config.compressor, 'type') + syncer_config = getattr(config, config.WhichOneof('synchronizer')) + self._spec = synchronizers_pb2.SFBSynchronizer.Spec.Name(syncer_config.spec) + if autodist.float_major_minor_tf_version < 1.15 or autodist.float_major_minor_tf_version < 2.1: + logging.warning('Collective synchronizer spec "{}" a.k.a communication_hint has no effect ' + 'until tensorflow-gpu 1.x>= 1.15 or 2.x>=2.1. It may cause error currently.' + .format(self._spec)) + self._spec = None + + # Collective ops within the same group will be merged by the scoped optimizer. + # Normally the group index shall be smaller than the number of variables in the graph; this kernel assumes + # the strategy will validate the group assignments are legitimate. + self._group = syncer_config.group + super().__init__() + if compressor_value is not None: + self._compressor_type = compressor_pb2.Compressor.Type.Name(compressor_value) + + @staticmethod + def _broadcast_send(tensor: Tensor, conf: CollectiveOpsConfig): + """ + Using CollectiveOps, broadcasting send the given tensor. + + Args: + tensor (Tensor): the tensor to all-reduce + conf (CollectiveOpsConfig): the config for CollectiveOps + + Returns: + The sent Tensor + """ + return collective_ops.broadcast_send(tensor, **conf.__dict__) + + @staticmethod + def _broadcast_recv(conf: CollectiveOpsConfig): + """ + Using CollectiveOps, broadcasting receive the given tensor. + + Args: + tensor (Tensor): the tensor to all-reduce + conf (CollectiveOpsConfig): the config for CollectiveOps + + Returns: + The received Tensor + """ + return collective_ops.broadcast_recv(**conf.__dict__) + + def in_graph_apply(self, graph_item, var_name): + """ + Perform in-graph synchronization based on AllReduce and TensorFlow Collective Ops. + + Note that collective ops now only supports dense tensors. + + Args: + graph_item (graph_item.GraphItem): the graph_item to be distributed + var_name (str): the corresponded variable name + + Returns: + graph_item.GraphItem: The new graph + """ + # Skip sfb synchronizer when rank <= 1 + if self.num_replicas * self.num_workers <= 1: + return graph_item + + item = graph_item + var_op_name = get_op_name(var_name) + + # Throw an error if the variable is sparse + # master_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(0)) + # grad, _, _ = graph_item.var_op_name_to_grad_info[master_op_name] + with item.graph.as_default(): + self._share_initializer(item, var_op_name, master_replica=0) + self._collect_dense_gradients(item, var_op_name) + return item + + # pylint: disable-msg=too-many-locals + def _collect_dense_gradients(self, graph_item, var_op_name): + """Append collective ops after the gradient is calculated.""" + if self.num_replicas * self.num_workers <= 1: + raise ValueError('CollectiveOps requires collective group size > 1') + + compressors = defaultdict(lambda: Compressor.create(self._compressor_type, var_op_name)) + + conf_u = CollectiveOpsConfig() + conf_v = CollectiveOpsConfig() + conf_u.group_size = len(self.all_canonical_replica_devices) + conf_v.group_size = len(self.all_canonical_replica_devices) + conf_u.group_key = get_collective_keys().get_group_key(self.all_canonical_replica_devices) + conf_v.group_key = get_collective_keys().get_group_key(self.all_canonical_replica_devices) + + if self._spec: + setattr(conf_u, 'communication_hint', self._spec) + setattr(conf_v, 'communication_hint', self._spec) + + for i in range(0, self.num_replicas): + tensors_u = [] + tensors_v = [] + for j in range(0, self.num_replicas): + op_name = ops.prepend_name_scope(var_op_name, replica_prefix(j)) + # conf.instance_key = get_collective_keys().get_instance_key(op_name) + grad, _, _ = graph_item.var_op_name_to_grad_info[op_name] + v, u = grad.op.inputs + conf_u.shape = u.shape + conf_v.shape = v.shape + conf_u.dtype = u.dtype + conf_v.dtype = v.dtype + conf_u.instance_key = get_collective_keys().get_instance_key(op_name + 'u') + conf_v.instance_key = get_collective_keys().get_instance_key(op_name + 'v') + if i == j: + tensors_u.append(self._broadcast_send(u, conf_u)) + tensors_v.append(self._broadcast_send(v, conf_v)) + else: + tensors_u.append(self._broadcast_recv(conf_u)) + tensors_v.append(self._broadcast_recv(conf_v)) + op_name = ops.prepend_name_scope(var_op_name, replica_prefix(j)) + grad, _, _ = graph_item.var_op_name_to_grad_info[op_name] + grad_consumers = get_consumers(grad.op) + received_grads = [compressors[i].decompress(v, u) for u, v in zip(tensors_u, tensors_v)] + with ops.name_scope(replica_prefix(i) + "/collective-group-{}/".format(self._group)): + # compressed_grad = compressors[i].compress(grad) + with ops.colocate_with(grad.op): + combined_grad = math_ops.add_n(received_grads) + update_consumers(grad_consumers, grad, combined_grad) + + # def _collect_sparse_gradients(self, graph_item, var_op_name): + # """Append collective ops after the gradient is calculated.""" + # if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value: + # raise NotImplementedError('Currently the collective NCCL AllGather is not supported in TensorFlow.' + # 'Please choose another strategy.') + # conf = {} + # if self._spec: + # conf = {'communication_hint': self._spec} + # if self._compressor_type: + # logging.warning('AllGather currently does not support AutoDist compressor so it skips.') + # if self.num_replicas * self.num_workers <= 1: + # raise ValueError('CollectiveOps requires collective group size > 1') + # for i in range(0, self.num_replicas): + # op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) + # grad, _, _ = graph_item.var_op_name_to_grad_info[op_name] + # # TODO (Tairui): (3) Merge of reduction for performance + # indices_c_ops = grad.indices.consumers() + # indices_cc_ops = get_control_consumers(grad.indices.op) + # values_c_ops = grad.values.consumers() + # values_cc_ops = get_control_consumers(grad.values.op) + # with ops.name_scope(replica_prefix(i)): + # with ops.colocate_with(grad.indices.op): + # new_indices = collective_ops.all_gather( + # grad.indices, + # self.num_replicas * self.num_workers, + # get_collective_keys().get_group_key(self.all_canonical_replica_devices), + # get_collective_keys().get_instance_key(var_op_name + '-indices'), + # **conf + # ) + # with ops.colocate_with(grad.values.op): + # new_values = collective_ops.all_gather( + # grad.values, + # self.num_replicas * self.num_workers, + # get_collective_keys().get_group_key(self.all_canonical_replica_devices), + # get_collective_keys().get_instance_key(var_op_name + '-values'), + # **conf + # ) + # update_consumers(indices_c_ops, grad.indices, new_indices) + # update_control_consumers(indices_cc_ops, grad.indices.op, new_indices.op) + # update_consumers(values_c_ops, grad.values, new_values) + # update_control_consumers(values_cc_ops, grad.values.op, new_values) + + def _share_initializer(self, graph_item, var_op_name, master_replica=0): + """Share the initializers of all replica variables to use initializer on replica=master_replica.""" + # find the initial value of the var on master_replica + master_var_op = graph_item.graph.get_operation_by_name( + ops.prepend_name_scope(var_op_name, replica_prefix(master_replica))) + master_var = graph_item.trainable_var_op_to_var[master_var_op] + master_init_tensor = graph_item.graph.get_tensor_by_name(master_var.initial_value.name) + master_init_op = master_init_tensor.op + # set the device of the init ops to reside on the chief device + master_init_device = device_spec.DeviceSpecV2.from_string(master_init_op.device) \ + .replace(task=0) + master_init_op._set_device_from_string(master_init_device.to_string()) + + for i in range(0, self.num_replicas): + if i == master_replica: + continue + var_op = graph_item.graph.get_operation_by_name( + ops.prepend_name_scope(var_op_name, replica_prefix(i))) + var = graph_item.trainable_var_op_to_var[var_op] + init_op = graph_item.graph.get_tensor_by_name(var.initial_value.name).op + init_assign_op = get_consumers(init_op)[0] + init_assign_op._update_input(1, master_init_tensor) + + # pylint: disable=no-self-use + def between_graph_apply(self, graph_item, var_name): + """Allreduce synchronizer will do nothing in between-graph synchronization.""" + return graph_item diff --git a/autodist/kernel/synchronization/synchronizer.py b/autodist/kernel/synchronization/synchronizer.py index bfaa5a5..e9a38d7 100644 --- a/autodist/kernel/synchronization/synchronizer.py +++ b/autodist/kernel/synchronization/synchronizer.py @@ -41,6 +41,7 @@ def __init__(self): self.var_op_to_accum_apply_op = None self.is_chief = None self.all_canonical_replica_devices = None + self._compressor_type = None # pylint: disable=too-many-arguments def assign_cluster_information(self, diff --git a/autodist/patch.py b/autodist/patch.py index d53a08f..7335c24 100644 --- a/autodist/patch.py +++ b/autodist/patch.py @@ -55,7 +55,6 @@ class PatchTensorFlow: @staticmethod def patch_var_reading(): """It only works with tf.gradients but not tape.gradients.""" - def value(self): """A cached operation which reads the value of this variable.""" if self._cached_value is not None: diff --git a/autodist/proto/compressor.proto b/autodist/proto/compressor.proto new file mode 100644 index 0000000..069ff95 --- /dev/null +++ b/autodist/proto/compressor.proto @@ -0,0 +1,34 @@ +// Copyright 2020 Petuum +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * AutoDist compressor messages. + */ + +syntax = "proto3"; + +package autodist.proto; + +message Compressor { + + /** Which compressor to use */ + enum Type { + NoneCompressor = 0; // No compression + HorovodCompressor = 1; // Horovod's Compression + HorovodCompressorEF = 2; // Horovod's Compression but with Error Feedback. + SFBCompressor = 3; + } + + Type type = 1; +} \ No newline at end of file diff --git a/autodist/proto/strategy.proto b/autodist/proto/strategy.proto index 8b11e7d..47f1eb4 100644 --- a/autodist/proto/strategy.proto +++ b/autodist/proto/strategy.proto @@ -23,6 +23,8 @@ package autodist.proto; import "autodist/proto/synchronizers.proto"; +import "autodist/proto/compressor.proto"; + /** * Represents the strategy the AutoDist backend will implement. @@ -45,9 +47,12 @@ message Strategy { oneof synchronizer { autodist.proto.PSSynchronizer PSSynchronizer = 2; // One of a synchronizer to choose autodist.proto.AllReduceSynchronizer AllReduceSynchronizer = 3; // One of a synchronizer to choose + autodist.proto.SFBSynchronizer SFBSynchronizer = 4; } - string partitioner = 4; // Optional partitioner configuration, e.g. `1, 2, 1` - repeated Node part_config = 5; // Optional node configs for each node partition (if partitioned) + string partitioner = 5; // Optional partitioner configuration, e.g. `1, 2, 1` + repeated Node part_config = 6; // Optional node configs for each node partition (if partitioned) + + autodist.proto.Compressor compressor = 7; } /** configuration of some individual nodes of the computational graph */ diff --git a/autodist/proto/synchronizers.proto b/autodist/proto/synchronizers.proto index f70996f..e410546 100644 --- a/autodist/proto/synchronizers.proto +++ b/autodist/proto/synchronizers.proto @@ -42,16 +42,23 @@ message AllReduceSynchronizer { Spec spec = 1; // Specification for collective communication - /** Which gradient compression method to use */ - enum Compressor { - NoneCompressor = 0; // No compression - HorovodCompressor = 1; // Horovod's Compression - HorovodCompressorEF = 2; // Horovod's Compression but with Error Feedback. - // PowerSGDCompressor = 3; // PowerSGD compression algorithm (arxiv.org/abs/1905.13727) + /** The allreduce group to merge with. The group index should be less than the number of variables */ + int32 group = 3; +} + +/** + * Synchronization using Sufficient Factor Broadcasting. + */ +message SFBSynchronizer { + /** Which communication method to use */ + enum Spec { + AUTO = 0; // Runtime's automatic choices + NCCL = 1; // Use ncclAllReduce for all-reduce, and ring algorithms for all-gather + RING = 2; // TensorFlow's ring algorithms for all-reduce and all-gather } - Compressor compressor = 2; // One of the compressors to choose + Spec spec = 1; // Specification for collective communication /** The allreduce group to merge with. The group index should be less than the number of variables */ int32 group = 3; -} +} \ No newline at end of file diff --git a/autodist/strategy/__init__.py b/autodist/strategy/__init__.py index 0424dd5..44e070a 100644 --- a/autodist/strategy/__init__.py +++ b/autodist/strategy/__init__.py @@ -25,3 +25,4 @@ from .partitioned_all_reduce_strategy import PartitionedAR from .random_axis_partition_all_reduce_strategy import RandomAxisPartitionAR from .uneven_partition_ps_strategy import UnevenPartitionedPS +from .poseidon_strategy import Poseidon diff --git a/autodist/strategy/all_reduce_strategy.py b/autodist/strategy/all_reduce_strategy.py index 4e0f24a..3305ec0 100644 --- a/autodist/strategy/all_reduce_strategy.py +++ b/autodist/strategy/all_reduce_strategy.py @@ -15,7 +15,7 @@ """AllReduce StrategyBuilder.""" from autodist.strategy.base import Strategy, StrategyBuilder -from autodist.proto import strategy_pb2, synchronizers_pb2 +from autodist.proto import strategy_pb2, synchronizers_pb2, compressor_pb2 class AllReduce(StrategyBuilder): @@ -85,6 +85,6 @@ def _gen_all_reduce_node_config(var_name, group=0, all_reduce_spec="NCCL", compr node = strategy_pb2.Strategy.Node() node.var_name = var_name node.AllReduceSynchronizer.spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Value(all_reduce_spec) - node.AllReduceSynchronizer.compressor = synchronizers_pb2.AllReduceSynchronizer.Compressor.Value(compressor) + node.compressor.type = compressor_pb2.Compressor.Type.Value(compressor) node.AllReduceSynchronizer.group = group return node diff --git a/autodist/strategy/partitioned_all_reduce_strategy.py b/autodist/strategy/partitioned_all_reduce_strategy.py index 543fb79..9eed267 100644 --- a/autodist/strategy/partitioned_all_reduce_strategy.py +++ b/autodist/strategy/partitioned_all_reduce_strategy.py @@ -18,7 +18,7 @@ from autodist.kernel.common.utils import get_op_name from autodist.kernel.partitioner import PartitionerConfig -from autodist.proto import strategy_pb2, synchronizers_pb2 +from autodist.proto import strategy_pb2, synchronizers_pb2, compressor_pb2 from autodist.strategy.base import Strategy, StrategyBuilder @@ -86,10 +86,8 @@ def _gen_node_config(self, var, var_counter): if num_shards <= 1: node.AllReduceSynchronizer.spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Value("AUTO") - node.AllReduceSynchronizer.compressor = \ - synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("NoneCompressor") - # node.AllReduceSynchronizer.compressor = \ - # synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("PowerSGDCompressor") + node.compressor.type = compressor_pb2.Compressor.Type.Value("NoneCompressor") + # node.compressor = compressor_pb2.Compressor.Type.Value("PowerSGDCompressor") node.AllReduceSynchronizer.group = var_counter // self.chunk_size return node, num_shards @@ -109,10 +107,8 @@ def _gen_node_config(self, var, var_counter): # Here let's just make it consistent part.var_name = '{}/part_{}:0'.format(get_op_name(var.name), i) part.AllReduceSynchronizer.spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Value("AUTO") - part.AllReduceSynchronizer.compressor = \ - synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("NoneCompressor") - # part.AllReduceSynchronizer.compressor = \ - # synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("PowerSGDCompressor") + part.compressor.type = compressor_pb2.Compressor.Type.Value("NoneCompressor") + # part.compressor = compressor_pb2.Compressor.Type.Value("PowerSGDCompressor") part.AllReduceSynchronizer.group = (var_counter + i) // self.chunk_size node.part_config.extend([part]) return node, num_shards diff --git a/autodist/strategy/poseidon_strategy.py b/autodist/strategy/poseidon_strategy.py new file mode 100644 index 0000000..95ec5ed --- /dev/null +++ b/autodist/strategy/poseidon_strategy.py @@ -0,0 +1,169 @@ +# Copyright 2020 Petuum, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Poseidon StrategyBuilder.""" + +from tensorflow.python.framework import tensor_shape + +from autodist.strategy.base import Strategy, StrategyBuilder +from autodist.proto import strategy_pb2, synchronizers_pb2, compressor_pb2 +from autodist.kernel.common.utils import get_op_name + + +class Poseidon(StrategyBuilder): + """ + PS StrategyBuilder with Greedy Load Balancing. + + The Load Balancing is determined by total memory + usage for storing variables, i.e. we always assign + a variable to the current lowest-memory-usage + Parameter Server. + """ + + #pylint: disable=too-many-arguments + def __init__(self, batch_size, local_proxy_variable=False, sync=True, staleness=0, + broadcast_spec='NCCL', compressor='SFBCompressor'): + if batch_size < 1: + raise ValueError('The batch_size must be greater than zero.') + self._batch_size = batch_size + self._local_proxy_variable = local_proxy_variable + self._sync = sync + self._staleness = staleness + if self._staleness > 0: + assert self._sync, 'If staleness is positive, sync has to be set true.' + self.broadcast_spec = broadcast_spec + self.compressor = compressor + self.loads = {} + super().__init__() + + def build(self, graph_item, resource_spec): + """Generate the Strategy.""" + expr = Strategy() + + # get each variable, generate variable synchronizer config + expr.graph_config.replicas.extend([k for k, v in resource_spec.gpu_devices]) + for k, v in resource_spec.node_cpu_devices.items(): + if k not in resource_spec.node_gpu_devices: + expr.graph_config.replicas.extend(v) + + num_servers = resource_spec.num_cpus + num_workers = resource_spec.num_gpus + + # find all variables + variables = graph_item.trainable_var_op_to_var.values() + reduction_device_names = [k for k, _ in resource_spec.cpu_devices] + self.loads = {ps: 0.0 for ps in reduction_device_names} + + # Mark each variable to be synchronized with a Parameter Server + for var in variables: + op_name = get_op_name(var.name) + shape = get_op_shape(var) + if op_name == 'sequential/dense/kernel' and ((2 * self._batch_size * (num_workers - 1) * + (shape[0] + shape[1])) <= (2 * shape[0] * shape[1] * + (num_servers + num_workers - 2) + / num_servers)): + node_config = self._gen_sfb_node_config(var.name, broadcast_spec=self.broadcast_spec, + compressor=self.compressor) + else: + node_config = self._gen_ps_node_config(var, self._local_proxy_variable, self._sync, + self._staleness) + expr.node_config.append(node_config) + + return expr + + def _gen_ps_node_config(self, var, local_proxy_variable, sync, staleness): + """ + Creates a NodeConfig specifying synchronization with Parameter Servers. + + Args: + var (Variable): The variable to generate a config for. + + Returns: + strategy_pb2.Strategy.Node: the config for the node. + """ + min_ps = min(self.loads, key=self.loads.get) + self.loads[min_ps] += byte_size_load_fn(var) + + node = strategy_pb2.Strategy.Node() + node.var_name = var.name + node.PSSynchronizer.reduction_destination = min_ps + node.PSSynchronizer.local_replication = local_proxy_variable + node.PSSynchronizer.sync = sync + node.PSSynchronizer.staleness = staleness + return node + + @staticmethod + def _gen_sfb_node_config(var_name, group=0, broadcast_spec="NCCL", compressor="SFBCompressor"): + """ + Creates a NodeConfig specifying synchronization with Parameter Servers. + + Args: + var (Variable): The variable to generate a config for. + + Returns: + strategy_pb2.Strategy.Node: the config for the node. + """ + node = strategy_pb2.Strategy.Node() + node.var_name = var_name + node.SFBSynchronizer.spec = synchronizers_pb2.SFBSynchronizer.Spec.Value(broadcast_spec) + node.compressor.type = compressor_pb2.Compressor.Type.Value(compressor) + node.SFBSynchronizer.group = group + return node + + +def get_op_shape(op): + """Get number of elements in a "variable" op.""" + shape = op.get_shape() + if not shape.is_fully_defined(): + # Due to legacy behavior, scalar "Variable" ops have output Tensors that + # have unknown shape when the op is created (and hence passed to this + # load function for placement), even though the scalar shape is set + # explicitly immediately afterward. + shape = tensor_shape.TensorShape(op.get_attr("shape")) + shape.assert_is_fully_defined() + return shape + + +def byte_size_load_fn(op): + """ + Load function that computes the byte size of a single-output `Operation`. + + Copied (with modifications) from tensorflow.contrib.training.python.training.device_setter. + + This is intended to be used with `"Variable"` ops, which have a single + `Tensor` output with the contents of the variable. However, it can also be + used for calculating the size of any op that has a single output. + + Intended to be used with `GreedyLoadBalancingStrategy`. + + Args: + op: An `Operation` with a single output, typically a "Variable" op. + + Returns: + The number of bytes in the output `Tensor`. + + Raises: + ValueError: if `op` does not have a single output, or if the shape of the + single output is not fully-defined. + """ + elem_size = op.dtype.size + shape = op.get_shape() + if not shape.is_fully_defined(): + # Due to legacy behavior, scalar "Variable" ops have output Tensors that + # have unknown shape when the op is created (and hence passed to this + # load function for placement), even though the scalar shape is set + # explicitly immediately afterward. + shape = tensor_shape.TensorShape(op.get_attr("shape")) + shape.assert_is_fully_defined() + return shape.num_elements() * elem_size diff --git a/autodist/strategy/random_axis_partition_all_reduce_strategy.py b/autodist/strategy/random_axis_partition_all_reduce_strategy.py index 5e90dd4..f3ba8ba 100644 --- a/autodist/strategy/random_axis_partition_all_reduce_strategy.py +++ b/autodist/strategy/random_axis_partition_all_reduce_strategy.py @@ -19,7 +19,7 @@ from autodist.kernel.common.utils import get_op_name from autodist.kernel.partitioner import PartitionerConfig -from autodist.proto import strategy_pb2, synchronizers_pb2 +from autodist.proto import strategy_pb2, synchronizers_pb2, compressor_pb2 from autodist.strategy.base import Strategy, StrategyBuilder @@ -87,10 +87,8 @@ def _gen_node_config(self, var, var_counter, grad): if num_shards <= 1: node.AllReduceSynchronizer.spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Value("AUTO") - node.AllReduceSynchronizer.compressor = \ - synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("NoneCompressor") - # node.AllReduceSynchronizer.compressor = \ - # synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("PowerSGDCompressor") + node.compressor.type = compressor_pb2.Compressor.Type.Value("NoneCompressor") + # node.compressor = compressor_pb2.Compressor.Type.Value("PowerSGDCompressor") node.AllReduceSynchronizer.group = var_counter // self.chunk_size return node, num_shards @@ -107,10 +105,8 @@ def _gen_node_config(self, var, var_counter, grad): # Here let's just make it consistent part.var_name = '{}/part_{}:0'.format(get_op_name(var.name), i) part.AllReduceSynchronizer.spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Value("AUTO") - part.AllReduceSynchronizer.compressor = \ - synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("NoneCompressor") - # part.AllReduceSynchronizer.compressor = \ - # synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("PowerSGDCompressor") + part.compressor.type = compressor_pb2.Compressor.Type.Value("NoneCompressor") + # part.compressor = compressor_pb2.Compressor.Type.Value("PowerSGDCompressor") part.AllReduceSynchronizer.group = (var_counter + i) // self.chunk_size node.part_config.extend([part]) return node, num_shards diff --git a/tests/integration/single_run.py b/tests/integration/single_run.py index 79fca5e..1441d34 100644 --- a/tests/integration/single_run.py +++ b/tests/integration/single_run.py @@ -10,6 +10,7 @@ from autodist.strategy.partitioned_all_reduce_strategy import PartitionedAR from autodist.strategy.uneven_partition_ps_strategy import UnevenPartitionedPS from autodist.strategy.random_axis_partition_all_reduce_strategy import RandomAxisPartitionAR +from autodist.strategy.poseidon_strategy import Poseidon STRATEGIES_FOR_DISTRIBUTED_TESTS = { 'PS': PS(sync=True), @@ -23,7 +24,8 @@ 'ParallaxProxy': Parallax(local_proxy_variable=True), 'PartitionedAR': PartitionedAR(), 'RandomAxisPartitionAR': RandomAxisPartitionAR(chunk_size=4), - 'UnevenPartitionedPS': UnevenPartitionedPS(local_proxy_variable=True) + 'UnevenPartitionedPS': UnevenPartitionedPS(local_proxy_variable=True), + 'Poseidon': Poseidon(batch_size=32, local_proxy_variable=True, staleness=3) } diff --git a/tests/integration/test_all.py b/tests/integration/test_all.py index 24bd863..b66c78d 100644 --- a/tests/integration/test_all.py +++ b/tests/integration/test_all.py @@ -14,6 +14,7 @@ from autodist.strategy.partitioned_all_reduce_strategy import PartitionedAR from autodist.strategy.uneven_partition_ps_strategy import UnevenPartitionedPS from autodist.strategy.random_axis_partition_all_reduce_strategy import RandomAxisPartitionAR +from autodist.strategy.poseidon_strategy import Poseidon from .cases import c0, c1, c2, c3, c4, c5, c6, c7, c8 @@ -37,12 +38,13 @@ PartitionedPS(local_proxy_variable=True), AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='NoneCompressor'), AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='HorovodCompressor'), - AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='HorovodCompressorEF'), + # AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='HorovodCompressorEF'), PSLoadBalancing(local_proxy_variable=True), Parallax(local_proxy_variable=True), PartitionedAR(), UnevenPartitionedPS(local_proxy_variable=True), - RandomAxisPartitionAR(chunk_size=4) + RandomAxisPartitionAR(chunk_size=4), + Poseidon(batch_size=32, local_proxy_variable=True) ]