Skip to content
1 change: 1 addition & 0 deletions autodist/kernel/synchronization/all_reduce_synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tensorflow.python import ops
from tensorflow.python.framework import device_spec
from tensorflow.python.ops import collective_ops
from tensorflow.python.framework.ops import Tensor # pylint: disable=unused-import

import autodist
from autodist.const import ENV
Expand Down
189 changes: 108 additions & 81 deletions autodist/kernel/synchronization/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.

"""Gradient Compressors for All-Reduce."""
import copy
from abc import ABC, abstractmethod
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import Tensor
from tensorflow.python.ops import collective_ops, math_ops

#from tensorflow.python.ops import array_ops, collective_ops, linalg_ops, math_ops, random_ops
#from autodist.kernel.synchronization.collective_key import get_collective_keys
from tensorflow.python.ops import collective_ops, math_ops, random_ops, array_ops, linalg_ops
from autodist.kernel.synchronization.collective_key import get_collective_keys
#from autodist.utils import logging


Expand Down Expand Up @@ -205,80 +204,108 @@ class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works becaus
"""Horovod's Compression but with Error Feedback."""


# class PowerSGDCompressor(CompressorEF):
# """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727)."""

# def __init__(self, var_op_name, rank=1):
# self.rank = rank
# self.og_shape, self.ndims, self.new_shape, self.compressor = None, None, None, None
# super().__init__(var_op_name)

# def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig):
# """
# Compress, reduce, and decompress a given tensor.

# Args:
# tensor (Tensor): the Tensor to reduce.
# conf (CollectiveOpsConfig): the config for Collective Ops.

# Returns:
# Reduced Tensor
# """
# if self.og_shape is None:
# self.og_shape = tensor.shape
# self.ndims = len(self.og_shape)

# # Check if rank 1 tensor (this shouldn't be called with sparse tensors)
# # Just reduce it if it is, no need to compress
# if self._is_1d:
# return self._all_reduce(tensor, conf)

# logging.info(f"Compressing tensor {tensor.name} (var {self.var_op_name}) with shape {tensor.shape}")
# if self.ndims > 2:
# tensor = array_ops.reshape(tensor, [self.og_shape[0], -1])

# if self.compressor is None:
# self.new_shape = array_ops.shape_v2(tensor)
# self.compressor = random_ops.random_normal([self.new_shape[1], self.rank])

# if self.error is not None:
# tensor += self.error

# compressed_tensor = self._compress(tensor)
# self.error = tensor - self._decompress(compressed_tensor)

# # all reduce mean p
# reduced = self._all_reduce(compressed_tensor, conf)
# reduced = self._orthogonalize(reduced)

# # update compressor
# self.compressor = math_ops.matmul(tensor, reduced, transpose_a=True)
# conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + "/compressor")
# self.compressor = self._all_reduce(self.compressor, conf)
# return array_ops.reshape(self._decompress(reduced), self.og_shape) \
# if self.ndims > 2 else self._decompress(reduced)

# def _compress(self, tensor: Tensor):
# return math_ops.matmul(tensor, self.compressor)

# def _decompress(self, compressed_tensor: Tensor):
# return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True)

# @property
# def _is_1d(self):
# return self.ndims <= 1 or (
# self.ndims == 2 and any(d == 1 for d in self.og_shape)
# )

# @staticmethod
# def _orthogonalize(matrix):
# _, m = matrix.shape
# for i in range(m):
# v = matrix[:, i]
# v /= linalg_ops.norm_v2(v)
# v = array_ops.expand_dims_v2(v, 1)

# begin, rest = matrix[:, :i], matrix[:, (i + 1):]
# rest -= math_ops.matmul(v, rest, transpose_a=True) * v
# matrix = array_ops.concat([begin, v, rest], 1)
# return matrix
class PowerSGDCompressor(CompressorEF):
"""An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727)."""

def __init__(self, var_op_name, rank=1):
self.rank = rank
self.og_shape, self.ndims = None, None
self.compressor, self.compressor_conf = None, None # compressor is the Q in paper
self.var_op_name = var_op_name
super().__init__(var_op_name)

def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig):
"""
Compress, reduce, and decompress a given tensor.

Args:
tensor (Tensor): the Tensor to reduce.
conf (CollectiveOpsConfig): the config for Collective Ops.

Returns:
Reduced Tensor
"""
if self.og_shape is None:
self.og_shape = array_ops.shape_v2(tensor)
if self.og_shape.shape[0] is None:
self.ndims = 0
else:
self.ndims = self.og_shape.shape[0]

# rank <= 1
if self.ndims <= 1 or (self.ndims == 2 and any([i == 1 for i in tensor.get_shape().as_list()])):
return self._all_reduce(tensor, conf)

og_dtype = tensor.dtype
tensor = array_ops.reshape(math_ops.cast(tensor, dtypes.float32), [self.og_shape[0], -1])

# compressor init
if self.compressor is None:
self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank],
seed=1000, dtype=dtypes.float32)

self.compressor_conf = copy.copy(conf)
self.compressor_conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor')

if self.error is not None:
tensor += self.error

compressed_tensor = self._compress(tensor)
self.error = tensor - self._decompress(compressed_tensor)

reduced_tensor = self._all_reduce(compressed_tensor, conf)

orthonormal_reduced_tensor = self._modified_gram_schmidt(reduced_tensor)

self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr

# all reduce mean compressor
self.compressor = self._all_reduce(self.compressor, self.compressor_conf)

return math_ops.cast(array_ops.reshape(self._decompress(orthonormal_reduced_tensor), self.og_shape), og_dtype)

def _compress(self, tensor: Tensor):
"""
Compress a given tensor.

Args:
tensor (Tensor): the Tensor to compress.

Returns:
Tensor
"""
return math_ops.matmul(tensor, self.compressor) # nxm * mxr => nxr

def _decompress(self, compressed_tensor: Tensor):
"""
Decompress a given tensor.

Args:
compressed_tensor (Tensor): the Tensor to decompress.

Returns:
Tensor, Context
"""
return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm => nxm

@staticmethod
def _modified_gram_schmidt(matrix):
"""
Apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns.

Args:
matrix (Tensor): the Tensor to orthogonalize.

Returns:
matrix (Tensor)
"""
_, m = matrix.shape

for i in range(m):
v = matrix[:, i:(i + 1)]
v /= linalg_ops.norm_v2(v, axis=0)

rest = matrix[:, (i + 1):]
rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v
matrix = array_ops.concat([matrix[:, :i], v, rest], 1)
return matrix
2 changes: 1 addition & 1 deletion autodist/proto/synchronizers.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ message AllReduceSynchronizer {
NoneCompressor = 0; // No compression
HorovodCompressor = 1; // Horovod's Compression
HorovodCompressorEF = 2; // Horovod's Compression but with Error Feedback.
// PowerSGDCompressor = 3; // PowerSGD compression algorithm (arxiv.org/abs/1905.13727)
PowerSGDCompressor = 3; // PowerSGD compression algorithm (arxiv.org/abs/1905.13727)
}

Compressor compressor = 2; // One of the compressors to choose
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='NoneCompressor'),
AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='HorovodCompressor'),
AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='HorovodCompressorEF'),
AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='PowerSGDCompressor'),
PSLoadBalancing(local_proxy_variable=True),
Parallax(local_proxy_variable=True),
PartitionedAR(),
Expand Down