Fireworks RFT now available! Fine-tune open models that outperform frontier models. Try today

Blog
Kimi Qkclip

A Deep Dive into MLA training/inference difference and why QK-Clip from Kimi is such an elegant idea

QK-Clip

Today, we're unpacking a clever insight from the researchers behind Kimi K2, a powerful LLM from Moonshot AI. This all started from a fascinating exchange in the comment section of a technical blog post.

We'll break it down step by step, with real math to appreciate the elegance, but I'll explain it like we're chatting over coffee. By the end, you'll see why this "QK-Clip" trick is so smart and how it makes models like Kimi more reliable for your apps.

Anecdotally, I have heard whispers on the street that there are quality trade-offs with using MLA, and this may be the secret ingredients that some of the top labs have been missing, paving the future for inference to be more efficient across the board.

A Comment Exchange on the Kimi Blog

Our story begins on the comment section of a blog post by Su Jianlin (苏剑林) on https://kexue.fm/archives/11126.

A Comment Exchange on the Kimi Blog
  • Someone asked, why “during decoding, you cannot fully materialize the k you get during training”
  • Jianlin structural difference in how Keys are computed in Multi-Head Latent Attention (MLA), a memory-efficient variant used in models like Kimi K2 (inspired by DeepSeek-V2). In training, Keys are fully "materialized" (computed and structured in a way that allows normalization), but in decoding (inference), a key component is missing, breaking techniques like RMSNorm.

This exchange fascinated me because it highlights a real-world engineering challenge in scaling LLMs. As app devs, we often treat models as black boxes, but peeking inside reveals why innovations like QK-Clip are crucial for stable performance. Inspired by this, I created an animated visualization to make the concept accessible.

But let me explain, in case you don’t know the details about MLA.

2. Background: LLMs, Attention, and Why Efficiency Matters

If you're building GenAI apps- say, a chatbot or text generator, you're likely using LLMs like GPT models or Kimi models via APIs. At their core, LLMs are giant neural networks that predict the next token (word or subword) in a sequence. They do this by processing inputs through layers of "attention" mechanisms.

Attention is the secret sauce: It lets the model weigh the importance of different parts of the input. In standard Multi-Head Attention (MHA), for each position in the sequence, we compute:

  • Queries (Q): What the current token is "asking" about.
  • Keys (K): Representations of past tokens to match against Q.
  • Values (V): The actual info to retrieve based on Q-K matches.

The attention score is basically Softmax(QKT/d){Softmax}(Q K^T / \sqrt{d}), where dd is the dimension, and this gets multiplied by V.

But here's the catch for large models: During inference (decoding), especially for long conversations, storing all K and V (the "KV cache") eats up memory. Models like Kimi K2 use Multi-Head Latent Attention (MLA) to compress this.

In MLA, Keys and Values are projected into a lower-dimensional "latent" space (e.g., from 5120 dims to 512), saving memory without losing much power. This is genius for apps handling long contexts, like summarizing documents or maintaining chat history.

https://arxiv.org/pdf/2405.04434 from the original DeepSeek V2 paper
https://arxiv.org/pdf/2405.04434 from the original DeepSeek V2 paper

However, as the comment revealed, MLA introduces a subtle difference between training (where we process the whole sequence at once) and inference (where we generate one token at a time). This can cause instability, like exploding values in attention scores.

3. The Problem and Math Deep Dive: Training vs. Inference in Multi-Head Latent Attention (MLA)

If you're an app developer who's used LLMs but never really dug into how they compute attention under the hood, this section is for you. We'll merge the problem explanation with a detailed math breakdown, stepping through everything one piece at a time. Imagine we're walking through a recipe: I'll define each ingredient (variable), show how they're mixed (the formulas), and explain what goes wrong if you skip a step.

Along the way, I'll address two common follow-up questions proactively:

(1) Why not just compute the missing projection in inference, isn't skipping it an inconsistency?

(2) Does this breakage only affect norms in one spot, or can you norm elsewhere to fix explosions?

By the end, you'll see exactly why training and inference differ in MLA, why that breaks normalization tricks like RMSNorm, and how it sets up Kimi's clever fix.

First, Recap the Big Idea in Simple Terms

In attention mechanisms (the part of LLMs that decides what to "pay attention to" in a sentence), we need to create Keys (K) for each token. These Keys are like searchable tags that help the model match the current query to past context.

In Multi-Head Latent Attention (MLA)- used in efficient models like Kimi K2 to save memory. Keys are built in a compressed way. During training (when the model learns from data, processing entire sequences at once), Keys are fully built with all parts.

But during inference (or "decoding," when your app generates text one token at a time), we simplify the process to be faster and use less memory. This simplification skips a key step, which is fine for basic computation but breaks add-on techniques like normalization (which keeps values from exploding).

The result? Without careful handling, attention scores can go haywire in inference, leading to weird outputs or crashes in your app. Now, let's unpack the math to see why.

Step 1: Understanding the Input and Basic Setup

Every token in your xi\boldsymbol{x}_i input (like a word in a prompt) starts as a high-dimensional vector called the embedding, denoted xi\boldsymbol{x}_i. Here:

  • i is the position of the token in the sequence.
  • To make attention efficient, MLA first compresses part of this into a "latent" space- a smaller vector to save memory.
  • We multiply xi\boldsymbol{x}_i by a weight matrix $WcRdm×dc\boldsymbol{W}_c \in \mathbb{R}^{d_m \times d_c}$, where dcd_c is much smaller than dmd_m(e.g., dc=512d_c = 512).
  • Result: $ci=xiWcRdc\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c \in \mathbb{R}^{d_c}$.

This ci\boldsymbol{c}_i is the compressed "latent" version of the token. It's like summarizing a long article into key points- efficient for storage in the KV cache during long chats.

Step 2: Building the Key in Training (Full Version)

In training, we build the Key (ki(s)\boldsymbol{k}_i^{(s)}) for each attention head(s) (models have multiple "heads" to focus on different aspects). The Key is concatenated from two parts:

  • Part 1: The Projected Latent (Head-Specific)
    • Take the latent ci\boldsymbol{c}_i.
    • Multiply it by a head-specific weight matrix Wkc(s)\boldsymbol{W}_{kc}^{(s)} \in Rdc×dk\mathbb{R}^{d_c \times d_k}, where dkd_k is the per-head key dimension (e.g., 128).
    • Result: ciWkc(s)Rdk\boldsymbol{c}i \boldsymbol{W}{kc}^{(s)} \in \mathbb{R}^{d_k}.
  • Part 2: The Rotary Path (Position-Aware)
    • Take the original xi\boldsymbol{x}_i
    • Multiply by another weight matrix WkrRdm×dr\boldsymbol{W}_{kr} \in \mathbb{R}^{d_m \times d_r}, where drd_r is small (e.g., 64).
    • Then apply Rotary Position Embedding (RoPE), Ri\boldsymbol{\mathcal{R}}_i, which rotates the vector based on position i to help the model understand order.
    • Result: xiWkrRiRdr\boldsymbol{x}i \boldsymbol{W}{kr} \boldsymbol{\mathcal{R}}_i \in \mathbb{R}^{d_r}.
  • Concatenate Them: Stack these two vectors vertically:

ki(s)=[ciWkc(s)xiWkrRi]Rdk+dr\boldsymbol{k}_i^{(s)} = \begin{bmatrix} \boldsymbol{c}i \boldsymbol{W}{kc}^{(s)} \\ \boldsymbol{x}i \boldsymbol{W}{kr} \boldsymbol{\mathcal{R}}_i \end{bmatrix} \in \mathbb{R}^{d_k + d_r}
(E.g., 128 + 64 = 192 dimensions.)

This full Key is "materialized"- fully computed and structured, ready for anything, like applying norms. Here is the MLA construction in the original DeepSeek paper. I circled the concat in orange.

https://arxiv.org/pdf/2405.04434 from the original DeepSeek V2 paper
https://arxiv.org/pdf/2405.04434 from the original DeepSeek V2 paper

Step 3: Building the Key in Inference (Simplified Version) – And Why We Skip the Projection

In inference, we optimize for speed: We cache the latent ci\boldsymbol{c}_i once and reuse it across heads, avoiding recomputing heavy projections. So, the Key ki\boldsymbol{k}_i (now shared, not per-head) becomes:

  • Part 1: Direct Latent (No Projection)
    • Just use ciRdc\boldsymbol{c}i \in \mathbb{R}^{d_c} directly—no multiplication by Wkc(s)\boldsymbol{W}{kc}^{(s)}! This saves compute and memory, as Wkc(s)\boldsymbol{W}_{kc}^{(s)} is "absorbed" into other parts of the model (like the query or output weights).
      • Specifically, the benefit from MLA is so that we can precompute components during training. If we write out the formula
      • qt(s)ki(s)=(xtWq(s))(ciWkc(s))=xt(Wq(s)Wkc(s))ci\boldsymbol{q}{t}^{(s)} \boldsymbol{k}{i}^{(s)\top} = (\boldsymbol{x}{t} \boldsymbol{W}{q}^{(s)}) (\boldsymbol{c}{i} \boldsymbol{W}{kc}^{(s)})^{\top} = \boldsymbol{x}{t} (\boldsymbol{W}{q}^{(s)} \boldsymbol{W}{kc}^{(s)\top}) \boldsymbol{c}{i}^{\top}
      • Here (Wq(s)Wkc(s))(\boldsymbol{W}{q}^{(s)} \boldsymbol{W}_{kc}^{(s)\top}) can be precomputed for inference, so Wkc(s)W_{kc}^{(s)} simply “disappears” in inference
  • Part 2: The Rotary Path
    • Same as training: xiWkrRiRdr\boldsymbol{x}i \boldsymbol{W}{kr} \boldsymbol{\mathcal{R}}_i \in \mathbb{R}^{d_r}.
  • Concatenate Them:

ki=[cixiWkrRi]Rdc+dr\boldsymbol{k}_i = \begin{bmatrix} \boldsymbol{c}_i \\ \boldsymbol{x}i \boldsymbol{W}{kr} \boldsymbol{\mathcal{R}}_i \end{bmatrix} \in \mathbb{R}^{d_c + d_r}

(E.g., 512 + 64 = 576 dimensions—notice it's larger and structurally different because dc>dkd_c > d_k.)

Key difference: No Wkc(s)\boldsymbol{W}_{kc}^{(s)} in Part 1! This is efficient but means the Key isn't "fully materialized" like in training.

Now, addressing a natural follow-up: Why not just compute Wkc(s)\boldsymbol{W}_{kc}^{(s)} in inference anyway? Isn't skipping it an inconsistency between training and inference?

It could be computed, but that would defeat MLA's main goal: massive efficiency in memory and speed.

Instead, the projection is "absorbed" or fused mathematically into other weights (e.g., query Wq(s)\boldsymbol{W}_q^{(s)} and output Wo(s))\boldsymbol{W}o^{(s)}) keeping outputs identical without explicit computation. For example, the effect of Wkc(s)\boldsymbol{W}{kc}^{(s)} is pre-multiplied into queries: Q(s)=xWq(s)Wkc(s)TQ'^{(s)} = \boldsymbol{x} \boldsymbol{W}q^{(s)} \boldsymbol{W}{kc}^{(s)T}, so you query the latent directly. This isn't an inconsistency in the core math (attention outputs match), but it does alter the internal structure—making the Key non-materializable for add-ons like norms.

Step 4: Why This Breaks Normalization (Like RMSNorm) – And Where Else Norms Can (or Can't) Help

Normalization techniques, such as RMSNorm, are used to stabilize attention by scaling vectors to prevent huge values (explosions). RMSNorm for a vector vRd\boldsymbol{v} \in \mathbb{R}^d is:

RMSNorm(v)=v1dj=1dvj2+ϵ\text{RMSNorm}(\boldsymbol{v}) = \boldsymbol{v} \oslash \sqrt{\frac{1}{d} \sum_{j=1}^d v_j^2 + \epsilon}

(where \oslash is element-wise division, and ϵ\epsilon is a tiny number to avoid division by zero).

Here, RMSNorm is often part of "QK-Norm": Normalizing Queries and Keys before their dot product $Q K^T$ to tame explosions from unchecked weight growth.

  • In Training: We can apply RMSNorm to the full Key ki(s)\boldsymbol{k}_i^{(s)}. The norm is computed over both parts, including the projected latent ciWkc(s)\boldsymbol{c}i \boldsymbol{W}{kc}^{(s)}. Everything is there, so it works: The model learns stable weights.
  • In Inference: We try to apply RMSNorm to ki\boldsymbol{k}_i, but... the structure is different! The first part is raw ci\boldsymbol{c}_i (512 dims, different scale/distribution) vs. projected (128 dims). Worse, since the projection is fused elsewhere, norming the raw latent would mess up the fused math, leading to wrong outputs. The blog comment nails it: You can't compute the norm of the absent ciWkc(s)\boldsymbol{c}_i \boldsymbol{W}_{kc}^{(s)}.

This leads to another follow-up: Does this just break RMSNorm on that specific part (the Keys in attention)? Can you apply norms elsewhere to fix QK explosions?

Yes, it specifically breaks QK-Norm (RMSNorm on Keys/Queries in attention), as you can't norm the missing projected part. You can (and do) apply norms elsewhere—like LayerNorm after attention or FFN layers, which stabilizes the overall hidden state without MLA issues. But these don't target QK explosions directly: If QKTQ K^T blows up (e.g., max values in thousands), Softmax overflows, causing repetitive or garbage output. Norms elsewhere help the model overall but can't prevent that local issue in attention—the hotspot for explosions from weights like Wq\boldsymbol{W}_q and Wk\boldsymbol{W}_k.

4. Visualizing It All: Step-by-Step Through the Graph

MLA versus Standard Attention

To make this tangible, I built an animated flowchart using NetworkX and Matplotlib (code shared in the query). It’s a side-by-side views of Training (left) and Decoding (right), with nodes as variables/operations, edges as computations, and colors indicating success (light green) or failure (salmon red). The animation highlights paths orange step-by-step, unfolding the process.

Here's a walkthrough:

  • Overall Layout: At the top is the input token embedding xiRdm\boldsymbol{x}_i \in \mathbb{R}^{d_m}(e.g., dmd_m = 5120). It branches left (latent compression) and right (rotary path). They merge into the Key ki\boldsymbol{k}_i, then to RMSNorm and Output.
  • Step 1: Input Splitting
    • Both sides start with xi\boldsymbol{x}_i .
    • Left: Multiplies by WcRdm×dc\boldsymbol{W}_c \in \mathbb{R}^{d_m \times d_c} to get compressed latent ciRdc\boldsymbol{c}_i \in \mathbb{R}^{d_c} (e.g., dcd_c = 512).
  • Step 2: Left Branch (Latent Projection)
    • Training: ci\boldsymbol{c}i *multiplies by head-specific Wkc(s)Rdc×dk\boldsymbol{W}{kc}^{(s)} \in \mathbb{R}^{d_c \times d_k}* (green node) to get ciWkc(s)Rdk\boldsymbol{c}i \boldsymbol{W}{kc}^{(s)} \in \mathbb{R}^{d_k} (e.g.,dk=128 d_k = 128).
    • Decoding: This projection is missing (red node labeled "Missing Projection")! Instead, it uses direct ci\boldsymbol{c}_i (no multiplication).
  • Step 3: Right Branch (Rotary Embedding)
    • Both: xi\boldsymbol{x}i multiplies by WkrRdm×dr\boldsymbol{W}{kr} \in \mathbb{R}^{d_m \times d_r} (e.g., dr=64d_r = 64), then applies RoPE (Ri\mathcal{R}_i) via a curved self-loop (for position encoding).
    • Result: xiWkrRiRdr\boldsymbol{x}i \boldsymbol{W}{kr} \mathcal{R}_i \in \mathbb{R}^{d_r}.
  • Step 4: Concatenation to Key
    • Merge the branches: Training gets full Key ki(s)Rdk+dr\boldsymbol{k}_i^{(s)} \in \mathbb{R}^{d_k + d_r}; Decoding gets incomplete \boldsymbol{k}_i \in \mathbb{R}^{d_c + d_r}$ (note $d_c > d_k, so dimensions mismatch slightly in structure).
  • Step 5: Applying RMSNorm
    • Training: Norm works on the full vector (green).
    • Decoding: Norm fails due to missing component (red)—can't compute the norm of the absent projection!
  • Step 6: Output to Attention
    • Training: Stable output.
    • Decoding: Potential instability.

The animation reveals this flow sequentially, ending with a note on dimensions and QK-Clip's fix. Colors make the "missing piece" pop—training flows smoothly in green, decoding stalls in red.

The Genius of QK-Clip: Kimi's Smart Solution, and Why This Matters

Kimi researchers spotted that Muon (an optimizer without weight decay) causes "MaxLogit explosions" (huge pre-Softmax values). Instead of runtime norms (which break in MLA decoding), QK-Clip clips the weights Wq\boldsymbol{W}_q and Wk\boldsymbol{W}_k during training when MaxLogit exceeds a threshold τ\tau:

Final version:

If Smax(l,h)>τS_{\max}^{(l,h)} > \tau:

For column-wise weights like Wqc(l,h),Wkc(l,h)\boldsymbol{W}{qc}^{(l,h)}, \boldsymbol{W}{kc}^{(l,h)}: WW×τ/Smax(l,h)\boldsymbol{W} \leftarrow \boldsymbol{W} \times \sqrt{\tau / S_{\max}^{(l,h)}}

For row-wise like Wqr(l,h)\boldsymbol{W}{qr}^{(l,h)}: WW×τ/Smax(l,h)\boldsymbol{W} \leftarrow \boldsymbol{W} \times \tau / S{\max}^{(l,h)}.

This "absorbs" the fix into the weights permanently, so inference stays stable without structural changes. It's smart because it's per-head, avoids over-clipping, and works with MLA's latent tricks—enabling Kimi K2 to scale to billions of parameters. For you as a dev, this means more reliable APIs: Fewer crashes in long sessions, better outputs without custom hacks.

We started with a curious comment on a Kimi blog and ended with a visualization demystifying MLA's training-inference gap. This isn't just academic—understanding these internals helps you choose models wisely (e.g., Kimi for memory-efficient chats), and appreciate the engineering smarts behind them. Kimi's QK-Clip shows how targeted fixes can push LLMs further, making your GenAI apps more robust.

I have attached the code for the animation at the end. Watch the video—it's enlightening! Questions? Drop a comment. Happy building! 🚀

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib.patches as patches
# --- GRAPH CREATION (Unchanged from your original) ---
def create_graph(mode='training'):
"""Creates the graph for Training or Decoding."""
G = nx.DiGraph()
nodes = [
(r'$x_i$' + '\n' + r'$\in \mathbb{R}^{d_m}$', 'x_i'),
(r'$\boldsymbol{W}_c$' + '\n' + r'$\in \mathbb{R}^{d_m \times d_c}$', 'W_c'),
(r'$\boldsymbol{c}_i$' + '\n' + r'$\in \mathbb{R}^{d_c}$', 'c_i'),
(r'$\boldsymbol{W}_{kr}$' + '\n' + r'$\in \mathbb{R}^{d_m \times d_r}$', 'W_kr'),
(r'RoPE' + '\n' + r'$\mathcal{R}_i$', 'R_i'),
(r'$x_i \boldsymbol{W}_{kr} \mathcal{R}_i$' + '\n' + r'$\in \mathbb{R}^{d_r}$', 'x_W_kr_R'),
(r'$\boldsymbol{W}_{kc}^{(s)}$' + '\n' + r'$\in \mathbb{R}^{d_c \times d_k}$' if mode == 'training'
else r'$\boldsymbol{W}_{kc}^{(s)}$' + '\n' + '(Missing Projection)', 'W_kc^s'),
(r'$\boldsymbol{c}_i \boldsymbol{W}_{kc}^{(s)}$' + '\n' + r'$\in \mathbb{R}^{d_k}$' if mode == 'training'
else 'Direct Path' + '\n' + r'(No Projection)', 'c_W_kc'),
(r'Key: $\boldsymbol{k}_i^{(s)}$' + '\n' + r'$\in \mathbb{R}^{d_k + d_r}$' if mode == 'training'
else r'Incomplete Key: $\boldsymbol{k}_i$' + '\n' + r'$\in \mathbb{R}^{d_c + d_r}$', 'k_i'),
('RMSNorm', 'RMSNorm'),
('Stable Attention' + '\n' + 'Output', 'Output')
]
for label, node_id in nodes:
G.add_node(node_id, label=label)
edges = [
('x_i', 'c_i', r'+'), ('W_c', 'c_i', ''),
('x_i', 'x_W_kr_R', r'+'), ('W_kr', 'x_W_kr_R', ''),
('c_i', 'c_W_kc', '' if mode == 'training' else ''),
('W_kc^s', 'c_W_kc', '' if mode == 'training' else ''),
('c_W_kc', 'k_i', 'Concat'),
('x_W_kr_R', 'k_i', 'Concat'),
('k_i', 'RMSNorm', ''),
('RMSNorm', 'Output', '')
]
for u, v, label in edges:
G.add_edge(u, v, label=label)
G.add_edge('R_i', 'x_W_kr_R', label=r'$\circlearrowleft \mathcal{R}_i$')
color_map = {
'x_i': 'skyblue', 'W_c': 'lightgray', 'c_i': 'skyblue',
'W_kr': 'lightgray', 'R_i': 'lightgray', 'x_W_kr_R': 'skyblue',
'W_kc^s': 'lightgreen' if mode == 'training' else 'salmon',
'c_W_kc': 'lightgreen' if mode == 'training' else 'salmon',
'k_i': 'skyblue' if mode == 'training' else 'lightcoral', # Mismatch color
'RMSNorm': 'lightgreen' if mode == 'training' else 'salmon',
'Output': 'lightgreen' if mode == 'training' else 'salmon'
}
# Create the vertically stacked "meta-nodes" for layout purposes
# The original image implies these are stacked components
# Left Stack
G.nodes['W_c']['stack'] = 'left_stack'
G.nodes['c_i']['stack'] = 'left_stack'
G.nodes['W_kc^s']['stack'] = 'left_stack'
G.nodes['c_W_kc']['stack'] = 'left_stack'
# Right Stack
G.nodes['W_kr']['stack'] = 'right_stack'
G.nodes['x_W_kr_R']['stack'] = 'right_stack'
return G, color_map
# --- REVISED LAYOUT ---
# Spaced out for clarity, especially the left and right branches
pos = {
'x_i': (0, 10),
# Left Branch (Compressed Path) - now stacked
'W_c': (-6, 8.5), 'c_i': (-6, 6.5), 'W_kc^s': (-6, 4.5), 'c_W_kc': (-6, 2.5),
# Right Branch (Rotary Path) - now stacked
'W_kr': (6, 8.5), 'x_W_kr_R': (6, 6.5), 'R_i': (9.5, 6.5),
# Merge and Final Output
'k_i': (0, 0), 'RMSNorm': (0, -3), 'Output': (-0, -5.5)
}
# --- REVISED DRAWING FUNCTION ---
def draw_graph(ax, G, color_map, title, annotation):
ax.clear()
ax.set_title(title, fontsize=16, weight='bold', y=1.05)
ax.axis('off')
ax.set_xlim(-12, 12)
ax.set_ylim(-8, 11)
# --- Draw Edges First (so nodes are drawn on top) ---
# Draw regular edges with slight curve
regular_edges = [e for e in G.edges() if e not in [('R_i', 'x_W_kr_R')]]
nx.draw_networkx_edges(G, pos, edgelist=regular_edges, ax=ax,
arrows=True, arrowstyle='-|>,head_width=0.6,head_length=1.0',
arrowsize=20, width=2.0, edge_color='gray',
connectionstyle='arc3,rad=0.1')
# Draw RoPE "edge" as a curved arrow pointing from the side
rad = 0.8
arrow = patches.FancyArrowPatch(pos['R_i'], pos['x_W_kr_R'],
connectionstyle=f"arc3,rad={rad}",
color="gray", arrowstyle='-|>,head_width=0.6,head_length=1.0',
linewidth=2)
ax.add_patch(arrow)
# --- Draw Edge Labels ---
edge_labels = nx.get_edge_attributes(G, 'label')
# Use a transparent bbox for clean look
bbox_props = dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.8)
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=10,
label_pos=0.5, ax=ax, bbox=bbox_props,
font_color='black')
# --- Draw Nodes with Auto-fitting Bounding Boxes ---
node_labels = nx.get_node_attributes(G, 'label')
for node, (x, y) in pos.items():
label = node_labels[node]
color = color_map[node]
# Draw the text with a bounding box that fits it perfectly
ax.text(x, y, label, ha='center', va='center', fontsize=10,
bbox=dict(boxstyle='square,pad=0.7', fc=color, ec='black', lw=1.5))
# --- Draw Annotation Below Graph ---
ax.text(0, -7.5, annotation, ha='center', va='center', fontsize=11,
bbox=dict(facecolor='whitesmoke', edgecolor='gray', boxstyle='round,pad=0.5'))
# --- FIGURE AND ANIMATION SETUP ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(22, 11))
fig.suptitle('MLA vs. Standard Attention: Key Formation During Training and Inference', fontsize=20, weight='bold')
# Create graphs
G_train, cmap_train = create_graph('training')
G_dec, cmap_dec = create_graph('decoding')
# Define animation paths
path_edges = [
('x_i', 'c_i'), ('x_i', 'x_W_kr_R'),
('c_i', 'c_W_kc'), ('x_W_kr_R', 'k_i'),
('c_W_kc', 'k_i'), ('k_i', 'RMSNorm'),
('RMSNorm', 'Output')
]
# Animation function to highlight paths
def animate(frame):
# Draw base graphs
draw_graph(ax1, G_train, cmap_train, 'Training / Prefill Phase',
'Full projection materializes the full key.\nRMSNorm is applied to a complete, stable vector.')
draw_graph(ax2, G_dec, cmap_dec, 'Decoding / Inference Phase',
'The head-specific projection matrix is missing.\nRMSNorm sees an unstable, incomplete key structure.')
# Highlight edges sequentially
highlight_edges_train = path_edges[:frame+1]
nx.draw_networkx_edges(G_train, pos, edgelist=highlight_edges_train, ax=ax1,
edge_color='orange', width=3.5,
arrows=True, arrowstyle='-|>,head_width=0.6,head_length=1.0',
connectionstyle='arc3,rad=0.1')
highlight_edges_dec = path_edges[:frame+1]
nx.draw_networkx_edges(G_dec, pos, edgelist=highlight_edges_dec, ax=ax2,
edge_color='orange', width=3.5,
arrows=True, arrowstyle='-|>,head_width=0.6,head_length=1.0',
connectionstyle='arc3,rad=0.1')
# Manually highlight RoPE path when its turn comes
if frame >= 1: # Highlight after the first step
rad = 0.8
arrow_train = patches.FancyArrowPatch(pos['R_i'], pos['x_W_kr_R'],
connectionstyle=f"arc3,rad={rad}",
color="orange", arrowstyle='-|>,head_width=0.6,head_length=1.0',
linewidth=3.5, zorder=10)
ax1.add_patch(arrow_train)
arrow_dec = patches.FancyArrowPatch(pos['R_i'], pos['x_W_kr_R'],
connectionstyle=f"arc3,rad={rad}",
color="orange", arrowstyle='-|>,head_width=0.6,head_length=1.0',
linewidth=3.5, zorder=10)
ax2.add_patch(arrow_dec)
# Run and save the animation or just show the final static frame
num_frames = len(path_edges)
# To see the final static image:
animate(num_frames - 1)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("mla_visualization_improved.png", dpi=150, bbox_inches='tight')
plt.show()
# To save the animation (uncomment the lines below):
ani = FuncAnimation(fig, animate, frames=num_frames, interval=1200, repeat=False)
ani.save('mla_visualization_improved.mp4', writer='ffmpeg', dpi=150)