diff --git a/src/transformers/model_debugging_utils.py b/src/transformers/model_debugging_utils.py index 2c7b47c04fd5..6935722b4fb4 100644 --- a/src/transformers/model_debugging_utils.py +++ b/src/transformers/model_debugging_utils.py @@ -202,7 +202,11 @@ def is_layer_block(node): if not match or not node.get("children"): return False number = match.group(2) - return any(f".{number}." in child.get("module_path", "") for child in node["children"]) + search_str = f".{number}." + for child in node["children"]: + if search_str in child.get("module_path", ""): + return True + return False def prune_intermediate_layers(node):