1717"""Model and data parallel groups."""
1818from typing import Tuple , Optional
1919import warnings
20-
20+ import os
2121import torch
2222
2323from apex .transformer .log_util import get_transformer_logger
@@ -80,6 +80,77 @@ def is_unitialized():
8080 """Useful for code segments that may be accessed with or without mpu initialization"""
8181 return _DATA_PARALLEL_GROUP is None
8282
83+ def set_nccl_socket_envs ():
84+ if os .getenv ("NCCL_SOCKET_IFNAME" ) is None :
85+ raise RuntimeError ("NCCL_SOCKET_IFNAME was not set" )
86+ os .environ ["NCCL_NET" ] = "Socket"
87+
88+ def set_nccl_ib_envs ():
89+ os .environ ["NCCL_NET" ] = "IB"
90+
91+ def init_nccl_net (group ):
92+ temp = torch .ones (1 , device = "cuda" )
93+ torch .distributed .all_reduce (temp , group = group )
94+ torch .cuda .synchronize ()
95+
96+ def new_nccl_socket_group (ranks ):
97+ set_nccl_socket_envs ()
98+ group = torch .distributed .new_group (ranks , backend = "nccl" )
99+ init_nccl_net (group = group )
100+ return group
101+
102+ def new_nccl_ib_group (ranks ):
103+ set_nccl_ib_envs ()
104+ group = torch .distributed .new_group (ranks , backend = "nccl" )
105+ init_nccl_net (group = group )
106+ return group
107+
108+ def new_process_group (ranks , backend ):
109+ """
110+ This function creates process groups.
111+
112+ In addition to simply creating the process groups, it initializes NCCL
113+ for hybrid IB/Socket network like in the following diagram:
114+
115+ ____________
116+ [GPU Node 0]---TCP---| |---TCP---[GPU Node 2]
117+ | | | |
118+ | | | |
119+ IB | IP Network | IB
120+ | | | |
121+ | | | |
122+ [GPU Node 1]---TCP---|____________|---TCP---[GPU Node 3]
123+
124+
125+ If an environment variable NUM_GPUS_PER_IB_BLOCK is defined it looks up the ranks
126+ and determines whether the list of ranks belong to the same computational block where
127+ GPUs nodes are interconnected via IB type of connection or not.
128+ If all ranks are in the same block, the process group will use NCCL_NET=IB for
129+ communication, otherwise it will use NCCL_NET=Socket.
130+
131+ If NCCL_NET=Socket is ever to be used, the user must set NCCL_SOCKET_IFNAME.
132+ Additionally, it is recommended to set NCCL_SOCKET_NTHREADS and
133+ NCCL_NSOCKS_PERTHREAD before running the job.
134+ See: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
135+ for more info
136+
137+ The core assumption for this functionality is that the ranks are evenly divided
138+ into IB blocks and all these IB blocks are of the same size.
139+ """
140+ if backend is None :
141+ backend = "nccl"
142+
143+ compute_block_size = os .getenv ("NUM_GPUS_PER_IB_BLOCK" )
144+ if backend == "nccl" and compute_block_size is not None :
145+ compute_block_size = int (compute_block_size )
146+ blocks = [rank // compute_block_size for rank in ranks ]
147+ use_ib = all (block == blocks [0 ] for block in blocks )
148+ if use_ib :
149+ return new_nccl_ib_group (ranks )
150+ else :
151+ return new_nccl_socket_group (ranks )
152+ else :
153+ return torch .distributed .new_group (ranks , backend = backend )
83154
84155def initialize_model_parallel (
85156 tensor_model_parallel_size_ : int = 1 ,
@@ -139,6 +210,9 @@ def initialize_model_parallel(
139210 if default_backend == "ucc" :
140211 warnings .warn ("The UCC's functionality as `default_backend` is not well verified" , ExperimentalWarning )
141212
213+ # Saving the NCCL_NET type for reusing it at the epilogue
214+ default_nccl_net = os .getenv ("NCCL_NET" )
215+
142216 world_size : int = torch .distributed .get_world_size ()
143217 tensor_model_parallel_size : int = min (tensor_model_parallel_size_ , world_size )
144218 pipeline_model_parallel_size : int = min (pipeline_model_parallel_size_ , world_size )
@@ -199,7 +273,7 @@ def initialize_model_parallel(
199273 for j in range (tensor_model_parallel_size ):
200274 ranks = range (start_rank + j , end_rank , tensor_model_parallel_size )
201275 all_data_parallel_group_ranks .append (list (ranks ))
202- group = torch . distributed . new_group (ranks , backend = default_backend )
276+ group = new_process_group (ranks , backend = default_backend )
203277 if rank in ranks :
204278 _DATA_PARALLEL_GROUP = group
205279
@@ -225,7 +299,7 @@ def initialize_model_parallel(
225299 data_parallel_group_ranks [i ]
226300 for data_parallel_group_ranks in all_data_parallel_group_ranks
227301 ]
228- group = torch . distributed . new_group (ranks , backend = default_backend )
302+ group = new_process_group (ranks , backend = default_backend )
229303 if rank in ranks :
230304 _MODEL_PARALLEL_GROUP = group
231305
@@ -238,7 +312,7 @@ def initialize_model_parallel(
238312 ranks = list (
239313 range (i * tensor_model_parallel_size , (i + 1 ) * tensor_model_parallel_size )
240314 )
241- group = torch . distributed . new_group (ranks , backend = default_backend )
315+ group = new_process_group (ranks , backend = default_backend )
242316 if rank in ranks :
243317 _TENSOR_MODEL_PARALLEL_GROUP = group
244318
@@ -266,7 +340,7 @@ def initialize_model_parallel(
266340 'relative position embedding group is already initialized'
267341 for i in range (num_pipeline_model_parallel_groups ):
268342 ranks = range (i , world_size , num_pipeline_model_parallel_groups )
269- group = torch . distributed . new_group (ranks , backend = p2p_backend )
343+ group = new_process_group (ranks , backend = p2p_backend )
270344 if rank in ranks :
271345 _PIPELINE_MODEL_PARALLEL_GROUP = group
272346 _PIPELINE_GLOBAL_RANKS = ranks
@@ -304,28 +378,28 @@ def initialize_model_parallel(
304378 encoder_relative_position_embedding_ranks = ranks
305379 decoder_relative_position_embedding_ranks = ranks
306380
307- group = torch . distributed . new_group (embedding_ranks , backend = default_backend )
381+ group = new_process_group (embedding_ranks , backend = p2p_backend )
308382 if rank in embedding_ranks :
309383 _EMBEDDING_GROUP = group
310384 if rank in ranks :
311385 _EMBEDDING_GLOBAL_RANKS = embedding_ranks
312386
313- group = torch . distributed . new_group (position_embedding_ranks , backend = default_backend )
387+ group = new_process_group (position_embedding_ranks , backend = p2p_backend )
314388 if rank in position_embedding_ranks :
315389 _POSITION_EMBEDDING_GROUP = group
316390 if rank in ranks :
317391 _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
318392
319393 if encoder_relative_position_embedding_ranks :
320- group = torch . distributed . new_group (encoder_relative_position_embedding_ranks )
394+ group = new_process_group (encoder_relative_position_embedding_ranks , backend = p2p_backend )
321395 if rank in encoder_relative_position_embedding_ranks :
322396 _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
323397 if rank in ranks :
324398 _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
325399 encoder_relative_position_embedding_ranks
326400
327401 if decoder_relative_position_embedding_ranks :
328- group = torch . distributed . new_group (decoder_relative_position_embedding_ranks )
402+ group = new_process_group (decoder_relative_position_embedding_ranks , backend = p2p_backend )
329403 if rank in decoder_relative_position_embedding_ranks :
330404 _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
331405 if rank in ranks :
@@ -335,6 +409,14 @@ def initialize_model_parallel(
335409 if init_mpi_proc_group :
336410 torch .distributed .new_group (backend = 'mpi' )
337411
412+ if default_nccl_net == "Socket" :
413+ set_nccl_socket_envs ()
414+ elif default_nccl_net == "IB" :
415+ set_nccl_ib_envs ()
416+ elif default_nccl_net is None :
417+ os .unsetenv ("NCCL_NET" )
418+ else :
419+ os .environ ["NCCL_NET" ] = default_nccl_net
338420
339421def get_rank_info () -> Tuple [int , int , int ]:
340422 """Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger."""
0 commit comments