Thanks to visit codestin.com
Credit goes to github.com

Skip to content

GAD-cell/muon-clip

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

115 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Muon Optimizer 2.0

This repository presents an implementation of the Muon optimizer, enhanced with the QK-Clipping technique introduced in Kimi K2 and better newton-shulz orthogonalization.

Key Features

  • QK-Clipping: Introduces a mechanism to stabilize training by clipping attention logits for each head.
  • Corrected RMS: Corrected Muon's update RMS to ensure a compatible learning rate between Muon and Adam
  • Esasy to use: Designed to integrate seamlessly with existing transformer and pytorch architectures. Designed to be used as a regular pytorch optimizer.
  • Scalability: Optimized for large-scale training scenarios and implemented for DDP training.
  • Efficient orthogonalization: Designed to improve gradients orthogonalization via CANS method, a better newton-shulz iteration with eigenvalues interval estimation and chebychev polynomials. (Experimental)
  • Metrics Logs: Use W&B or tensorboard to monitor QK-clipping

How to use

Here's a basic example:

from muon import MuonClip, MuonConfig
from transformers import AutoConfig

# model config can also be a dic with at least num_key_value_heads,num_attention_heads and head_dim keys
model_config = AutoConfig.from_pretrained("{hf_model}")

muon_config = MuonConfig(
    unified_lr=True # If true, use the same learning rate for both Muon and Adam parts of the optimizer.
    lr=1e-5
    lr_muon=1e-4 # Only used if unified_lr is False
    lr_adam=1e-4

    muon_beta=0.95
    muon_decay=0.0
    ns_steps=5 #Number of newton-shulz interations. Increase for more precision during orthogonalization

    adam_betas=(0.9, 0.95)
    adam_decay=0.0
    adam_eps=1e-10

    enable_clipping=True
    clipping_layers_mapping={"q_proj": "q_proj", "k_proj": "k_proj"} # If using a special model with non standard q_proj and k_proj names. Just change the value to the desired name.
    clipping_threshold=50.0
    clipping_alpha=0.5

    log_max_logits=True
    log_dir="./logs"
    cans_ortho=False # Experimental: Use CANS orthogonalization. Suggest to disable it for now.
    estimate_lower_bound=False
)

optimizer = MuonClip(model, model_config, muon_config)

model.train() #You must call model.train() after defining the optimizer so that hooks are registered correctly.

Demo

Below a training test with and without clipping. Notice how the logits are clipped when reaching clipping_threshold. Training max_logits

Installation

To install muon-clip just use:

pip install git+https://github.com/GAD-cell/muon-clip.git@main

Coming soon

-"Zero stage 1" like optimization based on distributed muon
-Notebooks for training and distributed training with MuonClip

About

An implementation of the Muon optimizer in pytorch featuring the latest research improvements.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors