Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0124107

Browse files
oralubenforrestl111
authored andcommitted
Build extensions in parallel (NVIDIA/apex#1882)
1 parent 963e14e commit 0124107

1 file changed

Lines changed: 40 additions & 2 deletions

File tree

setup.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
import os
1818
import subprocess
19+
import threading
1920
from packaging.version import parse, Version
20-
from typing import List, Set
2121
import warnings
2222

2323
from setuptools import setup, find_packages
@@ -179,6 +179,44 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
179179
)
180180
ext_modules.append(fused_extension)
181181

182+
183+
parallel = None
184+
if 'EXT_PARALLEL' in os.environ:
185+
try:
186+
parallel = int(os.getenv('EXT_PARALLEL'))
187+
finally:
188+
pass
189+
190+
191+
# Prevent file conflicts when multiple extensions are compiled simultaneously
192+
class BuildExtensionSeparateDir(BuildExtension):
193+
build_extension_patch_lock = threading.Lock()
194+
thread_ext_name_map = {}
195+
196+
def finalize_options(self):
197+
if parallel is not None:
198+
self.parallel = parallel
199+
super().finalize_options()
200+
201+
def build_extension(self, ext):
202+
with self.build_extension_patch_lock:
203+
if not getattr(self.compiler, "_compile_separate_output_dir", False):
204+
compile_orig = self.compiler.compile
205+
206+
def compile_new(*args, **kwargs):
207+
return compile_orig(*args, **{
208+
**kwargs,
209+
"output_dir": os.path.join(
210+
kwargs["output_dir"],
211+
self.thread_ext_name_map[threading.current_thread().ident]),
212+
})
213+
self.compiler.compile = compile_new
214+
self.compiler._compile_separate_output_dir = True
215+
self.thread_ext_name_map[threading.current_thread().ident] = ext.name
216+
objects = super().build_extension(ext)
217+
return objects
218+
219+
182220
setup(
183221
name='sageattention',
184222
version='2.2.0',
@@ -191,5 +229,5 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
191229
packages=find_packages(),
192230
python_requires='>=3.9',
193231
ext_modules=ext_modules,
194-
cmdclass={"build_ext": BuildExtension},
232+
cmdclass={"build_ext": BuildExtensionSeparateDir} if ext_modules else {},
195233
)

0 commit comments

Comments
 (0)