18 Chapter 17: Constrained LLM Inference
19 Introduction
On-device LLM inference faces harsh constraints: limited memory (4-16GB), power budgets (5-15W mobile, sub-1W for wearables), thermal throttling, and offline operation. Success requires quantization (compress weights/activations), distillation (create efficient student models), and speculative decoding (accelerate generation with draft models).
This Document Covers:
Quantization formats (FP8, INT8, INT4, INT2, and role of BF16/FP16) and algorithms (AWQ, GPTQ, SmoothQuant)
Layer-wise precision strategies–which layers tolerate INT2/INT4 vs require FP16/BF16
Hardware considerations (GPUs, Apple Silicon, ARM SoCs, x86 edge)
Knowledge distillation for creating efficient draft models
Speculative decoding mechanics with importance sampling
KV cache management and eviction policies
Production deployment recipes
Target Audience:
ML engineers deploying LLMs to mobile/edge devices, researchers optimizing inference, and interviewees preparing for systems+ML roles. Assumes familiarity with transformers and basic numerical formats.
20 Decoding Strategies: Temperature and Sampling
20.1 Temperature Scaling
Temperature is applied to **final output logits** before sampling the next token, NOT in attention softmax (which uses fixed \(T=\sqrt{d_k}\) scaling).
Where Temperature Applies:
Transformer outputs logits \(z \in \mathbb{R}^{|V|}\) (one score per vocabulary token)
Scale by temperature: \(z' = z / T\)
Apply softmax: \(p = \text{softmax}(z') = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}\)
Sample next token from distribution \(p\)
Effect on Distribution: \[p_i(T) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}\]
| Temperature | Distribution Shape | Behavior |
|---|---|---|
| \(T \to 0\) | Peaked (almost one-hot) | Deterministic, picks highest logit (greedy) |
| \(T = 1.0\) | Standard softmax | Default, uses model’s learned probabilities |
| \(T \in (0.7, 0.9)\) | Sharpened (higher confidence) | More focused, less random, common for factual tasks |
| \(T \in (1.1, 1.5)\) | Flattened (lower confidence) | More diverse/creative, common for creative writing |
| \(T \to \infty\) | Uniform | Random, ignores model’s learned preferences |
Concrete Example:
Logits: \(z = [5.0, 3.0, 1.0, 0.5]\) (4 tokens)
\(T=0.1\) (greedy): \(z' = [50, 30, 10, 5]\) → \(p \approx [0.999, 0.001, 0.0, 0.0]\) (almost always token 0)
\(T=1.0\) (standard): \(z' = [5.0, 3.0, 1.0, 0.5]\) → \(p \approx [0.84, 0.11, 0.02, 0.01]\) (token 0 picked 84%)
\(T=2.0\) (creative): \(z' = [2.5, 1.5, 0.5, 0.25]\) → \(p \approx [0.64, 0.23, 0.09, 0.05]\) (more mass on tokens 1-3)
20.2 Sampling Strategies
Greedy Decoding (\(T \approx 0\)):
Pick \(\operatorname*{arg\,max}_i z_i\) at each step (deterministic)
Use for: factual Q&A, code generation, translation
Problem: repetitive, no exploration
Nucleus Sampling (Top-p):
Apply temperature: \(z' = z / T\), compute \(p = \text{softmax}(z')\)
Sort tokens by probability
Keep smallest set where cumulative probability \(\geq p\) (e.g., \(p=0.9\))
Renormalize and sample from filtered distribution
Top-k Sampling:
Keep only top-k tokens by probability (e.g., \(k=50\)), renormalize and sample
Simpler than top-p but less adaptive (always keeps k tokens regardless of distribution shape)
20.3 Production Recommendations
| Task | Temperature | Strategy |
|---|---|---|
| Code generation | 0.0-0.2 | Greedy or low-T + top-p=0.95 |
| Factual Q&A | 0.3-0.5 | Low-T + top-p=0.9 |
| Summarization | 0.5-0.7 | Mid-T + top-p=0.9 |
| Creative writing | 0.8-1.2 | Higher-T + top-p=0.95 |
| Brainstorming | 1.0-1.5 | High-T + top-p=1.0 (no filtering) |
| Chatbot (general) | 0.7-0.9 | Balanced + top-p=0.95 |
20.4 Common Misconceptions
Myth: Temperature affects attention mechanism
Reality: Attention uses fixed \(T=\sqrt{d_k}\); API temperature only scales final logitsMyth: Higher temperature = higher quality
Reality: Higher temperature = more randomness; quality depends on taskMyth: Temperature=0 disables randomness
Reality: Temperature \(\to\) 0 gives greedy decoding (argmax), which is deterministicMyth: Top-p replaces temperature
Reality: Top-p filters distribution after temperature scaling; use both together
21 Quantization: Fundamentals
21.1 Why Quantize?
Memory Savings:
FP32 → INT8: \(4\times\) reduction (e.g., 7B model: 28GB → 7GB)
FP32 → INT4: \(8\times\) reduction (28GB → 3.5GB)
FP32 → INT2: \(16\times\) reduction (28GB → 1.75GB)
Bandwidth & Latency:
Narrower data types reduce DRAM traffic (memory-bound workloads dominate inference)
Example: 7B model at FP32 requires \(\sim\)28GB DRAM reads per token; INT8 drops to 7GB
Enables higher throughput at fixed power envelope
Hardware Utilization:
Lower precision enables higher throughput: NVIDIA A100 delivers 312 TFLOPS (FP16) vs 19.5 TFLOPS (FP32), and 624 TOPS (INT8)
Mobile NPUs/DSPs optimized for INT8/INT4 operations (10-100 TOPS typical)
TFLOPS = Tera (trillion) Floating-Point Operations Per Second; TOPS = Tera Operations Per Second (integer)
21.2 Quantization Basics
Affine Quantization: \[x_q = \text{clip}\left(\text{round}\left(\frac{x}{s}\right) + z, q_{\min}, q_{\max}\right)\] \[x \approx s \cdot (x_q - z)\] where \(s\) is the scale (maps float range to integer range), \(z\) is the zero-point (asymmetric quantization), and \([q_{\min}, q_{\max}]\) is the quantized range (e.g., \([-128, 127]\) for INT8).
Symmetric vs Asymmetric:
Symmetric (\(z=0\)): Range \([-\alpha, \alpha]\) centered at zero; simpler matmul kernels
Asymmetric (\(z \neq 0\)): Can represent non-centered distributions (e.g., post-ReLU activations)
Preference: Symmetric for weights, asymmetric for activations (though symmetric is more common for simplicity)
Granularity:
Per-tensor: Single scale for entire tensor–fast but coarse
Per-channel: Separate scale for each output channel–better accuracy, minimal overhead
Per-group: Scale for groups of \(g\) elements (e.g., \(g=64\) or \(g=128\))–balances accuracy and storage
21.3 Numeric Formats
| Format | Bits | Range | Typical Use | Hardware |
|---|---|---|---|---|
| FP32 | 32 | \(\pm 3.4 \times 10^{38}\) | Baseline | Universal |
| FP16 | 16 | \(\pm 65504\) | Training, high-accuracy inference | GPUs, modern CPUs |
| BF16 | 16 | \(\pm 3.4 \times 10^{38}\) | Training (wide range) | TPUs, newer GPUs |
| FP8 (E4M3) | 8 | \(\pm 448\) | Activations, KV cache | H100, AMD MI300 |
| FP8 (E5M2) | 8 | \(\pm 57344\) | Gradients (wide range) | H100 |
| INT8 | 8 | \([-128, 127]\) | Weights, activations | Ubiquitous |
| INT4 | 4 | \([-8, 7]\) | Weights (matmul-heavy) | GPUs, NPUs |
| INT2 | 2 | \([-2, 1]\) or \(\{-1,0,1\}\) | Extreme compression | ASICs, specialized NPUs |
FP8 Details:
E4M3: 4-bit exponent, 3-bit mantissa–better precision, narrower range (for activations)
E5M2: 5-bit exponent, 2-bit mantissa–wider range, coarser precision (for gradients/outliers)
Used in NVIDIA H100 FP8 tensor cores, AMD MI300; enables near-FP16 accuracy with \(2\times\) memory savings
BF16 (Brain Float 16) Details:
Same exponent range as FP32 (8 bits) but only 7-bit mantissa (vs 23-bit in FP32)
Advantage: No overflow issues when converting FP32 → BF16 (same dynamic range)
Use case: Training on TPUs (Google), newer GPUs (Ampere+); less common for inference
Inference trade-off: FP16 has better precision (10-bit mantissa) for activations; BF16 preferred when training checkpoints already in BF16 or when using mixed-precision training
Typically: Use FP16 for inference unless model trained in BF16 (avoid conversion loss)
FP16 vs BF16 for Inference:
FP16 is generally preferred for inference (better precision, wider hardware support). Use BF16 only if: (1) model trained in BF16, (2) hardware has native BF16 support (TPUs, Ampere+ GPUs), or (3) specific layers overflow in FP16 (rare for inference).
22 Quantization Algorithms
22.1 Post-Training Quantization (PTQ)
PTQ quantizes a pre-trained model without additional training. Critical for fast deployment.
22.1.1 Round-to-Nearest (RTN)
Simplest baseline:
Compute scale \(s = \frac{\max(|W|)}{q_{\max}}\) for each tensor/channel
Round: \(W_q = \text{round}(W / s)\)
Fast but largest accuracy drop (especially INT4/INT2)
22.1.2 GPTQ (Second-Order Quantization)
Key Idea: Minimize reconstruction error by considering weight interactions (Hessian-based).
Algorithm:
For each layer, compute Hessian \(H = \nabla^2_W \mathcal{L}\) (approximate with Fisher or empirical samples)
Quantize weights blockwise (e.g., 128 columns at a time)
For each weight \(w_i\), quantize and compute error \(e = w_i - w_q\)
Compensate remaining weights: \(w_j \leftarrow w_j - \frac{H_{ij}}{H_{ii}} e\) (minimizes output error)
Benefits:
Achieves near-optimal weight rounding for INT4/INT3
1-2% perplexity improvement vs RTN at INT4
Widely used: Hugging Face Optimum, llama.cpp, TensorRT-LLM
22.1.3 AWQ (Activation-Aware Weight Quantization)
Key Insight: Not all weight channels are equally important–protect channels with large activation magnitudes.
Algorithm:
Collect activation statistics on calibration data
Compute per-channel salience: \(s_i = \mathbb{E}[\|X_i\|]\) where \(X_i\) is input to channel \(i\)
Scale weights: \(W'_i = s_i \cdot W_i\) (amplify important channels)
Quantize scaled weights with higher effective precision
At inference, apply inverse scale: \(Y = W'_q X / s\) (fused with dequant)
Benefits:
Preserves outlier channels that dominate model quality
Works well for INT4 weights with FP16 activations
Minimal overhead (scales stored per-channel)
Used in AutoAWQ, vLLM, MLC-LLM
22.1.4 SmoothQuant
Problem: Activations have outliers (e.g., token embeddings, early layers) → hard to quantize activations to INT8.
Solution: Migrate outliers from activations into weights via channel-wise scaling.
Algorithm:
Identify outlier channels in activations: \(\alpha_i = \max(|X_i|)\)
Compute per-channel scaling: \(s_i = \alpha_i^\gamma\) where \(\gamma \in [0, 1]\) (hyperparameter)
Transform: \(X'_i = X_i / s_i\), \(W'_i = W_i \cdot s_i\)
Now \(X'\) has reduced dynamic range → quantize both \(X'\) and \(W'\) to INT8
Mathematically equivalent: \(Y = WX = (W' \cdot \text{diag}(s^{-1})) \cdot (\text{diag}(s) \cdot X) = W' X'\)
Benefits:
Enables INT8 activations (critical for full INT8 matmuls on mobile hardware)
Typical \(\gamma = 0.5\) (balance between activation and weight quantization difficulty)
Used in TensorRT-LLM, deployed in production for LLaMA/Mistral
22.1.5 QuIP# and Other Advanced Methods
QuIP#: Learned rounding with incoherence processing–pushes INT2 boundaries
OmniQuant: Unified framework combining learnable scales + clipping + block reordering
SpQR: Sparse quantization–keep top 1% outliers in FP16, rest in INT4
22.2 Quantization-Aware Training (QAT)
When to use: When PTQ quality insufficient (INT4/INT2 regimes, extreme compression targets).
Fake Quantization:
Insert quantize-dequantize ops during forward pass: \(\tilde{W} = \text{dequant}(\text{quant}(W))\)
Gradients flow through straight-through estimator (STE): \(\frac{\partial \tilde{W}}{\partial W} = \mathbb{1}\)
Model learns to be robust to quantization noise
How STE Works Despite Discontinuity:
The quantization function \(\text{quant}(x) = \lfloor x/s \rceil\) (round-to-nearest) is discontinuous and has zero gradient almost everywhere. The STE addresses this by using different functions for forward vs backward passes:
Forward Pass: Apply real quantization \[\tilde{W} = s \cdot \lfloor W/s \rceil\]
Backward Pass: Treat quantization as identity \[\frac{\partial \mathcal{L}}{\partial W} = \frac{\partial \mathcal{L}}{\partial \tilde{W}} \cdot \underbrace{\frac{\partial \tilde{W}}{\partial W}}_{\text{set to } \mathbb{1}} = \frac{\partial \mathcal{L}}{\partial \tilde{W}}\]
Why This Works:
Gradient Approximation: Within quantization bins, \(\tilde{W} \approx W\), so \(\partial \tilde{W}/\partial W \approx 1\) is reasonable
Learned Avoidance: Network learns to keep weights away from quantization boundaries where gradient signal is most distorted
Stochastic Smoothing: Mini-batch gradients average over many samples, smoothing out discontinuities
Empirical Success: Works well in practice for INT8/INT4; INT2 requires careful tuning
Alternative: Some implementations use clipped identity \(\frac{\partial}{\partial x}\text{clip}(x, -1, 1) = \mathbb{1}_{|x| \leq 1}\) to zero out gradients for extreme outliers.
Training Recipe:
Start from pre-trained FP16/BF16 model
Insert fake-quant layers (typically per-channel symmetric for weights, per-tensor for activations)
Fine-tune for 1-5% of original training steps with lower learning rate (\(10^{-5}\) to \(10^{-6}\))
Optionally combine with knowledge distillation (teacher = FP16 model)
Export quantized weights and scales
Benefits:
Recovers 1-3% perplexity vs PTQ at INT4
Enables aggressive compression (INT2 MLP weights with INT8 attention)
Used in production: QLoRA quantized fine-tuning, mobile model deployments
23 Layer-Wise Precision Strategy
Not all layers tolerate low precision equally. Empirical guidelines:
| Layer Type | Safe Precision | Aggressive | Notes |
|---|---|---|---|
| Embedding (Input) | FP16/BF16 | INT8 | Sensitive to outliers; keep FP16 or use per-token scales |
| Attention QKV Proj | INT8 | INT4 | Per-channel scales sufficient; can go INT4 with AWQ |
| Attention Output Proj | INT8 | INT4 | Match QKV precision; residual connection sensitive |
| Attention Softmax | FP16 | FP16 | Never quantize–numerical instability kills quality |
| MLP (FFN) Up/Down | INT8 | INT4/INT2 | Most weight-heavy; benefits most from quantization |
| MLP Activation (SwiGLU) | FP16/INT8 | INT8 | Post-activation quantization with asymmetric scales |
| LayerNorm/RMSNorm | FP16 | FP16 | Keep high precision–stabilizes residuals |
| KV Cache | FP16/FP8 | INT8 | FP8 (E4M3) typical; INT8 with per-token/per-channel scales |
| Final LM Head | FP16/BF16 | INT8 | Accuracy-critical; use FP16 or careful INT8 with calibration |
23.1 Where to Use INT2?
Target: MLP weight matrices (up_proj, down_proj) in middle-to-late layers.
Requirements:
Per-group quantization: Group size 64-128; per-channel scales too coarse
QAT or strong distillation: PTQ alone causes significant degradation at INT2
Hardware support: Native INT2 dot-product (e.g., custom ASICs, Qualcomm Hexagon DSP); otherwise overhead negates gains
Pair with higher-precision activations: Use FP16 or INT8 activations to limit noise accumulation
Typical Configuration:
Attention: INT4 or INT8 weights, FP16 activations
MLP: INT2 weights (with group quant), INT8 activations
Norms/Embeddings/LM Head: FP16
Quality Impact:
Expect 2-5% perplexity increase vs FP16 baseline
Mitigated with QAT + distillation (can recover to within 1-2%)
Critical for \(<\)2GB models (e.g., 3B model: 12GB FP32 → 0.75GB INT2 MLP + INT4 attention)
23.2 Where to Use INT4?
Target: All weight matrices (attention + MLP) except embeddings and LM head.
Best Practices:
Use AWQ or GPTQ for per-channel weight quantization
Keep activations in FP16 or INT8 (SmoothQuant for INT8 activations)
Group size 128 typical; per-channel often sufficient for attention
Hardware:
NVIDIA GPUs: INT4 tensor cores (Ampere+), fused dequant kernels
Apple Silicon: ANE supports INT8 natively; INT4 via bit-packing with minor overhead
ARM SoCs: Use NEON/i8mm for INT8; INT4 requires software emulation (2 INT4 → 1 INT8)
23.3 Where to Use INT8?
Target: Conservative deployment–all layers except softmax and norms.
Configuration:
Weights: Per-channel symmetric INT8
Activations: Per-tensor INT8 (with SmoothQuant if outliers present)
KV cache: INT8 per-channel or FP8 (E4M3)
Hardware: Ubiquitous support–AVX512-VNNI, ARM i8mm, Apple ANE, NVIDIA INT8 tensor cores.
Quality: Typically \(<\)0.5% perplexity degradation with proper calibration (RTN often sufficient).
23.4 FP8 Considerations
When to use: GPUs with native FP8 support (H100, MI300), large models (\(>\)70B) where memory bandwidth critical.
Advantages:
Near-FP16 accuracy with \(2\times\) memory savings
No calibration needed (dynamic range sufficient)
Ideal for KV cache (E4M3 format)
Disadvantages:
Limited hardware support (not available on mobile/edge)
Marginal vs INT8 on memory-constrained devices (INT4 better choice)
24 Hardware Considerations
24.1 NVIDIA GPUs
Formats: FP8 (H100), INT8/INT4 (Ampere+), FP16/BF16 (all modern GPUs)
Best Practices:
Use TensorRT-LLM or vLLM with FP8/INT8/INT4 kernels
Fused matmul-dequant kernels critical (avoid separate dequant → copies)
Per-token KV cache quantization (FP8 E4M3 or INT8 per-channel)
Paged attention for efficient memory management
Memory Hierarchy:
HBM bandwidth: 2-3 TB/s (A100/H100)–memory-bound for large models
L2 cache: 40-50 MB–tile matmuls to fit working set
Shared memory: 100-200 KB per SM–critical for attention softmax, reductions
24.2 Apple Silicon (M1/M2/M3, A-series)
Neural Engine (ANE):
Optimized for INT8 operations; FP16 fallback for unsupported ops
Limited to certain op patterns (matmul, conv, activations)–custom ops fall back to CPU/GPU
Use Core ML for deployment; quantization via coremltools
Best Practices:
INT8 weights, FP16 activations (ANE mixed-precision)
Keep softmax and norms in FP16 (run on GPU)
Unified memory (16-64GB shared CPU/GPU/ANE)–avoid copies
Watch thermal throttling–sustained inference at 5-10W
Memory:
Unified memory bandwidth: 200-400 GB/s (M2 Pro/Max)
KV cache grows quickly–use FP16 or INT8 per-channel
Paginate cache to avoid large reallocations
24.3 ARM SoCs (Snapdragon, MediaTek, Exynos)
Compute Units:
CPU: NEON SIMD (INT8), i8mm extensions (Armv8.6+) for INT8 matmul
GPU: Mali/Adreno–FP16 preferred; INT8 support varies
NPU/DSP: Hexagon (Qualcomm), APU (MediaTek)–INT8 native, some INT4 support
Best Practices:
Deploy INT8 weights to NPU, FP16 activations on GPU if mixed-precision supported
Use NNAPI (Android) or vendor SDKs (SNPE for Qualcomm)
Avoid frequent CPU-GPU transfers–pin memory, use ION buffers
Target 3-7W sustained power (phones); watch thermal throttling
Memory:
LPDDR5: 50-100 GB/s bandwidth
Limited SRAM on NPU (few MB)–tile carefully
KV cache in system RAM–use INT8 or FP16 compressed
24.4 x86 Edge (Intel, AMD)
Instructions:
AVX512-VNNI: INT8 dot-product (Cascade Lake+)
AMX (Advanced Matrix Extensions): INT8/BF16 matrix tiles (Sapphire Rapids+)
AVX-VNNI: INT8 on hybrid cores (Alder Lake+)
Best Practices:
INT8 matmuls with AVX512-VNNI or AMX (tile to 16x16 or 32x32)
Use oneDNN or ONNX Runtime for optimized kernels
Watch L3 cache size (20-100 MB)–tile KV cache access
DDR bandwidth: 50-100 GB/s (client), 200+ GB/s (server)
24.5 Tiny MCUs / IoT
Constraints: \(<\)1MB SRAM, \(<\)100 MHz, \(<\)100mW power
Strategies:
Extreme quantization: INT4/INT2 weights, 8-bit activations
Tiny models: \(<\)100M parameters, distilled from larger teacher
Lookup tables (LUTs) for activations–avoid floating-point
Streaming inference: process token-by-token, evict cache aggressively
Use specialized frameworks: TensorFlow Lite Micro, MCUNet
25 Knowledge Distillation
25.1 Why Distill for Edge?
Create small, efficient draft models for speculative decoding
Align student with teacher distribution–improves quantization tolerance
Reduce model size while preserving task-specific performance
25.2 Standard Distillation Setup
Objective: Train student \(S\) to match teacher \(T\) outputs.
Loss Function: \[\mathcal{L} = \lambda_{\text{KD}} \cdot \text{KL}(P_T \| P_S) + \lambda_{\text{CE}} \cdot \text{CE}(y, P_S) + \lambda_{\text{hidden}} \cdot \|H_T - H_S\|^2\]
where:
\(P_T = \text{softmax}(z_T / \tau)\), \(P_S = \text{softmax}(z_S / \tau)\): Softened logits (temperature \(\tau = 2\)-\(4\))
\(\text{CE}(y, P_S)\): Hard label cross-entropy (if labeled data available)
\(\|H_T - H_S\|^2\): Hidden state matching (optional; requires dimension alignment)
Temperature Scaling:
Higher \(\tau\) (2-4) softens distribution–exposes "dark knowledge" (relative probabilities)
Student learns similarity structure, not just argmax
25.3 Distillation for Quantized Models
QAT + Distillation: Combine fake quantization with distillation loss.
Recipe:
Start with pre-trained FP16 teacher
Initialize student (same architecture, smaller or quantized)
Insert fake-quant ops in student
Train with combined loss: KD + hard labels + hidden matching
Use lower learning rate (\(10^{-5}\)) for 1-5% of original training steps
Data:
Use original training data if available
Synthetic data: sample from teacher with nucleus sampling (\(p=0.9\))
Domain-specific fine-tuning data
Synthetic Data Generation from Teacher:
When original training data unavailable or insufficient, generate synthetic data by sampling from teacher model.
Nucleus (Top-p) Sampling:
Sort vocabulary by probability: \(p_1 \geq p_2 \geq \cdots \geq p_{|V|}\)
Find smallest set \(V_p\) such that \(\sum_{i \in V_p} p_i \geq p\) (typically \(p = 0.9\)-\(0.95\))
Renormalize: \(p'_i = p_i / \sum_{j \in V_p} p_j\) for \(i \in V_p\), else \(p'_i = 0\)
Sample next token from \(p'\)
Why Nucleus vs Greedy?
Diversity: Avoids repetitive outputs from greedy decoding
Quality: Filters low-probability tail (unlike pure sampling) that leads to nonsense
Distribution Coverage: Student sees varied contexts, not just modal paths
Synthetic Data Pipeline:
Seed with prompts (curated or from unlabeled corpus)
Generate continuations from teacher with nucleus sampling (\(p=0.9\), \(\tau=1.0\))
Filter low-quality outputs (perplexity threshold, toxicity checks)
Use teacher’s soft labels (logits) as distillation targets
Optionally mix with real data (70-90% synthetic, 10-30% real)
Production Example: Llama-2-Chat distillation used 100K synthetic dialogues from GPT-4 + 27K human-annotated examples.
25.4 Distillation for Speculative Decoding
Goal: Create draft model aligned with teacher (verifier) to maximize acceptance rate.
Key Considerations:
Distribution matching: High KL weight (\(\lambda_{\text{KD}} \gg \lambda_{\text{CE}}\))–draft must mimic teacher’s probability structure
Temperature: Match inference temperature (e.g., \(\tau=1\) for greedy, \(\tau=0.7\) for sampling)
Context length: Train on same context length as deployment
Calibration: Draft should be well-calibrated (confidence matches accuracy)
Architecture Choices:
Smaller depth: 12-16 layers vs 32+ for teacher
Shared vocabulary and tokenizer (critical for alignment)
Optional: Shared embeddings with teacher (freeze embeddings during distillation)
26 Speculative Decoding
26.1 The Latency Problem
Autoregressive decoding is inherently sequential: must generate token \(t\) before \(t+1\). Each token requires full model forward pass.
Memory Bandwidth Bottleneck:
7B FP16 model: \(\sim\)14GB weights
Single token generation: 14GB read from DRAM
GPU HBM: 2 TB/s → 7ms per token (memory-bound, not compute-bound)
Batch size helps, but latency still linear in sequence length for single user
Solution: Use small draft model to generate multiple tokens in parallel, verifier model validates.
26.2 Speculative Decoding Mechanics
Setup:
Draft model \(M_d\): Small (1-3B), fast, quantized (INT8/INT4)
Verifier model \(M_v\): Large (7B+), accurate, FP16 or INT8
Algorithm:
Draft generates \(k\) tokens autoregressively: \(t_1, t_2, \ldots, t_k\) with probabilities \(p_d(t_i | t_{<i})\)
Verifier processes all \(k\) tokens in parallel (one forward pass)–computes \(p_v(t_i | t_{<i})\) for \(i=1,\ldots,k\)
Accept/reject each token using importance sampling (see below)
If token \(t_j\) rejected, discard \(t_{j+1}, \ldots, t_k\); resample from verifier at position \(j\)
Repeat until sequence complete
Key Insight: Verifier processes \(k\) tokens in one pass (same cost as generating 1 token), draft generates \(k\) tokens cheaply → \(k\times\) speedup if high acceptance rate.
26.3 Importance Sampling for Acceptance
Goal: Ensure output distribution matches verifier’s distribution (unbiased sampling).
Acceptance Probability: \[\alpha_i = \min\left(1, \frac{p_v(t_i | t_{<i})}{p_d(t_i | t_{<i})}\right)\]
Procedure at position \(i\):
Sample \(u \sim \text{Uniform}(0, 1)\)
If \(u \leq \alpha_i\): accept \(t_i\), continue to \(i+1\)
If \(u > \alpha_i\): reject \(t_i\), discard \(t_{i+1}, \ldots, t_k\); resample from adjusted distribution: \[p_{\text{new}}(t) = \frac{\max(0, p_v(t) - p_d(t))}{\sum_{t'} \max(0, p_v(t') - p_d(t'))}\]
Why This Works:
When draft and verifier agree (\(p_d \approx p_v\)): \(\alpha_i \approx 1\) → high acceptance
When draft overestimates probability: \(\alpha_i < 1\) → probabilistic rejection maintains correct distribution
Adjusted resampling ensures unbiased output (provably matches verifier’s distribution)
Importance Sampling (whiteboard interview):
Problem Setup: We want to sample from target distribution \(p_v\) (verifier), but have samples from proposal distribution \(p_d\) (draft). How do we correct for the mismatch?
Importance Sampling Foundation:
Given: Draft token \(t \sim p_d\), target distribution \(p_v\).
Define acceptance probability for a single draft sample: \[\alpha(t) = \min\left(1, \frac{p_v(t)}{p_d(t)}\right)\]
Important Distinction:
\(P(\text{accept } t | \text{draft proposes } t)\): Probability of accepting token \(t\) given draft proposed it (single sampling event)
\(P(\text{output } t)\): Overall probability of outputting \(t\) through the entire process (direct acceptance OR rejection-resampling)
Claim: The procedure produces the correct marginal distribution: \(P(\text{output } t) = p_v(t)\).
Proof Strategy:
Token \(t\) can be output via two mutually exclusive paths:
Direct acceptance: Draft proposes \(t\) (prob \(p_d(t)\)), accept with prob \(\alpha(t) = \min(1, p_v(t)/p_d(t))\)
Rejection-resampling: Draft proposes \(t' \neq t\), gets rejected, resample produces \(t\)
We will show that \(P(\text{output } t) = p_v(t)\) by analyzing both paths separately for two cases.
Case 1: \(p_v(t) \geq p_d(t)\) (verifier assigns higher probability)
\(\alpha(t) = 1\), so draft token \(t\) always accepted if proposed. But \(t\) can be output in two ways:
Path A: Draft proposes \(t\) directly (probability \(p_d(t)\)), always accept (\(\alpha = 1\)): \[P(\text{accept } t \text{ from draft}) = p_d(t) \cdot 1 = p_d(t)\]
Path B: Draft proposes something else (\(t' \neq t\)), gets rejected, then resample produces \(t\).
Define rejection probability: Token \(t'\) is rejected when \(p_v(t') < p_d(t')\), with probability \(1 - \alpha(t') = 1 - p_v(t')/p_d(t')\).
Total probability of entering resampling (rejected mass): \[Z = \sum_{t'} p_d(t') \left(1 - \frac{p_v(t')}{p_d(t')}\right) = \sum_{t'} \max(0, p_d(t') - p_v(t'))\]
When rejection occurs, resample from adjusted distribution: \[p_{\text{new}}(t) = \frac{\max(0, p_v(t) - p_d(t))}{\sum_{t'} \max(0, p_v(t') - p_d(t'))}\]
Conservation of Probability: Since \(\sum_{t'} p_d(t') = \sum_{t'} p_v(t') = 1\), we have: \[\sum_{t'} (p_d(t') - p_v(t')) = 0\]
Split this sum by sign: \[\sum_{t': p_d(t') > p_v(t')} (p_d(t') - p_v(t')) + \sum_{t': p_d(t') \leq p_v(t')} (p_d(t') - p_v(t')) = 0\]
Rearranging: \[\sum_{t': p_d(t') > p_v(t')} (p_d(t') - p_v(t')) = -\sum_{t': p_d(t') \leq p_v(t')} (p_d(t') - p_v(t')) = \sum_{t': p_v(t') > p_d(t')} (p_v(t') - p_d(t'))\]
Therefore: \[\underbrace{\sum_{t'} \max(0, p_d(t') - p_v(t'))}_{\text{rejected mass (draft excess)}} = \underbrace{\sum_{t'} \max(0, p_v(t') - p_d(t'))}_{\text{missing mass (verifier excess)}}\]
Call this common value \(Z\), so \(p_{\text{new}}(t) = \frac{\max(0, p_v(t) - p_d(t))}{Z}\).
Careful: Multiple rejection rounds ARE possible!
Resampling from \(p_{\text{new}}\) could produce a token with \(p_d(t') > p_v(t')\), which gets rejected again, leading to another resampling round. This creates a geometric series.
Correct Geometric Series Analysis:
Let’s track all ways token \(t\) (with \(p_v(t) > p_d(t)\)) can be output via resampling:
Round 1: Draft rejected (prob \(Z\)), resample produces \(t\) (prob \(p_{\text{new}}(t)\)), accept with prob \(\alpha(t) = 1\)
Round 2: Draft rejected (prob \(Z\)), resample produces some \(t'' \neq t\) with \(p_d(t'') > p_v(t'')\) (rejected again), then second resample produces \(t\)
Round k: \(k-1\) rejections, then resample produces \(t\)
But wait–this is getting complicated. Here’s the key insight:
The Algorithm as Specified: When rejection occurs, we resample from \(p_{\text{new}}\) and apply the SAME acceptance-rejection procedure recursively. This means: \[P(\text{output } t \text{ via resampling}) = Z \cdot \underbrace{P(\text{output } t | \text{start from } p_{\text{new}})}_{\text{recursive}}\]
Let \(q(t) = P(\text{output } t | \text{procedure with proposal } p_{\text{new}})\). This satisfies: \[q(t) = p_{\text{new}}(t) \cdot \alpha(t) + Z_{\text{new}} \cdot q(t)\] where \(Z_{\text{new}}\) is the rejection probability under \(p_{\text{new}}\).
Solving for \(q(t)\): \[q(t) = \frac{p_{\text{new}}(t) \cdot \alpha(t)}{1 - Z_{\text{new}}}\]
For \(p_v(t) > p_d(t)\), we have \(\alpha(t) = 1\), so: \[q(t) = \frac{p_{\text{new}}(t)}{1 - Z_{\text{new}}}\]
But here’s the beautiful part: By the DEFINITION of the procedure, \(q(t)\) must equal \(p_v(t)\) for the recursion to work! So: \[P(\text{output } t) = p_d(t) \cdot 1 + Z \cdot q(t) = p_d(t) + Z \cdot \frac{p_{\text{new}}(t)}{1 - Z_{\text{new}}}\]
And by induction/fixed point argument, this equals \(p_v(t)\).
Simpler Argument (conservation):
All rejected mass \(Z\) must go somewhere. The adjusted distribution \(p_{\text{new}}\) specifies how to redistribute it proportionally to the "deficit" \((p_v(t) - p_d(t))\). Since: \[Z = \sum_{t'} (p_v(t') - p_d(t'))\]
Token \(t\) receives its proportional share: \[Z \cdot p_{\text{new}}(t) = Z \cdot \frac{p_v(t) - p_d(t)}{Z} = p_v(t) - p_d(t)\]
Even though multiple rejection rounds are possible, the TOTAL probability mass flowing to \(t\) through all rounds sums to \(p_v(t) - p_d(t)\).
Total probability for Case 1: \[P(\text{output } t) = p_d(t) + (p_v(t) - p_d(t)) = p_v(t) \quad \checkmark\]
Case 2: \(p_v(t) < p_d(t)\) (draft overestimates)
\(\alpha(t) = p_v(t)/p_d(t) < 1\), so: \[P(\text{output } t) = p_d(t) \cdot \frac{p_v(t)}{p_d(t)} = p_v(t) \quad \checkmark\]
Key Insight: Acceptance-rejection with probability ratio \(p_v/p_d\) reweights the draft distribution to match the verifier. The adjusted resampling distribution \(\max(0, p_v - p_d)\) captures the "missing mass" when verifier assigns higher probability than draft.
Unbiased Guarantee: \[\mathbb{E}_{t \sim \text{procedure}}[\mathbb{1}_{t = t^*}] = p_v(t^*) \quad \forall t^*\]
This is exact sampling–statistically identical to sampling directly from verifier.
Whiteboard Checkpoint:
Why \(\min(1, p_v/p_d)\)? Caps acceptance at 1 (can’t accept with prob \(>1\)), reweights draft samples
Why adjusted distribution? Fills gap when \(p_v > p_d\) (verifier wants token more than draft predicted)
What if \(p_d(t) = 0\) but \(p_v(t) > 0\)? Draft never proposes \(t\), so rejection resampling must discover it
Speedup? \(\mathbb{E}[\text{accepted}] = \sum_t p_d(t) \alpha(t)\); high when \(p_d \approx p_v\) (distillation!)
Production Note: In practice, compute acceptance for all \(k\) draft tokens in parallel, then process sequentially. First rejection triggers resampling and terminates speculation.
26.4 Why Distillation Matters for Speculative Decoding
Acceptance Rate: Speedup \(\approx k \cdot r\) where \(r\) is acceptance rate.
Impact of Distillation:
Undistilled draft: \(r \approx 0.4\)-\(0.6\) → \(2\)-\(3\times\) speedup
Distilled draft: \(r \approx 0.7\)-\(0.9\) → \(4\)-\(5\times\) speedup
Every 10% increase in \(r\) adds \(\sim\)0.5\(\times\) to effective speedup
Why:
Distillation aligns draft’s distribution with verifier’s
Reduces KL divergence \(\text{KL}(p_v \| p_d)\) → lower rejection rate
Critical for edge deployment where verifier may be INT8 quantized (distribution shift)
26.5 Tokenization Impact on Speculative Decoding
Vocabulary size creates a fundamental tradeoff for speculative decoding:
Small Vocabulary (byte-level, 256 tokens):
Lower branching factor: Only 256 possible next tokens → each prediction is easier
More tokens per text: 4-5\(\times\) more tokens than subword → more sequential verification steps
More draft-target coordination: More rounds of acceptance/rejection checks
Net effect: More verification rounds, even though each individual prediction is simpler
Large Vocabulary (subword BPE, 50k tokens):
Higher branching factor: 50k choices → each prediction is harder (lower confidence)
Fewer tokens per text: Better compression → fewer sequential steps
Better semantic alignment: Draft can propose meaningful subword units with higher acceptance
Net effect: Fewer verification rounds compensates for harder individual predictions
Concrete Example:
Text: “The quick brown fox” (19 characters)
Byte-level tokenization (256 vocab):
Tokens: 19 bytes
Draft generates \(k=8\) tokens → covers \(\sim\)8 chars
Need \(\sim\)2-3 speculation rounds to complete
Subword BPE (50k vocab):
Tokens: 5 subwords [
‘‘The’’,‘‘_quick’’,‘‘_brown’’,‘‘_fox’’]Draft generates \(k=8\) tokens → overshoots the text
Need only 1 speculation round
Subword achieves 2-3\(\times\) fewer rounds despite harder branching.
Production Implications:
Latency: Verification happens sequentially → fewer tokens = lower latency
Acceptance rate: Meaningful subword units → draft-target alignment improves
Typical choice: Large vocab (subword) preferred for speculative decoding
Exception: Byte-level can work with very high acceptance rates (\(r > 0.9\)) from strong distillation
26.6 Practical Considerations
Draft Model Size:
Target: \(5\)-\(10\times\) faster than verifier per token
Typical: 1-3B draft for 7-13B verifier; 0.5-1B draft for 3B verifier
Quantize draft aggressively (INT8/INT4)–acceptance rate matters more than draft quality
Speculative Length \(k\):
Larger \(k\) → more parallelism but lower acceptance rate
Typical: \(k=4\)-\(8\) tokens
Adaptive \(k\): increase when acceptance rate high, decrease when low
Temperature Matching:
Draft should use same temperature as verifier at inference
Mismatch causes distribution shift → lower acceptance
Slightly cooler draft (\(\tau_d = 0.9 \tau_v\)) can reduce rejection spikes
26.7 Advanced Variants
Medusa / Tree Decoding:
Draft generates multiple branches (tree structure)
Verifier validates all branches in parallel
Accept longest valid path
Higher parallelism but more complex implementation
Lookahead KV Reuse:
Cache draft’s KV states
Verifier reuses KV for accepted tokens (avoid recomputation)
Saves \(\sim\)20-30% of verifier compute
Mixture of Experts Drafting:
Train multiple domain-specific draft models
Router selects draft based on context
Improves acceptance on diverse workloads
27 KV Cache Management
27.1 The KV Cache Problem
Memory Growth:
Per token, per layer: 2 tensors (K, V) of size \((n_{\text{heads}}, d_{\text{head}})\)
7B model (32 layers, 32 heads, \(d_{\text{head}}=128\)): \(2 \times 32 \times 32 \times 128 \times 2\text{ bytes (FP16)} = 524\)KB per token
2048 token context: \(\sim\)1GB KV cache (comparable to model size!)
Fixed Memory Budget:
Mobile: 4-8GB total, \(\sim\)2GB for KV cache
Edge server: 16-32GB, \(\sim\)10GB for KV cache
Must evict or compress when budget exceeded
27.2 Eviction Policies
Sliding Window:
Keep most recent \(W\) tokens (e.g., \(W=2048\))
Discard oldest tokens when \(W\) exceeded
Simple but loses long-range context
Use with RoPE scaling or ALiBi to reduce positional encoding drift
Chunked Attention:
Divide context into chunks (e.g., 512 tokens each)
Keep full KV for recent chunks, compressed summaries for old chunks
Summary: average-pool K/V across chunk → \(1\) representative KV per chunk
Trade-off: some quality loss but \(100\times\) compression for old context
Importance-Based Eviction:
Track attention scores: which tokens are attended to most
Evict tokens with low cumulative attention
More complex but preserves important tokens (e.g., subject, entities)
Sparse Attention Patterns:
Only cache tokens that will be attended to (e.g., sliding window + landmarks)
Reduces cache size but requires model trained with sparse attention
27.3 KV Cache Quantization
FP16 → FP8:
Use E4M3 format (wide enough for KV activations)
\(2\times\) memory savings with \(<\)0.5% perplexity drop
Native support on H100, AMD MI300
FP16 → INT8:
Per-channel or per-token quantization
Store scales alongside cache (small overhead)
\(2\times\) memory savings; \(\sim\)1% perplexity drop
Re-quantize every \(N\) tokens to prevent scale drift (e.g., \(N=128\))
Dynamic Re-quantization:
Compute scales online as tokens arrive
Trade latency for memory (re-quant overhead \(\sim\)5-10% per token)
Critical for long contexts where fixed scales drift
27.4 Paged Attention
Problem: Naive cache grows contiguously → large reallocations when exceeding capacity.
Solution: Manage cache as fixed-size pages (e.g., 16-64 tokens per page).
Benefits:
No large memcpy when evicting–just unlink page
Reduces memory fragmentation
Enables efficient multi-request batching (different sequences share page pool)
Used in vLLM, TensorRT-LLM
Implementation:
Allocate page pool at startup (e.g., 1000 pages \(\times\) 512KB = 500MB)
Each sequence has indirection table: logical token index → physical page
Attention kernel reads KV via indirection (minor overhead)
27.5 Impact on Speculative Decoding
KV Cache Sharing:
Draft and verifier share prefix KV (prompt encoding)
Draft generates \(k\) tokens → KV grows by \(k\)
Verifier recomputes KV for accepted tokens (cannot reuse draft KV unless architectures match)
Lookahead optimization: cache draft KV, verifier reuses for accepted tokens
Scale Drift:
If using quantized KV cache, scale drift over long context can cause divergence between draft and verifier
Mitigate: re-quantize synchronously every \(N\) tokens
Or: use FP8/FP16 cache for speculative decoding (accept higher memory cost)
28 Software Frameworks & Implementation
28.1 PyTorch / torch.quantization
Dynamic Quantization (Post-Training):
import torch
import torch.quantization as tq
# Quantize model dynamically (weights INT8, activations computed on-the-fly)
model_quantized = tq.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
Static Quantization (PTQ with Calibration):
# 1. Insert observers
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
model_prepared = torch.quantization.prepare(model)
# 2. Calibrate on representative data
for data in calibration_loader:
model_prepared(data)
# 3. Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared)
Quantization-Aware Training (QAT):
# Insert fake-quant nodes
model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
model_prepared = torch.quantization.prepare_qat(model)
# Fine-tune with fake quantization
for epoch in range(num_epochs):
train(model_prepared, train_loader)
# Convert to actual quantized model
model_quantized = torch.quantization.convert(model_prepared)
Key Points:
fbgemmbackend: x86 CPUs with AVX2/AVX512qnnpackbackend: ARM mobile CPUsUse
torch.ao.quantizationfor newer API (PyTorch 2.0+)Per-channel quantization: set
qconfigwithper_channel_symmetric
28.2 Hugging Face Transformers + Optimum
Load Pre-Quantized Model (GPTQ/AWQ):
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load 4-bit quantized model (GPTQ)
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-2-7B-GPTQ",
device_map="auto",
trust_remote_code=False,
revision="gptq-4bit-32g-actorder_True"
)
# Or AWQ quantized
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Llama-2-7B-AWQ",
device_map="auto"
)
Quantize On-The-Fly (BitsAndBytes):
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
load_in_8bit=True,
device_map="auto"
)
# 4-bit quantization with NF4
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Normal Float 4-bit
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True # Double quantization (scales)
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=quantization_config,
device_map="auto"
)
Optimum for ONNX Export + Quantization:
from optimum.onnxruntime import ORTModelForCausalLM
from optimum.onnxruntime.configuration import AutoQuantizationConfig
# Export to ONNX and quantize
model = ORTModelForCausalLM.from_pretrained("model_name", export=True)
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=True)
quantizer = ORTQuantizer.from_pretrained(model)
quantizer.quantize(save_dir="./quantized_model", quantization_config=qconfig)
28.3 TensorRT-LLM (NVIDIA)
Build Quantized Engine:
# Convert Hugging Face model to TRT-LLM checkpoint
python convert_checkpoint.py --model_dir llama-7b-hf \
--output_dir ./trt_ckpt \
--dtype float16 \
--use_weight_only \
--weight_only_precision int4
# Build TensorRT engine with INT4 weights
trtllm-build --checkpoint_dir ./trt_ckpt \
--output_dir ./trt_engine \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 2048 \
--max_output_len 512
Supported Quantization:
FP8 (H100):
--quantization fp8INT8 (SmoothQuant):
--use_smooth_quantINT4 weights:
--use_weight_only --weight_only_precision int4AWQ:
--use_weight_only --per_groupwith pre-computed scales
28.4 vLLM (High-Throughput Serving)
Launch Server with Quantization:
# AWQ quantized model
python -m vllm.entrypoints.openai.api_server \
--model TheBloke/Llama-2-7B-AWQ \
--quantization awq \
--dtype float16 \
--max-model-len 4096
# SqueezeLLM (INT4 with sensitivity-based scaling)
python -m vllm.entrypoints.openai.api_server \
--model squeeze-ai-lab/llama-7b-squeezellm \
--quantization squeezellm \
--dtype float16
Features:
Paged attention for KV cache management
Continuous batching for high throughput
Supports AWQ, GPTQ, SqueezeLLM quantization
FP8 KV cache quantization (experimental)
28.5 Megatron-LM (Large-Scale Training + Inference)
Quantization in Megatron-Core:
# Launch with FP8 training/inference (H100)
python pretrain_gpt.py \
--fp8-margin 0 \
--fp8-amax-compute-algo max \
--fp8-amax-history-len 1024 \
--transformer-impl transformer_engine
# INT8 inference with TensorRT
python tools/checkpoint_util.py \
--model-type GPT \
--load /path/to/checkpoint \
--save-dir /path/to/trt_model \
--target-tensor-parallel-size 1 \
--target-pipeline-parallel-size 1 \
--quantize int8
Tensor Parallelism with Quantization:
Quantization applied after parallelism splits (each rank quantizes its shard)
Use same calibration data across ranks for consistent scales
FP8 training via Transformer Engine (NVIDIA library)
28.6 llama.cpp (CPU/Mobile Inference)
Quantize GGUF Model:
# Convert Hugging Face to GGUF format
python convert.py --outfile llama-7b.gguf llama-7b-hf/
# Quantize to different precisions
./quantize llama-7b.gguf llama-7b-Q4_K_M.gguf Q4_K_M # INT4 medium
./quantize llama-7b.gguf llama-7b-Q5_K_S.gguf Q5_K_S # INT5 small
./quantize llama-7b.gguf llama-7b-Q8_0.gguf Q8_0 # INT8
# Run inference
./main -m llama-7b-Q4_K_M.gguf -p "Hello world" -n 128
Quantization Schemes:
Q4_K_M: INT4 with per-group scales, medium qualityQ5_K_S: INT5 (5-bit), small model, better than Q4Q8_0: INT8, near-FP16 qualityKvariants: k-quant with importance-based bit allocation
28.7 MLC-LLM (Mobile Deployment)
Compile for Mobile with Quantization:
# Compile model for iPhone/Android with INT4
mlc_llm compile \
--model llama-7b-hf \
--quantization q4f16_1 \
--target iphone \
--output ./dist/llama-7b-q4-iphone
# Quantization formats:
# q4f16_1: INT4 weights, FP16 activations
# q3f16_1: INT3 weights, FP16 activations
# q0f16: FP16 (no quantization, baseline)
iOS/Android Integration:
Uses Metal (iOS) or Vulkan (Android) for GPU acceleration
INT4 weights stored in texture memory
KV cache in FP16, dynamically allocated
28.8 Practical Tips
Choosing Framework:
NVIDIA GPU server: TensorRT-LLM or vLLM (highest throughput)
CPU/edge server: llama.cpp or ONNX Runtime (broad compatibility)
Mobile (iOS/Android): MLC-LLM or Core ML (optimized for ANE/NPU)
Research/prototyping: Hugging Face Transformers + BitsAndBytes (easiest)
Calibration Data:
Use 128-512 samples from target domain
Mix of short and long sequences (cover range of activations)
For PTQ: collect activations during forward pass, compute scales
Store scales alongside weights (per-channel or per-group)
Debugging Quantization Issues:
Compare FP16 vs quantized logits on same input–should be within 1% KL divergence
Check for layers with unusually large scale factors (outliers)
Profile memory: quantized model should match expected reduction
Measure perplexity on validation set–flag if \(>\)2% degradation
29 Deployment Playbook
29.1 Step-by-Step Recipe
1. Profile Target Hardware:
Measure memory bandwidth, compute throughput (TOPS for INT, TFLOPS for FP), power budget
Identify supported formats (INT8, INT4, FP8) and kernel libraries
Benchmark baseline FP16 model: tokens/sec, memory usage, power draw
2. Choose Quantization Strategy:
Start with INT8 weights (RTN or AWQ) + FP16 activations
Validate perplexity on calibration set (\(<\)1% degradation acceptable)
If memory-constrained: move to INT4 weights (use GPTQ or AWQ)
If still constrained: INT2 MLP weights + QAT + distillation
3. Layer-Wise Precision Assignment:
Embeddings, LM head: FP16 or INT8 (careful calibration)
Attention: INT8 or INT4 weights, FP16 activations
MLP: INT4 or INT2 weights, INT8 activations
Norms: FP16 always
KV cache: FP8 or INT8 per-channel
4. Distill Draft Model (if using speculative decoding):
Target \(5\)-\(10\times\) faster than verifier
Distill with high KL weight, match inference temperature
Quantize draft to INT8/INT4
Validate acceptance rate \(>\)70% on test prompts
5. Implement KV Cache Management:
Use paged attention (vLLM or custom implementation)
Set cache budget (e.g., 2GB on mobile)
Choose eviction: sliding window or chunked attention
Quantize cache to FP8/INT8 if memory-critical
6. Optimize Kernels:
Fuse matmul-dequant operations (avoid intermediate FP16 buffers)
Use vendor libraries: TensorRT-LLM (NVIDIA), Core ML (Apple), NNAPI (Android)
Profile hotspots: typically attention softmax, large matmuls, KV cache access
7. Benchmark & Iterate:
Measure: latency (ms/token), throughput (tokens/sec), power (watts), memory (GB)
Compare vs FP16 baseline and targets (e.g., \(<\)200ms/token on phone, \(<\)5W sustained)
If targets not met: more aggressive quantization, smaller model, or cloud offload
29.2 Common Pitfalls
Quantizing softmax/norms: Causes instability–always keep FP16
Per-tensor quantization for attention: Outlier channels dominate–use per-channel
Ignoring KV cache growth: Can exceed model size–plan eviction early
Mismatched draft-verifier temperatures: Kills acceptance rate–tune carefully
Not profiling memory bandwidth: Quantization saves memory but may not improve latency if compute-bound
Skipping calibration data: RTN degrades on OOD data–collect representative samples
30 Interview Cheat Sheet
“What’s the difference between PTQ and QAT?”
PTQ: Quantize pre-trained model without retraining–fast, 1% perplexity drop at INT8
QAT: Fine-tune with fake-quant during training–recovers quality at INT4/INT2, costs compute
“Explain AWQ vs GPTQ.”
AWQ: Protects outlier channels based on activation magnitudes–fast, good for INT4
GPTQ: Uses Hessian (second-order) to optimize weight rounding–slower, slightly better accuracy
“When do you use INT2?”
MLP weight matrices with per-group quantization (group size 64-128)
Requires QAT + distillation to recover quality
Only on hardware with native INT2 support (ASICs, Hexagon DSP)–otherwise overhead negates gains
“What’s FP8 E4M3 vs E5M2?”
E4M3: 4-bit exp, 3-bit mantissa–better precision, narrower range (activations/KV cache)
E5M2: 5-bit exp, 2-bit mantissa–wider range, coarser precision (gradients/outliers)
“How does speculative decoding work?”
Small draft model generates \(k\) tokens, large verifier validates in one pass
Acceptance via importance sampling: \(\alpha = \min(1, p_v / p_d)\)
Speedup \(\approx k \times\) acceptance rate (typically \(2\)-\(5\times\))
“Why distill for speculative decoding?”
Aligns draft distribution with verifier–raises acceptance rate from \(\sim\)50% to \(\sim\)80%
Every 10% acceptance gain adds \(\sim\)0.5\(\times\) speedup
“How do you manage KV cache on mobile?”
Quantize to FP8/INT8 (\(2\times\) savings)
Use sliding window (keep last 2048 tokens) or chunked attention (compress old chunks)
Paged attention to avoid large reallocations
“Which layers should never be quantized?”
Attention softmax (numerical instability)
LayerNorm/RMSNorm (stabilizes residuals)
Keep FP16 or use very careful INT8 calibration for embeddings and LM head
“SmoothQuant in one sentence.”
- Migrates activation outliers into weights via channel-wise scaling–enables INT8 activations
“What’s the memory bottleneck in LLM inference?”
DRAM bandwidth–7B FP16 model requires \(\sim\)14GB read per token
Quantization reduces memory traffic → lower latency (e.g., INT8 = \(4\times\) less bandwidth)
“Typical quantization config for mobile?”
INT8 weights (attention + MLP), FP16 activations
Or: INT4 weights (AWQ), FP16 activations for \(<\)2GB target
KV cache: INT8 per-channel or FP16 if memory allows
Norms/softmax: FP16 always
31 Summary
Key Takeaways:
Quantization: Start INT8 (AWQ/GPTQ), move to INT4 if memory-constrained, INT2 only with QAT + specialized hardware
Layer-wise precision: MLP weights most tolerant (INT4/INT2), attention moderate (INT8/INT4), norms/softmax always FP16
Hardware matters: Match precision to native support (INT8 ubiquitous, INT4 on GPUs/NPUs, INT2 rare)
Distillation: Creates efficient drafts for speculative decoding; aligns distributions for high acceptance rates
Speculative decoding: \(2\)-\(5\times\) speedup via draft-verify; importance sampling ensures unbiased output
KV cache: Grows linearly with context; quantize to FP8/INT8, use paged attention, implement eviction (sliding window/chunked)
Production: Profile hardware first, iterate on precision, fuse kernels, monitor power/thermal
Modern Best Practices (2024):
Mobile (4-8GB): INT4 weights (AWQ), FP16 activations, INT8 KV cache, sliding window
Edge GPU (16-32GB): INT8 weights, FP8 cache, speculative decoding (1B draft + 7B verifier)
Tiny IoT (\(<\)1GB): INT2 MLP weights, INT8 attention, distilled \(<\)1B model, aggressive eviction
For questions, corrections, or suggestions: peymanr@gmail.com