1
+ # https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
2
+ # https://huggingface.co/docs/trl/main/en/dpo_trainer
3
+ # https://huggingface.co/datasets/lvwerra/stack-exchange-paired
4
+ # https://huggingface.co/blog/zh/dpo-trl
5
+
6
+ # https://github.dev/RUCAIBox/LLMBox
7
+
8
+ # 0. imports
9
+ import os
10
+ from dataclasses import dataclass , field
11
+ from typing import Dict , Optional
12
+
13
+ import torch
14
+ from accelerate import Accelerator
15
+ from datasets import Dataset , load_dataset
16
+ from peft import LoraConfig
17
+ from transformers import AutoModelForCausalLM , AutoTokenizer , HfArgumentParser , TrainingArguments , set_seed
18
+ from torch .utils .data import Dataset , DataLoader , random_split
19
+ from trl import DPOTrainer
20
+ from tinyllm_dataset import load_dpo_dataset
21
+
22
+
23
+ # Define and parse arguments.
24
+ @dataclass
25
+ class ScriptArguments :
26
+ """
27
+ The arguments for the DPO training script.
28
+ """
29
+
30
+ # data parameters
31
+ beta : Optional [float ] = field (default = 0.1 , metadata = {"help" : "the beta parameter for DPO loss" })
32
+
33
+ # training parameters
34
+ model_name : Optional [str ] = field (default = "" ,metadata = {"help" : "the location of the SFT model name or path" })
35
+ dataset_dir_or_path : Optional [str ] = field (default = "" ,metadata = {"help" : "the location of the SFT model name or path" })
36
+ eval_dataset_dir_or_path : Optional [str ] = field (default = "" ,metadata = {"help" : "the location of the SFT model name or path" })
37
+ resume : Optional [bool ] = field (default = False ,metadata = {"help" : "the location of the SFT model name or path" })
38
+ base_model_path : Optional [str ] = field (default = "" ,metadata = {"help" : "the location of the SFT model name or path" })
39
+
40
+ learning_rate : Optional [float ] = field (default = 5e-4 , metadata = {"help" : "optimizer learning rate" })
41
+ lr_scheduler_type : Optional [str ] = field (default = "cosine" , metadata = {"help" : "the lr scheduler type" })
42
+ warmup_ratio : Optional [float ] = field (default = 0.01 , metadata = {"help" : "the number of warmup steps" })
43
+ weight_decay : Optional [float ] = field (default = 0.05 , metadata = {"help" : "the weight decay" })
44
+ optimizer_type : Optional [str ] = field (default = "adamw_torch" , metadata = {"help" : "the optimizer type" })
45
+
46
+ per_device_train_batch_size : Optional [int ] = field (default = 1 , metadata = {"help" : "train batch size per device" })
47
+ per_device_eval_batch_size : Optional [int ] = field (default = 1 , metadata = {"help" : "eval batch size per device" })
48
+ gradient_accumulation_steps : Optional [int ] = field (
49
+ default = 4 , metadata = {"help" : "the number of gradient accumulation steps" }
50
+ )
51
+ gradient_checkpointing : Optional [bool ] = field (
52
+ default = True , metadata = {"help" : "whether to use gradient checkpointing" }
53
+ )
54
+
55
+ gradient_checkpointing_use_reentrant : Optional [bool ] = field (
56
+ default = False , metadata = {"help" : "whether to use reentrant for gradient checkpointing" }
57
+ )
58
+
59
+ bf16 : bool = field (
60
+ default = False ,
61
+ metadata = {
62
+ "help" : (
63
+ "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
64
+ " architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change."
65
+ )
66
+ },
67
+ )
68
+ fp16 : bool = field (
69
+ default = False ,
70
+ metadata = {"help" : "Whether to use fp16 (mixed) precision instead of 32-bit" },
71
+ )
72
+
73
+ max_prompt_length : Optional [int ] = field (default = 512 , metadata = {"help" : "the maximum prompt length" })
74
+ max_length : Optional [int ] = field (default = 1024 , metadata = {"help" : "the maximum sequence length" })
75
+ num_train_epochs : Optional [int ] = field (default = 5 , metadata = {"help" : "epoch of training steps" })
76
+ logging_strategy : Optional [str ] = field (default = "steps" , metadata = {"help" : "logging_strategy" })
77
+ logging_dir : Optional [str ] = field (default = "" , metadata = {"help" : "logging_dir" })
78
+ logging_steps : Optional [int ] = field (default = 10 , metadata = {"help" : "the logging frequency" })
79
+ save_steps : Optional [int ] = field (default = 100 , metadata = {"help" : "the saving frequency" })
80
+ eval_steps : Optional [int ] = field (default = 100 , metadata = {"help" : "the evaluation frequency" })
81
+
82
+ output_dir : Optional [str ] = field (default = "./results" , metadata = {"help" : "the output directory" })
83
+
84
+ # instrumentation
85
+ sanity_check : Optional [bool ] = field (default = False , metadata = {"help" : "only train on 1000 samples" })
86
+ report_to : Optional [str ] = field (
87
+ default = "tensorboard" ,
88
+ metadata = {
89
+ "help" : 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
90
+ '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
91
+ 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
92
+ },
93
+ )
94
+ # debug argument for distributed training
95
+ ignore_bias_buffers : Optional [bool ] = field (
96
+ default = False ,
97
+ metadata = {
98
+ "help" : "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
99
+ "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
100
+ },
101
+ )
102
+ seed : Optional [int ] = field (
103
+ default = 0 , metadata = {"help" : "Random seed that will be set at the beginning of training." }
104
+ )
105
+
106
+
107
+ if __name__ == "__main__" :
108
+ parser = HfArgumentParser (ScriptArguments )
109
+ script_args = parser .parse_args_into_dataclasses ()[0 ]
110
+
111
+ set_seed (script_args .seed )
112
+
113
+ # 1. load a pretrained model
114
+ model = AutoModelForCausalLM .from_pretrained (
115
+ script_args .base_model_path ,
116
+ trust_remote_code = True ,
117
+ )
118
+ model .config .use_cache = False
119
+
120
+ if script_args .ignore_bias_buffers :
121
+ # torch distributed hack
122
+ model ._ddp_params_and_buffers_to_ignore = [
123
+ name for name , buffer in model .named_buffers () if buffer .dtype == torch .bool
124
+ ]
125
+
126
+ tokenizer = AutoTokenizer .from_pretrained (script_args .base_model_path , trust_remote_code = True )
127
+ tokenizer .pad_token = tokenizer .eos_token
128
+
129
+ tokenizer .add_special_tokens ({"bos_token" : tokenizer .eos_token })
130
+ tokenizer .bos_token_id = tokenizer .eos_token_id
131
+
132
+ # 2. Load the Stack-exchange paired dataset
133
+ # dpo_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
134
+ # dpo_dataset = dpo_dataset.filter(
135
+ # lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
136
+ # and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
137
+ # )
138
+
139
+ data_path = "/mnt/cephfs-xiongzhuang/wangdongnian/tiny-llm-zh/data/rm_train/rm_data.jsonl"
140
+ dpo_dataset = load_dpo_dataset (script_args .dataset_dir_or_path , max_length = script_args .max_length , sanity_check = script_args .sanity_check )
141
+
142
+ train_loader = torch .utils .data .DataLoader (
143
+ dpo_dataset ,
144
+ batch_size = 2 ,
145
+ pin_memory = False ,
146
+ drop_last = False ,
147
+ shuffle = False ,
148
+ num_workers = 8 ,
149
+ )
150
+ for i , item in enumerate (train_loader ):
151
+ print (item )
152
+ break
153
+
154
+
155
+ # 3. Load evaluation dataset
156
+ if script_args .eval_dataset_dir_or_path == "" :
157
+ evaluation_strategy = "no"
158
+ else :
159
+ evaluation_strategy = "steps"
160
+ eval_dataset = load_dpo_dataset (script_args .eval_dataset_dir_or_path , max_length = script_args .max_length , sanity_check = script_args .sanity_check )
161
+
162
+
163
+ # 4. initialize training arguments:
164
+ training_args = TrainingArguments (
165
+ per_device_train_batch_size = script_args .per_device_train_batch_size ,
166
+ per_device_eval_batch_size = script_args .per_device_eval_batch_size ,
167
+ num_train_epochs = script_args .num_train_epochs ,
168
+ logging_dir = script_args .logging_dir ,
169
+ logging_strategy = script_args .logging_strategy ,
170
+ logging_steps = script_args .logging_steps ,
171
+ save_steps = script_args .save_steps ,
172
+ gradient_accumulation_steps = script_args .gradient_accumulation_steps ,
173
+ gradient_checkpointing = script_args .gradient_checkpointing ,
174
+ learning_rate = script_args .learning_rate ,
175
+ evaluation_strategy = "no" ,
176
+ eval_steps = script_args .eval_steps ,
177
+ output_dir = script_args .output_dir ,
178
+ report_to = script_args .report_to ,
179
+ lr_scheduler_type = script_args .lr_scheduler_type ,
180
+ warmup_ratio = script_args .warmup_ratio ,
181
+ optim = script_args .optimizer_type ,
182
+ bf16 = script_args .bf16 ,
183
+ fp16 = script_args .fp16 ,
184
+ remove_unused_columns = False ,
185
+ run_name = script_args .model_name ,
186
+ gradient_checkpointing_kwargs = dict (use_reentrant = script_args .gradient_checkpointing_use_reentrant ),
187
+ seed = script_args .seed ,
188
+ # project_kwargs={"logging_dir": script_args.output_dir},
189
+ )
190
+
191
+ # peft_config = LoraConfig(
192
+ # r=script_args.lora_r,
193
+ # lora_alpha=script_args.lora_alpha,
194
+ # lora_dropout=script_args.lora_dropout,
195
+ # target_modules=[
196
+ # "q_proj",
197
+ # "v_proj",
198
+ # "k_proj",
199
+ # "out_proj",
200
+ # "fc_in",
201
+ # "fc_out",
202
+ # "wte",
203
+ # ],
204
+ # bias="none",
205
+ # task_type="CAUSAL_LM",
206
+ # )
207
+
208
+ # 5. initialize the DPO trainer
209
+ dpo_trainer = DPOTrainer (
210
+ model ,
211
+ ref_model = None ,
212
+ args = training_args ,
213
+ beta = script_args .beta ,
214
+ train_dataset = dpo_dataset ,
215
+ eval_dataset = eval_dataset ,
216
+ tokenizer = tokenizer ,
217
+ # peft_config=peft_config,
218
+ max_prompt_length = script_args .max_prompt_length ,
219
+ max_length = script_args .max_length ,
220
+ # data_collator=collator_fn,
221
+ )
222
+
223
+ # 6. train
224
+ dpo_trainer .train (script_args .resume )
225
+
226
+
227
+ # 7. save
228
+ output_dir = os .path .join (script_args .output_dir , "last_dpo_model" )
229
+ dpo_trainer .save_model (output_dir )
230
+ # dpo_trainer.model.save_pretrained(output_dir)
0 commit comments