Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 86bdbc9

Browse files
author
Marcin Kardas
committed
Switch to AutoTokenizer
1 parent 8c5602e commit 86bdbc9

File tree

2 files changed

+49
-31
lines changed

2 files changed

+49
-31
lines changed

galai/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,15 @@ def load_model(
5656

5757
if name in HF_MAPPING:
5858
hf_model, default_dtype = HF_MAPPING[name]
59-
tokenizer_path = hf_model
60-
from_file=False
61-
59+
galai_model = True
6260
elif Path(name).exists():
6361
hf_model = name
6462
default_dtype = torch.float32
65-
# tokenizer_path = "facebook/galactica-1.3b"
66-
tokenizer_path = name + "/tokenizer.json"
67-
from_file=True
63+
galai_model = False
6864
else:
6965
raise ValueError(
70-
"Invalid model name. Must be one of 'mini', 'base', 'standard', 'large', 'huge'."
66+
"Invalid model name. Must be one of 'mini', 'base', 'standard', 'large', 'huge', " +
67+
"a path to a local checkpoint dir, or a model name available on HuggingFace hub."
7168
)
7269

7370
if dtype is None:
@@ -109,7 +106,7 @@ def load_model(
109106
UserWarning
110107
)
111108
num_gpus = available
112-
if num_gpus > 1 and parallelize:
109+
if num_gpus > 1 and parallelize and galai_model:
113110
mi = ModelInfo.by_name(name)
114111
if mi.num_heads % num_gpus != 0:
115112
raise ValueError(
@@ -130,7 +127,7 @@ def load_model(
130127
num_gpus=num_gpus,
131128
tensor_parallel=parallelize,
132129
)
133-
model._set_tokenizer(tokenizer_path, from_file=from_file)
130+
model._set_tokenizer(hf_model)
134131
model._load_checkpoint(checkpoint_path=hf_model)
135132

136133
return model

galai/model.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
import torch
55

6-
from tokenizers import Tokenizer
7-
from transformers import OPTForCausalLM, StoppingCriteriaList, StoppingCriteria
6+
from transformers import AutoTokenizer, OPTForCausalLM, StoppingCriteriaList, StoppingCriteria
87
from parallelformers import parallelize
98
import psutil
109

@@ -80,6 +79,7 @@ def __init__(
8079
self.is_loaded = False
8180
self.num_gpus = num_gpus
8281
self.tensor_parallel = tensor_parallel
82+
self.max_input_length = 2020
8383
self._master_port = None
8484

8585
def _load_checkpoint(self, checkpoint_path: str):
@@ -134,7 +134,7 @@ def _parallelize(self) -> None:
134134
master_port=self._master_port,
135135
)
136136

137-
def _set_tokenizer(self, tokenizer_path: str, from_file=False):
137+
def _set_tokenizer(self, tokenizer_path: str):
138138
"""
139139
Configures the tokenizer for the model
140140
@@ -143,12 +143,27 @@ def _set_tokenizer(self, tokenizer_path: str, from_file=False):
143143
tokenizer_path : str
144144
Path for the tokenizer (str)
145145
"""
146-
if from_file:
147-
self.tokenizer = Tokenizer.from_file(tokenizer_path)
148-
else:
149-
self.tokenizer = Tokenizer.from_pretrained(tokenizer_path)
150-
self.tokenizer.enable_padding(direction="left", pad_id=1, pad_type_id=0, pad_token="[PAD]")
151-
self.tokenizer.enable_truncation(max_length=2020, direction="left")
146+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
147+
148+
# setup padding
149+
tokenizer.pad_token_id = 1
150+
tokenizer.pad_token = "<pad>"
151+
tokenizer.padding_side = "left"
152+
153+
# setup truncation
154+
tokenizer.truncation_side = "left"
155+
156+
# setup special tokens
157+
tokenizer.bos_token_id = 0
158+
tokenizer.bos_token = "<s>"
159+
160+
tokenizer.eos_token_id = 2
161+
tokenizer.eos_token = "</s>"
162+
163+
tokenizer.unk_token = "<unk>"
164+
tokenizer.unk_token_id = 3
165+
166+
self.tokenizer = tokenizer
152167

153168
def _tokenize(self, input_text: List[str], new_doc: bool) -> torch.LongTensor:
154169
"""
@@ -167,24 +182,27 @@ def _tokenize(self, input_text: List[str], new_doc: bool) -> torch.LongTensor:
167182
text = escape_custom_split_sequence(text)
168183
if not text:
169184
warnings.warn(
170-
"Found an empty input text. Chainging to end-of-document token instead.",
185+
"Found an empty input text. Changing to end-of-document token instead.",
171186
UserWarning
172187
)
173-
text = "</s>"
188+
text = self.tokenizer.eos_token
174189
texts.append(text)
175190

176191
if new_doc:
177-
pad_id = self.tokenizer.padding["pad_id"]
178-
pad_token = self.tokenizer.id_to_token(pad_id)
192+
pad_token = self.tokenizer.pad_token
179193
texts = [pad_token + t for t in texts]
180194

181-
list_encoded = self.tokenizer.encode_batch(texts)
182-
context_tokens = [encoded.ids for encoded in list_encoded]
195+
encoded = self.tokenizer(
196+
texts,
197+
padding="longest",
198+
max_length=self.max_input_length,
199+
truncation=True
200+
)
201+
context_tokens = encoded["input_ids"]
183202
input_v = torch.LongTensor(context_tokens).to(self.model.device)
184203

185204
if new_doc:
186-
eos_id = self.tokenizer.token_to_id("</s>")
187-
input_v[input_v[:, 0] == pad_id, 0] = eos_id
205+
input_v[input_v[:, 0] == self.tokenizer.pad_token_id, 0] = self.tokenizer.eos_token_id
188206
return input_v
189207

190208
@torch.inference_mode()
@@ -278,9 +296,12 @@ def generate(
278296
)
279297

280298
# we keep special tokens such as [START_REF] or <work>
281-
decoded = self.tokenizer.decode_batch(out['sequences'].tolist(), skip_special_tokens=False)
299+
decoded = self.tokenizer.batch_decode(out['sequences'], skip_special_tokens=False)
282300
# so we manually remove </s> and <pad>
283-
decoded = [text.replace("</s>", "").replace("<pad>", "") for text in decoded]
301+
decoded = [
302+
text.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
303+
for text in decoded
304+
]
284305

285306
if num_return_sequences == 1:
286307
return decoded[0] if isinstance(input_text, str) else decoded
@@ -366,7 +387,7 @@ def generate_reference(
366387
prompt_length = input_v.shape[1]
367388
finished_reference_criteria = FinishedReferenceCriteria(
368389
prompt_length=prompt_length,
369-
end_ref_id=self.tokenizer.token_to_id("[END_REF]"),
390+
end_ref_id=self.tokenizer.convert_tokens_to_ids("[END_REF]"),
370391
)
371392

372393
if max_new_tokens is None and max_length is None:
@@ -399,8 +420,8 @@ def generate_reference(
399420
stopping_criteria=stopping_criteria,
400421
)
401422
# cut-off the prompts
402-
generated_tokens = out["sequences"][:, prompt_length:].tolist()
403-
decoded = self.tokenizer.decode_batch(generated_tokens, skip_special_tokens=False)
423+
generated_tokens = out["sequences"][:, prompt_length:]
424+
decoded = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
404425
references = []
405426
unfinished_generation = False
406427
for text in decoded:

0 commit comments

Comments
 (0)