1616
1717import os
1818import subprocess
19+ import threading
1920from packaging .version import parse , Version
20- from typing import List , Set
2121import warnings
2222
2323from setuptools import setup , find_packages
@@ -179,6 +179,44 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
179179)
180180ext_modules .append (fused_extension )
181181
182+
183+ parallel = None
184+ if 'EXT_PARALLEL' in os .environ :
185+ try :
186+ parallel = int (os .getenv ('EXT_PARALLEL' ))
187+ finally :
188+ pass
189+
190+
191+ # Prevent file conflicts when multiple extensions are compiled simultaneously
192+ class BuildExtensionSeparateDir (BuildExtension ):
193+ build_extension_patch_lock = threading .Lock ()
194+ thread_ext_name_map = {}
195+
196+ def finalize_options (self ):
197+ if parallel is not None :
198+ self .parallel = parallel
199+ super ().finalize_options ()
200+
201+ def build_extension (self , ext ):
202+ with self .build_extension_patch_lock :
203+ if not getattr (self .compiler , "_compile_separate_output_dir" , False ):
204+ compile_orig = self .compiler .compile
205+
206+ def compile_new (* args , ** kwargs ):
207+ return compile_orig (* args , ** {
208+ ** kwargs ,
209+ "output_dir" : os .path .join (
210+ kwargs ["output_dir" ],
211+ self .thread_ext_name_map [threading .current_thread ().ident ]),
212+ })
213+ self .compiler .compile = compile_new
214+ self .compiler ._compile_separate_output_dir = True
215+ self .thread_ext_name_map [threading .current_thread ().ident ] = ext .name
216+ objects = super ().build_extension (ext )
217+ return objects
218+
219+
182220setup (
183221 name = 'sageattention' ,
184222 version = '2.2.0' ,
@@ -191,5 +229,5 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
191229 packages = find_packages (),
192230 python_requires = '>=3.9' ,
193231 ext_modules = ext_modules ,
194- cmdclass = {"build_ext" : BuildExtension },
232+ cmdclass = {"build_ext" : BuildExtensionSeparateDir } if ext_modules else { },
195233)
0 commit comments