Kimi K2 Instruct, a 1T parameter model with state of the art quality for coding, reasoning, and agentic tool use, is now available on Fireworks! Try now

Blog
Muonclip

Deep-dive into MuonClip: Fixing Attention Score Explosions in Transformer Training

Deep-dive into MuonClip: Fixing Attention Score Explosions in Transformer Training

Interactive visualization for MuonClip, brought to you from Fireworks.ai

With the release of Kimi-K2, a state of the art tool calling and instruction following model, Kimi team also talked about how they scaled up their pre-training, with a new optimizer, MuonClip. Honestly we don’t see new optimizers that often, so let’s dive into this a little more to understand how this helped the Kimi team scale their training. Specifically, this was the part of the blog https://moonshotai.github.io/Kimi-K2/ related to MuonClip.

Kimi K2

So for people who are bad at math like me, what are they talking and how exactly does it solves their scaling problem.

The Attention Mechanism: A Quick Refresher

Before we hit the problem, let's recall how attention works in transformers (the backbone of most LLMs like GPT or Llama). Attention lets the model "focus" on relevant parts of the input sequence. It does this by computing query (Q), key (K), and value (V) projections from the input embeddings.

The magic happens in the attention scores (often called "logits" in this context, but we'll call them "QK scores" to avoid confusion with output probabilities). These are dot products between queries and keys, scaled by the square root of the dimension for stability:

QK Scores

High scores mean the model pays more attention to that key when aggregating values. But if these scores blow up to extreme values during training, things go haywire—leading to NaNs, gradients vanishing or exploding, and your entire run crashing.

The Scaling Challenge: Why Do QK Scores Explode?

As you scale LLMs to billions of parameters and trillions of tokens (like Kimi K2's 15.5T-token pretraining), instabilities creep in. Moonshot AI noticed this especially when using the Muon optimizer—a high-efficiency alternative to the trusty AdamW that's great for speeding up training but a bit more aggressive.

If you are interested in learning more about Muon, you can read more about it Kimi's paper around Muon, and this blog from Keller Jordan around this topic.

Existing fixes for the QK score explosion problem? Things like logit soft-capping (clamping scores to a max value) or query-key normalization (normalizing Q and K vectors) sound promising, but Moonshot found them lacking. Soft-capping can distort the attention distribution unnaturally, while normalization might not address the root cause in the weights themselves.

Enter MuonClip, Moonshot's upgrade to Muon that tackles this head-on with a technique called qk-clip.

MuonClip to the Rescue: Rescaling at the Source

MuonClip keeps Muon's speed advantages but adds a post-update safety net. After each Muon step (which orthogonalizes updates for balance—more on that in a sec), qk-clip checks the potential QK scores. If the max score exceeds a threshold t (say, 1.0 in our demo), it rescales W_q and W_k directly:

  • Compute η = t / max_score (so η < 1 if clipping).
  • Scale W_q by η^α and W_k by η^(1-α), where α (around 0.5) balances the adjustment between query and key.

This bounds scores at t without messing with the update direction—it's like gently shrinking the weights to prevent overflow, right at the source. No distortion in the attention probs, just controlled scales.

Why does this work better? My guess is that it preserves Muon's efficient, balanced updates while ensuring stability for massive datasets. In Kimi K2's case, it enabled smooth training without the crashes that plagued plain Muon. Jianlin’s blog https://kexue.fm/archives/11126 has way more details here that would go into

Visualizing qk-clip in Action: A Toy Example

To make this tangible, let's look at a simplified visualization. Imagine W_q and W_k as small matrices (e.g., 4x4, simulating a tiny attention head). We apply a Muon-like update (orthogonalizing for balance), then optionally clip.

Here's what happens simulating the “explosion” setup early on in training:

  • Pre-Clip (Top Row): After the update, W_q might have large values, leading to spiky QK scores (dot products). The heatmap shows scores potentially exceeding t, risking explosion.
  • Post-Clip (Bottom Row): qk-clip scales the matrices, capping the max score at t. Notice how the heatmaps use the same color scale—post-clip values are muted (closer to zero), but the relative patterns stay intact. Row norms (bar plots) shrink minimally, balanced by α.

Here is a visualization to help you understand the effect of clipping on the QK score after W_q (again, avoid the word logit so not to be confused with the final logit of the LLM output)

Effect of clipping on the QK score

You can try out the interactive example here: https://muon-clip-app-644257448872.us-central1.run.app. This toy setup mimics real LLM behavior: Aggressive updates (high scale) cause "explosions," but clip reins them in. The value 100 here and 70k steps is not an accident, it is the setup that Jianlin shared about Kimi K2 training https://kexue.fm/archives/11126, where the max logits is getting capped at 100 by the clip.

Max logits is getting capped at 100 by the clip

We also simulate the max QK score later on during training at around 30 ish, where clipping is not active anymore, and you can see the weights are not changed.

Training instability isn't just an academic headache—it's lost GPU hours and delayed launches. Tools like MuonClip democratize big-model training, letting startups like Kimi punch above their weight. The full detail for the QK Clip is actually way more complicated than what I had here, and please check out this blog https://kexue.fm/archives/11126 if you want to dig into this rabbit hole, especially the comment section has a bunch of back and forth with Jianlin and other researchers; or checkout Jianlin Su’s Tweet here https://x.com/Jianlin_S/status/1943920839487107372?t=1hmbwDoWTACe8AUY1ZSWWA&s=09

Here is the code for the visualization in case you are interested.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
import streamlit as st
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
a, b, c = (3.4445, -4.7750, 2.0315)
X = G / (G.norm() + eps)
transpose = G.size(0) > G.size(1)
if transpose:
X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if transpose:
X = X.T
return X
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
if nesterov:
update = grad.clone().lerp_(momentum, beta)
else:
update = momentum.clone()
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
scale_factor = max(1, grad.size(-2) / grad.size(-1)) ** 0.5
update *= scale_factor
return update
def apply_clip(W_q, W_k, alpha, t, x, eps=1e-7):
q = x @ W_q.T # (batch, seq, fan_out)
k = x @ W_k.T
scores = torch.einsum('bid,bjd->bij', q, k) / np.sqrt(W_q.size(0)) # / sqrt(fan_out)
max_score = scores.max()
if max_score > t:
eta = t / (max_score + eps)
scale_q = eta ** alpha
scale_k = eta ** (1 - alpha)
W_q *= scale_q
W_k *= scale_k
return True, max_score.item()
return False, max_score.item()
st.title('MuonClip Clipping Visualization - Early vs Late Training')
fan_out = st.slider('Output Dim (fan_out, e.g., d_head)', min_value=2, max_value=10, value=4, step=1)
fan_in = st.slider('Input Dim (fan_in, e.g., d_model)', min_value=2, max_value=10, value=4, step=1)
seq_len = st.slider('Sequence Length for Simulation', min_value=2, max_value=10, value=4, step=1)
seed = st.slider('Random Seed', min_value=0, max_value=100, value=42, step=1)
beta = st.slider('Momentum Beta', min_value=0.5, max_value=0.99, value=0.95, step=0.01)
nesterov = st.checkbox('Use Nesterov Momentum', value=True)
alpha = st.slider('Alpha for Clip', min_value=0.0, max_value=1.0, value=0.5, step=0.05)
lr_early = st.slider('LR for Early Training (high to simulate explosion)', min_value=0.1, max_value=200.0, value=50.0, step=1.0)
lr_late = st.slider('LR for Late Training (low for stable)', min_value=0.1, max_value=200.0, value=5.0, step=1.0)
t = st.slider('Clip Threshold t', min_value=10.0, max_value=200.0, value=100.0, step=10.0)
torch.manual_seed(seed)
# Initial weights small
init_scale = 0.1
W_q = torch.randn(fan_out, fan_in) * init_scale
W_k = torch.randn(fan_out, fan_in) * init_scale
# Simulate forward for scores
x = torch.randn(1, seq_len, fan_in)
# Function to compute for a scenario
def compute_scenario(lr, t, scenario_name):
# Fix grad_scale to 1, as normalization makes scale irrelevant for magnitude
grad_scale = 1.0
# Simulate grads
grad_q = torch.randn(fan_out, fan_in) * grad_scale
grad_k = torch.randn(fan_out, fan_in) * grad_scale
# Momentum buffers
momentum_q = torch.zeros_like(grad_q)
momentum_k = torch.zeros_like(grad_k)
# Compute Muon updates
update_q = muon_update(grad_q, momentum_q, beta=beta, ns_steps=5, nesterov=nesterov)
update_k = muon_update(grad_k, momentum_k, beta=beta, ns_steps=5, nesterov=nesterov)
# Apply updates to weights with LR
W_q_orth = W_q + lr * update_q # Sim update with LR scaling
W_k_orth = W_k + lr * update_k
# Pre-clip scores
q_pre = x @ W_q_orth.T
k_pre = x @ W_k_orth.T
scores_pre = torch.einsum('bid,bjd->bij', q_pre, k_pre) / np.sqrt(fan_out)
max_pre = scores_pre.max().item()
# Apply clip
W_q_clip = W_q_orth.clone()
W_k_clip = W_k_orth.clone()
clipped, max_post = apply_clip(W_q_clip, W_k_clip, alpha, t, x)
q_post = x @ W_q_clip.T
k_post = x @ W_k_clip.T
scores_post = torch.einsum('bid,bjd->bij', q_post, k_post) / np.sqrt(fan_out)
# Shared vmin/vmax
w_min = min(W_q_orth.min().item(), W_q_clip.min().item())
w_max = max(W_q_orth.max().item(), W_q_clip.max().item())
s_min = min(scores_pre.min().item(), scores_post.min().item())
s_max = max(scores_pre.max().item(), scores_post.max().item())
# Plot
fig, axs = plt.subplots(2, 3, figsize=(18, 10))
sns.heatmap(W_q_orth.numpy(), ax=axs[0,0], cmap='coolwarm', annot=True, fmt=".2f", vmin=w_min, vmax=w_max)
axs[0,0].set_title('W_q After Muon Update (Pre-Clip)')
sns.heatmap(W_q_clip.numpy(), ax=axs[1,0], cmap='coolwarm', annot=True, fmt=".2f", vmin=w_min, vmax=w_max)
axs[1,0].set_title('W_q After Clip')
sns.heatmap(scores_pre[0].numpy(), ax=axs[0,1], cmap='viridis', annot=True, fmt=".2f", vmin=s_min, vmax=s_max)
axs[0,1].set_title(f'Pre-Clip QK Scores, Max: {max_pre:.2f}')
sns.heatmap(scores_post[0].numpy(), ax=axs[1,1], cmap='viridis', annot=True, fmt=".2f", vmin=s_min, vmax=s_max)
axs[1,1].set_title(f'Post-Clip QK Scores, Max: {max_post:.2f}')
norms_pre = torch.norm(W_q_orth, dim=1).numpy()
axs[0,2].bar(range(fan_out), norms_pre, color='orange')
axs[0,2].set_title('W_q Row Norms Pre-Clip')
axs[0,2].set_ylim(0, max(norms_pre.max(), torch.norm(W_q_clip, dim=1).max()) * 1.1)
norms_post = torch.norm(W_q_clip, dim=1).numpy()
axs[1,2].bar(range(fan_out), norms_post, color='purple')
axs[1,2].set_title('W_q Row Norms Post-Clip')
axs[1,2].set_ylim(0, max(norms_pre.max(), norms_post.max()) * 1.1)
plt.tight_layout()
st.subheader(scenario_name)
st.pyplot(fig)
if clipped:
st.text(f"Clip triggered! Max score reduced from {max_pre:.2f} to {max_post:.2f}.")
else:
st.text("No clip triggered.")
# Early Training
compute_scenario(lr_early, t, "Early Training (Before 70k steps, clip at ~100)")
# Late Training
compute_scenario(lr_late, t, "Late Training (After 70k steps, natural ~30)")
st.markdown("""
### Quick Intuition:
- **Early Training**: High LR simulates aggressive updates causing high QK scores (>100), triggering clip.
- **Late Training**: Low LR for stable updates, scores naturally lower (~30), no clip.
- Adjust LR sliders to control update magnitude (grad_scale didn't work due to Muon's normalization; LR scales post-norm update).
- Experiment with seed if max is negative/low (randn can vary; try seeds for positive spikes).
""")

What do you think? Drop a comment on X/Twitter. If this sparked ideas, stay tuned for more content like this.