2727import paddle .nn .functional as F
2828
2929from paddlenlp .transformers import BertModel , BertForSequenceClassification , BertTokenizer
30+ from paddlenlp .transformers import TinyBertModel , TinyBertForSequenceClassification , TinyBertTokenizer
31+ from paddlenlp .transformers import TinyBertForSequenceClassification , TinyBertTokenizer
32+ from paddlenlp .transformers import RobertaForSequenceClassification , RobertaTokenizer
3033from paddlenlp .utils .log import logger
3134from paddleslim .nas .ofa import OFA , utils
3235from paddleslim .nas .ofa .convert_super import Convert , supernet
3336from paddleslim .nas .ofa .layers import BaseBlock
3437
35- MODEL_CLASSES = {"bert" : (BertForSequenceClassification , BertTokenizer ), }
38+ MODEL_CLASSES = {
39+ "bert" : (BertForSequenceClassification , BertTokenizer ),
40+ "roberta" : (RobertaForSequenceClassification , RobertaTokenizer ),
41+ "tinybert" : (TinyBertForSequenceClassification , TinyBertTokenizer ),
42+ }
43+
44+
45+ def tinybert_forward (self , input_ids , token_type_ids = None , attention_mask = None ):
46+ wtype = self .pooler .dense .fn .weight .dtype if hasattr (
47+ self .pooler .dense , 'fn' ) else self .pooler .dense .weight .dtype
48+ if attention_mask is None :
49+ attention_mask = paddle .unsqueeze (
50+ (input_ids == self .pad_token_id ).astype (wtype ) * - 1e9 , axis = [1 , 2 ])
51+ embedding_output = self .embeddings (input_ids , token_type_ids )
52+ encoded_layer = self .encoder (embedding_output , attention_mask )
53+ pooled_output = self .pooler (encoded_layer )
54+
55+ return encoded_layer , pooled_output
56+
57+
58+ TinyBertModel .forward = tinybert_forward
3659
3760
3861def parse_args ():
@@ -113,14 +136,15 @@ def do_train(args):
113136 config_path = os .path .join (args .model_name_or_path , 'model_config.json' )
114137 cfg_dict = dict (json .loads (open (config_path ).read ()))
115138
139+ kept_layers_index = {}
116140 if args .depth_mult < 1.0 :
117- depth = round (cfg_dict ["init_args" ][0 ]['num_hidden_layers' ] * args . depth_mult )
118- cfg_dict [ "init_args" ][ 0 ][ 'num_hidden_layers' ] = depth
119- kept_layers_index = {}
120- for idx , i in enumerate (range (1 , depth + 1 )):
141+ depth = round (cfg_dict ["init_args" ][0 ]['num_hidden_layers' ] *
142+ args . depth_mult )
143+ cfg_dict [ "init_args" ][ 0 ][ 'num_hidden_layers' ] = depth
144+ for idx , i in enumerate (range (1 , depth + 1 )):
121145 kept_layers_index [idx ] = math .floor (i / args .depth_mult ) - 1
122146
123- os .rename (config_path , config_path + '_bak' )
147+ os .rename (config_path , config_path + '_bak' )
124148 with open (config_path , "w" , encoding = "utf-8" ) as f :
125149 f .write (json .dumps (cfg_dict , ensure_ascii = False ))
126150
@@ -132,7 +156,7 @@ def do_train(args):
132156 origin_model = model_class .from_pretrained (
133157 args .model_name_or_path , num_classes = num_labels )
134158
135- os .rename (config_path + '_bak' , config_path )
159+ os .rename (config_path + '_bak' , config_path )
136160
137161 sp_config = supernet (expand_ratio = [1.0 , args .width_mult ])
138162 model = Convert (sp_config ).convert (model )
@@ -142,15 +166,24 @@ def do_train(args):
142166 sd = paddle .load (
143167 os .path .join (args .model_name_or_path , 'model_state.pdparams' ))
144168
145- for name , params in ofa_model .model .named_parameters ():
146- if 'encoder' not in name :
147- params .set_value (sd [name ])
148- else :
149- idx = int (name .strip ().split ('.' )[3 ])
150- mapping_name = name .replace ('.' + str (idx )+ '.' , '.' + str (kept_layers_index [idx ])+ '.' )
151- params .set_value (sd [mapping_name ])
169+ if len (kept_layers_index ) == 0 :
170+ ofa_model .model .set_state_dict (sd )
171+ else :
172+ for name , params in ofa_model .model .named_parameters ():
173+ if 'encoder' not in name :
174+ params .set_value (sd [name ])
175+ else :
176+ idx = int (name .strip ().split ('.' )[3 ])
177+ mapping_name = name .replace (
178+ '.' + str (idx ) + '.' ,
179+ '.' + str (kept_layers_index [idx ]) + '.' )
180+ params .set_value (sd [mapping_name ])
152181
153182 best_config = utils .dynabert_config (ofa_model , args .width_mult )
183+ for name , sublayer in ofa_model .model .named_sublayers ():
184+ if isinstance (sublayer , paddle .nn .MultiHeadAttention ):
185+ sublayer .num_heads = int (args .width_mult * sublayer .num_heads )
186+
154187 ofa_model .export (
155188 best_config ,
156189 input_shapes = [[1 , args .max_seq_length ], [1 , args .max_seq_length ]],
0 commit comments