Skip to content

Commit 7aa6af1

Browse files
authored
[Refactor] Move testing utils out of src (#12238)
* update * update * update * update * update * merge main * Revert "merge main" This reverts commit 65efbce.
1 parent 87b800e commit 7aa6af1

File tree

312 files changed

+2360
-554
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

312 files changed

+2360
-554
lines changed

examples/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,25 @@
2525
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
2626
sys.path.insert(1, git_repo_path)
2727

28+
# Add parent directory to path so we can import from tests
29+
repo_root = abspath(dirname(dirname(__file__)))
30+
if repo_root not in sys.path:
31+
sys.path.insert(0, repo_root)
32+
2833

2934
# silence FutureWarning warnings in tests since often we can't act on them until
3035
# they become normal warnings - i.e. the tests still need to test the current functionality
3136
warnings.simplefilter(action="ignore", category=FutureWarning)
3237

3338

3439
def pytest_addoption(parser):
35-
from diffusers.utils.testing_utils import pytest_addoption_shared
40+
from tests.testing_utils import pytest_addoption_shared
3641

3742
pytest_addoption_shared(parser)
3843

3944

4045
def pytest_terminal_summary(terminalreporter):
41-
from diffusers.utils.testing_utils import pytest_terminal_summary_main
46+
from tests.testing_utils import pytest_terminal_summary_main
4247

4348
make_reports = terminalreporter.config.getoption("--make-reports")
4449
if make_reports:

examples/controlnet/train_controlnet_sd3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import os
2525
import random
2626
import shutil
27+
28+
# Add repo root to path to import from tests
2729
from pathlib import Path
2830

2931
import accelerate
@@ -54,8 +56,7 @@
5456
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
5557
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
5658
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
57-
from diffusers.utils.testing_utils import backend_empty_cache
58-
from diffusers.utils.torch_utils import is_compiled_module
59+
from diffusers.utils.torch_utils import backend_empty_cache, is_compiled_module
5960

6061

6162
if is_wandb_available():

examples/vqgan/test_vqgan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@
2424
import torch
2525

2626
from diffusers import VQModel
27-
from diffusers.utils.testing_utils import require_timm
2827

2928

29+
# Add parent directories to path to import from tests
3030
sys.path.append("..")
31+
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
32+
if repo_root not in sys.path:
33+
sys.path.insert(0, repo_root)
34+
3135
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
3236

37+
from tests.testing_utils import require_timm # noqa
38+
3339

3440
logging.basicConfig(level=logging.DEBUG)
3541

src/diffusers/utils/testing_utils.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@
6666
global_rng = random.Random()
6767

6868
logger = get_logger(__name__)
69-
69+
logger.warning(
70+
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
71+
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
72+
)
7073
_required_peft_version = is_peft_available() and version.parse(
7174
version.parse(importlib.metadata.version("peft")).base_version
7275
) > version.parse("0.5")
@@ -801,10 +804,9 @@ def export_to_ply(mesh, output_ply_path: str = None):
801804
f.write(format.pack(*vertex))
802805

803806
if faces is not None:
804-
format = struct.Struct("<B3I")
805807
for tri in faces.tolist():
806808
f.write(format.pack(len(tri), *tri))
807-
809+
format = struct.Struct("<B3I")
808810
return output_ply_path
809811

810812

@@ -1144,23 +1146,23 @@ def enable_full_determinism():
11441146
Helper function for reproducible behavior during distributed training. See
11451147
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
11461148
"""
1147-
# Enable PyTorch deterministic mode. This potentially requires either the environment
1148-
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
1149-
# depending on the CUDA version, so we set them both here
1150-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
1151-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
1152-
torch.use_deterministic_algorithms(True)
1149+
from .torch_utils import enable_full_determinism as _enable_full_determinism
11531150

1154-
# Enable CUDNN deterministic mode
1155-
torch.backends.cudnn.deterministic = True
1156-
torch.backends.cudnn.benchmark = False
1157-
torch.backends.cuda.matmul.allow_tf32 = False
1151+
logger.warning(
1152+
"enable_full_determinism has been moved to diffusers.utils.torch_utils. "
1153+
"Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1154+
)
1155+
return _enable_full_determinism()
11581156

11591157

11601158
def disable_full_determinism():
1161-
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
1162-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
1163-
torch.use_deterministic_algorithms(False)
1159+
from .torch_utils import disable_full_determinism as _disable_full_determinism
1160+
1161+
logger.warning(
1162+
"disable_full_determinism has been moved to diffusers.utils.torch_utils. "
1163+
"Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1164+
)
1165+
return _disable_full_determinism()
11641166

11651167

11661168
# Utils for custom and alternative accelerator devices
@@ -1282,43 +1284,85 @@ def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable],
12821284

12831285
# These are callables which automatically dispatch the function specific to the accelerator
12841286
def backend_manual_seed(device: str, seed: int):
1285-
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
1287+
from .torch_utils import backend_manual_seed as _backend_manual_seed
1288+
1289+
logger.warning(
1290+
"backend_manual_seed has been moved to diffusers.utils.torch_utils. "
1291+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1292+
)
1293+
return _backend_manual_seed(device, seed)
12861294

12871295

12881296
def backend_synchronize(device: str):
1289-
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
1297+
from .torch_utils import backend_synchronize as _backend_synchronize
1298+
1299+
logger.warning(
1300+
"backend_synchronize has been moved to diffusers.utils.torch_utils. "
1301+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1302+
)
1303+
return _backend_synchronize(device)
12901304

12911305

12921306
def backend_empty_cache(device: str):
1293-
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
1307+
from .torch_utils import backend_empty_cache as _backend_empty_cache
1308+
1309+
logger.warning(
1310+
"backend_empty_cache has been moved to diffusers.utils.torch_utils. "
1311+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1312+
)
1313+
return _backend_empty_cache(device)
12941314

12951315

12961316
def backend_device_count(device: str):
1297-
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
1317+
from .torch_utils import backend_device_count as _backend_device_count
1318+
1319+
logger.warning(
1320+
"backend_device_count has been moved to diffusers.utils.torch_utils. "
1321+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1322+
)
1323+
return _backend_device_count(device)
12981324

12991325

13001326
def backend_reset_peak_memory_stats(device: str):
1301-
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
1327+
from .torch_utils import backend_reset_peak_memory_stats as _backend_reset_peak_memory_stats
1328+
1329+
logger.warning(
1330+
"backend_reset_peak_memory_stats has been moved to diffusers.utils.torch_utils. "
1331+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1332+
)
1333+
return _backend_reset_peak_memory_stats(device)
13021334

13031335

13041336
def backend_reset_max_memory_allocated(device: str):
1305-
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
1337+
from .torch_utils import backend_reset_max_memory_allocated as _backend_reset_max_memory_allocated
1338+
1339+
logger.warning(
1340+
"backend_reset_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
1341+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1342+
)
1343+
return _backend_reset_max_memory_allocated(device)
13061344

13071345

13081346
def backend_max_memory_allocated(device: str):
1309-
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
1347+
from .torch_utils import backend_max_memory_allocated as _backend_max_memory_allocated
1348+
1349+
logger.warning(
1350+
"backend_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
1351+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1352+
)
1353+
return _backend_max_memory_allocated(device)
13101354

13111355

13121356
# These are callables which return boolean behaviour flags and can be used to specify some
13131357
# device agnostic alternative where the feature is unsupported.
13141358
def backend_supports_training(device: str):
1315-
if not is_torch_available():
1316-
return False
1317-
1318-
if device not in BACKEND_SUPPORTS_TRAINING:
1319-
device = "default"
1359+
from .torch_utils import backend_supports_training as _backend_supports_training
13201360

1321-
return BACKEND_SUPPORTS_TRAINING[device]
1361+
logger.warning(
1362+
"backend_supports_training has been moved to diffusers.utils.torch_utils. "
1363+
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
1364+
)
1365+
return _backend_supports_training(device)
13221366

13231367

13241368
# Guard for when Torch is not available

src/diffusers/utils/torch_utils.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"""
1717

1818
import functools
19-
from typing import List, Optional, Tuple, Union
19+
import os
20+
from typing import Callable, Dict, List, Optional, Tuple, Union
2021

2122
from . import logging
2223
from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version
@@ -26,6 +27,56 @@
2627
import torch
2728
from torch.fft import fftn, fftshift, ifftn, ifftshift
2829

30+
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
31+
BACKEND_EMPTY_CACHE = {
32+
"cuda": torch.cuda.empty_cache,
33+
"xpu": torch.xpu.empty_cache,
34+
"cpu": None,
35+
"mps": torch.mps.empty_cache,
36+
"default": None,
37+
}
38+
BACKEND_DEVICE_COUNT = {
39+
"cuda": torch.cuda.device_count,
40+
"xpu": torch.xpu.device_count,
41+
"cpu": lambda: 0,
42+
"mps": lambda: 0,
43+
"default": 0,
44+
}
45+
BACKEND_MANUAL_SEED = {
46+
"cuda": torch.cuda.manual_seed,
47+
"xpu": torch.xpu.manual_seed,
48+
"cpu": torch.manual_seed,
49+
"mps": torch.mps.manual_seed,
50+
"default": torch.manual_seed,
51+
}
52+
BACKEND_RESET_PEAK_MEMORY_STATS = {
53+
"cuda": torch.cuda.reset_peak_memory_stats,
54+
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
55+
"cpu": None,
56+
"mps": None,
57+
"default": None,
58+
}
59+
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
60+
"cuda": torch.cuda.reset_max_memory_allocated,
61+
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
62+
"cpu": None,
63+
"mps": None,
64+
"default": None,
65+
}
66+
BACKEND_MAX_MEMORY_ALLOCATED = {
67+
"cuda": torch.cuda.max_memory_allocated,
68+
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
69+
"cpu": 0,
70+
"mps": 0,
71+
"default": 0,
72+
}
73+
BACKEND_SYNCHRONIZE = {
74+
"cuda": torch.cuda.synchronize,
75+
"xpu": getattr(torch.xpu, "synchronize", None),
76+
"cpu": None,
77+
"mps": None,
78+
"default": None,
79+
}
2980
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3081

3182
try:
@@ -36,6 +87,62 @@ def maybe_allow_in_graph(cls):
3687
return cls
3788

3889

90+
# This dispatches a defined function according to the accelerator from the function definitions.
91+
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
92+
if device not in dispatch_table:
93+
return dispatch_table["default"](*args, **kwargs)
94+
95+
fn = dispatch_table[device]
96+
97+
# Some device agnostic functions return values. Need to guard against 'None' instead at
98+
# user level
99+
if not callable(fn):
100+
return fn
101+
102+
return fn(*args, **kwargs)
103+
104+
105+
# These are callables which automatically dispatch the function specific to the accelerator
106+
def backend_manual_seed(device: str, seed: int):
107+
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
108+
109+
110+
def backend_synchronize(device: str):
111+
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
112+
113+
114+
def backend_empty_cache(device: str):
115+
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
116+
117+
118+
def backend_device_count(device: str):
119+
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
120+
121+
122+
def backend_reset_peak_memory_stats(device: str):
123+
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
124+
125+
126+
def backend_reset_max_memory_allocated(device: str):
127+
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
128+
129+
130+
def backend_max_memory_allocated(device: str):
131+
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
132+
133+
134+
# These are callables which return boolean behaviour flags and can be used to specify some
135+
# device agnostic alternative where the feature is unsupported.
136+
def backend_supports_training(device: str):
137+
if not is_torch_available():
138+
return False
139+
140+
if device not in BACKEND_SUPPORTS_TRAINING:
141+
device = "default"
142+
143+
return BACKEND_SUPPORTS_TRAINING[device]
144+
145+
39146
def randn_tensor(
40147
shape: Union[Tuple, List],
41148
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
@@ -197,3 +304,31 @@ def device_synchronize(device_type: Optional[str] = None):
197304
device_type = get_device()
198305
device_mod = getattr(torch, device_type, torch.cuda)
199306
device_mod.synchronize()
307+
308+
309+
def enable_full_determinism():
310+
"""
311+
Helper function for reproducible behavior during distributed training. See
312+
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
313+
"""
314+
# Enable PyTorch deterministic mode. This potentially requires either the environment
315+
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
316+
# depending on the CUDA version, so we set them both here
317+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
318+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
319+
torch.use_deterministic_algorithms(True)
320+
321+
# Enable CUDNN deterministic mode
322+
torch.backends.cudnn.deterministic = True
323+
torch.backends.cudnn.benchmark = False
324+
torch.backends.cuda.matmul.allow_tf32 = False
325+
326+
327+
def disable_full_determinism():
328+
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
329+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
330+
torch.use_deterministic_algorithms(False)
331+
332+
333+
if is_torch_available():
334+
torch_device = get_device()

0 commit comments

Comments
 (0)