-
Notifications
You must be signed in to change notification settings - Fork 59
Deprecate torch_dtype
#964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
for more information, see https://pre-commit.ci
|
This change cannot be applied directly, since users running Transformers versions below 4.57 would no longer be supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as said in comment
Do we need to upgrade transformers in requirement? Line 12 in 3c1a678
Will it bring other problems? |
not a good option for now |
|
1 one option is waiting for another 2 or 3 months to upgrade the trasnsformers import functools
import warnings
def support_legacy_keys(legacy_map: dict):
"""
legacy_map: dict, key = 新参数名, value = 旧参数名
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for new_key, old_key in legacy_map.items():
if old_key in kwargs:
if new_key in kwargs:
raise ValueError(f"Cannot use both '{new_key}' and '{old_key}'")
warnings.warn(f"'{old_key}' is deprecated, use '{new_key}' instead", DeprecationWarning)
kwargs[new_key] = kwargs.pop(old_key)
return func(*args, **kwargs)
return wrapper
return decorator使用示例: @support_legacy_keys({"dtype": "torch_dtype"})
def quan(*, dtype=None):
print(f"dtype = {dtype}")
# 调用方式
quan(dtype="torch.float16") # 新参数
quan(torch_dtype="torch.float16") # 老参数,会自动映射并提示 |
For 2, we also need to add a judgment on the transformers version. I prefer option 1. |
This should not depend on our examples or perspective; we need to make decisions from the users’ point of view |
torch_dtype->dtype#776