Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MoE] Fix misuse of num_experts as expert parallel group size (ep_size)#7537

Closed
Flakes342 wants to merge 0 commit into
deepspeedai:masterfrom
Flakes342:master
Closed

[MoE] Fix misuse of num_experts as expert parallel group size (ep_size)#7537
Flakes342 wants to merge 0 commit into
deepspeedai:masterfrom
Flakes342:master

Conversation

@Flakes342
Copy link
Copy Markdown
Contributor

Fixes #7535

Description

This PR fixes a bug in inference/engine.py where num_experts (moe_experts) was incorrectly passed as the expert parallel group size (ep_size) when creating expert parallel groups.

Currently:

if moe and dist.get_world_size() > 1:
    self._create_ep_parallel_group(config.moe.moe_experts)

This causes invalid behavior whenever num_experts > world_size, because _create_ep_parallel_group expects a group size, not the total number of experts as pointed out by @Arnoochka

Root Cause

num_experts = number of experts inside the MoE layer.

ep_size = how many GPUs to group together for expert parallelism.

These were mixed up in the code.

##Fix

Replaced the incorrect call with the proper ep_size argument:

if moe and dist.get_world_size() > 1:
    self._create_ep_parallel_group(config.moe.ep_size)

Additionally, added a safety check in _create_ep_parallel_group to catch invalid configurations:

num_ep_groups = dist.get_world_size() // moe_ep_size
if num_ep_groups == 0:
    raise ValueError(
        f"Invalid ep_size={moe_ep_size} for world_size={dist.get_world_size()}"
    )

Backward compatibility

  • If a user was already running with ep_size >= num_experts, the old code worked fine which would still work fine.
  • Only the previously broken case (num_experts > world_size) now works correctly.

@Flakes342 Flakes342 changed the title [MoE] Fixed misuse of num_experts as expert parallel group size (ep_size) [MoE] Fix misuse of num_experts as expert parallel group size (ep_size) Sep 3, 2025
Copy link
Copy Markdown
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Flakes342 Great catch, thank you for the fix!

@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Sep 9, 2025

@Flakes342 nv-mii raised an error, but it is not related to this PR. Let's merge once you fix the formatting and DCO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] InferenceEngine._create_ep_parallel_group uses num_experts instead of ep_size, causing incorrect behavior

2 participants