Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6a3a14b

Browse files
authored
[GLM/GPT3] Fix bugs in modeling of GLM/GPT3 (PaddlePaddle#5607)
* fix glm modeling * update * update * update
1 parent ef1d1f9 commit 6a3a14b

3 files changed

Lines changed: 30 additions & 8 deletions

File tree

โ€Žexamples/language_model/gpt-3/dygraph/modeling.pyโ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ def __init__(
11131113
hidden_dropout_prob=hidden_dropout_prob,
11141114
max_position_embeddings=max_position_embeddings,
11151115
type_vocab_size=type_vocab_size,
1116-
initializer_range=0.02,
1116+
initializer_range=initializer_range,
11171117
)
11181118
)
11191119

@@ -1151,7 +1151,7 @@ def _logits_helper(embedding, output):
11511151
hidden_dropout_prob=hidden_dropout_prob,
11521152
max_position_embeddings=max_position_embeddings,
11531153
type_vocab_size=type_vocab_size,
1154-
initializer_range=0.02,
1154+
initializer_range=initializer_range,
11551155
)
11561156
)
11571157

โ€Žmodel_zoo/gpt-3/ppfleetx/models/language_model/gpt/dygraph/hybrid_model.pyโ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ def __init__(
10651065
hidden_dropout_prob=hidden_dropout_prob,
10661066
max_position_embeddings=max_position_embeddings,
10671067
type_vocab_size=type_vocab_size,
1068-
initializer_range=0.02,
1068+
initializer_range=initializer_range,
10691069
sequence_parallel=sequence_parallel,
10701070
)
10711071
)
@@ -1118,7 +1118,7 @@ def _logits_helper(embedding, output):
11181118
hidden_dropout_prob=hidden_dropout_prob,
11191119
max_position_embeddings=max_position_embeddings,
11201120
type_vocab_size=type_vocab_size,
1121-
initializer_range=0.02,
1121+
initializer_range=initializer_range,
11221122
)
11231123
)
11241124

โ€Žpaddlenlp/transformers/glm/modeling.pyโ€Ž

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import paddle.nn.functional as F
2424
from paddle import Tensor
2525
from paddle.distributed import fleet
26+
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
2627
from paddle.distributed.fleet.utils import recompute
2728

2829
from ...utils.converter import StateDictNameMapping, init_name_mappings
@@ -183,7 +184,12 @@ def forward(self, hidden_states: Tensor, ltor_mask: Tensor, cache: Tensor = None
183184

184185
attention_scores = attention_scores + (-65504.0) * (1.0 - ltor_mask)
185186
attention_probs = F.softmax(attention_scores, axis=-1)
186-
attention_probs = self.attention_dropout(attention_probs)
187+
188+
if "local_seed" in get_rng_state_tracker().states_:
189+
with get_rng_state_tracker().rng_state("local_seed"):
190+
attention_probs = self.attention_dropout(attention_probs)
191+
else:
192+
attention_probs = self.attention_dropout(attention_probs)
187193

188194
# [bs, num_head, seq_len, seq_len(+cache_len)] * [bs, num_head, seq_len(+cache_len), head_dim]
189195
# [bs, num_head, seq_len, head_dim]
@@ -194,7 +200,12 @@ def forward(self, hidden_states: Tensor, ltor_mask: Tensor, cache: Tensor = None
194200
new_context_shape = context_layer.shape[:-2] + [self.num_attention_heads * self.attention_head_size]
195201
context_layer = context_layer.reshape(new_context_shape)
196202
output = self.dense(context_layer)
197-
output = self.output_dropout(output)
203+
204+
if "global_seed" in get_rng_state_tracker().states_:
205+
with get_rng_state_tracker().rng_state("global_seed"):
206+
output = self.output_dropout(output)
207+
else:
208+
output = self.output_dropout(output)
198209

199210
return output
200211

@@ -257,7 +268,13 @@ def forward(self, hidden_states):
257268

258269
# [batch_size, sequence_length, h]
259270
output = self.dense_4h_to_h(intermediate_parallel)
260-
output = self.dropout(output)
271+
272+
if "global_seed" in get_rng_state_tracker().states_:
273+
with get_rng_state_tracker().rng_state("global_seed"):
274+
output = self.dropout(output)
275+
else:
276+
output = self.dropout(output)
277+
261278
return output
262279

263280

@@ -359,7 +376,12 @@ def build_mask_matrix(seq_length, sep, memory_length=0):
359376
if self.block_position_encoding:
360377
block_position_embeddings = self.block_position_embeddings(block_position_ids)
361378
hidden_states = hidden_states + block_position_embeddings
362-
hidden_states = self.embedding_dropout(hidden_states)
379+
380+
if "local_seed" in get_rng_state_tracker().states_:
381+
with get_rng_state_tracker().rng_state("local_seed"):
382+
hidden_states = self.embedding_dropout(hidden_states)
383+
else:
384+
hidden_states = self.embedding_dropout(hidden_states)
363385

364386
all_hidden_states = [hidden_states.detach()]
365387
for i, layer in enumerate(self.layers):

0 commit comments

Comments
ย (0)