Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 58db4d0

Browse files
authored
GPT. use ClipGradByGlobalNorm, update unittest (PaddlePaddle#693)
1 parent 8865c32 commit 58db4d0

4 files changed

Lines changed: 56 additions & 19 deletions

File tree

examples/language_model/gpt/args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,12 @@ def parse_args(MODEL_CLASSES):
244244
default="gpu",
245245
choices=["cpu", "gpu", "xpu"],
246246
help="select cpu, gpu, xpu devices.")
247-
247+
parser.add_argument(
248+
"--lr_decay_style",
249+
type=str,
250+
default="cosine",
251+
choices=["cosine", "none"],
252+
help="Learning rate decay style.")
248253
args = parser.parse_args()
249254
args.test_iters = args.eval_iters * 10
250255

examples/language_model/gpt/run_pretrain.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,21 @@ def do_train(args):
125125
if args.decay_steps is None:
126126
args.decay_steps = args.max_steps
127127
warmup_step = args.warmup_rate * args.decay_steps
128-
lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
129-
max_lr=args.max_lr,
130-
min_lr=args.min_lr,
131-
warmup_step=warmup_step,
132-
decay_step=args.decay_steps)
128+
129+
lr_scheduler = None
130+
131+
if args.lr_decay_style == "none":
132+
lr_scheduler = None
133+
elif args.lr_decay_style == "cosine":
134+
lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
135+
max_lr=args.max_lr,
136+
min_lr=args.min_lr,
137+
warmup_step=warmup_step,
138+
decay_step=args.decay_steps)
133139

134140
clip = None
135141
if args.grad_clip > 0:
136-
clip = paddle.nn.ClipGradByNorm(clip_norm=args.grad_clip)
142+
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip)
137143

138144
# Generate parameter names needed to perform weight decay.
139145
# All bias and LayerNorm parameters are excluded.
@@ -142,7 +148,7 @@ def do_train(args):
142148
if not any(nd in n for nd in ["bias", "norm"])
143149
]
144150
optimizer = paddle.optimizer.AdamW(
145-
learning_rate=lr_scheduler,
151+
learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr,
146152
beta1=args.adam_beta1,
147153
beta2=args.adam_beta2,
148154
epsilon=args.adam_epsilon,
@@ -206,7 +212,8 @@ def do_train(args):
206212
tic_train = time.time()
207213
loss.backward()
208214
optimizer.step()
209-
lr_scheduler.step()
215+
if lr_scheduler is not None:
216+
lr_scheduler.step()
210217
optimizer.clear_grad()
211218

212219
if args.check_accuracy:

examples/language_model/gpt/run_pretrain_static.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def dist_optimizer(args, topo):
8686
if args.use_amp:
8787
dist_strategy.amp = True
8888
dist_strategy.amp_configs = {
89-
"custom_white_list": ['softmax', 'layer_norm', 'gelu'],
89+
"custom_white_list": [
90+
'softmax',
91+
'layer_norm',
92+
'gelu',
93+
],
94+
"custom_black_list": ['c_softmax_with_cross_entropy'],
9095
"init_loss_scaling": 32768,
9196
"use_dynamic_loss_scaling": True,
9297
}
@@ -282,8 +287,7 @@ def do_train(args):
282287

283288
clip = None
284289
if args.grad_clip > 0:
285-
# TODO @ZHUI Use nn.ClipGradByNorm
286-
clip = paddle.fluid.clip.GradientClipByNorm(
290+
clip = paddle.fluid.clip.GradientClipByGlobalNorm(
287291
clip_norm=args.grad_clip)
288292

289293
decay_param = [
@@ -292,6 +296,7 @@ def do_train(args):
292296
]
293297
# TODO @ZHUI Use paddle.optimizer.AdamW
294298
if ops.optimizer._jit_compile():
299+
logger.info("Using paddlenlp custom AdamW optimizer.")
295300
optimizer = ops.optimizer.AdamwOptimizer(
296301
learning_rate=lr_scheduler,
297302
beta1=args.adam_beta1,
@@ -305,6 +310,7 @@ def do_train(args):
305310
raise ValueError(
306311
"The paddle.optimizer.AdamW not compatible with Sharding!"
307312
)
313+
logger.info("Using paddle.optimizer.AdamW.")
308314
optimizer = paddle.optimizer.AdamW(
309315
learning_rate=lr_scheduler,
310316
beta1=args.adam_beta1,
@@ -313,6 +319,8 @@ def do_train(args):
313319
grad_clip=clip,
314320
weight_decay=args.weight_decay,
315321
apply_decay_param_fun=lambda x: x in decay_param)
322+
# alias
323+
optimizer.apply_optimize = optimizer._apply_optimize
316324

317325
if args.use_recompute:
318326
dist_strategy.recompute = True
@@ -357,20 +365,23 @@ def do_train(args):
357365
if args.mp_degree > 1:
358366
logger.warning("MP should init with dygraph params")
359367
else:
368+
logger.info("Loading parameters from %s" % static_path)
360369
paddle.static.load(main_program, static_path, exe)
361370
flag_loaded = True
362371

363-
if os.path.exists(dygrah_path):
372+
if not flag_loaded and os.path.exists(dygrah_path):
364373
if args.sharding_degree > 1:
365374
logger.warning("Sharding should init with static vars")
366375
else:
376+
logger.info("Loading parameters from %s" % dygrah_path)
367377
init_static_with_params(
368378
model,
369379
paddle.load(
370380
dygrah_path, return_numpy=True),
371381
topo,
372382
main_program)
373383
flag_loaded = True
384+
374385
if not flag_loaded:
375386
logger.error("No checkpoint load.")
376387

@@ -435,7 +446,9 @@ def do_train(args):
435446
save_persistables(exe,
436447
os.path.join(output_dir, "static_vars"),
437448
main_program)
438-
model.init_config["init_args"][0].init_config.pop("topo", None)
449+
if global_step == args.save_steps:
450+
model.init_config["init_args"][0].init_config.pop("topo",
451+
None)
439452
model.save_pretrained(output_dir)
440453
tokenizer.save_pretrained(output_dir)
441454
tic_train = time.time()

examples/language_model/gpt/tests/test_accuracy.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def check_init_checkpoint():
2929
def get_groundtruth():
3030
res = {
3131
1: {
32-
"loss": 11.043229103
32+
"loss": 11.008564949
3333
},
3434
20: {
35-
"loss": 10.904897690
35+
"loss": 10.876321793
3636
},
3737
}
3838
return res
@@ -60,6 +60,11 @@ def parse_log(path=None):
6060
return res
6161

6262

63+
def print_test_results(name):
64+
print("\n" * 5)
65+
print("---- This is test reports for %s task: ----" % name)
66+
67+
6368
class GPTAccuarcy(unittest.TestCase):
6469
"""
6570
Train accuarcy test for GPT
@@ -77,11 +82,13 @@ def test_acc_single_card(self):
7782
gt = get_groundtruth()
7883
res = parse_log("./output/gpt-%s/log/workerlog.0" %
7984
task_name.replace("_", "-"))
85+
print_test_results(task_name)
8086
for k in gt.keys():
8187
print("%s step: %d, gt:%.9f res:%.9f " %
8288
(task_name, k, gt[k]["loss"], res[k]["loss"]))
8389
self.assertAlmostEqual(
8490
gt[k]["loss"], res[k]["loss"], delta=1e-6)
91+
print("\n" * 5)
8592

8693
def test_acc_dp(self):
8794
check_dataset()
@@ -98,11 +105,13 @@ def test_acc_dp(self):
98105
res2 = parse_log("./output/gpt-%s/log/workerlog.1" %
99106
task_name.replace("_", "-"))
100107

108+
print_test_results(task_name)
101109
for k in gt.keys():
102110
mean = (res1[k]["loss"] + res2[k]["loss"]) / 2
103111
print("%s step: %d, gt:%.9f res:%.9f " %
104112
(task_name, k, gt[k]["loss"], mean))
105113
self.assertAlmostEqual(gt[k]["loss"], mean, delta=5e-6)
114+
print("\n" * 5)
106115

107116
@unittest.skipIf(not paddlenlp.ops.optimizer._jit_compile(),
108117
"The paddle.optimizer.AdamW not compatible with Sharding")
@@ -122,11 +131,13 @@ def test_acc_sharding_static(self):
122131
res2 = parse_log("./output/gpt-%s/log/workerlog.1" %
123132
task_name.replace("_", "-"))
124133

134+
print_test_results(task_name)
125135
for k in gt.keys():
126136
mean = (res1[k]["loss"] + res2[k]["loss"]) / 2
127137
print("%s step: %d, gt:%.9f res:%.9f " %
128138
(task_name, k, gt[k]["loss"], mean))
129139
self.assertAlmostEqual(gt[k]["loss"], mean, delta=5e-6)
140+
print("\n" * 5)
130141

131142
def test_acc_mp_static(self):
132143
check_dataset()
@@ -144,15 +155,16 @@ def test_acc_mp_static(self):
144155
res2 = parse_log("./output/gpt-%s/log/workerlog.1" %
145156
task_name.replace("_", "-"))
146157

158+
print_test_results(task_name)
147159
for k in gt.keys():
148160
self.assertAlmostEqual(
149161
res1[k]["loss"], res2[k]["loss"], delta=1e-7)
150162
mean = (res1[k]["loss"] + res2[k]["loss"]) / 2
151163
print("%s step: %d, gt:%.9f res:%.9f " %
152164
(task_name, k, gt[k]["loss"], mean))
153-
if k == 1:
154-
self.assertAlmostEqual(
155-
gt[k]["loss"], res1[k]["loss"], delta=1e-7)
165+
self.assertAlmostEqual(
166+
gt[k]["loss"], res1[k]["loss"], delta=1e-7)
167+
print("\n" * 5)
156168

157169

158170
if __name__ == "__main__":

0 commit comments

Comments
 (0)