diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index f9756e1a67aa..cfd397f4aba3 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -253,7 +253,7 @@ try: device = torch.device("cuda", rank % torch.cuda.device_count()) torch.cuda.set_device(device) - transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2)) + transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2), device_map="cuda") pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda") pipeline.transformer.set_attention_backend("flash") @@ -289,4 +289,4 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`]. ```py pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2)) -``` \ No newline at end of file +```