Skip to content

Commit 240e49c

Browse files
authored
fix gradient based criterion bug (#351)
Signed-off-by: wenhuach21 <wenhua.cheng@intel.com> Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>
1 parent 20559d2 commit 240e49c

File tree

3 files changed

+36
-22
lines changed

3 files changed

+36
-22
lines changed

neural_compressor/experimental/pruning_v2.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""pruning module."""
2-
#!/usr/bin/env python
2+
# !/usr/bin/env python
33
# -*- coding: utf-8 -*-
44
#
55
# Copyright (c) 2021 Intel Corporation
@@ -30,13 +30,15 @@
3030
from ..utils.utility import LazyImport
3131
from ..pruner.pruners import get_pruner
3232
from ..conf.pythonic_config import Config
33+
3334
LazyImport('torch.nn')
3435
torch = LazyImport('torch')
3536

3637
from deprecated import deprecated
3738
import importlib
3839
import re
3940

41+
4042
class Pruning(Component):
4143
"""This is base class of pruning object.
4244
@@ -71,12 +73,12 @@ def __init__(self, conf_fname_or_obj=None):
7173
# yaml file
7274
raise NotImplementedError("Only WeightPruningConfig config is supported currently.")
7375
self.pruners_info = process_config(self.conf)
74-
# self.model = None # here skip
76+
# self.model = None # here skip
7577
# align with old Component based API
7678
# self._init_with_conf()
7779
self.callbacks = dict(tf_pruning=TfPruningCallback)
7880
self.pruners = []
79-
self.generate_hooks() # place generate hooks here, to get rid of prepare() function.
81+
self.generate_hooks() # place generate hooks here, to get rid of prepare() function.
8082

8183
def update_config(self, *args, **kwargs):
8284
"""Add user-defined arguments to the original configurations.
@@ -134,6 +136,11 @@ def get_sparsity_ratio(self):
134136
elementwise_over_all = float(
135137
element_sparsity_cnt) / param_cnt
136138

139+
logger.info(
140+
f"elementwise_over_matmul_gemm_conv:{elementwise_over_matmul_gemm_conv},"
141+
f" elementwise_over_all:{elementwise_over_all},"
142+
f"blockwise_over_matmul_gemm_conv:{blockwise_over_matmul_gemm_conv}")
143+
137144
return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv
138145

139146
def _on_train_begin(self, dataloader=None):
@@ -188,6 +195,7 @@ def _on_train_end(self):
188195
"""Functions called after training."""
189196
for pruner in self.pruners:
190197
pruner.on_train_end()
198+
self.get_sparsity_ratio()
191199

192200
def _on_before_eval(self):
193201
"""Implement at the beginning of evaluation phase."""
@@ -227,16 +235,16 @@ def pre_process(self):
227235
if self._train_dataloader is None and self._train_func is None:
228236
train_dataloader_cfg = self.cfg.pruning.train.dataloader
229237
assert train_dataloader_cfg is not None, \
230-
'dataloader field of train field of pruning section ' \
231-
'in yaml file should be configured as train_dataloader property is NOT set!'
238+
'dataloader field of train field of pruning section ' \
239+
'in yaml file should be configured as train_dataloader property is NOT set!'
232240
train_dataloader_cfg.distributed = self.train_distributed
233241
self._train_dataloader = create_dataloader(self.framework, train_dataloader_cfg)
234242

235243
if self._eval_dataloader is None and self._eval_func is None:
236244
eval_dataloader_cfg = self.cfg.evaluation.accuracy.dataloader
237245
assert eval_dataloader_cfg is not None, \
238-
'dataloader field of evaluation ' \
239-
'in yaml file should be configured as eval_dataloader property is NOT set!'
246+
'dataloader field of evaluation ' \
247+
'in yaml file should be configured as eval_dataloader property is NOT set!'
240248
eval_dataloader_cfg.distributed = self.evaluation_distributed
241249
self._eval_dataloader = create_dataloader(self.framework, eval_dataloader_cfg)
242250

@@ -246,22 +254,22 @@ def pre_process(self):
246254
assert train_cfg, "train field of pruning section in yaml file must " \
247255
"be configured for pruning if pruning_func is NOT set."
248256
self._train_func = create_train_func(self.framework, \
249-
self.train_dataloader, \
250-
self.adaptor, \
251-
train_cfg, \
252-
hooks=self.hooks, \
253-
callbacks=self.callbacks)
257+
self.train_dataloader, \
258+
self.adaptor, \
259+
train_cfg, \
260+
hooks=self.hooks, \
261+
callbacks=self.callbacks)
254262
if self._eval_func is None:
255263
# eval section in yaml file should be configured.
256264
eval_cfg = self.cfg.evaluation
257265
assert eval_cfg, "eval field of pruning section in yaml file must " \
258-
"be configured for pruning if eval_func is NOT set."
266+
"be configured for pruning if eval_func is NOT set."
259267
self._eval_func = create_eval_func(self.framework, \
260268
self.eval_dataloader, \
261269
self.adaptor, \
262270
eval_cfg.accuracy.metric, \
263271
eval_cfg.accuracy.postprocess, \
264-
fp32_baseline = False)
272+
fp32_baseline=False)
265273
if getattr(self.train_dataloader, 'distributed', False):
266274
self.register_hook('on_train_begin', self.adaptor._pre_hook_for_hvd)
267275

@@ -272,14 +280,14 @@ def execute(self):
272280
"""
273281
logger.info("Start to get the baseline model's score before pruning.")
274282
self.baseline_score = self._eval_func(self._model if getattr(self._eval_func, 'builtin', None) \
275-
else self._model.model)
283+
else self._model.model)
276284
logger.info("Baseline model's score is {}.".format(str(self.baseline_score)))
277285
logger.info("Model pruning begins.")
278286
self._train_func(self._model if getattr(self._train_func, 'builtin', None) \
279-
else self._model.model)
287+
else self._model.model)
280288
logger.info("Model pruning is done. Start to evaluate the pruned model.")
281289
self.last_score = self._eval_func(self._model if getattr(self._eval_func, 'builtin', None) \
282-
else self._model.model)
290+
else self._model.model)
283291
logger.info("Pruned model score is {}.".format(str(self.last_score)))
284292
return self._model
285293

neural_compressor/pruner/criteria.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def on_step_begin(self):
6161
"""Calculate and store the pruning scores of pruning modules at the beginning of a step."""
6262
pass
6363

64+
def on_before_optimizer_step(self):
65+
"""Calculate and store the pruning scores of pruning modules before the optimizer step."""
66+
pass
67+
6468
def on_after_optimizer_step(self):
6569
"""Calculate and store the pruning scores of pruning modules after the optimizer step."""
6670
pass
@@ -113,7 +117,7 @@ def __init__(self, modules, config):
113117
super(GradientCriterion, self).__init__(modules, config)
114118
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
115119

116-
def on_after_optimizer_step(self):
120+
def on_before_optimizer_step(self):
117121
"""Calculate and store the pruning scores based on gradient criterion."""
118122
with torch.no_grad():
119123
for key in self.modules.keys():
@@ -143,7 +147,7 @@ def __init__(self, modules, config):
143147
super(SnipCriterion, self).__init__(modules, config)
144148
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
145149

146-
def on_after_optimizer_step(self):
150+
def on_before_optimizer_step(self):
147151
"""Calculate and store the pruning scores based on snip criterion."""
148152
##self.mask_weights()
149153
with torch.no_grad():
@@ -180,7 +184,7 @@ def __init__(self, modules, config):
180184
self.alpha = 0.9
181185
self.beta = 1.0
182186

183-
def on_after_optimizer_step(self):
187+
def on_before_optimizer_step(self):
184188
"""Calculate and store the pruning scores based on snip_momentum criterion."""
185189
with torch.no_grad():
186190
for key in self.modules.keys():

neural_compressor/pruner/pruners.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,15 @@ def update_masks(self, local_step):
323323
def on_before_optimizer_step(self):
324324
"""Implement before optimizer.step()."""
325325
self.reg.on_before_optimizer_step()
326+
self.criterion.on_before_optimizer_step()
326327

327328
def on_after_optimizer_step(self):
328329
"""Prune the model after optimization."""
329330
##the order of the following three lines can't not be exchanged
330331
if self.global_step >= self.start_step and self.global_step <= self.end_step:
331332
self.reg.on_after_optimizer_step()
332333
self.mask_weights()
333-
self.criterion.on_after_optimizer_step()
334+
334335
self.global_step += 1
335336

336337

@@ -563,6 +564,7 @@ def on_step_begin(self, local_step):
563564
def on_before_optimizer_step(self):
564565
"""Implement before optimizer.step()."""
565566
self.reg.on_before_optimizer_step()
567+
self.criterion.on_before_optimizer_step()
566568

567569
def on_after_optimizer_step(self):
568570
"""Prune the model after optimization."""
@@ -573,7 +575,7 @@ def on_after_optimizer_step(self):
573575
self.mask_weights()
574576
else:
575577
self.mask_weights_general(self.progressive_masks)
576-
self.criterion.on_after_optimizer_step()
578+
577579
self.global_step += 1
578580

579581
def print_progressive_sparsity(self):

0 commit comments

Comments
 (0)