17 Chapter 16: Training Optimization & Parallelization Strategies
18 Introduction: The Scale Challenge
Modern large language models (LLMs) have grown from millions to hundreds of billions of parameters:
GPT-2 (2019): 1.5B params, trained on 40GB text
GPT-3 (2020): 175B params, requires 350GB+ in fp32
LLaMA-2 70B (2023): 70B params, 2T tokens
GPT-4 (2023): Estimated 1.7T params (MoE), multi-modal
Key Challenges:
Memory: A 70B model requires \(\sim\)140GB just for fp16 weights–exceeds single GPU
Compute: Training on trillions of tokens takes months on thousands of GPUs
Communication: Gradient synchronization becomes bottleneck at scale
Efficiency: Memory overhead (optimizer states, activations) can be 10-20\(\times\) model size
This Document Covers:
Data Parallelism: DDP, gradient synchronization, FSDP, ZeRO stages
Model Parallelism: Tensor parallelism, pipeline parallelism, 3D parallelism
Memory Optimization: Mixed precision, gradient accumulation, activation checkpointing
Communication: All-reduce, reduce-scatter, gradient compression
Frameworks: PyTorch FSDP, DeepSpeed, Megatron-LM, Ray
19 Data Parallelism
19.1 Concept: Replicate Model, Split Data
Data Parallelism (DP) replicates the model on each GPU and splits the training batch across devices.
Forward Pass:
Each GPU holds a full copy of model parameters \(\theta\)
Global batch \(B\) split into \(N\) micro-batches: \(B = \{B_1, B_2, \ldots, B_N\}\)
GPU \(i\) processes \(B_i\) independently → computes loss \(\mathcal{L}_i\) and gradients \(g_i\)
Backward Pass:
Each GPU computes local gradients: \(g_i = \nabla_\theta \mathcal{L}_i(\theta)\)
Synchronize gradients via All-Reduce: \(g = \frac{1}{N} \sum_{i=1}^N g_i\)
Update parameters: \(\theta_{t+1} = \theta_t - \eta g\)
Key Insight: All-Reduce ensures every GPU has the same averaged gradient before updating. This keeps model replicas in sync.
Note: This is data parallelism–each GPU has a full copy of the model and processes different data. For models too large to fit on one GPU, use model parallelism (tensor/pipeline parallelism, covered later) where the model itself is split across GPUs.
19.2 DataParallel (DP) vs DistributedDataParallel (DDP)
DataParallel (DP):
Single-process, multi-threaded (Python GIL bottleneck)
Master GPU gathers gradients from all GPUs → broadcasts updated params
Memory imbalance: master GPU stores full batch + optimizer state
Slow: Communication is sequential; limited by GIL
DistributedDataParallel (DDP):
Multi-process: one process per GPU (avoids GIL)
No master GPU: All-Reduce distributes gradients evenly
Uses NCCL (NVIDIA Collective Communications Library) for efficient gradient sync
Faster: Up to 2-3\(\times\) speedup over DP
PyTorch DDP Setup:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group(backend="nccl")
model = MyModel().to(device)
model = DDP(model, device_ids=[local_rank])
# Training loop
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward() # Gradients auto-synced via All-Reduce
optimizer.step()
19.3 All-Reduce Algorithm
All-Reduce is a collective communication operation where each GPU contributes local gradients \(g_i\) and receives the global average \(g = \frac{1}{N} \sum_{i=1}^N g_i\).
Ring All-Reduce (Bandwidth-Optimal):
Arrange \(N\) GPUs in a logical ring
Reduce-Scatter Phase: Each GPU sends chunks to neighbors; gradually accumulate sums
All-Gather Phase: Each GPU sends accumulated chunks to neighbors; reconstruct full result
Complexity Analysis:
Consider \(N\) GPUs, each with gradient data of size \(M\) bytes.
Ring All-Reduce Mechanics:
Split each GPU’s data into \(N\) chunks of size \(M/N\)
Reduce-Scatter: \(N-1\) communication rounds → each GPU sends \((N-1) \cdot \frac{M}{N}\) bytes
All-Gather: \(N-1\) communication rounds → each GPU sends \((N-1) \cdot \frac{M}{N}\) bytes
Total data sent per GPU: \(2(N-1) \cdot \frac{M}{N} = \frac{2(N-1)M}{N}\)
Time Complexity:
Communication time: \(\frac{2(N-1)M}{N \cdot \text{bandwidth}} \approx \frac{2M}{\text{bandwidth}}\) for large \(N\)
As \(N \to \infty\), \(\frac{N-1}{N} \to 1\) → time approaches \(\frac{2M}{\text{bandwidth}}\)
Key insight: Communication cost is independent of \(N\)–scales perfectly!
Practical Implications:
Bandwidth-optimal: Each GPU sends/receives \(\approx 2M\) total (minimal theoretical limit)
Latency: \(2(N-1)\) sequential message steps → can hurt small models where latency dominates
Example: 8 GPUs, 1GB gradients → each GPU sends 1.75GB, takes \(\approx\) 14 message hops
Modern Practice: NCCL implements optimized All-Reduce using tree-based algorithms for small \(N\) (lower latency) and ring-based for large \(N\) (bandwidth-optimal).
19.4 Limitations of Naive Data Parallelism
Memory Replication: Each GPU stores full model (weights, optimizer states, gradients)
Example: 70B model in fp16 → 140GB \(\times\) 8 GPUs = 1.1TB total memory used
Inefficient for Large Models: Cannot train models larger than single GPU memory
Communication Overhead: All-Reduce latency grows with model size
Solution: ZeRO (Zero Redundancy Optimizer) and FSDP (Fully Sharded Data Parallel)
20 ZeRO: Zero Redundancy Optimizer
20.1 Motivation: Eliminate Memory Redundancy
In standard data parallelism, memory breakdown per GPU for a model with \(P\) parameters:
Model weights: \(2P\) bytes (fp16)
Gradients: \(2P\) bytes (fp16)
Optimizer states (Adam): \(12P\) bytes = fp32 master weights (\(4P\)) + fp32 momentum (\(4P\)) + fp32 variance (\(4P\))
Total: \(16P\) bytes per GPU → fully redundant across GPUs
Example: 7B model → \(7 \times 10^9 \times 16 = 112\) GB per GPU (before activations!)
ZeRO (DeepSpeed) removes this redundancy by sharding optimizer states, gradients, and parameters across GPUs.
20.2 ZeRO Stages
ZeRO-1: Optimizer State Partitioning
Shard optimizer states (momentum, variance) across \(N\) GPUs
Each GPU stores \(\frac{1}{N}\) of optimizer states → updates only its assigned params
Still does full All-Reduce: All GPUs compute and store full gradients (\(2P\) each)
After optimizer step: All-Gather updated params to sync models
Memory Reduction: Optimizer states from \(12P\) to \(\frac{12P}{N}\) (gradients still \(2P\))
ZeRO-2: Gradient Partitioning (builds on ZeRO-1)
Replaces All-Reduce with Reduce-Scatter: Each GPU gets different shard of averaged gradients
After backward: GPU \(i\) receives averaged gradients only for parameters it’s responsible for
Each GPU stores \(\frac{1}{N}\) of gradients (matching its optimizer state partition)
Updates only its assigned parameter shard using its local gradient shard
Memory Reduction: Gradients from \(2P\) to \(\frac{2P}{N}\) + Optimizer \(\frac{12P}{N}\)
Concrete Example: 2 GPUs, 4 parameters \((W_1, W_2, W_3, W_4)\)
ZeRO-1 Workflow:
Backward pass: Both GPUs compute local gradients \((g_1, g_2, g_3, g_4)\)
All-Reduce: Average gradients across GPUs
GPU-1 receives: \((G_1, G_2, G_3, G_4)\) (full averaged gradients)
GPU-2 receives: \((G_1, G_2, G_3, G_4)\) (full averaged gradients)
Optimizer step:
GPU-1 has optimizer states for \((W_1, W_2)\) → updates \((W_1, W_2)\) using \((G_1, G_2)\)
GPU-2 has optimizer states for \((W_3, W_4)\) → updates \((W_3, W_4)\) using \((G_3, G_4)\)
All-Gather: Broadcast updated parameters
GPU-1 sends \((W_1', W_2')\) to GPU-2
GPU-2 sends \((W_3', W_4')\) to GPU-1
Both now have full model: \((W_1', W_2', W_3', W_4')\)
Memory: Both GPUs store all gradients (\(2P\)), but only half the optimizer states (\(6P\) each)
ZeRO-2 Workflow:
Backward pass: Both GPUs compute local gradients \((g_1, g_2, g_3, g_4)\)
Reduce-Scatter: Average and distribute gradient shards
GPU-1 receives: \((G_1, G_2)\) only (averaged, for its assigned params)
GPU-2 receives: \((G_3, G_4)\) only (averaged, for its assigned params)
Optimizer step:
GPU-1 updates \((W_1, W_2)\) using \((G_1, G_2)\)
GPU-2 updates \((W_3, W_4)\) using \((G_3, G_4)\)
All-Gather: Broadcast updated parameters (same as ZeRO-1)
Memory: Each GPU stores half gradients (\(P\)) + half optimizer states (\(6P\) each)
Key Insight: In ZeRO-1, you store all gradients but can discard them after optimizer step. ZeRO-2 never stores the full gradient array, saving memory immediately.
ZeRO-3: Parameter Partitioning
Shard model weights: each GPU stores \(\frac{1}{N}\) of parameters
Forward/backward: All-Gather needed params on-the-fly, discard after use
Memory Reduction: Parameters from \(2P\) to \(\frac{2P}{N}\)
Trade-off: Increased communication (All-Gather per layer)
Memory Comparison (7B model, 8 GPUs, fp16):
| Method | Weights | Gradients | Optimizer |
|---|---|---|---|
| Baseline DP | 14 GB | 14 GB | 84 GB |
| ZeRO-1 | 14 GB | 14 GB | 10.5 GB |
| ZeRO-2 | 14 GB | 1.75 GB | 10.5 GB |
| ZeRO-3 | 1.75 GB | 1.75 GB | 10.5 GB |
ZeRO-3 Trade-off: Reduces memory by \(8\times\) but requires All-Gather for every layer. Use ZeRO-2 if communication bandwidth is limited; ZeRO-3 if memory-constrained.
20.3 ZeRO-Offload: CPU Memory Extension
ZeRO-Offload moves optimizer states to CPU memory, reducing GPU memory further:
Forward/backward on GPU (fast)
Gradients transferred to CPU → optimizer step on CPU → updated params back to GPU
Use Case: Training 10B+ models on consumer GPUs (e.g., RTX 3090 24GB)
DeepSpeed ZeRO-3 Config:
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
}
}
21 Fully Sharded Data Parallel (FSDP)
21.1 PyTorch Native Alternative to ZeRO
FSDP (PyTorch 1.11+) is PyTorch’s native implementation of ZeRO-3 principles:
Shards parameters, gradients, and optimizer states across GPUs
Automatically manages All-Gather/Reduce-Scatter during forward/backward
Integrated with PyTorch APIs (no external dependency like DeepSpeed)
21.2 FSDP Workflow
Forward Pass:
GPU \(i\) holds shard \(\theta_i\) (e.g., layers 0-3 of 32-layer model)
Before processing layer \(k\): All-Gather \(\theta_k\) from all GPUs
Compute activations \(a_k = f_k(a_{k-1}, \theta_k)\)
Discard \(\theta_k\) (free memory)
Backward Pass:
Re-All-Gather \(\theta_k\) to compute gradients \(g_k\)
Reduce-Scatter \(g_k\) → each GPU gets averaged shard
Discard \(\theta_k\) again
Optimizer Step:
- Each GPU updates its shard \(\theta_i\) using local optimizer state
PyTorch FSDP Usage:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import CPUOffload, MixedPrecision
model = MyModel()
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
)
# Training loop (identical to DDP)
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
22 Model Parallelism
22.1 When Data Parallelism Fails
Problem: Model too large to fit on a single GPU, even with ZeRO-3.
Example: 175B GPT-3 → \(\sim\)350GB in fp16
ZeRO-3 with 64 GPUs → \(\frac{350}{64} \approx 5.5\) GB per GPU (just for weights)
Activations during forward pass can be \(\sim\)10-50GB per GPU → OOM
Solution: Model Parallelism–split the model itself across GPUs.
ZeRO-3 vs Tensor Parallelism: Both Shard Weights, But Differently!
ZeRO-3 (Data Parallelism with Weight Sharding):
Each GPU stores \(\frac{1}{N}\) of weights (sharded for memory)
During forward: All-Gather to reconstruct full weights temporarily
All GPUs compute the same operations on different data batches
Still data parallelism–just memory-efficient
Tensor Parallelism (Model Parallelism):
Each GPU stores \(\frac{1}{N}\) of weights (sharded for computation)
Each GPU computes different parts of the layer on the same data
Example: GPU-0 computes first half of MLP output, GPU-1 computes second half
Splits both weights and computation
Key Difference: ZeRO-3 shards weights for memory but reconstructs them for computation. Tensor parallelism shards weights and splits the computation itself.
22.2 Tensor Parallelism (Intra-Layer)
Tensor Parallelism splits individual layers across GPUs. Most common for Transformer MLP and attention.
MLP Layer Splitting:
A standard Transformer MLP: \(y = \text{GELU}(xW_1)W_2\) where \(W_1 \in \mathbb{R}^{d \times 4d}\), \(W_2 \in \mathbb{R}^{4d \times d}\).
Column-wise Split of \(W_1\):
Partition \(W_1 = [W_1^{(1)} | W_1^{(2)}]\) across 2 GPUs
Each GPU computes: \(h^{(i)} = \text{GELU}(x W_1^{(i)})\) independently (no communication)
Concatenate: \(h = [h^{(1)} | h^{(2)}]\)
Row-wise Split of \(W_2\):
Partition \(W_2 = \begin{bmatrix} W_2^{(1)} \\ W_2^{(2)} \end{bmatrix}\)
Each GPU computes partial output: \(y^{(i)} = h^{(i)} W_2^{(i)}\)
All-Reduce: Sum \(y = \sum_i y^{(i)}\)
Communication: One All-Reduce per layer (cheap for large \(d\)).
Megatron-LM (NVIDIA) pioneered tensor parallelism for GPT-style models. Used to train GPT-3, MT-NLG (530B params).
22.3 Pipeline Parallelism (Inter-Layer)
Pipeline Parallelism splits the model into stages (layer groups) assigned to different GPUs.
Naive Pipelining (Sequential):
GPU 0: Layers 1-8
GPU 1: Layers 9-16
GPU 2: Layers 17-24
GPU 3: Layers 25-32
Problem: GPU 1 waits for GPU 0 → only 25% utilization (pipeline bubble).
GPipe (Micro-batching):
Split global batch into \(M\) micro-batches
Interleave forward passes: GPU 0 processes micro-batch 1, then GPU 1 starts while GPU 0 processes micro-batch 2
Backward passes flow in reverse: GPU 3 → GPU 2 → GPU 1 → GPU 0
Reduces bubble to \(\frac{N-1}{M}\) where \(N\) = # GPUs, \(M\) = # micro-batches
PipeDream (Asynchronous):
Different micro-batches at different pipeline stages simultaneously
1F1B Schedule: Each stage alternates forward and backward passes (1 Forward, 1 Backward)
Maintains multiple versions of weights to avoid staleness
Pipeline Parallelism Trade-offs:
+ No all-to-all communication (only point-to-point between stages)
- Pipeline bubbles reduce GPU utilization (\(\sim\)80-90% typical)
- Requires careful micro-batch size tuning
22.4 3D Parallelism: Combining All Strategies
Modern LLM training uses 3D Parallelism: Data + Tensor + Pipeline.
Example: 175B GPT-3 on 1024 GPUs (Megatron-LM):
Data Parallelism: 16-way (16 replicas of the full pipeline)
Tensor Parallelism: 8-way (each layer split across 8 GPUs)
Pipeline Parallelism: 8 stages (8 layer groups)
Total: \(16 \times 8 \times 8 = 1024\) GPUs
Communication Hierarchy:
Tensor parallelism: High-bandwidth NVLink within node (fast All-Reduce)
Pipeline parallelism: Point-to-point across nodes (moderate bandwidth)
Data parallelism: All-Reduce across replicas (infrequent, amortized)
23 Mixed Precision Training
23.1 Motivation: Speed and Memory
Standard Training: fp32 (32-bit floating point) for weights, activations, gradients.
Slow: Tensor cores (A100, H100) are optimized for fp16/bf16
Memory-hungry: 2\(\times\) more memory than fp16
Mixed Precision Training uses fp16 for forward/backward, fp32 for optimizer updates.
23.2 FP16 vs BF16
FP16 (IEEE Half-Precision):
1 sign bit, 5 exponent bits, 10 mantissa bits
Range: \(\pm 6.55 \times 10^4\) (limited dynamic range)
Risk: Gradient underflow (small gradients → zero)
BF16 (Brain Float 16):
1 sign bit, 8 exponent bits (same as fp32!), 7 mantissa bits
Range: \(\pm 3.4 \times 10^{38}\) (same as fp32)
Robust: No loss scaling needed
Modern Default: A100, H100, TPUs support bf16 natively
23.3 Mixed Precision Workflow
AMP (Automatic Mixed Precision):
Maintain master copy of weights in fp32
Cast weights to fp16/bf16 for forward pass
Compute loss and gradients in fp16/bf16
(FP16 only) Scale loss to prevent gradient underflow: \(\mathcal{L}' = S \cdot \mathcal{L}\) (typical \(S = 2^{16}\))
Convert gradients back to fp32, unscale (fp16 only)
Update fp32 master weights
Cast updated weights back to fp16/bf16
Memory Overhead of Mixed Precision:
Optimizer (Adam) maintains three fp32 copies: master weights (\(4P\)), momentum (\(4P\)), variance (\(4P\)) = \(12P\) bytes. The fp16 weights (\(2P\)) and gradients (\(2P\)) coexist separately. Total: \(16P\) bytes.
PyTorch AMP:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler() # Loss scaling for fp16
for batch in dataloader:
optimizer.zero_grad()
with autocast(): # fp16 forward/backward
loss = model(batch)
scaler.scale(loss).backward() # Scale loss, compute gradients
scaler.step(optimizer) # Unscale, update fp32 weights
scaler.update() # Adjust scaling factor
Speedup: 2-3\(\times\) on A100/H100; 4-5\(\times\) with Tensor Cores.
24 Memory Optimization Techniques
24.1 Activation Checkpointing (Gradient Checkpointing)
Problem: Forward pass stores activations for backward pass → memory scales with sequence length and depth.
Example: GPT-3 (96 layers, batch=1, seq_len=2048):
Activations: \(\sim\)10GB per sample
Batch size 32 → 320GB (exceeds A100 80GB)
Activation Checkpointing (also called Gradient Checkpointing):
Forward pass: Store activations only at checkpoints (e.g., every \(k\) layers)
Backward pass: Recompute intermediate activations on-the-fly from checkpoints
Trade-off: \(\sqrt{N}\) checkpoints → \(\sqrt{N}\) memory, \(2\times\) compute (1 forward + 1 recompute)
Memory Reduction: \(10\times\) typical (e.g., 320GB → 32GB).
PyTorch Gradient Checkpointing:
from torch.utils.checkpoint import checkpoint
def forward(x):
x = checkpoint(layer1, x) # Don't store layer1 activations
x = checkpoint(layer2, x)
return x
Hugging Face Transformers:
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
use_cache=False, # Required for gradient checkpointing
)
model.gradient_checkpointing_enable()
24.2 Gradient Accumulation
Gradient Accumulation simulates large batch sizes on limited memory:
Run \(K\) micro-batches (batch size \(B / K\))
Accumulate gradients: \(g = \sum_{i=1}^K g_i\)
Update weights once: \(\theta_{t+1} = \theta_t - \eta g\)
Effect: Effective batch size \(B\) with memory for \(B / K\).
Trade-off: Training time increases by \(K\times\) (no parallelism across micro-batches).
PyTorch Gradient Accumulation:
accumulation_steps = 4
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps # Scale loss
loss.backward() # Accumulate gradients
if (i + 1) % accumulation_steps == 0:
optimizer.step() # Update weights
optimizer.zero_grad() # Clear gradients
24.3 Flash Attention
Flash Attention (Dao et al., 2022) optimizes self-attention memory via kernel fusion and tiling:
Standard attention: Materializes \(QK^T\) matrix (\(O(N^2)\) memory for seq len \(N\))
Flash Attention: Computes attention in blocks, never stores full \(QK^T\)
Memory: \(O(N)\) instead of \(O(N^2)\)
Speed: 2-4\(\times\) faster (reduces HBM read/writes)
Adoption: Default in Hugging Face Transformers 4.36+, PyTorch 2.0+.
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
)
25 Communication Optimization
25.1 Gradient Compression
Motivation: All-Reduce bandwidth bottleneck–especially for multi-node training.
Gradient Compression reduces communication volume:
Top-k Sparsification: Send only top \(k\%\) largest gradients (1-10%)
Quantization: Reduce precision (e.g., fp16 → int8)
Error Feedback: Accumulate residual errors to avoid bias
Trade-off: 10-100\(\times\) compression with minimal accuracy loss (<0.5% typically).
25.2 Overlap Communication with Computation
DDP Optimization: Start All-Reduce for layer \(i\) gradients while computing layer \(i+1\).
ZeRO-3 Prefetching: Pre-fetch parameters for layer \(i+1\) while computing layer \(i\).
Effect: Hides communication latency–critical for multi-node training.
Modern Frameworks Automate This:
PyTorch DDP: Automatic bucketing + overlap
DeepSpeed:
overlap_comm=trueFSDP:
forward_prefetch=True
25.3 Network Topology Awareness
Intra-node vs Inter-node Communication:
NVLink (intra-node): 600 GB/s (H100), extremely low latency
InfiniBand/RoCE (inter-node): 200-400 Gb/s, higher latency
Optimization Strategies:
Place tensor parallelism within node (maximize NVLink usage)
Pipeline parallelism across nodes (minimize inter-node traffic)
Data parallelism for outer replication (amortize All-Reduce)
26 Framework Comparison
26.1 DeepSpeed
Strengths:
ZeRO-Offload: Train 10B+ models on single GPU
Pipeline Parallelism: Built-in support
Sparse Attention: Memory-efficient attention variants
Used by Microsoft, BigScience (BLOOM 176B), StabilityAI
Weaknesses:
External Dependency: Not native PyTorch
Debugging: Harder to debug than FSDP
26.2 PyTorch FSDP
Strengths:
Native PyTorch: Better integration with PyTorch 2.x (torch.compile)
Simpler API: Easier to adopt for existing codebases
Used by Meta (LLaMA-2, LLaMA-3)
Weaknesses:
No CPU Offload: Must fit in GPU memory (with sharding)
No Pipeline Parallelism: Only data + tensor parallelism
26.3 Megatron-LM
Strengths:
Tensor Parallelism: Best-in-class implementation
3D Parallelism: Combines data + tensor + pipeline
Used by NVIDIA (GPT-3, MT-NLG 530B, Megatron-Turing)
Weaknesses:
Complex Setup: Requires deep understanding of parallelism
Less Flexible: Tightly coupled to specific model architectures
26.4 When to Use What?
| Scenario | Recommendation |
|---|---|
| Single GPU, <10B model | Standard training (no parallelism) |
| Multi-GPU, single node | DDP or FSDP (SHARD_GRAD_OP) |
| Large model (10-70B), multi-GPU | FSDP (FULL_SHARD) or DeepSpeed ZeRO-2 |
| Very large model (70B+), multi-node | FSDP + tensor parallelism or DeepSpeed ZeRO-3 |
| Extreme scale (100B+) | Megatron-LM (3D parallelism) |
| Limited GPU memory | DeepSpeed ZeRO-Offload |
27 Practical Training Recipes
27.1 Recipe 1: Fine-Tuning 7B Model (LLaMA-2) on 4\(\times\)A100
Setup:
Model: LLaMA-2-7B (7B params, fp16 → 14GB)
Batch size: 32 (8 per GPU)
Seq length: 2048
Strategy:
FSDP: SHARD_GRAD_OP (ZeRO-2)
Mixed precision: bf16 (no loss scaling)
Gradient checkpointing: Enabled (saves \(\sim\)40GB activations)
Flash Attention 2: Reduces memory by 2\(\times\)
# Launch with torchrun
torchrun --nproc_per_node=4 train.py \
--model_name meta-llama/Llama-2-7b-hf \
--fsdp "shard_grad_op" \
--bf16 true \
--gradient_checkpointing true \
--attn_implementation flash_attention_2
27.2 Recipe 2: Pre-Training 70B Model on 64\(\times\)H100
Setup:
Model: 70B params (GPT-3 scale)
Batch size: 2048 (32 per GPU)
Seq length: 4096
Strategy:
Data Parallelism: 8-way (8 replicas)
FSDP: FULL_SHARD (ZeRO-3) within each replica
Tensor Parallelism: 8-way (split layers across 8 GPUs)
Total: \(8 \times 8 = 64\) GPUs
Mixed precision: bf16
Gradient accumulation: 4 steps
Activation checkpointing: Every 4 layers
Expected Throughput: \(\sim\)2-3 tokens/sec/GPU → 130-200 tokens/sec total.
27.3 Recipe 3: Budget Training (10B Model on 2\(\times\)RTX 3090)
Setup:
Model: 10B params
GPU: RTX 3090 (24GB each)
Batch size: 4 (2 per GPU)
Strategy:
DeepSpeed ZeRO-3 + Offload: Offload optimizer states to CPU
Mixed precision: fp16 (bf16 not supported on RTX 3090)
Gradient checkpointing: Aggressive (every 2 layers)
Gradient accumulation: 8 steps (effective batch 32)
Expected Speed: Very slow (\(\sim\)0.1-0.2 tokens/sec/GPU)–CPU bottleneck.
28 Interview Questions & Key Concepts
28.1 Common Interview Questions
Q: Explain the difference between DataParallel and DistributedDataParallel.
A: DP is single-process multi-threaded (Python GIL bottleneck), uses master GPU to aggregate gradients. DDP is multi-process (one per GPU), uses All-Reduce for gradient sync–no master GPU, much faster (2-3\(\times\)).
Q: What is ZeRO and how does it differ from FSDP?
A: ZeRO (DeepSpeed) shards optimizer states (Stage 1), gradients (Stage 2), and parameters (Stage 3) to reduce memory redundancy. FSDP is PyTorch’s native implementation of ZeRO-3 principles, with tighter integration but fewer features (no CPU offload in vanilla FSDP).
Q: When would you use tensor parallelism vs pipeline parallelism?
A: Tensor parallelism splits layers across GPUs (intra-layer)–requires high bandwidth (use within node via NVLink). Pipeline parallelism splits model into stages (inter-layer)–uses point-to-point communication (suitable across nodes). For very large models, combine both (3D parallelism).
Q: Explain gradient checkpointing. What’s the trade-off?
A: Stores activations only at checkpoints (e.g., every \(k\) layers), recomputes intermediate activations during backward. Trade-off: Reduces memory by \(\sim\)10\(\times\) but increases compute by \(\sim\)2\(\times\) (one forward pass + one recompute).
Q: Why use bf16 instead of fp16 for mixed precision?
A: BF16 has same dynamic range as fp32 (8 exponent bits) → no gradient underflow, no loss scaling needed. FP16 has limited range (5 exponent bits) → requires loss scaling to avoid underflow. Modern GPUs (A100, H100) support bf16 natively.
Q: How does Flash Attention reduce memory?
A: Standard attention materializes \(QK^T\) matrix (\(O(N^2)\) memory). Flash Attention uses kernel fusion and tiling to compute attention in blocks, never storing full \(QK^T\)–reduces memory to \(O(N)\) and speeds up by 2-4\(\times\) via fewer HBM reads/writes.
Q: Explain the All-Reduce algorithm used in DDP.
A: Ring All-Reduce: GPUs arranged in logical ring, two phases: (1) Reduce-Scatter–each GPU sends chunks, accumulates sums; (2) All-Gather–broadcast accumulated chunks. Bandwidth-optimal: \(O(2M / \text{bandwidth})\) independent of \(N\).
Q: What’s the memory breakdown for training a Transformer model?
A: For \(P\) parameters:
Weights: \(2P\) (fp16)
Gradients: \(2P\) (fp16)
Optimizer (Adam): \(12P\) (fp32 momentum + variance)
Activations: Depends on batch size and sequence length–often \(>10P\)
Q: How would you train a 175B model?
A: Use 3D parallelism: (1) Tensor parallelism (8-way, within node), (2) Pipeline parallelism (8 stages, across nodes), (3) Data parallelism (16 replicas). Total: \(8 \times 8 \times 16 = 1024\) GPUs. Use ZeRO-1 or FSDP (SHARD_GRAD_OP) for data parallelism. Enable gradient checkpointing, bf16, Flash Attention.
Q: What causes pipeline bubbles and how do you minimize them?
A: Bubbles occur when pipeline stages wait for previous stages. Minimize by: (1) Increase micro-batch count (bubble fraction \(\sim (N-1)/M\)), (2) Use 1F1B schedule (interleave forward/backward), (3) Balance stage compute times (even layer distribution).
Q: Explain tensor contiguity in PyTorch. Why does
viewsometimes fail?A: A tensor is contiguous when its data is laid out in one uninterrupted block of memory with the expected stride order (row-major by default). Operations like
transpose(),slice[::2], or certain views can produce non-contiguous tensors–they reference the same underlying data but with different strides.view()only works on contiguous tensors because it just changes the shape metadata without moving data. If a tensor isn’t contiguous, you must:Call
.contiguous().view(...)to create a contiguous copy firstUse
.reshape(...)which automatically copies if needed
Example:
x = torch.randn(3, 4) y = x.transpose(0, 1) # Non-contiguous (strides changed) # y.view(-1) would fail! z = y.contiguous().view(-1) # Works # Or: z = y.reshape(-1) # Also works (copies if needed)Q: Does
viewwork differently in distributed training (DDP)?A: No–
viewworks the same way in DDP. It’s a local tensor operation that just requires contiguity. In DDP:Each process has its own replica of the model
viewoperates on the local tensor shardIf a tensor becomes non-contiguous (e.g., after
transpose,gather, or custom ops), use.contiguous().view(...)
Caveat with DTensor: Newer PyTorch distributed APIs like
DTensormay have restrictions on view-like operations depending on sharding strategy, but standard DDP tensors follow normal contiguity rules.Q: When should you worry about contiguity in practice?
A: Be conscious of contiguity when:
After
transpose(),permute(), or non-trivial slicingBefore
view()–it will throwRuntimeErrorif non-contiguousWhen passing tensors to C++/CUDA kernels that assume contiguous layout
In tight training loops–unnecessary
.contiguous()copies waste memory/time
Best practice:
Use
.reshape(...)instead of.view(...)if you’re unsure (it handles contiguity automatically)Check contiguity with
tensor.is_contiguous()when debuggingAvoid calling
.contiguous()unnecessarily–it creates a copy
29 Distributed Training Diagnostics & Optimization
29.1 Introduction: The MLOps Optimization Mindset
When handed a distributed training workload that’s underperforming, your job is to identify the limiting resource and apply the smallest safe change that moves it. A useful mental model is:
Executive Mental Model (Vendor-Agnostic): GPU performance is a scheduling problem under resource constraints. You maximize throughput by keeping execution pipelines busy while respecting:
Reuse (Arithmetic Intensity): Reuse data so you do more FLOPs per byte from HBM.
Latency hiding: Hide long-latency ops via occupancy (many warps/waves) and ILP (independent instructions within a warp/wave).
Resource pressure: Registers and shared memory/LDS limit concurrency; too much pressure collapses occupancy.
Communication placement: In distributed training, topology and overlap determine whether collectives dominate iteration time.
Everything else (Triton/CUDA/HIP/TVM, cuBLAS/rocBLAS, Megatron/DeepSpeed/FSDP) is a way to move these levers.
29.1.1 GPU Execution Model: NVIDIA vs AMD (Practical Differences)
Common hierarchy: GPU \(\rightarrow\) SM/CU \(\rightarrow\) blocks/work-groups \(\rightarrow\) warps/wavefronts \(\rightarrow\) threads.
| Concept | NVIDIA | AMD |
|---|---|---|
| Warp / wave size | 32 threads (warp) | 64 threads (wavefront) |
| Register model | Unified registers | VGPR (vector) + SGPR (scalar) |
| Shared memory | SMEM (banked) | LDS (banked) |
| Matrix ISA | Tensor Cores (MMA/WMMA) | Matrix Cores (MFMA) |
| Occupancy sensitivity | Moderate | Often higher (VGPR pressure) |
29.1.2 Latency Hiding: Occupancy vs ILP (Critical Distinction)
Occupancy hides latency by switching to another warp/wavefront. ILP (Instruction-Level Parallelism) hides latency by issuing independent instructions within the same warp/wavefront while earlier instructions are still in flight.
When occupancy is capped (register/LDS pressure, large tiles, heavy fusion), ILP becomes the dominant lever.
Why this matters: Many GPU ops have long latency (HBM loads, matrix instructions, transcendental ops). You typically cannot reduce latency; you must hide it.
Unrolling (what it really does):
Increases ILP by breaking dependency chains and creating multiple independent accumulators.
Tradeoff: more unrolling \(\rightarrow\) more live values \(\rightarrow\) more registers \(\rightarrow\) lower occupancy (too much unrolling can slow the kernel).
Unrolling improves achieved throughput, not FLOPs-per-byte (it does not change arithmetic intensity).
29.1.3 Matrix Instructions and Shape Effects (Tensor Cores / MFMA)
Matrix instructions are wide (many FLOPs/instruction) and often long latency; reaching peak TFLOPs usually requires multiple independent matrix instructions in flight (ILP). Skinny GEMMs often underperform on both vendors due to insufficient ILP and poor utilization.
29.1.4 Roofline + Arithmetic Intensity (AI): Decide Memory-Bound vs Compute-Bound
The roofline model gives two ceilings:
Bandwidth roof: \(\text{Perf} \le \text{BW} \cdot \text{AI}\) (memory-bound region)
Compute roof: \(\text{Perf} \le \text{Peak FLOPs}\) (compute-bound region)
The knee occurs at: \[\text{AI}_{\text{knee}} = \frac{\text{Peak Compute}}{\text{HBM Bandwidth}}\]
Practical tiling rules:
AI increases by reusing data from registers/SMEM/LDS instead of reloading from HBM.
Increasing the reduction tile (often
BLOCK_K) is frequently the most efficient way to raise AI.Increasing
BLOCK_M/Ncan raise register pressure quickly; performance improves until occupancy collapses.
Shared memory / LDS bank conflicts: Shared memory is banked; repeated access to the same bank serializes. Fixes include padding, layout changes, swizzling, and reducing shared footprint.
29.1.5 Kernel Tuning Knobs (Portable Concepts)
Primary knobs: Tile sizes (M/N/K), warps/waves per block, pipeline stages, unrolling.
Secondary knobs: Layout/alignment, vectorization width, epilogue fusion.
Production must-haves: Shape bucketing, config caching, correctness/stability checks.
29.1.6 From Kernel to Cluster: Why Distributed Training Changes the Game
Parallelism choices change the kernel shapes and communication pattern:
Tensor parallelism splits weight matrices across GPUs, shrinking per-GPU GEMMs and increasing collective frequency.
Consequence: kernel efficiency matters more, and communication can dominate iteration time without careful overlap.
Portable priority order: topology (keep TP within the fastest fabric) \(\rightarrow\) overlap (async collectives, fewer barriers) \(\rightarrow\) reduce frequency/volume (fusion, bucketization).
29.1.7 A Senior Diagnostic Loop (Use This in Interviews)
Locate time: compute vs communication vs input pipeline vs idle.
Classify the bound: memory-bound vs compute-bound (roofline + counters) vs comm-bound (profiler trace).
Identify the limiter: occupancy vs ILP vs bandwidth vs synchronization/barriers.
Apply the smallest safe change: e.g., overlap AllReduce, adjust bucket size, choose better GEMM backend, reduce register/shared pressure, fuse epilogues.
Re-measure: keep changes that produce stable gains; roll back fragile wins.
Gold interview sentences:
“Performance is fundamentally a scheduling problem under resource constraints.”
“ILP hides latency within a warp; occupancy hides latency across warps.”
“Unrolling increases ILP, not arithmetic intensity.”
“Tile sizes increase AI until register/LDS pressure collapses occupancy.”
“With cross-node tensor parallelism, overlap is survival.”
29.2 Memory Bandwidth: Types and Bottlenecks
29.2.1 Memory Hierarchy in GPU Clusters
Modern distributed training involves multiple memory types with vastly different bandwidths:
| Memory Type | Bandwidth | Latency | Typical Use |
|---|---|---|---|
| GPU Registers | \(\sim\)20 TB/s | \(<\)1ns | Kernel-local variables |
| L1/L2 Cache | 10–20 TB/s | 1–10ns | Frequently accessed data |
| GPU HBM | 1.5–3 TB/s | \(\sim\)100ns | Model weights, activations |
| NVLink (GPU-GPU) | 600–900 GB/s | \(\sim\)1\(\mu\)s | Tensor/pipeline parallelism |
| PCIe 4.0 | 32 GB/s | \(\sim\)5\(\mu\)s | CPU-GPU transfers |
| InfiniBand (inter-node) | 200–400 GB/s | \(\sim\)5\(\mu\)s | All-Reduce across nodes |
| Ethernet (inter-node) | 10–100 GB/s | \(\sim\)10\(\mu\)s | Data loading, checkpoints |
| CPU RAM | 100–200 GB/s | \(\sim\)100ns | Offloaded optimizer states |
| NVMe SSD | 5–15 GB/s | \(\sim\)100\(\mu\)s | Dataset streaming, checkpointing |
Key Observations:
GPU HBM bandwidth (1.5–3 TB/s) is the primary bottleneck for memory-bound kernels
Network bandwidth (NVLink \(>\) InfiniBand \(>\) Ethernet) determines communication efficiency
CPU-GPU PCIe is 10–100\(\times\) slower than GPU HBM–minimize transfers
29.2.2 Signs You’re Memory Bandwidth Bound
Symptom 1: Low GPU Compute Utilization Despite Full Memory
nvidia-smi
+-----------------------------------------------------------------------------+
| GPU Name Utilization-Gpu Memory-Usage Temperature Power |
|=============================================================================|
| 0 H100 35% 78000MiB/80000 65C 320W |
+-----------------------------------------------------------------------------+
Interpretation:
GPU utilization \(<\)50% but memory nearly full → likely memory-bound
GPU is idle waiting for data to move between HBM and compute cores
Common in attention layers, large embedding lookups, layernorm
Symptom 2: Profiler Shows High Memory-Bound Kernel Time
Use nsys (NVIDIA Nsight Systems) to profile:
nsys profile -o profile.qdrep python train.py
nsys stats profile.qdrep --report cuda_kern_sum
Look for:
Achieved memory bandwidth vs theoretical peak
If achieved \(>\)80% of peak → memory-bound
Common culprits: elementwise ops (ReLU, dropout), reductions (sum, mean), transposes
Symptom 3: Kernel Execution Time Scales with Tensor Size, Not FLOPS
Matrix multiply (GEMM): time \(\propto\) FLOPs (compute-bound)
Elementwise ops: time \(\propto\) tensor size in bytes (memory-bound)
Test: Double batch size. If training time increases by 2\(\times\) → memory-bound. If \(<\)2\(\times\) → compute-bound.
29.2.3 How to Fix Memory Bandwidth Bottlenecks
1. Kernel Fusion
Combine multiple elementwise ops into a single kernel to reduce HBM roundtrips.
Before (Unfused):
# Three separate kernels, three HBM reads/writes
x = input + bias # Kernel 1: read input, write x
y = torch.relu(x) # Kernel 2: read x, write y
z = dropout(y) # Kernel 3: read y, write z
After (Fused):
# torch.compile with inductor fuses into one kernel
@torch.compile
def fused_op(input, bias, p=0.1):
return F.dropout(F.relu(input + bias), p)
z = fused_op(input, bias) # Single kernel, one HBM roundtrip
Speedup: 2–3\(\times\) for chains of elementwise ops.
2. Flash Attention (IO-Aware Attention)
Standard attention materializes full \(N \times N\) attention matrix in HBM:
# Standard attention (memory-bound)
scores = Q @ K.T # (N, N) matrix in HBM
attn = softmax(scores)
out = attn @ V
Flash Attention fuses operations and uses tiling to keep intermediates in SRAM:
from flash_attn import flash_attn_func
out = flash_attn_func(Q, K, V) # No (N, N) materialization
Result: 3–5\(\times\) speedup, 10–20\(\times\) memory reduction for long sequences.
3. Avoid Unnecessary Transposes and Copies
Transpose creates non-contiguous tensor → next op may trigger copy
Use
.reshape()instead of.view()+.contiguous()Pre-transpose weight matrices once during initialization
4. Use Lower Precision (BF16/FP16)
FP16/BF16 weights/activations → 2\(\times\) less memory traffic
Modern GPUs (A100/H100) have specialized BF16 hardware
Enable via
torch.autocast(device_type=’cuda’, dtype=torch.bfloat16)
29.3 Network Bandwidth Monitoring & Optimization
29.3.1 Measuring Network Utilization
1. NCCL Test (Bandwidth Benchmark)
# Clone NCCL tests
git clone https://github.com/NVIDIA/nccl-tests.git
cd nccl-tests && make
# Run All-Reduce test across all GPUs
mpirun -np 8 ./build/all_reduce_perf -b 1G -e 8G -f 2 -g 1
# Output:
# size time algbw busbw
# 1073741824 15.2ms 70.5GB/s 141GB/s # Bus bandwidth = effective BW
Interpretation:
Bus bandwidth = effective aggregate bandwidth considering all links
For NVLink: expect \(\sim\)600–900 GB/s (H100/A100)
For InfiniBand: expect \(\sim\)200–400 GB/s
If \(<\)50% of theoretical → network misconfiguration or congestion
2. Monitor Network Traffic During Training
# Monitor network interfaces
watch -n 1 'ifconfig | grep -A 5 ib0' # InfiniBand
iftop -i ib0 # Real-time traffic
# Or use NCCL environment variable
export NCCL_DEBUG=INFO
python train.py # Logs NCCL operations and timings
3. Identify Communication vs Compute Time
Use PyTorch profiler with communication tracing:
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_stack=True) as prof:
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
prof.export_chrome_trace("trace.json")
# View in chrome://tracing
Look for:
ncclAllReduce time in trace → communication overhead
If AllReduce \(>\)20% of iteration time → communication-bound
29.3.2 Signs of Communication Bottlenecks
1. Poor Scaling Efficiency
Scaling efficiency = \(\frac{\text{Speedup}}{\text{\# GPUs}}\)
Example:
1 GPU: 100 samples/sec
8 GPUs: 600 samples/sec → speedup = 6\(\times\) → efficiency = 75%
64 GPUs: 3200 samples/sec → speedup = 32\(\times\) → efficiency = 50%
Efficiency drops as # GPUs grows → communication overhead dominates.
2. High Gradient All-Reduce Time
In DDP, gradients are synchronized via AllReduce after backward pass:
# Check AllReduce time
NCCL_DEBUG=INFO python train.py 2>&1 | grep "AllReduce time"
Rule of thumb:
AllReduce time should be \(<\)10–15% of backward pass time
If \(>\)20% → communication-bound
3. Network Congestion (Multi-Job Clusters)
Multiple training jobs share network → contention
Use network QoS or dedicated VLANs for training traffic
Monitor with
ibstat(InfiniBand) orethtool(Ethernet)
29.3.3 Optimizing Communication Overhead
1. Gradient Accumulation (Reduce AllReduce Frequency)
Instead of AllReduce every step, accumulate gradients over \(k\) micro-batches:
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward() # Gradients accumulate locally
if (i + 1) % accumulation_steps == 0:
optimizer.step() # AllReduce happens here
optimizer.zero_grad()
Effect: Reduce AllReduce calls by \(k\times\) at cost of staleness.
2. Overlap Computation and Communication
DDP uses gradient bucketing to overlap AllReduce with backward pass:
Gradients grouped into buckets (\(\sim\)25MB each)
As soon as a bucket is ready, AllReduce starts (asynchronously)
Later layers continue backward pass while early gradients communicate
Tune bucket size:
model = DDP(model, bucket_cap_mb=25) # Default 25MB
# Larger buckets → fewer AllReduces, more latency
# Smaller buckets → more overlap, higher overhead
3. Use ZeRO/FSDP to Reduce Communication Volume
ZeRO-2/FSDP shard gradients → use reduce-scatter instead of AllReduce
Communication volume same, but each GPU only stores \(1/N\) of gradients
Frees memory for larger batches → amortizes communication
4. Enable NCCL Optimizations
export NCCL_IB_DISABLE=0 # Enable InfiniBand (if available)
export NCCL_P2P_DISABLE=0 # Enable peer-to-peer (NVLink)
export NCCL_ALGO=Ring # Ring algorithm (default, bandwidth-optimal)
export NCCL_SOCKET_IFNAME=ib0 # Specify InfiniBand interface
5. Topology-Aware Placement
Place processes to maximize NVLink usage (intra-node)
Minimize cross-node communication (use pipeline parallelism for stages)
Use
nvidia-smi topo -mto view GPU interconnect topology
29.4 GPU Utilization & Compute Bottlenecks
29.4.1 Measuring GPU Utilization
1. nvidia-smi (Coarse-Grained)
nvidia-smi dmon -s u -c 100 # Monitor utilization every second
# gpu sm mem enc dec
# 0 95 78 0 0 ← SM utilization 95%, good!
# 1 42 80 0 0 ← SM utilization 42%, investigate
SM Utilization (Streaming Multiprocessor):
\(>\)80%: GPU well-utilized (compute-bound)
50–80%: Mixed (some idle time, check memory/IO)
\(<\)50%: Underutilized (likely data loading, communication, or memory-bound)
2. PyTorch Profiler (Fine-Grained)
python -m torch.utils.bottleneck train.py
# Or use torch.profiler with TensorBoard
Identify:
Kernel time: Time spent in CUDA kernels (GEMM, softmax, etc.)
CPU time: Data preprocessing, Python overhead
CUDA memcpy: Host-device transfers
3. NVIDIA Nsight Compute (Kernel-Level)
For deep kernel analysis:
ncu --set full -o profile python train.py
ncu-ui profile.ncu-rep # View in GUI
Shows:
Occupancy (active warps / max warps)
Memory throughput (achieved vs peak)
Instruction mix (compute vs memory ops)
Roofline model (compute-bound vs memory-bound)
29.4.2 Common GPU Underutilization Causes
1. Data Loading Bottleneck
Symptom: GPU idles waiting for next batch.
Diagnosis:
# Profile with PyTorch DataLoader profiling
with profile(record_shapes=True) as prof:
for batch in dataloader:
model(batch)
print(prof.key_averages().table(sort_by="cpu_time_total"))
# Look for high "DataLoader" time
Fixes:
Increase
num_workersin DataLoader (typically 2–4 per GPU)Use
pin_memory=Truefor faster CPU→GPU transfersPrefetch batches:
prefetch_factor=2(PyTorch 1.8+)Use faster storage (NVMe SSD instead of network filesystem)
Pre-process data offline (e.g., tokenize once, save to disk)
2. Small Batch Size (Insufficient Parallelism)
Symptom: Low GPU utilization, kernels finish quickly.
Diagnosis:
# Check batch size utilization
print(f"Batch size: {batch.size(0)}")
print(f"GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
Fixes:
Increase batch size until GPU memory \(\sim\)80–90% full
Use gradient accumulation if single-batch OOM
Enable mixed precision (BF16) to fit larger batches
Use gradient checkpointing to trade compute for memory
3. Synchronization Points (Blocking Operations)
Operations that force GPU synchronization stall the pipeline:
.item()(copies scalar from GPU to CPU).cpu()(blocks until kernel completes)print(tensor)(implicit.item())Conditional logic on GPU tensors
Fix: Defer synchronization to end of epoch or use async copies.
4. Inefficient Kernels (Custom Ops)
Symptom: Specific layer has low utilization despite large input.
Diagnosis:
ncu --set full -k regex:<kernel_name> python train.py
# Check occupancy and memory throughput
Fixes:
Replace custom ops with optimized libraries (cuDNN, cuBLAS, Flash Attention)
Use
torch.compilefor kernel fusionProfile and optimize CUDA code (increase occupancy, reduce register usage)
29.5 Diagnostic Experiments: Systematic Troubleshooting
29.5.1 Binary Search for Bottlenecks
Step 1: Isolate Components
Run minimal versions to identify culprit:
# 1. Forward-only (no backward, no optimizer)
with torch.no_grad():
loss = model(batch)
# 2. Forward + backward (no optimizer)
loss = model(batch)
loss.backward()
# 3. Full training step
loss = model(batch)
loss.backward()
optimizer.step()
Compare iteration times to localize issue.
Step 2: Vary Batch Size
for batch_size in [8, 16, 32, 64, 128]:
# Measure throughput (samples/sec) and GPU utilization
Expected:
Throughput increases with batch size (up to memory limit)
GPU utilization increases with batch size
If plateaus early → memory-bound or data-loading bottleneck
Step 3: Profile with and without Communication
# Single GPU (no communication)
python train.py --gpus 1
# Multi-GPU with DDP
torchrun --nproc_per_node=8 train.py
If multi-GPU is \(<\)8\(\times\) faster → communication overhead.
Step 4: Microbenchmark Individual Layers
# Benchmark attention layer
layer = MultiHeadAttention(...)
input = torch.randn(batch, seq, dim, device='cuda')
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
output = layer(input)
torch.cuda.synchronize()
print(f"Avg time: {(time.time() - start) / 100 * 1000:.2f} ms")
Compare against Flash Attention or other implementations.
29.5.2 Key Metrics to Track
| Metric | Tool | Target / Threshold |
|---|---|---|
| Samples/sec | Training loop timer | Maximize |
| GPU utilization | nvidia-smi |
\(>\)80% |
| GPU memory usage | nvidia-smi |
80–95% (full but not OOM) |
| AllReduce time | NCCL_DEBUG | \(<\)15% of step time |
| Data loading time | PyTorch profiler | \(<\)5% of step time |
| Model FLOPs utilization | Nsight Compute | \(>\)50% of peak |
| Network bandwidth | nccl-tests |
\(>\)80% of theoretical |
| Scaling efficiency | Multi-GPU benchmarks | \(>\)80% (up to 64 GPUs) |
29.5.3 Optimization Decision Tree
Is GPU utilization \(<\)50%?
YES → Check data loading (increase workers, pin memory)
NO → Continue
Is GPU memory usage \(<\)80%?
YES → Increase batch size or use gradient accumulation
NO → Continue
Is AllReduce time \(>\)20% of step time?
YES → Enable gradient accumulation, check NCCL config, use FSDP
NO → Continue
Is memory bandwidth \(<\)80% of peak?
YES → Use Flash Attention, kernel fusion, mixed precision
NO → Continue
Is scaling efficiency \(<\)70%?
YES → Optimize communication (larger batches, FSDP, topology)
NO → System is well-optimized!
29.6 ROCm & AMD GPU Optimization
29.6.1 Introduction: AMD Instinct GPU Stack
AMD Instinct accelerators (MI200, MI300 series) use the ROCm (Radeon Open Compute) platform:
ROCm: Open-source GPU compute platform (analogous to CUDA)
HIP: Heterogeneous Interface for Portability (CUDA-compatible API)
hipify: Tool to auto-convert CUDA code to HIP
ROCm libraries: rocBLAS, MIOpen (like cuBLAS, cuDNN), rocFFT, rocSOLVER
Compiler: LLVM-based, uses AMDGPU backend
29.6.2 Key Hardware Differences: NVIDIA vs AMD
| Feature | NVIDIA (H100) | AMD (MI300X) |
|---|---|---|
| Thread grouping | 32-thread warp | 64-thread wavefront |
| Matrix acceleration | Tensor Cores (WMMA) | Matrix Cores (MFMA) |
| Memory | 80GB HBM3 | 192GB HBM3 |
| Memory BW | 3.35 TB/s | 5.3 TB/s |
| Compute (FP16) | 1979 TFLOPS | 1307 TFLOPS |
| Interconnect | NVLink 900 GB/s | Infinity Fabric 900 GB/s |
| Software stack | CUDA, cuDNN | ROCm, MIOpen |
Key Implications:
Wavefront size: Kernels tuned for 32 threads (NVIDIA) need re-tuning for 64 threads (AMD)
Memory advantage: MI300X has 2.4× more HBM, better for large models
Bandwidth advantage: MI300X has 1.6× higher memory bandwidth, benefits memory-bound workloads
29.6.3 ROCm Profiling Tools
1. rocprof (Primary Profiler)
rocprof is AMD’s equivalent to NVIDIA’s nvprof/nsys:
# Basic profiling
rocprof --stats python train.py
# Detailed kernel trace
rocprof --hip-trace --hsa-trace python train.py
# Specific metrics (memory bandwidth, occupancy)
rocprof --timestamp on -i metrics.txt python train.py
Common Metrics (metrics.txt):
pmc: SQ_WAVES, SQ_INSTS_VALU, SQ_INSTS_MFMA
pmc: TCC_HIT, TCC_MISS, TCC_EA_WRREQ, TCC_EA_RDREQ
pmc: GRBM_GUI_ACTIVE, GRBM_SPI_BUSY
2. rocTracer (API Tracing)
Traces HIP API calls and kernel launches:
roctracer -o trace.json python train.py
# View in chrome://tracing
3. AMD Profiler (rocProfiler GUI)
Visual profiling tool (like NVIDIA Nsight Compute):
rocprof --hsa-trace python train.py
# Open .csv output in rocProfiler GUI
4. Omniperf (Roofline Analysis)
AMD’s roofline model tool:
omniperf profile -n workload_name -- python train.py
omniperf analyze -p workload_name
Shows compute vs memory bottlenecks visually.
29.6.4 HIP Kernel Development
Converting CUDA to HIP
CUDA Kernel:
__global__ void saxpy(float a, float* x, float* y, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = a * x[i] + y[i];
}
// Launch
saxpy<<<blocks, threads>>>(a, x, y, n);
HIP Kernel (Auto-Converted with hipify-perl):
#include <hip/hip_runtime.h>
__global__ void saxpy(float a, float* x, float* y, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = a * x[i] + y[i];
}
// Launch (HIP uses hipLaunchKernelGGL)
hipLaunchKernelGGL(saxpy, blocks, threads, 0, 0, a, x, y, n);
Compilation:
# CUDA
nvcc -o program program.cu
# HIP
hipcc -o program program.cpp # .cpp extension for HIP
29.6.5 Platform-Specific Optimizations
Wavefront-Aware Tuning
AMD GPUs execute in 64-thread wavefronts (vs 32-thread NVIDIA warps):
// NVIDIA: 32 threads per warp
#define WARP_SIZE 32
int warp_id = threadIdx.x / 32;
// AMD: 64 threads per wavefront
#define WAVEFRONT_SIZE 64
int wave_id = threadIdx.x / 64;
// Portable
#ifdef __HIP_PLATFORM_AMD__
#define WAVE_SIZE 64
#else
#define WAVE_SIZE 32
#endif
Tuning implications:
Block sizes should be multiples of 64 (AMD) vs 32 (NVIDIA)
Shared memory bank conflicts differ (32 banks vs 64 banks)
Warp-level primitives (
__shfl_down) require different sizes
Matrix Core Utilization (MFMA)
AMD MI200/MI300 have Matrix Fused Multiply-Add (MFMA) instructions:
Operations: 16×16×16, 32×32×8 matrix multiplies
rocBLAS: Automatically uses MFMA for GEMMs
MIOpen: Uses MFMA for convolutions
Manual use: Inline assembly or rocWMMA library
// Check if MFMA is being used
rocprof -i mfma_check.txt python train.py
# mfma_check.txt:
pmc: SQ_INSTS_MFMA
If SQ_INSTS_MFMA is low, GEMM not using matrix cores.
Memory Hierarchy Optimization
AMD GPUs have different cache hierarchy:
L1 cache: 16KB per CU (vs 128KB L1 on NVIDIA)
L2 cache: 8MB (MI300X) vs 50MB (H100)
HBM: Higher bandwidth (5.3 TB/s vs 3.35 TB/s) but smaller caches
Strategy: AMD GPUs favor streaming workloads over cache-heavy ones.
29.6.6 PyTorch + ROCm Integration
Installation:
# Official ROCm PyTorch wheels
pip install torch torchvision torchaudio --index-url \
https://download.pytorch.org/whl/rocm5.7
# Verify
python -c "import torch; print(torch.cuda.is_available())" # True on ROCm
python -c "import torch; print(torch.version.hip)" # Shows ROCm version
Compatibility Notes:
PyTorch treats AMD GPUs as
cudadevices (HIP compatibility layer)Most PyTorch code runs unmodified on ROCm
Flash Attention 2 supports ROCm via Triton backend
DeepSpeed and FSDP work on ROCm (use RCCL instead of NCCL)
29.6.7 Communication Libraries: RCCL
RCCL (ROCm Communication Collectives Library) is AMD’s equivalent to NCCL:
Operations: AllReduce, ReduceScatter, AllGather, Broadcast
Topology-aware: Optimized for Infinity Fabric (AMD’s NVLink equivalent)
Multi-node: Supports InfiniBand, RoCE (RDMA over Converged Ethernet)
# Test RCCL bandwidth
git clone https://github.com/ROCmSoftwarePlatform/rccl-tests.git
cd rccl-tests && make
./build/all_reduce_perf -b 1G -e 8G -f 2
# Expected output (MI300X with Infinity Fabric):
# size time busbw
# 8589934592 45ms 381GB/s # Infinity Fabric bandwidth
29.6.8 Auto-Tuning for AMD GPUs
MIOpen Auto-Tuning
MIOpen (AMD’s cuDNN) includes find-db auto-tuning:
# Enable auto-tuning (stores results in ~/.config/miopen)
export MIOPEN_FIND_MODE=1 # Normal find (fast)
export MIOPEN_FIND_MODE=2 # Exhaustive search (slow, best perf)
# First run auto-tunes, subsequent runs use cached configs
python train.py
Find Database Location:
~/.config/miopen/
|-- 2.0.0/ # ROCm version
| `-- gfx90a/ # GPU architecture (MI200 series)
| `-- conv_find_db.db # SQLite database of tuned configs
Composable Kernel Library
AMD’s Composable Kernel (CK) library provides optimized kernels:
Operations: GEMM, GEMM-Softmax-GEMM (for attention), convolutions
Templates: Tile sizes, pipeline stages parameterized
Auto-selection: Library picks best kernel based on shape
# Build PyTorch with Composable Kernel backend
USE_ROCM=1 USE_COMPOSABLE_KERNEL=1 python setup.py install
Triton Auto-Tuning on ROCm
Triton supports AMD GPUs, auto-tuning works identically:
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=2), # 2 wavefronts
triton.Config({'BLOCK_SIZE': 256}, num_warps=4), # 4 wavefronts
triton.Config({'BLOCK_SIZE': 512}, num_warps=8), # 8 wavefronts
],
key=['N'],
)
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
# ... kernel code ...
pass
# Auto-tunes on AMD GPU, caches results
Note: Triton uses "warps" terminology but maps to wavefronts on AMD.
29.6.9 Debugging Common ROCm Issues
Issue 1: Kernel Launch Failures
# Enable detailed error messages
export HIP_VISIBLE_DEVICES=0
export HSA_ENABLE_DEBUG=1
export AMD_LOG_LEVEL=3 # Verbose logging
python train.py 2>&1 | grep -i error
Issue 2: Performance Lower Than Expected
Check if MIOpen auto-tuning is enabled (
MIOPEN_FIND_MODE)Verify GPU is not thermally throttling (
rocm-smi)Profile to check MFMA utilization (
SQ_INSTS_MFMA)Compare memory bandwidth vs theoretical peak (rocprof)
Ensure Infinity Fabric links are active (topology)
Issue 3: Multi-GPU Hangs
# Enable RCCL debug logging
export RCCL_DEBUG=INFO
export RCCL_DEBUG_SUBSYS=INIT,COLL
# Check network topology
rocm-smi --showtopo
29.6.10 ROCm Ecosystem Tools
rocm-smi: GPU monitoring (like nvidia-smi)
rocminfo: Display GPU capabilities
rocm-bandwidth-test: Benchmark memory bandwidth
TransferBench: Test peer-to-peer transfers
rocprof: Performance profiling
rocgdb: GPU debugger (like cuda-gdb)
roctracer: API tracing
Omniperf: Roofline analysis
Quick ROCm Health Check:
# Check GPU status
rocm-smi
# Test peer-to-peer bandwidth between GPUs
rocm-bandwidth-test
# Verify HIP can see devices
hipconfig --check
# Run simple kernel test
cat > test.cpp << 'EOF'
#include <hip/hip_runtime.h>
#include <stdio.h>
__global__ void hello() { printf("Hello from GPU %d\n", blockIdx.x); }
int main() {
hello<<<4, 1>>>();
hipDeviceSynchronize();
return 0;
}
EOF
hipcc test.cpp -o test && ./test
29.6.11 Interview Questions: ROCm & AMD GPUs
Q: How would you port a CUDA kernel to HIP?
A: Use
hipify-perlfor automatic conversion. Review changes (warps→wavefronts, 32→64 threads). Recompile withhipcc. Tune block sizes for 64-thread wavefronts. Test correctness and benchmark performance.Q: What’s the difference between NVIDIA Tensor Cores and AMD Matrix Cores?
A: Both accelerate matrix multiplication. NVIDIA Tensor Cores use WMMA (warp matrix multiply-accumulate). AMD Matrix Cores use MFMA (matrix fused multiply-add). Similar performance, different programming interfaces. Both accessed via libraries (cuBLAS/rocBLAS) or inline assembly.
Q: How do you profile a PyTorch model on AMD GPUs?
A: Use
rocprof --hip-trace python train.py. Check memory bandwidth withTCC_EA_RDREQ/WRREQmetrics. Verify MFMA usage withSQ_INSTS_MFMA. Use PyTorch profiler withtorch.profiler(ROCm-aware). Visualize in chrome://tracing.Q: Why might a model run slower on AMD MI300X despite higher memory bandwidth?
A: Possible causes: (1) Kernels not tuned for 64-thread wavefronts, (2) MIOpen auto-tuning not enabled, (3) Smaller L1/L2 caches hurt cache-heavy workloads, (4) MFMA not utilized (check with profiler), (5) Software maturity (CUDA ecosystem more optimized).
Q: How do you enable auto-tuning for AMD GPUs?
A: Set
MIOPEN_FIND_MODE=1(fast) or=2(exhaustive). First run auto-tunes and caches results in~/.config/miopen/. Subsequent runs use cached configs. For Triton kernels, use@triton.autotunedecorator (same as CUDA).
29.7 Production Best Practices
Operational Checklist for Large-Scale Training:
Baseline single-GPU performance first (eliminate data/model issues)
Profile before optimizing (use PyTorch profiler + Nsight)
Monitor key metrics continuously (log GPU util, throughput, loss)
Test scaling incrementally (1 → 8 → 64 → 512 GPUs)
Validate accuracy after each optimization (ensure no regression)
Checkpoint frequently (failures inevitable at scale)
Use NCCL/network diagnostics (validate topology, bandwidth)
Automate alerts (OOM, loss spikes, utilization drops)
Common Pitfalls:
Assuming compute-bound: Many workloads are memory/IO-bound
Ignoring data loading: Can easily become 50%+ of time
Over-optimizing communication: If already \(<\)10%, focus elsewhere
Not validating changes: Profile after each optimization to confirm impact
Premature scaling: Fix single-node issues before going multi-node
29.8 Interview Questions on Diagnostics
Q: Your distributed training job has 40% GPU utilization. How do you debug?
A:
Check if data loading is slow (profile DataLoader time)
Verify batch size is large enough to saturate GPU
Look for synchronization points (
.item(),.cpu())Check if blocked on AllReduce (NCCL_DEBUG=INFO)
Profile with PyTorch profiler to identify hot spots
Q: How do you know if you’re memory bandwidth bound vs compute bound?
A:
Use
nsysorncuto check achieved memory bandwidthIf \(>\)80% of peak HBM bandwidth → memory-bound
Test: double batch size; if time doubles → memory-bound
Memory-bound ops: elementwise (ReLU, dropout), softmax, layernorm
Compute-bound ops: GEMM (matmul), convolutions
Q: Scaling from 8 to 64 GPUs only gives 5\(\times\) speedup. Why?
A: Communication overhead. Solutions:
Use gradient accumulation (reduce AllReduce frequency)
Switch to FSDP/ZeRO (more efficient communication)
Increase batch size (amortize communication over more compute)
Check network bandwidth (run
nccl-tests)Ensure topology is optimal (NVLink intra-node, InfiniBand inter-node)
Q: How do you monitor network bandwidth during training?
A:
Run
nccl-teststo benchmark AllReduce bandwidthEnable
NCCL_DEBUG=INFOto log communication timesUse
iftoporifconfigto monitor interface trafficProfile with PyTorch profiler, look for
ncclAllReducetimeCheck if AllReduce time \(>\)15% of iteration time → bottleneck
Q: What’s the difference between HBM bandwidth and network bandwidth?
A:
HBM (GPU memory): 1.5–3 TB/s, local to GPU, affects memory-bound kernels
NVLink (GPU-GPU): 600–900 GB/s, intra-node, for tensor/pipeline parallelism
InfiniBand (inter-node): 200–400 GB/s, for distributed training AllReduce
Different bottlenecks: HBM limits kernel speed, network limits scaling
Q: How would you optimize a training job that’s memory-bound?
A:
Use Flash Attention (fused attention kernel)
Enable kernel fusion (
torch.compile)Use mixed precision (BF16/FP16) to reduce memory traffic
Avoid unnecessary transposes/copies
Replace custom ops with optimized libraries (cuDNN, cuBLAS)
30 Auto-Tuning & Kernel Optimization
30.1 Introduction: The Auto-Tuning Problem
Modern GPU kernels have high-dimensional configuration spaces:
Block sizes: 32, 64, 128, 256, 512, 1024 threads
Tile sizes: How to partition work across thread blocks
Memory layouts: Row-major, column-major, tiled
Parallelization strategies: Which loops to parallelize, unroll factors
Pipeline stages: Double buffering, triple buffering for overlap
Search space size: For matrix multiplication alone, \(>\)10,000 valid configurations exist. Manual tuning is infeasible at scale.
Auto-tuning automates this search: systematically explore configurations, measure performance, select the best.
30.2 TVM Auto-Scheduler (Ansor)
30.2.1 Overview
Apache TVM is an end-to-end deep learning compiler with sophisticated auto-tuning:
Unified IR: Hardware-agnostic tensor expressions
Auto-Scheduler (Ansor): Search-based schedule optimization
Cost model: ML-based performance predictor (avoids running every config)
Target backends: CUDA, ROCm, CPU, ARM, FPGA
30.2.2 How TVM Auto-Scheduler Works
Define compute: Write tensor expression (hardware-agnostic)
Generate search space: Enumerate tiling, parallelization, reordering options
Sample candidates: Use evolutionary search or random sampling
Predict cost: ML model estimates latency without running
Measure top-k: Actually run best predicted configs on hardware
Update cost model: Train on measured results (active learning)
Repeat: Iterate until budget exhausted or convergence
TVM Auto-Tuning Example:
import tvm
from tvm import te, auto_scheduler
# Step 1: Define compute (hardware-agnostic)
@auto_scheduler.register_workload
def matmul(N, K, M):
A = te.placeholder((N, K), name="A")
B = te.placeholder((K, M), name="B")
k = te.reduce_axis((0, K), name="k")
C = te.compute((N, M),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="C")
return [A, B, C]
# Step 2: Create tuning task
target = tvm.target.cuda()
task = auto_scheduler.SearchTask(
func=matmul,
args=(1024, 1024, 1024),
target=target
)
# Step 3: Auto-tune
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=1000, # Run 1000 measurements
measure_callbacks=[auto_scheduler.RecordToFile("matmul.json")],
early_stopping=100 # Stop if no improvement after 100 trials
)
task.tune(tune_option)
# Step 4: Apply best schedule
sch, args = task.apply_best("matmul.json")
func = tvm.build(sch, args, target)
# Use optimized kernel
a = tvm.nd.array(np.random.rand(1024, 1024).astype("float32"), device)
b = tvm.nd.array(np.random.rand(1024, 1024).astype("float32"), device)
c = tvm.nd.empty((1024, 1024), device=device)
func(a, b, c) # Runs optimized kernel
30.2.3 TVM Cost Model
Key innovation: predict performance without running on hardware.
Features: Extract from schedule (loop tiling, memory access patterns)
Model: Gradient boosting (XGBoost) trained on measured latencies
Transfer learning: Pre-trained on many kernels, fine-tunes per workload
Benefit: Reduce measurements from 10,000s to \(\sim\)1,000
30.2.4 When to Use TVM
Diverse hardware: Need to target CUDA, ROCm, ARM, FPGA from single codebase
Custom ops: Operators not in cuDNN/MIOpen (fused kernels, sparse ops)
Aggressive optimization: Willing to invest tuning time for best performance
Model deployment: Optimize entire model graph end-to-end
30.3 Triton Auto-Tuning
30.3.1 Overview
OpenAI Triton is a Python-based GPU kernel language:
Python-like syntax: Lower barrier than CUDA/HIP
Automatic memory management: Compiler handles shared memory, tiling
Built-in auto-tuning:
@triton.autotunedecoratorBackends: CUDA, ROCm (via LLVM AMDGPU)
30.3.2 Triton Auto-Tune Mechanism
Define configs: List of parameter combinations (block size, num warps)
First invocation: Triton compiles all configs, runs benchmarks
Cache winner: Stores best config in
~/.triton/autotuneSubsequent calls: Directly uses cached config (no overhead)
Triton Auto-Tune Example:
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
triton.Config({'BLOCK_SIZE': 256}, num_warps=8),
triton.Config({'BLOCK_SIZE': 512}, num_warps=16),
triton.Config({'BLOCK_SIZE': 1024}, num_warps=32),
],
key=['N'], # Auto-tune based on input size N
)
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, N,
BLOCK_SIZE: tl.constexpr):
# Program ID (which block)
pid = tl.program_id(0)
# Compute offsets
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Bounds check mask
mask = offsets < N
# Load, compute, store
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
# Launch kernel (auto-tunes on first call)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, out, N)
30.3.3 Triton vs CUDA/HIP
| Aspect | CUDA/HIP | Triton |
|---|---|---|
| Language | C-like, manual memory | Python-like, auto memory |
| Shared memory | Manual allocation | Compiler manages |
| Tiling | Manual loop writing | Automatic blocking |
| Auto-tuning | External (cutlass, ck) | Built-in decorator |
| Portability | CUDA or HIP (not both) | CUDA + ROCm from one source |
| Learning curve | Steep (weeks) | Moderate (days) |
| Performance | Best (hand-tuned) | Near-optimal (90-95%) |
30.3.4 Triton in Production
Flash Attention 2: Implemented in Triton, auto-tuned for different GPUs
PyTorch Inductor:
torch.compilebackend uses Triton for fused kernelsOpenAI: Powers GPT inference optimizations
Anthropic: Used in Claude training infrastructure
30.4 PyTorch Inductor Auto-Tuning
torch.compile (PyTorch 2.0+) includes auto-tuning via Triton backend:
import torch
# Enable max auto-tuning
@torch.compile(mode="max-autotune")
def fused_mlp(x, w1, w2):
return torch.nn.functional.gelu(x @ w1) @ w2
# First call: auto-tunes kernels (slow)
# Subsequent calls: uses cached configs (fast)
output = fused_mlp(x, w1, w2)
What it does:
Graph capture: Traces PyTorch operations
Fusion: Merges elementwise ops into single kernels
Code generation: Generates Triton kernels
Auto-tuning: Tries multiple block sizes, selects fastest
Caching: Stores compiled kernels in
~/.inductor
Speedups observed:
Elementwise-heavy models: 1.5–2× faster
Transformer inference: 1.2–1.5× faster
Custom fused ops: 2–5× faster
30.5 ROCm/AMD Auto-Tuning
30.5.1 MIOpen Auto-Tuning
AMD’s MIOpen (cuDNN equivalent) uses find-db for auto-tuning:
# Enable auto-tuning (stores results in ~/.config/miopen)
export MIOPEN_FIND_MODE=1 # Normal find (moderate time)
export MIOPEN_FIND_MODE=2 # Exhaustive search (slow, best perf)
# First run builds find database
python train.py # Slow (tuning overhead)
# Subsequent runs use cached configs
python train.py # Fast (no tuning overhead)
Find database location:
~/.config/miopen/
|-- 2.0.0/ # ROCm version
| `-- gfx90a/ # GPU architecture (MI200 = gfx90a, MI300 = gfx940)
| `-- conv_find_db.db # SQLite DB of tuned configs
30.5.2 Composable Kernel (CK) Library
AMD’s Composable Kernel provides optimized templated kernels:
Operations: GEMM, batched GEMM, GEMM-Softmax-GEMM (attention)
Templates: Tile sizes, pipeline stages as template parameters
Auto-selection: Library picks best instantiation for input shape
Integration: Used by MIOpen, PyTorch ROCm backend
30.6 Auto-Tuning Best Practices
30.6.1 Tuning Budget
Quick tuning: 100–200 trials (\(\sim\)10 minutes) for 80% optimal
Production tuning: 1,000–2,000 trials (\(\sim\)1 hour) for 95% optimal
Exhaustive: 10,000+ trials (hours) for last 1–2%
Rule of thumb: Diminishing returns after 1,000 trials for most kernels.
30.6.2 Caching Strategy
Per-shape tuning: Different optimal configs for different input sizes
Cache key: (kernel_name, input_shapes, dtype, GPU_arch)
Persistent storage: Store in filesystem, share across jobs
Version control: Include tuning cache in Docker images
30.6.3 Correctness Validation
Critical: Always validate correctness before deploying tuned kernels!
Golden reference: Run baseline (PyTorch native op)
Numerical comparison:
torch.allclose(output, reference, rtol=1e-5)Multiple input sizes: Test edge cases (empty, very large)
Different dtypes: fp32, fp16, bf16
Gradient checking: For backprop kernels
30.7 Interview Questions: Auto-Tuning
Q: What’s the difference between TVM and Triton?
A:
TVM: Full compiler stack, targets diverse hardware (CPU/GPU/FPGA), uses ML cost model to predict performance, suitable for deployment pipelines
Triton: Python DSL for GPU kernels only, measures actual latency (no cost model), easier to write, integrated with PyTorch
Use TVM: Multi-platform deployment, custom model formats
Use Triton: Quick prototyping, PyTorch integration, GPU-only workloads
Q: How does Triton auto-tuning work under the hood?
A:
Decorator lists candidate configs (block size, num warps)
On first invocation, Triton compiles all variants
Runs microbenchmark for each, measures latency
Selects fastest, caches in
~/.triton/autotune/{hash}Subsequent calls directly use cached config (zero overhead)
Q: When would you use
torch.compile(mode="max-autotune")?A: When:
Model has many elementwise ops (activations, normalization) → good fusion opportunities
Training/inference time is critical (worth 1-time tuning cost)
Input shapes are fixed (tuning cache reusable)
Avoid when:
Dynamic shapes (cache misses, retuning overhead)
Model dominated by large matmuls (cuBLAS already optimal)
Rapid iteration (compilation time adds latency)
Q: How do you validate an auto-tuned kernel is correct?
A:
Run baseline implementation, save output as golden reference
Test tuned kernel with
torch.allclose(output, golden, rtol=1e-5)Test multiple input shapes (small, large, edge cases)
Test different dtypes (fp32, fp16, bf16)
For backprop: verify gradients match finite differences
Run in CI on every config change
Q: How would you share tuning caches across a team?
A:
Store cache directory in version control (Git LFS for large files)
Include in Docker image:
COPY .triton/autotune /root/.triton/autotuneUse shared filesystem in cluster (NFS mount to
~/.triton)Build caching service: HTTP API to query/store configs by (kernel, shape, GPU)
CI pipeline: run tuning, commit updated caches
30.8 Key Takeaways for Interviews
Critical Concepts to Master:
DDP: Multi-process data parallelism with All-Reduce
ZeRO/FSDP: Sharding optimizer, gradients, parameters to reduce memory
Tensor Parallelism: Split layers across GPUs (intra-layer)
Pipeline Parallelism: Split model into stages (inter-layer)
Mixed Precision: Use fp16/bf16 for compute, fp32 for updates
Gradient Checkpointing: Trade memory for recomputation
Flash Attention: Fused kernel to reduce attention memory
Modern LLM Training Stack:
Framework: PyTorch FSDP or DeepSpeed ZeRO
Precision: BF16 (A100/H100) or FP16 with loss scaling
Optimizer: AdamW with warmup + cosine decay
Memory: Gradient checkpointing + Flash Attention 2
Communication: NCCL (intra-node), InfiniBand (inter-node)
31 References & Further Reading
ZeRO (Rajbhandari et al., 2020): ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
FSDP (Meta AI, 2021): PyTorch Fully Sharded Data Parallel
Megatron-LM (Shoeybi et al., 2019): Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
GPipe (Huang et al., 2019): GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
Flash Attention (Dao et al., 2022): FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Mixed Precision (Micikevicius et al., 2018): Mixed Precision Training
DeepSpeed Documentation: https://www.deepspeed.ai/
PyTorch FSDP Tutorial: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html