From 3794aa981ca8dbd9d4c910272af80d26cd60a4f9 Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Tue, 19 Aug 2025 17:22:40 +0800 Subject: [PATCH 01/14] Dsv3 sft (#10968) * update expert parallel init logic * fix flash_mask && MoEFlexTokenLayer experts && add some config * offload optimizer --------- Co-authored-by: blacksheep-Aristotle --- llm/config/deepseek-v3/sft_argument.json | 48 ++++ llm/deepseeek_sft.sh | 22 ++ llm/run_finetune.py | 27 +- paddlenlp/trainer/training_args.py | 248 ++++++++++++++---- paddlenlp/trainer/utils/offload_optimizer.py | 81 ++++++ .../transformers/deepseek_v2/modeling.py | 28 +- .../transformers/deepseek_v2/modeling_pp.py | 13 + paddlenlp/transformers/moe_layer.py | 16 +- paddlenlp/transformers/moe_utils.py | 16 +- paddlenlp/trl/model_config.py | 2 + paddlenlp/trl/sft_config.py | 2 +- 11 files changed, 430 insertions(+), 73 deletions(-) create mode 100644 llm/config/deepseek-v3/sft_argument.json create mode 100644 llm/deepseeek_sft.sh create mode 100644 paddlenlp/trainer/utils/offload_optimizer.py diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json new file mode 100644 index 000000000000..166fa39a394f --- /dev/null +++ b/llm/config/deepseek-v3/sft_argument.json @@ -0,0 +1,48 @@ +{ + "model_name_or_path": "/root/paddlejob/new_disk/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 16, + "per_device_eval_batch_size": 1, + "eval_accumulation_steps": 1, + "num_train_epochs": 3, + "learning_rate": 2.2e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 2048, + "max_length": 4097, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "use_expert_parallel": true, + "expert_parallel_degree": 8, + "continue_training": true, + "pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm", + "tensor_parallel_config": "enable_delay_scale_loss", + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "loss", + "recompute": true, + "recompute_use_reentrant": true, + "recompute_granularity": "full", + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 8, + "sharding_parallel_degree": 8, + "sharding": "stage1", + "zero_padding": true, + "unified_checkpoint": true, + "use_flash_attention": true, + "flash_mask": true, + "using_fake_gate": true, + "using_flex_token": true, + "pre_alloc_memory": 60, + "tensorwise_offload_optimizer": true, + "autotuner_benchmark": true +} + diff --git a/llm/deepseeek_sft.sh b/llm/deepseeek_sft.sh new file mode 100644 index 000000000000..8d49c6c6c048 --- /dev/null +++ b/llm/deepseeek_sft.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# llama 模型数据下载 +# mkdir -p data +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.bin +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx + +mpirun rm -rf output +nohup mpirun sh script/train_gpu.sh config/deepseek-v3/sft_argument.json > run.log 2>&1 & + diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 31427a516f2d..1e829021efd9 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -107,6 +107,19 @@ def paddlenlp_verison_check(): ) +def mock_offload_optimizer(): + """ + mock offload optimizer + """ + try: + from paddlenlp.trainer.utils.offload_optimizer import hack_offload_optimizer + + hack_offload_optimizer() + logger.warning("hack_offload_optimizer called.") + except ImportError: + logger.warning("hack_offload_optimizer is not imported") + + def main(): paddlenlp_verison_check() parser = PdArgumentParser((GenerateArgument, ModelConfig, ReftArgument, DataConfig, SFTConfig)) @@ -119,10 +132,19 @@ def main(): else: gen_args, model_args, reft_args, data_args, training_args = parser.parse_args_into_dataclasses() + if training_args.tensorwise_offload_optimizer: + mock_offload_optimizer() + training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") training_args.print_config(gen_args, "Generation") + if training_args.pre_alloc_memory > 0: + memory_size = int(training_args.pre_alloc_memory * 1024 * 1024 * 1024) + x = paddle.empty([memory_size], dtype=paddle.uint8) + logger.info(f"pre_alloc_memory size {x.shape}") + del x + # Setup GPU & distributed training paddle.set_device(training_args.device) set_seed(seed=training_args.seed) @@ -250,7 +272,9 @@ def main(): raise ValueError("Please set eval_with_do_generation to false in pipeline parallel mode.") model_class = AutoModelForCausalLMPipe - + model_config["using_flex_token"] = model_args.using_fake_gate + model_config.using_fake_gate = model_args.using_fake_gate + print("model_config ", model_config, flush=True) if model_args.continue_training and not training_args.autotuner_benchmark: model = model_class.from_pretrained( model_args.model_name_or_path, @@ -261,6 +285,7 @@ def main(): # NOTE(gongenlei): new add autotuner_benchmark model = model_class.from_config(model_config, dtype=dtype) + print("model:", model, flush=True) if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): logger.warning("`flash_mask` must use with zero padding and flash attention.") data_args.zero_padding = True diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 30a3e7b3dc62..9f3c89f6fa23 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1081,6 +1081,11 @@ class TrainingArguments: nccl_comm_group_config: Optional[str] = field( default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"} ) + + pre_alloc_memory: int = field( + default=0, + metadata={"help": "pre allocate memory size GB"}, + ) def __post_init__(self): world_size = paddle.distributed.get_world_size() @@ -1180,7 +1185,125 @@ def __post_init__(self): if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1: raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") - self._post_init_parallel_degree() + self.use_hybrid_parallel = False + + if isinstance(self.sharding, bool): + self.sharding = "stage1" if self.sharding else "" + if isinstance(self.sharding, str): + self.sharding = [ShardingOption(s) for s in self.sharding.split()] + if self.sharding == [ShardingOption.OFFLOAD]: + raise ValueError( + "`--sharding offload` can't work on its own. It needs to be added to `--sharding stage2` or " + '`--sharding stage3`. For example, `--sharding "stage2 offload"`.' + ) + elif len(self.sharding) > (ShardingOption.OFFLOAD in self.sharding) + 1: + raise ValueError("`--sharding` recived too many arguments.") + + if self.sharding_degree > 0: + warnings.warn("`sharding_degree` is deprecated, please use `sharding_parallel_degree`") + self.sharding_parallel_degree = max(self.sharding_degree, self.sharding_parallel_degree) + self.data_parallel_degree = 1 + + delattr(self, "sharding_degree") + + if len(self.sharding) == 0 and self.sharding_parallel_degree > 0: + warnings.warn("`--sharding_parallel_degree` is useful only when `--sharding` is specified.") + + world_size = paddle.distributed.get_world_size() + + if world_size > 1: + tensor_parallel_degree = max(self.tensor_parallel_degree, 1) + sep_parallel_degree = max(self.sep_parallel_degree, 1) + context_parallel_degree = max(self.context_parallel_degree, 1) + pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) + expert_parallel_degree = max(self.expert_parallel_degree, 1) + expert_tensor_parallel_degree = max(self.expert_tensor_parallel_degree, 1) + + # TODO(@gexiao): support expert_tensor_parallel_degree > 1 in the future + assert ( + expert_tensor_parallel_degree == 1 + ), f"Currently only support expert_tensor_parallel_degree=1, but got expert_tensor_parallel_degree of {expert_tensor_parallel_degree}" + + assert ( + world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0 + ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}." + + assert not ( + sep_parallel_degree > 1 and context_parallel_degree > 1 + ), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}." + + if self.sharding_parallel_degree == -1: + if len(self.sharding) > 0: + self.sharding_parallel_degree = world_size // ( + tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + ) + + sharding_parallel_degree = max(self.sharding_parallel_degree, 1) + if sharding_parallel_degree == 1 and len(self.sharding) > 0: + logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") + self.sharding = [] + + self.data_parallel_degree = world_size // ( + sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + ) + + if expert_parallel_degree > 1: + moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree) + assert ( + self.expert_tensor_parallel_degree <= 1 + ), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1" + else: + moe_sharding_parallel_degree = 1 + moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1) + if moe_sharding_parallel_degree > 1 and self.data_parallel_degree > 1: + raise NotImplementedError( + f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: {self.data_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." + ) + + if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1: + assert ( + sharding_parallel_degree % moe_sharding_parallel_degree == 0 + ), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." + + assert not ( + self.data_parallel_degree > 1 and expert_parallel_degree > 1 + ), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}." + + if ( + sharding_parallel_degree > 1 + or tensor_parallel_degree > 1 + or pipeline_parallel_degree > 1 + or self.sep_parallel_degree > 1 + or self.context_parallel_degree > 1 + or expert_parallel_degree > 1 + or expert_tensor_parallel_degree > 1 + ): + self.use_hybrid_parallel = True + self.sharding_parallel_degree = sharding_parallel_degree + self.tensor_parallel_degree = tensor_parallel_degree + self.pipeline_parallel_degree = pipeline_parallel_degree + self.sep_parallel_degree = sep_parallel_degree + self.context_parallel_degree = context_parallel_degree + self.expert_parallel_degree = expert_parallel_degree + self.expert_tensor_parallel_degree = expert_tensor_parallel_degree + self.moe_sharding_parallel_degree = moe_sharding_parallel_degree + + if not self.use_hybrid_parallel: + self.sharding = [] + self.sharding_parallel_degree = -1 + self.tensor_parallel_degree = -1 + self.pipeline_parallel_degree = -1 + self.sep_parallel_degree = -1 + self.context_parallel_degree = -1 + self.expert_parallel_degree = -1 + self.expert_tensor_parallel_degree = -1 + + if self.hybrid_parallel_topo_order is None: + self.hybrid_parallel_topo_order = "sharding_first" + assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"] + + if self.use_hybrid_parallel and self.enable_auto_parallel: + self.use_hybrid_parallel = False if self.to_static: assert world_size == 1 or self.enable_auto_parallel, ( @@ -1383,6 +1506,17 @@ def is_segment_parallel_supported(): logger.warning("segment parallel is not supported!!!, Ignore it.") return support_sep + def is_context_parallel_supported(): + import inspect + + members = [ + name for (name, date) in inspect.getmembers(fleet.base.topology.EPHybridCommunicateGroup) + ] + support_cp = "get_context_parallel_world_size" in members + if not support_cp: + logger.warning("context parallel is not supported!!! Ignore it.") + return support_cp + if self.hybrid_parallel_topo_order == "pp_first": if is_segment_parallel_supported(): order = ["dp", "pp", "sharding", "sep", "mp"] @@ -1394,17 +1528,31 @@ def is_segment_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - order = order[1:-1] + ["dp", "mp"] + if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: + if is_context_parallel_supported(): + order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"] + else: + order = ["sharding", "moe_sharding", "pp", "sep", "dp", "ep", "mp"] + else: + order = ["sharding", "pp", "sep", "dp", "mp"] - if is_segment_parallel_supported(): + if is_context_parallel_supported(): hybrid_configs = { "dp_degree": self.data_parallel_degree, "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, - "sep_degree": self.sep_parallel_degree - if self.sep_parallel_degree > 1 - else self.context_parallel_degree, + "sep_degree": self.sep_parallel_degree, + "cp_degree": self.context_parallel_degree, + "order": order, + } + elif is_segment_parallel_supported(): + hybrid_configs = { + "dp_degree": self.data_parallel_degree, + "mp_degree": self.tensor_parallel_degree, + "pp_degree": self.pipeline_parallel_degree, + "sharding_degree": self.sharding_parallel_degree, + "sep_degree": self.sep_parallel_degree, "order": order, } else: @@ -1416,6 +1564,13 @@ def is_segment_parallel_supported(): "order": order, } + if self.expert_parallel_degree > 1: + assert ( + self.use_expert_parallel is True and self.moe_sharding_parallel_degree >= 0 + ), f"invalid expert_parallel_degree {self.expert_parallel_degree} and use_expert_paralle:{self.use_expert_parallel}." + hybrid_configs["ep_degree"] = self.expert_parallel_degree + hybrid_configs["moe_sharding_degree"] = self.moe_sharding_parallel_degree + try: if self.split_norm_comm: hybrid_configs["split_norm_comm"] = True @@ -2052,47 +2207,12 @@ def _post_init_parallel_degree(self): self.use_hybrid_parallel = False def add_moe_comm_group(self): - hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs + # NOTE(zhangweilong):move init_moe_group logic to paddle fleet.init + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() + moe_grad_group = fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() hcg = fleet.get_hybrid_communicate_group() - topo = hcg._topo - sharding_parallel_groups = topo.get_comm_list("sharding") - experts_replicas = self.sharding_parallel_degree // self.expert_parallel_degree - - # init experts groups inside all sharding groups - for ranks_in_current_sharding_group in sharding_parallel_groups: - # init experts parallel groups (dispatch & combine) - for i in range(experts_replicas): - rank_indices = list(range(i * self.expert_parallel_degree, (i + 1) * self.expert_parallel_degree)) - ranks = [ranks_in_current_sharding_group[i] for i in rank_indices] - if message2nccl_config is not None and hybrid_configs.get("ep_configs", None) is not None: - group = dist.new_group( - ranks=ranks, nccl_config=message2nccl_config(hybrid_configs["ep_configs"].nccl_config, "ep") - ) - else: - group = dist.new_group(ranks=ranks) - if dist.get_rank() in ranks: - assert not hasattr(hcg, "expert_parallel_group"), "expert_parallel_group can not be set repeate" - setattr(hcg, "expert_parallel_group", group) - - # init experts gradients comm groups - for i in range(self.expert_parallel_degree): - rank_indices = list(range(i, self.sharding_parallel_degree, self.expert_parallel_degree)) - ranks = [ranks_in_current_sharding_group[i] for i in rank_indices] - if message2nccl_config is not None and hybrid_configs.get("ep_configs", None) is not None: - group = dist.new_group( - ranks=ranks, - nccl_config=message2nccl_config(hybrid_configs["ep_configs"].grad_nccl_config, "ep_grad"), - ) - else: - group = dist.new_group(ranks=ranks) - if dist.get_rank() in ranks: - assert not hasattr(hcg, "expert_grad_comm_group"), "expert_grad_comm_group can not be set repeate" - setattr(hcg, "expert_grad_comm_group", group) - - assert hasattr(hcg, "expert_parallel_group") and hasattr(hcg, "expert_grad_comm_group") - logger.info( - f"experts groups are created, expert_parallel_group: {hcg.expert_parallel_group}, expert_grad_comm_group: {hcg.expert_grad_comm_group}" - ) + setattr(hcg, "expert_parallel_group", moe_group) + setattr(hcg, "expert_grad_comm_group", moe_grad_group) def __str__(self): self_as_dict = asdict(self) @@ -2200,6 +2320,28 @@ def pipeline_parallel_rank(self): else: return 0 + @property + def expert_parallel_rank(self): + if self.use_hybrid_parallel: + hcg = fleet.get_hybrid_communicate_group() + if hasattr(hcg, "get_expert_parallel_rank"): + return max(hcg.get_expert_parallel_rank(), 0) + else: + return 0 + else: + return 0 + + @property + def context_parallel_rank(self): + if self.use_hybrid_parallel: + hcg = fleet.get_hybrid_communicate_group() + if hasattr(hcg, "get_context_parallel_rank"): + return max(hcg.get_context_parallel_rank(), 0) + else: + return 0 + else: + return 0 + def _format_name(self, prefix, rank, degree): size = 2 return f"{prefix}{rank:0>{size}d}" @@ -2214,7 +2356,7 @@ def optimizer_name_suffix(self): name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) if self.sharding_parallel_degree > 1: name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree)) - if self.use_expert_parallel: + if self.use_expert_parallel and self.expert_parallel_degree <= 1: name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: @@ -2230,7 +2372,7 @@ def weight_name_suffix(self): name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree)) if self.pipeline_parallel_degree > 1: name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) - if self.use_expert_parallel: + if self.use_expert_parallel and self.expert_parallel_degree <= 1: name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) @@ -2239,7 +2381,9 @@ def weight_name_suffix(self): return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None - def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None): + def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None, sharding_parallel_degree=None): + if sharding_parallel_degree is None: + sharding_parallel_degree = self.sharding_parallel_degree if self.use_hybrid_parallel: name = [] if self.tensor_parallel_degree > 1: @@ -2249,12 +2393,12 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None): pp_id = self.pipeline_parallel_rank assert isinstance(pp_id, int) name.append(self._format_name("pp", pp_id, self.pipeline_parallel_degree)) - if self.sharding_parallel_degree > 1: + if sharding_parallel_degree > 1: if shard_id is None: shard_id = self.sharding_parallel_rank assert isinstance(shard_id, int) - name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree)) - if self.use_expert_parallel: + name.append(self._format_name("shard", shard_id, sharding_parallel_degree)) + if self.use_expert_parallel and self.expert_parallel_degree <= 1: if moe_id is None: moe_id = self.data_parallel_rank assert isinstance(moe_id, int) diff --git a/paddlenlp/trainer/utils/offload_optimizer.py b/paddlenlp/trainer/utils/offload_optimizer.py new file mode 100644 index 000000000000..65f5b77e2e5d --- /dev/null +++ b/paddlenlp/trainer/utils/offload_optimizer.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import _C_ops +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer, +) +from paddle.optimizer import Optimizer + +from .sharding_io import to_device + + +def offload(tensor): + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CPUPlace() + + new_tensor = to_device(tensor, place) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def reload(tensor): + new_tensor = to_device(tensor) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def hack_offload_optimizer(): + # Step 1: mock _add_accumulator + origin_add_accumulator = getattr(Optimizer, "_add_accumulator") + + def new_add_accumulator(self, *args, **kwargs): + x = origin_add_accumulator(self, *args, **kwargs) + offload(x) + return x + + setattr(Optimizer, "_add_accumulator", new_add_accumulator) + + # Step 2: mock _C_ops.adamw_ and _C_ops.adamw + for name in ["adam_", "adamw_"]: + origin_op = getattr(_C_ops, name) + + def new_opt_op(*args): + for arg in args: + if isinstance(arg, paddle.Tensor): + reload(arg) + + ret = origin_op(*args) + + for i, arg in enumerate(args): + if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient + offload(arg) + return ret + + setattr(_C_ops, name, new_opt_op) + + # Step 3: mock _insert_sync + opt_type = HybridParallelOptimizer + origin_insert_sync = getattr(opt_type, "_insert_sync") + + def new_insert_sync(self, sync_var, *args, **kwargs): + origin_place = sync_var.place + reload(sync_var) + ret = origin_insert_sync(self, sync_var, *args, **kwargs) + new_sync_var = to_device(sync_var, origin_place) + assert new_sync_var is sync_var, "to_device must be inplace operation" + return ret + + setattr(opt_type, "_insert_sync", new_insert_sync) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index ca71b478ee49..76389ab1a7ff 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -785,8 +785,10 @@ def __init__(self, config: DeepseekV2Config): ) # (LiuTing) only support either tp or ep. - moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() expert_parallel_degree = dist.get_world_size(moe_group) + moe_grad_group = fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() + expert_parallel_degree = 1 if expert_parallel_degree < 0 else expert_parallel_degree act_tp_shard = config.tensor_parallel_degree > 1 and expert_parallel_degree <= 1 super().__init__( @@ -800,7 +802,12 @@ def __init__(self, config: DeepseekV2Config): }, gate=gate, capacity=2.0, + moe_group="expert", ) + + for p in self.experts.parameters(): + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -838,8 +845,8 @@ def __init__(self, config: DeepseekV2Config): ) hcg = fleet.get_hybrid_communicate_group() - moe_group = hcg.expert_parallel_group - moe_grad_group = hcg.expert_grad_comm_group + moe_group = hcg.get_expert_parallel_group() + moe_grad_group = hcg.get_moe_sharding_parallel_group() super().__init__( config=config, @@ -1467,7 +1474,7 @@ def get_tensor_parallel_split_mappings(num_layers): base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) # moe unit routed experts - moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() expert_parallel_degree = dist.get_world_size(moe_group) if expert_parallel_degree <= 1: for e_i in range(config.n_routed_experts): @@ -1721,6 +1728,19 @@ def forward( attention_mask = attention_mask[ :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers ] + + # attn_mask_startend_row_indices: [b, num_head, seq_len] or [b, num_head, seq_len, C], C is 2 or 4 + if attn_mask_startend_row_indices is not None: + if attn_mask_startend_row_indices.ndim == 3: + attn_mask_startend_row_indices = attn_mask_startend_row_indices[ + :, :, : -self.config.num_nextn_predict_layers, + ] + elif attn_mask_startend_row_indices.ndim == 4: + attn_mask_startend_row_indices = attn_mask_startend_row_indices[ + :, :, : -self.config.num_nextn_predict_layers, : + ] + else: + raise ValueError("attn_mask_startend_row_indices must be 3D or 4D tensor") if self.enable_recompute and self.training: if use_cache: diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 42b0e5de776d..4361023f6bbc 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -132,6 +132,19 @@ def forward(self, args): attention_mask = attention_mask[ :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers ] + + # attn_mask_startend_row_indices: [b, num_head, seq_len] or [b, num_head, seq_len, C], C is 2 or 4 + if attn_mask_startend_row_indices is not None: + if attn_mask_startend_row_indices.ndim == 3: + attn_mask_startend_row_indices = attn_mask_startend_row_indices[ + :, :, : -self.config.num_nextn_predict_layers, + ] + elif attn_mask_startend_row_indices.ndim == 4: + attn_mask_startend_row_indices = attn_mask_startend_row_indices[ + :, :, : -self.config.num_nextn_predict_layers, : + ] + else: + raise ValueError("attn_mask_startend_row_indices must be 3D or 4D tensor") if attention_mask is not None: assert ( diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 340fba1f5245..040fb5d1f22a 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -176,12 +176,13 @@ def __init__( except AttributeError: is_fleet_init = False - if ( - is_fleet_init - and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 - and moe_group == "data" - ): - self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + if is_fleet_init and dist.get_world_size() > 1: + if moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group + else: + assert NotImplementedError("moe_group can only be data or expert, but given {}".format(self.moe_group)) self.moe_rank = dist.get_rank(self.moe_group) self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.expert_parallel_degree = dist.get_world_size(self.moe_group) @@ -350,8 +351,7 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m self.token_dispatcher = MoEFlexTokenDispatcher( self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group ) - - self.experts = nn.LayerList([expert_class(**expert_kwargs)] * self.num_local_experts) + self.experts = nn.LayerList([expert_class(**expert_kwargs) for _ in range(self.num_local_experts)]) self.router = gate def expert_forward(self, dispatched_input, tokens_per_expert): diff --git a/paddlenlp/transformers/moe_utils.py b/paddlenlp/transformers/moe_utils.py index 466591b0638d..d82654f6375b 100644 --- a/paddlenlp/transformers/moe_utils.py +++ b/paddlenlp/transformers/moe_utils.py @@ -18,6 +18,11 @@ import paddle +try: + from paddle import scatter_add_ +except ImportError: + scatter_add_ = None + def permute( tokens, @@ -91,11 +96,8 @@ def unpermute( # Create an output tensor filled with zeros output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype) # Scatter add the permuted_input back to the original positions - output_tokens.put_along_axis_( - axis=0, - indices=sorted_indices.unsqueeze(1).expand([-1, hidden]), - values=permuted_tokens, - reduce="add", - include_self=True, - ) + if scatter_add_ is not None: + scatter_add_(output_tokens, sorted_indices, permuted_tokens) + else: + output_tokens.scatter_(index=sorted_indices, updates=permuted_tokens, overwrite=False) return output_tokens diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py index 2e244d211158..e03f8eb7c399 100644 --- a/paddlenlp/trl/model_config.py +++ b/paddlenlp/trl/model_config.py @@ -145,3 +145,5 @@ class ModelConfig: ) actscale_moving_rate: float = field(default=0.01, metadata={"help": "EMA moving_rate for activation scale"}) fp8_format_type: str = field(default="hybrid", metadata={"help": "FP8 Format"}) + using_flex_token: bool = field(default=False, metadata={"help": "Whether to use deepep moe_layer"}) + using_fake_gate: bool = field(default=False, metadata={"help": "Whether to fake gate"}) diff --git a/paddlenlp/trl/sft_config.py b/paddlenlp/trl/sft_config.py index f759bc68a1aa..06e9843d2dad 100644 --- a/paddlenlp/trl/sft_config.py +++ b/paddlenlp/trl/sft_config.py @@ -80,7 +80,7 @@ def __post_init__(self): super().__post_init__() # NOTE(gongenlei): new add autotuner_benchmark if self.autotuner_benchmark: - self.max_steps = 5 + self.max_steps = 20 self.do_train = True self.do_export = False self.do_predict = False From f2e43b927fc9f1fc49dd2d550a8ddddac05812a3 Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Wed, 20 Aug 2025 17:27:40 +0800 Subject: [PATCH 02/14] fix use_rms_norm && add subbatch_token_num config (#10974) Co-authored-by: Your Name --- llm/config/deepseek-v3/sft_argument.json | 1 + llm/run_finetune.py | 3 ++- .../transformers/deepseek_v2/modeling.py | 19 ++++++------------- paddlenlp/trl/model_config.py | 3 +++ 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index 166fa39a394f..5703453e40a9 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -41,6 +41,7 @@ "flash_mask": true, "using_fake_gate": true, "using_flex_token": true, + "use_fused_rms_norm": true, "pre_alloc_memory": 60, "tensorwise_offload_optimizer": true, "autotuner_benchmark": true diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 1e829021efd9..b129cf27df30 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -272,8 +272,9 @@ def main(): raise ValueError("Please set eval_with_do_generation to false in pipeline parallel mode.") model_class = AutoModelForCausalLMPipe - model_config["using_flex_token"] = model_args.using_fake_gate + model_config.using_flex_token = model_args.using_flex_token model_config.using_fake_gate = model_args.using_fake_gate + model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num print("model_config ", model_config, flush=True) if model_args.continue_training and not training_args.autotuner_benchmark: model = model_class.from_pretrained( diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 76389ab1a7ff..fde4fd576200 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -330,17 +330,8 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, eps=1e-6, use_seq mark_as_sequence_parallel_parameter(self.weight) def forward(self, hidden_states): - if self.config.use_fused_rms_norm and get_env_device() == "xpu": - if self.weight.dtype != hidden_states.dtype: - hidden_states = paddle.cast(hidden_states, self.weight.dtype) - try: - import paddle_xpu_nn # noqa: F821 - - return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] - except ImportError: - raise NotImplementedError( - f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" - ) + if self.config.use_fused_rms_norm: + return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) with paddle.amp.auto_cast(False): hidden_states = hidden_states.astype("float32") @@ -1728,12 +1719,14 @@ def forward( attention_mask = attention_mask[ :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers ] - + # attn_mask_startend_row_indices: [b, num_head, seq_len] or [b, num_head, seq_len, C], C is 2 or 4 if attn_mask_startend_row_indices is not None: if attn_mask_startend_row_indices.ndim == 3: attn_mask_startend_row_indices = attn_mask_startend_row_indices[ - :, :, : -self.config.num_nextn_predict_layers, + :, + :, + : -self.config.num_nextn_predict_layers, ] elif attn_mask_startend_row_indices.ndim == 4: attn_mask_startend_row_indices = attn_mask_startend_row_indices[ diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py index e03f8eb7c399..8a28bf4860c1 100644 --- a/paddlenlp/trl/model_config.py +++ b/paddlenlp/trl/model_config.py @@ -147,3 +147,6 @@ class ModelConfig: fp8_format_type: str = field(default="hybrid", metadata={"help": "FP8 Format"}) using_flex_token: bool = field(default=False, metadata={"help": "Whether to use deepep moe_layer"}) using_fake_gate: bool = field(default=False, metadata={"help": "Whether to fake gate"}) + moe_subbatch_token_num: int = field( + default=0, metadata={"help": "moelayer subbatch token num, The smaller the value, the smaller the peak memory"} + ) From 7d5eb9a86d93a498108db04c9a2420ca6ca728c6 Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Fri, 22 Aug 2025 11:23:59 +0800 Subject: [PATCH 03/14] moelayer with subbatch to reduce memory (#10985) Co-authored-by: deepllz --- llm/config/deepseek-v3/sft_argument.json | 1 + .../transformers/deepseek_v2/modeling.py | 250 ++++++++++++++++-- .../transformers/deepseek_v2/modeling_pp.py | 16 +- paddlenlp/transformers/moe_layer.py | 1 + 4 files changed, 240 insertions(+), 28 deletions(-) diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index 5703453e40a9..260f509e6f7c 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -42,6 +42,7 @@ "using_fake_gate": true, "using_flex_token": true, "use_fused_rms_norm": true, + "moe_subbatch_token_num": 1024, "pre_alloc_memory": 60, "tensorwise_offload_optimizer": true, "autotuner_benchmark": true diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index fde4fd576200..66bbb459a639 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1185,7 +1185,7 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute self.input_layernorm = DeepseekV2RMSNorm(config) self.post_attention_layernorm = DeepseekV2RMSNorm(config) - def forward( + def subbatch_recompute_forward( self, hidden_states: paddle.Tensor, position_ids: Optional[paddle.Tensor] = None, @@ -1194,26 +1194,67 @@ def forward( past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, - **kwargs, ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: - """ - Args: - hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` - attention_mask (`paddle.Tensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + offload_kwargs = {} + offload_kwargs["offload_indices"] = [0] + assert self.recompute_granularity != "full_attn" + attn_outputs = recompute( + self.attn, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + + assert len(hidden_states.shape) == 3 + sub_seq_len = self.config.moe_subbatch_token_num + seq_len = hidden_states.shape[1] + assert seq_len % sub_seq_len == 0 + num_chunks = seq_len // sub_seq_len + split_list = [sub_seq_len] * num_chunks + input_list = paddle.split(hidden_states, split_list, axis=1) + output_list = [] + for chunk in input_list: + offload_kwargs = {} + offload_kwargs["offload_indices"] = [0] + out = recompute( + self.mlp.forward, + chunk, + **offload_kwargs, ) + output_list.append(out) + hidden_states = paddle.concat(output_list, axis=1) + outputs = recompute( + self.post_process, + hidden_states, + residual, + output_attentions, + use_cache, + self_attn_weights, + present_key_value, + **offload_kwargs, + ) + return outputs + + def attn( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1254,18 +1295,32 @@ def forward( else: hidden_states = outputs + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + attn_outputs = (hidden_states, residual) + if output_attentions: self_attn_weights = outputs[1] + attn_outputs += (self_attn_weights,) if use_cache: present_key_value = outputs[2 if output_attentions else 1] + attn_outputs += (present_key_value,) - hidden_states = residual + hidden_states + return attn_outputs - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + def post_process( + self, + hidden_states, + residual, + output_attentions=False, + use_cache=False, + self_attn_weights=None, + present_key_value=None, + ): hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -1281,6 +1336,139 @@ def forward( return outputs + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + attn_outputs = self.attn( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + + hidden_states = self.mlp(hidden_states) + outputs = self.post_process( + hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value + ) + return outputs + + # def forward( + # self, + # hidden_states: paddle.Tensor, + # position_ids: Optional[paddle.Tensor] = None, + # attention_mask: Optional[paddle.Tensor] = None, + # output_attentions: Optional[bool] = False, + # past_key_value: Optional[Tuple[paddle.Tensor]] = None, + # use_cache: Optional[bool] = False, + # attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + # **kwargs, + # ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + # """ + # Args: + # hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` + # attention_mask (`paddle.Tensor`, *optional*): + # attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + # query_sequence_length, key_sequence_length)` if default attention is used. + # output_attentions (`bool`, *optional*): + # Whether or not to return the attentions tensors of all attention layers. See `attentions` under + # returned tensors for more detail. + # use_cache (`bool`, *optional*): + # If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + # (see `past_key_values`). + # past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + # """ + # if "padding_mask" in kwargs: + # warnings.warn( + # "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + # ) + # residual = hidden_states + + # hidden_states = self.input_layernorm(hidden_states) + + # # Self Attention + # has_gradient = not hidden_states.stop_gradient + # if ( + # self.enable_recompute + # and self.layerwise_recompute + # and has_gradient + # and self.recompute_granularity == "full_attn" + # ): + # outputs = recompute( + # self.self_attn, + # hidden_states=hidden_states, + # position_ids=position_ids, + # attention_mask=attention_mask, + # output_attentions=output_attentions, + # past_key_value=past_key_value, + # use_cache=use_cache, + # attn_mask_startend_row_indices=attn_mask_startend_row_indices, + # **kwargs, + # ) + # else: + # outputs = self.self_attn( + # hidden_states=hidden_states, + # position_ids=position_ids, + # attention_mask=attention_mask, + # output_attentions=output_attentions, + # past_key_value=past_key_value, + # use_cache=use_cache, + # attn_mask_startend_row_indices=attn_mask_startend_row_indices, + # **kwargs, + # ) + + # if type(outputs) is tuple: + # hidden_states = outputs[0] + # else: + # hidden_states = outputs + + # if output_attentions: + # self_attn_weights = outputs[1] + + # if use_cache: + # present_key_value = outputs[2 if output_attentions else 1] + + # hidden_states = residual + hidden_states + + # # Fully Connected + # residual = hidden_states + # hidden_states = self.post_attention_layernorm(hidden_states) + # hidden_states = self.mlp(hidden_states) + # hidden_states = residual + hidden_states + + # outputs = (hidden_states,) + + # if output_attentions: + # outputs += (self_attn_weights,) + + # if use_cache: + # outputs += (present_key_value,) + + # if type(outputs) is tuple and len(outputs) == 1: + # outputs = outputs[0] + + # return outputs + class DeepseekV2MTPLayer(DeepseekV2DecoderLayer): def __init__( @@ -1800,6 +1988,8 @@ def forward( next_decoder_cache = () if use_cache else None mtp_outputs = [] + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 + for idx in range(self.config.num_hidden_layers): decoder_layer = self.layers[idx] @@ -1809,7 +1999,17 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient - if ( + if moelayer_use_subbatch_recompute: + layer_outputs = decoder_layer.subbatch_recompute_forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + elif ( self.enable_recompute and idx not in self.no_recompute_layers and has_gradient diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 4361023f6bbc..1bb77a737683 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -132,12 +132,14 @@ def forward(self, args): attention_mask = attention_mask[ :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers ] - + # attn_mask_startend_row_indices: [b, num_head, seq_len] or [b, num_head, seq_len, C], C is 2 or 4 if attn_mask_startend_row_indices is not None: if attn_mask_startend_row_indices.ndim == 3: attn_mask_startend_row_indices = attn_mask_startend_row_indices[ - :, :, : -self.config.num_nextn_predict_layers, + :, + :, + : -self.config.num_nextn_predict_layers, ] elif attn_mask_startend_row_indices.ndim == 4: attn_mask_startend_row_indices = attn_mask_startend_row_indices[ @@ -221,7 +223,15 @@ def forward(self, args): elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices - if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 + if moelayer_use_subbatch_recompute: + hidden_states = super().subbatch_recompute_forward( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + elif self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: if attention_mask is not None or attn_mask_startend_row_indices is not None: hidden_states = recompute( super().forward, diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 040fb5d1f22a..ed3a705ab245 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -362,6 +362,7 @@ def expert_forward(self, dispatched_input, tokens_per_expert): for chunk, expert in zip(chunks, self.experts): chunk = chunk.contiguous() # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + # print("expert token:", chunk.shape, flush=True) outputs += [expert(chunk)] return paddle.concat(outputs, axis=0) From 4493f192d832690b2caa39337f809c9820900f95 Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Thu, 28 Aug 2025 14:45:14 +0800 Subject: [PATCH 04/14] support sequence parallel in deepseek v3 model * support sequence parallel in deepseek v3 * polish, remove 'print' command --- llm/config/deepseek-v3/sft_argument.json | 14 ++-- paddlenlp/trainer/training_args.py | 4 +- .../transformers/deepseek_v2/modeling.py | 74 +++++++++++++++---- .../transformers/deepseek_v2/modeling_pp.py | 11 ++- paddlenlp/transformers/moe_layer.py | 2 +- 5 files changed, 76 insertions(+), 29 deletions(-) diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index 260f509e6f7c..3d614a77ffc2 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -1,19 +1,19 @@ { - "model_name_or_path": "/root/paddlejob/new_disk/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/", + "model_name_or_path": "./dsv3_128k_config", "dataset_name_or_path": "./data", "output_dir": "./checkpoints/sft_ckpts", "per_device_train_batch_size": 1, "gradient_accumulation_steps": 16, "per_device_eval_batch_size": 1, "eval_accumulation_steps": 1, - "num_train_epochs": 3, + "num_train_epochs": 1, "learning_rate": 2.2e-04, "warmup_steps": 30, "logging_steps": 1, "evaluation_strategy": "epoch", "save_strategy": "epoch", "src_length": 2048, - "max_length": 4097, + "max_length": 131073, "bf16": true, "fp16_opt_level": "O2", "do_train": true, @@ -31,9 +31,9 @@ "recompute_use_reentrant": true, "recompute_granularity": "full", "save_total_limit": 1, - "tensor_parallel_degree": 1, + "tensor_parallel_degree": 8, "pipeline_parallel_degree": 8, - "sharding_parallel_degree": 8, + "sharding_parallel_degree": 2, "sharding": "stage1", "zero_padding": true, "unified_checkpoint": true, @@ -45,6 +45,8 @@ "moe_subbatch_token_num": 1024, "pre_alloc_memory": 60, "tensorwise_offload_optimizer": true, - "autotuner_benchmark": true + "autotuner_benchmark": true, + "sequence_parallel": true, + "tensor_parallel_output": true } diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 9f3c89f6fa23..b0fa3d278a78 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1699,8 +1699,8 @@ def is_context_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) - if self.expert_parallel_degree > 1: - self.add_moe_comm_group() + # if self.expert_parallel_degree > 1: + # self.add_moe_comm_group() elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 66bbb459a639..a08899111b1b 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -34,6 +34,7 @@ from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.communication.reduce import ReduceOp from paddle.distributed.fleet.recompute.recompute import recompute from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -892,6 +893,10 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads + self.num_local_heads = self.num_heads + if config.tensor_parallel_degree > 1: + assert self.num_heads % config.tensor_parallel_degree == 0, f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})." + self.num_local_heads = self.num_heads // config.tensor_parallel_degree self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta @@ -940,18 +945,18 @@ def linear_dtype_gaurd(): if self.q_lora_rank is None: with linear_dtype_gaurd(): - self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False) else: with linear_dtype_gaurd(): self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) - self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) - self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False) + self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) with linear_dtype_gaurd(): self.kv_a_proj_with_mqa = Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) - self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True) - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False) - self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False) + self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=False) + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=True) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) else: # for without tensor parallel if self.q_lora_rank is None: @@ -1047,7 +1052,11 @@ def forward( warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - bsz, q_len, _ = hidden_states.shape + ori_shape = hidden_states.shape + if self.config.sequence_parallel: + seq_len, bsz, _ = hidden_states.shape + else: + bsz, seq_len, _ = hidden_states.shape # DeepSeekV2 q_lora_rank=1536 # DeepSeekV2-lite q_lora_rank=None @@ -1057,8 +1066,8 @@ def forward( q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) if self.sequence_parallel: - target_query_shape = [-1, self.seq_length, self.num_heads, self.q_head_dim] - target_key_value_shape = [-1, self.seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + target_query_shape = [bsz, self.seq_length, self.num_local_heads, self.q_head_dim] + target_key_value_shape = [bsz, self.seq_length, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim] else: target_query_shape = [0, 0, self.num_heads, self.q_head_dim] target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] @@ -1071,8 +1080,9 @@ def forward( compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) if self.sequence_parallel: k_pe = GatherOp.apply(k_pe) - k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( - [-1, q_len, self.num_heads, self.qk_rope_head_dim] + k_pe = paddle.transpose(k_pe, [1, 0, 2]) + k_pe = k_pe.reshape([-1, self.seq_length, 1, self.qk_rope_head_dim]).expand( + [-1, self.seq_length, self.num_local_heads, self.qk_rope_head_dim] ) # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 @@ -1140,6 +1150,9 @@ def forward( # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. attn_output = self.o_proj(attn_output) + if attn_output.shape != ori_shape: + attn_output = attn_output.reshape(ori_shape) + if not output_attentions: attn_weights = None @@ -1975,7 +1988,8 @@ def forward( if self.config.sequence_parallel: # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] bs, seq_len, hidden_size = inputs_embeds.shape - inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + # inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) @@ -2057,7 +2071,8 @@ def forward( if self.config.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) - hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] + # hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) inputs_embeds_cur_depth = paddle.concat( [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 @@ -2122,7 +2137,12 @@ def __init__(self, config: DeepseekV2Config): else: self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + if self.config.sequence_parallel: + self.seq_para_scale = 1.0 / self.config.tensor_parallel_degree + self.mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None): + if self.enable_parallel_cross_entropy: if prediction_scores.shape[-1] == self.config.vocab_size: warnings.warn( @@ -2151,13 +2171,33 @@ def add_loss(main_loss, loss): masked_lm_labels_ori = masked_lm_labels masked_lm_labels = masked_lm_labels[:, : -self.config.num_nextn_predict_layers] seq_length = masked_lm_labels.shape[1] + + if self.config.sequence_parallel: + masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B] + masked_lm_labels = ScatterOp.apply(masked_lm_labels) + loss = compute_loss(prediction_scores, masked_lm_labels) + if self.config.sequence_parallel: + loss = loss * self.seq_para_scale + dist.all_reduce(loss, op=ReduceOp.SUM, group=self.mp_group) + mtp_loss_res = [] for depth in range(self.config.num_nextn_predict_layers): prediction_scores_cur_depth = mtp_logits[depth] masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)] + + if self.config.sequence_parallel: + masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B] + masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth) + res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth) + + if self.config.sequence_parallel: + res_cur_depth = res_cur_depth * self.seq_para_scale + dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group) + + mtp_loss_res.append(res_cur_depth) loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip @@ -2203,9 +2243,11 @@ def __init__(self, config: DeepseekV2Config): self.xpu_parallel_matmul = None def forward(self, hidden_states, tensor_parallel_output=None): - if self.config.sequence_parallel: - hidden_states = GatherOp.apply(hidden_states) - hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) + + # if self.config.sequence_parallel: + # hidden_states = GatherOp.apply(hidden_states) + # hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 1bb77a737683..cf0ddc363e75 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -170,8 +170,9 @@ def forward(self, args): batch_size, seq_length, _ = inputs_embeds.shape if self.sequence_parallel: + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] - inputs_embeds = paddle.reshape(inputs_embeds, [-1, inputs_embeds.shape[-1]]) + # inputs_embeds = paddle.reshape(inputs_embeds, [-1, inputs_embeds.shape[-1]]) # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) embeds_res = [inputs_embeds] @@ -184,7 +185,8 @@ def forward(self, args): axis=1, ) if self.sequence_parallel: - inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) + inputs_embeds_mtp = paddle.transpose(inputs_embeds_mtp, [1, 0, 2]) # [B, S, H] --> [S, B, H] + # inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) embeds_res.append(inputs_embeds_mtp) # if not self.sequence_parallel @@ -195,7 +197,8 @@ def forward(self, args): return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) else: if self.sequence_parallel: - inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + # inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) inputs_embeds = ScatterOp.apply(inputs_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) @@ -205,7 +208,7 @@ def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) if self.config.num_nextn_predict_layers > 0: - batch_size, _, hidden_size = hidden_states.shape + hidden_size = hidden_states.shape[-1] batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:] hidden_states = hidden_states[..., :batch_size_mtp] diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index ed3a705ab245..f65f702d2a3f 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -180,7 +180,7 @@ def __init__( if moe_group == "data": self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() elif moe_group == "expert": - self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group() else: assert NotImplementedError("moe_group can only be data or expert, but given {}".format(self.moe_group)) self.moe_rank = dist.get_rank(self.moe_group) From 546e1cb73bf7f1de3fc6baf4f41b3bbe87792fbc Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Thu, 28 Aug 2025 20:41:05 +0800 Subject: [PATCH 05/14] enable master_grad && stage1 V2 opt and polish some code --- llm/config/deepseek-v3/sft_argument.json | 26 ++++++++++--------- paddlenlp/trainer/training_args.py | 4 +-- .../transformers/deepseek_v2/modeling_pp.py | 4 +-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index 3d614a77ffc2..cd5cb98714d8 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -1,29 +1,32 @@ { - "model_name_or_path": "./dsv3_128k_config", + "model_name_or_path": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/", "dataset_name_or_path": "./data", "output_dir": "./checkpoints/sft_ckpts", "per_device_train_batch_size": 1, "gradient_accumulation_steps": 16, "per_device_eval_batch_size": 1, "eval_accumulation_steps": 1, + "max_steps": 20, + "amp_master_grad": true, "num_train_epochs": 1, "learning_rate": 2.2e-04, "warmup_steps": 30, "logging_steps": 1, - "evaluation_strategy": "epoch", - "save_strategy": "epoch", + "evaluation_strategy": "steps", + "save_strategy": "steps", "src_length": 2048, - "max_length": 131073, + "max_length": 4097, "bf16": true, "fp16_opt_level": "O2", "do_train": true, - "do_eval": true, + "do_eval": false, "disable_tqdm": true, "use_expert_parallel": true, "expert_parallel_degree": 8, - "continue_training": true, + "continue_training": false, "pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm", "tensor_parallel_config": "enable_delay_scale_loss", + "sharding_parallel_config": "split_param", "load_best_model_at_end": true, "eval_with_do_generation": false, "metric_for_best_model": "loss", @@ -31,9 +34,9 @@ "recompute_use_reentrant": true, "recompute_granularity": "full", "save_total_limit": 1, - "tensor_parallel_degree": 8, - "pipeline_parallel_degree": 8, - "sharding_parallel_degree": 2, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 16, + "sharding_parallel_degree": 8, "sharding": "stage1", "zero_padding": true, "unified_checkpoint": true, @@ -42,11 +45,10 @@ "using_fake_gate": true, "using_flex_token": true, "use_fused_rms_norm": true, - "moe_subbatch_token_num": 1024, + "moe_subbatch_token_num": 0, "pre_alloc_memory": 60, "tensorwise_offload_optimizer": true, - "autotuner_benchmark": true, - "sequence_parallel": true, + "sequence_parallel": false, "tensor_parallel_output": true } diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index b0fa3d278a78..9f3c89f6fa23 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1699,8 +1699,8 @@ def is_context_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) - # if self.expert_parallel_degree > 1: - # self.add_moe_comm_group() + if self.expert_parallel_degree > 1: + self.add_moe_comm_group() elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index cf0ddc363e75..45da8559a078 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -242,7 +242,7 @@ def forward(self, args): position_ids=position_ids, attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, - use_reentrant=False, + use_reentrant=self.config.recompute_use_reentrant, ) else: # for pretrain @@ -300,7 +300,7 @@ def forward(self, args): position_ids=position_ids, attention_mask=attention_mask, attn_mask_startend_row_indices=attn_mask_startend_row_indices, - use_reentrant=False, + use_reentrant=self.config.recompute_use_reentrant, ) else: # for pretrain From adc2f365778b5c697ffdc7f2ec8b441b37a3df3c Mon Sep 17 00:00:00 2001 From: Difer Date: Fri, 29 Aug 2025 12:00:51 +0800 Subject: [PATCH 06/14] fix some bugs --- .../transformers/deepseek_v2/modeling.py | 41 +++++++++++-------- paddlenlp/transformers/moe_gate.py | 14 +++++-- paddlenlp/transformers/moe_layer.py | 29 +++++++++++-- 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index a08899111b1b..6e9f17f39205 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -33,8 +33,8 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.communication.reduce import ReduceOp +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -799,7 +799,7 @@ def __init__(self, config: DeepseekV2Config): for p in self.experts.parameters(): setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) - + setattr(p, "is_moe_param", True) self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -851,6 +851,7 @@ def __init__(self, config: DeepseekV2Config): for p in self.experts.parameters(): setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + setattr(p, "is_moe_param", True) self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: @@ -895,7 +896,9 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): self.num_heads = config.num_attention_heads self.num_local_heads = self.num_heads if config.tensor_parallel_degree > 1: - assert self.num_heads % config.tensor_parallel_degree == 0, f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})." + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})." self.num_local_heads = self.num_heads // config.tensor_parallel_degree self.max_position_embeddings = config.max_position_embeddings @@ -1067,7 +1070,12 @@ def forward( if self.sequence_parallel: target_query_shape = [bsz, self.seq_length, self.num_local_heads, self.q_head_dim] - target_key_value_shape = [bsz, self.seq_length, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim] + target_key_value_shape = [ + bsz, + self.seq_length, + self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim, + ] else: target_query_shape = [0, 0, self.num_heads, self.q_head_dim] target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] @@ -1153,7 +1161,6 @@ def forward( if attn_output.shape != ori_shape: attn_output = attn_output.reshape(ori_shape) - if not output_attentions: attn_weights = None @@ -1511,7 +1518,7 @@ def forward( hidden_states = self.hnorm(hidden_states) nextn_hidden_state = self.enorm(nextn_hidden_state) - hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1)) + hidden_states = self.eh_proj(paddle.concat([nextn_hidden_state, hidden_states], axis=-1)) layer_outputs = super(DeepseekV2MTPLayer, self).forward( hidden_states, @@ -1711,10 +1718,13 @@ def get_tensor_parallel_split_mappings(num_layers): return final_actions - mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers + 2) return mappings + def get_tensor_parallel_mappings(self, is_split=True): + return type(self)._get_tensor_parallel_mappings(self.config, is_split) + def _init_weights(self, layer): return if self.config.tensor_parallel_degree > 1: @@ -1988,7 +1998,7 @@ def forward( if self.config.sequence_parallel: # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] bs, seq_len, hidden_size = inputs_embeds.shape - inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] # inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) @@ -2071,7 +2081,7 @@ def forward( if self.config.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) - hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] # hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) inputs_embeds_cur_depth = paddle.concat( @@ -2173,7 +2183,7 @@ def add_loss(main_loss, loss): seq_length = masked_lm_labels.shape[1] if self.config.sequence_parallel: - masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B] + masked_lm_labels = masked_lm_labels.transpose([1, 0]) # [B, S] --> [S, B] masked_lm_labels = ScatterOp.apply(masked_lm_labels) loss = compute_loss(prediction_scores, masked_lm_labels) @@ -2188,16 +2198,15 @@ def add_loss(main_loss, loss): masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)] if self.config.sequence_parallel: - masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B] + masked_lm_labels_cur_depth = masked_lm_labels_cur_depth.transpose([1, 0]) # [B, S] --> [S, B] masked_lm_labels_cur_depth = ScatterOp.apply(masked_lm_labels_cur_depth) res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth) - + if self.config.sequence_parallel: res_cur_depth = res_cur_depth * self.seq_para_scale dist.all_reduce(res_cur_depth, op=ReduceOp.SUM, group=self.mp_group) - mtp_loss_res.append(res_cur_depth) loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip @@ -2245,9 +2254,9 @@ def __init__(self, config: DeepseekV2Config): def forward(self, hidden_states, tensor_parallel_output=None): # if self.config.sequence_parallel: - # hidden_states = GatherOp.apply(hidden_states) - # hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) - # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) + # hidden_states = GatherOp.apply(hidden_states) + # hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + # hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 0ccd96cc1618..4a526feb6acb 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -578,11 +578,19 @@ def topkgating_nodrop(self, gates: paddle.Tensor): # get topk mask mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + # hongyu fix start + gates_masked = gates * mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + gates_masked *= self.routed_scaling_factor + # hongyu fix end if hasattr(self.config, "seq_aux") and self.config.seq_aux: l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) else: l_aux = self._cal_aux_loss(gates, mask) - exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) - topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) - return topk_masked_gates, mask, exp_counts, l_aux, l_zloss + # topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + return gates_masked, mask, exp_counts, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index f65f702d2a3f..723bf525d6df 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -175,7 +175,6 @@ def __init__( is_fleet_init = True except AttributeError: is_fleet_init = False - if is_fleet_init and dist.get_world_size() > 1: if moe_group == "data": self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() @@ -198,7 +197,6 @@ def __init__( self.expert_parallel_degree = 1 self.moe_num_experts_per_device = self.moe_num_experts self.is_dummy_moe = True - self.all_to_all_dropout = all_to_all_dropout self.enable_recompute = False @@ -348,10 +346,21 @@ def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, m self.moe_router_topk = gate.top_k self.moe_num_experts = moe_num_experts self.num_local_experts = moe_num_experts // self.ep_size + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.token_dispatcher = MoEFlexTokenDispatcher( self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group ) - self.experts = nn.LayerList([expert_class(**expert_kwargs) for _ in range(self.num_local_experts)]) + self.expert_parallel_degree = 1 if self.ep_size < 0 else self.ep_size + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + self.experts = nn.LayerList([]) + for i in range(self.moe_num_experts): + if i // self.moe_num_experts_per_device == self.moe_rank: + self.experts.append(expert_class(**expert_kwargs)) + else: + self.experts.append(None) self.router = gate def expert_forward(self, dispatched_input, tokens_per_expert): @@ -359,10 +368,12 @@ def expert_forward(self, dispatched_input, tokens_per_expert): tokens_per_expert = tokens_per_expert.tolist() # print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}") chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) - for chunk, expert in zip(chunks, self.experts): + for i, chunk in enumerate(chunks): chunk = chunk.contiguous() # assert chunk.shape[0] != 0, "Cannot dispatch empty input" # print("expert token:", chunk.shape, flush=True) + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device] outputs += [expert(chunk)] return paddle.concat(outputs, axis=0) @@ -377,3 +388,13 @@ def forward(self, hidden_states: paddle.Tensor): expert_output = self.expert_forward(dispatched_input, tokens_per_expert) output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) return output, l_aux, l_zloss + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device From d1a3d888795667fb40575a68118f0ac97ef0cbfd Mon Sep 17 00:00:00 2001 From: Difer <707065510@qq.com> Date: Fri, 29 Aug 2025 12:07:31 +0800 Subject: [PATCH 07/14] add warm load (#11029) --- paddlenlp/trainer/trainer.py | 28 ++ paddlenlp/trainer/training_args.py | 7 +- paddlenlp/trainer/utils/load_utils.py | 258 ++++++++++++++++++ .../transformers/deepseek_v2/modeling_pp.py | 10 +- 4 files changed, 298 insertions(+), 5 deletions(-) create mode 100644 paddlenlp/trainer/utils/load_utils.py diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b48679dbf26a..1e3b34bc8b0c 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -189,6 +189,7 @@ nested_numpify, nested_truncate, ) +from .utils.load_utils import load_paddle_model_from_safetensors from .utils.sharding_io import ShardingIO DEFAULT_CALLBACKS = [DefaultFlowCallback] @@ -1108,6 +1109,13 @@ def _inner_training_loop( if self.args.ignore_data_skip: self.timers and self.timers("read-data").start() + if self.args.hf_ckpt_dir is not None: + print("Start loading the Hugging Face model with warm start") + weight_map_path = os.path.join(self.args.hf_ckpt_dir, "model.safetensors.index.json") + ckpt_pre = self.args.hf_ckpt_dir + + load_paddle_model_from_safetensors(model, weight_map_path, ckpt_pre, verbose=True) + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( train_dataloader.batch_sampler, DistributedBatchSampler @@ -1343,8 +1351,28 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" ) elif isinstance(self.optimizer, HybridParallelOptimizer): + # print("hack for moe grad") + # for p in parameters_list: + # if getattr(p, 'is_moe_param', False): + # if p.grad is not None: + # # print(p.name, p.grad) + # p.grad /= 8 + # if p.main_grad is not None: + # # print(p.name, p.main_grad) + # p.main_grad /= 8 + self.optimizer._step(parameters_list) else: + # print("hack for moe gradr") + # for p in parameters_list: + # if getattr(p, 'is_moe_param', False): + # if p.grad is not None: + # print(p.name, p.grad) + # p.grad /= 4 + # if p.main_grad is not None: + # print(p.name, p.main_grad) + # p.main_grad /= 4 + self.optimizer.step() if self.args.offload_optim: diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 9f3c89f6fa23..31f982e26851 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1081,12 +1081,17 @@ class TrainingArguments: nccl_comm_group_config: Optional[str] = field( default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"} ) - + pre_alloc_memory: int = field( default=0, metadata={"help": "pre allocate memory size GB"}, ) + hf_ckpt_dir: Optional[str] = field( + default=None, + metadata={"help": "huggingface checkpoint dir"}, + ) + def __post_init__(self): world_size = paddle.distributed.get_world_size() if in_auto_parallel_align_mode(): diff --git a/paddlenlp/trainer/utils/load_utils.py b/paddlenlp/trainer/utils/load_utils.py new file mode 100644 index 000000000000..312dae7a6755 --- /dev/null +++ b/paddlenlp/trainer/utils/load_utils.py @@ -0,0 +1,258 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from collections import defaultdict +from typing import List, Optional + +import paddle +from paddle.distributed import fleet +from safetensors import safe_open + +# develop: "_layers.." +_LAYER_RE = re.compile(r"^_layers\.(\d+)(?:\.(.*))?$") +_EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$") +_EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$") + +custom_name_map = { + "mlp.router.weight": "mlp.gate.weight", + "mlp.router.e_score_correction_bias": "mlp.gate.e_score_correction_bias", +} + + +def _layers_match(name: str): + return _LAYER_RE.match(name) + + +def simple_safe_call(model, method_name, *args, **kwargs): + if hasattr(model, method_name): + return getattr(model, method_name)(*args, **kwargs) + if hasattr(model, "_layers") and hasattr(model._layers, method_name): + return getattr(model._layers, method_name)(*args, **kwargs) + raise AttributeError(f"{type(model).__name__} (or its wrapper) has no method {method_name}") + + +def add_prefix_to_keys(d, prefix): + print("Input dict:", d) + + mappings = {} + for key, value in d.items(): + if key == "embed_tokens.weight": + new_key = "_layers.0.embed_tokens.weight" + elif key == "lm_head.weight": + new_key = "_layers.64.weight" + else: + new_key = f"{prefix}{key}" + mappings[new_key] = value + return mappings + + +def _get_hf_prefix_develop(idx: int) -> str: + if idx == 0: + return "model" # embedding + if idx == 63: + return "model" # final norm + if idx == 64: + return "lm_head" # lm_head + return f"model.layers.{idx - 1}" # decoder layer + + +def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if m := _EXPERT_W1_RE.match(rest): + expert_id = int(m.group(1)) + return [ + f"{hf_prefix}.mlp.experts.{expert_id}.gate_proj.weight", + f"{hf_prefix}.mlp.experts.{expert_id}.up_proj.weight", + ] + if m := _EXPERT_W2_RE.match(rest): + expert_id = int(m.group(1)) + return [ + f"{hf_prefix}.mlp.experts.{expert_id}.down_proj.weight", + ] + return None + + +def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if rest == "mlp.w1": + return [ + f"{hf_prefix}.mlp.gate_proj.weight", + f"{hf_prefix}.mlp.up_proj.weight", + ] + if rest == "mlp.w2": + return [ + f"{hf_prefix}.mlp.down_proj.weight", + ] + return None + + +def paddle_name_to_hf_names(paddle_name: str) -> List[str]: + """ + Mapping Function for Paddle Parameter Names to Hugging Face Names + """ + m = _layers_match(paddle_name) + if not m: + return [] + idx = int(m.group(1)) + rest = m.group(2) or "" + + hf_prefix = _get_hf_prefix_develop(idx) + + # 专项重命名 + if rest in custom_name_map: + return [f"{hf_prefix}.{custom_name_map[rest]}"] + + # 历史专家 + if expert_names := _handle_expert_weights(hf_prefix, rest): + return expert_names + + # 历史mlp + if mlp_names := _handle_mlp_weights(hf_prefix, rest): + return mlp_names + + return [f"{hf_prefix}.{rest}"] if rest else [hf_prefix] + + +def prepare_tensor(tensor, pd_param, tensor_parallel_mappings, mp_degree, dst_shape): + """ + Converting weight tensors to match the target model’s shape involves + automatically adjusting for transposing, concatenating, and slicing by columns or lengths. + """ + + if isinstance(tensor, list): + tensor = paddle.concat( + [ + paddle.transpose(tensor[0], perm=[1, 0]).contiguous(), + paddle.transpose(tensor[1], perm=[1, 0]).contiguous(), + ], + axis=-1, + ) + # match for transpose + if len(tensor.shape) == 2: + if (tensor.shape[0] == dst_shape[1] or tensor.shape[1] == dst_shape[0]) and tensor.shape != dst_shape: + tensor = paddle.transpose(tensor, perm=[1, 0]).contiguous() + print(f"after transpose get hf tensor shape {tensor.shape}, paddle shape {dst_shape}") + + if mp_degree > 1 and pd_param in tensor_parallel_mappings: + tensor = tensor_parallel_mappings[pd_param](tensor) + if tensor.shape == dst_shape: + return tensor + raise ValueError(f"Unexpected tensor shape: got {tensor.shape}, want {dst_shape}") + + +def load_paddle_model_from_safetensors( + model, + weight_map_path: str, + ckpt_pre: str, + verbose: bool = True, +): + """ + Load safetensors into a Paddle model using the weight mappings outlined in index.json. + """ + + tensor_parallel_mappings = {} + mp_degree = fleet.get_hybrid_communicate_group().get_model_parallel_world_size() + print("fuck mp degree!!!!!!!!!", mp_degree) + + if mp_degree > 1: + print("load with mp_degree:", mp_degree) + tensor_parallel_mappings = simple_safe_call(model, "get_tensor_parallel_mappings", is_split=True) + tensor_parallel_mappings = add_prefix_to_keys(tensor_parallel_mappings, "_") + + for k, v in tensor_parallel_mappings.items(): + print("tensor_parallel_mappings:", k, v) + + with open(weight_map_path, "r") as f: + weight_map = json.load(f)["weight_map"] + + required_files = set() + file_to_pd_param_name = defaultdict(list) + pd_param_name_to_file = defaultdict(list) + + for pd_name, _ in model.named_parameters(): + hf_names = paddle_name_to_hf_names(pd_name) + if verbose: + print(f"paddle_name_to_hf_names: {pd_name} -> {hf_names}") + if not hf_names: + if verbose: + print(f"Warning: {pd_name} can not be mapped") + continue + for i, hf_name in enumerate(hf_names): + if hf_name in weight_map: + filename = weight_map[hf_name] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + if filename not in pd_param_name_to_file[pd_name]: + pd_param_name_to_file[pd_name].append(filename) + else: + if verbose: + print(f"Warning: {pd_name} -> {hf_name} not found in weight map") + + check_list = [] + if verbose: + print("---- start load param ----") + for key, value in tensor_parallel_mappings.items(): + print(key, value) + for filename in required_files: + try: + with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f: + pd_params = file_to_pd_param_name[filename] + for pd_param in pd_params: + if pd_param in check_list: + continue + if verbose: + print("load for pd_param:", pd_param) + hf_names = paddle_name_to_hf_names(pd_param) + if not hf_names: + continue + if len(hf_names) == 1: + tensor = f.get_tensor(hf_names[0]) + value = prepare_tensor( + tensor, pd_param, tensor_parallel_mappings, mp_degree, model.state_dict()[pd_param].shape + ) + + model.state_dict()[pd_param].set_value(paddle.cast(value, model.state_dict()[pd_param].dtype)) + else: + files = pd_param_name_to_file[pd_param] + if len(files) == 1: + tensor0 = f.get_tensor(hf_names[0]) + tensor1 = f.get_tensor(hf_names[1]) + else: + if weight_map[hf_names[0]] == filename: + tensor0 = f.get_tensor(hf_names[0]) + with safe_open( + ckpt_pre + weight_map[hf_names[1]], framework="paddle", device="cpu" + ) as f2: + tensor1 = f2.get_tensor(hf_names[1]) + else: + with safe_open( + ckpt_pre + weight_map[hf_names[0]], framework="paddle", device="cpu" + ) as f2: + tensor0 = f2.get_tensor(hf_names[0]) + tensor1 = f.get_tensor(hf_names[1]) + value = prepare_tensor( + [tensor0, tensor1], + pd_param, + tensor_parallel_mappings, + mp_degree, + model.state_dict()[pd_param].shape, + ) + model.state_dict()[pd_param].set_value(value) + check_list.append(pd_param) + except Exception as e: + print(f"Error loading {filename}: {str(e)}") + raise + + if verbose: + print("All parameters loaded.") diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 45da8559a078..3a178c9728eb 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -170,7 +170,7 @@ def forward(self, args): batch_size, seq_length, _ = inputs_embeds.shape if self.sequence_parallel: - inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] # inputs_embeds = paddle.reshape(inputs_embeds, [-1, inputs_embeds.shape[-1]]) # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) @@ -185,7 +185,7 @@ def forward(self, args): axis=1, ) if self.sequence_parallel: - inputs_embeds_mtp = paddle.transpose(inputs_embeds_mtp, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds_mtp = paddle.transpose(inputs_embeds_mtp, [1, 0, 2]) # [B, S, H] --> [S, B, H] # inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) embeds_res.append(inputs_embeds_mtp) @@ -197,7 +197,7 @@ def forward(self, args): return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) else: if self.sequence_parallel: - inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] # inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) inputs_embeds = ScatterOp.apply(inputs_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) @@ -270,7 +270,6 @@ def forward(self, args): class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) hidden_states_main_model = hidden_states_list[0] inputs_embeds_cur_depth_list = hidden_states_list[1:] @@ -525,3 +524,6 @@ def get_hcg(): def get_loss_fn(self, config): return DeepseekV2PretrainingCriterionPipe(config) + + def get_tensor_parallel_mappings(self, is_split=True): + return type(self)._get_tensor_parallel_mappings(self.config, is_split) From 991a573528037fff82ced4d45580389e373af09b Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Fri, 29 Aug 2025 19:00:36 +0800 Subject: [PATCH 08/14] fix bug in subbatch and add 128k sft config (#11032) --- llm/config/deepseek-v3/sft_128k_argument.json | 54 +++++++++++++++++++ .../transformers/deepseek_v2/modeling.py | 31 ++++++----- .../transformers/deepseek_v2/modeling_pp.py | 12 ++--- 3 files changed, 78 insertions(+), 19 deletions(-) create mode 100644 llm/config/deepseek-v3/sft_128k_argument.json diff --git a/llm/config/deepseek-v3/sft_128k_argument.json b/llm/config/deepseek-v3/sft_128k_argument.json new file mode 100644 index 000000000000..f7aa58cd5d0f --- /dev/null +++ b/llm/config/deepseek-v3/sft_128k_argument.json @@ -0,0 +1,54 @@ +{ + "model_name_or_path": "./dsv3_128k_config", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 16, + "per_device_eval_batch_size": 1, + "eval_accumulation_steps": 1, + "num_train_epochs": 1, + "max_steps": 20, + "learning_rate": 2.2e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 2048, + "max_length": 131073, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "use_expert_parallel": true, + "expert_parallel_degree": 8, + "continue_training": false, + "pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm", + "tensor_parallel_config": "enable_delay_scale_loss", + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "loss", + "recompute": true, + "recompute_use_reentrant": true, + "recompute_granularity": "full", + "save_total_limit": 1, + "tensor_parallel_degree": 8, + "pipeline_parallel_degree": 16, + "sharding_parallel_degree": 1, + "sharding": "stage1", + "zero_padding": true, + "unified_checkpoint": true, + "use_flash_attention": true, + "flash_mask": true, + "using_fake_gate": true, + "using_flex_token": true, + "use_fused_rms_norm": true, + "moe_subbatch_token_num": 1024, + "pre_alloc_memory": 70, + "tensorwise_offload_optimizer": true, + "sequence_parallel": true, + "tensor_parallel_output": true, + "amp_master_grad": true, + "sharding_parallel_config": "split_param" +} + diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 6e9f17f39205..39724e5f46c8 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1227,6 +1227,7 @@ def subbatch_recompute_forward( past_key_value, use_cache, attn_mask_startend_row_indices, + **offload_kwargs, ) hidden_states = attn_outputs[0] @@ -1236,22 +1237,22 @@ def subbatch_recompute_forward( assert len(hidden_states.shape) == 3 sub_seq_len = self.config.moe_subbatch_token_num - seq_len = hidden_states.shape[1] + seq_axis = 0 if self.config.sequence_parallel else 1 + seq_len = hidden_states.shape[seq_axis] assert seq_len % sub_seq_len == 0 num_chunks = seq_len // sub_seq_len split_list = [sub_seq_len] * num_chunks - input_list = paddle.split(hidden_states, split_list, axis=1) + input_list = paddle.split(hidden_states, split_list, axis=seq_axis) output_list = [] + for chunk in input_list: - offload_kwargs = {} - offload_kwargs["offload_indices"] = [0] out = recompute( self.mlp.forward, chunk, **offload_kwargs, ) output_list.append(out) - hidden_states = paddle.concat(output_list, axis=1) + hidden_states = paddle.concat(output_list, axis=seq_axis) outputs = recompute( self.post_process, hidden_states, @@ -1929,7 +1930,7 @@ def forward( if attention_mask is not None: attention_mask = attention_mask[ :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers - ] + ].contiguous() # attn_mask_startend_row_indices: [b, num_head, seq_len] or [b, num_head, seq_len, C], C is 2 or 4 if attn_mask_startend_row_indices is not None: @@ -1938,11 +1939,11 @@ def forward( :, :, : -self.config.num_nextn_predict_layers, - ] + ].contiguous() elif attn_mask_startend_row_indices.ndim == 4: attn_mask_startend_row_indices = attn_mask_startend_row_indices[ :, :, : -self.config.num_nextn_predict_layers, : - ] + ].contiguous() else: raise ValueError("attn_mask_startend_row_indices must be 3D or 4D tensor") @@ -2004,7 +2005,7 @@ def forward( inputs_embeds = ScatterOp.apply(inputs_embeds) # embed positions - hidden_states = inputs_embeds + hidden_states = inputs_embeds.contiguous() # decoder layers all_hidden_states = () if output_hidden_states else None @@ -2167,10 +2168,18 @@ def compute_loss(preds, labels): masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) ) count = paddle.sum(binary_sequence) + + if self.config.sequence_parallel: + dist.all_reduce(count, op=ReduceOp.SUM, group=self.mp_group) + if count == 0: loss = paddle.sum(masked_lm_loss * binary_sequence) else: loss = paddle.sum(masked_lm_loss * binary_sequence) / count + + if self.config.sequence_parallel: + dist.all_reduce(loss, op=ReduceOp.SUM, group=self.mp_group) + return loss def add_loss(main_loss, loss): @@ -2188,10 +2197,6 @@ def add_loss(main_loss, loss): loss = compute_loss(prediction_scores, masked_lm_labels) - if self.config.sequence_parallel: - loss = loss * self.seq_para_scale - dist.all_reduce(loss, op=ReduceOp.SUM, group=self.mp_group) - mtp_loss_res = [] for depth in range(self.config.num_nextn_predict_layers): prediction_scores_cur_depth = mtp_logits[depth] diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 3a178c9728eb..8e53c39d4c33 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -72,14 +72,14 @@ def parse_args(args): def return_args(hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None): - ret = (hidden_states,) + ret = (hidden_states.contiguous(),) if attention_mask is not None: - ret += (attention_mask.clone(),) + ret += (attention_mask.contiguous().clone(),) if attn_mask_startend_row_indices is not None: - ret += (attn_mask_startend_row_indices.clone(),) + ret += (attn_mask_startend_row_indices.contiguous().clone(),) if position_ids is not None: - ret += (position_ids.clone(),) + ret += (position_ids.contiguous().clone(),) if len(ret) == 1: ret = ret[0] @@ -210,8 +210,8 @@ def forward(self, args): if self.config.num_nextn_predict_layers > 0: hidden_size = hidden_states.shape[-1] batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) - inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:] - hidden_states = hidden_states[..., :batch_size_mtp] + inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:].contiguous() + hidden_states = hidden_states[..., :batch_size_mtp].contiguous() has_gradient = not hidden_states.stop_gradient From 7adac11d86ae13e849b06adcfd42259deb474f8d Mon Sep 17 00:00:00 2001 From: deepllz Date: Mon, 1 Sep 2025 15:42:32 +0800 Subject: [PATCH 09/14] compatible with lastest paddle develop branch && update SFT train config to get better performance --- llm/config/deepseek-v3/sft_argument.json | 11 +++---- paddlenlp/trainer/utils/ckpt_converter.py | 39 +++++++++++++++++------ paddlenlp/transformers/moe_utils.py | 13 +++++--- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index cd5cb98714d8..766f4682ac01 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -12,8 +12,8 @@ "learning_rate": 2.2e-04, "warmup_steps": 30, "logging_steps": 1, - "evaluation_strategy": "steps", - "save_strategy": "steps", + "evaluation_strategy": "no", + "save_strategy": "no", "src_length": 2048, "max_length": 4097, "bf16": true, @@ -22,7 +22,7 @@ "do_eval": false, "disable_tqdm": true, "use_expert_parallel": true, - "expert_parallel_degree": 8, + "expert_parallel_degree": 16, "continue_training": false, "pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm", "tensor_parallel_config": "enable_delay_scale_loss", @@ -35,8 +35,8 @@ "recompute_granularity": "full", "save_total_limit": 1, "tensor_parallel_degree": 1, - "pipeline_parallel_degree": 16, - "sharding_parallel_degree": 8, + "pipeline_parallel_degree": 8, + "sharding_parallel_degree": 16, "sharding": "stage1", "zero_padding": true, "unified_checkpoint": true, @@ -51,4 +51,3 @@ "sequence_parallel": false, "tensor_parallel_output": true } - diff --git a/paddlenlp/trainer/utils/ckpt_converter.py b/paddlenlp/trainer/utils/ckpt_converter.py index 23f085e18f44..648dfc3b42dc 100644 --- a/paddlenlp/trainer/utils/ckpt_converter.py +++ b/paddlenlp/trainer/utils/ckpt_converter.py @@ -19,16 +19,35 @@ from typing import List, Union import paddle -from paddle.distributed.checkpoint.load_state_dict import ( - _load_state_dict, - get_rank_to_read_files, -) -from paddle.distributed.checkpoint.metadata import ( - LocalTensorIndex, - LocalTensorMetadata, - Metadata, -) -from paddle.distributed.checkpoint.utils import flatten_state_dict + +try: + from paddle.distributed.checkpoint.load_state_dict import ( + _load_state_dict, + get_rank_to_read_files, + ) +except ImportError: + from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( + _load_state_dict, + get_rank_to_read_files, + ) + +try: + from paddle.distributed.checkpoint.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, + ) +except ImportError: + from paddle.distributed.flex_checkpoint.dcp.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, + ) + +try: + from paddle.distributed.checkpoint.utils import flatten_state_dict +except ImportError: + from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict from paddle.distributed.fleet.utils.log_util import logger MODEL_WEIGHT_SUFFIX = ".pdparams" diff --git a/paddlenlp/transformers/moe_utils.py b/paddlenlp/transformers/moe_utils.py index d82654f6375b..4743c70e8149 100644 --- a/paddlenlp/transformers/moe_utils.py +++ b/paddlenlp/transformers/moe_utils.py @@ -96,8 +96,13 @@ def unpermute( # Create an output tensor filled with zeros output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype) # Scatter add the permuted_input back to the original positions - if scatter_add_ is not None: - scatter_add_(output_tokens, sorted_indices, permuted_tokens) - else: - output_tokens.scatter_(index=sorted_indices, updates=permuted_tokens, overwrite=False) + # if scatter_add_ is not None: + # # NOTE: this expand will cause a big memory usage, so disable this method + # sorted_indices = sorted_indices.unsqueeze(1).expand(-1, hidden) + # output_tokens.scatter_add_(0, sorted_indices, permuted_tokens) + # else: + # NOTE: Calling multiple times of scatter_ will not accumulate, + # Instead, it reset to zero and then accumulated again. + # so can't do subbatch here. + output_tokens.scatter_(index=sorted_indices, updates=permuted_tokens, overwrite=False) return output_tokens From a0fcbdcd9a6ab5ca2035747b53a3666da59d183d Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Mon, 1 Sep 2025 16:35:57 +0800 Subject: [PATCH 10/14] update dsv3 SFT 128K train config to get better performance (#11037) --- llm/config/deepseek-v3/sft_128k_argument.json | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/llm/config/deepseek-v3/sft_128k_argument.json b/llm/config/deepseek-v3/sft_128k_argument.json index f7aa58cd5d0f..5ef15ec31109 100644 --- a/llm/config/deepseek-v3/sft_128k_argument.json +++ b/llm/config/deepseek-v3/sft_128k_argument.json @@ -1,5 +1,5 @@ { - "model_name_or_path": "./dsv3_128k_config", + "model_name_or_path": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/", "dataset_name_or_path": "./data", "output_dir": "./checkpoints/sft_ckpts", "per_device_train_batch_size": 1, @@ -11,17 +11,17 @@ "learning_rate": 2.2e-04, "warmup_steps": 30, "logging_steps": 1, - "evaluation_strategy": "epoch", - "save_strategy": "epoch", + "evaluation_strategy": "no", + "save_strategy": "no", "src_length": 2048, "max_length": 131073, "bf16": true, "fp16_opt_level": "O2", "do_train": true, - "do_eval": true, + "do_eval": false, "disable_tqdm": true, "use_expert_parallel": true, - "expert_parallel_degree": 8, + "expert_parallel_degree": 16, "continue_training": false, "pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm", "tensor_parallel_config": "enable_delay_scale_loss", @@ -33,8 +33,8 @@ "recompute_granularity": "full", "save_total_limit": 1, "tensor_parallel_degree": 8, - "pipeline_parallel_degree": 16, - "sharding_parallel_degree": 1, + "pipeline_parallel_degree": 8, + "sharding_parallel_degree": 2, "sharding": "stage1", "zero_padding": true, "unified_checkpoint": true, @@ -44,7 +44,7 @@ "using_flex_token": true, "use_fused_rms_norm": true, "moe_subbatch_token_num": 1024, - "pre_alloc_memory": 70, + "pre_alloc_memory": 60, "tensorwise_offload_optimizer": true, "sequence_parallel": true, "tensor_parallel_output": true, From ad9e95bd7024297bc389fa4dec67fbc4000b4cb5 Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Wed, 3 Sep 2025 17:20:43 +0800 Subject: [PATCH 11/14] adapt global_norm_clip for hybrid expert parallel (#11061) --- paddlenlp/optimizers/__init__.py | 15 + .../moe_hybrid_parallel_optimizer.py | 431 ++++++++++++++++++ paddlenlp/trainer/trainer.py | 30 +- paddlenlp/trainer/training_args.py | 3 - .../transformers/deepseek_v2/modeling.py | 19 +- 5 files changed, 490 insertions(+), 8 deletions(-) create mode 100644 paddlenlp/optimizers/__init__.py create mode 100644 paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py diff --git a/paddlenlp/optimizers/__init__.py b/paddlenlp/optimizers/__init__.py new file mode 100644 index 000000000000..5091716a263e --- /dev/null +++ b/paddlenlp/optimizers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .moe_hybrid_parallel_optimizer import MoEHybridParallelOptimizer diff --git a/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py new file mode 100644 index 000000000000..47dac4714624 --- /dev/null +++ b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py @@ -0,0 +1,431 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logger +import paddle +import paddle.distributed as dist +from paddle.autograd import no_grad +from paddle.distributed.fleet.base.topology import ParallelMode +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, +) +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer as HPBase, +) +from paddle.distributed.fleet.utils import timer_helper as timer +from paddle.distributed.fleet.utils.hybrid_parallel_util import unwrap_optimizer +from paddle.distributed.fleet.utils.mix_precision_utils import MixPrecisionOptimizer +from paddle.framework import core +from paddle.nn import ClipGradByGlobalNorm, clip + +__all__ = [] + + +class MoEHybridParallelClipGrad: + def __init__(self, clip, hcg, timers=None): + self._clip = clip + self._hcg = hcg + if hasattr(hcg, "get_moe_sharding_parallel_world_size") and hcg.get_moe_sharding_parallel_world_size() > 0: + # hybrid expert parallel + self.moe_group = hcg.get_expert_parallel_group() + self.moe_sharding_group = hcg.get_moe_sharding_parallel_group() + + self.stat = {} # for logging + self._timers = timers + self.processed_steps = 0 + + def _global_norm( + self, global_norm_var_dist, global_norm_var_not_dist, global_norm_var_dist_moe, global_norm_var_not_dist_moe + ): + # sharding first + sharding_flag = self._hcg.get_sharding_parallel_world_size() > 1 + mp_flag = self._hcg.get_model_parallel_world_size() > 1 + pp_flag = self._hcg.get_pipe_parallel_world_size() > 1 + + """do comm""" + # logger.info( + # f"before reduce: dist-moe-grad-norm={global_norm_var_dist_moe.item()} " + # f"before reduce: non-dist-moe-grad-norm={global_norm_var_not_dist_moe.item()}" + # ) + + if self.moe_sharding_group: + dist.all_reduce( + global_norm_var_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_sharding_group, + ) + dist.all_reduce( + global_norm_var_not_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_sharding_group, + ) + + if self.moe_group: + dist.all_reduce( + global_norm_var_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_group, + ) + dist.all_reduce( + global_norm_var_not_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_group, + ) + + if pp_flag: + paddle.distributed.all_reduce( + global_norm_var_dist_moe, + group=self._hcg.get_pipe_parallel_group(), + ) + paddle.distributed.all_reduce( + global_norm_var_not_dist_moe, + group=self._hcg.get_pipe_parallel_group(), + ) + + # logger.info( + # f"after reduce: dist-moe-grad-norm={global_norm_var_dist_moe.item()} " + # f"after reduce: non-dist-moe-grad-norm={global_norm_var_not_dist_moe.item()}" + # ) + + # logger.info( + # f"before reduce: dist-grad-norm={global_norm_var_dist.item()} " + # f"before reduce: non-dist-grad-norm={global_norm_var_not_dist.item()}" + # ) + # add all reduce to get global norm of distributed params_and_grads + if sharding_flag: + # norm of mp distributed variable + if mp_flag: + # dist should reduce among sharding group、mp group、pp group + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_sharding_parallel_group(), + ) + # not dist only reduce among sharding group and pp group later + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_sharding_parallel_group(), + ) + + # norm of mp distributed variable + if mp_flag: + # dist should reduce among sharding group、mp group、pp group + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_model_parallel_group(), + ) + if pp_flag: + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_pipe_parallel_group(), + ) + + # add all reduce to get global norm of non-distributed params_and_grads in groups of pp + if pp_flag: + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_pipe_parallel_group(), + ) + + # logger.info( + # f"after reduce: dist-grad-norm={global_norm_var_dist.item()} " + # f"after reduce: non-dist-grad-norm={global_norm_var_not_dist.item()}" + # ) + + @no_grad() + def _dygraph_clip(self, params_grads): + if self._timers: + self._timers("dygraph-clip").start() + sum_square_dist_fp16 = [] + sum_square_dist_bf16 = [] + sum_square_dist_fp32 = [] + + sum_square_dist_moe_fp16 = [] + sum_square_dist_moe_bf16 = [] + sum_square_dist_moe_fp32 = [] + + sum_square_not_dist_fp16 = [] + sum_square_not_dist_bf16 = [] + sum_square_not_dist_fp32 = [] + + sum_square_not_dist_moe_fp16 = [] + sum_square_not_dist_moe_bf16 = [] + sum_square_not_dist_moe_fp32 = [] + + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + merge_grad = g + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = clip.merge_selected_rows(g) + merge_grad = clip.get_tensor_from_selected_rows(merge_grad) + sum_square = clip._squared_l2_norm(merge_grad) + + not_shared_enable = (not hasattr(p, "is_firstly_shared")) or ( + hasattr(p, "is_firstly_shared") and getattr(p, "is_firstly_shared", True) + ) + + is_moe_param = getattr(p, "is_moe_param", False) + print(f"p.name:{p.name}, is_moe_param:{is_moe_param}") + if is_moe_param: + assert 0 + if not_shared_enable: + if getattr(p, "no_sync", False): + if p.is_distributed: + if g.dtype == paddle.float16: + sum_square_dist_moe_fp16.append(sum_square) + elif g.dtype == paddle.bfloat16: + sum_square_dist_moe_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_dist_moe_fp32.append(sum_square) + else: + if g.dtype == paddle.float16: + sum_square_not_dist_moe_fp16.append(sum_square) + elif g.dtype == paddle.bfloat16: + sum_square_not_dist_moe_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_not_dist_moe_fp32.append(sum_square) + + elif p.is_distributed: + if g.dtype == paddle.float16: + sum_square_dist_fp16.append(sum_square) + elif g.dtype == paddle.bfloat16: + sum_square_dist_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_dist_fp32.append(sum_square) + else: + assert not getattr( + p, "no_sync", False + ), f"moe param shoud be distributed, got: {p.name}, shape={p.shape}" + if g.dtype == paddle.float16: + sum_square_not_dist_fp16.append(sum_square) + if g.dtype == paddle.bfloat16: + sum_square_not_dist_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_not_dist_fp32.append(sum_square) + else: + assert not getattr(p, "no_sync", False), "MoE cannot handle shared param" + + # assert ( + # sum_square_dist_moe_fp16 + # or sum_square_dist_moe_bf16 + # or sum_square_dist_moe_fp32 + # or sum_square_not_dist_moe_fp16 + # or sum_square_not_dist_moe_bf16 + # or sum_square_not_dist_moe_fp32 + # ), f"no moe param found" + + def add_n_list(tensor_list): + if not tensor_list: + return paddle.zeros((1,), dtype=paddle.float32) + return paddle.add_n(tensor_list).cast(paddle.float32) + + # moe global norm of distributed FP16 params_and_grads + global_norm_dist_moe_fp16 = add_n_list( + sum_square_dist_moe_fp16, + ) + global_norm_not_dist_moe_fp16 = add_n_list( + sum_square_not_dist_moe_fp16, + ) + global_norm_dist_fp16 = add_n_list( + sum_square_dist_fp16, + ) + global_norm_not_dist_fp16 = add_n_list( + sum_square_not_dist_fp16, + ) + + global_norm_dist_moe_bf16 = add_n_list( + sum_square_dist_moe_bf16, + ) + global_norm_not_dist_moe_bf16 = add_n_list( + sum_square_not_dist_moe_bf16, + ) + global_norm_dist_bf16 = add_n_list( + sum_square_dist_bf16, + ) + global_norm_not_dist_bf16 = add_n_list( + sum_square_not_dist_bf16, + ) + + global_norm_dist_moe_fp32 = add_n_list( + sum_square_dist_moe_fp32, + ) + global_norm_not_dist_moe_fp32 = add_n_list( + sum_square_not_dist_moe_fp32, + ) + global_norm_dist_fp32 = add_n_list( + sum_square_dist_fp32, + ) + global_norm_not_dist_fp32 = add_n_list( + sum_square_not_dist_fp32, + ) + + global_norm_var_dist_moe = global_norm_dist_moe_fp16 + global_norm_dist_moe_bf16 + global_norm_dist_moe_fp32 + + global_norm_var_not_dist_moe = ( + global_norm_not_dist_moe_fp16 + global_norm_not_dist_moe_bf16 + global_norm_not_dist_moe_fp32 + ) + + global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_bf16 + global_norm_dist_fp32 + global_norm_var_not_dist = global_norm_not_dist_fp16 + global_norm_not_dist_bf16 + global_norm_not_dist_fp32 + result = self._comm_and_clip( + params_grads, + global_norm_var_dist, + global_norm_var_not_dist, + global_norm_var_dist_moe, + global_norm_var_not_dist_moe, + ) + if self._timers: + self._timers("dygraph-clip").stop() + + return result + + def _comm_and_clip( + self, + params_grads, + global_norm_var_dist, + global_norm_var_not_dist, + global_norm_var_dist_moe, + global_norm_var_not_dist_moe, + ): + + self._global_norm( + global_norm_var_dist, global_norm_var_not_dist, global_norm_var_dist_moe, global_norm_var_not_dist_moe + ) + + global_norm_var_fp32 = paddle.sqrt( + global_norm_var_dist + global_norm_var_not_dist + global_norm_var_dist_moe + global_norm_var_not_dist_moe + ) + self.stat["global_grad_norm"] = global_norm_var_fp32.astype("float32").item() + + max_global_norm = paddle.full( + shape=[], + dtype=global_norm_var_fp32.dtype, + fill_value=self.clip_norm, + ) + clip_var = paddle.divide( + x=max_global_norm, + y=paddle.maximum(x=global_norm_var_fp32, y=max_global_norm) + + paddle.full(shape=[], dtype=paddle.float32, fill_value=1.0e-6), + ) + logger.info(f"hybrid-moe-clip, var={clip_var.item()}, global_norm:{global_norm_var_fp32.item()}") + clip_var_fp16 = paddle.cast(clip_var, paddle.float16) + + if ( + not isinstance(paddle.framework._current_expected_place(), paddle.CustomPlace) + or paddle.framework._current_expected_place().get_device_type() == "npu" + ): + clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16) + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + if g.dtype == paddle.float16: + g.multiply_(clip_var_fp16) + elif g.dtype == paddle.bfloat16: + if paddle.is_compiled_with_xpu(): + raise NotImplementedError("BF16 is not supported on XPU now") + g.multiply_(clip_var_bf16) + else: + g.multiply_(clip_var) + p._reset_grad_inplace_version(True) + + return params_grads + + def __getattr__(self, item): + return getattr(self._clip, item) + + def __call__(self, params_grads): + print("==== zyc debug in moe_hybrid_parallel_optimizer.py ====") + for p, g in params_grads: + has_moe_attr = hasattr(p, "is_moe_param") + is_moe_param = False + if has_moe_attr: + is_moe_param = p.is_moe_param + print(f"p.name:{p.name}, has_moe_attr:{has_moe_attr}, is_moe_param:{is_moe_param}") + return self._dygraph_clip(params_grads) + + +class MoEHybridParallelOptimizer(HPBase): + # adapter wrapper for optimizer + def __init__(self, optimizer, hcg, strategy): + # Note: Only sharding stage 1 is considered in HybridParallelOptimizer. + # The sharding stage2 and stage3 optimizers are invoked in other api. + print( + f"moe_sharding_degree:{hcg.get_moe_sharding_parallel_world_size()}, sharding_degree:{hcg.get_sharding_parallel_world_size()}, ep_degree:{hcg.get_expert_parallel_world_size()}" + ) + if hcg.get_moe_sharding_parallel_world_size() > 0: + split_param = strategy.hybrid_configs["sharding_configs"].split_param + assert ( + hcg.get_sharding_parallel_world_size() >= 1 and split_param is True + ), "Hybrid expert parallel only supports ShardingV2 now" + if hcg.get_sharding_parallel_world_size() > 1: + split_param = strategy.hybrid_configs["sharding_configs"].split_param + ShardingOptimizer = DygraphShardingOptimizerV2 if split_param else DygraphShardingOptimizer + optimizer = ShardingOptimizer(optimizer, hcg) + + self._enable_timer = strategy.hybrid_configs["enable_optimizer_timer"] + + if self._enable_timer: + if not timer.is_timer_initialized(): + timer.set_timers() + self._timers = timer.get_timers() + else: + self._timers = None + + self._inner_opt = optimizer + self._strategy = strategy + self._hcg = hcg + + self._use_dp_mode = self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL + + self._need_dp = self._hcg.get_data_parallel_world_size() > 1 + + self._dp_enable = not self._use_dp_mode and self._need_dp + + self._sharding_enable = self._hcg.get_sharding_parallel_world_size() > 1 + + self._sep_enable = self._hcg.get_sep_parallel_world_size() > 1 + + if isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm) and not self._use_dp_mode: + logger.warning( + "While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " + "or Sharding, the grad clip of original optimizer will be changed." + ) + + inner_opt = unwrap_optimizer( + self._inner_opt, + ( + MixPrecisionOptimizer, + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, + ), + ) + + if ( + inner_opt._parameter_list + and not isinstance(inner_opt._parameter_list[0], dict) + and len([p for p in inner_opt._parameter_list if hasattr(p, "main_grad")]) > 0 + ): + inner_opt._grad_clip = MoEHybridParallelClipGrad(inner_opt._grad_clip, hcg, self._timers) + else: + inner_opt._grad_clip = MoEHybridParallelClipGrad(inner_opt._grad_clip, hcg, self._timers) + if inner_opt._parameter_list and isinstance(inner_opt._parameter_list[0], dict): + for item in inner_opt._param_groups: + if "grad_clip" in item.keys(): + item["grad_clip"] = MoEHybridParallelClipGrad(inner_opt._grad_clip, hcg, self._timers) + self.processed_steps = 0 diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 1e3b34bc8b0c..be2dada70a3b 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2196,6 +2196,30 @@ def _decorate_exclude_layers(self, model: nn.Layer): exclude_layers = [] return exclude_layers + def _wrap_distributed_optimizer(self, optimizer): + """ + In hybrid expert parallel, use customized optimizer and grad clip + """ + if ( + self.args.use_expert_parallel + and self.args.moe_sharding_parallel_degree >= 1 + and self.args.expert_parallel_degree > 1 + ): + from paddlenlp.optimizers import MoEHybridParallelOptimizer + + fleet_env = fleet.fleet + fleet_env.user_defined_optimizer = optimizer + hp_optim = MoEHybridParallelOptimizer(optimizer, fleet_env._hcg, fleet_env._user_defined_strategy) + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].dp_comm_overlap: + hp_optim._dp_enable = False + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].sharding_comm_overlap: + hp_optim._sharding_enable = False + return hp_optim + else: + return fleet.distributed_optimizer(optimizer) + def _wrap_model(self, model, training=True): # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again @@ -2311,7 +2335,7 @@ def get_expected_keys(inputs, keys): assert self.optimizer is not None, "Pipeline mode need decorate optimizer, pelease init optimizer." if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.optimizer = self._wrap_distributed_optimizer(self.optimizer) if ( hasattr(self.args, "enable_sharding_comm_overlap") @@ -2342,7 +2366,7 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.optimizer = self._wrap_distributed_optimizer(self.optimizer) else: cpu_offload = ShardingOption.OFFLOAD in self.args.sharding assert self.optimizer is not None, "optimizer is empty!" @@ -2400,7 +2424,7 @@ def get_expected_keys(inputs, keys): assert self.optimizer is not None, "Tensor parallel mode need decorate optimizer, pelease init optimizer." if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.optimizer = self._wrap_distributed_optimizer(self.optimizer) # stage1 has v1 and v2 version if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 31f982e26851..c6971f0a8aeb 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1704,9 +1704,6 @@ def is_context_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) - if self.expert_parallel_degree > 1: - self.add_moe_comm_group() - elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 39724e5f46c8..c824871c9955 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -797,9 +797,17 @@ def __init__(self, config: DeepseekV2Config): moe_group="expert", ) + self.is_mp_moe = False + self.is_ep_moe = True for p in self.experts.parameters(): - setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) setattr(p, "is_moe_param", True) + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + p.no_sync = not self.is_mp_moe + p.expert = not self.is_mp_moe + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True + self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -849,9 +857,16 @@ def __init__(self, config: DeepseekV2Config): moe_group=moe_group, ) + self.is_mp_moe = False + self.is_ep_moe = True for p in self.experts.parameters(): - setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) setattr(p, "is_moe_param", True) + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + p.no_sync = not self.is_mp_moe + p.expert = not self.is_mp_moe + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: From 7e317e6ca97448015b7b3cfa52b4184f13d9fff7 Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Wed, 3 Sep 2025 19:42:07 +0800 Subject: [PATCH 12/14] fix aux_loss_alpha && lr value too big problem and add aux update callback and add mtp subatch_recompute (#11062) * fix ep grad * fix aux_loss_alpha && lr value too big problem and add aux update callback and add mtp subatch_recompute * fix logger error --- llm/config/deepseek-v3/sft_argument.json | 11 ++-- llm/run_finetune.py | 22 +++++-- .../moe_hybrid_parallel_optimizer.py | 2 +- paddlenlp/trainer/trainer.py | 35 +++++----- paddlenlp/trainer/trainer_callback.py | 65 +++++++++++++++++++ .../transformers/deepseek_v2/configuration.py | 2 +- .../transformers/deepseek_v2/modeling.py | 43 ++++++++++++ .../transformers/deepseek_v2/modeling_pp.py | 11 +++- paddlenlp/transformers/moe_gate.py | 2 +- paddlenlp/trl/model_config.py | 1 + 10 files changed, 160 insertions(+), 34 deletions(-) diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index 766f4682ac01..edc8452fe09e 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -1,15 +1,18 @@ { "model_name_or_path": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/", - "dataset_name_or_path": "./data", + "hf_ckpt_dir": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/", + "dataset_name_or_path": "/root/paddlejob/workspace/env_run/zhengzhonghui/data_100k/", "output_dir": "./checkpoints/sft_ckpts", "per_device_train_batch_size": 1, "gradient_accumulation_steps": 16, "per_device_eval_batch_size": 1, "eval_accumulation_steps": 1, - "max_steps": 20, + "max_steps": 100, + "max_grad_norm": 0, "amp_master_grad": true, "num_train_epochs": 1, - "learning_rate": 2.2e-04, + "learning_rate": 2.2e-05, + "aux_loss_alpha": 0.0001, "warmup_steps": 30, "logging_steps": 1, "evaluation_strategy": "no", @@ -42,7 +45,7 @@ "unified_checkpoint": true, "use_flash_attention": true, "flash_mask": true, - "using_fake_gate": true, + "using_fake_gate": false, "using_flex_token": true, "use_fused_rms_norm": true, "moe_subbatch_token_num": 0, diff --git a/llm/run_finetune.py b/llm/run_finetune.py index b129cf27df30..de40634ef379 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -46,7 +46,7 @@ ReFTModel, intervention_mapping, ) -from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed, MoECorrectionBiasAdjustCallback from paddlenlp.trainer.trainer_callback import TrainerState from paddlenlp.transformers import ( AutoConfig, @@ -262,6 +262,11 @@ def main(): if model_args.strategy_name == "YaRNScalingRotaryEmbedding": model_config.long_sequence_init_args["original_max_position_embeddings"] = data_args.max_length + + model_config.using_flex_token = model_args.using_flex_token + model_config.using_fake_gate = model_args.using_fake_gate + model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num + model_config.aux_loss_alpha = model_args.aux_loss_alpha logger.info(f"Final model config: {model_config}") logger.info("Creating model") @@ -272,10 +277,7 @@ def main(): raise ValueError("Please set eval_with_do_generation to false in pipeline parallel mode.") model_class = AutoModelForCausalLMPipe - model_config.using_flex_token = model_args.using_flex_token - model_config.using_fake_gate = model_args.using_fake_gate - model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num - print("model_config ", model_config, flush=True) + if model_args.continue_training and not training_args.autotuner_benchmark: model = model_class.from_pretrained( model_args.model_name_or_path, @@ -467,6 +469,14 @@ def compute_metrics_do_generation(eval_preds): return_attention_mask=not model_args.flash_mask, pad_to_multiple_of=data_args.pad_to_multiple_of, ) + callbacks = [] + if isinstance(train_ds, ZeroPaddingIterableDataset): + callbacks += [ZeroPaddingIterDatasetCallback()] + + if getattr(model_config, "topk_method", None) == "noaux_tc": + callbacks += [MoECorrectionBiasAdjustCallback(lr=0)] + + print("callbacks:", callbacks, flush=True) trainer = SFTTrainer( model=model, args=training_args, @@ -476,7 +486,7 @@ def compute_metrics_do_generation(eval_preds): compute_metrics=metrics, data_collator=data_collator_fn if not model_args.reft else ReftDataCollator(data_collator=data_collator_fn), do_generation=data_args.eval_with_do_generation, - callbacks=[ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None, + callbacks=callbacks, gen_args=gen_args, data_args=data_args, ) diff --git a/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py index 47dac4714624..b11263fe90e6 100644 --- a/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py +++ b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logger +from paddle.distributed.fleet.utils.log_util import logger import paddle import paddle.distributed as dist from paddle.autograd import no_grad diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index be2dada70a3b..db23594b7bc1 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1323,6 +1323,21 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): elif p.grad is not None: p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) + if os.environ.get("FIX_EP_GRAD", None): + param_count = 0 + for p in model._layers.parameters(): + if hasattr(p, "is_moe_param") and p.is_moe_param: + with paddle.no_grad(): + if hasattr(p, "main_grad") and p.main_grad is not None: + # print("main grad scale 1/ep") + p.main_grad.scale_(1.0 / self.args.expert_parallel_degree) + param_count += 1 + elif p.grad is not None: + # print("grad scale 1/ep") + p.grad.scale_(1.0 / self.args.expert_parallel_degree) + param_count += 1 + print("fix ep grad count:{}".format(param_count), flush=True) + # Optimizer step self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None @@ -1351,28 +1366,8 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" ) elif isinstance(self.optimizer, HybridParallelOptimizer): - # print("hack for moe grad") - # for p in parameters_list: - # if getattr(p, 'is_moe_param', False): - # if p.grad is not None: - # # print(p.name, p.grad) - # p.grad /= 8 - # if p.main_grad is not None: - # # print(p.name, p.main_grad) - # p.main_grad /= 8 - self.optimizer._step(parameters_list) else: - # print("hack for moe gradr") - # for p in parameters_list: - # if getattr(p, 'is_moe_param', False): - # if p.grad is not None: - # print(p.name, p.grad) - # p.grad /= 4 - # if p.main_grad is not None: - # print(p.name, p.main_grad) - # p.main_grad /= 4 - self.optimizer.step() if self.args.offload_optim: diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index 67584dcd1c0e..8f9beaab813e 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -30,6 +30,11 @@ from .trainer_utils import IntervalStrategy, has_length from .training_args import TrainingArguments +import paddle +import paddle.distributed as dist +from paddle.distributed.fleet import fleet + +from paddlenlp.transformers.moe_gate import PretrainedMoEGate __all__ = [ "TrainerState", @@ -40,6 +45,7 @@ "ProgressCallback", "PrinterCallback", "EarlyStoppingCallback", + "MoECorrectionBiasAdjustCallback", ] @@ -609,3 +615,62 @@ def on_evaluate(self, args, state, control, metrics, **kwargs): self.check_metric_value(args, state, control, metric_value) if self.early_stopping_patience_counter >= self.early_stopping_patience: control.should_training_stop = True + + +class MoECorrectionBiasAdjustCallback(TrainerCallback): + """used for moe aux loss free balance""" + + def __init__(self, lr=0.001, use_mp=False): + super().__init__() + self.update_lr = lr + self.use_mp = use_mp + + def on_optimizer_end(self, args, state, control, **kwargs): + model = kwargs["model"] + + biases = [] + usages = [] + + def get_stat(layer): + if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc": + biases.append(layer.e_score_correction_bias) + usages.append(layer.expert_usage) + + model.apply(get_stat) + + if not usages: + return + usages_tensor = paddle.stack(usages, 0) # [num_layers, num_local_experts] + if not hasattr(fleet, "_hcg"): + dist.all_reduce(usages_tensor) + return + + hcg = fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() + sd_group = hcg.get_sharding_parallel_group() + + if self.use_mp and mp_group.nranks > 1: + dist.all_reduce(usages_tensor, group=mp_group) + if dp_group.nranks > 1: + dist.all_reduce(usages_tensor, group=dp_group) + if sd_group.nranks > 1: + dist.all_reduce(usages_tensor, group=sd_group) + + usages_mean = usages_tensor.mean(-1, keepdim=True) + update = paddle.sign(usages_mean - usages_tensor) * self.update_lr + update = update.astype(paddle.float32) + update_list = list(update) + + # print('on_optimizer_end bias:', [bias.tolist() for bias in biases]) + # print('on_optimizer_end usage:', usages_tensor.tolist()) + # print('on_optimizer_end update:', update.tolist()) + + def update_bias(layer): + if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc": + with paddle.no_grad(): + if not layer.weight.stop_gradient: + biases.pop(0).add_(update_list.pop(0)) + usages.pop(0).zero_() + + model.apply(update_bias) \ No newline at end of file diff --git a/paddlenlp/transformers/deepseek_v2/configuration.py b/paddlenlp/transformers/deepseek_v2/configuration.py index d21afc20780f..28422719b452 100644 --- a/paddlenlp/transformers/deepseek_v2/configuration.py +++ b/paddlenlp/transformers/deepseek_v2/configuration.py @@ -160,7 +160,7 @@ def __init__( first_k_dense_replace=0, norm_topk_prob=False, scoring_func="softmax", - aux_loss_alpha=0.001, + aux_loss_alpha=0.0001, seq_aux=True, hidden_act="silu", max_position_embeddings=2048, diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index c824871c9955..b301ff5da008 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -706,6 +706,12 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): default_initializer=nn.initializer.Constant(0.0), ) self.e_score_correction_bias.is_distributed = True + self.e_score_correction_bias.stop_gradient = True + self.expert_usage = paddle.zeros( + shape=[num_experts], + dtype=paddle.int64, + ) + self.expert_usage.stop_gradient = True self.using_flex_token = config.using_flex_token @@ -730,6 +736,8 @@ def forward(self, hidden_states): if self.using_flex_token: scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(scores) + with paddle.no_grad(): + self.expert_usage += exp_counts return scores, routing_map, l_aux, l_zloss capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) @@ -1519,6 +1527,41 @@ def __init__( self.hnorm = DeepseekV2RMSNorm(config) self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) + def subbatch_recompute_forward( + self, + hidden_states: paddle.Tensor, + nextn_hidden_state: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + hidden_states = self.eh_proj(paddle.concat([nextn_hidden_state, hidden_states], axis=-1)) + + layer_outputs = super(DeepseekV2MTPLayer, self).subbatch_recompute_forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + return hidden_states + def forward( self, hidden_states: paddle.Tensor, diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 8e53c39d4c33..6aa5cfc3a6fa 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -290,7 +290,16 @@ def forward(self, args): hidden_states = hidden_states_main_model for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] - if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 + if moelayer_use_subbatch_recompute: + hidden_states = super().subbatch_recompute_forward( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + elif self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: if attention_mask is not None or attn_mask_startend_row_indices is not None: hidden_states = recompute( super().forward, diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index 4a526feb6acb..d498c0dbfd2c 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -301,7 +301,7 @@ def _topk_noaux_tc( assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" - scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) + scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.detach().unsqueeze(0) group_scores = ( scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) ) # fmt:skip [n, n_group] diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py index 8a28bf4860c1..62e8f49d4ffb 100644 --- a/paddlenlp/trl/model_config.py +++ b/paddlenlp/trl/model_config.py @@ -150,3 +150,4 @@ class ModelConfig: moe_subbatch_token_num: int = field( default=0, metadata={"help": "moelayer subbatch token num, The smaller the value, the smaller the peak memory"} ) + aux_loss_alpha: float = field(default=0.0001, metadata={"help": "aux_loss_alpha"}) From 6e67781862460d8680670538fb137f1ccb1fa623 Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Fri, 5 Sep 2025 17:35:18 +0800 Subject: [PATCH 13/14] fix ep grad bug (#11072) --- llm/config/deepseek-v3/sft_argument.json | 2 +- llm/run_finetune.py | 8 +++- .../moe_hybrid_parallel_optimizer.py | 9 +---- paddlenlp/trainer/trainer.py | 15 ------- paddlenlp/trainer/trainer_callback.py | 40 ++++++++++++++++++- 5 files changed, 47 insertions(+), 27 deletions(-) diff --git a/llm/config/deepseek-v3/sft_argument.json b/llm/config/deepseek-v3/sft_argument.json index edc8452fe09e..7e713dd4b355 100644 --- a/llm/config/deepseek-v3/sft_argument.json +++ b/llm/config/deepseek-v3/sft_argument.json @@ -8,7 +8,7 @@ "per_device_eval_batch_size": 1, "eval_accumulation_steps": 1, "max_steps": 100, - "max_grad_norm": 0, + "max_grad_norm": 1.0, "amp_master_grad": true, "num_train_epochs": 1, "learning_rate": 2.2e-05, diff --git a/llm/run_finetune.py b/llm/run_finetune.py index de40634ef379..36e7b221729b 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -46,7 +46,7 @@ ReFTModel, intervention_mapping, ) -from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed, MoECorrectionBiasAdjustCallback +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed, MoECorrectionBiasAdjustCallback, MoeExpertsGradScaleCallback from paddlenlp.trainer.trainer_callback import TrainerState from paddlenlp.transformers import ( AutoConfig, @@ -474,7 +474,11 @@ def compute_metrics_do_generation(eval_preds): callbacks += [ZeroPaddingIterDatasetCallback()] if getattr(model_config, "topk_method", None) == "noaux_tc": - callbacks += [MoECorrectionBiasAdjustCallback(lr=0)] + # deepseek_v3 finetune do not update the bias, so set lr to 0.0 + callbacks += [MoECorrectionBiasAdjustCallback(lr=0.0)] + + if training_args.use_expert_parallel: + callbacks += [MoeExpertsGradScaleCallback(training_args)] print("callbacks:", callbacks, flush=True) trainer = SFTTrainer( diff --git a/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py index b11263fe90e6..479aa9a44c05 100644 --- a/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py +++ b/paddlenlp/optimizers/moe_hybrid_parallel_optimizer.py @@ -179,7 +179,7 @@ def _dygraph_clip(self, params_grads): ) is_moe_param = getattr(p, "is_moe_param", False) - print(f"p.name:{p.name}, is_moe_param:{is_moe_param}") + if is_moe_param: assert 0 if not_shared_enable: @@ -350,13 +350,6 @@ def __getattr__(self, item): return getattr(self._clip, item) def __call__(self, params_grads): - print("==== zyc debug in moe_hybrid_parallel_optimizer.py ====") - for p, g in params_grads: - has_moe_attr = hasattr(p, "is_moe_param") - is_moe_param = False - if has_moe_attr: - is_moe_param = p.is_moe_param - print(f"p.name:{p.name}, has_moe_attr:{has_moe_attr}, is_moe_param:{is_moe_param}") return self._dygraph_clip(params_grads) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index db23594b7bc1..c44eb5bf23c7 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1323,21 +1323,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): elif p.grad is not None: p.grad.scale_(1.0 / self.args.gradient_accumulation_steps) - if os.environ.get("FIX_EP_GRAD", None): - param_count = 0 - for p in model._layers.parameters(): - if hasattr(p, "is_moe_param") and p.is_moe_param: - with paddle.no_grad(): - if hasattr(p, "main_grad") and p.main_grad is not None: - # print("main grad scale 1/ep") - p.main_grad.scale_(1.0 / self.args.expert_parallel_degree) - param_count += 1 - elif p.grad is not None: - # print("grad scale 1/ep") - p.grad.scale_(1.0 / self.args.expert_parallel_degree) - param_count += 1 - print("fix ep grad count:{}".format(param_count), flush=True) - # Optimizer step self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index 8f9beaab813e..e4394bfde2ea 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -46,6 +46,7 @@ "PrinterCallback", "EarlyStoppingCallback", "MoECorrectionBiasAdjustCallback", + "MoeExpertsGradScaleCallback", ] @@ -673,4 +674,41 @@ def update_bias(layer): biases.pop(0).add_(update_list.pop(0)) usages.pop(0).zero_() - model.apply(update_bias) \ No newline at end of file + model.apply(update_bias) + +class MoeExpertsGradScaleCallback(TrainerCallback): + """ + 此 hook 用于修正专家参数的梯度被放大N倍的问题 + """ + + def __init__(self, args): + """_summary_ + + Args: + args (_type_): _description_ + """ + if not args.use_expert_parallel: + raise ValueError("This callback should be used with expert parallel") + if args.expert_parallel_degree > 1: + self.expert_gradient_scaling_factor = 1.0 / args.expert_parallel_degree + if args.tensor_parallel_degree > 1: + self.expert_gradient_scaling_factor *= args.tensor_parallel_degree + logger.info( + f"EP-MoE is used, expert gradient scaling factor is set to {self.expert_gradient_scaling_factor}" + ) + + def on_optimizer_begin(self, args, state, control, **kwargs): + model = kwargs["model"] + param_count = 0 + for p in model.parameters(): + if not getattr(p, "no_sync", False): + continue + if hasattr(p, "is_moe_param") and p.is_moe_param: + with paddle.no_grad(): + if hasattr(p, "main_grad") and p.main_grad is not None: + p.main_grad.scale_(self.expert_gradient_scaling_factor) + param_count += 1 + elif p.grad is not None: + p.grad.scale_(self.expert_gradient_scaling_factor) + param_count += 1 + logger.info("correct ep grad count:{}".format(param_count)) \ No newline at end of file From 7adc457f19e413822a4adcb995b704b81b7f02ca Mon Sep 17 00:00:00 2001 From: zhengzhonghui Date: Thu, 11 Sep 2025 14:56:32 +0800 Subject: [PATCH 14/14] [Bug Fix]reduce grad of kv_a_proj_with_mqa and q_a_proj to maintain correctness (#11085) --- llm/run_finetune.py | 1 + .../transformers/deepseek_v2/modeling.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 36e7b221729b..914f79ec939a 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -267,6 +267,7 @@ def main(): model_config.using_fake_gate = model_args.using_fake_gate model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num model_config.aux_loss_alpha = model_args.aux_loss_alpha + model_config.gradient_accumulation_steps = training_args.gradient_accumulation_steps logger.info(f"Final model config: {model_config}") logger.info("Creating model") diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index b301ff5da008..375e4a3c8885 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -1002,6 +1002,26 @@ def linear_dtype_gaurd(): # fmt: on + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + def grad_allreduce_hook(param, accumulation_steps): + hcg = fleet.get_hybrid_communicate_group() + pg = hcg.get_model_parallel_group().process_group + step = [0] + + @paddle.autograd.no_grad() + def __impl__(): + step[0] += 1 + if (step[0] % accumulation_steps) == 0: + if hasattr(param, "main_grad"): + pg.allreduce(param.main_grad).wait() + else: + pg.allreduce(param.grad).wait() + + return __impl__ + # kv_a_proj_with_mqa and q_a_proj grad need to be reduce between mp + self.kv_a_proj_with_mqa.weight._register_backward_hook(grad_allreduce_hook(self.kv_a_proj_with_mqa.weight, accumulation_steps=config.gradient_accumulation_steps)) + self.q_a_proj.weight._register_backward_hook(grad_allreduce_hook(self.q_a_proj.weight, accumulation_steps=config.gradient_accumulation_steps)) + self._init_rope() self.softmax_scale = self.q_head_dim ** (-0.5)