|
37 | 37 | from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer |
38 | 38 | from paddlenlp.transformers import LinearDecayWithWarmup |
39 | 39 | from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman |
| 40 | +from static.model_convert_util import convert_base_to_fused |
40 | 41 |
|
41 | 42 | FORMAT = '%(asctime)s-%(levelname)s: %(message)s' |
42 | 43 | logging.basicConfig(level=logging.INFO, format=FORMAT) |
@@ -246,123 +247,123 @@ def convert_example(example, |
246 | 247 | else: |
247 | 248 | return example['input_ids'], example['token_type_ids'] |
248 | 249 |
|
249 | | -def fused_weight(weight, num_head): |
250 | | - a = paddle.transpose(weight, perm=[1, 0]) |
251 | | - return paddle.reshape(a, shape=[1, num_head, int(a.shape[0]/num_head), a.shape[1]]) |
252 | | - |
253 | | -def fused_qkv(qkv_weight, num_head): |
254 | | - q = qkv_weight['q'] |
255 | | - k = qkv_weight['k'] |
256 | | - v = qkv_weight['v'] |
257 | | - |
258 | | - fq = fused_weight(q, num_head) |
259 | | - fk = fused_weight(k, num_head) |
260 | | - fv = fused_weight(v, num_head) |
261 | | - a = paddle.concat(x=[fq, fk, fv], axis=0) |
262 | | - return a |
263 | | - |
264 | | -def convert_base_to_fused(state_to_load): |
265 | | - base_to_fused = dict() |
266 | | - base_to_fused["weight"] = "scale" |
267 | | - base_to_fused["bias"] = "bias" |
268 | | - |
269 | | - fused_state_to_load = dict() |
270 | | - qkv_weight = dict() |
271 | | - qkv_bias = dict() |
272 | | - qkv_count = 0 |
273 | | - num_head = 16 |
274 | | - layer_index = 0 |
275 | | - for key, value in state_to_load.items(): |
276 | | - array = key.split('.') |
277 | | - fused_array = list(array) |
278 | | - if len(array) == 6:#linear or layer_norm |
279 | | - if 'linear' in array[4]: |
280 | | - #linear1.weight -> ffn._linear1_weight |
281 | | - #linear1.bias -> ffn._linear1_bias |
282 | | - fused_array[5] = "_" + array[4] + "_" + array[5] |
283 | | - fused_array[4] = "ffn" |
284 | | - fused_key = '.'.join(fused_array) |
285 | | - fused_state_to_load[fused_key] = value |
286 | | - #print(key, fused_key) |
287 | | - #if array[3] == "0": |
288 | | - # np.savetxt(key+".txt", value) |
289 | | - |
290 | | - elif 'norm' in array[4]: |
291 | | - if array[4][-1] == '1': |
292 | | - #norm1.weight -> fused_atten.pre_ln_scale |
293 | | - #norm2.weight -> fused_atten.ln_scale |
294 | | - fused_array[4] = "fused_attn" |
295 | | - fused_array[5] = "ln_" + base_to_fused[array[5]] |
296 | | - fused_key = '.'.join(fused_array) |
297 | | - fused_state_to_load[fused_key] = value |
298 | | - #print(key, fused_key) |
299 | | - #if array[3] == "0": |
300 | | - # np.savetxt(key+".txt", value) |
301 | | - else: |
302 | | - #norm1.weight -> ffn._ln1_scale |
303 | | - fused_array[4] = "ffn" |
304 | | - fused_array[5] = "_ln" + array[4][-1] + "_" + base_to_fused[array[5]] |
305 | | - fused_key = '.'.join(fused_array) |
306 | | - fused_state_to_load[fused_key] = value |
307 | | - #print(key, fused_key) |
308 | | - #if array[3] == "0": |
309 | | - # np.savetxt(key+".txt", value) |
310 | | - elif len(array) == 7:#self_atten |
311 | | - if 'q' in array[5]: |
312 | | - if array[6] == "weight": |
313 | | - qkv_weight['q'] = value |
314 | | - else: |
315 | | - qkv_bias['q'] = value |
316 | | - qkv_count += 1 |
317 | | - elif 'k' in array[5]: |
318 | | - if array[6] == "weight": |
319 | | - qkv_weight['k'] = value |
320 | | - else: |
321 | | - qkv_bias['k'] = value |
322 | | - qkv_count += 1 |
323 | | - elif 'v' in array[5]: |
324 | | - if array[6] == "weight": |
325 | | - qkv_weight['v'] = value |
326 | | - else: |
327 | | - qkv_bias['v'] = value |
328 | | - qkv_count += 1 |
329 | | - else: |
330 | | - fused_array.pop() |
331 | | - fused_array[4] = "fused_attn" |
332 | | - if array[6] == "weight": |
333 | | - fused_array[5] = "linear_weight" |
334 | | - else: |
335 | | - fused_array[5] = "linear_bias" |
336 | | - fused_key = '.'.join(fused_array) |
337 | | - fused_state_to_load[fused_key] = value |
338 | | - #print(key, fused_key) |
339 | | - #if array[3] == "0": |
340 | | - # np.savetxt(key+".txt", value) |
341 | | - |
342 | | - if qkv_count == 6: |
343 | | - qkv_count = 0 |
344 | | - fused_array.pop() |
345 | | - |
346 | | - fused_array[4] = "fused_attn" |
347 | | - fused_array[5] = "qkv_weight" |
348 | | - fused_key = '.'.join(fused_array) |
349 | | - fused_state_to_load[fused_key] = fused_qkv(qkv_weight, num_head) |
350 | | - #print(key, fused_key) |
351 | | - |
352 | | - fused_array[4] = "fused_attn" |
353 | | - fused_array[5] = "qkv_bias" |
354 | | - fused_key = '.'.join(fused_array) |
355 | | - a = paddle.concat(x=[qkv_bias['q'], qkv_bias['k'], qkv_bias['v']], axis=0) |
356 | | - tmp_bias = paddle.reshape(a, shape=[3, num_head, int(a.shape[0]/3/num_head)]) |
357 | | - fused_state_to_load[fused_key] = tmp_bias |
358 | | - #print(key, fused_key, tmp_bias.numpy().shape) |
359 | | - #if array[3] == "0": |
360 | | - # np.savetxt("fused_bias.txt", tmp_bias.numpy().flatten()) |
361 | | - #if array[3] == "0": |
362 | | - |
363 | | - else: |
364 | | - fused_state_to_load[key] = value |
365 | | - return fused_state_to_load |
| 250 | +#def fused_weight(weight, num_head): |
| 251 | +# a = paddle.transpose(weight, perm=[1, 0]) |
| 252 | +# return paddle.reshape(a, shape=[1, num_head, int(a.shape[0]/num_head), a.shape[1]]) |
| 253 | +# |
| 254 | +#def fused_qkv(qkv_weight, num_head): |
| 255 | +# q = qkv_weight['q'] |
| 256 | +# k = qkv_weight['k'] |
| 257 | +# v = qkv_weight['v'] |
| 258 | +# |
| 259 | +# fq = fused_weight(q, num_head) |
| 260 | +# fk = fused_weight(k, num_head) |
| 261 | +# fv = fused_weight(v, num_head) |
| 262 | +# a = paddle.concat(x=[fq, fk, fv], axis=0) |
| 263 | +# return a |
| 264 | +# |
| 265 | +#def convert_base_to_fused(state_to_load): |
| 266 | +# base_to_fused = dict() |
| 267 | +# base_to_fused["weight"] = "scale" |
| 268 | +# base_to_fused["bias"] = "bias" |
| 269 | +# |
| 270 | +# fused_state_to_load = dict() |
| 271 | +# qkv_weight = dict() |
| 272 | +# qkv_bias = dict() |
| 273 | +# qkv_count = 0 |
| 274 | +# num_head = 16 |
| 275 | +# layer_index = 0 |
| 276 | +# for key, value in state_to_load.items(): |
| 277 | +# array = key.split('.') |
| 278 | +# fused_array = list(array) |
| 279 | +# if len(array) == 6:#linear or layer_norm |
| 280 | +# if 'linear' in array[4]: |
| 281 | +# #linear1.weight -> ffn._linear1_weight |
| 282 | +# #linear1.bias -> ffn._linear1_bias |
| 283 | +# fused_array[5] = "_" + array[4] + "_" + array[5] |
| 284 | +# fused_array[4] = "ffn" |
| 285 | +# fused_key = '.'.join(fused_array) |
| 286 | +# fused_state_to_load[fused_key] = value |
| 287 | +# #print(key, fused_key) |
| 288 | +# #if array[3] == "0": |
| 289 | +# # np.savetxt(key+".txt", value) |
| 290 | +# |
| 291 | +# elif 'norm' in array[4]: |
| 292 | +# if array[4][-1] == '1': |
| 293 | +# #norm1.weight -> fused_atten.pre_ln_scale |
| 294 | +# #norm2.weight -> fused_atten.ln_scale |
| 295 | +# fused_array[4] = "fused_attn" |
| 296 | +# fused_array[5] = "ln_" + base_to_fused[array[5]] |
| 297 | +# fused_key = '.'.join(fused_array) |
| 298 | +# fused_state_to_load[fused_key] = value |
| 299 | +# #print(key, fused_key) |
| 300 | +# #if array[3] == "0": |
| 301 | +# # np.savetxt(key+".txt", value) |
| 302 | +# else: |
| 303 | +# #norm1.weight -> ffn._ln1_scale |
| 304 | +# fused_array[4] = "ffn" |
| 305 | +# fused_array[5] = "_ln" + array[4][-1] + "_" + base_to_fused[array[5]] |
| 306 | +# fused_key = '.'.join(fused_array) |
| 307 | +# fused_state_to_load[fused_key] = value |
| 308 | +# #print(key, fused_key) |
| 309 | +# #if array[3] == "0": |
| 310 | +# # np.savetxt(key+".txt", value) |
| 311 | +# elif len(array) == 7:#self_atten |
| 312 | +# if 'q' in array[5]: |
| 313 | +# if array[6] == "weight": |
| 314 | +# qkv_weight['q'] = value |
| 315 | +# else: |
| 316 | +# qkv_bias['q'] = value |
| 317 | +# qkv_count += 1 |
| 318 | +# elif 'k' in array[5]: |
| 319 | +# if array[6] == "weight": |
| 320 | +# qkv_weight['k'] = value |
| 321 | +# else: |
| 322 | +# qkv_bias['k'] = value |
| 323 | +# qkv_count += 1 |
| 324 | +# elif 'v' in array[5]: |
| 325 | +# if array[6] == "weight": |
| 326 | +# qkv_weight['v'] = value |
| 327 | +# else: |
| 328 | +# qkv_bias['v'] = value |
| 329 | +# qkv_count += 1 |
| 330 | +# else: |
| 331 | +# fused_array.pop() |
| 332 | +# fused_array[4] = "fused_attn" |
| 333 | +# if array[6] == "weight": |
| 334 | +# fused_array[5] = "linear_weight" |
| 335 | +# else: |
| 336 | +# fused_array[5] = "linear_bias" |
| 337 | +# fused_key = '.'.join(fused_array) |
| 338 | +# fused_state_to_load[fused_key] = value |
| 339 | +# #print(key, fused_key) |
| 340 | +# #if array[3] == "0": |
| 341 | +# # np.savetxt(key+".txt", value) |
| 342 | +# |
| 343 | +# if qkv_count == 6: |
| 344 | +# qkv_count = 0 |
| 345 | +# fused_array.pop() |
| 346 | +# |
| 347 | +# fused_array[4] = "fused_attn" |
| 348 | +# fused_array[5] = "qkv_weight" |
| 349 | +# fused_key = '.'.join(fused_array) |
| 350 | +# fused_state_to_load[fused_key] = fused_qkv(qkv_weight, num_head) |
| 351 | +# #print(key, fused_key) |
| 352 | +# |
| 353 | +# fused_array[4] = "fused_attn" |
| 354 | +# fused_array[5] = "qkv_bias" |
| 355 | +# fused_key = '.'.join(fused_array) |
| 356 | +# a = paddle.concat(x=[qkv_bias['q'], qkv_bias['k'], qkv_bias['v']], axis=0) |
| 357 | +# tmp_bias = paddle.reshape(a, shape=[3, num_head, int(a.shape[0]/3/num_head)]) |
| 358 | +# fused_state_to_load[fused_key] = tmp_bias |
| 359 | +# #print(key, fused_key, tmp_bias.numpy().shape) |
| 360 | +# #if array[3] == "0": |
| 361 | +# # np.savetxt("fused_bias.txt", tmp_bias.numpy().flatten()) |
| 362 | +# #if array[3] == "0": |
| 363 | +# |
| 364 | +# else: |
| 365 | +# fused_state_to_load[key] = value |
| 366 | +# return fused_state_to_load |
366 | 367 |
|
367 | 368 |
|
368 | 369 |
|
@@ -445,7 +446,7 @@ def do_train(args): |
445 | 446 | ####convert model to fused model |
446 | 447 | model = fused_model |
447 | 448 | #model = base_model |
448 | | - #model.set_state_dict(state_to_load) |
| 449 | + #model.set_state_dict(base_state_to_load) |
449 | 450 |
|
450 | 451 | if paddle.distributed.get_world_size() > 1: |
451 | 452 | model = paddle.DataParallel(model) |
|
0 commit comments