Skip to content
This repository was archived by the owner on Jun 17, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions lib/utils/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def LoadModelFromPickleFile(
model,
pkl_file,
use_gpu=True,
root_gpu_id=0,
gpu_ids=[0],
bgr2rgb=False,
):

Expand All @@ -89,30 +89,31 @@ def LoadModelFromPickleFile(
else:
device_opt = caffe2_pb2.CPU

with core.NameScope('gpu_{}'.format(root_gpu_id)):
with core.DeviceScope(core.DeviceOption(device_opt, root_gpu_id)):
for unscoped_blob_name in unscoped_blob_names.keys():
scoped_blob_name = scoped_name(unscoped_blob_name)
if unscoped_blob_name not in blobs:
log.info('{} not found'.format(unscoped_blob_name))
continue
if scoped_blob_name in ws_blobs:
ws_blob = workspace.FetchBlob(scoped_blob_name)
target_shape = ws_blob.shape
if target_shape == blobs[unscoped_blob_name].shape:
log.info('copying {} to {}'.format(
unscoped_blob_name, scoped_blob_name))
if bgr2rgb and unscoped_blob_name == 'conv1_w':
feeding_blob = FlipBGR2RGB(
blobs[unscoped_blob_name]
)
for gpu_id in gpu_ids:
with core.NameScope('gpu_{}'.format(gpu_id)):
with core.DeviceScope(core.DeviceOption(device_opt, gpu_id)):
for unscoped_blob_name in unscoped_blob_names.keys():
scoped_blob_name = scoped_name(unscoped_blob_name)
if unscoped_blob_name not in blobs:
log.info('{} not found'.format(unscoped_blob_name))
continue
if scoped_blob_name in ws_blobs:
ws_blob = workspace.FetchBlob(scoped_blob_name)
target_shape = ws_blob.shape
if target_shape == blobs[unscoped_blob_name].shape:
log.info('copying {} to {}'.format(
unscoped_blob_name, scoped_blob_name))
if bgr2rgb and unscoped_blob_name == 'conv1_w':
feeding_blob = FlipBGR2RGB(
blobs[unscoped_blob_name]
)
else:
feeding_blob = blobs[unscoped_blob_name]

else:
feeding_blob = blobs[unscoped_blob_name]

else:
log.info('found {} but blob shape do not match'.format(
unscoped_blob_name))
workspace.FeedBlob(
scoped_blob_name,
feeding_blob.astype(np.float32, copy=False)
)
log.info('found {} but blob shape do not match'.format(
unscoped_blob_name))
workspace.FeedBlob(
scoped_blob_name,
feeding_blob.astype(np.float32, copy=False)
)
2 changes: 1 addition & 1 deletion tools/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def create_model_ops(model, loss_scale):
model,
args.load_model_path,
use_gpu=True,
root_gpu_id=gpus[0]
gpu_ids=gpus
)
else:
model_loader.LoadModelFromPickleFile(
Expand Down
2 changes: 1 addition & 1 deletion tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_input_fn(model):
test_model,
args.load_model_path,
use_gpu=True,
root_gpu_id=gpus[0]
gpu_ids=gpus
)
data_parallel_model.FinalizeAfterCheckpoint(test_model)
else:
Expand Down
2 changes: 1 addition & 1 deletion tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def test_input_fn(model):
model_loader.LoadModelFromPickleFile(
train_model,
args.pretrained_model,
root_gpu_id=gpus[0]
gpu_ids=gpus
)

data_parallel_model.FinalizeAfterCheckpoint(
Expand Down