forked from PaddlePaddle/PaddleFleet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
122 lines (105 loc) · 3.68 KB
/
setup.py
File metadata and controls
122 lines (105 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import os
import subprocess
def get_version_from_txt():
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file, "r") as f:
version = f.read().strip()
return version
def custom_version_scheme(version):
base_version = get_version_from_txt()
date_str = (
subprocess.check_output(
["git", "log", "-1", "--format=%cd", "--date=format:%Y%m%d"]
)
.decode()
.strip()
)
return f"{base_version}.dev{date_str}"
def no_local_scheme(version):
return ""
def change_pwd():
"""change_pwd"""
path = os.path.dirname(__file__)
if path:
os.chdir(path)
def setup_ops_extension():
from paddle.utils.cpp_extension import CUDAExtension, setup
from build_utils import get_cuda_version
# 定义 NVCC 编译参数
nvcc_args = [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-maxrregcount=32",
"-lineinfo",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_90a,code=sm_90a",
"-gencode=arch=compute_100,code=sm_100",
"-DNDEBUG",
]
cuda_major, cuda_minor = get_cuda_version()
if cuda_major < 12:
raise ValueError(
f"CUDA version must be >= 12. Detected version: {cuda_major}.{cuda_minor}"
)
if cuda_major == 12 and cuda_minor < 8:
nvcc_args = [arg for arg in nvcc_args if "compute_100" not in arg]
ext_module = CUDAExtension(
sources=[
# cpp files
# cuda files
"./src/paddlefleet/_extensions/tokens_stable_unzip.cu",
"./src/paddlefleet/_extensions/tokens_unzip_gather.cu",
"./src/paddlefleet/_extensions/tokens_zip_unique_add.cu",
"./src/paddlefleet/_extensions/tokens_zip_prob.cu",
"./src/paddlefleet/_extensions/merge_subbatch_cast.cu",
"./src/paddlefleet/_extensions/tokens_unzip_slice.cu",
"./src/paddlefleet/_extensions/fuse_swiglu_scale.cu",
"./src/paddlefleet/_extensions/swiglu_kernel.cu",
],
include_dirs=[
os.path.join(os.getcwd(), "src/paddlefleet/_extensions"),
],
extra_compile_args={
"cxx": [
"-O3",
"-w",
"-Wno-abi",
"-fPIC",
"-std=c++17",
],
"nvcc": nvcc_args,
},
)
change_pwd()
setup(
name="paddlefleet._extensions.ops",
ext_modules=[ext_module],
use_scm_version={
"version_scheme": custom_version_scheme,
"local_scheme": no_local_scheme,
},
setup_requires=["setuptools_scm"],
)
setup_ops_extension()