2323import paddle .nn .functional as F
2424from paddle import Tensor
2525from paddle .distributed import fleet
26+ from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
2627from paddle .distributed .fleet .utils import recompute
2728
2829from ...utils .converter import StateDictNameMapping , init_name_mappings
@@ -183,7 +184,12 @@ def forward(self, hidden_states: Tensor, ltor_mask: Tensor, cache: Tensor = None
183184
184185 attention_scores = attention_scores + (- 65504.0 ) * (1.0 - ltor_mask )
185186 attention_probs = F .softmax (attention_scores , axis = - 1 )
186- attention_probs = self .attention_dropout (attention_probs )
187+
188+ if "local_seed" in get_rng_state_tracker ().states_ :
189+ with get_rng_state_tracker ().rng_state ("local_seed" ):
190+ attention_probs = self .attention_dropout (attention_probs )
191+ else :
192+ attention_probs = self .attention_dropout (attention_probs )
187193
188194 # [bs, num_head, seq_len, seq_len(+cache_len)] * [bs, num_head, seq_len(+cache_len), head_dim]
189195 # [bs, num_head, seq_len, head_dim]
@@ -194,7 +200,12 @@ def forward(self, hidden_states: Tensor, ltor_mask: Tensor, cache: Tensor = None
194200 new_context_shape = context_layer .shape [:- 2 ] + [self .num_attention_heads * self .attention_head_size ]
195201 context_layer = context_layer .reshape (new_context_shape )
196202 output = self .dense (context_layer )
197- output = self .output_dropout (output )
203+
204+ if "global_seed" in get_rng_state_tracker ().states_ :
205+ with get_rng_state_tracker ().rng_state ("global_seed" ):
206+ output = self .output_dropout (output )
207+ else :
208+ output = self .output_dropout (output )
198209
199210 return output
200211
@@ -257,7 +268,13 @@ def forward(self, hidden_states):
257268
258269 # [batch_size, sequence_length, h]
259270 output = self .dense_4h_to_h (intermediate_parallel )
260- output = self .dropout (output )
271+
272+ if "global_seed" in get_rng_state_tracker ().states_ :
273+ with get_rng_state_tracker ().rng_state ("global_seed" ):
274+ output = self .dropout (output )
275+ else :
276+ output = self .dropout (output )
277+
261278 return output
262279
263280
@@ -359,7 +376,12 @@ def build_mask_matrix(seq_length, sep, memory_length=0):
359376 if self .block_position_encoding :
360377 block_position_embeddings = self .block_position_embeddings (block_position_ids )
361378 hidden_states = hidden_states + block_position_embeddings
362- hidden_states = self .embedding_dropout (hidden_states )
379+
380+ if "local_seed" in get_rng_state_tracker ().states_ :
381+ with get_rng_state_tracker ().rng_state ("local_seed" ):
382+ hidden_states = self .embedding_dropout (hidden_states )
383+ else :
384+ hidden_states = self .embedding_dropout (hidden_states )
363385
364386 all_hidden_states = [hidden_states .detach ()]
365387 for i , layer in enumerate (self .layers ):
0 commit comments