11 Chapter 10: Attention Mechanisms & Modern Transformer Architectures
12 Introduction: The Attention Story
This document provides a comprehensive technical guide to attention mechanisms and modern transformer architectures, designed for ML/AI interview preparation and practical understanding.
Coverage:
Core attention mechanisms (Q/K/V, multi-head, RoPE, GQA)
Efficiency techniques (sparse attention, linear attention, Flash Attention)
Alternative architectures (SSMs/Mamba, RWKV, Delta attention)
Scaling methods (MoE, sparse architectures, low-rank compression)
Vision transformers (ViT, DeiT, Swin, CLIP)
Production systems (DeepSeek, Qwen, LLaMA, Kimi)
The Evolution:
Full Attention (2017): Vaswani’s transformer with \(O(n^2)\) complexity–breakthrough but limited by quadratic scaling
Complexity Crisis: Quadratic cost becomes bottleneck for long contexts (64K+ tokens)
Three Solution Paths:
Sparse/Structured Attention: Limit connectivity (sliding windows, grouped representatives)
Linear/Kernelized Attention: Factorize softmax kernel for \(O(n)\) complexity
State-Space Models: Replace attention with recurrent dynamics (Mamba, RWKV)
Capacity Scaling: Mixture-of-Experts (MoE) for massive models with sparse activation
Modern Hybrids (2024): Combine multiple techniques–DeepSeek (MoE + sparse attention + MLA), Kimi (Delta attention), Qwen (GQA + dual-chunk)
Unifying Theme: All mechanisms can be viewed as content-based selection with relaxed gates–whether softmax attention over tokens, kernelized approximations, recurrent state updates, or expert routing.
How to Use This Document:
13 Core Attention Mechanism
13.1 Content-Based Lookup: The Q/K/V Paradigm
Given a query \(q \in \mathbb{R}^{d_k}\) and memory with keys \(k_j \in \mathbb{R}^{d_k}\) and values \(v_j \in \mathbb{R}^{d_v}\), attention computes a weighted retrieval: \[\begin{align} s_j & = q^\top k_j \quad \text{(similarity scores)} \\ \alpha_j & = \mathrm{softmax}\left(\frac{s_j}{\sqrt{d_k}}\right) = \frac{\exp(s_j / \sqrt{d_k})}{\sum_l \exp(s_l / \sqrt{d_k})} \quad \text{(softmax weights)} \\ y & = \sum_j \alpha_j v_j \quad \text{(weighted sum of values)} \end{align}\]
Why Q/K/V Factorization?
Asymmetric roles: Query asks, key describes location, value stores content
Expressiveness: Learn bilinear similarity \(q^\top W_Q^\top W_K k\) beyond simple dot products
Multi-head decomposition: Different subspaces capture different aspects
Caching: In autoregressive generation, cache past keys/values once
13.2 Softmax as Relaxed Selector
Softmax approximates a hard argmax. With temperature \(T\): \[\alpha_j(T) = \frac{\exp(s_j / T)}{\sum_l \exp(s_l / T)}\]
\(T \to 0\): Concentrates on \(\operatorname*{arg\,max}_j s_j\) (hard selection)
\(T \to \infty\): Uniform distribution (no selection)
Standard \(T=\sqrt{d_k}\): Balances gradient flow and selectivity
Key Insight: Softmax acts as a differentiable gate for selecting relevant information. This principle extends to:
MoE gating: softmax over expert slots
Linear attention: kernel approximation of softmax
Delta attention: gated recurrent updates
13.3 Scaled Dot-Product Attention
For sequence length \(n\), stack queries/keys/values into matrices: \[\begin{align} Q & = \begin{bmatrix} q_1^\top \\ \vdots \\ q_n^\top \end{bmatrix} \in \mathbb{R}^{n \times d_k}, \quad K = \begin{bmatrix} k_1^\top \\ \vdots \\ k_n^\top \end{bmatrix} \in \mathbb{R}^{n \times d_k}, \quad V = \begin{bmatrix} v_1^\top \\ \vdots \\ q_n^\top \end{bmatrix} \in \mathbb{R}^{n \times d_v} \end{align}\]
where queries, keys, and values are computed via learned linear projections from input embeddings \(X \in \mathbb{R}^{n \times d_{\text{model}}}\): \[\begin{align} Q & = XW_Q, \quad W_Q \in \mathbb{R}^{d_{\text{model}} \times d_k} \\ K & = XW_K, \quad W_K \in \mathbb{R}^{d_{\text{model}} \times d_k} \\ V & = XW_V, \quad W_V \in \mathbb{R}^{d_{\text{model}} \times d_v} \end{align}\]
\[\begin{equation} \boxed{\text{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V} \end{equation}\]
Note: The full formula with dropout (applied during training) is: \[\text{Attention}(Q, K, V) = \text{Dropout}\left(\mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)\right) V\] Dropout is applied to attention weights (post-softmax) with typical rate \(p \in [0.1, 0.2]\).
What this means: After computing softmax probabilities \(\alpha_{ij}\), dropout randomly zeros out some elements:
Each \(\alpha_{ij}\) is set to 0 with probability \(p\) (typically 10-20%)
Remaining values are scaled by \(1/(1-p)\) to maintain expected sum
This prevents over-reliance on specific attention patterns during training
We omit dropout in most formulas for clarity, but it’s present in all implementations.
Typical Dimensions:
\(d_k = 64\) is standard (original Transformer: \(d_{\text{model}} = 512\), 8 heads \(\rightarrow\) \(d_k = 512/8 = 64\))
Modern LLMs often use \(d_k \in \{64, 80, 128\}\) (e.g., LLaMA: \(d_k = 128\))
Scaling factor: \(\sqrt{64} = 8\), \(\sqrt{128} \approx 11.3\)
Complexity:
Time: \(O(n^2 d_k)\) for computing \(QK^\top\), \(O(n^2 d_v)\) for weighted sum
Memory: \(O(n^2)\) to store attention weights
13.4 Q/K/V Attention Flow
Attention Computation Flow:
Input: Token embeddings \(x_1, x_2, \ldots, x_n \in \mathbb{R}^{d_{\text{model}}}\)
Project to Q/K/V:
Queries: \(q_i = W_Q x_i \in \mathbb{R}^{d_k}\) (what to look for)
Keys: \(k_j = W_K x_j \in \mathbb{R}^{d_k}\) (what is available)
Values: \(v_j = W_V x_j \in \mathbb{R}^{d_v}\) (content to retrieve)
Compute Scores: \(s_{ij} = q_i^\top k_j / \sqrt{d_k}\) for all \(j\)
Softmax: \(\alpha_{ij} = \frac{\exp(s_{ij})}{\sum_l \exp(s_{il})}\) (attention weights)
Dropout (training only): Randomly zero out attention weights:
Each \(\alpha_{ij}\) set to 0 with probability \(p \approx 0.1\)
Scale remaining by \(1/(1-p)\): ensures expected value unchanged
Result: \(\alpha'_{ij}\) with some connections dropped
Weighted Sum: \(y_i = \sum_j \alpha'_{ij} v_j\) (output for position \(i\))
Dropout Example: If softmax produces \(\alpha_i = [0.5, 0.3, 0.2]\) (attend to 3 tokens), dropout with \(p=0.1\) might zero the second element → \([0.5, 0, 0.2]\), then rescale → \([0.556, 0, 0.222]\) (scaled by \(1/0.9\)). Now position \(i\) only attends to tokens 1 and 3, forcing the model to learn robust attention patterns.
Example: For token at position \(i\) in sentence “The cat sat on the mat”:
\(q_i\) (query) asks: “What is relevant context for this position?”
Compare with all \(k_j\) (keys): \(k_1\) (The), \(k_2\) (cat), \(k_3\) (sat), etc.
High similarity \(q_i^\top k_j\) means position \(j\) is relevant
Softmax normalizes: \(\alpha_{i2}\) might be large if “cat” is most relevant
Output \(y_i\) is weighted combination: mostly \(v_2\) (cat’s value) + some context from others
14 The Transformer Architecture (Vaswani et al., 2017)
14.1 Original Transformer: Encoder-Decoder
The original transformer combined:
Multi-head self-attention: \(h\) parallel attention heads with different \(W_Q^{(h)}, W_K^{(h)}, W_V^{(h)}\)
Position-wise FFN: \(\text{FFN}(x) = W_2 \sigma(W_1 x + b_1) + b_2\) applied independently to each position
Residual connections + Post-LayerNorm: \(\text{LN}(x + \text{Sublayer}(x))\) (post-norm)
Positional encodings: Sinusoidal functions \(\text{PE}_{pos, 2i} = \sin(pos / 10000^{2i/d})\)
Post-Norm vs. Pre-Norm:
Original Transformer (Vaswani et al., 2017) – Post-Norm: \[x_{\text{out}} = \text{LayerNorm}(x + \text{Sublayer}(x))\] Normalization applied after residual addition.
Modern Transformers (GPT-2 onwards) – Pre-Norm: \[x_{\text{out}} = x + \text{Sublayer}(\text{LayerNorm}(x))\] Normalization applied before sublayer, residual wraps the entire block.
Why the change?
Post-norm: Stable for shallow models (6-12 layers), but gradient flow degrades in very deep networks
Pre-norm: Better gradient flow, enables training 50+ layer models without warmup, more stable optimization
Timeline:
2017: Vaswani (original) – post-norm
2019: GPT-2 – switched to pre-norm
2020+: LLaMA, Qwen, DeepSeek – all use pre-norm (often with RMSNorm instead of LayerNorm)
14.2 Multi-Head Attention
\[\begin{equation} \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \end{equation}\] where each head computes: \[\text{head}_i = \text{Attention}(Q W_Q^{(i)}, K W_K^{(i)}, V W_V^{(i)})\]
Intuition: Different heads attend to different aspects (syntactic vs. semantic relationships, local vs. global context).
14.3 Self-Attention vs Cross-Attention
Self-attention: Q, K, V all from same sequence: \(\text{attn}(X, X, X)\)
Cross-attention: Q from one sequence, K/V from another: \(\text{attn}(\text{decoder}, \text{encoder}, \text{encoder})\)
PyTorch Example (BART decoder layer):
# 1. Self-attention: decoder attends to past decoder tokens
x = self.self_attn(query=x, key=x, value=x, attn_mask=causal_mask)
# 2. Cross-attention: decoder queries encoder outputs
x = self.cross_attn(query=x, key=enc_out, value=enc_out)
# 3. Feed-forward
x = self.ffn(x)
Key difference: Self-attn uses \((x,x,x)\); cross-attn uses \((decoder, encoder, encoder)\). Cross-attention is how encoder-decoder models (BART, T5) pass source information to the decoder. Decoder-only models (GPT) skip cross-attention entirely.
Head Specialization (empirical observations):
Head 3: Subject-verb agreement (syntactic)
Head 7: Coreference resolution (semantic)
Head 11: Next-word prediction (local patterns)
14.4 Complete Transformer Layer: Data Flow
Transformer Layer Architecture (one complete block):
Input: Sequence \(X \in \mathbb{R}^{n \times d_{\text{model}}}\) (n tokens, each \(d_{\text{model}}\) dimensional)
Block 1: Multi-Head Self-Attention
Split into h heads: \(d_k = d_v = d_{\text{model}} / h\) per head
Per-head attention (for \(i = 1, \ldots, h\)):
Project: \(Q^{(i)} = X W_Q^{(i)}\), \(K^{(i)} = X W_K^{(i)}\), \(V^{(i)} = X W_V^{(i)}\)
Each \(W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{d_{\text{model}} \times d_k}\)
Compute: \(\text{head}_i = \text{Attention}(Q^{(i)}, K^{(i)}, V^{(i)}) \in \mathbb{R}^{n \times d_k}\)
Concatenate heads: \([\text{head}_1 \| \text{head}_2 \| \cdots \| \text{head}_h] \in \mathbb{R}^{n \times d_{\text{model}}}\)
Output projection: \(Z = \text{Concat}(\text{heads}) W^O\) where \(W^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}\)
Residual + Norm:
Post-norm: \(X' = \text{LayerNorm}(X + \text{Sublayer}(X))\) (original Transformer)
Pre-norm: \(X' = X + \text{Sublayer}(\text{LayerNorm}(X))\) (modern default, easier to train)
Block 2: Position-Wise Feed-Forward Network
Applied to each token independently (same weights for all positions):
For each row \(x_i \in \mathbb{R}^{d_{\text{model}}}\) of \(X'\):
\(\text{FFN}(x_i) = W_2 \sigma(W_1 x_i + b_1) + b_2\)
\(W_1 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}\) (typically \(d_{\text{ff}} = 4 \cdot d_{\text{model}}\))
\(W_2 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}}\)
\(\sigma\) is activation (ReLU originally, SwiGLU in modern LLMs)
Residual + Norm: \(X_{\text{out}} = \text{LayerNorm}(X' + \text{FFN}(X'))\)
Key Points:
“Position-wise” means applied independently to each token (no cross-token interaction in FFN)
\(W^O\) is the output projection that mixes information across heads after concatenation
FFN acts like a 1×1 convolution: same weights applied to every token position
Parameter Count per Transformer Layer:
\(12d_{\text{model}}^2\) including both attention and FFN. Breakdown:
Multi-Head Attention:
\(W_Q, W_K, W_V\): \(3 \times d_{\text{model}} \times d_{\text{model}} = 3d_{\text{model}}^2\)
\(W^O\) (output projection): \(d_{\text{model}} \times d_{\text{model}} = d_{\text{model}}^2\)
Attention subtotal: \(4d_{\text{model}}^2\)
Feed-Forward Network:
\(W_1\): \(d_{\text{model}} \times d_{\text{ff}} = d_{\text{model}} \cdot d_{\text{ff}}\)
\(W_2\): \(d_{\text{ff}} \times d_{\text{model}} = d_{\text{ff}} \cdot d_{\text{model}}\)
Biases: \(d_{\text{ff}} + d_{\text{model}}\) (negligible)
FFN subtotal: \(2 \cdot d_{\text{model}} \cdot d_{\text{ff}}\)
Total per layer: \(4d_{\text{model}}^2 + 2 d_{\text{model}} d_{\text{ff}}\)
When \(d_{\text{ff}} = 4 d_{\text{model}}\) (common default): \[4d_{\text{model}}^2 + 2 d_{\text{model}} \cdot 4d_{\text{model}} = 4d_{\text{model}}^2 + 8d_{\text{model}}^2 = 12d_{\text{model}}^2\]
Key Insight: FFN dominates parameter count (8 vs 4) when \(d_{\text{ff}} = 4d\), even though attention does all the cross-token mixing.
14.5 Causal Masking for Autoregressive Generation
For language modeling, prevent attending to future tokens: \[\text{mask}_{ij} = \begin{cases} 0 & \text{if } i < j \\ -\infty & \text{if } i \geq j \end{cases}\] Applied before softmax: \(\mathrm{softmax}(QK^\top / \sqrt{d_k} + \text{mask})\)
14.6 Complexity Analysis
Per layer with sequence length \(n\) and model dimension \(d\):
| Component | Time | Memory |
|---|---|---|
| Self-attention | \(O(n^2 d)\) | \(O(n^2)\) |
| Feed-forward | \(O(n d^2)\) | \(O(n d)\) |
| Total (L layers) | \(O(L(n^2 d + n d^2))\) | \(O(n^2 + n d)\) |
Bottleneck: For large \(n\) (long sequences), the \(O(n^2)\) attention dominates. For large \(d\) (wide models), FFN dominates. Modern models target both: efficient attention for long context, sparse/MoE for capacity.
Where does the \(O(n^2)\) cross-token mixing happen?
A common confusion: projections (\(Q = XW_Q\), \(K = XW_K\), \(V = XW_V\)) and FFN are tokenwise operations (\(O(nd)\) or \(O(nd^2)\)), so where is the quadratic complexity?
Answer: The \(O(n^2)\) interaction happens in the attention computation itself:
Compute scores: \(S = QK^\top \in \mathbb{R}^{n \times n}\)
Entry \(S_{ij}\) = how much token \(i\) attends to token \(j\)
Creates an \(n \times n\) matrix: every token interacts with every other token
Cost: \(O(n^2 d_k)\) for matrix multiplication
Softmax: \(A = \text{softmax}(S / \sqrt{d_k})\) (rowwise)
Each row sums to 1: attention weights over all \(n\) tokens
Cost: \(O(n^2)\) to process all entries
Weighted sum: \(\text{Output} = AV \in \mathbb{R}^{n \times d_v}\)
Each output token is a mixture of all tokens’ values
Row \(i\) of output = \(\sum_{j=1}^n A_{ij} V_j\) (weighted sum over sequence)
Cost: \(O(n^2 d_v)\) for matrix multiplication
Key insight: The \(n \times n\) attention matrix \(A\) is where tokens communicate. Without this, each token would only see itself (no context). The FFN operates on already-mixed representations (post-attention), so it can be tokenwise.
Memory: Storing \(A \in \mathbb{R}^{n \times n}\) requires \(O(n^2)\) memory. For \(n=4096\), that’s 16M floats per head. This is why long-context models need efficient attention (Flash Attention, sparse patterns, etc.).
15 Evolution: From Transformers to Modern Variants
15.1 GPT, BERT, and the Decoder-Only Paradigm
GPT (Radford et al., 2018): Decoder-only transformer with causal masking, trained as language model.
BERT (Devlin et al., 2019): Encoder-only with bidirectional attention, trained with masked language modeling (MLM).
Modern LLMs (GPT-3, LLaMA, Qwen, etc.): Decoder-only architecture dominates due to:
Unified pre-training and inference (autoregressive)
Efficient KV caching in generation
Scaling laws favor decoder-only for next-token prediction
15.2 Architectural Improvements in Modern Transformers
LLaMA (Touvron et al., 2023) and descendants (Qwen, Mistral, DeepSeek):
RoPE (Rotary Position Embeddings): Apply rotation matrices to Q/K instead of additive positional encoding \[\begin{equation} q_m = R_m q, \quad k_n = R_n k, \quad \text{where } R_m^{(i)} = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix} \end{equation}\] Each dimension pair \((2i, 2i+1)\) has fixed frequency \(\theta_i = 10000^{-2i/d_k}\) (not learned). Lower dimensions rotate faster, higher dimensions rotate slower. Natural extrapolation to longer sequences via rotation interpolation.
TipExampleRoPE Data Flow (for token at position \(m\)):
Input: Token embedding \(x_m \in \mathbb{R}^{d_{\text{model}}}\) at position \(m\)
Project to query: \(q = W_Q x_m \in \mathbb{R}^{d_k}\) (standard attention projection)
Apply rotations per dimension pair:
For each pair \((q_{2i}, q_{2i+1})\) where \(i = 0, 1, \ldots, d_k/2 - 1\):
Compute frequency: \(\theta_i = 10000^{-2i/d_k}\)
Compute angle: \(\alpha_m^{(i)} = m \cdot \theta_i\) (position \(\times\) frequency)
Rotate: \(\begin{bmatrix} q_m^{(2i)} \\ q_m^{(2i+1)} \end{bmatrix} = \begin{bmatrix} \cos(\alpha_m^{(i)}) & -\sin(\alpha_m^{(i)}) \\ \sin(\alpha_m^{(i)}) & \cos(\alpha_m^{(i)}) \end{bmatrix} \begin{bmatrix} q_{2i} \\ q_{2i+1} \end{bmatrix}\)
Result: Rotated query \(q_m \in \mathbb{R}^{d_k}\) with position \(m\) encoded
Attention score: \(q_m^\top k_n = (R_m q)^\top (R_n k) = q^\top R_m^\top R_n k = q^\top R_{n-m} k\)
Key insight: Rotation property means attention score depends on relative position \((n-m)\)
No absolute position needed after rotation–relative position emerges naturally!
Concrete Example (\(d_k = 4\), position \(m=5\)):
Frequencies: \(\theta_0 = 1.0\), \(\theta_1 = 10000^{-2/4} \approx 0.01\)
Angles at position 5: \(\alpha_5^{(0)} = 5 \cdot 1.0 = 5\) rad, \(\alpha_5^{(1)} = 5 \cdot 0.01 = 0.05\) rad
Pair \((q_0, q_1)\) rotates by \(5\) radians (fast rotation)
Pair \((q_2, q_3)\) rotates by \(0.05\) radians (slow rotation)
Result: Low dimensions encode fine-grained local position, high dimensions encode coarse global position
SwiGLU activation (in FFN): \(\text{SwiGLU}(x) = \text{Swish}(W_1 x) \odot (W_2 x)\) improves over ReLU/GELU in feed-forward layers
RMSNorm: \(\text{RMSNorm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \epsilon}} \cdot \gamma\) (faster than LayerNorm)
Pre-normalization: Apply norm before sublayer (more stable training)
15.3 Grouped-Query Attention (GQA)
Problem: Multi-head attention caches separate K/V for each head. With \(h\) heads, sequence length \(n\), dimension \(d_k\): \[\text{KV cache size} = 2 \times h \times n \times d_k\] For LLaMA-70B (\(h=64\), \(d_k=128\), \(n=4096\)): \(\sim\)67M values per layer!
Multi-Query Attention (MQA): Use 1 shared K/V head for all Q heads:
Each token still has \(h\) query heads: \(Q_1, Q_2, \ldots, Q_h\)
But only 1 key head \(K\) and 1 value head \(V\) shared across all queries
Each Q head attends to the same \(K, V\): \(\text{head}_i = \text{Attention}(Q_i, K, V)\)
KV cache reduction: \(h\times\) smaller
Grouped-Query Attention (GQA): Middle ground between MHA and MQA:
Partition \(h\) query heads into \(g\) groups (where \(g < h\))
Each group shares one K/V head
Example: 32 Q heads, 8 KV groups \(\rightarrow\) each KV head serves 4 Q heads
GQA Example: 32 query heads, 8 KV heads (4:1 grouping)
Architecture:
Query projections: \(Q_1, Q_2, \ldots, Q_{32}\) (32 separate \(W_Q\) matrices)
Key projections: \(K_1, K_2, \ldots, K_8\) (8 shared \(W_K\) matrices)
Value projections: \(V_1, V_2, \ldots, V_8\) (8 shared \(W_V\) matrices)
Grouping:
Heads \(Q_1, Q_2, Q_3, Q_4\) attend to \(K_1, V_1\)
Heads \(Q_5, Q_6, Q_7, Q_8\) attend to \(K_2, V_2\)
\(\vdots\)
Heads \(Q_{29}, Q_{30}, Q_{31}, Q_{32}\) attend to \(K_8, V_8\)
Computation per group: \[\text{head}_i = \text{Attention}(Q_i, K_{\lceil i/4 \rceil}, V_{\lceil i/4 \rceil})\]
KV Cache Reduction:
Standard MHA: \(2 \times 32 \times n \times d_k\) values
GQA (8 groups): \(2 \times 8 \times n \times d_k\) values
Reduction: \(4\times\) smaller cache
Comparison:
| Method | Q Heads | KV Heads | KV Cache Size |
|---|---|---|---|
| Multi-Head (MHA) | \(h\) | \(h\) | \(2hn d_k\) |
| Grouped-Query (GQA) | \(h\) | \(g\) | \(2gn d_k\) |
| Multi-Query (MQA) | \(h\) | \(1\) | \(2n d_k\) |
Empirical Findings:
MQA can hurt quality (all Q heads see identical K/V–less expressive)
GQA with \(g = h/4\) or \(h/8\) recovers most quality while maintaining cache savings
Used in: LLaMA-2/3 (8 KV groups), Qwen2.5 (4 KV heads for 28 Q heads), Mistral
16 Structured and Sparse Attention
16.1 Motivation: Reducing Quadratic Cost
Instead of full \(n \times n\) attention, restrict connectivity:
Local/Sliding Window: Each token attends to window of size \(w\): \(O(nw)\)
Strided/Dilated: Attend every \(k\)-th position (like dilated convolutions)
Block-Sparse: Divide sequence into blocks, attend within block + global tokens
Longformer/BigBird: Combine local + global + random attention
16.2 Sliding Window Attention
Used in Mistral, Qwen, Kimi models: \[\begin{equation} \alpha_{ij} = \begin{cases} \frac{\exp(q_i^\top k_j / \sqrt{d_k})}{\sum_{l=\max(0, i-w)}^{i} \exp(q_i^\top k_l / \sqrt{d_k})} & \text{if } i - w \leq j \leq i \\ 0 & \text{otherwise} \end{cases} \end{equation}\]
Tradeoff:
+ Linear complexity \(O(nw d)\)
- Information flow limited: token at position \(i\) can’t directly attend to position \(j\) if \(|i-j| > w\)
Mitigation: Stack layers so \(L\) layers give effective receptive field \(L \cdot w\)
Why Stacking Layers Increases Receptive Field:
With window size \(w\), each layer can only directly attend within distance \(w\). But stacking layers allows indirect information flow:
Example: Window size \(w = 4\), token at position \(i = 10\)
Layer 1: Token 10 attends to positions [6, 7, 8, 9, 10] (window of 4 before + self)
Layer 2: Token 10’s representation now contains information from [6–10]
Token 10 again attends to [6–10], but token 6’s representation contains info from [2–6]
So token 10 indirectly accesses positions [2–10] via token 6
Layer 3: Effective receptive field extends to [0–10] (reaching position \(10 - 3 \times 4 = -2 \approx 0\))
General Rule: With \(L\) layers and window \(w\), token at position \(i\) can access information from positions \([i - L \cdot w, i]\).
Why This Works: Same principle as CNNs–each conv layer has small kernel (e.g., 3×3), but stacking layers increases receptive field. In attention, the "kernel" is the sliding window.
16.3 Flash Attention: IO-Aware Exact Attention
Flash Attention (Dao et al., 2022) computes exact attention but optimizes memory hierarchy:
Tile computation to fit in SRAM (on-chip cache)
Recompute attention scores in backward pass instead of storing full \(n \times n\) matrix
Achieves same \(O(n^2)\) FLOPs but 2-4\(\times\) speedup via reduced HBM (off-chip memory) access
Key Insight: Modern GPU bottleneck is memory bandwidth, not compute. Flash Attention is algorithm-hardware co-design.
17 Linear and Kernelized Attention
17.1 Motivation: Factorizing the Softmax Kernel
Standard attention: \[y_i = \sum_j \frac{\exp(q_i^\top k_j / \sqrt{d_k})}{\sum_l \exp(q_i^\top k_l / \sqrt{d_k})} v_j\]
If we can approximate \(\exp(q^\top k) \approx \phi(q)^\top \phi(k)\) for some feature map \(\phi: \mathbb{R}^{d_k} \to \mathbb{R}^{d_\phi}\), then: \[\begin{align} y_i & \approx \frac{\phi(q_i)^\top \sum_j \phi(k_j) v_j^\top}{\phi(q_i)^\top \sum_j \phi(k_j)} \\ & = \frac{\phi(q_i)^\top S}{\phi(q_i)^\top z} \end{align}\] where \(S = \sum_j \phi(k_j) v_j^\top \in \mathbb{R}^{d_\phi \times d_v}\) and \(z = \sum_j \phi(k_j) \in \mathbb{R}^{d_\phi}\).
Complexity: Computing \(S\) and \(z\) requires \(O(n d_\phi d_v)\) time, then each query \(q_i\) takes \(O(d_\phi d_v)\). Total: \(O(n d_\phi d_v)\) vs. \(O(n^2 d)\).
17.2 Performer: Random Fourier Features
Performer (Choromanski et al., 2021) uses random Fourier features: \[\phi(x) = \frac{1}{\sqrt{m}} \begin{bmatrix} \exp(w_1^\top x) \\ \vdots \\ \exp(w_m^\top x) \end{bmatrix}, \quad w_i \sim \mathcal{N}(0, I)\] Approximates Gaussian kernel \(\exp(q^\top k)\) in expectation.
Why Random Fourier Features Approximate Gaussian Kernel:
The Gaussian (RBF) kernel is \(k(q, k) = \exp(q^\top k)\) (unnormalized, ignoring the \(-\|q-k\|^2/2\) term in standard RBF).
Bochner’s Theorem: Any shift-invariant kernel \(k(x-y)\) can be expressed as: \[k(x - y) = \int p(w) e^{i w^\top (x-y)} dw = \mathbb{E}_{w \sim p}[\cos(w^\top x - w^\top y)]\]
For Gaussian kernel \(\exp(-\|x-y\|^2/2)\), the frequency distribution \(p(w)\) is Gaussian \(\mathcal{N}(0, I)\).
Random Fourier Features: Approximate the expectation with Monte Carlo sampling:
Sample \(m\) random vectors: \(w_1, \ldots, w_m \sim \mathcal{N}(0, I)\)
Define feature map: \(\phi(x) = \frac{1}{\sqrt{m}} [\exp(w_1^\top x), \ldots, \exp(w_m^\top x)]^\top\)
Then: \(\phi(q)^\top \phi(k) = \frac{1}{m} \sum_{i=1}^m \exp(w_i^\top q) \exp(w_i^\top k) = \frac{1}{m} \sum_{i=1}^m \exp(w_i^\top (q + k))\)
Why this works: By Bochner’s theorem, this Monte Carlo average converges to \(\mathbb{E}_{w}[\exp(w^\top q) \exp(w^\top k)] \approx \exp(q^\top k)\) as \(m \to \infty\).
Practical Note: Performer uses complex exponentials (cosine + sine components) for better numerical stability, but the core idea is the same.
17.3 Linear Attention with Causal Masking
For autoregressive generation: \[\begin{align} S_t & = S_{t-1} + \phi(k_t) v_t^\top \\ z_t & = z_{t-1} + \phi(k_t) \\ y_t & = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t} \end{align}\]
Recurrent formulation: State \((S_t, z_t)\) can be updated in \(O(d_\phi d_v)\) per step–constant memory!
17.4 Limitations of Linear Attention
- Approximation error: \(\phi(q)^\top \phi(k) \neq \exp(q^\top k)\) exactly
- Softmax selectivity lost: linear attention tends to be more diffuse (less peaked)
- Empirically underperforms full attention on complex tasks (NLU benchmarks)
How Good is Performer in Practice?
Speed: Significant speedup on long sequences (\(n > 2048\)). For \(n = 16K\), Performer is 3-4\(\times\) faster than standard attention.
Quality: Mixed results:
Good: Works well on protein sequence modeling, long-range dependencies where attention is diffuse
Okay: Competitive on ImageNet classification, some language modeling tasks (small quality drop)
Poor: Underperforms on tasks requiring sharp attention (question answering, GLUE benchmarks)
Why? Linear attention can’t model the peaked, selective attention patterns needed for complex reasoning. The Gaussian kernel approximation loses the "winner-take-all" behavior of softmax.
Current Status: Largely superseded by:
Flash Attention: Exact attention, 2-4\(\times\) faster via IO optimization (no quality loss)
Sliding window: Mistral/Qwen use \(O(nw)\) with small quality drop, better than linear attention
Hybrid approaches: Dual-chunk (Qwen), Delta attention (Kimi) combine local exact + approximate global
Interview Takeaway: Performer was an important research direction (2021), but Flash Attention proved that optimizing exact attention is more practical than approximating it.
18 Delta Attention and Kimi
18.1 Kimi: Long-Context Chinese LLM
Kimi (Moonshot AI) targets 200K+ context windows via Delta Attention–a hybrid between linear attention and recurrent state updates.
18.2 Delta Attention Mechanism
Instead of storing full KV cache, maintain incremental state: \[\begin{align} S_t & = \gamma_t \odot S_{t-1} + \beta_t \odot \phi(k_t) v_t^\top \\ z_t & = \gamma_t \odot z_{t-1} + \beta_t \odot \phi(k_t) \\ y_t & = \frac{\phi(q_t)^\top S_t}{\phi(q_t)^\top z_t} \end{align}\] where \(\gamma_t, \beta_t\) are learned gates (sigmoid outputs): \[\gamma_t = \sigma(W_\gamma q_t), \quad \beta_t = \sigma(W_\beta k_t)\]
Notice how Delta Attention relates to linear attention.
Interpretation:
\(\gamma_t\): Decay rate for old information (forget gate)
\(\beta_t\): Importance of new key-value pair (input gate)
Analogous to GRU/LSTM gating mechanisms
18.3 Advantages and Tradeoffs
Advantages:
+ Sub-linear memory: State size \(O(d_\phi d_v)\) independent of sequence length
+ Efficient long-context inference without full KV cache
+ Gating provides learned compression vs. fixed kernel in linear attention
Tradeoffs:
- Sequential dependency in training (can’t parallelize over time as easily)
- Lossy compression: Some long-range dependencies may decay
19 State-Space Models (SSMs) and Mamba
19.1 From Attention to Recurrence: SSMs
State-space models provide an alternative to attention by modeling sequences as dynamical systems. The classical discrete-time linear SSM is:
\[\begin{align} s_t & = A s_{t-1} + B u_t \\ y_t & = C s_t + D u_t \end{align}\]
where:
\(s_t \in \mathbb{R}^N\) is the hidden state at time \(t\)
\(u_t \in \mathbb{R}^{d}\) is the input token representation
\(y_t \in \mathbb{R}^{d}\) is the output
\(A \in \mathbb{R}^{N \times N}\) governs state evolution (transition matrix)
\(B \in \mathbb{R}^{N \times d}\) maps input to state update
\(C \in \mathbb{R}^{d \times N}\) maps state to output (readout matrix)
\(D \in \mathbb{R}^{d \times d}\) is the skip connection (often identity or zero)
Key Properties:
Recurrent formulation: State \(s_t\) compresses history, requires \(O(N)\) memory
Fixed parameters: Classical SSMs have constant \(A, B, C, D\) for all time steps
Linear complexity: \(O(n d N)\) time for sequence of length \(n\)
19.2 SSMs as Convolutional Models
Unrolling the recurrence for time-invariant \(A, B, C\): \[y_t = \sum_{j=0}^{t-1} C A^{t-j-1} B u_j + D u_t = \sum_{j=0}^{t-1} h_{t-j} u_j + D u_t\] where \(h_k = C A^{k-1} B\) is the impulse response kernel.
This is a structured 1D convolution–can be computed efficiently via FFT in \(O(n \log n)\) for fixed \(A, B, C\).
19.3 Selective SSMs: The Mamba Innovation
Classical SSMs have time-invariant \(A, B, C, D\). Mamba (Gu & Dao, 2023) makes them input-dependent:
\[\begin{align} B_t & = W_B(u_t) \quad \text{(input-dependent input projection)} \\ C_t & = W_C(u_t) \quad \text{(input-dependent readout)} \\ A_t & = f_A(u_t) \quad \text{(optional: input-dependent transition)} \end{align}\]
The state update becomes: \[\begin{align} s_t & = A_t s_{t-1} + B_t u_t \\ h_t & = C_t s_t \end{align}\]
Intuition:
\(B_t\): Controls what information from input \(u_t\) gets stored in state
\(C_t\): Controls what information from state \(s_t\) contributes to output
Input-dependent \(B_t, C_t\) enable content-based filtering (like attention’s Q/K matching)
Key Difference from Classical SSMs:
\(B_t, C_t\) adapt based on input–content-dependent filtering
No longer expressible as time-invariant convolution
Requires sequential scan, but GPU-optimized via parallel prefix-sum algorithms
Connection to Attention:
Classical SSM with fixed \(B, C\) is like attention with fixed uniform weights–no content-based selection.
Mamba’s input-dependent \(B_t, C_t\) provides content-based gating similar to attention’s softmax weights, but with:
\(O(N)\) state memory vs. \(O(nd)\) KV cache
Implicit history compression vs. explicit key-value storage
Linear updates vs. quadratic attention matrix
19.4 Mamba Architecture
Mamba block replaces transformer attention: \[\begin{align} h_t & = \text{SSM}(u_t; A, B_t, C_t) \\ o_t & = \text{Gated-MLP}(h_t) \odot u_t \end{align}\]
Notation: \(s_t\) is the SSM’s internal hidden state, \(h_t = C_t s_t\) is the SSM output, and \(o_t\) is the final Mamba block output.
Stacks Mamba blocks like transformer layers. No explicit Q/K/V or softmax.
19.5 Comparison: Attention vs. SSMs
| Property | Attention | Mamba SSM |
|---|---|---|
| Complexity (time) | \(O(n^2 d)\) | \(O(n d N)\) |
| Complexity (memory) | \(O(n^2)\) | \(O(N)\) |
| Parallelization | Full (training) | Parallel scan |
| Long-range dependencies | Explicit | Implicit (state) |
| Content-based routing | Softmax gate | Input-dependent \(B_t, C_t\) |
| Inference cache | KV cache \(O(n d)\) | State \(O(N)\) |
When SSMs Excel:
Ultra-long sequences (audio, genomics, time series)
Streaming inference with bounded memory
Uniform computation per token (predictable latency)
When Attention Excels:
Complex reasoning requiring explicit attention to distant tokens
In-context learning (ICL) with few-shot prompts
Retrieval-augmented generation (RAG) where queries attend to retrieved docs
20 RWKV: Receptance Weighted Key Value
20.1 Motivation: RNN Efficiency + Transformer Parallelism
RWKV (Peng et al., 2023) is a linear attention variant that:
At inference: Runs as RNN with \(O(1)\) per-token memory
At training: Parallelizes like transformer
20.2 RWKV Attention Mechanism
RWKV introduces four learned projection matrices per token position: \[\begin{align} r_t & = W^R x_t \quad \text{(receptance -- analogous to query)} \\ k_t & = W^K x_t \quad \text{(key)} \\ v_t & = W^V x_t \quad \text{(value)} \end{align}\] where \(x_t\) is the input token embedding, and a learnable time-decay parameter \(w \in \mathbb{R}^d\) (one per channel, shared across time).
The name RWKV comes from the four components: \(\mathbf{R}\)eceptance, \(\mathbf{W}\)eight (decay), \(\mathbf{K}\)ey, \(\mathbf{V}\)alue.
Time-Weighted Attention: Define attention with exponential decay based on learnable \(w\): \[\alpha_{ij} = \frac{\exp(r_i^\top k_j - (i - j) w)}{\sum_{l=1}^{i} \exp(r_i^\top k_l - (i - l) w)}\] where \(w > 0\) is a learnable decay rate (larger \(w\) means faster forgetting of distant tokens). When \(w \in \mathbb{R}^d\) (vector), decay is applied element-wise per channel, allowing different forgetting rates across feature dimensions.
The output at position \(i\) is: \[y_i = \sum_{j=1}^{i} \alpha_{ij} v_j = \frac{\sum_{j=1}^{i} \exp(r_i^\top k_j - (i-j)w) v_j}{\sum_{j=1}^{i} \exp(r_i^\top k_j - (i-j)w)}\]
Recurrent Formulation: We can rewrite this as a recurrence by factoring out the decay. Define: \[\begin{align} s_t & = \sum_{j=1}^{t} e^{-(t-j)w} e^{r_t^\top k_j} v_j \\ n_t & = \sum_{j=1}^{t} e^{-(t-j)w} e^{r_t^\top k_j} \end{align}\]
Then \(y_t = s_t / n_t\) matches the time-weighted attention formula. The key insight is that these can be updated recursively: \[\begin{align} s_t & = e^{-w} s_{t-1} + e^{r_t^\top k_t} v_t \quad \text{(decay old state + add new)} \\ n_t & = e^{-w} n_{t-1} + e^{r_t^\top k_t} \quad \text{(normalizer)} \\ y_t & = s_t / n_t \quad \text{(normalized output)} \end{align}\]
Connection to Time-Weighted Attention:
At time \(t\), the numerator is: \[\begin{align*} s_t & = e^{-w} s_{t-1} + e^{r_t^\top k_t} v_t \\ & = e^{-w} \left[\sum_{j=1}^{t-1} e^{-(t-1-j)w} e^{r_t^\top k_j} v_j\right] + e^{r_t^\top k_t} v_t \\ & = \sum_{j=1}^{t-1} e^{-(t-j)w} e^{r_t^\top k_j} v_j + e^{r_t^\top k_t} v_t \\ & = \sum_{j=1}^{t} e^{-(t-j)w} e^{r_t^\top k_j} v_j \end{align*}\]
This is exactly the numerator of \(y_t\) in the time-weighted attention formula! The denominator \(n_t\) follows the same pattern. So the recurrent update computes the same values as the attention formula, but in \(O(1)\) per step instead of \(O(t)\).
Parallel Formulation (Training): Compute prefix sums in \(O(\log n)\) depth via parallel scan.
Key Innovation:
Receptance \(R\): Like query in attention, determines what to retrieve
Weight \(W\): Controls time-decay (how quickly to forget past tokens)
Exponential decay \(e^{-w(i-j)}\) replaces softmax’s position-independent weighting
State \((s_t, n_t)\) compresses history with learnable decay, enabling \(O(d)\) memory inference
20.3 Why RWKV Enables Recursion (and Softmax Doesn’t)
The Computational Trick:
Standard softmax attention cannot be recursively computed because the normalization is query-dependent: \[\alpha_{ij} = \frac{\exp(q_i^\top k_j / \sqrt{d})}{\sum_{l=1}^{n} \exp(q_i^\top k_l / \sqrt{d})}\] Each query \(q_i\) has its own normalization constant \(\sum_{l} \exp(q_i^\top k_l)\), requiring \(O(n)\) computation per position → \(O(n^2)\) total.
RWKV’s formulation avoids this by making the exponential decay factorizable: \[\exp(r_t^\top k_j - (t-j)w) = e^{r_t^\top k_j} \cdot e^{-(t-j)w}\] The time decay \(e^{-(t-j)w}\) is **independent of the query \(r_t\)**, so it can be recursively accumulated: \[\begin{align*} e^{-(t-j)w} = e^{-w} \cdot e^{-(t-1-j)w} \end{align*}\] This allows the state update \(s_t = e^{-w} s_{t-1} + e^{r_t^\top k_t} v_t\) to work.
Why Softmax Can’t Do This:
If we tried to write softmax in this form, we’d need: \[y_i = \frac{\sum_j \exp(q_i^\top k_j) v_j}{\sum_l \exp(q_i^\top k_l)}\] The denominator \(\sum_l \exp(q_i^\top k_l)\) **changes for every query position \(i\)**–there’s no way to factor it out and recurse. The softmax normalization creates a coupling across all positions that breaks recursive structure.
20.4 The Selectivity Tradeoff
What RWKV Loses:
Softmax attention has position-independent content selection: \[\text{If } q_i^\top k_j \gg q_i^\top k_l, \text{ then } \alpha_{ij} \approx 1 \text{ regardless of } |i-j|\] A query can attend strongly to any key with high similarity, even if it’s far away.
RWKV applies mandatory exponential decay: \[\alpha_{tj} \propto e^{r_t^\top k_j} \cdot e^{-(t-j)w}\] Even if \(r_t^\top k_j\) is very large, the term \(e^{-(t-j)w}\) forces distant tokens to have exponentially smaller weights. For \(w=0.1\) and \(t-j=100\), the decay factor is \(e^{-10} \approx 0.000045\), making that token nearly invisible.
Concrete Example:
Consider a question-answering task: “The capital of France, established in the 5th century, is Paris. [100 tokens of other content]. What is the capital of France?”
Softmax attention: Query at “France?” can attend strongly to “Paris” 100 tokens back if \(q \cdot k_{\text{Paris}}\) is high
RWKV: Even if \(r_t \cdot k_{\text{Paris}}\) is high, \(e^{-100w}\) severely downweights it (recency bias)
Why This Matters:
Tasks requiring random access (multi-hop reasoning, factual recall): Softmax wins
Tasks with strong locality (next-token prediction, code completion): RWKV competitive
RWKV’s fixed decay is a hard inductive bias–can’t be learned away, unlike attention patterns
Efficiency-Quality Frontier:
| Model | Inference Cost | Selectivity |
|---|---|---|
| Softmax Attention | \(O(n^2)\) | Full content-based |
| RWKV | \(O(n)\) train, \(O(1)\) infer | Content + mandatory time decay |
20.5 Comparison: RWKV vs. Mamba vs. Linear Attention
| Model | Mechanism | Inference Memory | Parallelizable? |
|---|---|---|---|
| Full Attention | Softmax Q/K/V | \(O(n d)\) | Yes |
| Linear Attention | Kernel \(\phi(q)^\top \phi(k)\) | \(O(d_\phi d_v)\) | Yes |
| Delta Attention | Gated state | \(O(d_\phi d_v)\) | Partial |
| RWKV | Exponential decay | \(O(d)\) | Yes (scan) |
| Mamba | Selective SSM | \(O(N)\) | Yes (scan) |
Where to Use RWKV:
RWKV is a **linear attention variant** with exponential time decay, positioned between full transformers and SSMs.
Does it replace full attention? Not in production. RWKV models (e.g., RWKV-4, RWKV-5) are research projects showing promising results but underperform GPT/LLaMA on most benchmarks.
Use Cases (research/niche):
Language modeling: Competitive on perplexity for long documents (e.g., books, code) where recency bias helps
Streaming inference: \(O(1)\) memory makes it viable for resource-constrained deployment (edge devices)
Long sequences: Better than transformers for 100K+ tokens, but worse quality than Mamba on same-length sequences
Limitations vs. Full Attention:
- Exponential decay bakes in recency bias–can’t attend equally to all positions like transformers
- Weaker on tasks requiring random access (question answering, multi-hop reasoning)
- No major models use RWKV in production (GPT-4, Claude, Gemini all use attention)
RWKV vs. Mamba:
Mamba: Better quality on most benchmarks (especially audio, time-series, genomics). Selective SSM more flexible than fixed exponential decay.
RWKV: Simpler architecture, easier to understand/implement. Good for learning about linear attention.
Current Status (2026): RWKV is an **interesting research direction** but hasn’t displaced transformers. Most practical long-context systems use:
Sliding window attention (Mistral, Qwen) – better quality tradeoff
Flash Attention – optimizes exact attention, no approximation
Hybrid models (dual-chunk, Delta attention) – combine exact + approximate
Interview Takeaway: RWKV shows linear attention is viable, but in practice, optimizing exact attention (Flash) or hybrid approaches (Mamba, dual-chunk) work better.
21 Mixture-of-Experts (MoE)
21.1 Scaling Capacity Without Proportional Compute
Key Idea: MoE replaces the dense FFN sublayer with sparse expert selection, keeping attention unchanged.
MoE Transformer Layer Structure:
# Standard Transformer Layer
x = x + Attention(x) # ← Multi-head attention (unchanged in MoE)
x = x + FFN(x) # ← Dense FFN: all parameters used
# MoE Transformer Layer
x = x + Attention(x) # ← SAME attention with QK^T (n×n mixing)
x = x + MoE(x) # ← Sparse expert routing replaces FFN
Critical point: MoE only replaces the FFN sublayer. The attention mechanism (including \(QK^\top\) and cross-token mixing) remains identical to standard transformers.
Why This Helps – Capacity vs. Computation Trade-off:
Problem: Dense FFN uses all parameters for every token
FFN with 4\(d\) hidden dim: \(2 \times d \times 4d = 8d^2\) parameters
All \(8d^2\) params activated per token → expensive for large models
MoE Solution: Split FFN into \(E\) experts, route each token to top-\(k\) (typically \(k=2\))
Total parameters: \(E \times 8d^2\) (e.g., 8 experts = \(8 \times\) capacity)
Active per token: \(k \times 8d^2 = 2 \times 8d^2\) (only 25% of experts)
Result: \(8\times\) more capacity, only \(2\times\) compute cost
Complexity Impact:
Dense FFN: \(O(nd^2)\) for \(n\) tokens, dimension \(d\)
MoE FFN: \(O(n \cdot k \cdot d^2)\) where \(k \ll E\) (sparse activation)
Example: 64 experts, top-2 routing → 32\(\times\) capacity increase, 2\(\times\) compute
Why Not Replace Attention?
Attention is already \(O(n^2 d)\) – bottleneck is sequence length, not model capacity
Attention needs to mix information across all tokens (can’t be sparse)
FFN is tokenwise (no cross-token dependency) → safe to make sparse
Practical Example (GPT-3 scale):
Dense model: 175B params, all active → 175B FLOPs/token
MoE equivalent: 1.7T params (10\(\times\)), top-2 of 64 experts → 350B FLOPs/token (2\(\times\))
Net gain: 10\(\times\) capacity for 2\(\times\) cost → 5\(\times\) efficiency improvement
MoE replaces dense feed-forward layers with many expert networks, routing each token to a small subset.
Basic MoE Layer: \[\begin{align} g & = W_g x \quad \text{(gating logits)} \\ p & = \mathrm{softmax}(g) = \mathrm{softmax}(W_g x) \quad \text{(probabilities)} \\ \mathcal{E} & = \text{TopK}(p, k) \quad \text{(select top-$k$ indices)} \\ y & = \sum_{e \in \mathcal{E}} p_e \cdot f_e(x) \end{align}\]
where:
\(x \in \mathbb{R}^d\) is the input token
\(W_g \in \mathbb{R}^{E \times d}\) produces \(E\) gating logits (one per expert)
\(p \in \mathbb{R}^E\) is the probability distribution over experts
\(\mathcal{E} \subseteq \{1, \ldots, E\}\) with \(|\mathcal{E}| = k\) are the selected expert indices
\(p_e\) scaled probabilities for selected experts in \(\mathcal{E}\)
\(f_e: \mathbb{R}^d \to \mathbb{R}^d\) is expert \(e\)’s feed-forward network
MoE Data Flow (8 experts, top-2 routing):
Input: Token \(x \in \mathbb{R}^{512}\)
Compute gating logits: \(g = W_g x \in \mathbb{R}^8\) (one score per expert)
Softmax: \(p = \mathrm{softmax}(g) = [0.31, 0.02, 0.28, 0.05, 0.19, 0.08, 0.04, 0.03]\)
Select top-2: \(\mathcal{E} = \{1, 3\}\) (experts with highest probabilities)
Compute expert outputs:
\(f_1(x) = W_{2,1} \text{ReLU}(W_{1,1} x)\) (expert 1)
\(f_3(x) = W_{2,3} \text{ReLU}(W_{1,3} x)\) (expert 3)
Renormalize selected probabilities: \(\tilde{p}_1 = 0.31/(0.31+0.28) \approx 0.53\), \(\tilde{p}_3 \approx 0.47\)
Weighted combination: \(y = 0.53 \cdot f_1(x) + 0.47 \cdot f_3(x)\)
Key Points:
Only 2 of 8 experts compute (25% density)
Total capacity = 8 experts, but compute cost = 2 experts
Each expert specializes (e.g., expert 1 for code, expert 3 for math)
Analogy to Attention:
Gate \(W_g x\) plays role of \(q^\top k\) (content-based routing)
Softmax over experts = softmax over keys
Expert outputs \(f_e(x)\) = values \(v_j\)
Key difference: Routing to functions (experts) not tokens (keys)
Gradient Flow Through TopK:
The TopK operation is non-differentiable (discrete selection), so MoE systems use a trick:
Forward Pass: Hard TopK selection
\(\mathcal{E} = \text{TopK}(p, k)\) – discrete selection, only \(k\) experts compute
\(y = \sum_{e \in \mathcal{E}} p_e \cdot f_e(x)\) – sparse computation
Backward Pass: Treat as soft (straight-through estimator)
Gradient flows through all softmax probabilities \(p\), not just selected experts
Effectively: \(\frac{\partial \mathcal{L}}{\partial g_e} = \frac{\partial \mathcal{L}}{\partial p_e} \cdot \frac{\partial p_e}{\partial g_e}\) for all \(e\)
This allows router weights \(W_g\) to learn from all tokens, even when expert \(e\) wasn’t selected
Why This Works:
Forward: Sparse (only \(k\) of \(E\) experts compute) – saves computation
Backward: Dense gradients to router – all experts learn when to activate
Without dense gradients, router would only update selected experts (poor exploration)
Implementation: In PyTorch, use custom autograd function or detach TopK mask in forward but ignore it in backward (gradient flows through original softmax).
21.2 Load-Balancing: Preventing Expert Collapse
Without constraints, softmax routing exhibits rich-get-richer dynamics (popular experts get more tokens, improve faster, become more popular). MoE systems use an auxiliary loss: \[\mathcal{L}_{\text{balance}} = \alpha \sum_{e=1}^E r_e \cdot P_e\] where:
\(r_e = \frac{1}{n} \sum_{i=1}^n \mathbb{1}[e \in \mathcal{E}_i]\) is the fraction of tokens routed to expert \(e\)
\(P_e = \frac{1}{n} \sum_{i=1}^n p_{i,e}\) is the average gate probability for expert \(e\)
\(n\) is total number of tokens in batch
Intuition: Penalize experts with high average gate scores (\(P_e\) large) but low actual utilization (\(r_e\) small). This encourages uniform load while allowing learned specialization.
How Auxiliary Loss is Incorporated:
Added to the language modeling loss: \[\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{LM}} + \alpha \cdot \mathcal{L}_{\text{balance}}\]
Mixtral uses \(\alpha = 0.01\). Router gradients come from both task loss and balance loss.
22 DeepSeek: Combining MoE with Sparse Attention
22.1 The Capacity vs. Context Length Problem
MoE solves capacity but not context length. DeepSeek augments MoE with a sparse, locality-aware attention pattern that reduces the \(O(n^2)\) cost while preserving the essence of content-based selection.
Instead of attending to all previous tokens, DeepSeek’s attention partitions the context into:
a local window: recent tokens (dense attention),
a grouped horizon: older tokens summarized at coarser granularity,
optional global tokens: syntactic/semantic anchors (e.g., BOS, instructions).
This yields Deterministic Sparse Attention (DSA) or Native Sparse Attention (NSA).
22.2 Sparse Attention Pattern: Local Windows + Grouped Representatives
DeepSeek’s Deterministic Sparse Attention (DSA) partitions context into:
Local Window: Recent \(w\) tokens (dense, full attention)
Grouped Horizon: Older tokens summarized at coarser granularity
Global Tokens: Optional anchors (BOS, instruction markers)
Why This Works – Attention Score Decay: In practice, attention weights decay rapidly with distance: \[\alpha_{tj} = \frac{\exp(q_t^\top k_j / \sqrt{d_k})}{\sum_{l=1}^t \exp(q_t^\top k_l / \sqrt{d_k})} \quad \text{where } \alpha_{tj} \to 0 \text{ for } j \ll t\]
Most attention mass concentrates on recent tokens \(j \in [t-w, t]\)
Distant tokens contribute negligibly individually but matter collectively
Solution: Group distant tokens, attend to representatives
Natural Multi-Resolution Memory: \[\text{Recent: individual tokens} \;|\; \text{Mid-range: pooled groups} \;|\; \text{Distant: coarse summaries}\]
This mirrors other architectures:
SSMs: Compressed long-term state + precise local updates
Delta attention: Gated forgetting of distant context
Convolutions: Receptive fields widen with depth
22.3 DeepSeek DSA Pattern (Simplified Form)
For token \(t\), DSA restricts attention to: \[\mathcal{N}(t) = \underbrace{\{t-w,\ldots,t\}}_{\text{local window}} \;\cup\; \underbrace{\{g_1(t), g_2(t), \ldots\}}_{\text{group representatives}} \;\cup\; \underbrace{\text{global tokens}}_{\text{optional}}.\]
The attention update becomes: \[y_t = \sum_{j \in \mathcal{N}(t)} \frac{\exp(q_t^\top k_j / \sqrt{d_k})} {\sum_{l \in \mathcal{N}(t)} \exp(q_t^\top k_l / \sqrt{d_k})} \; v_j.\]
Key insight: Grouping partitions the past into bands of equal importance–near tokens have fine granularity, distant tokens have coarse granularity. This reflects how attention scores decay and variance decreases over distance.
Concrete Example: Position \(t=1000\), window \(w=256\), stride \(s=64\)
Neighborhood \(\mathcal{N}(1000)\) includes:
Local window: tokens 744-1000 (256 tokens, dense attention to all)
Group 1: token 680 (first representative beyond window)
Group 2: token 616 (stride 64 back)
Group 3: token 552
\(\vdots\) (continue every 64th token)
Global: token 0 (BOS, always included)
Complexity Reduction:
Full attention: 1000 tokens \(\times\) 1000 queries = 1M attention scores
DSA: \((256 + \lceil(744)/64\rceil + 1) \times 1000 \approx 269{,}000\) scores
Reduction: \(\sim\)73% fewer computations while preserving long-range access
22.4 Why Groups Allow Efficient Long-Context Reasoning
Grouping creates a pyramid-like memory: \[\text{recent tokens: individual} \;\;|\;\; \text{mid-range: pooled} \;\;|\;\; \text{distant: highly pooled}.\]
This yields:
lower complexity: \(O(w + \text{groups})\) instead of \(O(n)\) per token,
adaptive resolution: attention can still “jump” to representative summary tokens,
MoE synergy: experts specialize on different ranges (local syntax vs. long-range discourse).
22.5 DeepSeek-V2: Integrating MoE and Sparse Attention
DeepSeek-V2 uses three main techniques:
1. Fine-Grained MoE
160 experts, top-6 routing per token (3.75% density). Encourages specialization across granular linguistic functions.
2. Multi-Head Latent Attention (MLA)
Problem: Standard attention stores full KV cache: each token stores \(K_i \in \mathbb{R}^{d_k}\) and \(V_i \in \mathbb{R}^{d_v}\) per head. For \(h\) heads and sequence length \(n\): \(O(n \cdot h \cdot (d_k + d_v))\) memory.
MLA Solution: Compress K/V through low-rank bottleneck: \[\begin{align} c_i^K & = W_K^{\mathrm{down}} x_i \in \mathbb{R}^{d_c} \quad \text{(compressed key, } d_c \ll d_k \text{)} \\ K_i & = W_K^{\mathrm{up}} c_i^K \in \mathbb{R}^{d_k} \quad \text{(reconstruct for attention)} \\ c_i^V & = W_V^{\mathrm{down}} x_i \in \mathbb{R}^{d_c} \\ V_i & = W_V^{\mathrm{up}} c_i^V \in \mathbb{R}^{d_v} \end{align}\]
Key Insight – Why Not Fuse?
\(W^{\mathrm{up}} W^{\mathrm{down}}\) could be fused into one matrix during training
BUT: During inference, we cache \(c_i^K, c_i^V\) (compressed), not \(K_i, V_i\) (full)
When query \(q_t\) arrives, we reconstruct: \(K_i = W^{\mathrm{up}} c_i^K\) on-the-fly
Cache size: \(n \times d_c\) instead of \(n \times h \times d_k\) (savings: \(h \cdot d_k / d_c\) factor)
Example: DeepSeek-V2 with \(h=128\) heads, \(d_k=128\), \(d_c=512\):
Standard KV cache per token (keys only): \(h \times d_k = 128 \times 128 = 16{,}384\) scalar values
MLA cache per token (keys only): \(d_c = 512\) scalar values (compressed \(c^K\))
Reduction for keys: \(32\times\) smaller cache!
Note: Same calculation applies to values; total cache includes both K and V
Tradeoff:
+ Massive KV cache reduction for long contexts
+ \(W^{\mathrm{up}}\) can be shared across heads (further savings)
- Extra matmul \(W^{\mathrm{up}} c_i\) during attention compute (but memory bandwidth is the real bottleneck)
3. Deterministic Sparse Attention
Local window + grouped summaries. This creates a structure analogous to: \[\text{attention} \approx \text{local SSM-like recurrence} + \text{global sparse lookup}.\]
Unified View:
MoE increases depth of representation (specialized functions).
Sparse attention increases breadth of context (efficient long-range access).
Together, they behave like: “large-capacity neural information system with multiresolution access.’’
Outcome: 236B total parameters, 21B active per token: massive functional capacity, efficient long-context behavior, and practical KV-cache requirements for deployment.
23 Qwen: Gated Architectures and Hybrid Attention
23.1 Qwen2.5 Architecture
Qwen2.5 (Alibaba) uses:
Grouped-Query Attention (GQA): 28 Q heads, 4 K/V heads (7:1 ratio)
SwiGLU gating in FFN: \(f(x) = W_2 \left(\mathrm{SiLU}(W_{1a} x) \odot (W_{1b} x)\right)\)
Dual-chunk attention (Qwen2.5-Turbo): Hybrid approach for long contexts
YaRN positional interpolation: Extends RoPE to 128K context via frequency scaling
23.2 Dual-Chunk Attention Mechanism
Qwen2.5-Turbo introduces dual-chunk attention to balance local precision and global context at 128K+ sequences.
Problem: Standard sliding window attention loses global context; full attention is too expensive for 128K tokens.
Solution: Compute two parallel attention outputs and blend them:
Local chunk (dense window):
Full attention within sliding window of size \(w\) (e.g., \(w=2048\))
Captures fine-grained, token-level dependencies
\(o_{\text{local}} = \text{Attention}(Q, K_{\text{window}}, V_{\text{window}})\)
Global chunk (sparse stride):
Sparse attention to every \(s\)-th token (e.g., \(s=64\)) across full context
Captures long-range dependencies efficiently
For 128K context: attend to \(128000/64 = 2000\) representative tokens
\(o_{\text{stride}} = \text{Attention}(Q, K_{\text{stride}}, V_{\text{stride}})\)
Gated fusion:
Learned gate \(\lambda = \sigma(W_\lambda x)\) blends outputs: \[y = x + \lambda \odot o_{\text{local}} + (1-\lambda) \odot o_{\text{stride}}\]
Gate learns when to prioritize local vs. global context
Early layers: favor local (\(\lambda \approx 0.8\)); deep layers: balance both
Complexity Analysis:
Local: \(O(n \cdot w)\) where \(w=2048\) → \(O(n)\) for fixed window
Global: \(O(n \cdot n/s)\) where \(s=64\) → \(O(n^2/64)\)
Total: \(O(n)\) local + \(O(n^2/s)\) sparse → practical for 128K contexts
Compare: Full attention \(O(n^2)\) infeasible; pure sliding window \(O(n)\) but loses global
Why This Works:
Most information is local (within 2K tokens)
Long-range dependencies captured by sampling representative tokens
Gating allows model to learn task-specific local/global balance
Similar to Longformer + BigBird hybrid but with learned blending
23.3 Qwen Gated Attention Block (practical notes)
Pre-norm + RoPE on Q/K: RMSNorm before projections; RoPE applied to \(Q, K\) maintains phase for long context.
Head scaling and GQA: Q heads are grouped to reuse K/V heads; per-head learned scale \(s_h\) rescales logits \(\frac{s_h}{\sqrt{d_k}} QK^\top\) to stabilize mixed-precision training.
Chunk-gated fusion: Dual-chunk attention produces \(o_{\text{local}}\) (dense window) and \(o_{\text{stride}}\) (sparse/global). A learned gate blends them: \[\lambda = \sigma(W_\lambda x), \quad y = x + \lambda \odot o_{\text{local}} + (1-\lambda) \odot o_{\text{stride}}.\] Gate biases favor local context early in training, relaxing toward global reads later.
Gated FFN (SwiGLU): Same gate form as above controls channel-wise flow, improving stability for high width-to-depth ratios used in Qwen2.x.
23.4 Qwen-MoE Variants
Qwen1.5-MoE:
14.3B total params, 2.7B active
64 experts, activate top-4 per token
Expert specialization emerges: domain-specific routing (code vs. language vs. math)
24 Vision Transformers (ViT)
24.1 From Convolutions to Patches: ViT Architecture
Vision Transformers (Dosovitskiy et al., 2020) apply standard transformers to images by treating image patches as tokens:
Input Processing:
Patchify: Split image \(H \times W \times 3\) into \(N\) non-overlapping patches of size \(P \times P\) \[N = \frac{H \cdot W}{P^2} \quad \text{(e.g., 224$\times$224 image, 16$\times$16 patches $\rightarrow$ 196 patches)}\]
Linear embedding: Flatten each patch to \(P^2 \cdot 3\) vector, project to \(d\)-dimensional embedding: \[x_i = W_{\text{patch}} \cdot \text{flatten}(\text{patch}_i) \quad \in \mathbb{R}^d\]
Prepend [CLS] token: Add learnable class token \(x_{\text{CLS}}\) at position 0 (like BERT)
Add positional embeddings: Learnable position embeddings (not sinusoidal): \[z_i = x_i + E_{\text{pos}}[i] \quad \text{for } i = 0, 1, \ldots, N\]
Transformer Encoder:
Standard transformer layers (multi-head self-attention + FFN)
No causal masking–full bidirectional attention across all patches
Output: Use final [CLS] token representation for classification
24.2 Key ViT Variants and Practices
ViT-Base/Large/Huge:
ViT-Base: 12 layers, \(d=768\), 12 heads, 86M params
ViT-Large: 24 layers, \(d=1024\), 16 heads, 307M params
ViT-Huge: 32 layers, \(d=1280\), 16 heads, 632M params
DeiT (Data-efficient ViT, Touvron et al., 2021):
Trains ViT without massive datasets (JFT-300M) via distillation from CNN teacher
Distillation token: Add \(x_{\text{dist}}\) alongside [CLS], match teacher’s predictions
Strong data augmentation (RandAugment, Mixup, CutMix)
Swin Transformer (Liu et al., 2021):
Hierarchical structure: start with small patches, merge into larger patches (like CNN pyramid)
Shifted windows: Local attention within windows, shift windows between layers for cross-window connections
Reduces complexity from \(O(HW)^2\) to \(O(HW \cdot M^2)\) where \(M\) is window size
Better for dense prediction (detection, segmentation)
MAE (Masked Autoencoders, He et al., 2022):
Self-supervised pre-training: mask 75% of patches, reconstruct pixel values
Asymmetric encoder-decoder: lightweight decoder for reconstruction
Strong transfer to downstream tasks (classification, detection)
CLIP (Radford et al., 2021):
Contrastive pre-training on image-text pairs (400M from web)
Image encoder: ViT-L/14 (patch size 14\(\times\)14)
Text encoder: Transformer with causal masking
Loss: contrastive (match image/text embeddings), zero-shot transfer to classification
DINO/DINOv2 (Caron et al., 2021/2023):
Self-supervised via self-distillation (student-teacher with momentum update)
DINOv2: trained on 142M curated images, strong semantic features
Works as universal visual backbone for downstream tasks
24.3 Practical Considerations
When to use ViT vs. CNNs:
ViT: Better with large-scale pre-training, global context modeling, transfer learning
CNNs: Better inductive bias for small datasets, local spatial structure, efficiency on edge
Hybrid: ConvNeXt (modernized CNN matching ViT), or ViT with convolutional stem
Position Embeddings:
Original ViT: Learnable absolute positions
Interpolation for different resolutions: resize position embeddings at fine-tune time
Relative positional encodings (Swin): better for variable-size inputs
Compute Optimization:
Attention is \(O(N^2)\) where \(N = (H/P) \times (W/P)\)–large patch size \(P\) reduces tokens
Flash Attention for memory efficiency (same as LLMs)
Hierarchical designs (Swin, PVT) reduce attention cost at high resolution
Interview-Ready Facts:
ViT patch size 16 on ImageNet-224: \(14 \times 14 = 196\) patches
token aggregates global information (alternative: global average pooling)
ViT needs pre-training on large datasets (ImageNet-21K or JFT-300M) to match CNNs
CLIP’s zero-shot: encode text prompts "a photo of a [class]" and match cosine similarity
Swin’s shifted windows enable cross-window information flow while keeping local complexity
25 Unifying View: Content-Based Selection
All attention mechanisms discussed share the core principle of content-based selection with soft gates:
| Mechanism | “Query” | “Keys/Memory” | Gate Function |
|---|---|---|---|
| Full Attention | \(q_i\) | \(k_j, v_j\) | \(\mathrm{softmax}(q^\top k)\) |
| Linear Attention | \(\phi(q_i)\) | \(\phi(k_j), v_j\) | Kernel factorization |
| Delta Attention | \(q_t\) | State \(S_t\) | Gated update \(\gamma, \beta\) |
| SSM/Mamba | \(C_t\) | State \(x_t\) | Dynamics \(A_t, B_t, C_t\) |
| RWKV | \(q_t\) | Decayed history | Exponential time-weighting |
Fundamental Question:
How do we perform content-based lookup and aggregation over long sequences while controlling computational and memory costs?
Two Main Strategies:
Approximate attention: Linearize/kernelize softmax (Performer, linear attention)
Replace with recurrence: Compress history into fixed-size state (Mamba, RWKV, Delta)
26 Interview Cheat Phrases
26.1 Core Attention
“Attention is content-based lookup: queries retrieve relevant values via softmax-weighted keys.”
“Q/K/V factorization enables asymmetric roles–query asks, key describes, value stores–and multi-head decomposition.”
“Softmax acts as a relaxed selector gate, interpolating between hard argmax (low temp) and uniform (high temp).”
26.2 Transformers
“Original transformer has \(O(n^2 d)\) attention complexity–quadratic in sequence length is the bottleneck for long context.”
“Modern improvements: RoPE for positional encoding, SwiGLU for activation, GQA to reduce KV cache memory.”
26.3 Sparse/Structured Attention
“Sliding window attention reduces complexity to \(O(nwd)\) but limits direct long-range connections–stacking layers expands receptive field.”
“Flash Attention achieves 2-4\(\times\) speedup via IO-aware tiling, not by reducing FLOPs–hardware-algorithm co-design matters.”
26.4 Linear/Kernelized Attention
“Linear attention factorizes \(\exp(q^\top k) \approx \phi(q)^\top \phi(k)\), enabling recurrent state updates with constant memory per token.”
“Tradeoff: \(O(n)\) complexity but loses softmax’s selectivity–tends to underperform on complex reasoning tasks.”
26.5 Delta Attention
“Delta attention uses gated recurrent updates to compress KV cache into fixed-size state–forget gate \(\gamma\) decays old info, input gate \(\beta\) controls new contributions.”
“Kimi leverages Delta attention for 200K+ context with sub-linear memory.”
26.6 SSMs and Mamba
“State-space models replace attention with learned dynamical systems–classical SSMs have fixed dynamics, Mamba makes them input-dependent.”
“Mamba achieves \(O(nd N)\) complexity with state size \(N\)–trades explicit attention for implicit state compression.”
“SSMs excel at ultra-long sequences and streaming inference; attention excels at complex reasoning and in-context learning.”
26.7 RWKV
- “RWKV is linear attention with exponential time-decay–RNN-like inference with \(O(1)\) memory, transformer-like parallel training.”
26.8 MoE
“MoE is attention over expert slots: softmax gate routes tokens to top-\(k\) of \(E\) experts–scales capacity without proportional compute.”
“Load balancing is critical: auxiliary loss encourages uniform expert utilization to prevent collapse.”
“DeepSeek-V2: 236B total params, 21B active–massive capacity via fine-grained sparse MoE.”
26.9 Modern LLMs
“LLaMA lineage (Mistral, Qwen, DeepSeek): decoder-only, RoPE, GQA, SwiGLU–standard stack for modern pre-training.”
“Qwen2.5: 28 Q heads, 4 K/V heads (GQA 7:1), dual-chunk attention for 128K context.”
“DeepSeek: MoE + multi-head latent attention (low-rank KV compression) + sparse attention.”
27 Summary: The Attention Evolution
Historical Timeline and Key Milestones:
2017: Original Transformer–Full \(O(n^2)\) attention, multi-head self-attention, position-wise FFN. Established Q/K/V paradigm. "Attention Is All You Need" (Vaswani et al.).
2018-2019: GPT & BERT–Demonstrated pre-training + fine-tuning. Decoder-only (GPT) vs. encoder-only (BERT) split.
2020-2021: Scaling Laws–LLMs benefit from scale (GPT-3). Sparse/structured attention (Longformer, BigBird) to handle longer sequences.
2021-2022: Efficient Attention–Performer (linear attention via RFF), Flash Attention (IO-aware exact attention).
2022-2023: Modern LLM Stack–LLaMA introduces RoPE, SwiGLU, pre-norm, GQA. Standard for Mistral, Qwen, DeepSeek.
2023: State-Space Models–Mamba (selective SSMs), RWKV (linear RNN-attention hybrid). Challenge attention’s dominance for ultra-long context.
2023-2024: MoE Scaling–DeepSeek-V2, Qwen-MoE, Mixtral. Sparse expert activation scales capacity without compute explosion.
2024: Hybrid Architectures–Delta attention (Kimi), multi-head latent attention (DeepSeek), dual-chunk attention (Qwen Turbo). Combine ideas for production systems.
Current Frontier (2024-2025):
Long-context efficiency: Delta attention, Mamba, RWKV variants
MoE with thousands of experts + learned routing policies
Hybrid attention-SSM architectures (e.g., interleaving Mamba and attention layers)
Hardware co-design: Flash Attention 2/3, tensor cores, custom kernels
Related Topics (Beyond This Document):
Inference Optimization: KV cache management, paged attention (vLLM), speculative decoding, continuous batching, quantization (INT8/FP8)
Post-Training: RLHF, RLAIF, DPO, reward modeling, preference optimization, safety fine-tuning
Hardware-Aware Design: GPU/TPU memory hierarchy, tensor cores, operator fusion, kernel optimization
Retrieval-Augmented Generation: RAG architectures, vector databases, embedding models, reranking
Multimodal Models: Vision-language models (Flamingo, GPT-4V), audio-text (Whisper), unified encoders
For questions, corrections, or suggestions: peymanr@gmail.com