Thanks to visit codestin.com
Credit goes to github.com

Skip to content

GLM4.5V 训练慢,在flash_attn2下训练报错 #149

@ZhaozwTD

Description

@ZhaozwTD

在LLaMa Factory训练能够正常训练起来,但是显卡利用率偏低,训练的较慢;

在添加了参数 flash_attn: fa2后,结果报错,报错信息如下:

[rank5]:   File "/mnt/bn/codebase/glm4.5V/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 96, in run_sft                                                           
[rank5]:     train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)                                                                                                  
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/trainer.py", line 2315, in train                                                                                           
[rank5]:     return inner_training_loop(                                                                                                                                                                
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/trainer.py", line 2659, in _inner_training_loop                                                                            
[rank5]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)                                                                                                                       
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/trainer.py", line 3873, in training_step                                                                                   
[rank5]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank5]:   File "/mnt/bn/codebase/glm4.5V/LLaMA-Factory/src/llamafactory/train/sft/trainer.py", line 108, in compute_loss
[rank5]:     return super().compute_loss(model, inputs, *args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/trainer.py", line 3961, in compute_loss
[rank5]:     outputs = model(**inputs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank5]:     ret_val = func(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
[rank5]:     loss = self.module(*inputs, **kwargs) 
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank5]:     return inner()
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank5]:     result = forward_call(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/peft/peft_model.py", line 1719, in forward
[rank5]:     return self.base_model(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank5]:     return inner()
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank5]:     result = forward_call(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
[rank5]:     return self.model.forward(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank5]:     output = func(self, *args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py", line 1546, in forward
[rank5]:     outputs = self.model(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank5]:     return inner()
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1790, in inner                                                                                        
[rank5]:     result = forward_call(*args, **kwargs)                                                                                                                                                     
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/utils/generic.py", line 1083, in wrapper                                                                                   
[rank5]:     outputs = func(self, *args, **kwargs)                                                                                                                                                      
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py", line 989, in forward                                                              
[rank5]:     layer_outputs = decoder_layer(                                                                                                                                                             
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/modeling_layers.py", line 93, in __call__                                                                                  
[rank5]:     return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
[rank5]:   File "/mnt/bn/codebase/glm4.5V/LLaMA-Factory/src/llamafactory/model/model_utils/checkpointing.py", line 97, in custom_gradient_checkpointing_func
[rank5]:     return gradient_checkpointing_func(func, *args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_compile.py", line 32, in inner
[rank5]:     return disable_fn(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank5]:     return fn(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
[rank5]:     return CheckpointFunction.apply(function, preserve, *args)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/autograd/function.py", line 575, in apply
[rank5]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 264, in forward
[rank5]:     outputs = run_function(*args)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank5]:     return inner()
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank5]:     result = forward_call(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank5]:     return func(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py", line 415, in forward
[rank5]:     hidden_states, _ = self.self_attn(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank5]:     return inner()
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank5]:     result = forward_call(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank5]:     return func(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/glm4v_moe/modeling_glm4v_moe.py", line 229, in forward
[rank5]:     attn_output, attn_weights = attention_interface(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/integrations/flash_attention.py", line 66, in flash_attention_forward
[rank5]:     attn_output = _flash_attention_forward(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/transformers/modeling_flash_attention_utils.py", line 672, in _flash_attention_forward
[rank5]:     out = flash_varlen_fn(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 1448, in flash_attn_varlen_func
[rank5]:     return FlashAttnVarlenFunc.apply(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/autograd/function.py", line 575, in apply
[rank5]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 930, in forward
[rank5]:     out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_ops.py", line 1116, in __call__
[rank5]:     return self._op(*args, **(kwargs or {}))
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_library/autograd.py", line 113, in autograd_impl
[rank5]:     result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
[rank5]:     result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_ops.py", line 721, in redispatch
[rank5]:     return self._handle.redispatch_boxed(keyset, *args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_library/custom_ops.py", line 324, in backend_impl
[rank5]:     result = self._backend_fns[device_type](*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_compile.py", line 32, in inner
[rank5]:     return disable_fn(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank5]:     return fn(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/torch/_library/custom_ops.py", line 367, in wrapped_fn
[rank5]:     return fn(*args, **kwargs)
[rank5]:   File "/home/tiger/.local/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 170, in _flash_attn_varlen_forward
[rank5]:     out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
[rank5]: RuntimeError: cu_seqlens_q must have dtype int32

训练 yaml

# wandb 设置
report_to: wandb
run_name: glm45v-test-1

### model
model_name_or_path: GLM-4.5V

### method
stage: sft
flash_attn: fa2
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
deepspeed: examples/deepspeed/ds_z3_offload_config.json
ddp_find_unused_parameters: false 
gradient_checkpointing: true
freeze_vision_tower: true
enable_thinking: false 

### dataset
dataset: demo_nothink, demo_think_500
template: glm4v_moe
cutoff_len: 30000
# max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 96
image_max_pixels: 262144

### output
output_dir: /mnt/bn/glm45v-test-1
logging_steps: 1
# save_steps: 130
save_strategy: epoch
save_total_limit: 2
overwrite_output_dir: true 

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 1
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
trust_remote_code: true
ddp_timeout: 180000000

安装包版本

transformers==4.56.0.dev0
flash-attn==2.7.3
deepspeed==0.14.4
torch==2.5.1

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions