Thanks to visit codestin.com
Credit goes to github.com

Skip to content

dannigt/mid-align

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

28 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Middle-Layer Representation Alignment for Cross-Lingual Transfer in Fine-Tuned LLMs

arXiv Python 3.9 HuggingFace

🧩 Models

Base Model Task LoRA Adapters Note
Meta-Llama-3-8B-Instruct machine translation MT-baseline-LoRA only supervised FT
Meta-Llama-3-8B-Instruct machine translation MT-mid-align-LoRA + mid-align
Meta-Llama-3-8B-Instruct slot filling SF-baseline-LoRA only supervised FT
Meta-Llama-3-8B-Instruct slot filling SF-mid-align-LoRA + mid-align
Qwen2.5-7B-Instruct machine translation MT-baseline-LoRA only supervised FT
Qwen2.5-7B-Instruct machine translation MT-mid-align-LoRA + mid-align
Qwen2.5-7B-Instruct slot filling SF-baseline-LoRA only supervised FT
Qwen2.5-7B-Instruct slot filling SF-mid-align-LoRA + mid-align

πŸ“‹ Overview

This repository contains the implementation of our paper, which aims to improve cross-lingual transfer in fine-tuned Large Language Models. By aligning representations between languages, our method improves knowledge transfer while maintaining performance in the supervised conditions.

overview

πŸ” Key Features

  • Alternate training (with LoRA) between task-specific and alignment objectives
  • Contrastive learning for multilingual representations in LLMs
  • Implementation for Llama 3 and Qwen2.5

πŸ› οΈ Installation

Our implementation builds upon Huggingface Transformer version v4.43.4.

# create and activate conda environment
conda create -n midalign python=3.9
conda activate midalign
# helper packages
pip install skikit-learn hf_mtask_trainer 
# install package and dependencies
pip install -e .
# pytorch
pip install torch torchvision torchaudio
# deepspeed
pip install deepspeed
# flash attention
pip install flash-attn --no-build-isolation
# other huggingface packags
pip install datasets evaluate peft
# for evaluation
pip install seqeval levenshtein sacrebleu unbabel-comet

Note: Flash attention is optional. If installation is unsuccessful, remove this line to train without it.

For a complete list of packages in our environment, please see environment.yml.

πŸ”§ Implementation Details

We've made the following modifications to Hugging Face Transformers (v4.43.4):

  1. Model Architectures: src/transformers/models/llama/modeling_llama.py and src/transformers/models/qwen2/modeling_qwen2.py to enable contrastive learning at a selected layer
  2. Training Loop: src/transformers/trainer.py to support alternate training between two losses\
  3. Configuration: src/transformers/configuration_utils.py to support reading in additional configuration parameters
  4. Tokenization: src/transformers/tokenization_utils_base.py to support taking a pair of parallel sentences at once
  5. Loss Tracking: src/transformers/modeling_utils.py to track two different losses

πŸ“Š Data

The approach requires pairs of parallel sentences for aligning cross-lingual representations. Example data files:

  • ./data_example/train_baseline.json: For standard task-specific training
  • ./data_example/train_baseline_with_alignment.json: For training with alignment

The main experiment datasets (in .jsonl format) can be downloaded from this link.

πŸš€ Training

We have two training setups:

1. Task-specific Baseline Training

bash ./scripts/train_baseline.sh

This trains the model with standard task-specific objectives.

Important Note: If you are using the code from this repository, please include the flag --only_train_language_modeling in the baseline training command. This flag is not needed when using unmodified code from HuggingFace.

2. Training with Alignment

bash ./scripts/train_with_alignment.sh

Additional/changed parameters for alignment training:

  • --alternate_training: activate alternate training
  • --contrastive_data_mode 2: configures data format for alternate training
  • --additional_loss_layer 16: specifies the layer for contrastive loss (16 in our experiments)
  • --contrastive_loss_temperature 0.1: controls the peakiness of contrastive loss
  • --distance_function cosine: defines the distance metric for representation comparison
  • --num_train_epochs 10: doubled from baseline due to alternate training
Expand for more details
#!/bin/bash
export NCCL_DEBUG=INFO
#export NCCL_SOCKET_IFNAME=#eno2np1 #eth1
export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_NET_GDR_READ=1

export MASTER_ADDR="${CHIEF_IP:=localhost}"
export MASTER_PORT="${MASTER_PORT:=29501}"

module load compiler/gnu/12
module load devel/cuda/12.0

HOST_NUM=1
INDEX=0

model_path="meta-llama/Meta-Llama-3-8B-Instruct"
train_files="./data_example/train_baseline.json" # replace by actual training data
valid_files="./data_example/train_baseline.json" # replace by actual validation data
train_bsz=32
eval_bsz=32
gradient_accumulation_steps=4
lora_config="./config/lora_config.json"
LR="5e-4"
OUTDIR="./test_run_outputs"
nproc_per_node=1 # number of GPUs used in training
loss_layer=16
loss_temperature=0.1
loss_distance_type="cosine"


torchrun --nnodes $HOST_NUM --node_rank $INDEX --nproc_per_node $nproc_per_node \
    --master_addr $MASTER_ADDR --master_port $MASTER_PORT  \
    ${train_path} \
    --deepspeed ./config/deepspeed_config.json \
    --bf16 True \
    --bf16_full_eval True \
    --model_name_or_path ${model_path} \
    --train_file $train_files \
    --validation_file $valid_files \
    --use_lora True \
    --lora_config ./config/lora_config.json \
    --torch_dtype bfloat16 \
    --preprocessing_num_workers 16 \
    --dataloader_num_workers 1 \
    --dataloader_pin_memory True \
    --per_device_train_batch_size $train_bsz \
    --per_device_eval_batch_size $eval_bsz \
    --gradient_accumulation_steps $gradient_accumulation_steps \
    --num_train_epochs 10 \
    --save_strategy "steps" \
    --save_steps 200 \
    --save_total_limit 2 \
    --learning_rate $LR \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "inverse_sqrt" \
    --logging_steps 10 \
    --block_size 2048 \
    --do_train \
    --eval_strategy "steps" \
    --eval_steps 200 \
    --streaming \
    --ddp_timeout 3600 \
    --seed 1 \
    --gradient_checkpointing True \
    --load_best_model_at_end True \
    --metric_for_best_model "eval_loss" \
    --patience 5 \
    --output_dir $OUTDIR \
    --contrastive_data_mode 2 \
    --additional_loss_layer $loss_layer \
    --contrastive_loss_temperature $loss_temperature \
    --distance_function $loss_distance_type \
    --alternate_training \
    --disable_tqdm True --overwrite_output_dir 2>&1  | tee -a $OUTDIR/train.log

Our training scripts are adapted from ParroT.

The configuration files for DeepSpeed and LoRA are included in ./config.

πŸ“ˆ Inference and Evaluation

Slot Filling

basemodel="meta-llama/Meta-Llama-3-8B-Instruct"
path2peftmodel="danniliu/Meta-Llama-3-8B-Instruct-SF-baseline-LoRA" # or replace by local path to finetuned model
lang="en" # replace by other langauge codes (see massive_lang_map in scripts/utils.py)

# Run inference
python -m scripts.run_inference_massive --base-model-name $basemodel \
                                        --peft-model-id $path2peftmodel \
                                        --lang $lang \ 
                                        --partition "test" \
                                        --output-path $path2outputjson \
                                        --no-default-template
                                        
# Extract predictions from json outputs
python scripts/extract_json_output.py $path2outputjson $path2output
                                                                     
# Evaluate performance
python -m scripts.eval_massive --pred-slots-file $path2output \
                               --lang $lang \
                               --partition "test"

Machine Translation

basemodel="meta-llama/Meta-Llama-3-8B-Instruct"
path2peftmodel="danniliu/Meta-Llama-3-8B-Instruct-MT-baseline-LoRA" # or replace by local path to finetuned model
source_lang="cs" # input language; other langauge codes in translation_lang_map in scripts/utils.py
lang="en" # output language
path2outputjson=${source_lang}-${lang}_"wmt23.json"

# Run inference
python -m scripts.run_inference_eval_wmt23 --base-model-name $basemodel \
                                           --peft-model-id $path2peftmodel \
                                           --source-lang $source_lang \
                                           --lang $lang \
                                           --output-path $path2outputjson \
                                           --no-default-template

Note: for llama 3, our training prompt template differed slightly from the default (not using bos token, having a line break after each end-of-turn token). Therefore please add --no-default-template to reproduce the results from the paper.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages