@@ -88,28 +88,37 @@ def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_inp
8888 inc .left_right_halo_exchange_inplace (self .handle , self .left_rank , self .right_rank , left_output_halo , right_output_halo , left_input_halo , right_input_halo )
8989
9090class HaloExchangerPeer (HaloExchanger ):
91- def __init__ (self , ranks , rank_in_group , peer_pool , explicit_nhwc , numSM = 1 ):
91+ def __init__ (self , ranks , rank_in_group , peer_pool , explicit_nhwc , numSM = 0 ):
9292 super (HaloExchangerPeer , self ).__init__ (ranks , rank_in_group )
9393 self .diagnostics = False
9494 self .explicit_nhwc = explicit_nhwc
9595 self .numSM = numSM
9696 self .peer_pool = peer_pool
97- self .signals = peer_pool .allocate_peer_tensors ([2 ,4 ], torch .int32 , False , False )
98- self .signals [self .rank_in_group ].zero_ ()
97+
98+ def _allocate_peer_tensor (self , halo ):
99+
100+ # Compute size in bytes
101+ # Note: Pad buffer so each CUDA block gets required buffer size
102+ size = 4 * halo .numel () * halo .element_size ()
103+ size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
104+ size = (size + size_per_block - 1 ) // size_per_block * size_per_block
105+
106+ # Construct dtype peer buffer with desired size
107+ shape = [1 , 1 , 1 , size // halo .element_size ()]
108+ return self .peer_pool .allocate_peer_tensors (shape , halo .dtype , False , True )
99109
100110 def left_right_halo_exchange (self , left_output_halo , right_output_halo , left_input_halo = None , right_input_halo = None ):
101111 inplace = False if left_input_halo is None and right_input_halo is None else True
102112 if not inplace :
103113 left_input_halo = torch .empty_like (right_output_halo )
104114 right_input_halo = torch .empty_like (left_output_halo )
105115 channels_last = left_output_halo .is_contiguous (memory_format = torch .channels_last ) and not self .explicit_nhwc
106- left_tx = self .peer_pool . allocate_peer_tensors ( list ( left_output_halo . shape ), left_output_halo . dtype , channels_last , True )
107- right_tx = self .peer_pool . allocate_peer_tensors ( list ( right_output_halo . shape ), right_output_halo . dtype , channels_last , True )
116+ left_tx = self ._allocate_peer_tensor ( left_input_halo )
117+ right_tx = self ._allocate_peer_tensor ( right_input_halo )
108118 pm .push_pull_halos_1d (
109- self .diagnostics , self .explicit_nhwc , self .numSM ,
119+ self .diagnostics , self .explicit_nhwc , self .numSM , self . rank_in_group ,
110120 self .left_zero , left_output_halo , left_tx [self .rank_in_group ], right_tx [self .wrap_around_left_rank_in_group ], left_input_halo ,
111121 self .right_zero , right_output_halo , right_tx [self .rank_in_group ], left_tx [self .wrap_around_right_rank_in_group ], right_input_halo ,
112- self .signals [self .wrap_around_left_rank_in_group ], self .signals [self .wrap_around_right_rank_in_group ], self .signals [self .rank_in_group ]
113122 )
114123 if not inplace :
115124 return left_input_halo , right_input_halo
0 commit comments