Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5c8860f

Browse files
authored
fix compile group norm [v1/v2] (#1927)
* fix group norm * merge upstream code
1 parent eb1e7d9 commit 5c8860f

1 file changed

Lines changed: 28 additions & 23 deletions

File tree

setup.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -512,31 +512,36 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
512512
)
513513

514514
# CUDA group norm V2 is tested on SM100
515-
if bare_metal_version >= Version("12.8"):
516-
arch_flags = ["-gencode=arch=compute_100,code=sm_100"]
517-
else:
518-
arch_flags = ["-gencode=arch=compute_90,code=compute_90"]
515+
if bare_metal_version >= Version("12.4"):
516+
if bare_metal_version >= Version("12.8"):
517+
arch_flags = [
518+
"-gencode=arch=compute_90,code=sm_90",
519+
"-gencode=arch=compute_100,code=sm_100",
520+
"-gencode=arch=compute_120,code=compute_120",
521+
]
522+
else:
523+
arch_flags = ["-gencode=arch=compute_90,code=compute_90"]
519524

520-
ext_modules.append(
521-
CUDAExtension(
522-
name="group_norm_v2_cuda",
523-
sources=[
524-
"apex/contrib/csrc/group_norm_v2/gn.cpp",
525-
"apex/contrib/csrc/group_norm_v2/gn_cuda.cu",
526-
"apex/contrib/csrc/group_norm_v2/gn_utils.cpp",
527-
] + glob.glob("apex/contrib/csrc/group_norm_v2/gn_cuda_inst_*.cu"),
528-
extra_compile_args={
529-
"cxx": ["-O2"],
530-
"nvcc": [
531-
"-O2", "--use_fast_math", "--ftz=false",
532-
"-U__CUDA_NO_HALF_CONVERSIONS__",
533-
"-U__CUDA_NO_HALF_OPERATORS__",
534-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
535-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
536-
] + arch_flags,
537-
},
525+
ext_modules.append(
526+
CUDAExtension(
527+
name="group_norm_v2_cuda",
528+
sources=[
529+
"apex/contrib/csrc/group_norm_v2/gn.cpp",
530+
"apex/contrib/csrc/group_norm_v2/gn_cuda.cu",
531+
"apex/contrib/csrc/group_norm_v2/gn_utils.cpp",
532+
] + glob.glob("apex/contrib/csrc/group_norm_v2/gn_cuda_inst_*.cu"),
533+
extra_compile_args={
534+
"cxx": ["-O2"],
535+
"nvcc": [
536+
"-O2", "--use_fast_math", "--ftz=false",
537+
"-U__CUDA_NO_HALF_CONVERSIONS__",
538+
"-U__CUDA_NO_HALF_OPERATORS__",
539+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
540+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
541+
] + arch_flags,
542+
},
543+
)
538544
)
539-
)
540545

541546
if has_flag("--index_mul_2d", "APEX_INDEX_MUL_2D"):
542547
if "--index_mul_2d" in sys.argv:

0 commit comments

Comments
 (0)