From f200b65ac579d6b7a8448e3b1470fd9b9532cb35 Mon Sep 17 00:00:00 2001 From: OHTAKE Tomohiro Date: Thu, 17 Jan 2019 10:46:56 +1100 Subject: [PATCH] Fix model loading on multi-GPU environment --- lib/utils/model_loader.py | 55 ++++++++++++++++++++------------------- tools/extract_features.py | 2 +- tools/test_net.py | 2 +- tools/train_net.py | 2 +- 4 files changed, 31 insertions(+), 30 deletions(-) diff --git a/lib/utils/model_loader.py b/lib/utils/model_loader.py index 5ec6f67..b6329d5 100644 --- a/lib/utils/model_loader.py +++ b/lib/utils/model_loader.py @@ -70,7 +70,7 @@ def LoadModelFromPickleFile( model, pkl_file, use_gpu=True, - root_gpu_id=0, + gpu_ids=[0], bgr2rgb=False, ): @@ -89,30 +89,31 @@ def LoadModelFromPickleFile( else: device_opt = caffe2_pb2.CPU - with core.NameScope('gpu_{}'.format(root_gpu_id)): - with core.DeviceScope(core.DeviceOption(device_opt, root_gpu_id)): - for unscoped_blob_name in unscoped_blob_names.keys(): - scoped_blob_name = scoped_name(unscoped_blob_name) - if unscoped_blob_name not in blobs: - log.info('{} not found'.format(unscoped_blob_name)) - continue - if scoped_blob_name in ws_blobs: - ws_blob = workspace.FetchBlob(scoped_blob_name) - target_shape = ws_blob.shape - if target_shape == blobs[unscoped_blob_name].shape: - log.info('copying {} to {}'.format( - unscoped_blob_name, scoped_blob_name)) - if bgr2rgb and unscoped_blob_name == 'conv1_w': - feeding_blob = FlipBGR2RGB( - blobs[unscoped_blob_name] - ) + for gpu_id in gpu_ids: + with core.NameScope('gpu_{}'.format(gpu_id)): + with core.DeviceScope(core.DeviceOption(device_opt, gpu_id)): + for unscoped_blob_name in unscoped_blob_names.keys(): + scoped_blob_name = scoped_name(unscoped_blob_name) + if unscoped_blob_name not in blobs: + log.info('{} not found'.format(unscoped_blob_name)) + continue + if scoped_blob_name in ws_blobs: + ws_blob = workspace.FetchBlob(scoped_blob_name) + target_shape = ws_blob.shape + if target_shape == blobs[unscoped_blob_name].shape: + log.info('copying {} to {}'.format( + unscoped_blob_name, scoped_blob_name)) + if bgr2rgb and unscoped_blob_name == 'conv1_w': + feeding_blob = FlipBGR2RGB( + blobs[unscoped_blob_name] + ) + else: + feeding_blob = blobs[unscoped_blob_name] + else: - feeding_blob = blobs[unscoped_blob_name] - - else: - log.info('found {} but blob shape do not match'.format( - unscoped_blob_name)) - workspace.FeedBlob( - scoped_blob_name, - feeding_blob.astype(np.float32, copy=False) - ) + log.info('found {} but blob shape do not match'.format( + unscoped_blob_name)) + workspace.FeedBlob( + scoped_blob_name, + feeding_blob.astype(np.float32, copy=False) + ) diff --git a/tools/extract_features.py b/tools/extract_features.py index 594cc8a..0c3b549 100644 --- a/tools/extract_features.py +++ b/tools/extract_features.py @@ -150,7 +150,7 @@ def create_model_ops(model, loss_scale): model, args.load_model_path, use_gpu=True, - root_gpu_id=gpus[0] + gpu_ids=gpus ) else: model_loader.LoadModelFromPickleFile( diff --git a/tools/test_net.py b/tools/test_net.py index 63c7a60..9d54073 100644 --- a/tools/test_net.py +++ b/tools/test_net.py @@ -163,7 +163,7 @@ def test_input_fn(model): test_model, args.load_model_path, use_gpu=True, - root_gpu_id=gpus[0] + gpu_ids=gpus ) data_parallel_model.FinalizeAfterCheckpoint(test_model) else: diff --git a/tools/train_net.py b/tools/train_net.py index f93e005..17f7f8b 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -348,7 +348,7 @@ def test_input_fn(model): model_loader.LoadModelFromPickleFile( train_model, args.pretrained_model, - root_gpu_id=gpus[0] + gpu_ids=gpus ) data_parallel_model.FinalizeAfterCheckpoint(