Conversation
|
Looks awesome!! Thanks for adding me as a reviewer, I will look at it today! |
nastya236
left a comment
There was a problem hiding this comment.
This looks awesome! This PR very nicely:
- Separates algorithm implementation from the group classes by introducing
MeshImplandRingImplobjects - With the above, we can make an adaptive decision on which
all_reduceto pick: if the number of nodes is more than 2 and the size of the message is big enough, we pick bidirectional ring all_reduce (which is better for larger messages since each node only sends/receives 1/N of the data per step) - So now for mesh type of communication we allocate both buffers for a mesh and for a ring and init both
MeshImplandRingImpl
I left couple of comments mostly for my personal understanding.
| for (int j = 0; j < size_; j++) { | ||
| buffers_.emplace_back(FRAME_SIZE * (1 << k)); | ||
| } | ||
| // Ring buffers (1 for each direction) |
There was a problem hiding this comment.
Nit and just for me to understand better: in case size_ = 2, mesh and ring should be identical and we would never fall into ring all_reduce. I am wondering if it makes sense to allocate ring buffers and init RingImpl only if size_ > 2? But probably, extra complexity does not justify small memory saving..
There was a problem hiding this comment.
Yeah that isn't a bad idea. The memory is 4MB so not important but perhaps there are other reasons to not register unnecessary buffers.
We might still want to use a ring even with 2 nodes because the process is different. The mesh sends all at once and reduces all. The ring sends half, reduces half and gathers half. If the reduction is very expensive then the ring will be faster.
| } | ||
| encoder.dispatch([in_ptr, out_ptr, size, this, reduce_op]() { | ||
| if (size_ > 2 && | ||
| ((std::is_same_v<T, bfloat16_t> && size > 65536) || |
There was a problem hiding this comment.
Interesting! Just for my understanding, why for bfloat16 the cutover is at 65,536 elements = 128KB, but for other 2-byte types (like float16), it is 8MB / 2 = 4M elements = 8MB?
There was a problem hiding this comment.
Bfloat summation is slow on the M3s (we should put in a TODO to make it faster) so the ring (that sums less) is more efficient. For 8MB is where the ring surpasses the mesh with 4 nodes.
| } | ||
| } | ||
| } | ||
| mesh_.all_gather(in_ptr, out_ptr, n_bytes); |
There was a problem hiding this comment.
Also just for me to understand better, I am wondering if it makes sense here to have similar conditional all_gather as for all_reduce if a message is large enough? What do you think?
There was a problem hiding this comment.
Good pointing it out. It doesn't because there is no bandwidth benefit for all gather as there is for all reduce. All reduce benefits because the summations are shared among nodes so we don't ever send all the data.
So tl;dr all gather will be quite slower via the ring vs the mesh.
This PR refactors the communication implementations outside the groups so we can use the more efficient ring reduce in the mesh group for large sizes.
Comparison of all reduce performance on 4 M3 Ultras.
