Unified Optimizers and Transposed LM Head#200
Conversation
|
Wow big upgrade on clarity of code! I'll retime later after merging in the open triton kernel one, which shouldn't affect loss. On I had originally picked 2/3 to split the embed because it prevented an extra compile. I see that's no longer an issue. That's interesting you found -50 worked better. Maybe there is more opportunity here. Some options, maybe you already explored:
On Blending / Switching Muon and Adam- I had some ideas a while back to relax Muon at the end of training. The challenge is that the SGD update is terrible because it gets dominated by a couple massive activations coming from the residual stream. Adam relies on having accurate EMA buffers. Even though we could compute EMA buffers in the background, they would be inaccurate since they are coming from the Muon training trajectory. Would need some way to approach off-policy optimizer momentum learning, or use some rough heuristics. Perhaps there is some existing literature on transitioning training optimizers mid-training. On the topic of taking effective steps at the end of training- I explored taking a half-trained model and observing how quickly I could curve-fit a single batch if stepping on it repeatedly, with Muon vs SGD. The main problem that SGD had was that its updates to layer n would completely change the activations entering layer n+1, which made the gradients on layer n+1 inaccurate. An interesting question in gradient descent of loss = F(A,B) is how updating parameter A impacts the accuracy of the gradient estimate for parameter B. For instance, I found that the value projection was the most sensitive part. Freezing the value projection let me crank up the LR on Muon and descend much faster when curve fitting a single batch. However at the time I didn't find a way to translate this to better validation loss. Unless something gives a big boost I tend to drop it because its easy to get distracted by noise. |
Good call, I'll remove that. I think there may be a couple other spots like that, I'll review.
Yeah, my rationale there was that, if you're going to zero the optimizer state, doing it right at the batch size increase seems like the worst place to do it. So 50 steps was to give it time to adjust, while also minimizing the penalty of starting the extra communication of embedding gradients early. If you take away the "zeroing the optimizer state" problem, I don't know that untying 50 steps back is necessarily any better, loss-wise.
This is a very sane idea, I'm trying it out now. Zeroing the state is the only thing that changed about the math compared to the baseline, and this would address that. The '50 steps back' approach costs 50ms, so I have to imagine this would be cheaper. Great thoughts on the optimizer mixing, those sound like some interesting experiments.
That is some valuable wisdom that I am going to try and hang on to. :) |
is this because comms is still dominating compute in terms of time? |
|
@chrisjmccormick please let me know when you want to consider this one finalized and I’ll time it and explore some other random ideas when I have the GPUs up. |
|
@varunneal Yeah, we're definitely still very comms-bound. @ClassicLarry Thanks Larry. The all_reduce on the optimizer state was the right answer--faster and it brought the loss back down. I re-ran the timings; I'm sitting down now to update the commit with new logs and updated README. I also played with re-adding attention layer 6 to the model using one of the spare MLP matrices in the parameter bank--didn't get a record out of it, but I'll share what I tried. |
|
A little more detail on the success of the
Probably didn't need 10 runs, that's just how many iterations I put in |
|
great, I'll time once commit updated. |
|
Will merge later at 105.9-1.0 = 104.9s = 1.748 minutes. Super cool write up. Image links may need to be repointed. |
|
Thanks! Yeah, I had the images working in the README, but was struggling to figure out how to reference them properly in the pull request. Just got it. |
|
Just a fyi you have |
Hmm, that's not good--it's not in the log file. Digging into it now. Update: Ok, I think we're good. Only difference between I'll remove the duplicate code from |
|
ah looks like when I resolved the merge conflict it brought back in those kernel functions on accident. |
|
Ah, ok, phew. :) Unrelated, but I'm wondering if we need to add back in more warmup steps. I was diffing log files and noticed something:
Something to try next time I have a machine going. |
The data loader has logic to wait on the full first shard around step 6. Maybe things are fast enough now that this needs to get pushed out to step 10 or so. On warmup, would be nice if we can make it faster, would speed up testing a lot. Would need deep dive on where compile bottleneck is. |
|
On Leveraging the Spare MLP Weights- Something that might be more promising is to take the 7687684*2 = ~5 million weights and give each token an ~8 parameter embedding, that gets used somewhere. The results from adding an attn layer are interesting. My hunch here is that we will end up at a flat 10 layers as more efficient operations keep getting added around the standard attn and mlps. |
Unified, Zero-Copy Optimizer and Transposed LM Head
The main contributions of this record are:
NorMuonAndAdam, oriented towards per-parameter rather than per-optimizer configuration.I did my development against record 59. After integrating record 60 (logit softcap kernel) and re-timing I was able to remove another 5 steps (-10 steps total):
Using record 59 (the MLP kernel) as baseline and reducing 5 steps:
Notes:
triton_kernels.pychange, and added a line at the top of train_gpt.py to pull the kernel code into the log file.1. Unified, Zero-Copy Optimizer
1.1. Stacking vs. Sharding
One of the core changes in this PR is the consolidation of the attention and mlp weight matrices into two "parameter banks", which can be sharded across the GPUs, eliminating the need for stacking and unstacking during the optimizer window.
Current Approach
The approach to communicating the Muon parameters has been to:
This approach has a few drawbacks:
This illustration is out-of-date, so ignore what it says about comms and compute delays, but it highlights the memcpys I'm referring to, and shows how time-consuming they are.
Parameter Banks
This PR utilitizes a "parameter bank" strategy that was introduced in a previous PR for the attention gates and ve gates, and I've extended it to the attention and mlp projection matrices.
The QKVO weight matrices from all 10 attention modules are stored as a single parameter of shape
(10, 4×768, 768). This can be 'reshaped' without memory movement to(40, 768, 768), and then sharded evenly across the 8 GPUs (so that each gets five attention matrices).The MLP weight matrices (11 modules, 2 matrices per module) require padding to be sharded evenly. Their parameter bank is
(12, 2, 3072), and this is "reshaped" to(24, 3072)so that each GPU receives 3 matrices to process.Note that means we're paying the communication and optimizer overhead for an entire additional MLP, so we may be able to find a use for those weights. There are some notes on what I tried under the Ideas section.
Now, both passes have balanced workloads, and:
(5, 768, 768)each.1.2. Individual Parameters
An additional benefit of the parameter banks is that it helps eliminate the need for the concept of "parameter groups".
Optimizer groups exist to handle the fact that we typically have multiple instances of the same functional matrix--e.g., we have 10 layers worth of attention weights--which all share settings like learning rate.
Groups add an additional layer of abstraction and complexity to the optimizer code. With the parameter banks, our model no longer has per-layer parameters, and instead has "only" the 13 distinct parameters in the table below (defined in the
TrainingManagerinit function).The one remaining set of components that could arguably still be grouped are the value embedding tables. Instead, I've individually labeled them as 've0', 've1', and 've2' to treat them as distinct parameters.
1.3. Replacing Hooks with Explicit Ordering
The most substantial benefit we were getting from registering gradient accumulation hooks was to overlap communication with the lengthy lm_head gradient accumulation kernel. With that removed (as described further down), there's less benefit to overlapping the smaller kernels, and I chose to remove the feature for simplicity and added flexibility in ordering the communication and workload.
The order of the parameters in the above
param_tabledictates the order in which we issue the scatter or all_reduce operations, and then a separate list specifies the order that we work on them:This ordering was chosen by Claude Opus, and the comments are its' rationale. This strategy outperforms the prior hook approach and appears to be responsible for the reduction in validation loss that allowed me to remove 5 steps.
I think the rationale for the work order makes sense, but note how its roughly the opposite of the communication order--that part seems counterintuitive to me.
1.4. Debugging Comms
I've been having issues lately with getting good trace data. The profiler is somehow slowing things down much more than usual, and the communication data in particular doesn't look right. This has made it hard to arrive on a deliberate strategy and validate it.
In other words, I can't rationalize / confirm that the current communication and work strategy is optimal. It would even make sense for there to exist a hook-based solution that outperforms this one. This PR is very much "for whatever reason, this works better than what we had".
Just for future reference, I've printed out the order that the hooks were firing in, which tells us what order gradient accumulation happens in, and which in turn may have some value in understanding communication behavior.
Note: To include the NorMuon parameters as well, I added hooks to them and manually set the gradient accumulation steps to 2 so that they would trigger.
Here's the order they're called in:
Not very intuitive!
2. Transposed LM-Head
Another long-standing communications issue has been the slow gradient accumulation kernel for the LM head.
The below illustrations are from an older baseline, but highlight the problem well.
In the Adam plus Muon steps:
In the Muon-only steps:
I've resolved this by transposing the memory layout of the LM head so that it now has shape (768, 50304), and updated the FP8 implementation to support this. From test runs and looking at single GPU trace files, it appears to have identical forward/backward performance to the current implementation.
Here's what was behind the slow kernel, and how transposing addresses it.
2.1. Mismatched Memory Layouts
Currently, the backward pass produces gradients for the LM head with layout (768, 50304), but the head is stored as (50304, 768).
I've learned that, while CUDA can often hide the cost of mismatched memory layouts when performing matrix multiplications, it's not possible to hide this problem for element-wise operations.
This is because values can be read once and used multiple times in matmul, but for element-wise operations each value is read and used exactly once. When the memory layouts don't match, it has to stride through one of the matrices, resulting in kernels that run substantially slower.
This is the reason for the slowness of the current gradient accumulation kernel, which takes maybe 4x longer than the other embedding tables despite being the same size and shape.
Transposing the head brings the weights and gradients into alignment. This has the biggest impact for the NorMuon-only steps, where we've never been able to overlap any communication with the slower kernel.
2.2. Tied Embeddings
This change was more clearly beneficial back before we re-tied the embeddings. The tied embeddings create a problem--the input embeddings still need to have their current shape of
(50304, 768)so that the embeddings are laid out consecutively in memory and can be selected efficiently.To resolve this, I "untied the embeddings" in the sense that there are now always two separate matrices involved, but they're kept effectively tied by managing their gradients and updates manually.
This means manually combining the LM head and input embedding gradients before scattering them, and then copying the received updates into both the
embedandlm_headweights.This actually brings us back where we started--expensive elementwise operations on two misaligned matrices--but now it's inside of the optimizer step code where there's plenty of compute available and we have the ability to overlap it with comms.
2.3. Optimizer State After Untying
An additional challenge was setting up the optimizer state for the
embedmatrix at the untie point. Because of the difference in orientation, each GPU handles a(96, 50304)shard oflm_headand a(6288, 768)shard ofembed. That means we can't simply copy the Adam momentum buffers from one to the other.Thank you to @ClassicLarry for suggesting a more straightforward solution to this than what I had been trying. We simply perform an explicit communication step to make sure every GPU has the correct momentum buffer state for
embed. This preserves the behavior and accuracy of the existing implementation.Ideas
With the Muon memcpys gone, we have quite a lot of compute available during the optimizer window. Here are a few things I played with that might deserve exploring more, or might spark other ideas.
Interpolated Embedding Updates
Because we're now manually combining the gradient updates of the input embeddings and LM heads to tie them, we have the opportunity to "mix" them differently; i.e., we could scale the input embedding gradients up or down relative to the lm_head before combining them.
Blending / Switching Muon and Adam
Adam is remarkably fast, and I think we can afford to run both NorMuon and Adam on the MLP and Attention weights at the same time. I tried some experiments blending their updates over the course of training, but nothing looked promising.
What I'd still be interested to try, though, is switching the projection matrices from Muon to Adam towards the very end. Compared to Muon's orthogonalization constraint, I'm wondering if the per-parameter freedom of Adam's updates might allow it to settle into a better local optimum (or find one faster) at the end of training.
(Larry has some interesting notes on this idea in the comments of the PR)
Starting
forwardWithinstepIt's not a giant time savings, but inside of the optimizer window, while still waiting for the attention and MLP gathers to complete, we can start selecting the next input embeddings, normalizing them, applying the smear gate, and perhaps even running the first attention layer (since those weights arrive before the MLPs).
I started on this, but set it aside when I realized how complicated / impossible it would be when doing gradient accumulation steps.
I think the solution there would probably be to define a separate training loop that's only for the 8x gpu setup (with no gradient accumulation support), where we could apply this.
Leveraging the Spare MLP Weights
We are now paying the communication and optimizer overhead for another MLP's worth of weights.
One idea I tried for taking advantage of this was to use one of those spare matrices to serve as the QKVO weights to restore attention layer 6.
Because this matrix is optimized with the MLP weights, Muon treats it as one big matrix instead of four separate ones. Not ideal, but it still works.
It seemed promising, with some very low loss values:
[3.2740, 3.2722, 3.2733, 3.2758]However, it takes about 50 steps to break even, and the loss was far too high (row 4 below).
I tried both with and without the skip connection (which seems like it may be serving some of the role of the missing attention layer?):
Maybe with the right adjustments to the skip connection, or other attention attributes, it could work?