-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup.py
More file actions
492 lines (395 loc) · 16 KB
/
setup.py
File metadata and controls
492 lines (395 loc) · 16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
import os
import subprocess
import sys
from setuptools import setup
# Set DISTUTILS_USE_SDK for Windows builds to avoid VC env issues
if sys.platform == "win32":
os.environ.setdefault("DISTUTILS_USE_SDK", "1")
# Fix PATH to use MSVC linker instead of Git's link.exe
# Git installs a Unix-style 'link' command that conflicts with MSVC's linker
path = os.environ.get("PATH", "")
path_parts = path.split(os.pathsep)
filtered_parts = [p for p in path_parts if "Git\\usr\\bin" not in p and "Git/usr/bin" not in p]
os.environ["PATH"] = os.pathsep.join(filtered_parts)
# =============================================================================
# Multi-CUDA Package Configuration
# =============================================================================
# FUSSIM_CUDA_VARIANT controls which CUDA extension to build:
# - "cu118", "cu121", "cu124", "cu126", "cu128": Build specific CUDA version
# - "all": Build all CUDA variants (fat wheel)
# - Not set: Auto-detect from installed PyTorch/CUDA
#
# Extension naming: fussim._cuda_cu{version} (e.g., fussim._cuda_cu126)
SUPPORTED_CUDA_VARIANTS = ["cu118", "cu121", "cu124", "cu126", "cu128", "cu130"]
def _cuda_variant_to_tuple(variant):
"""Convert a variant string like 'cu128' to a comparable tuple (12, 8)."""
nums = variant[2:] # strip "cu" prefix
if len(nums) == 3:
return (int(nums[0:2]), int(nums[2]))
elif len(nums) == 2:
return (int(nums[0]), int(nums[1]))
return (0, 0)
def _is_probable_pip_build_isolation():
"""Best-effort detection for pip's isolated build env.
PEP 517 build isolation installs build requirements in a temporary
environment that may contain a different PyTorch build than the user's
target environment. For PyTorch CUDA extensions, that can silently produce
ABI-mismatched source builds.
"""
markers = [
os.environ.get("PIP_BUILD_TRACKER", ""),
sys.prefix,
sys.executable,
*sys.path,
]
return any("pip-build-env" in str(marker).lower() for marker in markers if marker)
def _get_torch_build_info():
"""Return the torch version visible to the build env, if any."""
try:
import torch
return torch.__version__, torch.version.cuda
except Exception:
return None, None
def _raise_build_isolation_error():
"""Abort ambiguous source builds with a prescriptive message."""
torch_version, torch_cuda = _get_torch_build_info()
torch_summary = "not installed in the build env"
if torch_version:
torch_summary = f"torch {torch_version} (CUDA {torch_cuda or 'none'}) in the build env"
raise RuntimeError(
"fussim source builds are not supported under pip build isolation because "
"PyTorch CUDA extensions must compile against the target environment's "
f"Torch, but pip is using {torch_summary}.\n\n"
"Install the intended torch build in your environment first, then build "
"without isolation:\n"
" pip install torch==... --index-url https://download.pytorch.org/whl/<cuda>\n"
" pip install --no-build-isolation --no-binary fussim fussim\n\n"
"For a local checkout, use:\n"
" pip install --no-build-isolation .\n\n"
"If you intentionally prepared a matching isolated build env, set "
"FUSSIM_ALLOW_BUILD_ISOLATION=1 to override this safeguard."
)
def get_cuda_variant_from_torch():
"""Detect CUDA version from PyTorch.
Finds the best compatible variant for the runtime CUDA version:
1. Exact match if available
2. Otherwise, the highest supported variant <= the runtime version
"""
try:
import torch
cuda_version = torch.version.cuda
if cuda_version:
major, minor = cuda_version.split(".")[:2]
runtime = (int(major), int(minor))
# Try exact match first
for variant in SUPPORTED_CUDA_VARIANTS:
if _cuda_variant_to_tuple(variant) == runtime:
return variant
# Find the highest supported variant <= runtime version
# Constrain to same major version (no ABI compat across majors)
best = None
for variant in SUPPORTED_CUDA_VARIANTS:
vt = _cuda_variant_to_tuple(variant)
if vt[0] == runtime[0] and vt <= runtime:
best = variant
return best
except Exception:
pass
return None
# Check for CUDA availability before importing torch extensions
def get_cuda_version():
"""Get CUDA version from nvcc or return None if not available."""
try:
output = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
output = output.decode("utf-8")
# Parse version from output like "release 12.1"
import re
match = re.search(r"release (\d+\.\d+)", output)
if match:
return match.group(1)
except (subprocess.CalledProcessError, FileNotFoundError, OSError):
pass
return None
def find_cuda_home():
"""Find CUDA home directory from various sources."""
# First check environment variables
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home and os.path.isdir(cuda_home):
return cuda_home
# Common Windows paths - check these first as they're most reliable
if sys.platform == "win32":
for cuda_ver in ["13.0", "12.8", "12.6", "12.4", "12.1", "11.8"]:
path = f"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v{cuda_ver}"
if os.path.isdir(path):
return path
# Common Linux paths
for path in ["/usr/local/cuda", "/usr/cuda"]:
if os.path.isdir(path):
return path
# Try to find from nvcc path (may not work in isolated build environments)
try:
import shutil
nvcc_path = shutil.which("nvcc")
if nvcc_path:
# nvcc is typically at CUDA_HOME/bin/nvcc
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
if os.path.isdir(cuda_home):
return cuda_home
except Exception:
pass
return None
def check_cuda_available():
"""Check if CUDA is available for building."""
cuda_version = get_cuda_version()
if cuda_version:
return True
cuda_home = find_cuda_home()
if cuda_home:
return True
return False
# Check if we're building an sdist (source distribution) - CUDA not required for sdist
_building_sdist = "sdist" in sys.argv or "egg_info" in sys.argv
if not _building_sdist:
if (
os.environ.get("FUSSIM_ALLOW_BUILD_ISOLATION") != "1"
and _is_probable_pip_build_isolation()
):
_raise_build_isolation_error()
# Auto-detect and set CUDA_HOME if not already set
_cuda_home = find_cuda_home()
if _cuda_home:
if not os.environ.get("CUDA_HOME"):
os.environ["CUDA_HOME"] = _cuda_home
if not os.environ.get("CUDA_PATH"):
os.environ["CUDA_PATH"] = _cuda_home
if not check_cuda_available():
raise RuntimeError(
"CUDA is required to build fussim but was not found.\n"
"Please ensure:\n"
" 1. NVIDIA CUDA Toolkit is installed\n"
" 2. nvcc is in your PATH, or\n"
" 3. CUDA_HOME environment variable is set\n"
"\n"
"Install CUDA Toolkit from: https://developer.nvidia.com/cuda-downloads"
)
import torch.utils.cpp_extension as cpp_ext # noqa: E402
from torch.utils.cpp_extension import BuildExtension, CUDAExtension # noqa: E402
# If torch didn't find CUDA_HOME but we have it, patch the module
if cpp_ext.CUDA_HOME is None and _cuda_home:
cpp_ext.CUDA_HOME = _cuda_home
else:
# For sdist, we don't need CUDA extensions
_cuda_home = None
# Platform-specific compiler flags (only needed for wheel builds)
IS_WINDOWS = sys.platform == "win32"
if not _building_sdist:
# C++ compiler flags
if IS_WINDOWS:
cxx_flags = ["/O2", "/std:c++17"]
else:
cxx_flags = ["-O3", "-std=c++17"]
def get_cuda_arch_flags():
"""
Get CUDA architecture flags for compilation.
Priority:
1. TORCH_CUDA_ARCH_LIST environment variable
2. Auto-detect from available GPUs
3. Default to common architectures
"""
# Check for user-specified architectures
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", "").strip()
if arch_list:
# User specified architectures - parse and convert to nvcc flags
return parse_cuda_arch_list(arch_list)
# Try to auto-detect GPU architectures
detected = detect_gpu_architectures()
if detected:
return detected
# Fallback to common architectures with PTX for forward compatibility
# This covers: Turing (7.5), Ampere (8.0, 8.6), Ada Lovelace (8.9), Hopper (9.0)
return get_default_arch_flags()
def parse_cuda_arch_list(arch_list):
"""Parse TORCH_CUDA_ARCH_LIST format into nvcc flags."""
flags = []
# Handle semicolon or space separated architectures
archs = arch_list.replace(" ", ";").split(";")
archs = [a.strip() for a in archs if a.strip()]
for arch in archs:
# Handle PTX suffix
has_ptx = "+PTX" in arch.upper()
arch_clean = arch.replace("+PTX", "").replace("+ptx", "").strip()
# Convert decimal format (e.g., "8.9") to nvcc format (e.g., "89")
if "." in arch_clean:
major, minor = arch_clean.split(".")[:2]
arch_num = f"{major}{minor}"
else:
arch_num = arch_clean
# Add gencode flags for better multi-arch support
flags.append(f"-gencode=arch=compute_{arch_num},code=sm_{arch_num}")
if has_ptx:
flags.append(f"-gencode=arch=compute_{arch_num},code=compute_{arch_num}")
return flags
def detect_gpu_architectures():
"""Detect GPU architectures from available CUDA devices."""
try:
import torch
if not torch.cuda.is_available():
return None
archs = set()
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
arch = f"{props.major}{props.minor}"
archs.add(arch)
if not archs:
return None
flags = []
for arch in sorted(archs):
flags.append(f"-gencode=arch=compute_{arch},code=sm_{arch}")
# Add PTX for newest architecture for forward compatibility
newest = max(archs)
flags.append(f"-gencode=arch=compute_{newest},code=compute_{newest}")
return flags
except Exception:
return None
def get_supported_archs(archs):
"""Filter architectures to only those supported by the current CUDA toolkit."""
cuda_version_num = None
try:
# Try to get CUDA version from torch
import torch
cuda_version = torch.version.cuda
if cuda_version:
major, minor = map(int, cuda_version.split(".")[:2])
cuda_version_num = major * 10 + minor
except Exception:
pass
if cuda_version_num is None:
# Try to get from nvcc
cuda_ver = get_cuda_version()
if cuda_ver:
try:
major, minor = map(int, cuda_ver.split(".")[:2])
cuda_version_num = major * 10 + minor
except ValueError:
cuda_version_num = None
if cuda_version_num is None:
# Can't determine version, return conservative subset
return ["75", "80", "86"]
# Map CUDA versions to max supported architecture
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
max_arch_by_cuda = {
110: "80", # CUDA 11.0
111: "86", # CUDA 11.1+
118: "89", # CUDA 11.8
120: "90", # CUDA 12.0
128: "100", # CUDA 12.8 (Blackwell GB100)
129: "120", # CUDA 12.9+ (Blackwell GB20x - RTX 5090)
130: "120", # CUDA 13.0 (Blackwell GB20x - RTX 5090)
}
# Find max supported architecture for this CUDA version
max_arch = "80" # Conservative default
for cuda_ver, arch in sorted(max_arch_by_cuda.items()):
if cuda_version_num >= cuda_ver:
max_arch = arch
# Filter to supported architectures
return [a for a in archs if int(a) <= int(max_arch)]
def get_default_arch_flags():
"""Get default architecture flags covering common GPUs."""
# Default architectures: Turing, Ampere, Ada Lovelace, Hopper, Blackwell
# Using gencode for binary compatibility + PTX for forward compatibility
# Note: Different CUDA toolkit versions support different max architectures:
# - CUDA 11.x: up to sm_86 (Ampere)
# - CUDA 12.0+: up to sm_90 (Hopper)
# - CUDA 12.8+: up to sm_120 (Blackwell)
default_archs = ["75", "80", "86", "89", "90", "100", "120"]
# Filter architectures based on what the current CUDA toolkit supports
supported_archs = get_supported_archs(default_archs)
flags = []
for arch in supported_archs:
flags.append(f"-gencode=arch=compute_{arch},code=sm_{arch}")
# Add PTX for newest for forward compatibility with future GPUs
if supported_archs:
newest = supported_archs[-1]
flags.append(f"-gencode=arch=compute_{newest},code=compute_{newest}")
return flags
if not _building_sdist:
# NVCC compiler flags
nvcc_flags = [
"-O3",
"--maxrregcount=32",
"--use_fast_math",
"-std=c++17",
]
# Add architecture flags
nvcc_flags.extend(get_cuda_arch_flags())
# Windows-specific NVCC flags
if IS_WINDOWS:
nvcc_flags.extend(
[
"--allow-unsupported-compiler", # Allow newer MSVC versions
"-D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH", # Fix STL1002 error with old CUDA
"-Xcompiler",
"/wd4819", # Suppress Unicode warning
"-Xcompiler",
"/wd4251", # Suppress DLL interface warning
]
)
def get_cuda_variants_to_build():
"""Determine which CUDA variants to build based on environment."""
variant_env = os.environ.get("FUSSIM_CUDA_VARIANT", "").strip().lower()
if variant_env == "all":
# Fat wheel: build all variants
return SUPPORTED_CUDA_VARIANTS
elif variant_env in SUPPORTED_CUDA_VARIANTS:
# Specific variant requested
return [variant_env]
else:
# Auto-detect from PyTorch
detected = get_cuda_variant_from_torch()
if detected:
return [detected]
# Fallback: try to detect from nvcc and build single variant
cuda_ver = get_cuda_version()
if cuda_ver:
try:
major, minor = map(int, cuda_ver.split(".")[:2])
runtime = (major, minor)
except ValueError:
runtime = None
if runtime is not None:
best = None
for variant in SUPPORTED_CUDA_VARIANTS:
vt = _cuda_variant_to_tuple(variant)
if vt[0] == runtime[0] and vt <= runtime:
best = variant
if best is not None:
return [best]
# Last resort: build cu130 (most recent)
return ["cu130"]
def get_extensions():
"""Build the list of extension modules."""
if _building_sdist:
return []
variants = get_cuda_variants_to_build()
ext_modules = []
for variant in variants:
ext_modules.append(
CUDAExtension(
name=f"fussim._cuda_{variant}",
sources=["csrc/ssim.cu", "csrc/ssim_fp16.cu", "csrc/ext.cpp"],
include_dirs=["csrc"],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
)
return ext_modules
if _building_sdist:
# For sdist, no extensions or custom build commands
setup()
else:
setup(
ext_modules=get_extensions(),
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=False)},
)