Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d1a74cc

Browse files
committed
dpo train
1 parent f23a46a commit d1a74cc

File tree

2 files changed

+466
-0
lines changed

2 files changed

+466
-0
lines changed

dpo_train.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
2+
# https://huggingface.co/docs/trl/main/en/dpo_trainer
3+
# https://huggingface.co/datasets/lvwerra/stack-exchange-paired
4+
# https://huggingface.co/blog/zh/dpo-trl
5+
6+
# https://github.dev/RUCAIBox/LLMBox
7+
8+
# 0. imports
9+
import os
10+
from dataclasses import dataclass, field
11+
from typing import Dict, Optional
12+
13+
import torch
14+
from accelerate import Accelerator
15+
from datasets import Dataset, load_dataset
16+
from peft import LoraConfig
17+
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
18+
from torch.utils.data import Dataset, DataLoader, random_split
19+
from trl import DPOTrainer
20+
from tinyllm_dataset import load_dpo_dataset
21+
22+
23+
# Define and parse arguments.
24+
@dataclass
25+
class ScriptArguments:
26+
"""
27+
The arguments for the DPO training script.
28+
"""
29+
30+
# data parameters
31+
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
32+
33+
# training parameters
34+
model_name: Optional[str] = field(default="",metadata={"help": "the location of the SFT model name or path"})
35+
dataset_dir_or_path: Optional[str] = field(default="",metadata={"help": "the location of the SFT model name or path"})
36+
eval_dataset_dir_or_path: Optional[str] = field(default="",metadata={"help": "the location of the SFT model name or path"})
37+
resume: Optional[bool] = field(default=False,metadata={"help": "the location of the SFT model name or path"})
38+
base_model_path: Optional[str] = field(default="",metadata={"help": "the location of the SFT model name or path"})
39+
40+
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
41+
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
42+
warmup_ratio: Optional[float] = field(default=0.01, metadata={"help": "the number of warmup steps"})
43+
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
44+
optimizer_type: Optional[str] = field(default="adamw_torch", metadata={"help": "the optimizer type"})
45+
46+
per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"})
47+
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
48+
gradient_accumulation_steps: Optional[int] = field(
49+
default=4, metadata={"help": "the number of gradient accumulation steps"}
50+
)
51+
gradient_checkpointing: Optional[bool] = field(
52+
default=True, metadata={"help": "whether to use gradient checkpointing"}
53+
)
54+
55+
gradient_checkpointing_use_reentrant: Optional[bool] = field(
56+
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
57+
)
58+
59+
bf16: bool = field(
60+
default=False,
61+
metadata={
62+
"help": (
63+
"Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
64+
" architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change."
65+
)
66+
},
67+
)
68+
fp16: bool = field(
69+
default=False,
70+
metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
71+
)
72+
73+
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
74+
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
75+
num_train_epochs: Optional[int] = field(default=5, metadata={"help": "epoch of training steps"})
76+
logging_strategy: Optional[str] = field(default="steps", metadata={"help": "logging_strategy"})
77+
logging_dir: Optional[str] = field(default="", metadata={"help": "logging_dir"})
78+
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
79+
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
80+
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
81+
82+
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
83+
84+
# instrumentation
85+
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
86+
report_to: Optional[str] = field(
87+
default="tensorboard",
88+
metadata={
89+
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
90+
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
91+
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
92+
},
93+
)
94+
# debug argument for distributed training
95+
ignore_bias_buffers: Optional[bool] = field(
96+
default=False,
97+
metadata={
98+
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
99+
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
100+
},
101+
)
102+
seed: Optional[int] = field(
103+
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
parser = HfArgumentParser(ScriptArguments)
109+
script_args = parser.parse_args_into_dataclasses()[0]
110+
111+
set_seed(script_args.seed)
112+
113+
# 1. load a pretrained model
114+
model = AutoModelForCausalLM.from_pretrained(
115+
script_args.base_model_path,
116+
trust_remote_code=True,
117+
)
118+
model.config.use_cache = False
119+
120+
if script_args.ignore_bias_buffers:
121+
# torch distributed hack
122+
model._ddp_params_and_buffers_to_ignore = [
123+
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
124+
]
125+
126+
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_path, trust_remote_code=True)
127+
tokenizer.pad_token = tokenizer.eos_token
128+
129+
tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token})
130+
tokenizer.bos_token_id = tokenizer.eos_token_id
131+
132+
# 2. Load the Stack-exchange paired dataset
133+
# dpo_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
134+
# dpo_dataset = dpo_dataset.filter(
135+
# lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
136+
# and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
137+
# )
138+
139+
data_path = "/mnt/cephfs-xiongzhuang/wangdongnian/tiny-llm-zh/data/rm_train/rm_data.jsonl"
140+
dpo_dataset = load_dpo_dataset(script_args.dataset_dir_or_path, max_length=script_args.max_length, sanity_check=script_args.sanity_check)
141+
142+
train_loader = torch.utils.data.DataLoader(
143+
dpo_dataset,
144+
batch_size=2,
145+
pin_memory=False,
146+
drop_last=False,
147+
shuffle=False,
148+
num_workers=8,
149+
)
150+
for i, item in enumerate(train_loader):
151+
print(item)
152+
break
153+
154+
155+
# 3. Load evaluation dataset
156+
if script_args.eval_dataset_dir_or_path == "":
157+
evaluation_strategy = "no"
158+
else:
159+
evaluation_strategy = "steps"
160+
eval_dataset = load_dpo_dataset(script_args.eval_dataset_dir_or_path, max_length=script_args.max_length, sanity_check=script_args.sanity_check)
161+
162+
163+
# 4. initialize training arguments:
164+
training_args = TrainingArguments(
165+
per_device_train_batch_size=script_args.per_device_train_batch_size,
166+
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
167+
num_train_epochs=script_args.num_train_epochs,
168+
logging_dir=script_args.logging_dir,
169+
logging_strategy=script_args.logging_strategy,
170+
logging_steps=script_args.logging_steps,
171+
save_steps=script_args.save_steps,
172+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
173+
gradient_checkpointing=script_args.gradient_checkpointing,
174+
learning_rate=script_args.learning_rate,
175+
evaluation_strategy="no",
176+
eval_steps=script_args.eval_steps,
177+
output_dir=script_args.output_dir,
178+
report_to=script_args.report_to,
179+
lr_scheduler_type=script_args.lr_scheduler_type,
180+
warmup_ratio=script_args.warmup_ratio,
181+
optim=script_args.optimizer_type,
182+
bf16=script_args.bf16,
183+
fp16=script_args.fp16,
184+
remove_unused_columns=False,
185+
run_name=script_args.model_name,
186+
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
187+
seed=script_args.seed,
188+
# project_kwargs={"logging_dir": script_args.output_dir},
189+
)
190+
191+
# peft_config = LoraConfig(
192+
# r=script_args.lora_r,
193+
# lora_alpha=script_args.lora_alpha,
194+
# lora_dropout=script_args.lora_dropout,
195+
# target_modules=[
196+
# "q_proj",
197+
# "v_proj",
198+
# "k_proj",
199+
# "out_proj",
200+
# "fc_in",
201+
# "fc_out",
202+
# "wte",
203+
# ],
204+
# bias="none",
205+
# task_type="CAUSAL_LM",
206+
# )
207+
208+
# 5. initialize the DPO trainer
209+
dpo_trainer = DPOTrainer(
210+
model,
211+
ref_model=None,
212+
args=training_args,
213+
beta=script_args.beta,
214+
train_dataset=dpo_dataset,
215+
eval_dataset=eval_dataset,
216+
tokenizer=tokenizer,
217+
# peft_config=peft_config,
218+
max_prompt_length=script_args.max_prompt_length,
219+
max_length=script_args.max_length,
220+
# data_collator=collator_fn,
221+
)
222+
223+
# 6. train
224+
dpo_trainer.train(script_args.resume)
225+
226+
227+
# 7. save
228+
output_dir = os.path.join(script_args.output_dir, "last_dpo_model")
229+
dpo_trainer.save_model(output_dir)
230+
# dpo_trainer.model.save_pretrained(output_dir)

0 commit comments

Comments
 (0)