11"""pruning module."""
2- #!/usr/bin/env python
2+ # !/usr/bin/env python
33# -*- coding: utf-8 -*-
44#
55# Copyright (c) 2021 Intel Corporation
3030from ..utils .utility import LazyImport
3131from ..pruner .pruners import get_pruner
3232from ..conf .pythonic_config import Config
33+
3334LazyImport ('torch.nn' )
3435torch = LazyImport ('torch' )
3536
3637from deprecated import deprecated
3738import importlib
3839import re
3940
41+
4042class Pruning (Component ):
4143 """This is base class of pruning object.
4244
@@ -71,12 +73,12 @@ def __init__(self, conf_fname_or_obj=None):
7173 # yaml file
7274 raise NotImplementedError ("Only WeightPruningConfig config is supported currently." )
7375 self .pruners_info = process_config (self .conf )
74- # self.model = None # here skip
76+ # self.model = None # here skip
7577 # align with old Component based API
7678 # self._init_with_conf()
7779 self .callbacks = dict (tf_pruning = TfPruningCallback )
7880 self .pruners = []
79- self .generate_hooks () # place generate hooks here, to get rid of prepare() function.
81+ self .generate_hooks () # place generate hooks here, to get rid of prepare() function.
8082
8183 def update_config (self , * args , ** kwargs ):
8284 """Add user-defined arguments to the original configurations.
@@ -134,6 +136,11 @@ def get_sparsity_ratio(self):
134136 elementwise_over_all = float (
135137 element_sparsity_cnt ) / param_cnt
136138
139+ logger .info (
140+ f"elementwise_over_matmul_gemm_conv:{ elementwise_over_matmul_gemm_conv } ,"
141+ f" elementwise_over_all:{ elementwise_over_all } ,"
142+ f"blockwise_over_matmul_gemm_conv:{ blockwise_over_matmul_gemm_conv } " )
143+
137144 return elementwise_over_matmul_gemm_conv , elementwise_over_all , blockwise_over_matmul_gemm_conv
138145
139146 def _on_train_begin (self , dataloader = None ):
@@ -188,6 +195,7 @@ def _on_train_end(self):
188195 """Functions called after training."""
189196 for pruner in self .pruners :
190197 pruner .on_train_end ()
198+ self .get_sparsity_ratio ()
191199
192200 def _on_before_eval (self ):
193201 """Implement at the beginning of evaluation phase."""
@@ -227,16 +235,16 @@ def pre_process(self):
227235 if self ._train_dataloader is None and self ._train_func is None :
228236 train_dataloader_cfg = self .cfg .pruning .train .dataloader
229237 assert train_dataloader_cfg is not None , \
230- 'dataloader field of train field of pruning section ' \
231- 'in yaml file should be configured as train_dataloader property is NOT set!'
238+ 'dataloader field of train field of pruning section ' \
239+ 'in yaml file should be configured as train_dataloader property is NOT set!'
232240 train_dataloader_cfg .distributed = self .train_distributed
233241 self ._train_dataloader = create_dataloader (self .framework , train_dataloader_cfg )
234242
235243 if self ._eval_dataloader is None and self ._eval_func is None :
236244 eval_dataloader_cfg = self .cfg .evaluation .accuracy .dataloader
237245 assert eval_dataloader_cfg is not None , \
238- 'dataloader field of evaluation ' \
239- 'in yaml file should be configured as eval_dataloader property is NOT set!'
246+ 'dataloader field of evaluation ' \
247+ 'in yaml file should be configured as eval_dataloader property is NOT set!'
240248 eval_dataloader_cfg .distributed = self .evaluation_distributed
241249 self ._eval_dataloader = create_dataloader (self .framework , eval_dataloader_cfg )
242250
@@ -246,22 +254,22 @@ def pre_process(self):
246254 assert train_cfg , "train field of pruning section in yaml file must " \
247255 "be configured for pruning if pruning_func is NOT set."
248256 self ._train_func = create_train_func (self .framework , \
249- self .train_dataloader , \
250- self .adaptor , \
251- train_cfg , \
252- hooks = self .hooks , \
253- callbacks = self .callbacks )
257+ self .train_dataloader , \
258+ self .adaptor , \
259+ train_cfg , \
260+ hooks = self .hooks , \
261+ callbacks = self .callbacks )
254262 if self ._eval_func is None :
255263 # eval section in yaml file should be configured.
256264 eval_cfg = self .cfg .evaluation
257265 assert eval_cfg , "eval field of pruning section in yaml file must " \
258- "be configured for pruning if eval_func is NOT set."
266+ "be configured for pruning if eval_func is NOT set."
259267 self ._eval_func = create_eval_func (self .framework , \
260268 self .eval_dataloader , \
261269 self .adaptor , \
262270 eval_cfg .accuracy .metric , \
263271 eval_cfg .accuracy .postprocess , \
264- fp32_baseline = False )
272+ fp32_baseline = False )
265273 if getattr (self .train_dataloader , 'distributed' , False ):
266274 self .register_hook ('on_train_begin' , self .adaptor ._pre_hook_for_hvd )
267275
@@ -272,14 +280,14 @@ def execute(self):
272280 """
273281 logger .info ("Start to get the baseline model's score before pruning." )
274282 self .baseline_score = self ._eval_func (self ._model if getattr (self ._eval_func , 'builtin' , None ) \
275- else self ._model .model )
283+ else self ._model .model )
276284 logger .info ("Baseline model's score is {}." .format (str (self .baseline_score )))
277285 logger .info ("Model pruning begins." )
278286 self ._train_func (self ._model if getattr (self ._train_func , 'builtin' , None ) \
279- else self ._model .model )
287+ else self ._model .model )
280288 logger .info ("Model pruning is done. Start to evaluate the pruned model." )
281289 self .last_score = self ._eval_func (self ._model if getattr (self ._eval_func , 'builtin' , None ) \
282- else self ._model .model )
290+ else self ._model .model )
283291 logger .info ("Pruned model score is {}." .format (str (self .last_score )))
284292 return self ._model
285293
0 commit comments