Thanks to visit codestin.com
Credit goes to github.com

Skip to content

JACCL refactor and small update#3174

Merged
angeloskath merged 5 commits intomainfrom
jaccl
Feb 26, 2026
Merged

JACCL refactor and small update#3174
angeloskath merged 5 commits intomainfrom
jaccl

Conversation

@angeloskath
Copy link
Member

This PR refactors the communication implementations outside the groups so we can use the more efficient ring reduce in the mesh group for large sizes.

Comparison of all reduce performance on 4 M3 Ultras.
jaccl-mesh-update

@nastya236
Copy link
Collaborator

Looks awesome!! Thanks for adding me as a reviewer, I will look at it today!

Copy link
Collaborator

@nastya236 nastya236 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks awesome! This PR very nicely:

  • Separates algorithm implementation from the group classes by introducing MeshImpl and RingImpl objects
  • With the above, we can make an adaptive decision on which all_reduce to pick: if the number of nodes is more than 2 and the size of the message is big enough, we pick bidirectional ring all_reduce (which is better for larger messages since each node only sends/receives 1/N of the data per step)
  • So now for mesh type of communication we allocate both buffers for a mesh and for a ring and init both MeshImpl and RingImpl

I left couple of comments mostly for my personal understanding.

for (int j = 0; j < size_; j++) {
buffers_.emplace_back(FRAME_SIZE * (1 << k));
}
// Ring buffers (1 for each direction)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit and just for me to understand better: in case size_ = 2, mesh and ring should be identical and we would never fall into ring all_reduce. I am wondering if it makes sense to allocate ring buffers and init RingImpl only if size_ > 2? But probably, extra complexity does not justify small memory saving..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that isn't a bad idea. The memory is 4MB so not important but perhaps there are other reasons to not register unnecessary buffers.

We might still want to use a ring even with 2 nodes because the process is different. The mesh sends all at once and reduces all. The ring sends half, reduces half and gathers half. If the reduction is very expensive then the ring will be faster.

}
encoder.dispatch([in_ptr, out_ptr, size, this, reduce_op]() {
if (size_ > 2 &&
((std::is_same_v<T, bfloat16_t> && size > 65536) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! Just for my understanding, why for bfloat16 the cutover is at 65,536 elements = 128KB, but for other 2-byte types (like float16), it is 8MB / 2 = 4M elements = 8MB?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bfloat summation is slow on the M3s (we should put in a TODO to make it faster) so the ring (that sums less) is more efficient. For 8MB is where the ring surpasses the mesh with 4 nodes.

}
}
}
mesh_.all_gather(in_ptr, out_ptr, n_bytes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just for me to understand better, I am wondering if it makes sense here to have similar conditional all_gather as for all_reduce if a message is large enough? What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good pointing it out. It doesn't because there is no bandwidth benefit for all gather as there is for all reduce. All reduce benefits because the summations are shared among nodes so we don't ever send all the data.

So tl;dr all gather will be quite slower via the ring vs the mesh.

@angeloskath angeloskath merged commit 5c4abd2 into main Feb 26, 2026
16 checks passed
@angeloskath angeloskath deleted the jaccl branch February 26, 2026 21:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants