@@ -562,7 +562,23 @@ def __init__(
562562 self .opset_version = _target (opset_version ) if opset_version is not None else None
563563 self ._prog = mil .Program ()
564564
565+ self .src_model_has_all_fp16_weights = False
566+
565567 if isinstance (loaded_model , torch .jit .ScriptModule ):
568+ # src_model_has_all_fp16_weights will be True
569+ # if there are more than one trainable layers in the model
570+ # and if all those trainable layers have the fp16 dtype
571+ # eg: if pytorch_model.half() has been explicitly used.
572+ num_trainable_layers = 0
573+ num_trainable_fp16_layers = 0
574+ for param in loaded_model .parameters ():
575+ if param .requires_grad :
576+ num_trainable_layers += 1
577+ if param .dtype == torch .float16 :
578+ num_trainable_fp16_layers += 1
579+ if num_trainable_layers > 0 :
580+ self .src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers
581+
566582 self .context = TranscriptionContext (frontend = TorchFrontend .TORCHSCRIPT )
567583 self .graph = InternalTorchIRGraph .from_torchscript (
568584 torchscript = loaded_model , inputs = self .inputs , cut_at_symbols = cut_at_symbols
@@ -1140,6 +1156,11 @@ def convert(self) -> Program:
11401156 user_names = list (ssa_func_inputs .keys ())
11411157 internal_names = list (self .graph .inputs .keys ())
11421158 internal_names .extend (user_names [len (internal_names ) :])
1159+ input_dtypes = []
1160+ for torch_name , ssa_name in zip (internal_names , user_names ):
1161+ input_var = ssa_func .inputs [ssa_name ]
1162+ input_dtypes .append (input_var .dtype )
1163+ all_fp16_inputs = all (x == types .fp16 for x in input_dtypes )
11431164 for torch_name , ssa_name in zip (internal_names , user_names ):
11441165 input_var = ssa_func .inputs [ssa_name ]
11451166 if self .context .frontend == TorchFrontend .TORCHSCRIPT :
@@ -1151,7 +1172,7 @@ def convert(self) -> Program:
11511172 # So here we perform the "cast input to fp32" step
11521173 if (
11531174 types .is_tensor (input_var .sym_type ) or types .is_scalar (input_var .sym_type )
1154- ) and input_var .dtype == types .fp16 :
1175+ ) and input_var .dtype == types .fp16 and not ( all_fp16_inputs and self . src_model_has_all_fp16_weights ) :
11551176 # This cast should have placeholder scope
11561177 with mb .scope (
11571178 ScopeInfo (
0 commit comments