Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b877f92

Browse files
author
zkh2016
committed
Merge branch 'develop' of https://github.com/zkh2016/PaddleNLP into develop
2 parents 1c4921e + 2994379 commit b877f92

3 files changed

Lines changed: 141 additions & 125 deletions

File tree

examples/experimental/faster_bert/run_glue.py

Lines changed: 119 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
3838
from paddlenlp.transformers import LinearDecayWithWarmup
3939
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
40+
from static.model_convert_util import convert_base_to_fused
4041

4142
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
4243
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -246,123 +247,123 @@ def convert_example(example,
246247
else:
247248
return example['input_ids'], example['token_type_ids']
248249

249-
def fused_weight(weight, num_head):
250-
a = paddle.transpose(weight, perm=[1, 0])
251-
return paddle.reshape(a, shape=[1, num_head, int(a.shape[0]/num_head), a.shape[1]])
252-
253-
def fused_qkv(qkv_weight, num_head):
254-
q = qkv_weight['q']
255-
k = qkv_weight['k']
256-
v = qkv_weight['v']
257-
258-
fq = fused_weight(q, num_head)
259-
fk = fused_weight(k, num_head)
260-
fv = fused_weight(v, num_head)
261-
a = paddle.concat(x=[fq, fk, fv], axis=0)
262-
return a
263-
264-
def convert_base_to_fused(state_to_load):
265-
base_to_fused = dict()
266-
base_to_fused["weight"] = "scale"
267-
base_to_fused["bias"] = "bias"
268-
269-
fused_state_to_load = dict()
270-
qkv_weight = dict()
271-
qkv_bias = dict()
272-
qkv_count = 0
273-
num_head = 16
274-
layer_index = 0
275-
for key, value in state_to_load.items():
276-
array = key.split('.')
277-
fused_array = list(array)
278-
if len(array) == 6:#linear or layer_norm
279-
if 'linear' in array[4]:
280-
#linear1.weight -> ffn._linear1_weight
281-
#linear1.bias -> ffn._linear1_bias
282-
fused_array[5] = "_" + array[4] + "_" + array[5]
283-
fused_array[4] = "ffn"
284-
fused_key = '.'.join(fused_array)
285-
fused_state_to_load[fused_key] = value
286-
#print(key, fused_key)
287-
#if array[3] == "0":
288-
# np.savetxt(key+".txt", value)
289-
290-
elif 'norm' in array[4]:
291-
if array[4][-1] == '1':
292-
#norm1.weight -> fused_atten.pre_ln_scale
293-
#norm2.weight -> fused_atten.ln_scale
294-
fused_array[4] = "fused_attn"
295-
fused_array[5] = "ln_" + base_to_fused[array[5]]
296-
fused_key = '.'.join(fused_array)
297-
fused_state_to_load[fused_key] = value
298-
#print(key, fused_key)
299-
#if array[3] == "0":
300-
# np.savetxt(key+".txt", value)
301-
else:
302-
#norm1.weight -> ffn._ln1_scale
303-
fused_array[4] = "ffn"
304-
fused_array[5] = "_ln" + array[4][-1] + "_" + base_to_fused[array[5]]
305-
fused_key = '.'.join(fused_array)
306-
fused_state_to_load[fused_key] = value
307-
#print(key, fused_key)
308-
#if array[3] == "0":
309-
# np.savetxt(key+".txt", value)
310-
elif len(array) == 7:#self_atten
311-
if 'q' in array[5]:
312-
if array[6] == "weight":
313-
qkv_weight['q'] = value
314-
else:
315-
qkv_bias['q'] = value
316-
qkv_count += 1
317-
elif 'k' in array[5]:
318-
if array[6] == "weight":
319-
qkv_weight['k'] = value
320-
else:
321-
qkv_bias['k'] = value
322-
qkv_count += 1
323-
elif 'v' in array[5]:
324-
if array[6] == "weight":
325-
qkv_weight['v'] = value
326-
else:
327-
qkv_bias['v'] = value
328-
qkv_count += 1
329-
else:
330-
fused_array.pop()
331-
fused_array[4] = "fused_attn"
332-
if array[6] == "weight":
333-
fused_array[5] = "linear_weight"
334-
else:
335-
fused_array[5] = "linear_bias"
336-
fused_key = '.'.join(fused_array)
337-
fused_state_to_load[fused_key] = value
338-
#print(key, fused_key)
339-
#if array[3] == "0":
340-
# np.savetxt(key+".txt", value)
341-
342-
if qkv_count == 6:
343-
qkv_count = 0
344-
fused_array.pop()
345-
346-
fused_array[4] = "fused_attn"
347-
fused_array[5] = "qkv_weight"
348-
fused_key = '.'.join(fused_array)
349-
fused_state_to_load[fused_key] = fused_qkv(qkv_weight, num_head)
350-
#print(key, fused_key)
351-
352-
fused_array[4] = "fused_attn"
353-
fused_array[5] = "qkv_bias"
354-
fused_key = '.'.join(fused_array)
355-
a = paddle.concat(x=[qkv_bias['q'], qkv_bias['k'], qkv_bias['v']], axis=0)
356-
tmp_bias = paddle.reshape(a, shape=[3, num_head, int(a.shape[0]/3/num_head)])
357-
fused_state_to_load[fused_key] = tmp_bias
358-
#print(key, fused_key, tmp_bias.numpy().shape)
359-
#if array[3] == "0":
360-
# np.savetxt("fused_bias.txt", tmp_bias.numpy().flatten())
361-
#if array[3] == "0":
362-
363-
else:
364-
fused_state_to_load[key] = value
365-
return fused_state_to_load
250+
#def fused_weight(weight, num_head):
251+
# a = paddle.transpose(weight, perm=[1, 0])
252+
# return paddle.reshape(a, shape=[1, num_head, int(a.shape[0]/num_head), a.shape[1]])
253+
#
254+
#def fused_qkv(qkv_weight, num_head):
255+
# q = qkv_weight['q']
256+
# k = qkv_weight['k']
257+
# v = qkv_weight['v']
258+
#
259+
# fq = fused_weight(q, num_head)
260+
# fk = fused_weight(k, num_head)
261+
# fv = fused_weight(v, num_head)
262+
# a = paddle.concat(x=[fq, fk, fv], axis=0)
263+
# return a
264+
#
265+
#def convert_base_to_fused(state_to_load):
266+
# base_to_fused = dict()
267+
# base_to_fused["weight"] = "scale"
268+
# base_to_fused["bias"] = "bias"
269+
#
270+
# fused_state_to_load = dict()
271+
# qkv_weight = dict()
272+
# qkv_bias = dict()
273+
# qkv_count = 0
274+
# num_head = 16
275+
# layer_index = 0
276+
# for key, value in state_to_load.items():
277+
# array = key.split('.')
278+
# fused_array = list(array)
279+
# if len(array) == 6:#linear or layer_norm
280+
# if 'linear' in array[4]:
281+
# #linear1.weight -> ffn._linear1_weight
282+
# #linear1.bias -> ffn._linear1_bias
283+
# fused_array[5] = "_" + array[4] + "_" + array[5]
284+
# fused_array[4] = "ffn"
285+
# fused_key = '.'.join(fused_array)
286+
# fused_state_to_load[fused_key] = value
287+
# #print(key, fused_key)
288+
# #if array[3] == "0":
289+
# # np.savetxt(key+".txt", value)
290+
#
291+
# elif 'norm' in array[4]:
292+
# if array[4][-1] == '1':
293+
# #norm1.weight -> fused_atten.pre_ln_scale
294+
# #norm2.weight -> fused_atten.ln_scale
295+
# fused_array[4] = "fused_attn"
296+
# fused_array[5] = "ln_" + base_to_fused[array[5]]
297+
# fused_key = '.'.join(fused_array)
298+
# fused_state_to_load[fused_key] = value
299+
# #print(key, fused_key)
300+
# #if array[3] == "0":
301+
# # np.savetxt(key+".txt", value)
302+
# else:
303+
# #norm1.weight -> ffn._ln1_scale
304+
# fused_array[4] = "ffn"
305+
# fused_array[5] = "_ln" + array[4][-1] + "_" + base_to_fused[array[5]]
306+
# fused_key = '.'.join(fused_array)
307+
# fused_state_to_load[fused_key] = value
308+
# #print(key, fused_key)
309+
# #if array[3] == "0":
310+
# # np.savetxt(key+".txt", value)
311+
# elif len(array) == 7:#self_atten
312+
# if 'q' in array[5]:
313+
# if array[6] == "weight":
314+
# qkv_weight['q'] = value
315+
# else:
316+
# qkv_bias['q'] = value
317+
# qkv_count += 1
318+
# elif 'k' in array[5]:
319+
# if array[6] == "weight":
320+
# qkv_weight['k'] = value
321+
# else:
322+
# qkv_bias['k'] = value
323+
# qkv_count += 1
324+
# elif 'v' in array[5]:
325+
# if array[6] == "weight":
326+
# qkv_weight['v'] = value
327+
# else:
328+
# qkv_bias['v'] = value
329+
# qkv_count += 1
330+
# else:
331+
# fused_array.pop()
332+
# fused_array[4] = "fused_attn"
333+
# if array[6] == "weight":
334+
# fused_array[5] = "linear_weight"
335+
# else:
336+
# fused_array[5] = "linear_bias"
337+
# fused_key = '.'.join(fused_array)
338+
# fused_state_to_load[fused_key] = value
339+
# #print(key, fused_key)
340+
# #if array[3] == "0":
341+
# # np.savetxt(key+".txt", value)
342+
#
343+
# if qkv_count == 6:
344+
# qkv_count = 0
345+
# fused_array.pop()
346+
#
347+
# fused_array[4] = "fused_attn"
348+
# fused_array[5] = "qkv_weight"
349+
# fused_key = '.'.join(fused_array)
350+
# fused_state_to_load[fused_key] = fused_qkv(qkv_weight, num_head)
351+
# #print(key, fused_key)
352+
#
353+
# fused_array[4] = "fused_attn"
354+
# fused_array[5] = "qkv_bias"
355+
# fused_key = '.'.join(fused_array)
356+
# a = paddle.concat(x=[qkv_bias['q'], qkv_bias['k'], qkv_bias['v']], axis=0)
357+
# tmp_bias = paddle.reshape(a, shape=[3, num_head, int(a.shape[0]/3/num_head)])
358+
# fused_state_to_load[fused_key] = tmp_bias
359+
# #print(key, fused_key, tmp_bias.numpy().shape)
360+
# #if array[3] == "0":
361+
# # np.savetxt("fused_bias.txt", tmp_bias.numpy().flatten())
362+
# #if array[3] == "0":
363+
#
364+
# else:
365+
# fused_state_to_load[key] = value
366+
# return fused_state_to_load
366367

367368

368369

@@ -445,7 +446,7 @@ def do_train(args):
445446
####convert model to fused model
446447
model = fused_model
447448
#model = base_model
448-
#model.set_state_dict(state_to_load)
449+
#model.set_state_dict(base_state_to_load)
449450

450451
if paddle.distributed.get_world_size() > 1:
451452
model = paddle.DataParallel(model)

examples/experimental/faster_bert/static/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

examples/experimental/faster_bert/static/run_glue.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@
2828
from paddle.metric import Accuracy
2929
from paddlenlp.data import Stack, Tuple, Pad
3030
from paddlenlp.data.sampler import SamplerHelper
31-
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
31+
from paddlenlp.transformers import BertTokenizer
32+
from modeling import BertForSequenceClassification
3233
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
3334
from paddlenlp.transformers import LinearDecayWithWarmup
3435
from paddlenlp.metrics import Mcc, PearsonAndSpearman
3536
from paddlenlp.utils.log import logger
3637

38+
from model_convert_util import convert_base_to_fused
39+
3740
METRIC_CLASSES = {
3841
"cola": Mcc,
3942
"sst-2": Accuracy,
@@ -168,14 +171,15 @@ def create_data_holder(task_name):
168171

169172
def reset_program_state_dict(args, model, state_dict, pretrained_state_dict):
170173
"""
171-
Initialize the parameter from the bert config, and set the parameter by
174+
Initialize the parameter from the bert config, and set the parameter by
172175
reseting the state dict."
173176
"""
174177
reset_state_dict = {}
175178
scale = model.initializer_range if hasattr(model, "initializer_range")\
176179
else getattr(model, args.model_type).config["initializer_range"]
177180
reset_parameter_names = []
178181
for n, p in state_dict.items():
182+
print(n)
179183
if n in pretrained_state_dict:
180184
reset_state_dict[p.name] = np.array(pretrained_state_dict[n])
181185
reset_parameter_names.append(n)
@@ -208,7 +212,7 @@ def set_seed(args):
208212
def evaluate(exe, metric, loss, correct, dev_program, data_loader,
209213
phase="eval"):
210214
"""
211-
The evaluate process, calcluate the eval loss and metric.
215+
The evaluate process, calcluate the eval loss and metric.
212216
"""
213217
metric.reset()
214218
returns = [loss]
@@ -295,7 +299,7 @@ def do_train(args):
295299

296300
batchify_fn = lambda samples, fn=Tuple(
297301
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
298-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
302+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
299303
Stack(dtype="int64" if train_ds.label_list else "float32") # label
300304
): fn(samples)
301305

@@ -357,8 +361,14 @@ def do_train(args):
357361
with paddle.static.program_guard(main_program, startup_program):
358362
num_class = 1 if train_ds.label_list is None else len(
359363
train_ds.label_list)
360-
model, pretrained_state_dict = model_class.from_pretrained(
364+
base_model, pretrained_state_dict = model_class.from_pretrained(
365+
args.model_name_or_path, num_classes=num_class)
366+
367+
fused_model, fused_pretrained_state_dict = model_class.from_pretrained(
361368
args.model_name_or_path, num_classes=num_class)
369+
370+
model = fused_model
371+
362372
loss_fct = paddle.nn.loss.CrossEntropyLoss(
363373
) if train_ds.label_list else paddle.nn.loss.MSELoss()
364374
logits = model(input_ids, token_type_ids)
@@ -395,11 +405,16 @@ def do_train(args):
395405
# Initialize the fine-tuning parameter, we will load the parameters in
396406
# pre-training model. And initialize the parameter which not in pre-training model
397407
# by the normal distribution.
408+
409+
####convert model to fused model
410+
fused_pretrained_state_dict = convert_base_to_fused(pretrained_state_dict)
411+
####convert model to fused model
412+
398413
exe = paddle.static.Executor(place)
399414
exe.run(startup_program)
400415
state_dict = model.state_dict()
401416
reset_state_dict = reset_program_state_dict(args, model, state_dict,
402-
pretrained_state_dict)
417+
fused_pretrained_state_dict)
403418
paddle.static.set_program_state(main_program, reset_state_dict)
404419

405420
global_step = 0

0 commit comments

Comments
 (0)