diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 7649f0d0..d9b3d26e 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -16,7 +16,17 @@ import warnings from functools import wraps from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + TypeVar, +) import numpy import torch @@ -44,6 +54,7 @@ "pack_bitmasks", "unpack_bitmasks", "patch_attr", + "patch_attrs", "ParameterizedDefaultDict", "get_num_attn_heads", "get_num_kv_heads", @@ -368,6 +379,34 @@ def patch_attr(base: object, attr: str, value: Any): delattr(base, attr) +@contextlib.contextmanager +def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]): + """ + Same as `patch_attr` but for a list of objects to patch + Patch attribute for a list of objects with list of values. + Original values are restored upon exit + + :param bases: objects which has the attribute to patch + :param attr: name of the the attribute to patch + :param values: used to replace original values. Must be same + length as bases + + Usage: + >>> from types import SimpleNamespace + >>> obj1 = SimpleNamespace() + >>> obj2 = SimpleNamespace() + >>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]): + ... assert obj1.attribute == "value1" + ... assert obj2.attribute == "value2" + >>> assert not hasattr(obj1, "attribute") + >>> assert not hasattr(obj2, "attribute") + """ + with contextlib.ExitStack() as stack: + for base, value in zip(bases, values): + stack.enter_context(patch_attr(base, attr, value)) + yield + + class ParameterizedDefaultDict(dict): """ Similar to `collections.DefaultDict`, but upon fetching a key which is missing, diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index 1c0aed95..eccb7b80 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -21,6 +21,7 @@ ParameterizedDefaultDict, load_compressed, patch_attr, + patch_attrs, save_compressed, save_compressed_model, ) @@ -176,6 +177,23 @@ def test_patch_attr(): assert not hasattr(obj, "attribute") +def test_patch_attrs(): + num_objs = 4 + objs = [SimpleNamespace() for _ in range(num_objs)] + for idx, obj in enumerate(objs): + if idx % 2 == 0: + obj.attribute = f"original_{idx}" + with patch_attrs(objs, "attribute", [f"patched_{idx}" for idx in range(num_objs)]): + for idx, obj in enumerate(objs): + assert obj.attribute == f"patched_{idx}" + obj.attribute = "modified" + for idx, obj in enumerate(objs): + if idx % 2 == 0: + assert obj.attribute == f"original_{idx}" + else: + assert not hasattr(obj, "attribute") + + def test_parameterized_default_dict(): def add_one(value): return value + 1