3
3
4
4
import torch
5
5
6
- from tokenizers import Tokenizer
7
- from transformers import OPTForCausalLM , StoppingCriteriaList , StoppingCriteria
6
+ from transformers import AutoTokenizer , OPTForCausalLM , StoppingCriteriaList , StoppingCriteria
8
7
from parallelformers import parallelize
9
8
import psutil
10
9
@@ -80,6 +79,7 @@ def __init__(
80
79
self .is_loaded = False
81
80
self .num_gpus = num_gpus
82
81
self .tensor_parallel = tensor_parallel
82
+ self .max_input_length = 2020
83
83
self ._master_port = None
84
84
85
85
def _load_checkpoint (self , checkpoint_path : str ):
@@ -134,7 +134,7 @@ def _parallelize(self) -> None:
134
134
master_port = self ._master_port ,
135
135
)
136
136
137
- def _set_tokenizer (self , tokenizer_path : str , from_file = False ):
137
+ def _set_tokenizer (self , tokenizer_path : str ):
138
138
"""
139
139
Configures the tokenizer for the model
140
140
@@ -143,12 +143,27 @@ def _set_tokenizer(self, tokenizer_path: str, from_file=False):
143
143
tokenizer_path : str
144
144
Path for the tokenizer (str)
145
145
"""
146
- if from_file :
147
- self .tokenizer = Tokenizer .from_file (tokenizer_path )
148
- else :
149
- self .tokenizer = Tokenizer .from_pretrained (tokenizer_path )
150
- self .tokenizer .enable_padding (direction = "left" , pad_id = 1 , pad_type_id = 0 , pad_token = "[PAD]" )
151
- self .tokenizer .enable_truncation (max_length = 2020 , direction = "left" )
146
+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
147
+
148
+ # setup padding
149
+ tokenizer .pad_token_id = 1
150
+ tokenizer .pad_token = "<pad>"
151
+ tokenizer .padding_side = "left"
152
+
153
+ # setup truncation
154
+ tokenizer .truncation_side = "left"
155
+
156
+ # setup special tokens
157
+ tokenizer .bos_token_id = 0
158
+ tokenizer .bos_token = "<s>"
159
+
160
+ tokenizer .eos_token_id = 2
161
+ tokenizer .eos_token = "</s>"
162
+
163
+ tokenizer .unk_token = "<unk>"
164
+ tokenizer .unk_token_id = 3
165
+
166
+ self .tokenizer = tokenizer
152
167
153
168
def _tokenize (self , input_text : List [str ], new_doc : bool ) -> torch .LongTensor :
154
169
"""
@@ -167,24 +182,27 @@ def _tokenize(self, input_text: List[str], new_doc: bool) -> torch.LongTensor:
167
182
text = escape_custom_split_sequence (text )
168
183
if not text :
169
184
warnings .warn (
170
- "Found an empty input text. Chainging to end-of-document token instead." ,
185
+ "Found an empty input text. Changing to end-of-document token instead." ,
171
186
UserWarning
172
187
)
173
- text = "</s>"
188
+ text = self . tokenizer . eos_token
174
189
texts .append (text )
175
190
176
191
if new_doc :
177
- pad_id = self .tokenizer .padding ["pad_id" ]
178
- pad_token = self .tokenizer .id_to_token (pad_id )
192
+ pad_token = self .tokenizer .pad_token
179
193
texts = [pad_token + t for t in texts ]
180
194
181
- list_encoded = self .tokenizer .encode_batch (texts )
182
- context_tokens = [encoded .ids for encoded in list_encoded ]
195
+ encoded = self .tokenizer (
196
+ texts ,
197
+ padding = "longest" ,
198
+ max_length = self .max_input_length ,
199
+ truncation = True
200
+ )
201
+ context_tokens = encoded ["input_ids" ]
183
202
input_v = torch .LongTensor (context_tokens ).to (self .model .device )
184
203
185
204
if new_doc :
186
- eos_id = self .tokenizer .token_to_id ("</s>" )
187
- input_v [input_v [:, 0 ] == pad_id , 0 ] = eos_id
205
+ input_v [input_v [:, 0 ] == self .tokenizer .pad_token_id , 0 ] = self .tokenizer .eos_token_id
188
206
return input_v
189
207
190
208
@torch .inference_mode ()
@@ -278,9 +296,12 @@ def generate(
278
296
)
279
297
280
298
# we keep special tokens such as [START_REF] or <work>
281
- decoded = self .tokenizer .decode_batch (out ['sequences' ]. tolist () , skip_special_tokens = False )
299
+ decoded = self .tokenizer .batch_decode (out ['sequences' ], skip_special_tokens = False )
282
300
# so we manually remove </s> and <pad>
283
- decoded = [text .replace ("</s>" , "" ).replace ("<pad>" , "" ) for text in decoded ]
301
+ decoded = [
302
+ text .replace (self .tokenizer .eos_token , "" ).replace (self .tokenizer .pad_token , "" )
303
+ for text in decoded
304
+ ]
284
305
285
306
if num_return_sequences == 1 :
286
307
return decoded [0 ] if isinstance (input_text , str ) else decoded
@@ -366,7 +387,7 @@ def generate_reference(
366
387
prompt_length = input_v .shape [1 ]
367
388
finished_reference_criteria = FinishedReferenceCriteria (
368
389
prompt_length = prompt_length ,
369
- end_ref_id = self .tokenizer .token_to_id ("[END_REF]" ),
390
+ end_ref_id = self .tokenizer .convert_tokens_to_ids ("[END_REF]" ),
370
391
)
371
392
372
393
if max_new_tokens is None and max_length is None :
@@ -399,8 +420,8 @@ def generate_reference(
399
420
stopping_criteria = stopping_criteria ,
400
421
)
401
422
# cut-off the prompts
402
- generated_tokens = out ["sequences" ][:, prompt_length :]. tolist ()
403
- decoded = self .tokenizer .decode_batch (generated_tokens , skip_special_tokens = False )
423
+ generated_tokens = out ["sequences" ][:, prompt_length :]
424
+ decoded = self .tokenizer .batch_decode (generated_tokens , skip_special_tokens = False )
404
425
references = []
405
426
unfinished_generation = False
406
427
for text in decoded :
0 commit comments