Skip to content
1 change: 0 additions & 1 deletion autodist/graph_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def get_default_graph_item():

def wrap_optimizer_init(fn: Callable):
"""Wraps the __init__ function of OptimizerV2 objects and stores the info in the default GraphItem."""

def wrapper(*args, **kwargs):
# args[0] should be `self`, which is an object of type == optimizer class
containing_class = type(args[0])
Expand Down
1 change: 1 addition & 0 deletions autodist/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
# So Strategy.create() works
from autodist.kernel.synchronization.ps_synchronizer import PSSynchronizer
from autodist.kernel.synchronization.all_reduce_synchronizer import AllReduceSynchronizer
from autodist.kernel.synchronization.sfb_synchronizer import SFBSynchronizer
4 changes: 2 additions & 2 deletions autodist/kernel/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def _initialize_synchronizers(self):
for part in node.part_config:
self._synchronizers[part.var_name] = \
Synchronizer.create(part.WhichOneof('synchronizer'),
getattr(part, part.WhichOneof('synchronizer')))
part)
else:
self._synchronizers[node.var_name] = \
Synchronizer.create(node.WhichOneof('synchronizer'),
getattr(node, node.WhichOneof('synchronizer')))
node)

config = self._strategy.graph_config.replicas
replica_devices = {device_spec.DeviceSpecV2.from_string(s) for s in config}
Expand Down
50 changes: 42 additions & 8 deletions autodist/kernel/synchronization/all_reduce_synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,31 @@
from tensorflow.python import ops
from tensorflow.python.framework import device_spec
from tensorflow.python.ops import collective_ops
from tensorflow.python.framework.ops import Tensor

import autodist
from autodist.const import ENV
from autodist.kernel.common.utils import get_consumers, update_consumers, \
replica_prefix, get_control_consumers, update_control_consumers
from autodist.kernel.common.utils import get_op_name
from autodist.kernel.synchronization.collective_key import get_collective_keys
from autodist.kernel.synchronization.compressor import Compressor, CollectiveOpsConfig
# from autodist.kernel.synchronization.compressor import Compressor, CollectiveOpsConfig
from autodist.kernel.synchronization.compressor import Compressor
from autodist.kernel.synchronization.synchronizer import Synchronizer
from autodist.proto import synchronizers_pb2
from autodist.proto import synchronizers_pb2, compressor_pb2, strategy_pb2
from autodist.utils import logging


class CollectiveOpsConfig:
"""Config for using Collective Ops."""

group_size: int
group_key: str
instance_key: str
merge_op: str
final_op: str


class AllReduceSynchronizer(Synchronizer):
"""
AllReduce Synchronizer.
Expand All @@ -50,21 +62,39 @@ class AllReduceSynchronizer(Synchronizer):
2. any other types of hybrid reduction of PS and AllReduce.
"""

def __init__(self, config: synchronizers_pb2.AllReduceSynchronizer):
self._spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Name(config.spec)
def __init__(self, config: strategy_pb2.Strategy.Node):
# compressor_value = getattr(config, 'compressor')
compressor_value = getattr(config.compressor, 'type')
syncer_config = getattr(config, config.WhichOneof('synchronizer'))
self._spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Name(syncer_config.spec)
if autodist.float_major_minor_tf_version < 1.15 or autodist.float_major_minor_tf_version < 2.1:
logging.warning('Collective synchronizer spec "{}" a.k.a communication_hint has no effect '
'until tensorflow-gpu 1.x>= 1.15 or 2.x>=2.1. It may cause error currently.'
.format(self._spec))
self._spec = None

self._compressor_type = synchronizers_pb2.AllReduceSynchronizer.Compressor.Name(config.compressor)

# Collective ops within the same group will be merged by the scoped optimizer.
# Normally the group index shall be smaller than the number of variables in the graph; this kernel assumes
# the strategy will validate the group assignments are legitimate.
self._group = config.group
self._group = syncer_config.group
super().__init__()
if compressor_value is not None:
self._compressor_type = compressor_pb2.Compressor.Type.Name(compressor_value)
print(self._compressor_type)

@staticmethod
def _all_reduce(tensor: Tensor, conf: CollectiveOpsConfig):
"""
Using CollectiveOps, AllReduce the given tensor.

Args:
tensor (Tensor): the tensor to all-reduce
conf (CollectiveOpsConfig): the config for CollectiveOps

Returns:
The All-Reduced Tensor
"""
return collective_ops.all_reduce(tensor, **conf.__dict__)

def in_graph_apply(self, graph_item, var_name):
"""
Expand Down Expand Up @@ -124,8 +154,12 @@ def _collect_dense_gradients(self, graph_item, var_op_name):

# "\/" is added for name scope reuse
with ops.name_scope(replica_prefix(i) + "/collective-group-{}/".format(self._group)):
# compressed_grad = compressors[i].compress(grad)
with ops.colocate_with(grad.op):
reduced_grad = compressors[i].reduce(grad, conf)
compressed_grad = compressors[i].compress(grad)
reduced = self._all_reduce(compressed_grad, conf)
reduced_grad = compressors[i].decompress(reduced)
# reduced_grad = compressors[i].decompress(reduced)
update_consumers(grad_consumers, grad, reduced_grad)
# TODO(Hao): update grad, target pair here or not?

Expand Down
Loading