diff --git a/tools/copy_partitions.py b/tools/copy_partitions.py index 7a40a5791e2d..e1df005eb009 100644 --- a/tools/copy_partitions.py +++ b/tools/copy_partitions.py @@ -9,13 +9,13 @@ import json import copy -def copy_file(file_name, ip, workspace): +def copy_file(file_name, ip, port, workspace): print('copy {} to {}'.format(file_name, ip + ':' + workspace + '/')) - cmd = 'rsync -e \"ssh -o StrictHostKeyChecking=no\" -arvc ' + file_name + ' ' + ip + ':' + workspace + '/' + cmd = 'rsync -e \"ssh -o StrictHostKeyChecking=no -p ' + port + ' \" -arvc ' + file_name + ' ' + ip + ':' + workspace + '/' subprocess.check_call(cmd, shell = True) -def exec_cmd(ip, cmd): - cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\'' +def exec_cmd(ip, port, cmd): + cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' -p ' + port + ' \'' + cmd + '\'' subprocess.check_call(cmd, shell = True) def main(): @@ -36,8 +36,8 @@ def main(): hosts = [] with open(args.ip_config) as f: for line in f: - ip, _, _ = line.strip().split(' ') - hosts.append(ip) + ip, port, _, _ = line.strip().split(' ') + hosts.append([ip, port]) # We need to update the partition config file so that the paths are relative to @@ -70,25 +70,25 @@ def main(): json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4) # Copy ip config. - for part_id, ip in enumerate(hosts): + for part_id, (ip, port) in enumerate(hosts): remote_path = '{}/{}'.format(args.workspace, args.rel_data_path) - exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) + exec_cmd(ip, port, 'mkdir -p {}'.format(remote_path)) - copy_file(args.ip_config, ip, args.workspace) - copy_file(tmp_part_config, ip, '{}/{}'.format(args.workspace, args.rel_data_path)) + copy_file(args.ip_config, ip, port, args.workspace) + copy_file(tmp_part_config, ip, port, '{}/{}'.format(args.workspace, args.rel_data_path)) node_map = part_metadata['node_map'] edge_map = part_metadata['edge_map'] if not isinstance(node_map, list): - copy_file(node_map, ip, tmp_part_metadata['node_map']) + copy_file(node_map, ip, port, tmp_part_metadata['node_map']) if not isinstance(edge_map, list): - copy_file(edge_map, ip, tmp_part_metadata['edge_map']) + copy_file(edge_map, ip, port, tmp_part_metadata['edge_map']) remote_path = '{}/{}/part{}'.format(args.workspace, args.rel_data_path, part_id) - exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) + exec_cmd(ip, port, 'mkdir -p {}'.format(remote_path)) part_files = part_metadata['part-{}'.format(part_id)] - copy_file(part_files['node_feats'], ip, remote_path) - copy_file(part_files['edge_feats'], ip, remote_path) - copy_file(part_files['part_graph'], ip, remote_path) + copy_file(part_files['node_feats'], ip, port, remote_path) + copy_file(part_files['edge_feats'], ip, port, remote_path) + copy_file(part_files['part_graph'], ip, port, remote_path) def signal_handler(signal, frame):