3
3
4
4
import torch
5
5
6
- from tokenizers import Tokenizer
7
- from transformers import OPTForCausalLM , StoppingCriteriaList , StoppingCriteria
6
+ from transformers import AutoTokenizer , OPTForCausalLM , StoppingCriteriaList , StoppingCriteria
8
7
from parallelformers import parallelize
9
8
import psutil
10
9
@@ -80,6 +79,7 @@ def __init__(
80
79
self .is_loaded = False
81
80
self .num_gpus = num_gpus
82
81
self .tensor_parallel = tensor_parallel
82
+ self .max_input_length = 2020
83
83
self ._master_port = None
84
84
85
85
def _load_checkpoint (self , checkpoint_path : str ):
@@ -129,9 +129,15 @@ def _parallelize(self) -> None:
129
129
130
130
self ._master_port = 13000 + (id (self .model ) % 32749 )
131
131
132
+ custom_policies = None
133
+ if self .model .config .model_type == "opt" and not self .model .config .enable_bias :
134
+ from galai .parallel_policy import OPTDecoderLayerPolicyNoBias
135
+ custom_policies = [OPTDecoderLayerPolicyNoBias ]
136
+
132
137
parallelize (
133
138
self .model , num_gpus = self .num_gpus , fp16 = self .dtype == torch .float16 ,
134
139
master_port = self ._master_port ,
140
+ custom_policies = custom_policies ,
135
141
)
136
142
137
143
def _set_tokenizer (self , tokenizer_path : str ):
@@ -143,9 +149,27 @@ def _set_tokenizer(self, tokenizer_path: str):
143
149
tokenizer_path : str
144
150
Path for the tokenizer (str)
145
151
"""
146
- self .tokenizer = Tokenizer .from_pretrained (tokenizer_path )
147
- self .tokenizer .enable_padding (direction = "left" , pad_id = 1 , pad_type_id = 0 , pad_token = "[PAD]" )
148
- self .tokenizer .enable_truncation (max_length = 2020 , direction = "left" )
152
+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
153
+
154
+ # setup padding
155
+ tokenizer .pad_token_id = 1
156
+ tokenizer .pad_token = "<pad>"
157
+ tokenizer .padding_side = "left"
158
+
159
+ # setup truncation
160
+ tokenizer .truncation_side = "left"
161
+
162
+ # setup special tokens
163
+ tokenizer .bos_token_id = 0
164
+ tokenizer .bos_token = "<s>"
165
+
166
+ tokenizer .eos_token_id = 2
167
+ tokenizer .eos_token = "</s>"
168
+
169
+ tokenizer .unk_token = "<unk>"
170
+ tokenizer .unk_token_id = 3
171
+
172
+ self .tokenizer = tokenizer
149
173
150
174
def _tokenize (self , input_text : List [str ], new_doc : bool ) -> torch .LongTensor :
151
175
"""
@@ -164,24 +188,27 @@ def _tokenize(self, input_text: List[str], new_doc: bool) -> torch.LongTensor:
164
188
text = escape_custom_split_sequence (text )
165
189
if not text :
166
190
warnings .warn (
167
- "Found an empty input text. Chainging to end-of-document token instead." ,
191
+ "Found an empty input text. Changing to end-of-document token instead." ,
168
192
UserWarning
169
193
)
170
- text = "</s>"
194
+ text = self . tokenizer . eos_token
171
195
texts .append (text )
172
196
173
197
if new_doc :
174
- pad_id = self .tokenizer .padding ["pad_id" ]
175
- pad_token = self .tokenizer .id_to_token (pad_id )
198
+ pad_token = self .tokenizer .pad_token
176
199
texts = [pad_token + t for t in texts ]
177
200
178
- list_encoded = self .tokenizer .encode_batch (texts )
179
- context_tokens = [encoded .ids for encoded in list_encoded ]
201
+ encoded = self .tokenizer (
202
+ texts ,
203
+ padding = "longest" ,
204
+ max_length = self .max_input_length ,
205
+ truncation = True
206
+ )
207
+ context_tokens = encoded ["input_ids" ]
180
208
input_v = torch .LongTensor (context_tokens ).to (self .model .device )
181
209
182
210
if new_doc :
183
- eos_id = self .tokenizer .token_to_id ("</s>" )
184
- input_v [input_v [:, 0 ] == pad_id , 0 ] = eos_id
211
+ input_v [input_v [:, 0 ] == self .tokenizer .pad_token_id , 0 ] = self .tokenizer .eos_token_id
185
212
return input_v
186
213
187
214
@torch .inference_mode ()
@@ -275,9 +302,12 @@ def generate(
275
302
)
276
303
277
304
# we keep special tokens such as [START_REF] or <work>
278
- decoded = self .tokenizer .decode_batch (out ['sequences' ]. tolist () , skip_special_tokens = False )
305
+ decoded = self .tokenizer .batch_decode (out ['sequences' ], skip_special_tokens = False )
279
306
# so we manually remove </s> and <pad>
280
- decoded = [text .replace ("</s>" , "" ).replace ("<pad>" , "" ) for text in decoded ]
307
+ decoded = [
308
+ text .replace (self .tokenizer .eos_token , "" ).replace (self .tokenizer .pad_token , "" )
309
+ for text in decoded
310
+ ]
281
311
282
312
if num_return_sequences == 1 :
283
313
return decoded [0 ] if isinstance (input_text , str ) else decoded
@@ -363,7 +393,7 @@ def generate_reference(
363
393
prompt_length = input_v .shape [1 ]
364
394
finished_reference_criteria = FinishedReferenceCriteria (
365
395
prompt_length = prompt_length ,
366
- end_ref_id = self .tokenizer .token_to_id ("[END_REF]" ),
396
+ end_ref_id = self .tokenizer .convert_tokens_to_ids ("[END_REF]" ),
367
397
)
368
398
369
399
if max_new_tokens is None and max_length is None :
@@ -396,8 +426,8 @@ def generate_reference(
396
426
stopping_criteria = stopping_criteria ,
397
427
)
398
428
# cut-off the prompts
399
- generated_tokens = out ["sequences" ][:, prompt_length :]. tolist ()
400
- decoded = self .tokenizer .decode_batch (generated_tokens , skip_special_tokens = False )
429
+ generated_tokens = out ["sequences" ][:, prompt_length :]
430
+ decoded = self .tokenizer .batch_decode (generated_tokens , skip_special_tokens = False )
401
431
references = []
402
432
unfinished_generation = False
403
433
for text in decoded :
0 commit comments