Skip to content

Commit 7a001c3

Browse files
kaixuanliusayakpaulDN6
authored
adjust unit tests for test_save_load_float16 (#12500)
* adjust unit tests for wan pipeline Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * avoid adjusting common `get_dummy_components` API Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * use `form_pretrained` to `transformer` and `transformer_2` Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent d8e4805 commit 7a001c3

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,18 @@ def test_float16_inference(self, expected_max_diff=5e-2):
14221422
def test_save_load_float16(self, expected_max_diff=1e-2):
14231423
components = self.get_dummy_components()
14241424
for name, module in components.items():
1425-
if hasattr(module, "half"):
1425+
# Account for components with _keep_in_fp32_modules
1426+
if hasattr(module, "_keep_in_fp32_modules") and module._keep_in_fp32_modules is not None:
1427+
for name, param in module.named_parameters():
1428+
if any(
1429+
module_to_keep_in_fp32 in name.split(".")
1430+
for module_to_keep_in_fp32 in module._keep_in_fp32_modules
1431+
):
1432+
param.data = param.data.to(torch_device).to(torch.float32)
1433+
else:
1434+
param.data = param.data.to(torch_device).to(torch.float16)
1435+
1436+
elif hasattr(module, "half"):
14261437
components[name] = module.to(torch_device).half()
14271438

14281439
pipe = self.pipeline_class(**components)

0 commit comments

Comments
 (0)