- 🚀 Differentiated Fine-Tuning (DFT) Loss: An innovative custom loss function that dynamically adjusts loss weights based on the model's confidence (
p_correct) in predicting the correct token. This allows the model to focus on consolidating "well-learned" knowledge while avoiding being misled by "overly difficult" samples, resulting in more stable and efficient convergence. - ⚡ High Performance and Efficiency:
- Distributed Training: Deep integration with DeepSpeed ZeRO-3, enabling large-scale model training on multiple GPUs and optimizing memory usage.
- Flash Attention 2: Built-in support for
flash_attention_2significantly boosts training speed and efficiency for long sequences (e.g., 8K+ tokens). - Gradient Checkpointing: Effectively reduces memory consumption during training.
- BF16/FP16 Mixed Precision: Accelerates training while maintaining model performance.
- 📦 Robust Data Processing:
- ChatML Format: Designed specifically for handling dialogue data in the
ChatMLformat. - Multi-Source Data Fusion: Automatically loads, validates, and merges training data from multiple JSON files, handling schema inconsistencies.
- Efficient Preprocessing: Supports multi-process data preprocessing for faster data preparation.
- ChatML Format: Designed specifically for handling dialogue data in the
- 📝 Comprehensive Logging and Monitoring:
- Distributed Logging: Custom distributed logger ensures clear, non-redundant logs in multi-GPU settings.
- Training Metrics Monitoring: Deep integration with WandB or SwanLab for real-time tracking of
loss,grad_norm, and custom metrics likeavg_p_correct.
- 🔧 Flexible Configuration: All training parameters (model, data, DFT parameters, training settings, etc.) are configured via command line for clarity and ease of use.
The standard Supervised Fine-Tuning (SFT) loss is the token-level cross-entropy (expectation over expert data pairs D):
: Input (e.g., question, instruction)
: Expert answer (ground truth label)
: Probability of output
under model parameters
RL aims to maximize expected reward:
The policy gradient is:
SFT's expectation is over , while RL's is over
. We aim to write SFT's gradient as "sampled from
with weighting".
Key trick:
Rewrite using importance sampling:
In SFT, is the "expert distribution". For the discrete dataset D:
Thus:
Define:
So, the above is equivalent to:
This rewrites the SFT gradient in the RL policy gradient form, differing only in the definitions of and
.
Note:
Only when the model generates the expert answer does it get reward 1; otherwise 0. However, this reward is amplified by
.
- If
is small, then
is large, leading to gradient explosion, unstable optimization, and poor generalization.
Core idea:
Since causes instability, directly multiply by
in the loss/gradient to "cancel" its effect.
Let:
Expanded:
Since
So, the DFT loss is:
For NLP, DFT is extended to token-level as:
- SFT cross-entropy loss and gradient
- Rewrite SFT gradient onto model policy distribution using importance sampling
- Recognize SFT is equivalent to an RL with sparse reward, amplified by
- Analyze instability/generalization issues caused by
- Propose DFT: multiply by
to neutralize amplification and correct the loss
The core idea of DFT Loss: Let the model focus more on what it is confident about learning correctly.
Traditional cross-entropy treats all tokens equally. DFT, via a dft_alpha parameter, adjusts this behavior:
# 1. Compute model's probability for the correct token, p_correct
with torch.no_grad():
probs = F.softmax(shift_logits_flat, dim=-1)
p_correct = probs.gather(1, correct_labels.unsqueeze(-1)).squeeze(-1)
# 2. Compute the loss weight using p_correct and dft_alpha
# If p_correct -> 1 (very confident), dft_weight -> 1
# If p_correct -> 0 (not confident), dft_weight -> (1 - dft_alpha)
# High p_correct ⇒ high weight ⇒ reinforce "easy/already learned" tokens (like self-paced, but prefers easy samples)
# Low p_correct ⇒ low weight ⇒ downweight hard samples' gradients
dft_weight = p_correct * self.dft_alpha + (1 - self.dft_alpha)
# 3. Apply weight to the original loss
loss_flat = original_loss_flat * dft_weight
# 4. Compute the final mean loss
loss = loss_flat.sum() / num_valid_tokensIn this way, the less confident (hard) samples contribute less to the loss, preventing them from dominating the gradients and disrupting convergence.
-
Clone the repository:
git clone https://github.com/your-username/your-repo-name.git cd your-repo-name -
(Recommended) Create a conda virtual environment:
conda create -n dft_trainer python=3.10 conda activate dft_trainer
-
Install dependencies. Make sure your environment has PyTorch matching your CUDA version.
# requirements.txt # Core dependencies torch --pre "torch>=2.1.0" transformers "transformers>=4.40.0" datasets "datasets>=2.18.0" deepspeed "deepspeed>=0.14.0" # Acceleration and efficiency accelerate "accelerate>=0.29.0" flash-attn --pre "flash-attn>=2.5.0" --no-build-isolation # Utilities sentencepiece # for tokenization protobuf # for tokenization # Experiment tracking (optional) swanlab wandb
To install:
pip install -r requirements.txt
This project uses standard JSONL files, where each line is a JSON object. Each object must have a messages field in ChatML format.
Data Example (data.jsonl):
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, please introduce yourself."}, {"role": "assistant", "content": "Hello! I am a large language model, happy to assist you."}]}
{"messages": [{"role": "user", "content": "Write me a poem about spring."}, {"role": "assistant", "content": "Sure. Spring breeze caresses sprouting green, soft rain moistens all unseen. Fields are fragrant, bees and butterflies dance, the land a painting, inviting all to glance."}]}Key Points:
messagesis a list of dialogue turns.- Each turn is a dict with
roleandcontent. - There must be at least one
assistantturn inmessagesto compute the loss (onlyassistantreplies are included in loss computation).
Training is managed via a launch script that configures all necessary parameters.
Prepare a DeepSpeed config file, e.g. ds_config/zero3.json for ZeRO-3:
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false,
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "none"
},
"offload_optimizer": {
"device": "none"
},
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_prefetch_bucket_size": 1e7,
"contiguous_gradients": true,
"overlap_comm": true
}
}Create a train.sh script to launch training:
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export TOKENIZERS_PARALLELISM=false
# List of training data files
DATA_FILES=(
"/path/to/your/data_part1.json"
"/path/to/your/data_part2.json"
# ... more data files
)
# Output and log directories
OUTPUT_DIR="output_model"
LOG_DIR="logs"
mkdir -p ${LOG_DIR}
LOG_FILE="${LOG_DIR}/train_$(date +%F_%H%M%S).log"
deepspeed --num_gpus=8 train_dft_fixed.py \
--model_name_or_path /path/to/your/base_model \
--torch_dtype bfloat16 \
--attn_implementation flash_attention_2 \
--trust_remote_code True \
--data_files "${DATA_FILES[@]}" \
--max_length 8192 \
--preprocessing_num_workers 8 \
--validation_split_percentage 2.0 \
--enable_gradient_checkpointing True \
--dft_alpha 0.7 \
--bf16 True \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 2e-6 \
--lr_scheduler_type cosine \
--weight_decay 0.01 \
--gradient_accumulation_steps 16 \
--eval_strategy steps \
--eval_steps 200 \
--save_strategy steps \
--save_steps 200 \
--save_total_limit 3 \
--save_only_model True \
--report_to swanlab \
--logging_steps 10 \
--warmup_ratio 0.05 \
--deepspeed ./ds_config/zero3.json \
--output_dir ${OUTPUT_DIR} \
--logging_dir ${LOG_DIR} \
--remove_unused_columns False \
--ddp_find_unused_parameters False
echo "Training started in background. Log file: ${LOG_FILE}"Make the script executable and run:
chmod +x train.sh
nohup ./train.sh > ${LOG_FILE} 2>&1 &To monitor logs in real-time:
tail -f ${LOG_FILE}- Console Logs: Training progress and loss are output to your specified log file.
- Experiment Tracking: If
report_tois set toswanlaborwandb, you can view all metrics and charts in the platform UI, including:train/loss: Training loss (should decrease steadily).eval/loss: Validation loss (key for generalization).train/grad_norm: Gradient norm (for training stability).train/train/avg_p_correct: DFT core metric; average model confidence on correct tokens (should increase steadily).train/train/dft_alpha: The hyperparameter you set (for verifying configuration).
Contributions are welcome! If you have ideas, suggestions, or find bugs, feel free to submit a Pull Request or open an Issue.
This project is licensed under the Apache 2.0 License.