17 Chapter 16: Training Optimization & Parallelization Strategies
18 Introduction: The Scale Challenge
Modern large language models (LLMs) have grown from millions to hundreds of billions of parameters:
GPT-2 (2019): 1.5B params, trained on 40GB text
GPT-3 (2020): 175B params, requires 350GB+ in fp32
LLaMA-2 70B (2023): 70B params, 2T tokens
GPT-4 (2023): Estimated 1.7T params (MoE), multi-modal
Key Challenges:
Memory: A 70B model requires \(\sim\)140GB just for fp16 weights–exceeds single GPU
Compute: Training on trillions of tokens takes months on thousands of GPUs
Communication: Gradient synchronization becomes bottleneck at scale
Efficiency: Memory overhead (optimizer states, activations) can be 10-20\(\times\) model size
This Document Covers:
Data Parallelism: DDP, gradient synchronization, FSDP, ZeRO stages
Model Parallelism: Tensor parallelism, pipeline parallelism, 3D parallelism
Memory Optimization: Mixed precision, gradient accumulation, activation checkpointing
Communication: All-reduce, reduce-scatter, gradient compression
Frameworks: PyTorch FSDP, DeepSpeed, Megatron-LM, Ray
19 Data Parallelism
19.1 Concept: Replicate Model, Split Data
Data Parallelism (DP) replicates the model on each GPU and splits the training batch across devices.
Forward Pass:
Each GPU holds a full copy of model parameters \(\theta\)
Global batch \(B\) split into \(N\) micro-batches: \(B = \{B_1, B_2, \ldots, B_N\}\)
GPU \(i\) processes \(B_i\) independently → computes loss \(\mathcal{L}_i\) and gradients \(g_i\)
Backward Pass:
Each GPU computes local gradients: \(g_i = \nabla_\theta \mathcal{L}_i(\theta)\)
Synchronize gradients via All-Reduce: \(g = \frac{1}{N} \sum_{i=1}^N g_i\)
Update parameters: \(\theta_{t+1} = \theta_t - \eta g\)
Key Insight: All-Reduce ensures every GPU has the same averaged gradient before updating. This keeps model replicas in sync.
Note: This is data parallelism–each GPU has a full copy of the model and processes different data. For models too large to fit on one GPU, use model parallelism (tensor/pipeline parallelism, covered later) where the model itself is split across GPUs.
19.2 DataParallel (DP) vs DistributedDataParallel (DDP)
DataParallel (DP):
Single-process, multi-threaded (Python GIL bottleneck)
Master GPU gathers gradients from all GPUs → broadcasts updated params
Memory imbalance: master GPU stores full batch + optimizer state
Slow: Communication is sequential; limited by GIL
DistributedDataParallel (DDP):
Multi-process: one process per GPU (avoids GIL)
No master GPU: All-Reduce distributes gradients evenly
Uses NCCL (NVIDIA Collective Communications Library) for efficient gradient sync
Faster: Up to 2-3\(\times\) speedup over DP
PyTorch DDP Setup:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group(backend="nccl")
model = MyModel().to(device)
model = DDP(model, device_ids=[local_rank])
# Training loop
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward() # Gradients auto-synced via All-Reduce
optimizer.step()
19.3 All-Reduce Algorithm
All-Reduce is a collective communication operation where each GPU contributes local gradients \(g_i\) and receives the global average \(g = \frac{1}{N} \sum_{i=1}^N g_i\).
Ring All-Reduce (Bandwidth-Optimal):
Arrange \(N\) GPUs in a logical ring
Reduce-Scatter Phase: Each GPU sends chunks to neighbors; gradually accumulate sums
All-Gather Phase: Each GPU sends accumulated chunks to neighbors; reconstruct full result
Complexity Analysis:
Consider \(N\) GPUs, each with gradient data of size \(M\) bytes.
Ring All-Reduce Mechanics:
Split each GPU’s data into \(N\) chunks of size \(M/N\)
Reduce-Scatter: \(N-1\) communication rounds → each GPU sends \((N-1) \cdot \frac{M}{N}\) bytes
All-Gather: \(N-1\) communication rounds → each GPU sends \((N-1) \cdot \frac{M}{N}\) bytes
Total data sent per GPU: \(2(N-1) \cdot \frac{M}{N} = \frac{2(N-1)M}{N}\)
Time Complexity:
Communication time: \(\frac{2(N-1)M}{N \cdot \text{bandwidth}} \approx \frac{2M}{\text{bandwidth}}\) for large \(N\)
As \(N \to \infty\), \(\frac{N-1}{N} \to 1\) → time approaches \(\frac{2M}{\text{bandwidth}}\)
Key insight: Communication cost is independent of \(N\)–scales perfectly!
Practical Implications:
Bandwidth-optimal: Each GPU sends/receives \(\approx 2M\) total (minimal theoretical limit)
Latency: \(2(N-1)\) sequential message steps → can hurt small models where latency dominates
Example: 8 GPUs, 1GB gradients → each GPU sends 1.75GB, takes \(\approx\) 14 message hops
Modern Practice: NCCL implements optimized All-Reduce using tree-based algorithms for small \(N\) (lower latency) and ring-based for large \(N\) (bandwidth-optimal).
19.4 Limitations of Naive Data Parallelism
Memory Replication: Each GPU stores full model (weights, optimizer states, gradients)
Example: 70B model in fp16 → 140GB \(\times\) 8 GPUs = 1.1TB total memory used
Inefficient for Large Models: Cannot train models larger than single GPU memory
Communication Overhead: All-Reduce latency grows with model size
Solution: ZeRO (Zero Redundancy Optimizer) and FSDP (Fully Sharded Data Parallel)
20 ZeRO: Zero Redundancy Optimizer
20.1 Motivation: Eliminate Memory Redundancy
In standard data parallelism, memory breakdown per GPU for a model with \(P\) parameters:
Model weights: \(2P\) bytes (fp16)
Gradients: \(2P\) bytes (fp16)
Optimizer states (Adam): \(12P\) bytes = fp32 master weights (\(4P\)) + fp32 momentum (\(4P\)) + fp32 variance (\(4P\))
Total: \(16P\) bytes per GPU → fully redundant across GPUs
Example: 7B model → \(7 \times 10^9 \times 16 = 112\) GB per GPU (before activations!)
ZeRO (DeepSpeed) removes this redundancy by sharding optimizer states, gradients, and parameters across GPUs.
20.2 ZeRO Stages
ZeRO-1: Optimizer State Partitioning
Shard optimizer states (momentum, variance) across \(N\) GPUs
Each GPU stores \(\frac{1}{N}\) of optimizer states → updates only its assigned params
Still does full All-Reduce: All GPUs compute and store full gradients (\(2P\) each)
After optimizer step: All-Gather updated params to sync models
Memory Reduction: Optimizer states from \(12P\) to \(\frac{12P}{N}\) (gradients still \(2P\))
ZeRO-2: Gradient Partitioning (builds on ZeRO-1)
Replaces All-Reduce with Reduce-Scatter: Each GPU gets different shard of averaged gradients
After backward: GPU \(i\) receives averaged gradients only for parameters it’s responsible for
Each GPU stores \(\frac{1}{N}\) of gradients (matching its optimizer state partition)
Updates only its assigned parameter shard using its local gradient shard
Memory Reduction: Gradients from \(2P\) to \(\frac{2P}{N}\) + Optimizer \(\frac{12P}{N}\)
Concrete Example: 2 GPUs, 4 parameters \((W_1, W_2, W_3, W_4)\)
ZeRO-1 Workflow:
Backward pass: Both GPUs compute local gradients \((g_1, g_2, g_3, g_4)\)
All-Reduce: Average gradients across GPUs
GPU-1 receives: \((G_1, G_2, G_3, G_4)\) (full averaged gradients)
GPU-2 receives: \((G_1, G_2, G_3, G_4)\) (full averaged gradients)
Optimizer step:
GPU-1 has optimizer states for \((W_1, W_2)\) → updates \((W_1, W_2)\) using \((G_1, G_2)\)
GPU-2 has optimizer states for \((W_3, W_4)\) → updates \((W_3, W_4)\) using \((G_3, G_4)\)
All-Gather: Broadcast updated parameters
GPU-1 sends \((W_1', W_2')\) to GPU-2
GPU-2 sends \((W_3', W_4')\) to GPU-1
Both now have full model: \((W_1', W_2', W_3', W_4')\)
Memory: Both GPUs store all gradients (\(2P\)), but only half the optimizer states (\(6P\) each)
ZeRO-2 Workflow:
Backward pass: Both GPUs compute local gradients \((g_1, g_2, g_3, g_4)\)
Reduce-Scatter: Average and distribute gradient shards
GPU-1 receives: \((G_1, G_2)\) only (averaged, for its assigned params)
GPU-2 receives: \((G_3, G_4)\) only (averaged, for its assigned params)
Optimizer step:
GPU-1 updates \((W_1, W_2)\) using \((G_1, G_2)\)
GPU-2 updates \((W_3, W_4)\) using \((G_3, G_4)\)
All-Gather: Broadcast updated parameters (same as ZeRO-1)
Memory: Each GPU stores half gradients (\(P\)) + half optimizer states (\(6P\) each)
Key Insight: In ZeRO-1, you store all gradients but can discard them after optimizer step. ZeRO-2 never stores the full gradient array, saving memory immediately.
ZeRO-3: Parameter Partitioning
Shard model weights: each GPU stores \(\frac{1}{N}\) of parameters
Forward/backward: All-Gather needed params on-the-fly, discard after use
Memory Reduction: Parameters from \(2P\) to \(\frac{2P}{N}\)
Trade-off: Increased communication (All-Gather per layer)
Memory Comparison (7B model, 8 GPUs, fp16):
| Method | Weights | Gradients | Optimizer |
|---|---|---|---|
| Baseline DP | 14 GB | 14 GB | 84 GB |
| ZeRO-1 | 14 GB | 14 GB | 10.5 GB |
| ZeRO-2 | 14 GB | 1.75 GB | 10.5 GB |
| ZeRO-3 | 1.75 GB | 1.75 GB | 10.5 GB |
ZeRO-3 Trade-off: Reduces memory by \(8\times\) but requires All-Gather for every layer. Use ZeRO-2 if communication bandwidth is limited; ZeRO-3 if memory-constrained.
20.3 ZeRO-Offload: CPU Memory Extension
ZeRO-Offload moves optimizer states to CPU memory, reducing GPU memory further:
Forward/backward on GPU (fast)
Gradients transferred to CPU → optimizer step on CPU → updated params back to GPU
Use Case: Training 10B+ models on consumer GPUs (e.g., RTX 3090 24GB)
DeepSpeed ZeRO-3 Config:
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
}
}
21 Fully Sharded Data Parallel (FSDP)
21.1 PyTorch Native Alternative to ZeRO
FSDP (PyTorch 1.11+) is PyTorch’s native implementation of ZeRO-3 principles:
Shards parameters, gradients, and optimizer states across GPUs
Automatically manages All-Gather/Reduce-Scatter during forward/backward
Integrated with PyTorch APIs (no external dependency like DeepSpeed)
21.2 FSDP Workflow
Forward Pass:
GPU \(i\) holds shard \(\theta_i\) (e.g., layers 0-3 of 32-layer model)
Before processing layer \(k\): All-Gather \(\theta_k\) from all GPUs
Compute activations \(a_k = f_k(a_{k-1}, \theta_k)\)
Discard \(\theta_k\) (free memory)
Backward Pass:
Re-All-Gather \(\theta_k\) to compute gradients \(g_k\)
Reduce-Scatter \(g_k\) → each GPU gets averaged shard
Discard \(\theta_k\) again
Optimizer Step:
- Each GPU updates its shard \(\theta_i\) using local optimizer state
PyTorch FSDP Usage:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import CPUOffload, MixedPrecision
model = MyModel()
model = FSDP(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
)
# Training loop (identical to DDP)
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
22 Model Parallelism
22.1 When Data Parallelism Fails
Problem: Model too large to fit on a single GPU, even with ZeRO-3.
Example: 175B GPT-3 → \(\sim\)350GB in fp16
ZeRO-3 with 64 GPUs → \(\frac{350}{64} \approx 5.5\) GB per GPU (just for weights)
Activations during forward pass can be \(\sim\)10-50GB per GPU → OOM
Solution: Model Parallelism–split the model itself across GPUs.
ZeRO-3 vs Tensor Parallelism: Both Shard Weights, But Differently!
ZeRO-3 (Data Parallelism with Weight Sharding):
Each GPU stores \(\frac{1}{N}\) of weights (sharded for memory)
During forward: All-Gather to reconstruct full weights temporarily
All GPUs compute the same operations on different data batches
Still data parallelism–just memory-efficient
Tensor Parallelism (Model Parallelism):
Each GPU stores \(\frac{1}{N}\) of weights (sharded for computation)
Each GPU computes different parts of the layer on the same data
Example: GPU-0 computes first half of MLP output, GPU-1 computes second half
Splits both weights and computation
Key Difference: ZeRO-3 shards weights for memory but reconstructs them for computation. Tensor parallelism shards weights and splits the computation itself.
22.2 Tensor Parallelism (Intra-Layer)
Tensor Parallelism splits individual layers across GPUs. Most common for Transformer MLP and attention.
MLP Layer Splitting:
A standard Transformer MLP: \(y = \text{GELU}(xW_1)W_2\) where \(W_1 \in \mathbb{R}^{d \times 4d}\), \(W_2 \in \mathbb{R}^{4d \times d}\).
Column-wise Split of \(W_1\):
Partition \(W_1 = [W_1^{(1)} | W_1^{(2)}]\) across 2 GPUs
Each GPU computes: \(h^{(i)} = \text{GELU}(x W_1^{(i)})\) independently (no communication)
Concatenate: \(h = [h^{(1)} | h^{(2)}]\)
Row-wise Split of \(W_2\):
Partition \(W_2 = \begin{bmatrix} W_2^{(1)} \\ W_2^{(2)} \end{bmatrix}\)
Each GPU computes partial output: \(y^{(i)} = h^{(i)} W_2^{(i)}\)
All-Reduce: Sum \(y = \sum_i y^{(i)}\)
Communication: One All-Reduce per layer (cheap for large \(d\)).
Megatron-LM (NVIDIA) pioneered tensor parallelism for GPT-style models. Used to train GPT-3, MT-NLG (530B params).
22.3 Pipeline Parallelism (Inter-Layer)
Pipeline Parallelism splits the model into stages (layer groups) assigned to different GPUs.
Naive Pipelining (Sequential):
GPU 0: Layers 1-8
GPU 1: Layers 9-16
GPU 2: Layers 17-24
GPU 3: Layers 25-32
Problem: GPU 1 waits for GPU 0 → only 25% utilization (pipeline bubble).
GPipe (Micro-batching):
Split global batch into \(M\) micro-batches
Interleave forward passes: GPU 0 processes micro-batch 1, then GPU 1 starts while GPU 0 processes micro-batch 2
Backward passes flow in reverse: GPU 3 → GPU 2 → GPU 1 → GPU 0
Reduces bubble to \(\frac{N-1}{M}\) where \(N\) = # GPUs, \(M\) = # micro-batches
PipeDream (Asynchronous):
Different micro-batches at different pipeline stages simultaneously
1F1B Schedule: Each stage alternates forward and backward passes (1 Forward, 1 Backward)
Maintains multiple versions of weights to avoid staleness
Pipeline Parallelism Trade-offs:
+ No all-to-all communication (only point-to-point between stages)
- Pipeline bubbles reduce GPU utilization (\(\sim\)80-90% typical)
- Requires careful micro-batch size tuning
22.4 3D Parallelism: Combining All Strategies
Modern LLM training uses 3D Parallelism: Data + Tensor + Pipeline.
Example: 175B GPT-3 on 1024 GPUs (Megatron-LM):
Data Parallelism: 16-way (16 replicas of the full pipeline)
Tensor Parallelism: 8-way (each layer split across 8 GPUs)
Pipeline Parallelism: 8 stages (8 layer groups)
Total: \(16 \times 8 \times 8 = 1024\) GPUs
Communication Hierarchy:
Tensor parallelism: High-bandwidth NVLink within node (fast All-Reduce)
Pipeline parallelism: Point-to-point across nodes (moderate bandwidth)
Data parallelism: All-Reduce across replicas (infrequent, amortized)
23 Mixed Precision Training
23.1 Motivation: Speed and Memory
Standard Training: fp32 (32-bit floating point) for weights, activations, gradients.
Slow: Tensor cores (A100, H100) are optimized for fp16/bf16
Memory-hungry: 2\(\times\) more memory than fp16
Mixed Precision Training uses fp16 for forward/backward, fp32 for optimizer updates.
23.2 FP16 vs BF16
FP16 (IEEE Half-Precision):
1 sign bit, 5 exponent bits, 10 mantissa bits
Range: \(\pm 6.55 \times 10^4\) (limited dynamic range)
Risk: Gradient underflow (small gradients → zero)
BF16 (Brain Float 16):
1 sign bit, 8 exponent bits (same as fp32!), 7 mantissa bits
Range: \(\pm 3.4 \times 10^{38}\) (same as fp32)
Robust: No loss scaling needed
Modern Default: A100, H100, TPUs support bf16 natively
23.3 Mixed Precision Workflow
AMP (Automatic Mixed Precision):
Maintain master copy of weights in fp32
Cast weights to fp16/bf16 for forward pass
Compute loss and gradients in fp16/bf16
(FP16 only) Scale loss to prevent gradient underflow: \(\mathcal{L}' = S \cdot \mathcal{L}\) (typical \(S = 2^{16}\))
Convert gradients back to fp32, unscale (fp16 only)
Update fp32 master weights
Cast updated weights back to fp16/bf16
Memory Overhead of Mixed Precision:
Optimizer (Adam) maintains three fp32 copies: master weights (\(4P\)), momentum (\(4P\)), variance (\(4P\)) = \(12P\) bytes. The fp16 weights (\(2P\)) and gradients (\(2P\)) coexist separately. Total: \(16P\) bytes.
PyTorch AMP:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler() # Loss scaling for fp16
for batch in dataloader:
optimizer.zero_grad()
with autocast(): # fp16 forward/backward
loss = model(batch)
scaler.scale(loss).backward() # Scale loss, compute gradients
scaler.step(optimizer) # Unscale, update fp32 weights
scaler.update() # Adjust scaling factor
Speedup: 2-3\(\times\) on A100/H100; 4-5\(\times\) with Tensor Cores.
24 Memory Optimization Techniques
24.1 Activation Checkpointing (Gradient Checkpointing)
Problem: Forward pass stores activations for backward pass → memory scales with sequence length and depth.
Example: GPT-3 (96 layers, batch=1, seq_len=2048):
Activations: \(\sim\)10GB per sample
Batch size 32 → 320GB (exceeds A100 80GB)
Activation Checkpointing (also called Gradient Checkpointing):
Forward pass: Store activations only at checkpoints (e.g., every \(k\) layers)
Backward pass: Recompute intermediate activations on-the-fly from checkpoints
Trade-off: \(\sqrt{N}\) checkpoints → \(\sqrt{N}\) memory, \(2\times\) compute (1 forward + 1 recompute)
Memory Reduction: \(10\times\) typical (e.g., 320GB → 32GB).
PyTorch Gradient Checkpointing:
from torch.utils.checkpoint import checkpoint
def forward(x):
x = checkpoint(layer1, x) # Don't store layer1 activations
x = checkpoint(layer2, x)
return x
Hugging Face Transformers:
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
use_cache=False, # Required for gradient checkpointing
)
model.gradient_checkpointing_enable()
24.2 Gradient Accumulation
Gradient Accumulation simulates large batch sizes on limited memory:
Run \(K\) micro-batches (batch size \(B / K\))
Accumulate gradients: \(g = \sum_{i=1}^K g_i\)
Update weights once: \(\theta_{t+1} = \theta_t - \eta g\)
Effect: Effective batch size \(B\) with memory for \(B / K\).
Trade-off: Training time increases by \(K\times\) (no parallelism across micro-batches).
PyTorch Gradient Accumulation:
accumulation_steps = 4
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps # Scale loss
loss.backward() # Accumulate gradients
if (i + 1) % accumulation_steps == 0:
optimizer.step() # Update weights
optimizer.zero_grad() # Clear gradients
24.3 Flash Attention
Flash Attention (Dao et al., 2022) optimizes self-attention memory via kernel fusion and tiling:
Standard attention: Materializes \(QK^T\) matrix (\(O(N^2)\) memory for seq len \(N\))
Flash Attention: Computes attention in blocks, never stores full \(QK^T\)
Memory: \(O(N)\) instead of \(O(N^2)\)
Speed: 2-4\(\times\) faster (reduces HBM read/writes)
Adoption: Default in Hugging Face Transformers 4.36+, PyTorch 2.0+.
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
)
25 Communication Optimization
25.1 Gradient Compression
Motivation: All-Reduce bandwidth bottleneck–especially for multi-node training.
Gradient Compression reduces communication volume:
Top-k Sparsification: Send only top \(k\%\) largest gradients (1-10%)
Quantization: Reduce precision (e.g., fp16 → int8)
Error Feedback: Accumulate residual errors to avoid bias
Trade-off: 10-100\(\times\) compression with minimal accuracy loss (<0.5% typically).
25.2 Overlap Communication with Computation
DDP Optimization: Start All-Reduce for layer \(i\) gradients while computing layer \(i+1\).
ZeRO-3 Prefetching: Pre-fetch parameters for layer \(i+1\) while computing layer \(i\).
Effect: Hides communication latency–critical for multi-node training.
Modern Frameworks Automate This:
PyTorch DDP: Automatic bucketing + overlap
DeepSpeed:
overlap_comm=trueFSDP:
forward_prefetch=True
25.3 Network Topology Awareness
Intra-node vs Inter-node Communication:
NVLink (intra-node): 600 GB/s (H100), extremely low latency
InfiniBand/RoCE (inter-node): 200-400 Gb/s, higher latency
Optimization Strategies:
Place tensor parallelism within node (maximize NVLink usage)
Pipeline parallelism across nodes (minimize inter-node traffic)
Data parallelism for outer replication (amortize All-Reduce)
26 Framework Comparison
26.1 DeepSpeed
Strengths:
ZeRO-Offload: Train 10B+ models on single GPU
Pipeline Parallelism: Built-in support
Sparse Attention: Memory-efficient attention variants
Used by Microsoft, BigScience (BLOOM 176B), StabilityAI
Weaknesses:
External Dependency: Not native PyTorch
Debugging: Harder to debug than FSDP
26.2 PyTorch FSDP
Strengths:
Native PyTorch: Better integration with PyTorch 2.x (torch.compile)
Simpler API: Easier to adopt for existing codebases
Used by Meta (LLaMA-2, LLaMA-3)
Weaknesses:
No CPU Offload: Must fit in GPU memory (with sharding)
No Pipeline Parallelism: Only data + tensor parallelism
26.3 Megatron-LM
Strengths:
Tensor Parallelism: Best-in-class implementation
3D Parallelism: Combines data + tensor + pipeline
Used by NVIDIA (GPT-3, MT-NLG 530B, Megatron-Turing)
Weaknesses:
Complex Setup: Requires deep understanding of parallelism
Less Flexible: Tightly coupled to specific model architectures
26.4 When to Use What?
| Scenario | Recommendation |
|---|---|
| Single GPU, <10B model | Standard training (no parallelism) |
| Multi-GPU, single node | DDP or FSDP (SHARD_GRAD_OP) |
| Large model (10-70B), multi-GPU | FSDP (FULL_SHARD) or DeepSpeed ZeRO-2 |
| Very large model (70B+), multi-node | FSDP + tensor parallelism or DeepSpeed ZeRO-3 |
| Extreme scale (100B+) | Megatron-LM (3D parallelism) |
| Limited GPU memory | DeepSpeed ZeRO-Offload |
27 Practical Training Recipes
27.1 Recipe 1: Fine-Tuning 7B Model (LLaMA-2) on 4\(\times\)A100
Setup:
Model: LLaMA-2-7B (7B params, fp16 → 14GB)
Batch size: 32 (8 per GPU)
Seq length: 2048
Strategy:
FSDP: SHARD_GRAD_OP (ZeRO-2)
Mixed precision: bf16 (no loss scaling)
Gradient checkpointing: Enabled (saves \(\sim\)40GB activations)
Flash Attention 2: Reduces memory by 2\(\times\)
# Launch with torchrun
torchrun --nproc_per_node=4 train.py \
--model_name meta-llama/Llama-2-7b-hf \
--fsdp "shard_grad_op" \
--bf16 true \
--gradient_checkpointing true \
--attn_implementation flash_attention_2
27.2 Recipe 2: Pre-Training 70B Model on 64\(\times\)H100
Setup:
Model: 70B params (GPT-3 scale)
Batch size: 2048 (32 per GPU)
Seq length: 4096
Strategy:
Data Parallelism: 8-way (8 replicas)
FSDP: FULL_SHARD (ZeRO-3) within each replica
Tensor Parallelism: 8-way (split layers across 8 GPUs)
Total: \(8 \times 8 = 64\) GPUs
Mixed precision: bf16
Gradient accumulation: 4 steps
Activation checkpointing: Every 4 layers
Expected Throughput: \(\sim\)2-3 tokens/sec/GPU → 130-200 tokens/sec total.
27.3 Recipe 3: Budget Training (10B Model on 2\(\times\)RTX 3090)
Setup:
Model: 10B params
GPU: RTX 3090 (24GB each)
Batch size: 4 (2 per GPU)
Strategy:
DeepSpeed ZeRO-3 + Offload: Offload optimizer states to CPU
Mixed precision: fp16 (bf16 not supported on RTX 3090)
Gradient checkpointing: Aggressive (every 2 layers)
Gradient accumulation: 8 steps (effective batch 32)
Expected Speed: Very slow (\(\sim\)0.1-0.2 tokens/sec/GPU)–CPU bottleneck.
28 Interview Questions & Key Concepts
28.1 Common Interview Questions
Q: Explain the difference between DataParallel and DistributedDataParallel.
A: DP is single-process multi-threaded (Python GIL bottleneck), uses master GPU to aggregate gradients. DDP is multi-process (one per GPU), uses All-Reduce for gradient sync–no master GPU, much faster (2-3\(\times\)).
Q: What is ZeRO and how does it differ from FSDP?
A: ZeRO (DeepSpeed) shards optimizer states (Stage 1), gradients (Stage 2), and parameters (Stage 3) to reduce memory redundancy. FSDP is PyTorch’s native implementation of ZeRO-3 principles, with tighter integration but fewer features (no CPU offload in vanilla FSDP).
Q: When would you use tensor parallelism vs pipeline parallelism?
A: Tensor parallelism splits layers across GPUs (intra-layer)–requires high bandwidth (use within node via NVLink). Pipeline parallelism splits model into stages (inter-layer)–uses point-to-point communication (suitable across nodes). For very large models, combine both (3D parallelism).
Q: Explain gradient checkpointing. What’s the trade-off?
A: Stores activations only at checkpoints (e.g., every \(k\) layers), recomputes intermediate activations during backward. Trade-off: Reduces memory by \(\sim\)10\(\times\) but increases compute by \(\sim\)2\(\times\) (one forward pass + one recompute).
Q: Why use bf16 instead of fp16 for mixed precision?
A: BF16 has same dynamic range as fp32 (8 exponent bits) → no gradient underflow, no loss scaling needed. FP16 has limited range (5 exponent bits) → requires loss scaling to avoid underflow. Modern GPUs (A100, H100) support bf16 natively.
Q: How does Flash Attention reduce memory?
A: Standard attention materializes \(QK^T\) matrix (\(O(N^2)\) memory). Flash Attention uses kernel fusion and tiling to compute attention in blocks, never storing full \(QK^T\)–reduces memory to \(O(N)\) and speeds up by 2-4\(\times\) via fewer HBM reads/writes.
Q: Explain the All-Reduce algorithm used in DDP.
A: Ring All-Reduce: GPUs arranged in logical ring, two phases: (1) Reduce-Scatter–each GPU sends chunks, accumulates sums; (2) All-Gather–broadcast accumulated chunks. Bandwidth-optimal: \(O(2M / \text{bandwidth})\) independent of \(N\).
Q: What’s the memory breakdown for training a Transformer model?
A: For \(P\) parameters:
Weights: \(2P\) (fp16)
Gradients: \(2P\) (fp16)
Optimizer (Adam): \(12P\) (fp32 momentum + variance)
Activations: Depends on batch size and sequence length–often \(>10P\)
Q: How would you train a 175B model?
A: Use 3D parallelism: (1) Tensor parallelism (8-way, within node), (2) Pipeline parallelism (8 stages, across nodes), (3) Data parallelism (16 replicas). Total: \(8 \times 8 \times 16 = 1024\) GPUs. Use ZeRO-1 or FSDP (SHARD_GRAD_OP) for data parallelism. Enable gradient checkpointing, bf16, Flash Attention.
Q: What causes pipeline bubbles and how do you minimize them?
A: Bubbles occur when pipeline stages wait for previous stages. Minimize by: (1) Increase micro-batch count (bubble fraction \(\sim (N-1)/M\)), (2) Use 1F1B schedule (interleave forward/backward), (3) Balance stage compute times (even layer distribution).
Q: Explain tensor contiguity in PyTorch. Why does
viewsometimes fail?A: A tensor is contiguous when its data is laid out in one uninterrupted block of memory with the expected stride order (row-major by default). Operations like
transpose(),slice[::2], or certain views can produce non-contiguous tensors–they reference the same underlying data but with different strides.view()only works on contiguous tensors because it just changes the shape metadata without moving data. If a tensor isn’t contiguous, you must:Call
.contiguous().view(...)to create a contiguous copy firstUse
.reshape(...)which automatically copies if needed
Example:
x = torch.randn(3, 4) y = x.transpose(0, 1) # Non-contiguous (strides changed) # y.view(-1) would fail! z = y.contiguous().view(-1) # Works # Or: z = y.reshape(-1) # Also works (copies if needed)Q: Does
viewwork differently in distributed training (DDP)?A: No–
viewworks the same way in DDP. It’s a local tensor operation that just requires contiguity. In DDP:Each process has its own replica of the model
viewoperates on the local tensor shardIf a tensor becomes non-contiguous (e.g., after
transpose,gather, or custom ops), use.contiguous().view(...)
Caveat with DTensor: Newer PyTorch distributed APIs like
DTensormay have restrictions on view-like operations depending on sharding strategy, but standard DDP tensors follow normal contiguity rules.Q: When should you worry about contiguity in practice?
A: Be conscious of contiguity when:
After
transpose(),permute(), or non-trivial slicingBefore
view()–it will throwRuntimeErrorif non-contiguousWhen passing tensors to C++/CUDA kernels that assume contiguous layout
In tight training loops–unnecessary
.contiguous()copies waste memory/time
Best practice:
Use
.reshape(...)instead of.view(...)if you’re unsure (it handles contiguity automatically)Check contiguity with
tensor.is_contiguous()when debuggingAvoid calling
.contiguous()unnecessarily–it creates a copy
28.2 Key Takeaways for Interviews
Critical Concepts to Master:
DDP: Multi-process data parallelism with All-Reduce
ZeRO/FSDP: Sharding optimizer, gradients, parameters to reduce memory
Tensor Parallelism: Split layers across GPUs (intra-layer)
Pipeline Parallelism: Split model into stages (inter-layer)
Mixed Precision: Use fp16/bf16 for compute, fp32 for updates
Gradient Checkpointing: Trade memory for recomputation
Flash Attention: Fused kernel to reduce attention memory
Modern LLM Training Stack:
Framework: PyTorch FSDP or DeepSpeed ZeRO
Precision: BF16 (A100/H100) or FP16 with loss scaling
Optimizer: AdamW with warmup + cosine decay
Memory: Gradient checkpointing + Flash Attention 2
Communication: NCCL (intra-node), InfiniBand (inter-node)
29 References & Further Reading
ZeRO (Rajbhandari et al., 2020): ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
FSDP (Meta AI, 2021): PyTorch Fully Sharded Data Parallel
Megatron-LM (Shoeybi et al., 2019): Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
GPipe (Huang et al., 2019): GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism
Flash Attention (Dao et al., 2022): FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Mixed Precision (Micikevicius et al., 2018): Mixed Precision Training
DeepSpeed Documentation: https://www.deepspeed.ai/
PyTorch FSDP Tutorial: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html