Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c818124

Browse files
author
Marcin Kardas
authored
Merge pull request #3 from paperswithcode/feature/path_or_model
Feature/path or model
2 parents cdcd7c9 + 7ef4dfc commit c818124

File tree

3 files changed

+120
-23
lines changed

3 files changed

+120
-23
lines changed

galai/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from galai.utils import ModelInfo
55
import torch
66
import warnings
7-
7+
from pathlib import Path
88

99
HF_MAPPING = {
1010
"mini": ("facebook/galactica-125m", torch.float32),
@@ -54,12 +54,19 @@ def load_model(
5454
Model - model object
5555
"""
5656

57-
if name not in HF_MAPPING:
57+
if name in HF_MAPPING:
58+
hf_model, default_dtype = HF_MAPPING[name]
59+
galai_model = True
60+
elif Path(name).exists():
61+
hf_model = name
62+
default_dtype = torch.float32
63+
galai_model = False
64+
else:
5865
raise ValueError(
59-
"Invalid model name. Must be one of 'mini', 'base', 'standard', 'large', 'huge'."
66+
"Invalid model name. Must be one of 'mini', 'base', 'standard', 'large', 'huge', " +
67+
"a path to a local checkpoint dir, or a model name available on HuggingFace hub."
6068
)
6169

62-
hf_model, default_dtype = HF_MAPPING[name]
6370
if dtype is None:
6471
dtype = default_dtype
6572

@@ -99,7 +106,7 @@ def load_model(
99106
UserWarning
100107
)
101108
num_gpus = available
102-
if num_gpus > 1 and parallelize:
109+
if num_gpus > 1 and parallelize and galai_model:
103110
mi = ModelInfo.by_name(name)
104111
if mi.num_heads % num_gpus != 0:
105112
raise ValueError(

galai/model.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
import torch
55

6-
from tokenizers import Tokenizer
7-
from transformers import OPTForCausalLM, StoppingCriteriaList, StoppingCriteria
6+
from transformers import AutoTokenizer, OPTForCausalLM, StoppingCriteriaList, StoppingCriteria
87
from parallelformers import parallelize
98
import psutil
109

@@ -80,6 +79,7 @@ def __init__(
8079
self.is_loaded = False
8180
self.num_gpus = num_gpus
8281
self.tensor_parallel = tensor_parallel
82+
self.max_input_length = 2020
8383
self._master_port = None
8484

8585
def _load_checkpoint(self, checkpoint_path: str):
@@ -129,9 +129,15 @@ def _parallelize(self) -> None:
129129

130130
self._master_port = 13000 + (id(self.model) % 32749)
131131

132+
custom_policies = None
133+
if self.model.config.model_type == "opt" and not self.model.config.enable_bias:
134+
from galai.parallel_policy import OPTDecoderLayerPolicyNoBias
135+
custom_policies = [OPTDecoderLayerPolicyNoBias]
136+
132137
parallelize(
133138
self.model, num_gpus=self.num_gpus, fp16=self.dtype == torch.float16,
134139
master_port=self._master_port,
140+
custom_policies=custom_policies,
135141
)
136142

137143
def _set_tokenizer(self, tokenizer_path: str):
@@ -143,9 +149,27 @@ def _set_tokenizer(self, tokenizer_path: str):
143149
tokenizer_path : str
144150
Path for the tokenizer (str)
145151
"""
146-
self.tokenizer = Tokenizer.from_pretrained(tokenizer_path)
147-
self.tokenizer.enable_padding(direction="left", pad_id=1, pad_type_id=0, pad_token="[PAD]")
148-
self.tokenizer.enable_truncation(max_length=2020, direction="left")
152+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
153+
154+
# setup padding
155+
tokenizer.pad_token_id = 1
156+
tokenizer.pad_token = "<pad>"
157+
tokenizer.padding_side = "left"
158+
159+
# setup truncation
160+
tokenizer.truncation_side = "left"
161+
162+
# setup special tokens
163+
tokenizer.bos_token_id = 0
164+
tokenizer.bos_token = "<s>"
165+
166+
tokenizer.eos_token_id = 2
167+
tokenizer.eos_token = "</s>"
168+
169+
tokenizer.unk_token = "<unk>"
170+
tokenizer.unk_token_id = 3
171+
172+
self.tokenizer = tokenizer
149173

150174
def _tokenize(self, input_text: List[str], new_doc: bool) -> torch.LongTensor:
151175
"""
@@ -164,24 +188,27 @@ def _tokenize(self, input_text: List[str], new_doc: bool) -> torch.LongTensor:
164188
text = escape_custom_split_sequence(text)
165189
if not text:
166190
warnings.warn(
167-
"Found an empty input text. Chainging to end-of-document token instead.",
191+
"Found an empty input text. Changing to end-of-document token instead.",
168192
UserWarning
169193
)
170-
text = "</s>"
194+
text = self.tokenizer.eos_token
171195
texts.append(text)
172196

173197
if new_doc:
174-
pad_id = self.tokenizer.padding["pad_id"]
175-
pad_token = self.tokenizer.id_to_token(pad_id)
198+
pad_token = self.tokenizer.pad_token
176199
texts = [pad_token + t for t in texts]
177200

178-
list_encoded = self.tokenizer.encode_batch(texts)
179-
context_tokens = [encoded.ids for encoded in list_encoded]
201+
encoded = self.tokenizer(
202+
texts,
203+
padding="longest",
204+
max_length=self.max_input_length,
205+
truncation=True
206+
)
207+
context_tokens = encoded["input_ids"]
180208
input_v = torch.LongTensor(context_tokens).to(self.model.device)
181209

182210
if new_doc:
183-
eos_id = self.tokenizer.token_to_id("</s>")
184-
input_v[input_v[:, 0] == pad_id, 0] = eos_id
211+
input_v[input_v[:, 0] == self.tokenizer.pad_token_id, 0] = self.tokenizer.eos_token_id
185212
return input_v
186213

187214
@torch.inference_mode()
@@ -275,9 +302,12 @@ def generate(
275302
)
276303

277304
# we keep special tokens such as [START_REF] or <work>
278-
decoded = self.tokenizer.decode_batch(out['sequences'].tolist(), skip_special_tokens=False)
305+
decoded = self.tokenizer.batch_decode(out['sequences'], skip_special_tokens=False)
279306
# so we manually remove </s> and <pad>
280-
decoded = [text.replace("</s>", "").replace("<pad>", "") for text in decoded]
307+
decoded = [
308+
text.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
309+
for text in decoded
310+
]
281311

282312
if num_return_sequences == 1:
283313
return decoded[0] if isinstance(input_text, str) else decoded
@@ -363,7 +393,7 @@ def generate_reference(
363393
prompt_length = input_v.shape[1]
364394
finished_reference_criteria = FinishedReferenceCriteria(
365395
prompt_length=prompt_length,
366-
end_ref_id=self.tokenizer.token_to_id("[END_REF]"),
396+
end_ref_id=self.tokenizer.convert_tokens_to_ids("[END_REF]"),
367397
)
368398

369399
if max_new_tokens is None and max_length is None:
@@ -396,8 +426,8 @@ def generate_reference(
396426
stopping_criteria=stopping_criteria,
397427
)
398428
# cut-off the prompts
399-
generated_tokens = out["sequences"][:, prompt_length:].tolist()
400-
decoded = self.tokenizer.decode_batch(generated_tokens, skip_special_tokens=False)
429+
generated_tokens = out["sequences"][:, prompt_length:]
430+
decoded = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
401431
references = []
402432
unfinished_generation = False
403433
for text in decoded:

galai/parallel_policy.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from parallelformers.policies.base import Layer, Policy
2+
from parallelformers.utils.dist_utils import AllReduceLinear
3+
4+
from transformers.models.opt.modeling_opt import OPTDecoderLayer
5+
6+
7+
__all__ = ["OPTDecoderLayerPolicyNoBias"]
8+
9+
10+
class OPTDecoderLayerPolicyNoBias(Policy):
11+
@staticmethod
12+
def replace_arguments(config, world_size):
13+
return {
14+
"self_attn.embed_dim": config.hidden_size // world_size,
15+
"self_attn.num_heads": config.num_attention_heads // world_size,
16+
}
17+
18+
@staticmethod
19+
def attn_qkv():
20+
return [
21+
Layer(
22+
weight="self_attn.q_proj.weight",
23+
),
24+
Layer(
25+
weight="self_attn.k_proj.weight",
26+
),
27+
Layer(
28+
weight="self_attn.v_proj.weight",
29+
),
30+
]
31+
32+
@staticmethod
33+
def attn_out():
34+
return [
35+
Layer(
36+
weight="self_attn.out_proj.weight",
37+
replace=AllReduceLinear,
38+
),
39+
]
40+
41+
@staticmethod
42+
def mlp_in():
43+
return [
44+
Layer(
45+
weight="fc1.weight",
46+
),
47+
]
48+
49+
@staticmethod
50+
def mlp_out():
51+
return [
52+
Layer(
53+
weight="fc2.weight",
54+
replace=AllReduceLinear,
55+
),
56+
]
57+
58+
@staticmethod
59+
def original_layer_class():
60+
return OPTDecoderLayer

0 commit comments

Comments
 (0)