Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f0df092

Browse files
committed
build extensions in parallel
1 parent b216eee commit f0df092

2 files changed

Lines changed: 36 additions & 2 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ CUDA and C++ extensions via
130130
git clone https://github.com/NVIDIA/apex
131131
cd apex
132132
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
133-
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
133+
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 4" ./
134134
# otherwise
135135
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
136136
```

setup.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import warnings
33
import os
4+
import threading
45
import glob
56
from packaging.version import parse, Version
67

@@ -859,6 +860,39 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
859860
)
860861

861862

863+
# Patch because `setup.py bdist_wheel` does not accept the `parallel` option
864+
parallel = None
865+
if "--parallel" in sys.argv:
866+
idx = sys.argv.index("--parallel")
867+
parallel = int(sys.argv[idx + 1])
868+
sys.argv.pop(idx + 1)
869+
sys.argv.pop(idx)
870+
871+
872+
# Prevent file conflicts when multiple extensions are compiled simultaneously
873+
class BuildExtensionSeparateDir(BuildExtension):
874+
build_extension_patch_lock = threading.Lock()
875+
thread_ext_name_map = {}
876+
877+
def build_extension(self, ext):
878+
with self.build_extension_patch_lock:
879+
if not getattr(self.compiler, "_compile_separate_output_dir", False):
880+
compile_orig = self.compiler.compile
881+
882+
def compile_new(*args, **kwargs):
883+
return compile_orig(*args, **{
884+
**kwargs,
885+
"output_dir": os.path.join(
886+
kwargs["output_dir"],
887+
self.thread_ext_name_map[threading.current_thread().ident]),
888+
})
889+
self.compiler.compile = compile_new
890+
self.compiler._compile_separate_output_dir = True
891+
self.thread_ext_name_map[threading.current_thread().ident] = ext.name
892+
objects = super().build_extension(ext)
893+
return objects
894+
895+
862896
setup(
863897
name="apex",
864898
version="0.1",
@@ -868,6 +902,6 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
868902
install_requires=["packaging>20.6"],
869903
description="PyTorch Extensions written by NVIDIA",
870904
ext_modules=ext_modules,
871-
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
905+
cmdclass={"build_ext": BuildExtensionSeparateDir.with_options(parallel=parallel)} if ext_modules else {},
872906
extras_require=extras,
873907
)

0 commit comments

Comments
 (0)