Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 15d0bed

Browse files
authored
[LLM] Add ChatGLM (PaddlePaddle#5543)
* [chatglm] init version * [chatglm] model architecture * [update] align commit 9324de7 * [chatglm] forward and generation aligned * [chatglm] remove resource file * [chatglm] add model state * [chatglm] adapt for mp * [chatglm] aligned mp forward * [chatglm] add finetune and checkpoint to check generation * [chatglm] fix generation shape error * [chatglm] add prediction script * [chatglm] fix recompute * [chatglm] feasible export * [chatglm] export ok, infer dtype error * [chatglm] use dtype instead of paddle_dtype * [chatglm] fix shape error during inference * [chatglm] standarize modeling (to check) * [chatglm] check succeed * [chatglm] rename finetune script * [chatglm] remove external type cast * [chatglm] move modeling and tokenizer to transformers * [chatglm] aligned with 4e8efe * [chatglm] update readme * [chatglm] update readme * [chatglm] update readme * [chatglm] update readme
1 parent 8256c7a commit 15d0bed

14 files changed

Lines changed: 2472 additions & 0 deletions

File tree

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# ChatGLM
2+
3+
ChatGLM-6B ๆ˜ฏไธ€ไธชๅผ€ๆบ็š„ใ€ๆ”ฏๆŒไธญ่‹ฑๅŒ่ฏญ้—ฎ็ญ”็š„ๅฏน่ฏ่ฏญ่จ€ๆจกๅž‹๏ผŒๅŸบไบŽ [General Language Model (GLM)](https://arxiv.org/abs/2103.10360) ๆžถๆž„๏ผŒๅ…ทๆœ‰ 62 ไบฟๅ‚ๆ•ฐใ€‚ChatGLM-6B ไฝฟ็”จไบ†ๅ’Œ ChatGLM ็›ธๅŒ็š„ๆŠ€ๆœฏ๏ผŒ้’ˆๅฏนไธญๆ–‡้—ฎ็ญ”ๅ’Œๅฏน่ฏ่ฟ›่กŒไบ†ไผ˜ๅŒ–ใ€‚็ป่ฟ‡็บฆ 1T ๆ ‡่ฏ†็ฌฆ็š„ไธญ่‹ฑๅŒ่ฏญ่ฎญ็ปƒ๏ผŒ่พ…ไปฅ็›‘็ฃๅพฎ่ฐƒใ€ๅ้ฆˆ่‡ชๅŠฉใ€ไบบ็ฑปๅ้ฆˆๅผบๅŒ–ๅญฆไน ็ญ‰ๆŠ€ๆœฏ็š„ๅŠ ๆŒ๏ผŒ62 ไบฟๅ‚ๆ•ฐ็š„ ChatGLM-6B ๅทฒ็ป่ƒฝ็”Ÿๆˆ็›ธๅฝ“็ฌฆๅˆไบบ็ฑปๅๅฅฝ็š„ๅ›ž็ญ”ใ€‚
4+
5+
6+
ๆœฌ็คบไพ‹ๆไพ›ไบ† ChatGLM ๆจกๅž‹็š„็”ŸๆˆไปปๅŠกๅพฎ่ฐƒๆต็จ‹๏ผŒ้€‚็”จไบŽ THUDM/chatglm-6b ๆจกๅž‹ใ€‚
7+
8+
## ็Žฏๅขƒไพ่ต–
9+
็›ฎๅ‰็‰ˆๆœฌๆ”ฏๆŒ็š„ๅŠŸ่ƒฝ่พƒๅคš๏ผŒๅปบ่ฎฎไฝฟ็”จ paddlepaddle develop ็‰ˆๆœฌไปฅ่Žทๅพ—่พƒๅฅฝไฝ“้ชŒใ€‚ไธ‹้ข็ป™ๅ‡บไบ† cuda 11.2 ็š„ paddle ๅฎ‰่ฃ…ๆ–นๆณ•ใ€‚ๆ›ดๅคšๅ…ถไป–็‰ˆๆœฌ๏ผŒ่ฏทๅ‚่€ƒ[ๅฎ˜็ฝ‘้ฆ–้กต](https://www.paddlepaddle.org.cn/)ไธ‹่ฝฝใ€‚
10+
```
11+
python -m pip install paddlepaddle-gpu==0.0.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html
12+
```
13+
14+
## AdvertiseGen ๅนฟๅ‘Š็”ŸๆˆไปปๅŠก
15+
16+
ๆœฌ็คบไพ‹ๅŸบไบŽๅนฟๅ‘Š็”Ÿๆˆๆ•ฐๆฎ้›† AdvertiseGen๏ผŒ่พ“ๅ…ฅไธบๆœ่ฃ…ๆ่ฟฐๅ…ณ้”ฎ่ฏ๏ผŒ่พ“ๅ‡บไธบ็›ธๅบ”็š„ๅนฟๅ‘Š่ฏญ๏ผŒๅฏไปŽ[่ฟ™้‡Œ](https://paddlenlp.bj.bcebos.com/datasets/examples/AdvertiseGen.tar.gz)ไธ‹่ฝฝใ€‚
17+
18+
### ๅคšๅก่ฎญ็ปƒ่„šๆœฌ๏ผˆๆจกๅž‹ๅนถ่กŒ็ญ–็•ฅ๏ผ‰
19+
20+
```
21+
python -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py \
22+
--model_name_or_path THUDM/chatglm-6b \
23+
--task_path AdvertiseGen/ \
24+
--max_steps 3000 \
25+
--learning_rate 3e-5 \
26+
--warmup_steps 20 \
27+
--eval_steps 100 \
28+
--logging_steps 1 \
29+
--save_steps 1000 \
30+
--save_total_limit 1 \
31+
--output_dir ./checkpoints/chatglm-6b \
32+
--src_length 64 \
33+
--tgt_length 64 \
34+
--per_device_eval_batch_size 16 \
35+
--per_device_train_batch_size 16 \
36+
--gradient_accumulation_steps 8 \
37+
--fp16 \
38+
--fp16_opt_level O2 \
39+
--recompute True \
40+
--do_train \
41+
--do_eval
42+
--tensor_parallel_degree 2
43+
```
44+
45+
ๅ…ถไธญๅ‚ๆ•ฐ้‡Šไน‰ๅฆ‚ไธ‹๏ผš
46+
47+
- `model_name_or_path`: ้ข„่ฎญ็ปƒๆจกๅž‹ๅ†…็ฝฎๅ็งฐๆˆ–่€…ๆจกๅž‹ๆ‰€ๅœจ็›ฎๅฝ•๏ผŒ้ป˜่ฎคไธบ`THUDM/chatglm-6b`ใ€‚
48+
- `task_path`: ๆ•ฐๆฎ้›†ๅญ˜ๅ‚จ็›ฎๅฝ•ใ€‚
49+
- `max_steps`: ๆจกๅž‹่ฎญ็ปƒๆญฅๆ•ฐใ€‚
50+
- `learning_rate`: ๅ‚ๆ•ฐๆ›ดๆ–ฐ็š„ๅญฆไน ็އใ€‚
51+
- `warmup_steps`: ๅญฆไน ็އ็ƒญๅฏ็š„ๆญฅๆ•ฐใ€‚
52+
- `eval_steps`: ๆจกๅž‹่ฏ„ไผฐ็š„้—ด้š”ๆญฅๆ•ฐใ€‚
53+
- `logging_steps`: ่ฎญ็ปƒๆ—ฅๅฟ—ๆ‰“ๅฐ็š„้—ด้š”ๆญฅๆ•ฐใ€‚
54+
- `save_steps`: ๆจกๅž‹ๅ‚ๆ•ฐไฟๅญ˜็š„้—ด้š”ๆญฅๆ•ฐใ€‚
55+
- `save_total_limit`: ๆจกๅž‹ checkpoint ไฟๅญ˜็š„ไปฝๆ•ฐใ€‚
56+
- `output_dir`: ๆจกๅž‹ๅ‚ๆ•ฐไฟๅญ˜็›ฎๅฝ•ใ€‚
57+
- `src_length`: ไธŠไธ‹ๆ–‡็š„ๆœ€ๅคง่พ“ๅ…ฅ้•ฟๅบฆ๏ผŒ้ป˜่ฎคไธบ128.
58+
- `tgt_length`: ็”Ÿๆˆๆ–‡ๆœฌ็š„ๆœ€ๅคง้•ฟๅบฆ๏ผŒ้ป˜่ฎคไธบ160.
59+
- `gradient_accumulation_steps`: ๆจกๅž‹ๅ‚ๆ•ฐๆขฏๅบฆ็ดฏ็งฏ็š„ๆญฅๆ•ฐ๏ผŒๅฏ็”จไบŽๆ‰ฉๅคง batch sizeใ€‚ๅฎž้™…็š„ batch_size = per_device_train_batch_size * gradient_accumulation_stepsใ€‚
60+
- `fp16`: ไฝฟ็”จ float16 ็ฒพๅบฆ่ฟ›่กŒๆจกๅž‹่ฎญ็ปƒๅ’ŒๆŽจ็†ใ€‚
61+
- `fp16_opt_level`: float16 ็ฒพๅบฆ่ฎญ็ปƒๆจกๅผ๏ผŒ`O2`่กจ็คบ็บฏ float16 ่ฎญ็ปƒใ€‚
62+
- `recompute`: ไฝฟ็”จ้‡่ฎก็ฎ—็ญ–็•ฅ๏ผŒๅผ€ๅฏๅŽๅฏ่Š‚็œ่ฎญ็ปƒๆ˜พๅญ˜ใ€‚
63+
- `do_train`: ๆ˜ฏๅฆ่ฎญ็ปƒๆจกๅž‹ใ€‚
64+
- `do_eval`: ๆ˜ฏๅฆ่ฏ„ไผฐๆจกๅž‹ใ€‚
65+
- `tensor_parallel_degree`: ๆจกๅž‹ๅนถ่กŒๆ•ฐ้‡ใ€‚
66+
67+
68+
## ๆจกๅž‹้ข„ๆต‹
69+
70+
ๅฏไปฅๅฐ†ๆจกๅž‹pythonๅ‰ๅ‘ไธŽๆŽจ็†็ป“ๆžœๆฏ”่พƒ๏ผš
71+
72+
```
73+
python predict_generation.py \
74+
--model_name_or_path ./checkpoints/chatglm-6b
75+
```
76+
77+
ๅฝ“ checkpoint ไฝฟ็”จ`tensor parallel`ๅญ˜ๅ‚จไธบๅคšๅˆ†็‰‡ๆ ผๅผๆ—ถ๏ผŒไนŸๅฏไปฅไฝฟ็”จๆญค่„šๆœฌ้ข„ๆต‹๏ผŒๆˆ–่€…ๅฐ†ๅ…ถๅˆๅนถไธบไธ€ไธชๅ•ๅˆ†็‰‡ๆƒ้‡ใ€‚ไพ‹ๅฆ‚๏ผŒไธ‹้ขๆจกๅž‹ไฟๅญ˜ไธบไบ†ๅ››ๅˆ†็‰‡๏ผŒ
78+
79+
```
80+
(base) root@localhost glm $ ll ./checkpoints/chatglm-6b/checkpoint-100/
81+
total 82G
82+
drwxr-xr-x 2 root root 4.0K Apr 16 22:41 ./
83+
drwxr-xr-x 4 root root 4.0K Apr 16 22:41 ../
84+
-rw-r--r-- 1 root root 811 Apr 16 22:40 config.json
85+
-rw-r--r-- 1 root root 2.6M Apr 16 22:40 ice_text.model
86+
-rw-r--r-- 1 root root 3.2G Apr 16 22:40 model_state.tp00.pdparams
87+
-rw-r--r-- 1 root root 3.2G Apr 16 22:40 model_state.tp01.pdparams
88+
-rw-r--r-- 1 root root 3.2G Apr 16 22:40 model_state.tp02.pdparams
89+
-rw-r--r-- 1 root root 3.2G Apr 16 22:40 model_state.tp03.pdparams
90+
```
91+
92+
ๅฏไปฅ่ฟ่กŒไปฅไธ‹ๅ‘ฝไปคๅฐ†ๆจกๅž‹ๅˆๅนถไธบๅ•ๅˆ†็‰‡ๅนถไฟๅญ˜ใ€‚
93+
94+
```
95+
python -m paddle.distributed.launch --gpus 0,1,2,3 predict_generation.py \
96+
--model_name_or_path ./checkpoints/chatglm-6b/checkpoint-100/ \
97+
--merge_tensor_parallel_path ./checkpoints/chatglm-merged
98+
```
99+
100+
ๅ…ถไธญๅ‚ๆ•ฐ `merge_tensor_parallel_path` ๆŒ‡ๅฎšไบ†ๅˆๅนถๅŽๆจกๅž‹ๅ‚ๆ•ฐ็š„ๅญ˜ๅ‚จไฝ็ฝฎใ€‚ๅฆ‚ๆžœไธ่ฎพ็ฝฎ่ฟ™ไธ€ๅ‚ๆ•ฐ๏ผŒๅฐ†ๅช่ท‘ๅ‰ๅ‘ใ€‚
101+
102+
## ๆจกๅž‹ๅฏผๅ‡บ
103+
104+
ๅœจๆจกๅž‹่ฎญ็ปƒๅฎŒๆฏ•ๅŽ๏ผŒๅฏไฝฟ็”จๅฆ‚ไธ‹่„šๆœฌๅฐ†ๆจกๅž‹ๅ‚ๆ•ฐๅฏผๅ‡บไธบ้™ๆ€ๅ›พ๏ผŒ็”จไบŽๆจกๅž‹ๆŽจ็†ใ€‚
105+
106+
```
107+
python export_generation_model.py \
108+
--model_name_or_path ./checkpoints/chatglm-6b \
109+
--output_path ./checkpoints/infer/chatglm \
110+
--dtype "float32"
111+
```
112+
113+
ๅ…ถไธญๅ‚ๆ•ฐๅฎšไน‰ๅฆ‚ไธ‹๏ผš
114+
115+
- `model_name_or_path`: ้ข„่ฎญ็ปƒๆจกๅž‹ๅ†…็ฝฎๅ็งฐๆˆ–่€…ๆจกๅž‹ๆ‰€ๅœจ็›ฎๅฝ•ใ€‚
116+
- `output_path`: ๅฏผๅ‡บๆจกๅž‹ๅญ˜ๅ‚จๅœฐๅ€ๅ’Œๆ–‡ไปถๅ‰็ผ€ใ€‚็คบไพ‹ไธญๅฏผๅ‡บๅœฐๅ€ไธบ `./checkpoints/infer`๏ผŒๆจกๅž‹ๅ‰็ผ€ไธบ `chatglm`ใ€‚
117+
- `dtype`: ๆจกๅž‹ๅ‚ๆ•ฐ็ฑปๅž‹๏ผŒ้ป˜่ฎคไธบ`float32`๏ผŒๅฏ้€‰ๅ‚ๆ•ฐ`float16`ๅ’Œ`float32`ใ€‚
118+
119+
## ๆจกๅž‹ๆŽจ็†๏ผˆc++ๆŽจ็†๏ผ‰
120+
121+
**็Žฏๅขƒไพ่ต–**
122+
123+
ๆจกๅž‹ๆŽจ็†ไพ่ต–ไบŽๆœ€ๆ–ฐ็‰ˆๆœฌ็š„ FastDeploy๏ผŒๅฏไฝฟ็”จไปฅไธ‹ๅ‘ฝไปคๅฎ‰่ฃ…๏ผš
124+
125+
```
126+
# GPU ๅฎ‰่ฃ…
127+
pip install fastdeploy-gpu-python==0.0.0 -f https://www.paddlepaddle.org.cn/whl/fastdeploy_nightly_build.html`
128+
```
129+
130+
่ฟ่กŒไปฅไธ‹ๅ‘ฝไปค๏ผŒไฝฟ็”จ้™ๆ€ๅ›พ่ฟ›่กŒๆจกๅž‹ๆŽจ็†ใ€‚
131+
132+
```
133+
python infer_generation.py \
134+
--model_path ./checkpoints/infer \
135+
--model_prefix chatglm
136+
```
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
17+
import numpy as np
18+
19+
20+
def read_local_dataset(path):
21+
with open(path, "r", encoding="utf-8") as fp:
22+
for line in fp:
23+
yield json.loads(line.strip())
24+
25+
26+
def convert_example(example, tokenizer, data_args, is_test=True):
27+
query = example["content"]
28+
response = example["summary"]
29+
history = example.get("history", None)
30+
31+
if history is None or len(history) == 0:
32+
prompt = query
33+
else:
34+
prompt = ""
35+
for i, (old_query, old_response) in enumerate(history):
36+
prompt += "[Round {}]\n้—ฎ๏ผš{}\n็ญ”๏ผš{}\n".format(i, old_query, old_response)
37+
prompt += "[Round {}]\n้—ฎ๏ผš{}\n็ญ”๏ผš".format(len(history), query)
38+
39+
# dataset for evaluation
40+
if is_test:
41+
inputs = {
42+
**tokenizer(prompt, max_length=data_args.src_length, truncation=True, padding="max_length"),
43+
"labels": tokenizer(response, max_length=data_args.tgt_length, truncation=True, padding="max_length")[
44+
"input_ids"
45+
],
46+
}
47+
# dataset for training
48+
else:
49+
src_ids = tokenizer(
50+
prompt,
51+
add_special_tokens=False,
52+
max_length=data_args.src_length - 1,
53+
truncation=True,
54+
truncation_side="left",
55+
)["input_ids"]
56+
tgt_ids = tokenizer(
57+
response,
58+
add_special_tokens=False,
59+
max_length=data_args.tgt_length - 2,
60+
truncation=True,
61+
truncation_side="right",
62+
)["input_ids"]
63+
64+
input_ids = tokenizer.build_inputs_with_special_tokens(src_ids, tgt_ids)
65+
66+
context_length = input_ids.index(tokenizer.bos_token_id)
67+
mask_position = context_length - 1
68+
69+
attention_mask = np.tri(len(input_ids), len(input_ids))
70+
attention_mask[:, :context_length] = 1
71+
attention_mask = attention_mask[None, :, :]
72+
attention_mask = (attention_mask < 0.5).astype("int64")
73+
74+
labels = [-100] * context_length + input_ids[mask_position + 1 :]
75+
76+
inputs = {
77+
"input_ids": input_ids,
78+
"attention_mask": attention_mask,
79+
"labels": labels,
80+
}
81+
return inputs
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
import paddle
19+
20+
from paddlenlp.transformers import ChatGLMForConditionalGeneration, ChatGLMTokenizer
21+
22+
23+
def parse_args():
24+
parser = argparse.ArgumentParser()
25+
# Required parameters
26+
parser.add_argument(
27+
"--model_name_or_path",
28+
default="THUDM/chatglm-6b",
29+
type=str,
30+
# required=True,
31+
help="Path of the trained model to be exported.",
32+
)
33+
parser.add_argument(
34+
"--output_path",
35+
default="inference/chatglm",
36+
type=str,
37+
# required=True,
38+
help="The output file prefix used to save the exported inference model.",
39+
)
40+
parser.add_argument("--dtype", default="float32", type=str, help="The data type of exported model")
41+
args = parser.parse_args()
42+
return args
43+
44+
45+
def main():
46+
args = parse_args()
47+
48+
paddle.set_default_dtype(args.dtype)
49+
50+
tokenizer = ChatGLMTokenizer.from_pretrained(args.model_name_or_path)
51+
model = ChatGLMForConditionalGeneration.from_pretrained(
52+
args.model_name_or_path, load_state_as_np=True, dtype=args.dtype
53+
)
54+
55+
model.eval()
56+
input_spec = [
57+
paddle.static.InputSpec(shape=[None, None], dtype="int64"), # input_ids
58+
paddle.static.InputSpec(shape=[None, None, None, None], dtype="int64"), # attention_mask
59+
paddle.static.InputSpec(shape=[None, None, None], dtype="int64"), # position_ids
60+
# max_length
61+
128,
62+
# min_length
63+
0,
64+
# decode_strategy
65+
"sampling",
66+
# temperature
67+
1.0,
68+
# top_k
69+
1,
70+
# top_p
71+
1.0,
72+
# repetition_penalty
73+
1,
74+
# num_beams
75+
1,
76+
# num_beam_groups
77+
1,
78+
# length_penalty
79+
0.0,
80+
# early_stopping
81+
False,
82+
# bos_token_id
83+
tokenizer.eos_token_id,
84+
# eos_token_id
85+
tokenizer.end_token_id,
86+
# pad_token_id
87+
tokenizer.pad_token_id,
88+
# decoder_start_token_id
89+
None,
90+
# forced_bos_token_id
91+
None,
92+
# forced_eos_token_id
93+
None,
94+
# no_repeat_ngram_size
95+
None,
96+
# num_return_sequences
97+
1,
98+
# diversity_rate
99+
0.0,
100+
# use_cache
101+
True,
102+
]
103+
model = paddle.jit.to_static(model.generate, input_spec=input_spec)
104+
105+
print("jit.to_static")
106+
# # Save converted static graph model
107+
paddle.jit.save(model, args.output_path)
108+
print("jit.save")
109+
# # Also save tokenizer for inference usage
110+
tokenizer.save_pretrained(os.path.dirname(args.output_path))
111+
print("save_pretrained")
112+
113+
114+
if __name__ == "__main__":
115+
main()

0 commit comments

Comments
ย (0)