-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[docs] Models #12248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[docs] Models #12248
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left some comments, LMK if they are unclear.
| `"cuda"` | places model or pipeline on CUDA device | | ||
| `"balanced"` | evenly distributes model or pipeline on all GPUs | | ||
| `"auto"` | distribute model from fastest device first to slowest | | ||
| `"cuda"` | places pipeline on CUDA device | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"cuda" is just an example. If someone wants to do it for any other supported accelerator, I believe they pass it by their name 👀
| `"cuda"` | places pipeline on CUDA device | | |
| `"cuda"` | places pipeline on CUDA (or supported accelerator) device | |
model = AutoModel.from_pretrained( | ||
"Qwen/Qwen-Image", | ||
subfolder="transformer" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model = AutoModel.from_pretrained( | |
"Qwen/Qwen-Image", | |
subfolder="transformer" | |
) | |
model = AutoModel.from_pretrained( | |
"Qwen/Qwen-Image", subfolder="transformer" | |
) |
"Qwen/Qwen-Image", | ||
subfolder="transformer" | ||
torch_dtype=torch.float16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Qwen/Qwen-Image", | |
subfolder="transformer" | |
torch_dtype=torch.float16 | |
"Qwen/Qwen-Image", | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16 |
) | ||
``` | ||
|
||
[torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) can also convert to a specific data type on the fly. However, it converts *all* weights to the requested data type unlike `torch_dtype` which respects `_keep_in_fp32_modules`. This argument preserves layers in `torch.float32` for numerical stability and best generation quality (see example [_keep_in_fp32_modules](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be nn.Module.to()
?
from diffusers import QwenImageTransformer2DModel | ||
|
||
model = QwenImageTransformer2DModel.from_pretrained( | ||
"Qwen/Qwen-Image",, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Qwen/Qwen-Image",, | |
"Qwen/Qwen-Image", |
import torch | ||
from diffusers import QwenImageTransformer2DModel | ||
|
||
max_memory = {0: "16GB", 1: "16GB"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Umm, what would 0 and 1 denote in this case, though? I think this form of max_memory
dict is reserved for the pipelines.
For models, you probably want to specify module names (regex should work, too). Cc: @SunMarc
Splits off the
Models
section fromLoad schedulers and models
and creates a dedicated section for models to include device placement, torch dtype,AutoModel
API, and saving as shards.