Skip to content
Empty file added autodist/autosync/__init__.py
Empty file.
Empty file.
143 changes: 143 additions & 0 deletions autodist/autosync/simulator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2020 Petuum. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simulator base class."""
from collections import OrderedDict

import os

from autodist.graph_item import GraphItem
from autodist.kernel.partitioner import PartitionerConfig
from autodist.resource_spec import ResourceSpec
from autodist.strategy.auto.item import VariableItem, PartItem, ResourceItem


class SimulatorBase:
"""Simulates strategies for a given graph and resource spec."""

def __init__(self,
graph_item=None,
resource_spec=None):
"""
Constructor for simulator base class
Args:
graph_item: a GraphItem object, or a path to a serialized GraphItem object.
resource_spec: a ResourceSpec object, or a path to a resource file.
"""
# check if it is a path
self._graph_item = None
if isinstance(graph_item, GraphItem):
self._graph_item = graph_item
elif isinstance(graph_item, str) and os.path.exists(graph_item):
self._graph_item = GraphItem.deserialize(graph_item)
else:
raise ValueError("Invalid graph_item: {}".format(graph_item))

self._resource_spec = None
if isinstance(resource_spec, ResourceSpec):
self._resource_spec = resource_spec
elif isinstance(resource_spec, str) and os.path.exists(resource_spec):
self._resource_spec = ResourceSpec(resource_spec)
else:
raise ValueError("Invalid resource_spec: {}".format(resource_spec))

def update_graph_item(self, graph_item):
"""Change the default graph_item with this simulator."""
if not graph_item:
raise ValueError('Empty graph item.')
self._graph_item = graph_item

def update_resource_spec(self, resource_spec):
"""Change the default resource_spec with this simulator."""
if not resource_spec:
raise ValueError('Empty resource spec.')
self._resource_spec = resource_spec

def simulate(self,
strategy,
graph_item=None,
resource_spec=None,
*args,
**kwargs):
"""Return simulated runtime cost given (strategy, graph_item, resource_spec) tuple."""
raise NotImplementedError()

def inference(self, *args, **kwargs):
"""Abstract method for simulator inference."""
raise NotImplementedError()

def load_checkpoint(self, checkpoint):
"""
Load a checkpoint file as weights of the simulator.

Args:
checkpoint: path to a checkpoint file.
"""
raise NotImplementedError()

# def save_checkpoint(self, model, checkpoint):
# """
# Save a trained weight as a checkpoint file.
#
# Args:
# model: trained model.
# checkpoint: path where to save the checkpoint.
# """
# raise NotImplementedError()

def preprocess(self,
strategy,
graph_item=None,
resource_spec=None):
"""
Preprocess a (strategy, graph_item, resource_spec) tuple into pre-features.

Args:
strategy: a distribution strategy
graph_item: optional graph_item, if not provided, the default one bundled with simulator will be used.
resource_spec: optional resource_spec, if not provided, the default one bundled with simulator will be used.

Returns:
OrderedDict(): variable/part name to variable/part items.
ResourceItem:
"""
if not graph_item:
if not self._graph_item:
raise ValueError('No graph item provided.')
else:
graph_item = self._graph_item
if not resource_spec:
if not self._resource_spec:
raise ValueError('No resource spec provided.')
else:
resource_spec = self._resource_spec
if not strategy:
raise ValueError('No strategy provided.')

resource_item = ResourceItem(resource_spec)
name_to_var = {var.name: var for var_op, var in graph_item.trainable_var_op_to_var.items()}

name_to_items = OrderedDict()
for node in strategy.node_config:
var_name = node.var_name
var = name_to_var[var_name]
if node.partitioner:
pc = PartitionerConfig(partition_str=node.partitioner)
for i, part in enumerate(node.part_config):
part_item = PartItem(var, graph_item, i, pc, part)
name_to_items[part_item.name] = part_item
else:
var_item = VariableItem(var, graph_item, node)
name_to_items[var_item.name] = var_item
return name_to_items, resource_item
193 changes: 193 additions & 0 deletions autodist/autosync/simulator/linear_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2020 Petuum Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Predefined simulator with linear model."""
import os
import pickle as pkl

import tensorflow as tf
import numpy as np

from autodist.autosync.simulator.predefined_simulator import PredefinedSimulator
from autodist.proto.synchronizers_pb2 import PSSynchronizer, AllReduceSynchronizer
from autodist.utils import logging


class LinearSimulator(PredefinedSimulator):
"""Simulates strategies for a given graph and resource spec."""

def __init__(self,
graph_item=None,
resource_spec=None,
batch_size=1,
seq_len=1,
checkpoint=None):
super(PredefinedSimulator, self).__init__(graph_item, resource_spec)
logging.debug('A LinearSimulator is instantiated: batch_size_per_gpu is {}'.format(batch_size))

self._batch_size_per_gpu = batch_size
self._seq_len = seq_len

# For loading weights of the linear model.
self._checkpoint = checkpoint
if self._checkpoint:
try:
self._weight = self.load_checkpoint(checkpoint)
except ValueError:
logging.warning('self._checkpoint is invalid')
self._weight = None

# TODO(Hao): add the default weights here.
self._default_weights = None

def simulate(self,
strategy,
graph_item=None,
resource_spec=None,
checkpoint=None,
*args,
**kwargs):
"""Return simulated runtime cost given (strategy, graph_item, resource_spec) tuple.

Args:
strategy: the strategy to simulate.
graph_item: the graph_item this strategy is generated on.
resource_spec: the resource_spec this strategy is on.
checkpoint: the checkpoint to perform inference (in place of the default weight).

Returns:
float: the estimated cost (lower is better).
"""
if not strategy:
raise ValueError('strategy is None.')
if not graph_item:
if not self._graph_item:
raise ValueError('No graph item provided.')
else:
graph_item = self._graph_item
if not resource_spec:
if not self._resource_spec:
raise ValueError('No resource spec provided.')
else:
resource_spec = self._resource_spec

x = self._extract_feature(strategy, graph_item, resource_spec)

# The priority of checkpoint lookup priority is:
# simulate(checkpoint) > self._weight > self._default_weight
if checkpoint:
weights = self.load_checkpoint(checkpoint)
elif self._weights:
weights = self._weights
else:
weights = self._default_weight

cost = self.inference(np.array(x), weights)
return cost

def inference(self, x, weights):
"""

Args:
x: features extracts from a (strategy, graph_item, resource_spec).
weight: trained linear model weight.

Returns:
float: ranking score.
"""
# if not isinstance(inputs, tf.Tensor):
# inputs = tf.reshape(tf.convert_to_tensor(inputs), [1, len(inputs)])

assert len(weights) == 2
W, b = weights
cost = np.array(W) * x.T + np.array(b)
return cost

def load_checkpoint(self, checkpoint):
"""
Load a trained weight from a checkpoint.

Args:
checkpoint: the file path to a npz, or a list/array of weights.

Returns:
list: load weights [W, b].
"""
logging.info('Loading checkpoint: {}'.format(checkpoint))
if isinstance(checkpoint, list):
assert(len(checkpoint) == 2 or len(checkpoint) == 13)
if len(checkpoint) == 13:
checkpoint = checkpoint[:11], checkpoint[12]
return checkpoint
elif isinstance(checkpoint, str):
if os.path.isfile(checkpoint):
weights = np.load(checkpoint)
return weights['W'], weights['b']
else:
raise ValueError('Unable to load the checkpoint: {}'.format(checkpoint))

def _extract_feature(self,
strategy,
graph_item,
resource_spec):
"""Get the feature vector as input to the linear model."""
var_name_to_items, resource_item, var_name_to_sync_time = \
self.extract_prefeature(strategy, graph_item, resource_spec)

feature_keys = ['transmission', 'network_overhead', 'gpu_kernel_memory_latency']
ps_server_sync_time = {}
cc_group_sync_time = {}

for var_name, var_item in var_name_to_items.items():
sync_time = var_name_to_sync_time[var_name]

# Extract per-server and per-group sync time.
if isinstance(var_item.synchronizer, PSSynchronizer):
server = var_item.device
if server not in ps_server_sync_time:
ps_server_sync_time[server] = {key: 0.0 for key in feature_keys}
for key in feature_keys:
ps_server_sync_time[server][key] += sync_time[0][key] + sync_time[1][key]
elif isinstance(var_item.synchronizer, AllReduceSynchronizer):
group = var_item.group
if group not in cc_group_sync_time:
cc_group_sync_time[group] = {key: 0.0 for key in feature_keys}
for key in feature_keys:
cc_group_sync_time[group][key] += sync_time[key]
else:
raise ValueError('Unrecognized type of synchronizer: {}'.format(type(var_item.synchronizer)))

# Different from predefined modeling, we transform these into feature vectors in this simulator.
# We care about the sum time of all servers/groups, or the slowest (max) server/group.
max_ps_server_sync_time = {key: 0.0 for key in feature_keys}
sum_ps_server_sync_time = {key: 0.0 for key in feature_keys}
max_cc_group_sync_time = {key: 0.0 for key in feature_keys}
sum_cc_group_sync_time = {key: 0.0 for key in feature_keys}

for key in feature_keys:
max_ps_server_sync_time[key] = \
max([sync_time[key] for sync_time in ps_server_sync_time.values()] or [0.0])
sum_ps_server_sync_time[key] = \
sum([sync_time[key] for sync_time in ps_server_sync_time.values()] or [0.0])
max_cc_group_sync_time[key] = \
max([sync_time[key] for sync_time in cc_group_sync_time.values()] or [0.0])
sum_cc_group_sync_time[key] = \
sum([sync_time[key] for sync_time in cc_group_sync_time.values()] or [0.0])

# concat them to get the feature.
x = [max_ps_server_sync_time[key] for key in feature_keys] + \
[sum_ps_server_sync_time[key] for key in feature_keys] + \
[max_cc_group_sync_time[key] for key in feature_keys] + \
[sum_cc_group_sync_time[key] for key in feature_keys]
return x
Loading