@@ -512,31 +512,36 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
512512 )
513513
514514 # CUDA group norm V2 is tested on SM100
515- if bare_metal_version >= Version ("12.8" ):
516- arch_flags = ["-gencode=arch=compute_100,code=sm_100" ]
517- else :
518- arch_flags = ["-gencode=arch=compute_90,code=compute_90" ]
515+ if bare_metal_version >= Version ("12.4" ):
516+ if bare_metal_version >= Version ("12.8" ):
517+ arch_flags = [
518+ "-gencode=arch=compute_90,code=sm_90" ,
519+ "-gencode=arch=compute_100,code=sm_100" ,
520+ "-gencode=arch=compute_120,code=compute_120" ,
521+ ]
522+ else :
523+ arch_flags = ["-gencode=arch=compute_90,code=compute_90" ]
519524
520- ext_modules .append (
521- CUDAExtension (
522- name = "group_norm_v2_cuda" ,
523- sources = [
524- "apex/contrib/csrc/group_norm_v2/gn.cpp" ,
525- "apex/contrib/csrc/group_norm_v2/gn_cuda.cu" ,
526- "apex/contrib/csrc/group_norm_v2/gn_utils.cpp" ,
527- ] + glob .glob ("apex/contrib/csrc/group_norm_v2/gn_cuda_inst_*.cu" ),
528- extra_compile_args = {
529- "cxx" : ["-O2" ],
530- "nvcc" : [
531- "-O2" , "--use_fast_math" , "--ftz=false" ,
532- "-U__CUDA_NO_HALF_CONVERSIONS__" ,
533- "-U__CUDA_NO_HALF_OPERATORS__" ,
534- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__" ,
535- "-U__CUDA_NO_BFLOAT16_OPERATORS__" ,
536- ] + arch_flags ,
537- },
525+ ext_modules .append (
526+ CUDAExtension (
527+ name = "group_norm_v2_cuda" ,
528+ sources = [
529+ "apex/contrib/csrc/group_norm_v2/gn.cpp" ,
530+ "apex/contrib/csrc/group_norm_v2/gn_cuda.cu" ,
531+ "apex/contrib/csrc/group_norm_v2/gn_utils.cpp" ,
532+ ] + glob .glob ("apex/contrib/csrc/group_norm_v2/gn_cuda_inst_*.cu" ),
533+ extra_compile_args = {
534+ "cxx" : ["-O2" ],
535+ "nvcc" : [
536+ "-O2" , "--use_fast_math" , "--ftz=false" ,
537+ "-U__CUDA_NO_HALF_CONVERSIONS__" ,
538+ "-U__CUDA_NO_HALF_OPERATORS__" ,
539+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__" ,
540+ "-U__CUDA_NO_BFLOAT16_OPERATORS__" ,
541+ ] + arch_flags ,
542+ },
543+ )
538544 )
539- )
540545
541546if has_flag ("--index_mul_2d" , "APEX_INDEX_MUL_2D" ):
542547 if "--index_mul_2d" in sys .argv :
0 commit comments