-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
55 lines (49 loc) · 1.46 KB
/
setup.py
File metadata and controls
55 lines (49 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from setuptools import setup
from pybind11.setup_helpers import Pybind11Extension, build_ext
import os
import subprocess
def get_cuda_arch():
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], capture_output=True, text=True
)
arch = result.stdout.strip().replace(".", "")
return f"sm_{arch}" if arch else "sm_80"
except:
return "sm_80"
cuda_arch = get_cuda_arch()
ext_modules = [
Pybind11Extension(
"sloth._sloth",
sources=[
"common/cpp/bindings.cpp",
"common/cpp/utils/timer.cpp",
"common/cpp/utils/checker.cpp",
"operators/01_elementwise/add/cute/add_kernel.cu",
"operators/01_elementwise/add/cutlass/add_kernel.cu",
],
include_dirs=[
"common/cpp/include",
"operators",
os.path.join(os.environ.get("CUTLASS_PATH", "third_party/cutlass"), "include"),
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17", f"-arch={cuda_arch}", "--expt-relaxed-constexpr"],
},
libraries=["cudart"],
),
]
setup(
name="sloth",
version="0.1.0",
ext_modules=ext_modules,
cmdclass={"build_ext": build_ext},
packages=["sloth"],
package_dir={"": "common/python"},
install_requires=[
"torch>=2.0.0",
"numpy>=1.24.0",
],
zip_safe=False,
)