From ad5408ff1b61e83bd409a212a806ec1a24c5fece Mon Sep 17 00:00:00 2001 From: Christy Li Date: Tue, 15 Dec 2020 13:08:21 -0500 Subject: [PATCH 1/4] add simulator code --- autodist/simulator/__init__.py | 0 autodist/simulator/config.py | 96 ++++ autodist/simulator/models/__init__.py | 0 autodist/simulator/models/base.py | 417 ++++++++++++++++++ .../simulator/models/predefined_simulator.py | 379 ++++++++++++++++ autodist/simulator/reorganize_data.py | 18 + autodist/simulator/simulate.py | 35 ++ .../train_predefined_similator_clean.py | 361 +++++++++++++++ autodist/simulator/utils.py | 135 ++++++ 9 files changed, 1441 insertions(+) create mode 100644 autodist/simulator/__init__.py create mode 100644 autodist/simulator/config.py create mode 100644 autodist/simulator/models/__init__.py create mode 100644 autodist/simulator/models/base.py create mode 100644 autodist/simulator/models/predefined_simulator.py create mode 100644 autodist/simulator/reorganize_data.py create mode 100644 autodist/simulator/simulate.py create mode 100644 autodist/simulator/train_predefined_similator_clean.py create mode 100644 autodist/simulator/utils.py diff --git a/autodist/simulator/__init__.py b/autodist/simulator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/autodist/simulator/config.py b/autodist/simulator/config.py new file mode 100644 index 0000000..5e9b9ff --- /dev/null +++ b/autodist/simulator/config.py @@ -0,0 +1,96 @@ +from pathlib import Path + + +GRAPH_ITEM_DIR = f'{str(Path.home())}/graph_items' +SIMULATION_DATA_DIR = f'{str(Path.home())}/autosync_dataset_release' +CHECKPOINT_DIR = f'{str(Path.home())}' + + +simulation_params = { + 'ncf_large_adam_dense': { + 'model_batch_size': 256, + 'model_seq_len': 1, + 'data_dir': [ + f'{SIMULATION_DATA_DIR}/cluster1/ncf_large_adam_dense_g3.4.25.1', + f'{SIMULATION_DATA_DIR}/cluster1/ncf_large_adam_dense_g3.4.25.1_g3.4.25.2', + f'{SIMULATION_DATA_DIR}/cluster1/ncf_large_adam_dense_g3.4.25.1_g3.4.25.2_2', + f'{SIMULATION_DATA_DIR}/cluster1/ncf_large_adam_dense_g3.4.25.1_g3.4.25.2_g3.4.25.3_g3.4.25.4_3.4.25.6_g3.4.25.7_g3.4.25.8_g3.4.25.9', + f'{SIMULATION_DATA_DIR}/cluster1/ncf_large_adam_dense_g3.4.25.6_g3.4.25.7_g3.4.25.8_g3.4.25.9', + ], + 'original_graph_item_path': f'{GRAPH_ITEM_DIR}/ncf_original_graph_item', + 'save_dir': f'{CHECKPOINT_DIR}/ncf_predefined_checkpoints', + 'save_prefix': 'ckpV1_ncf_large_adam_dense', + 'baseline': 0.15, + 'scale': 0.5, + 'learning_rate': 0.01, + 'list_size': 2, + 'batch_size': 100, + 'ranking_loss_key': 'pairwise_logistic_loss', + 'model_version': 'v1', + 'do_train': False, + 'do_test': True, + 'checkpoint': f'{CHECKPOINT_DIR}/ncf/predefined_checkpoints/ckpV1_ncf_large_adam_dense_orca_all_600_0.83249_0.84517', + }, + 'bert': { + 'model_batch_size': 32, + 'model_seq_len': 128, + 'data_dir': [ + f'{SIMULATION_DATA_DIR}/cluster2/bert_large_random_orca_11', + f'{SIMULATION_DATA_DIR}/cluster2/bert_large_orca_11_random_rej-13_trial-100-_expolre-3000_model-on-bert-orca_embedding_sim-weight-0.3_max-par-20_if-partition-lb-100000-zhijie', + f'{SIMULATION_DATA_DIR}/cluster2/bert_large_orca_11_test_run', + f'{SIMULATION_DATA_DIR}/cluster2/bert_large_orca_11_random_rej-3.5_trial-50_expolre-1000_model-new-2_embedding_sim-1.0', + f'{SIMULATION_DATA_DIR}/cluster2/bert_large_orca_11_random_rej-8_trial-30_expolre-1000_model-new-3_embedding_sim-0.2_ps-only', + f'{SIMULATION_DATA_DIR}/cluster2/bert_large_orca_11_random_rej-8_trial-50_expolre-1000_model-new-3_embedding_sim-0.4_ps-only', + f'{SIMULATION_DATA_DIR}/cluster2/bert_large_orca_11_random_rej-13_trial-20-_expolre-100_model-on-bert-orca_embedding_sim-weight-0.3_max-par-20_if-partition-lb-100000', + f'{SIMULATION_DATA_DIR}/bert/bert-aws/bert-large-aws4g4', + f'{SIMULATION_DATA_DIR}/bert/bert-aws/bert_large_random_search_aws_4_ps_only', + ], + 'original_graph_item_path': f'{GRAPH_ITEM_DIR}/bert_original_graph_item_large', + 'save_dir': f'{CHECKPOINT_DIR}/bert_predefined_checkpoints', + 'save_prefix': 'ckpV1_bert_orca', + 'baseline': 0.04, + 'scale': 0.5, + 'learning_rate': 0.01, + 'list_size': 2, + 'batch_size': 100, + 'ranking_loss_key': 'pairwise_logistic_loss', + 'do_train': True, + 'do_test': True, + 'model_version': 'v1', + 'checkpoint': f'{CHECKPOINT_DIR}//bert_predefined_checkpoints/ckpV1_bert_orca_100_0.67000_0.50000', + }, + 'vgg16': { + 'model_batch_size': 32, + 'model_seq_len': 1, + 'data_dir': [ + f'{SIMULATION_DATA_DIR}/cluster1/vgg16_aws4_from_vgg16-orca2aws-421_explore3000', + f'{SIMULATION_DATA_DIR}/cluster1/vgg16_aws-4_model-aws-new_rejection-4_explore-3000_sim-weight-0.75', + f'{SIMULATION_DATA_DIR}/cluster1/vgg16_aws-4_model-aws-only_rejection-8_explore-3000_sim-weight-0.3', + f'{SIMULATION_DATA_DIR}/cluster1/vgg16_aws_4_pure_random', + ], + 'original_graph_item_path': f'{GRAPH_ITEM_DIR}/vgg16_original_graph_item', + 'save_dir': f'{CHECKPOINT_DIR}/vgg16_predefined_checkpoints', + 'save_prefix': 'ckpV1_vgg_aws', + 'baseline': 0.0, + 'scale': 0.5, + 'do_train': True, + 'do_test': True, + 'model_version': 'v1', + 'learning_rate': 0.01, + 'list_size': 2, + 'batch_size': 100, + 'ranking_loss_key': 'pairwise_logistic_loss', + 'checkpoint': '', + }, + 'resnet101': { + 'model_batch_size': 32, + 'model_seq_len': 1, + 'baseline': 0.5, + 'scale': 0.5, + 'data_dir': '', + 'learning_rate': 0.01, + 'list_size': 2, + 'batch_size': 100, + 'ranking_loss_key': 'pairwise_logistic_loss', + }, +} diff --git a/autodist/simulator/models/__init__.py b/autodist/simulator/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/autodist/simulator/models/base.py b/autodist/simulator/models/base.py new file mode 100644 index 0000000..ecd1621 --- /dev/null +++ b/autodist/simulator/models/base.py @@ -0,0 +1,417 @@ +"""Strategy Simulator.""" +from collections import defaultdict +import numpy as np +from enum import Enum + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + +from autodist.cluster import SSHCluster +from autodist.resource_spec import ResourceSpec +from autodist.kernel.device.resolver import DeviceResolver +from autodist.graph_item import GraphItem +from autodist.kernel.partitioner import PartitionerConfig +from autodist.proto.synchronizers_pb2 import AllReduceSynchronizer +from autodist.strategy.base import Strategy +from autodist.kernel.common.utils import get_op_name, get_consumers + +from utils import resolve_device_address, get_max_num_local_replica, get_num_local_replica + +GIGABITS = np.float(1e+9) +INFINITY = 1e+9 +NUM_RUNS = 500 + + + +class Var: + def __init__(self, + name=None, + is_sparse=False, + synchronizer=None, + shape=None, + dtype=None, + device=None, + compressor=None): + self.name = name + self.is_sparse = is_sparse + self.synchronizer = synchronizer + self.shape = shape + self.dtype = dtype + self.device = device + self.compressor = compressor + self.device = device + self.is_partition = False + + self.original_shape = self.shape + + @property + def var_size(self): + size = 1 + if self.shape: + for s in self.shape: + size *= s + return size + + @property + def original_var_size(self): + size = 1 + if self.original_shape: + for s in self.original_shape: + size *= s + return size + + def size_to_transfer(self, batch_size_per_gpu=1, seq_len=1): + if not self.is_sparse: + return self.var_size + else: + if not self.shape: # scalar + return 1 + + emb_size = 1 + if len(self.shape) > 1: + for i in range(1, len(self.original_shape)): + emb_size = emb_size * self.original_shape[i] + + sparse_data_size = batch_size_per_gpu * seq_len * emb_size + + # estimate the embedding of this partition simply using a proportional formula + ret = sparse_data_size * self.var_size / self.original_var_size + return ret + +class Partition(Var): + def __init__(self, + name=None, + is_sparse=False, + synchronizer=None, + shape=None, + dtype=None, + device=None, + compressor=None, + part_id=0, + original_shape=None, + partition_str=None, + num_shards=1): + super(Partition, self).__init__(name, is_sparse, synchronizer, shape, dtype, device, compressor) + self.is_partition = True + self.part_id = part_id + self.partition_str = partition_str + self.original_shape = original_shape + self.num_shards = num_shards + +class Resource: + def __init__(self, cluster, device_resolver, graph_replicas, network_bandwidth, cpu_worker_list, + gpu_worker_list, max_num_local_replica, total_num_local_replica, worker_num_replicas): + self.cluster=cluster + self.device_resolver=device_resolver + self.graph_replicas=graph_replicas + self.network_bandwidth=network_bandwidth + self.cpu_worker_list=cpu_worker_list + self.gpu_worker_list=gpu_worker_list + self.max_num_local_replica=max_num_local_replica + self.total_num_local_replica=total_num_local_replica + self.worker_num_replicas=worker_num_replicas + +class VarType(Enum): + SPARSE = 0 + DENSE = 1 + +def byte_size_load_fn(op): + """ + Load function that computes the byte size of a single-output `Operation`. + + Copied (with modifications) from tensorflow.contrib.training.python.training.device_setter. + + This is intended to be used with `"Variable"` ops, which have a single + `Tensor` output with the contents of the variable. However, it can also be + used for calculating the size of any op that has a single output. + + Intended to be used with `GreedyLoadBalancingStrategy`. + + Args: + op: An `Operation` with a single output, typically a "Variable" op. + + Returns: + The number of bytes in the output `Tensor`. + + Raises: + ValueError: if `op` does not have a single output, or if the shape of the + single output is not fully-defined. + """ + elem_size = op.dtype.size + shape = op.get_shape() + if not shape.is_fully_defined(): + # Due to legacy behavior, scalar "Variable" ops have output Tensors that + # have unknown shape when the op is created (and hence passed to this + # load function for placement), even though the scalar shape is set + # explicitly immediately afterward. + shape = tensor_shape.TensorShape(op.get_attr("shape")) + shape.assert_is_fully_defined() + return shape.num_elements() * elem_size + + +class VariableHelper: + def __init__(self, var, graph_item): + self.var = var + self.graph_item = graph_item + self._var_op_name = get_op_name(var.name) + self._grad = graph_item.var_op_name_to_grad_info[self._var_op_name][0] + + @property + def var_type(self): + return VarType.DENSE if isinstance(self._grad, ops.Tensor) else VarType.SPARSE + + @property + def is_sparse(self): + return True if self.var_type == VarType.SPARSE else False + + @property + def is_embedding(self): + for op in get_consumers(self.var.op): + if op.type == "ResourceGather": + return True + # op = new_graph_item.graph.get_operation_by_name( + # ops.prepend_name_scope(op.name, ARION_TO_DELETE_SCOPE) + # ) + return False + + @property + def shape(self): + if self.var.initial_value.shape.ndims: + return self.var.initial_value.shape.as_list() + else: + return None + + @property + def partitionable_axis(self): + valid_axis = [] + if not self.shape: + return valid_axis + # Sparse variable can only be partition along the 0th axis + # only sample axis for dense variables + if self.is_sparse or self.is_embedding: + valid_axis = [0] + return valid_axis + for idx, dim in enumerate(self.shape): + if dim > 1: + valid_axis.append(idx) + return valid_axis + + @property + def byte_size(self): + return float(byte_size_load_fn(self.var)) + + @property + def dtype(self): + return self.var.dtype + + +class PartHelper: + def __init__(self, part_idx, var, pc): + self.var = var + self.part_idx = part_idx + self.pc = pc + + @property + def shape(self): + shape = self.var.initial_value.shape.as_list() + dim_size = shape[self.pc.axis] // self.pc.num_shards + extras = shape[self.pc.axis] % self.pc.num_shards + if self.part_idx < extras: + dim_size += 1 + shape[self.pc.axis] = dim_size + return shape + + @property + def var_shape(self): + return self.var.initial_value.shape.as_list() + + @property + def byte_size(self): + return float(byte_size_load_fn(self.var)) \ + * float(self.shape[self.pc.axis]) / float(self.var_shape[self.pc.axis]) + + + +class SimulatorBase: + """Simulates strategies for a given graph and resource spec.""" + + def __init__(self, original_graph_item_path): + self._original_graph_item_path = original_graph_item_path + self._original_graph_item = GraphItem.deserialize(original_graph_item_path) + + def simulate(self, strategy: Strategy, resource_spec: ResourceSpec, checkpoint: str): + """Return simulated runtime value by feeding features to the cost model.""" + raise NotImplementedError() + + def inference(self, inputs, checkpoint): + raise NotImplementedError() + + def load_checkpoint(self, checkpoint): + raise NotImplementedError() + + def save_checkpoint(self, model, checkpoint): + raise NotImplementedError() + + def create_features(self, strategy: Strategy, resource_spec: ResourceSpec): + raise NotImplementedError() + + def extract_pre_feature(self, strategy: Strategy, resource_spec: ResourceSpec, cluster: SSHCluster, + device_resolver: DeviceResolver): + resource = self.setup_resource(resource_spec, cluster, device_resolver) + + name2var = {var.name: var for var_op, var in self._original_graph_item.trainable_var_op_to_var.items()} + + meta = defaultdict() + for node in strategy.node_config: + var_name = node.var_name + # for var_op, var in self._original_graph_item.trainable_var_op_to_var.items(): + # if var.name == var_name: + # break + var = name2var[var_name] + var_helper = VariableHelper(var, self._original_graph_item) + + if node.partitioner: + pc = PartitionerConfig(partition_str=node.partitioner) + for i, part in enumerate(node.part_config): + part_helper = PartHelper(i, var, pc) + synchronizer = getattr(part, part.WhichOneof('synchronizer')) + compressor = getattr(synchronizer, 'compressor', None) + reduction_destination = getattr(synchronizer, 'reduction_destination', None) + device = resolve_device_address(reduction_destination if reduction_destination else var.device, + resource.device_resolver) + + part_meta = Partition(name=part.var_name, + is_sparse=var_helper.is_sparse, + shape=part_helper.shape, + dtype=var_helper.dtype, + synchronizer=synchronizer, + part_id=i, + num_shards=pc.num_shards, + partition_str=pc.partition_str, + original_shape=var_helper.shape, + compressor=compressor, + device=device) + meta[part_meta.name] = part_meta + else: + synchronizer = getattr(node, node.WhichOneof('synchronizer')) + compressor = getattr(synchronizer, 'compressor', None) + reduction_destination = getattr(synchronizer, 'reduction_destination', None) + device = resolve_device_address(reduction_destination if reduction_destination else var.device, + resource.device_resolver) + + var_meta = Var(name=var_name, + is_sparse=var_helper.is_sparse, + shape=var_helper.shape, + dtype=var_helper.dtype, + synchronizer=synchronizer, + compressor=compressor, + device=device) + meta[var_meta.name] = var_meta + return meta, resource + + def extract_pre_feature_legacy(self, strategy): + """Don't use now!!!""" + meta = defaultdict() + for node in strategy.node_config: + var_name = node.var_name + for var_op, var in self._original_graph_item.trainable_var_op_to_var.items(): + if var.name == var_name: + break + var_op_name = var_op.name + var_helper = VariableHelper(var, self._original_graph_item) + synchronizer = getattr(node, node.WhichOneof('synchronizer')) + compressor = getattr(synchronizer, 'compressor', None) + if compressor is not None: + compressor = AllReduceSynchronizer.Compressor.Name(compressor) + reduction_destinations = getattr(synchronizer, 'reduction_destinations', None) + if not reduction_destinations or len(reduction_destinations) <= 1: + # this variable is not partitioned + device = reduction_destinations[0] if reduction_destinations else var.device + var_meta = Var(name=var_name, + is_sparse=var_helper.is_sparse, + shape=var_helper.shape, + dtype=var_helper.dtype, + synchronizer=synchronizer, + compressor=compressor, + device=device) + meta[var_meta.name] = var_meta + else: + # this variable is partitioned + num_partitions = len(reduction_destinations) + partition_list = [1] * len(var_helper.shape) + partition_list[0] = num_partitions + pc = PartitionerConfig(partition_list=partition_list) + for i, device in enumerate(reduction_destinations): + part_helper = PartHelper(i, var, pc) + part_meta = Partition(name='{}/part_{}:0'.format(var_op_name, i), + is_sparse=var_helper.is_sparse, + shape=part_helper.shape, + dtype=var_helper.dtype, + synchronizer=synchronizer, + part_id=i, + partition_str=pc.partition_str, + original_shape=var_helper.shape, + compressor=compressor, + device=device) + meta[part_meta.name] = part_meta + return meta + + def setup_resource(self, resource_spec: ResourceSpec, cluster: SSHCluster, device_resolver: DeviceResolver): + graph_replicas = [resolve_device_address(k, device_resolver) for k, v in resource_spec.gpu_devices] + # bandwidth + network_bandwidth = self.network_bandwidth(resource_spec, device_resolver) + # Other information + cpu_worker_list = [resolve_device_address(device, device_resolver) for device, _ in resource_spec.cpu_devices] + gpu_worker_list = [resolve_device_address(device, device_resolver) for device, _ in resource_spec.gpu_devices] + max_num_local_replica = get_max_num_local_replica(graph_replicas, cluster) + total_num_local_replica = len(graph_replicas) + worker_num_replicas = [get_num_local_replica(cpu_worker, graph_replicas, cluster) + for cpu_worker in cpu_worker_list] + resource = Resource(cluster=cluster, + device_resolver=device_resolver, + graph_replicas=graph_replicas, + network_bandwidth=network_bandwidth, + cpu_worker_list=cpu_worker_list, + gpu_worker_list=gpu_worker_list, + max_num_local_replica=max_num_local_replica, + total_num_local_replica=total_num_local_replica, + worker_num_replicas=worker_num_replicas) + return resource + + @staticmethod + def network_bandwidth(resource_spec: ResourceSpec, device_resolver: DeviceResolver): + """Calculates all P2P network bandwidths between nodes in the cluster.""" + devices = [device for device, _ in resource_spec.devices] + resolved_devices = [resolve_device_address(device, device_resolver) for device, _ in resource_spec.devices] + gpu_cpu_bw = 10000. # hardcode for now + network_bandwidth = {} # key: + for i in range(len(devices)): + if resolved_devices[i] not in network_bandwidth: + network_bandwidth[resolved_devices[i]] = {} + for j in range(i, len(devices)): + if resolved_devices[j] not in network_bandwidth: + network_bandwidth[resolved_devices[j]] = {} + ip_i = devices[i].split(':')[0] + ip_j = devices[j].split(':')[0] + if ip_i != ip_j: + network_bandwidth[resolved_devices[i]][resolved_devices[j]] \ + = GIGABITS * 20 # resource_spec.network_bandwidth[ip_i] # todo: solve. + network_bandwidth[resolved_devices[j]][resolved_devices[i]] \ + = GIGABITS * 20 # resource_spec.network_bandwidth[ip_j] + else: + network_bandwidth[resolved_devices[i]][resolved_devices[j]] = GIGABITS * gpu_cpu_bw + network_bandwidth[resolved_devices[j]][resolved_devices[i]] = GIGABITS * gpu_cpu_bw + + return network_bandwidth + + @staticmethod + def min_bandwitdh(worker_list, bandwidth): + min_bandwidth = INFINITY + num_workers = len(worker_list) + for i in range(num_workers): + for j in range(i, num_workers): + min_bandwidth = min(min_bandwidth, bandwidth[worker_list[j]][worker_list[i]]) + + @property + def original_graph_item_path(self): + return self._original_graph_item_path diff --git a/autodist/simulator/models/predefined_simulator.py b/autodist/simulator/models/predefined_simulator.py new file mode 100644 index 0000000..1d29146 --- /dev/null +++ b/autodist/simulator/models/predefined_simulator.py @@ -0,0 +1,379 @@ +"""Strategy Simulator.""" + +import pickle as pkl + +import tensorflow as tf +from tensorflow.python.eager import context + +from autodist.strategy.base import Strategy +from autodist.resource_spec import ResourceSpec +from autodist.proto.synchronizers_pb2 import PSSynchronizer, AllReduceSynchronizer +from autodist.kernel.device.resolver import DeviceResolver +from autodist.cluster import SSHCluster + +from models.base import SimulatorBase +from utils import resolved_devices_on_diff_machine, get_sparse_var_bits, get_dense_var_bits + +class PredefinedSimulator(SimulatorBase): + """Simulates strategies for a given graph and resource spec.""" + + def __init__(self, + original_graph_item_path, + fetches=None, + batch_size=1, + seq_len=1, + get_coef=True, + checkpoint=None): + + super(PredefinedSimulator, self).__init__(original_graph_item_path=original_graph_item_path) + + print("It's using predefined simulator. batch_size_per_gpu is {}".format(batch_size)) + self._fetches = fetches + self._batch_size_per_gpu = batch_size + self._seq_len = seq_len + self._get_coef = get_coef + self._checkpoint = checkpoint + self._weights = None + with context.eager_mode(): + if self._checkpoint: + self._weights = self.load_checkpoint(self._checkpoint) + + def simulate(self, strategy: Strategy, resource_spec: ResourceSpec, checkpoint=None): + """Return simulated runtime value.""" + cluster = SSHCluster(resource_spec) + device_resolver = DeviceResolver(cluster) + inputs = self.create_features(strategy, resource_spec, cluster, device_resolver) + with context.eager_mode(): + cost = self.inference(inputs, checkpoint) + return cost + + def inference(self, inputs, checkpoint=None): + if checkpoint is not None: + weights = self.load_checkpoint(checkpoint) + elif self._weights is not None: + weights = self._weights + else: + raise ValueError("No checkpoint provided in either initialization or inference.") + + if not isinstance(inputs, tf.Tensor): + inputs = tf.reshape(tf.convert_to_tensor(inputs), [1, len(inputs)]) + + if len(weights) == 4: + W0, b0, W, b = weights + inputs = tf.nn.elu(tf.matmul(inputs, W0) + b0) + cost = tf.matmul(inputs, W) + b + elif len(weights) == 2: + W, b = weights + cost = tf.matmul(inputs, W) + b + else: + raise ValueError + return cost + + def load_checkpoint(self, checkpoint=None): + if checkpoint is None: + if self._checkpoint is not None: + checkpoint = self._checkpoint + else: + raise ValueError("checkpoint is None: {}".format(checkpoint)) + self._weights = pkl.load(open(checkpoint, 'rb')) + # self._weights = json.load(open(checkpoint, 'r')) + print("Loaded checkpoint.") + # print(self._weights) + return self._weights + + def save_checkpoint(self, model, checkpoint): + pkl.dump(model, open(checkpoint, 'wb')) + self._checkpoint = checkpoint + self._weights = model + + def create_features_v0(self, strategy: Strategy, resource_spec: ResourceSpec, cluster: SSHCluster, + device_resolver: DeviceResolver): + var_sync_time, vars, resource = self.predefined_sync_time(strategy, resource_spec) + + # Add up sync time per device to find the slowest server time. + feature_keys = ['transmission', 'network_overhead', 'gpu_kernel_memory_latency'] + device_ps_sync_time = {} + var_ar_sync_time = {} + for var_name, sync_time in var_sync_time.items(): + if isinstance(vars[var_name].synchronizer, PSSynchronizer): + device = vars[var_name].device + if device not in device_ps_sync_time: + device_ps_sync_time[device] = {key: 0.0 for key in feature_keys} + for key in feature_keys: + device_ps_sync_time[device][key] += sync_time[0][key] + sync_time[1][key] + + else: # AllReduce + if var_name not in var_ar_sync_time: + var_ar_sync_time[var_name] = {key: 0.0 for key in feature_keys} + for key in feature_keys: + var_ar_sync_time[var_name][key] += sync_time[key] + + max_device_ps_sync_time = {key: 0.0 for key in feature_keys} + sum_device_ps_sync_time = {key: 0.0 for key in feature_keys} + sum_var_ar_sync_time = {key: 0.0 for key in feature_keys} + for key in feature_keys: + max_device_ps_sync_time[key] = max([d[key] for d in device_ps_sync_time.values()] or [0.0]) + sum_device_ps_sync_time[key] = sum([d[key] for d in device_ps_sync_time.values()] or [0.0]) + sum_var_ar_sync_time[key] = sum([d[key] for d in var_ar_sync_time.values()] or [0.0]) + + feat = [max_device_ps_sync_time[key] for key in feature_keys] \ + + [sum_device_ps_sync_time[key] for key in feature_keys] \ + + [sum_var_ar_sync_time[key] for key in feature_keys] + + return feat + + def create_features(self, strategy: Strategy, resource_spec: ResourceSpec, cluster: SSHCluster, + device_resolver: DeviceResolver): + # var_sync_time, vars, resource = self.predefined_sync_time(strategy, resource_spec) + + vars, resource = self.extract_pre_feature(strategy=strategy, resource_spec=resource_spec, + cluster=cluster, device_resolver=device_resolver) + + feature_keys = ['transmission', 'network_overhead', 'gpu_kernel_memory_latency'] + device_ps_sync_time = {} + group_ar_sync_time = {} + + for var_name, var in vars.items(): + if isinstance(var.synchronizer, PSSynchronizer): + sync_time = self.var_ps_time(var, resource) + device = vars[var_name].device + if device not in device_ps_sync_time: + device_ps_sync_time[device] = {key: 0.0 for key in feature_keys} + for key in feature_keys: + device_ps_sync_time[device][key] += sync_time[0][key] + sync_time[1][key] + elif isinstance(var.synchronizer, AllReduceSynchronizer): + sync_time = self.var_ar_time(var, resource) + var_group = sync_time['group'] + if var_group not in group_ar_sync_time: + group_ar_sync_time[var_group] = {key: 0.0 for key in feature_keys} + for key in feature_keys: + group_ar_sync_time[var_group][key] += sync_time[key] + else: + raise ValueError('{}'.format(type(var.synchronizer))) + + max_device_ps_sync_time = {key: 0.0 for key in feature_keys} + sum_device_ps_sync_time = {key: 0.0 for key in feature_keys} + max_group_ar_sync_time = {key: 0.0 for key in feature_keys} + sum_group_ar_sync_time = {key: 0.0 for key in feature_keys} + for key in feature_keys: + max_device_ps_sync_time[key] = max([d[key] for d in device_ps_sync_time.values()] or [0.0]) + sum_device_ps_sync_time[key] = sum([d[key] for d in device_ps_sync_time.values()] or [0.0]) + max_group_ar_sync_time[key] = max([d[key] for d in group_ar_sync_time.values()] or [0.0]) + sum_group_ar_sync_time[key] = sum([d[key] for d in group_ar_sync_time.values()] or [0.0]) + + feat = [max_device_ps_sync_time[key] for key in feature_keys] \ + + [sum_device_ps_sync_time[key] for key in feature_keys] \ + + [max_group_ar_sync_time[key] for key in feature_keys] \ + + [sum_group_ar_sync_time[key] for key in feature_keys] + + return feat + + def predefined_sync_time(self, strategy, resource_spec): + """ graph_item: transformed graph item """ + vars, resource = self.extract_pre_feature(strategy=strategy, resource_spec=resource_spec) + # Compute synchronization time for every var + var_sync_time = {} + for var_name, var in vars.items(): + if isinstance(var.synchronizer, PSSynchronizer): + var_sync_time[var_name] = self.var_ps_time(var, resource) + elif isinstance(var.synchronizer, AllReduceSynchronizer): + var_sync_time[var_name] = self.var_ar_time(var, resource) + else: + raise ValueError('{}'.format(type(var.synchronizer))) + return var_sync_time, vars, resource + + def var_ps_time(self, var, resource, network_overhead=0.0, gpu_kernel_memory_latency=0.0): + """Compute synchronization time of a variable in PS strategy.""" + def _helper(worker_list, worker_num_replicas=None): + if worker_num_replicas is None: + worker_num_replicas = [1.0] * len(worker_list) + + this_server_time = 0 + # network transfer: sum up all workers time. equals to the time cost of this server. + # TODO(Hao): didn't consider any parallelization among partitions + for k, worker in enumerate(worker_list): + if resolved_devices_on_diff_machine(var.device, worker): + if var.is_sparse: + this_worker_size = get_sparse_var_bits(var_size_to_transfer) * worker_num_replicas[k] + else: + this_worker_size = get_dense_var_bits(var_size_to_transfer, var.dtype) + this_server_time += this_worker_size / resource.network_bandwidth[var.device][worker] + + if self._get_coef: + return { + 'transmission': this_server_time, + 'network_overhead': len(worker_list), + 'gpu_kernel_memory_latency': resource.max_num_local_replica, + 'constant': 1.0, + # possible affecting factors. + 'var_name': var.name, + 'strategy': 'ps', + 'local_proxy': var.synchronizer.local_replication, + 'is_sparse': var.is_sparse, + 'size_to_transfer': var_size_to_transfer, + 'dtype': str(var.dtype), + # 'server_list': [partition.to_dict() for partition in server_list], + 'worker_list': worker_list, + 'cpu_worker_list': resource.cpu_worker_list, + 'gpu_worker_list': resource.gpu_worker_list, + 'worker_num_replicas': worker_num_replicas, + 'max_num_local_replica': resource.max_num_local_replica, + 'is_ps': True, + } + else: + return this_server_time + len(worker_list) * network_overhead + \ + gpu_kernel_memory_latency * resource.max_num_local_replica + + var_size_to_transfer = var.size_to_transfer(batch_size_per_gpu=self._batch_size_per_gpu, + seq_len=self._seq_len) + + if var.is_sparse: + send_time = _helper(resource.cpu_worker_list, worker_num_replicas=resource.worker_num_replicas) + receive_time = _helper(resource.gpu_worker_list) + else: + send_time = _helper(resource.cpu_worker_list) + if var.synchronizer.local_replication: + receive_time = _helper(resource.cpu_worker_list) + else: + receive_time = _helper(resource.gpu_worker_list) + + return send_time, receive_time + + def var_ar_time(self, var, resource, network_overhead=0.0, gpu_kernel_memory_latency=0.0): + """Compute synchronization time of a variable in AR strategy.""" + worker_list = resource.cpu_worker_list + num_workers = len(worker_list) + min_bandwidth = None + for i in range(num_workers): + for j in range(i, num_workers): + if min_bandwidth is None: + min_bandwidth = resource.network_bandwidth[worker_list[j]][worker_list[i]] + else: + min_bandwidth = min(min_bandwidth, resource.network_bandwidth[worker_list[j]][worker_list[i]]) + + # Compressor + if var.compressor == "PowerSGDCompressor" or var.compressor == 3: + rank = 10 # currently using default value. So hardcode here. # todo: confirm + # assume var must be a dense variable. + og_shape = var.shape + ndims = len(og_shape) + if ndims <= 1: # no compress + size_to_transfer = var.size_to_transfer(batch_size_per_gpu=self._batch_size_per_gpu, + seq_len=self._seq_len) + else: + if ndims > 2: + n = og_shape[0] + m = 1 + for s in og_shape[1:]: + m *= s # tensor's shape (n, m) + else: + n, m = og_shape[0], og_shape[1] + size_to_transfer = n * rank + m * rank + dtype = tf.float32 + elif var.compressor == "HorovodCompressorEF" or var.compressor == "HorovodCompressor" \ + or var.compressor == 2 or var.compressor == 1: + size_to_transfer = var.size_to_transfer(batch_size_per_gpu=self._batch_size_per_gpu, + seq_len=self._seq_len) + dtype = tf.float32 + elif var.compressor == "NoneCompressor" or var.compressor == 0: + size_to_transfer = var.size_to_transfer(batch_size_per_gpu=self._batch_size_per_gpu, + seq_len=self._seq_len) + dtype = var.dtype + else: + raise ValueError('Compressor does not exist: {}'.format(var.compressor)) + + # todo: chunk_size + # AllReduce communication time + # time = 2 * (num_workers - 1) * get_dense_var_bits(size_to_transfer, dtype) / (min_bandwidth * num_workers) + time = get_dense_var_bits(size_to_transfer, dtype) / min_bandwidth + + if self._get_coef: + return { + 'transmission': time, + 'network_overhead': 1, # len(worker_list), + 'gpu_kernel_memory_latency': resource.max_num_local_replica, + 'constant': 1.0, + # possible affecting factors. + 'var_name': var.name, + 'group': var.synchronizer.group, + 'strategy': 'allreduce', + 'is_sparse': False, + # 'chunk_size': chunk_size, + 'spec': 'NCCL', # default + 'compressor': var.compressor, + 'worker_list': worker_list, + 'num_workers': num_workers, + 'size_to_transfer': size_to_transfer, + 'dtype': str(dtype), + 'min_bandwidth': min_bandwidth, + 'max_num_local_replica': resource.max_num_local_replica, + 'is_ps': False, + } + else: + return time + network_overhead * len(worker_list) \ + + gpu_kernel_memory_latency * resource.max_num_local_replica + + + + # @staticmethod + # def var_ps_time(var_name, is_sparse, local_proxy, server_list, cpu_worker_list, gpu_worker_list, + # max_num_local_replica, worker_num_replicas, network_bandwidth, get_coef, + # network_overhead=0.0, gpu_kernel_memory_latency=0.0): + # """Compute synchrinzation time of a variable in PS strategy.""" + # + # def _helper(worker_list, worker_num_replicas=None): + # if worker_num_replicas is None: + # worker_num_replicas = [1.0] * len(worker_list) + # # Compute the slowest server + # slowest_server_time = 0 + # for j, server in enumerate(server_list): + # if server.size_to_transfer == 0: + # continue + # # network transfer: sum up all workers time. equals to the time cost of this server. + # this_server_time = 0 + # for k, worker in enumerate(worker_list): + # if _resolved_devices_on_diff_machine(server.device, worker): + # if is_sparse: + # this_worker_size = get_sparse_var_bits(server.size_to_transfer) * worker_num_replicas[k] + # else: + # this_worker_size = get_dense_var_bits(server.size_to_transfer, server.dtype) + # this_server_time += this_worker_size / network_bandwidth[server.device][worker] + # slowest_server_time = max(slowest_server_time, this_server_time) + # + # if get_coef: + # return { + # 'transmission': slowest_server_time, + # 'network_overhead': len(worker_list), + # 'gpu_kernel_memory_latency': max_num_local_replica, + # 'constant': 1.0, + # # possible affecting factors. + # 'var_name': var_name, + # 'strategy': 'ps', + # 'local_proxy': local_proxy, + # 'is_sparse': is_sparse, + # 'server_list': [partition.to_dict() for partition in server_list], + # 'worker_list': worker_list, + # 'cpu_worker_list': cpu_worker_list, + # 'gpu_worker_list': gpu_worker_list, + # 'worker_num_replicas': worker_num_replicas, + # 'max_num_local_replica': max_num_local_replica, + # } + # else: + # return slowest_server_time + len(worker_list) * network_overhead + \ + # gpu_kernel_memory_latency * max_num_local_replica + # + # if is_sparse: + # send_time = _helper(cpu_worker_list, worker_num_replicas=worker_num_replicas) + # receive_time = _helper(gpu_worker_list) + # else: + # send_time = _helper(cpu_worker_list) + # if local_proxy: + # receive_time = _helper(cpu_worker_list) + # else: + # receive_time = _helper(gpu_worker_list) + # + # if get_coef: + # # return {key: send_time[key]+receive_time[key] for key in send_time.keys()} + # return send_time, receive_time + # else: + # return send_time, receive_time diff --git a/autodist/simulator/reorganize_data.py b/autodist/simulator/reorganize_data.py new file mode 100644 index 0000000..7f8d8f1 --- /dev/null +++ b/autodist/simulator/reorganize_data.py @@ -0,0 +1,18 @@ +import glob +import os +import shutil + +data_dir = '/home/christy.li/autosync_dataset_release' +for cluster in ['cluster1', 'cluster2']: + folders = glob.glob(os.path.join(data_dir, cluster, '*')) + print('\n{} folders in {}'.format(len(folders), cluster)) + for folder in folders: + print('\nfolder', folder) + resource_folder = os.path.join(folder, 'resource_specs') + if os.path.exists(resource_folder): + resource_files = glob.glob(os.path.join(resource_folder, '*')) + shutil.copy(resource_files[0], folder) + os.rename('/'.join(resource_files[0].split('/')[:-2]+[resource_files[0].split('/')[-1]]), + os.path.join(folder, 'resource_spec.yml')) + shutil.rmtree(resource_folder) + diff --git a/autodist/simulator/simulate.py b/autodist/simulator/simulate.py new file mode 100644 index 0000000..8eedcd1 --- /dev/null +++ b/autodist/simulator/simulate.py @@ -0,0 +1,35 @@ +import glob + +from tensorflow.python.eager import context + +from models.predefined_simulator import PredefinedSimulator +from autodist.strategy import base +from autodist.resource_spec import ResourceSpec + +from pathlib import Path + +GRAPH_ITEM_DIR = f'{str(Path.home())}/graph_items' +SIMULATION_DATA_DIR = f'{str(Path.home())}/autosync_dataset_release' +CHECKPOINT_DIR = f'{str(Path.home())}' + + +resource_spec_file = f'{SIMULATION_DATA_DIR}/cluster1/bert12l_aws4_from_bert3l_aws4_2/resource_spec.yml' +original_graph_item_path = f'{GRAPH_ITEM_DIR}/bert_original_graph_item_large' +checkpoint_path = f'{CHECKPOINT_DIR}/bert_predefined_checkpoints/ckpV1_bert_orca_100_0.67000_0.50000' +strategy_dir = f'{SIMULATION_DATA_DIR}/cluster1/bert12l_aws4_from_bert3l_aws4_2/strategies' +strategy_files = glob.glob(f'{strategy_dir}/*') +strategy_file = strategy_files[0] + + +with context.graph_mode(): + + strategy = base.Strategy.deserialize(strategy_file) + + simulator = PredefinedSimulator(original_graph_item_path=original_graph_item_path) + + cost = simulator.simulate(strategy=strategy, resource_spec=ResourceSpec(resource_spec_file), checkpoint=checkpoint_path) + + print(f"strategy_file: {strategy_file}, cost: {cost}") + + +print('finished') diff --git a/autodist/simulator/train_predefined_similator_clean.py b/autodist/simulator/train_predefined_similator_clean.py new file mode 100644 index 0000000..3895c35 --- /dev/null +++ b/autodist/simulator/train_predefined_similator_clean.py @@ -0,0 +1,361 @@ +import sys +import os +import numpy as np +import tensorflow as tf +from os.path import expanduser +import tqdm +import os +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +os.environ["CUDA_VISIBLE_DEVICES"]="1" +import argparse +import importlib +import glob +import json + +from tensorflow.python.eager import context +import tensorflow_ranking as tfr + +from autodist.strategy.base import Strategy +from models.predefined_simulator import PredefinedSimulator +from autodist.cluster import SSHCluster +from autodist.resource_spec import ResourceSpec +from autodist.kernel.device.resolver import DeviceResolver +np.random.seed(110) + + + +RankingLossKeys = { + # Names for the ranking based loss functions. + 'pairwise_hinge_loss': tfr.losses.RankingLossKey.PAIRWISE_HINGE_LOSS, + 'pairwise_logistic_loss': tfr.losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS, + 'pairwise_soft_zero_one_loss': tfr.losses.RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS, + 'softmax_loss': tfr.losses.RankingLossKey.SOFTMAX_LOSS, + 'sigmoid_cross_entropy_loss': tfr.losses.RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS, + 'mean_squared_loss': tfr.losses.RankingLossKey.MEAN_SQUARED_LOSS, + 'list_mle_loss': tfr.losses.RankingLossKey.LIST_MLE_LOSS, + 'approx_ndcg_loss': tfr.losses.RankingLossKey.APPROX_NDCG_LOSS, +} + + +def split_dataset(inputs, shuffle=True, train_ratio=0.7, test_ratio=0.15): + assert isinstance(inputs, list) + nb_elements = len(inputs) + nb_samples = len(inputs[0]) + n_train = int(nb_samples * train_ratio) + n_test = int(nb_samples * test_ratio) + shuffled = [] + train = [] + valid = [] + test = [] + + if shuffle: + random_indices = np.random.permutation(list(range(nb_samples))) + for i in range(nb_elements): + shuffled_i = [inputs[i][j] for j in random_indices] + train.append(shuffled_i[:n_train]) + valid.append(shuffled_i[n_train:-n_test]) + test.append(shuffled_i[-n_test:]) + else: + for i in range(nb_elements): + train.append(inputs[i][:n_train]) + valid.append(inputs[i][n_train:-n_test]) + test.append(inputs[i][-n_test:]) + + return train, valid, test + + +class TFRIterator: + def __init__(self, X, Y, list_size, batch_size, split, baseline=0.0, scale=1.0): + assert len(X) > 0, 'data: {}'.format(len(X)) + self.X = X + self.Y = Y + self.list_size = list_size + self.baseline = baseline + self.scale = scale + self.batch_size = batch_size + self.split = split + self.n = len(X) + self.num_examples = self.get_num_examples() + print('Split: {},\tnumber of samples: {},\tnumber of examples: {},\tmin of y: {}'.format( + split, len(X), self.num_examples, self.get_min_y())) + + def get_min_y(self): + return np.min(self.Y) + + def get_num_examples(self): + n_examples = 1 + for i in range(self.list_size): + n_examples *= (len(self.X) -1) + return n_examples + + def get_next(self): + xs = [[] for _ in range(self.list_size)] + ys = [] + for i in range(self.batch_size): + y =[] + for j in range(self.list_size): + ri = np.random.randint(self.n) + rx = self.X[ri] + ry = self.Y[ri] + xs[j].append(np.array(rx, dtype=np.float32)) + y.append(ry) + assert ry * self.scale - self.baseline > 0, '{}, {}, {}'.format(ry, self.scale, self.baseline) + ys.append(y) + xs = [np.array(xx, dtype=np.float32) for xx in xs] + ys = np.array(ys, dtype=np.float32) + if self.split == 'train': # normalize y as its used for loss weights. + ys = (ys * self.scale - self.baseline) + + return xs + [ys] + + +def load_from_folders_offline(simulation_dirs): + print('simulation_dirs', simulation_dirs) + X = [] + Y = [] + for simulation_dir in simulation_dirs: + x, y = load_from_one_folder_offline(simulation_dir) + if len(x) == 0: + print('Simulation folder does not have files: {}, skipping it.'.format(simulation_dir)) + continue + Y.append(y) + X.extend(x) + + Y = np.concatenate(Y, axis=0) + miny = np.min(Y) + assert len(X) == len(Y) + return X, Y + + +def load_from_one_folder_offline(simulation_dir): + runtime_files = glob.glob(os.path.join(simulation_dir, 'runtimes/*'), recursive=False) + resource_file = os.path.join(simulation_dir, 'resource_spec.yml') + + print("Searched runtime files: {}".format(len(runtime_files))) + X = [] + Y = [] + for runtime_file in runtime_files: + strategy_file = runtime_file.replace("runtimes/", 'strategies/') + if not os.path.exists(strategy_file) or not os.path.isfile(strategy_file): + print('strategy_file does not exist: {}.'.format(strategy_file)) + continue + X.append((strategy_file, resource_file)) + runtime = json.load(open(runtime_file, 'r')) + y = runtime['average'] + Y.append(y) + Y = np.array(Y, dtype=np.float) + print('Data points:{}, simulation_dir: {}'.format(len(X), simulation_dir)) + return X, Y + + +def main(args, sim_model_params): + + data_dir = sim_model_params['data_dir'] + original_graph_item_path = sim_model_params['original_graph_item_path'] + batch_size = sim_model_params['batch_size'] + ranking_loss_key = sim_model_params['ranking_loss_key'] + learning_rate = sim_model_params['learning_rate'] + list_size = sim_model_params['list_size'] + baseline = sim_model_params['baseline'] + scale = sim_model_params['scale'] + save_dir = sim_model_params['save_dir'] + save_prefix = sim_model_params['save_prefix'] + do_train = sim_model_params['do_train'] + do_test = sim_model_params['do_test'] + checkpoint = sim_model_params['checkpoint'] + model_version = sim_model_params['model_version'] + + # Create simulator + simulator = PredefinedSimulator(original_graph_item_path, batch_size=sim_model_params['model_batch_size'], + seq_len=sim_model_params['model_seq_len']) + + # Create features + strategy_resource_files, Y = load_from_folders_offline(data_dir) + print("Createing features...") + X = [] + prev_resource_file = None + with context.graph_mode(): + for strategy_file, resource_file in tqdm.tqdm(strategy_resource_files): + # For one folder with a common resource spec file, we load it only once to avoid costly computation. + if prev_resource_file is None or resource_file != prev_resource_file: + prev_resource_file = resource_file + resource_spec = ResourceSpec(resource_file) + cluster = SSHCluster(resource_spec) + device_resolver = DeviceResolver(cluster) + # x = simulator.create_features(Strategy.deserialize(strategy_file), ResourceSpec(resource_file)) + x = simulator.create_features(Strategy.deserialize(strategy_file), resource_spec, cluster, device_resolver) + X.append(x) + X = np.array(X, dtype=np.float) + print("Finished createing features.") + + # Create model + W = tf.Variable(tf.random.uniform([args.hidden_dim, 1]), name='W', dtype=tf.float32) + b = tf.Variable(0.0, name='b', dtype=tf.float32) + if model_version == 'v2': + W0 = tf.Variable(tf.random.uniform([args.hidden_dim, args.hidden_dim]), name='W0', dtype=tf.float32) + b0 = tf.Variable(0.0, name='b0', dtype=tf.float32) + loss_fn = tfr.losses.make_loss_fn(RankingLossKeys[ranking_loss_key]) + major_version, _, _ = tf.version.VERSION.split('.') + if major_version == '1': + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + else: + optimizer = tf.optimizers.Adam(learning_rate) + + def forward(xs): + rs = [] + for x in xs: + if model_version == 'v2': + x = tf.nn.elu(tf.matmul(x, W0) + b0) + r = tf.matmul(x, W) + b + rs.append(r) + r = tf.concat(rs, axis=1, name='logits') + return r + + @tf.function + def train_steps(inputs_iterator, total_steps, loss_fn): + + def train_step(input, loss_fn): + with tf.GradientTape() as tape: + logits = forward(input[:-1]) + loss = loss_fn(labels=input[-1], logits=logits, features={}) + vs = [W0, b0, W, b] if model_version == 'v2' else [W, b] + gradients = tape.gradient(loss, vs) + train_op = optimizer.apply_gradients(zip(gradients, vs)) + pred = tf.squeeze(tf.argmax(logits, axis=1)) + labels = tf.squeeze(tf.argmax(input[-1], axis=1)) + acc = tf.equal(pred, labels) + return loss, acc + + losses = [] + accs = [] + for step in range(total_steps): + l, a = train_step(inputs_iterator.get_next(), loss_fn) + losses.append(l) + accs.append(a) + return losses, accs + + @tf.function + def eval_step(input): + logits = forward(input[:-1]) + preds = tf.squeeze(tf.argmax(logits, axis=1)) + labels = tf.squeeze(tf.argmax(input[-1], axis=1)) + acc = tf.equal(preds, labels) + return acc, labels, preds, input[-1], logits + + def eval_steps(iterator, total_test_steps): + test_acc = [] + test_preds = [] + test_labels = [] + test_logits = [] + test_scores = [] + for step in range(total_test_steps): + acc, labels, preds, scores, logits = eval_step(iterator.get_next()) + test_acc.append(acc) + test_labels.append(labels) + test_preds.append(preds) + test_scores.append(scores) + test_logits.append(logits) + test_acc = tf.concat(test_acc, axis=0) + test_acc = tf.cast(test_acc, tf.float32) + avg_test_acc = tf.math.reduce_mean(test_acc) + test_labels = tf.concat(test_labels, axis=0) + test_preds = tf.concat(test_preds, axis=0) + test_scores = tf.concat(test_scores, axis=0) + test_logits = tf.concat(test_logits, axis=0) + return avg_test_acc, test_acc, test_labels, test_preds, test_scores, test_logits + + if do_train: + train_set, valid_set, test_set = split_dataset([X, Y], shuffle=True, train_ratio=0.7, test_ratio=0.15) + X_train, Y_train = train_set + X_valid, Y_valid = valid_set + X_test, Y_test = test_set + inputs_iterator = TFRIterator(X=X_train, Y=Y_train, list_size=list_size, batch_size=batch_size, split='train', + baseline=baseline, scale=scale) + valid_iterator = TFRIterator(X=X_valid, Y=Y_valid, list_size=list_size, batch_size=batch_size, split='valid') + test_iterator = TFRIterator(X=X_test, Y=Y_test, list_size=list_size, batch_size=batch_size, split='test') + total_train_steps = max(1, min(inputs_iterator.get_num_examples() // batch_size, 100)) + total_valid_steps = max(1, valid_iterator.get_num_examples() // batch_size) + total_test_steps = max(1, test_iterator.get_num_examples() // batch_size) + print("Total train steps per epoch: {}".format(total_train_steps)) + print("Total valid steps per epoch: {}".format(total_valid_steps)) + print("Total test steps: {}".format(total_test_steps)) + + + print("\nTrain model...") + losses = [] + for epoch in range(args.epochs): + loss, acc = train_steps(inputs_iterator, total_train_steps, loss_fn) + losses.extend(loss) + avgloss = sum(losses) / float(len(losses)) + print('Step: {}, avgloss: {:.5f}'.format(epoch, avgloss)) + if (epoch+1) % args.eval_every_epochs == 0: + print("\nEvaluate on valid set...") + avg_valid_acc, *_= eval_steps(valid_iterator, total_valid_steps) + print('avg_valid_acc: {}'.format(avg_valid_acc.numpy())) + print("Evaluate on test set...") + avg_test_acc, *_= eval_steps(test_iterator, total_test_steps) + print('avg_test_acc: {}\n'.format(avg_test_acc.numpy())) + print('W', W.numpy()) + print('b', b.numpy()) + + if (epoch+1) % args.save_every_epochs == 0: + if not os.path.exists(save_dir): + os.mkdir(save_dir) + checkpoint = '{}/{}_{}_{:.5f}_{:.5f}'.format(save_dir, save_prefix, epoch+1, + avg_valid_acc, avg_test_acc) + print("Save to {}".format(checkpoint)) + simulator.save_checkpoint([W0, b0, W, b] if model_version == 'v2' else [W, b], checkpoint) + + elif do_test: + print("Load from {}".format(checkpoint)) + weights = simulator.load_checkpoint(checkpoint) + if model_version == 'v2' and len(weights) == 4: + W0, b0, W, b = weights + elif model_version == 'v1' and len(weights) == 2: + W, b = weights + else: + raise ValueError + + test_iterator = TFRIterator(X=X, Y=Y, list_size=list_size, batch_size=batch_size, split='test') + total_test_steps = max(1, test_iterator.get_num_examples() // batch_size) + print("\nEvaluate on test set...") + avg_test_acc, test_acc, test_labels, test_preds, test_scores, test_logits = eval_steps(test_iterator, total_test_steps) + for i, labels, preds, scores, logits in zip(range(100), test_labels, test_preds, test_scores, test_logits): + print('labels', labels.numpy(), 'preds', preds.numpy(), 'scores', scores.numpy(), 'logits', logits.numpy()) + print('avg_test_acc', avg_test_acc.numpy()) + + test_iterator_single = TFRIterator(X=X, Y=Y, list_size=1, batch_size=len(X), split='test') + print("\nEvaluate each example in test set...") + avg_test_acc, test_acc, test_labels, test_preds, test_scores, test_logits = eval_steps(test_iterator_single, 1) + for i, labels, preds, scores, logits in zip(range(100), test_labels, test_preds, test_scores, test_logits): + print('labels', labels.numpy(), 'preds', preds.numpy(), 'scores', scores.numpy(), 'logits', logits.numpy()) + test_logits = sorted(list(test_logits.numpy())) + top_10_persent = test_logits[:int(len(test_logits)*0.1)] + print('top_10_persent', top_10_persent) + print('top_10_persent threshold', top_10_persent[-1]) + print('test_logits', test_logits) + + + + +def get_args_parser(): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("-ms", "--model_to_sim", default='bert', type=str, help="") + parser.add_argument("-sc", "--simulation_config", default='config', type=str, help="") + parser.add_argument("-hd", "--hidden_dim", default=12, type=int, help="") + parser.add_argument("-es", "--epochs", default=100, type=int, help="") + parser.add_argument("-ee", "--eval_every_epochs", default=10, type=int, help="") + parser.add_argument("-se", "--save_every_epochs", default=100, type=int, help="") + return parser + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Predefined simularot training script", parents=[get_args_parser()]) + args = parser.parse_args() + + module = importlib.import_module(args.simulation_config) # import module from str + simulation_params = getattr(module, "simulation_params") + + main(args, simulation_params[args.model_to_sim]) + + diff --git a/autodist/simulator/utils.py b/autodist/simulator/utils.py new file mode 100644 index 0000000..c476e6e --- /dev/null +++ b/autodist/simulator/utils.py @@ -0,0 +1,135 @@ +import tensorflow as tf +from tensorflow.python.framework import device_spec + +from autodist.kernel.device.resolver import DeviceResolver + + +DTYPE2BITS = { + tf.float16: 16, + "tf.float16": 16, + "": 16, + tf.float32: 32, + 'tf.float32': 32, + "": 32, + "": 32, + tf.float64: 64, + 'tf.float64': 64, + "": 64, + tf.bfloat16: 16, + 'tf.bfloat16': 16, + "": 16, + tf.complex64: 64, + 'tf.complex64': 64, + "": 64, + tf.complex128: 128, + 'tf.complex128': 128, + "": 128, + tf.int8: 8, + 'tf.int8': 8, + "": 8, + tf.uint8: 8, + 'tf.uint8': 8, + "": 8, + tf.uint16: 16, + 'tf.uint16': 16, + "": 16, + tf.uint32: 32, + 'tf.uint32': 32, + "": 32, + tf.uint64: 64, + 'tf.uint64': 64, + "": 64, + tf.int16: 16, + 'tf.int16': 16, + "": 16, + tf.int32: 32, + 'tf.int32': 32, + "": 32, + tf.int64: 64, + 'tf.int64': 64, + "": 64, + tf.bool: 1, + 'tf.bool': 1, + "": 1, + tf.string: 1, # todo: confirm + 'tf.string': 1, # todo: confirm + "": 1, # todo: confirm + tf.qint8: 8, + 'tf.qint8': 8, + "": 8, + tf.quint8: 8, + 'tf.quint8': 8, + "": 8, + tf.qint16: 16, + 'tf.qint16': 16, + "": 16, + tf.quint16: 16, + 'tf.quint16': 16, + "": 16, + tf.qint32: 32, + 'tf.qint32': 32, + "": 32, + tf.resource: 0, # its tensor shape is either [] or [None] todo: confirm + 'tf.resource': 0, # its tensor shape is either [] or [None] todo: confirm + "": 0, # its tensor shape is either [] or [None] todo: confirm +} + +def get_dtype_bits(dtype): + return DTYPE2BITS[dtype] if dtype in DTYPE2BITS else DTYPE2BITS[str(dtype)] + + +def get_dense_var_bits(size, dtype): + return size * get_dtype_bits(dtype) + + + +def get_sparse_var_bits(size): + # same size of values, indices, dense_shape + return size * (get_dtype_bits(tf.float32) + 2 * get_dtype_bits(tf.int64)) \ + + 2 * get_dtype_bits(tf.int64) + + + +def resolve_device_address(device: str, device_resolver: DeviceResolver): + # change real ip address to /job:worker/task:0 + if not device: + return device + parts = device.split(':') + if parts and parts[0] in device_resolver._address_to_tasks: + resolved_device = device_resolver._address_to_tasks[parts[0]][0] + resolved = '/job:{}/task:{}/device:'.format(resolved_device['job'], resolved_device['task']) + resolved = resolved + ':'.join(parts[-2:]) + return resolved + else: + raise ValueError("cannot resolve device: {} using device_resolver: {}".format( + device, device_resolver._address_to_tasks)) + + +def resolved_devices_on_diff_machine(device1, device2): + # e.g., '/job:worker/task:1/device:CPU:0', '/job:worker/task:1/GPU:0' + node1 = ':'.join(device1.split('/')[:-1]) + node2 = ':'.join(device2.split('/')[:-1]) + return node1 != node2 + + +def get_max_num_local_replica(replicas, cluster): + replica_devices = {device_spec.DeviceSpecV2.from_string(r) for r in replicas} + replica_hosts = {cluster.get_address_from_task(d.job, d.task) for d in replica_devices} + max_num_local_replica = 0 + for host in replica_hosts: + num_local_replica = sum(1 for d in replica_devices + if cluster.get_address_from_task(d.job, d.task) == host) + max_num_local_replica = max(max_num_local_replica, num_local_replica) + return max_num_local_replica + + +def get_num_local_replica(host, replicas, cluster): + # host: e.g., '/job:worker/task:0/device:CPU:0' + replica_devices = {device_spec.DeviceSpecV2.from_string(r) for r in replicas} + host_device = device_spec.DeviceSpecV2.from_string(host) + num_local_replica = sum(1 for d in replica_devices + if cluster.get_address_from_task(d.job, d.task) == + cluster.get_address_from_task(host_device.job, host_device.task)) + return num_local_replica + + From 2f0ad2b80f09fa94438288c1835930aa8e82d067 Mon Sep 17 00:00:00 2001 From: Christy Li Date: Tue, 15 Dec 2020 13:11:09 -0500 Subject: [PATCH 2/4] update --- autodist/simulator/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autodist/simulator/models/base.py b/autodist/simulator/models/base.py index ecd1621..9ca67ef 100644 --- a/autodist/simulator/models/base.py +++ b/autodist/simulator/models/base.py @@ -395,9 +395,9 @@ def network_bandwidth(resource_spec: ResourceSpec, device_resolver: DeviceResolv ip_j = devices[j].split(':')[0] if ip_i != ip_j: network_bandwidth[resolved_devices[i]][resolved_devices[j]] \ - = GIGABITS * 20 # resource_spec.network_bandwidth[ip_i] # todo: solve. + = GIGABITS * resource_spec.network_bandwidth[ip_i] # todo: solve. network_bandwidth[resolved_devices[j]][resolved_devices[i]] \ - = GIGABITS * 20 # resource_spec.network_bandwidth[ip_j] + = GIGABITS * resource_spec.network_bandwidth[ip_j] else: network_bandwidth[resolved_devices[i]][resolved_devices[j]] = GIGABITS * gpu_cpu_bw network_bandwidth[resolved_devices[j]][resolved_devices[i]] = GIGABITS * gpu_cpu_bw From 0e70c5b9557640a48df4b5c3ec624b74c3f0b132 Mon Sep 17 00:00:00 2001 From: Christy Li Date: Tue, 15 Dec 2020 14:49:39 -0500 Subject: [PATCH 3/4] add readme --- autodist/simulator/README.md | 50 +++++++++++++++++++ .../simulator/models/predefined_simulator.py | 1 - autodist/simulator/reorganize_data.py | 18 ------- 3 files changed, 50 insertions(+), 19 deletions(-) create mode 100644 autodist/simulator/README.md delete mode 100644 autodist/simulator/reorganize_data.py diff --git a/autodist/simulator/README.md b/autodist/simulator/README.md new file mode 100644 index 0000000..f600e9b --- /dev/null +++ b/autodist/simulator/README.md @@ -0,0 +1,50 @@ + +The ``simulator`` folder implements predefined simulator in AutoSync proposed in: [AutoSync: Learning to Synchronize for Data-Parallel Distributed Deep Learning](https://papers.nips.cc/paper/2020/hash/0a2298a72858d90d5c4b4fee954b6896-Abstract.html). + +## Download Data +Download the data from https://drive.google.com/file/d/1CTtIVORxzF_wOmxrsusbAhNC3bwmxuD8/view?usp=sharing. + +The data folder is organized by ML model categories. For a ML model, the simulation is conducted on two kinds of clusters (AWS and an in-house cluster). Each data sample comprises a pair. The resource specification file corresponds to all runtimes and strategies inside runtimes and strategies folders, respectively. The detailed data organization is: + + Model-1/ (e.g., BERT-large) + Cluster-1/ (e.g., AWS-4-g4) + resource_spec.yml + runtime/ + .yml + strategies/ + + Cluster-2 (e.g., In-house-11-nodes) + resource_spec.yml + runtime/ + .yml + strategies/ + + Model-2 + ...... + + Model-3 + ...... + + +## Train a predefined simulator + +Inside ``autodist/simulator`` folder. + +Define configuration in ``config.py`` including the model to simulate and the data folders (samples) to use. + +Run: ``python absolute_dir/train_predefined_simulator_clean.py`` + + +## Simulate (infer) a strategy + +Inside ``autodist/simulator`` folder. + +Define the strategy to simulate and checkpoint to load in ``simulate.py``. + +Run: ``python simulate.py``. + + +## Read a strategy + +Use ``strategy = base.Strategy.deserialize(strategy_file)`` to read a strategy stored as ``strategy_file``. + diff --git a/autodist/simulator/models/predefined_simulator.py b/autodist/simulator/models/predefined_simulator.py index 1d29146..5afca9d 100644 --- a/autodist/simulator/models/predefined_simulator.py +++ b/autodist/simulator/models/predefined_simulator.py @@ -124,7 +124,6 @@ def create_features_v0(self, strategy: Strategy, resource_spec: ResourceSpec, cl def create_features(self, strategy: Strategy, resource_spec: ResourceSpec, cluster: SSHCluster, device_resolver: DeviceResolver): - # var_sync_time, vars, resource = self.predefined_sync_time(strategy, resource_spec) vars, resource = self.extract_pre_feature(strategy=strategy, resource_spec=resource_spec, cluster=cluster, device_resolver=device_resolver) diff --git a/autodist/simulator/reorganize_data.py b/autodist/simulator/reorganize_data.py deleted file mode 100644 index 7f8d8f1..0000000 --- a/autodist/simulator/reorganize_data.py +++ /dev/null @@ -1,18 +0,0 @@ -import glob -import os -import shutil - -data_dir = '/home/christy.li/autosync_dataset_release' -for cluster in ['cluster1', 'cluster2']: - folders = glob.glob(os.path.join(data_dir, cluster, '*')) - print('\n{} folders in {}'.format(len(folders), cluster)) - for folder in folders: - print('\nfolder', folder) - resource_folder = os.path.join(folder, 'resource_specs') - if os.path.exists(resource_folder): - resource_files = glob.glob(os.path.join(resource_folder, '*')) - shutil.copy(resource_files[0], folder) - os.rename('/'.join(resource_files[0].split('/')[:-2]+[resource_files[0].split('/')[-1]]), - os.path.join(folder, 'resource_spec.yml')) - shutil.rmtree(resource_folder) - From 3c5e7f17943c84d5cf4236cb9a4b422d4e87d8ef Mon Sep 17 00:00:00 2001 From: Christy Li Date: Tue, 15 Dec 2020 14:59:49 -0500 Subject: [PATCH 4/4] udpate readme --- autodist/simulator/README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/autodist/simulator/README.md b/autodist/simulator/README.md index f600e9b..ec22291 100644 --- a/autodist/simulator/README.md +++ b/autodist/simulator/README.md @@ -46,5 +46,11 @@ Run: ``python simulate.py``. ## Read a strategy -Use ``strategy = base.Strategy.deserialize(strategy_file)`` to read a strategy stored as ``strategy_file``. +Use + +``from autodist.strategy import base`` + +``strategy = base.Strategy.deserialize(strategy_file)`` + +to read a strategy from ``strategy_file``.