17  Chapter 16: Training Optimization & Parallelization Strategies

18 Introduction: The Scale Challenge

Modern large language models (LLMs) have grown from millions to hundreds of billions of parameters:

  • GPT-2 (2019): 1.5B params, trained on 40GB text

  • GPT-3 (2020): 175B params, requires 350GB+ in fp32

  • LLaMA-2 70B (2023): 70B params, 2T tokens

  • GPT-4 (2023): Estimated 1.7T params (MoE), multi-modal

Key Challenges:

  1. Memory: A 70B model requires \(\sim\)140GB just for fp16 weights–exceeds single GPU

  2. Compute: Training on trillions of tokens takes months on thousands of GPUs

  3. Communication: Gradient synchronization becomes bottleneck at scale

  4. Efficiency: Memory overhead (optimizer states, activations) can be 10-20\(\times\) model size

This Document Covers:

  • Data Parallelism: DDP, gradient synchronization, FSDP, ZeRO stages

  • Model Parallelism: Tensor parallelism, pipeline parallelism, 3D parallelism

  • Memory Optimization: Mixed precision, gradient accumulation, activation checkpointing

  • Communication: All-reduce, reduce-scatter, gradient compression

  • Frameworks: PyTorch FSDP, DeepSpeed, Megatron-LM, Ray

19 Data Parallelism

19.1 Concept: Replicate Model, Split Data

Data Parallelism (DP) replicates the model on each GPU and splits the training batch across devices.

Forward Pass:

  1. Each GPU holds a full copy of model parameters \(\theta\)

  2. Global batch \(B\) split into \(N\) micro-batches: \(B = \{B_1, B_2, \ldots, B_N\}\)

  3. GPU \(i\) processes \(B_i\) independently → computes loss \(\mathcal{L}_i\) and gradients \(g_i\)

Backward Pass:

  1. Each GPU computes local gradients: \(g_i = \nabla_\theta \mathcal{L}_i(\theta)\)

  2. Synchronize gradients via All-Reduce: \(g = \frac{1}{N} \sum_{i=1}^N g_i\)

  3. Update parameters: \(\theta_{t+1} = \theta_t - \eta g\)

Note

Key Insight: All-Reduce ensures every GPU has the same averaged gradient before updating. This keeps model replicas in sync.

Note: This is data parallelism–each GPU has a full copy of the model and processes different data. For models too large to fit on one GPU, use model parallelism (tensor/pipeline parallelism, covered later) where the model itself is split across GPUs.

19.2 DataParallel (DP) vs DistributedDataParallel (DDP)

DataParallel (DP):

  • Single-process, multi-threaded (Python GIL bottleneck)

  • Master GPU gathers gradients from all GPUs → broadcasts updated params

  • Memory imbalance: master GPU stores full batch + optimizer state

  • Slow: Communication is sequential; limited by GIL

DistributedDataParallel (DDP):

  • Multi-process: one process per GPU (avoids GIL)

  • No master GPU: All-Reduce distributes gradients evenly

  • Uses NCCL (NVIDIA Collective Communications Library) for efficient gradient sync

  • Faster: Up to 2-3\(\times\) speedup over DP

TipExample

PyTorch DDP Setup:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Initialize process group
dist.init_process_group(backend="nccl")
model = MyModel().to(device)
model = DDP(model, device_ids=[local_rank])

# Training loop
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()  # Gradients auto-synced via All-Reduce
    optimizer.step()

19.3 All-Reduce Algorithm

All-Reduce is a collective communication operation where each GPU contributes local gradients \(g_i\) and receives the global average \(g = \frac{1}{N} \sum_{i=1}^N g_i\).

Ring All-Reduce (Bandwidth-Optimal):

  1. Arrange \(N\) GPUs in a logical ring

  2. Reduce-Scatter Phase: Each GPU sends chunks to neighbors; gradually accumulate sums

  3. All-Gather Phase: Each GPU sends accumulated chunks to neighbors; reconstruct full result

Complexity Analysis:

Consider \(N\) GPUs, each with gradient data of size \(M\) bytes.

Ring All-Reduce Mechanics:

  • Split each GPU’s data into \(N\) chunks of size \(M/N\)

  • Reduce-Scatter: \(N-1\) communication rounds → each GPU sends \((N-1) \cdot \frac{M}{N}\) bytes

  • All-Gather: \(N-1\) communication rounds → each GPU sends \((N-1) \cdot \frac{M}{N}\) bytes

  • Total data sent per GPU: \(2(N-1) \cdot \frac{M}{N} = \frac{2(N-1)M}{N}\)

Time Complexity:

  • Communication time: \(\frac{2(N-1)M}{N \cdot \text{bandwidth}} \approx \frac{2M}{\text{bandwidth}}\) for large \(N\)

  • As \(N \to \infty\), \(\frac{N-1}{N} \to 1\) → time approaches \(\frac{2M}{\text{bandwidth}}\)

  • Key insight: Communication cost is independent of \(N\)–scales perfectly!

Practical Implications:

  • Bandwidth-optimal: Each GPU sends/receives \(\approx 2M\) total (minimal theoretical limit)

  • Latency: \(2(N-1)\) sequential message steps → can hurt small models where latency dominates

  • Example: 8 GPUs, 1GB gradients → each GPU sends 1.75GB, takes \(\approx\) 14 message hops

Note

Modern Practice: NCCL implements optimized All-Reduce using tree-based algorithms for small \(N\) (lower latency) and ring-based for large \(N\) (bandwidth-optimal).

19.4 Limitations of Naive Data Parallelism

  • Memory Replication: Each GPU stores full model (weights, optimizer states, gradients)

  • Example: 70B model in fp16 → 140GB \(\times\) 8 GPUs = 1.1TB total memory used

  • Inefficient for Large Models: Cannot train models larger than single GPU memory

  • Communication Overhead: All-Reduce latency grows with model size

Solution: ZeRO (Zero Redundancy Optimizer) and FSDP (Fully Sharded Data Parallel)

20 ZeRO: Zero Redundancy Optimizer

20.1 Motivation: Eliminate Memory Redundancy

In standard data parallelism, memory breakdown per GPU for a model with \(P\) parameters:

  • Model weights: \(2P\) bytes (fp16)

  • Gradients: \(2P\) bytes (fp16)

  • Optimizer states (Adam): \(12P\) bytes = fp32 master weights (\(4P\)) + fp32 momentum (\(4P\)) + fp32 variance (\(4P\))

  • Total: \(16P\) bytes per GPU → fully redundant across GPUs

Example: 7B model → \(7 \times 10^9 \times 16 = 112\) GB per GPU (before activations!)

ZeRO (DeepSpeed) removes this redundancy by sharding optimizer states, gradients, and parameters across GPUs.

20.2 ZeRO Stages

ZeRO-1: Optimizer State Partitioning

  • Shard optimizer states (momentum, variance) across \(N\) GPUs

  • Each GPU stores \(\frac{1}{N}\) of optimizer states → updates only its assigned params

  • Still does full All-Reduce: All GPUs compute and store full gradients (\(2P\) each)

  • After optimizer step: All-Gather updated params to sync models

  • Memory Reduction: Optimizer states from \(12P\) to \(\frac{12P}{N}\) (gradients still \(2P\))

ZeRO-2: Gradient Partitioning (builds on ZeRO-1)

  • Replaces All-Reduce with Reduce-Scatter: Each GPU gets different shard of averaged gradients

  • After backward: GPU \(i\) receives averaged gradients only for parameters it’s responsible for

  • Each GPU stores \(\frac{1}{N}\) of gradients (matching its optimizer state partition)

  • Updates only its assigned parameter shard using its local gradient shard

  • Memory Reduction: Gradients from \(2P\) to \(\frac{2P}{N}\) + Optimizer \(\frac{12P}{N}\)

TipExample

Concrete Example: 2 GPUs, 4 parameters \((W_1, W_2, W_3, W_4)\)

ZeRO-1 Workflow:

  1. Backward pass: Both GPUs compute local gradients \((g_1, g_2, g_3, g_4)\)

  2. All-Reduce: Average gradients across GPUs

    • GPU-1 receives: \((G_1, G_2, G_3, G_4)\) (full averaged gradients)

    • GPU-2 receives: \((G_1, G_2, G_3, G_4)\) (full averaged gradients)

  3. Optimizer step:

    • GPU-1 has optimizer states for \((W_1, W_2)\) → updates \((W_1, W_2)\) using \((G_1, G_2)\)

    • GPU-2 has optimizer states for \((W_3, W_4)\) → updates \((W_3, W_4)\) using \((G_3, G_4)\)

  4. All-Gather: Broadcast updated parameters

    • GPU-1 sends \((W_1', W_2')\) to GPU-2

    • GPU-2 sends \((W_3', W_4')\) to GPU-1

    • Both now have full model: \((W_1', W_2', W_3', W_4')\)

  5. Memory: Both GPUs store all gradients (\(2P\)), but only half the optimizer states (\(6P\) each)

ZeRO-2 Workflow:

  1. Backward pass: Both GPUs compute local gradients \((g_1, g_2, g_3, g_4)\)

  2. Reduce-Scatter: Average and distribute gradient shards

    • GPU-1 receives: \((G_1, G_2)\) only (averaged, for its assigned params)

    • GPU-2 receives: \((G_3, G_4)\) only (averaged, for its assigned params)

  3. Optimizer step:

    • GPU-1 updates \((W_1, W_2)\) using \((G_1, G_2)\)

    • GPU-2 updates \((W_3, W_4)\) using \((G_3, G_4)\)

  4. All-Gather: Broadcast updated parameters (same as ZeRO-1)

  5. Memory: Each GPU stores half gradients (\(P\)) + half optimizer states (\(6P\) each)

Key Insight: In ZeRO-1, you store all gradients but can discard them after optimizer step. ZeRO-2 never stores the full gradient array, saving memory immediately.

ZeRO-3: Parameter Partitioning

  • Shard model weights: each GPU stores \(\frac{1}{N}\) of parameters

  • Forward/backward: All-Gather needed params on-the-fly, discard after use

  • Memory Reduction: Parameters from \(2P\) to \(\frac{2P}{N}\)

  • Trade-off: Increased communication (All-Gather per layer)

Memory Comparison (7B model, 8 GPUs, fp16):

Method Weights Gradients Optimizer
Baseline DP 14 GB 14 GB 84 GB
ZeRO-1 14 GB 14 GB 10.5 GB
ZeRO-2 14 GB 1.75 GB 10.5 GB
ZeRO-3 1.75 GB 1.75 GB 10.5 GB
Note

ZeRO-3 Trade-off: Reduces memory by \(8\times\) but requires All-Gather for every layer. Use ZeRO-2 if communication bandwidth is limited; ZeRO-3 if memory-constrained.

20.3 ZeRO-Offload: CPU Memory Extension

ZeRO-Offload moves optimizer states to CPU memory, reducing GPU memory further:

  • Forward/backward on GPU (fast)

  • Gradients transferred to CPU → optimizer step on CPU → updated params back to GPU

  • Use Case: Training 10B+ models on consumer GPUs (e.g., RTX 3090 24GB)

TipExample

DeepSpeed ZeRO-3 Config:

{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {"device": "cpu"},
    "offload_param": {"device": "cpu"},
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": 5e8,
    "stage3_prefetch_bucket_size": 5e8,
    "stage3_param_persistence_threshold": 1e6
  }
}

21 Fully Sharded Data Parallel (FSDP)

21.1 PyTorch Native Alternative to ZeRO

FSDP (PyTorch 1.11+) is PyTorch’s native implementation of ZeRO-3 principles:

  • Shards parameters, gradients, and optimizer states across GPUs

  • Automatically manages All-Gather/Reduce-Scatter during forward/backward

  • Integrated with PyTorch APIs (no external dependency like DeepSpeed)

21.2 FSDP Workflow

Forward Pass:

  1. GPU \(i\) holds shard \(\theta_i\) (e.g., layers 0-3 of 32-layer model)

  2. Before processing layer \(k\): All-Gather \(\theta_k\) from all GPUs

  3. Compute activations \(a_k = f_k(a_{k-1}, \theta_k)\)

  4. Discard \(\theta_k\) (free memory)

Backward Pass:

  1. Re-All-Gather \(\theta_k\) to compute gradients \(g_k\)

  2. Reduce-Scatter \(g_k\) → each GPU gets averaged shard

  3. Discard \(\theta_k\) again

Optimizer Step:

  • Each GPU updates its shard \(\theta_i\) using local optimizer state
TipExample

PyTorch FSDP Usage:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import CPUOffload, MixedPrecision

model = MyModel()
model = FSDP(
    model,
    mixed_precision=MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.float16,
    ),
    cpu_offload=CPUOffload(offload_params=True),
    sharding_strategy="FULL_SHARD",  # ZeRO-3 equivalent
)

# Training loop (identical to DDP)
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()

21.3 Sharding Strategies

FSDP offers multiple sharding modes:

  • FULL_SHARD: ZeRO-3 (shard params, grads, optimizer)

  • SHARD_GRAD_OP: ZeRO-2 (shard grads + optimizer only)

  • NO_SHARD: Equivalent to DDP (no sharding)

  • HYBRID_SHARD: Shard within node, replicate across nodes (minimize inter-node communication)

Note

When to Use FSDP vs DeepSpeed:

  • FSDP: PyTorch-native, easier debugging, better integration with PyTorch 2.x features

  • DeepSpeed: More mature, additional features (ZeRO-Offload, pipeline parallelism, sparse attention)

22 Model Parallelism

22.1 When Data Parallelism Fails

Problem: Model too large to fit on a single GPU, even with ZeRO-3.

  • Example: 175B GPT-3 → \(\sim\)350GB in fp16

  • ZeRO-3 with 64 GPUs → \(\frac{350}{64} \approx 5.5\) GB per GPU (just for weights)

  • Activations during forward pass can be \(\sim\)10-50GB per GPU → OOM

Solution: Model Parallelism–split the model itself across GPUs.

Note

ZeRO-3 vs Tensor Parallelism: Both Shard Weights, But Differently!

ZeRO-3 (Data Parallelism with Weight Sharding):

  • Each GPU stores \(\frac{1}{N}\) of weights (sharded for memory)

  • During forward: All-Gather to reconstruct full weights temporarily

  • All GPUs compute the same operations on different data batches

  • Still data parallelism–just memory-efficient

Tensor Parallelism (Model Parallelism):

  • Each GPU stores \(\frac{1}{N}\) of weights (sharded for computation)

  • Each GPU computes different parts of the layer on the same data

  • Example: GPU-0 computes first half of MLP output, GPU-1 computes second half

  • Splits both weights and computation

Key Difference: ZeRO-3 shards weights for memory but reconstructs them for computation. Tensor parallelism shards weights and splits the computation itself.

22.2 Tensor Parallelism (Intra-Layer)

Tensor Parallelism splits individual layers across GPUs. Most common for Transformer MLP and attention.

MLP Layer Splitting:

A standard Transformer MLP: \(y = \text{GELU}(xW_1)W_2\) where \(W_1 \in \mathbb{R}^{d \times 4d}\), \(W_2 \in \mathbb{R}^{4d \times d}\).

Column-wise Split of \(W_1\):

  • Partition \(W_1 = [W_1^{(1)} | W_1^{(2)}]\) across 2 GPUs

  • Each GPU computes: \(h^{(i)} = \text{GELU}(x W_1^{(i)})\) independently (no communication)

  • Concatenate: \(h = [h^{(1)} | h^{(2)}]\)

Row-wise Split of \(W_2\):

  • Partition \(W_2 = \begin{bmatrix} W_2^{(1)} \\ W_2^{(2)} \end{bmatrix}\)

  • Each GPU computes partial output: \(y^{(i)} = h^{(i)} W_2^{(i)}\)

  • All-Reduce: Sum \(y = \sum_i y^{(i)}\)

Communication: One All-Reduce per layer (cheap for large \(d\)).

Note

Megatron-LM (NVIDIA) pioneered tensor parallelism for GPT-style models. Used to train GPT-3, MT-NLG (530B params).

22.3 Pipeline Parallelism (Inter-Layer)

Pipeline Parallelism splits the model into stages (layer groups) assigned to different GPUs.

Naive Pipelining (Sequential):

  1. GPU 0: Layers 1-8

  2. GPU 1: Layers 9-16

  3. GPU 2: Layers 17-24

  4. GPU 3: Layers 25-32

Problem: GPU 1 waits for GPU 0 → only 25% utilization (pipeline bubble).

GPipe (Micro-batching):

  • Split global batch into \(M\) micro-batches

  • Interleave forward passes: GPU 0 processes micro-batch 1, then GPU 1 starts while GPU 0 processes micro-batch 2

  • Backward passes flow in reverse: GPU 3 → GPU 2 → GPU 1 → GPU 0

  • Reduces bubble to \(\frac{N-1}{M}\) where \(N\) = # GPUs, \(M\) = # micro-batches

PipeDream (Asynchronous):

  • Different micro-batches at different pipeline stages simultaneously

  • 1F1B Schedule: Each stage alternates forward and backward passes (1 Forward, 1 Backward)

  • Maintains multiple versions of weights to avoid staleness

Note

Pipeline Parallelism Trade-offs:

  • + No all-to-all communication (only point-to-point between stages)

  • - Pipeline bubbles reduce GPU utilization (\(\sim\)80-90% typical)

  • - Requires careful micro-batch size tuning

22.4 3D Parallelism: Combining All Strategies

Modern LLM training uses 3D Parallelism: Data + Tensor + Pipeline.

Example: 175B GPT-3 on 1024 GPUs (Megatron-LM):

  • Data Parallelism: 16-way (16 replicas of the full pipeline)

  • Tensor Parallelism: 8-way (each layer split across 8 GPUs)

  • Pipeline Parallelism: 8 stages (8 layer groups)

  • Total: \(16 \times 8 \times 8 = 1024\) GPUs

Communication Hierarchy:

  • Tensor parallelism: High-bandwidth NVLink within node (fast All-Reduce)

  • Pipeline parallelism: Point-to-point across nodes (moderate bandwidth)

  • Data parallelism: All-Reduce across replicas (infrequent, amortized)

23 Mixed Precision Training

23.1 Motivation: Speed and Memory

Standard Training: fp32 (32-bit floating point) for weights, activations, gradients.

  • Slow: Tensor cores (A100, H100) are optimized for fp16/bf16

  • Memory-hungry: 2\(\times\) more memory than fp16

Mixed Precision Training uses fp16 for forward/backward, fp32 for optimizer updates.

23.2 FP16 vs BF16

FP16 (IEEE Half-Precision):

  • 1 sign bit, 5 exponent bits, 10 mantissa bits

  • Range: \(\pm 6.55 \times 10^4\) (limited dynamic range)

  • Risk: Gradient underflow (small gradients → zero)

BF16 (Brain Float 16):

  • 1 sign bit, 8 exponent bits (same as fp32!), 7 mantissa bits

  • Range: \(\pm 3.4 \times 10^{38}\) (same as fp32)

  • Robust: No loss scaling needed

  • Modern Default: A100, H100, TPUs support bf16 natively

23.3 Mixed Precision Workflow

AMP (Automatic Mixed Precision):

  1. Maintain master copy of weights in fp32

  2. Cast weights to fp16/bf16 for forward pass

  3. Compute loss and gradients in fp16/bf16

  4. (FP16 only) Scale loss to prevent gradient underflow: \(\mathcal{L}' = S \cdot \mathcal{L}\) (typical \(S = 2^{16}\))

  5. Convert gradients back to fp32, unscale (fp16 only)

  6. Update fp32 master weights

  7. Cast updated weights back to fp16/bf16

Note

Memory Overhead of Mixed Precision:
Optimizer (Adam) maintains three fp32 copies: master weights (\(4P\)), momentum (\(4P\)), variance (\(4P\)) = \(12P\) bytes. The fp16 weights (\(2P\)) and gradients (\(2P\)) coexist separately. Total: \(16P\) bytes.

TipExample

PyTorch AMP:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()  # Loss scaling for fp16

for batch in dataloader:
    optimizer.zero_grad()
    
    with autocast():  # fp16 forward/backward
        loss = model(batch)
    
    scaler.scale(loss).backward()  # Scale loss, compute gradients
    scaler.step(optimizer)          # Unscale, update fp32 weights
    scaler.update()                 # Adjust scaling factor

Speedup: 2-3\(\times\) on A100/H100; 4-5\(\times\) with Tensor Cores.

24 Memory Optimization Techniques

24.1 Activation Checkpointing (Gradient Checkpointing)

Problem: Forward pass stores activations for backward pass → memory scales with sequence length and depth.

Example: GPT-3 (96 layers, batch=1, seq_len=2048):

  • Activations: \(\sim\)10GB per sample

  • Batch size 32 → 320GB (exceeds A100 80GB)

Activation Checkpointing (also called Gradient Checkpointing):

  1. Forward pass: Store activations only at checkpoints (e.g., every \(k\) layers)

  2. Backward pass: Recompute intermediate activations on-the-fly from checkpoints

  3. Trade-off: \(\sqrt{N}\) checkpoints → \(\sqrt{N}\) memory, \(2\times\) compute (1 forward + 1 recompute)

Memory Reduction: \(10\times\) typical (e.g., 320GB → 32GB).

TipExample

PyTorch Gradient Checkpointing:

from torch.utils.checkpoint import checkpoint

def forward(x):
    x = checkpoint(layer1, x)  # Don't store layer1 activations
    x = checkpoint(layer2, x)
    return x

Hugging Face Transformers:

model = AutoModelForCausalLM.from_pretrained(
    "gpt2",
    use_cache=False,  # Required for gradient checkpointing
)
model.gradient_checkpointing_enable()

24.2 Gradient Accumulation

Gradient Accumulation simulates large batch sizes on limited memory:

  1. Run \(K\) micro-batches (batch size \(B / K\))

  2. Accumulate gradients: \(g = \sum_{i=1}^K g_i\)

  3. Update weights once: \(\theta_{t+1} = \theta_t - \eta g\)

Effect: Effective batch size \(B\) with memory for \(B / K\).

Trade-off: Training time increases by \(K\times\) (no parallelism across micro-batches).

TipExample

PyTorch Gradient Accumulation:

accumulation_steps = 4

for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps  # Scale loss
    loss.backward()  # Accumulate gradients
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()       # Update weights
        optimizer.zero_grad()  # Clear gradients

24.3 Flash Attention

Flash Attention (Dao et al., 2022) optimizes self-attention memory via kernel fusion and tiling:

  • Standard attention: Materializes \(QK^T\) matrix (\(O(N^2)\) memory for seq len \(N\))

  • Flash Attention: Computes attention in blocks, never stores full \(QK^T\)

  • Memory: \(O(N)\) instead of \(O(N^2)\)

  • Speed: 2-4\(\times\) faster (reduces HBM read/writes)

Adoption: Default in Hugging Face Transformers 4.36+, PyTorch 2.0+.

TipExample
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="flash_attention_2",
)

25 Communication Optimization

25.1 Gradient Compression

Motivation: All-Reduce bandwidth bottleneck–especially for multi-node training.

Gradient Compression reduces communication volume:

  • Top-k Sparsification: Send only top \(k\%\) largest gradients (1-10%)

  • Quantization: Reduce precision (e.g., fp16 → int8)

  • Error Feedback: Accumulate residual errors to avoid bias

Trade-off: 10-100\(\times\) compression with minimal accuracy loss (<0.5% typically).

25.2 Overlap Communication with Computation

DDP Optimization: Start All-Reduce for layer \(i\) gradients while computing layer \(i+1\).

ZeRO-3 Prefetching: Pre-fetch parameters for layer \(i+1\) while computing layer \(i\).

Effect: Hides communication latency–critical for multi-node training.

Note

Modern Frameworks Automate This:

  • PyTorch DDP: Automatic bucketing + overlap

  • DeepSpeed: overlap_comm=true

  • FSDP: forward_prefetch=True

25.3 Network Topology Awareness

Intra-node vs Inter-node Communication:

  • NVLink (intra-node): 600 GB/s (H100), extremely low latency

  • InfiniBand/RoCE (inter-node): 200-400 Gb/s, higher latency

Optimization Strategies:

  • Place tensor parallelism within node (maximize NVLink usage)

  • Pipeline parallelism across nodes (minimize inter-node traffic)

  • Data parallelism for outer replication (amortize All-Reduce)

26 Framework Comparison

26.1 DeepSpeed

Strengths:

  • ZeRO-Offload: Train 10B+ models on single GPU

  • Pipeline Parallelism: Built-in support

  • Sparse Attention: Memory-efficient attention variants

  • Used by Microsoft, BigScience (BLOOM 176B), StabilityAI

Weaknesses:

  • External Dependency: Not native PyTorch

  • Debugging: Harder to debug than FSDP

26.2 PyTorch FSDP

Strengths:

  • Native PyTorch: Better integration with PyTorch 2.x (torch.compile)

  • Simpler API: Easier to adopt for existing codebases

  • Used by Meta (LLaMA-2, LLaMA-3)

Weaknesses:

  • No CPU Offload: Must fit in GPU memory (with sharding)

  • No Pipeline Parallelism: Only data + tensor parallelism

26.3 Megatron-LM

Strengths:

  • Tensor Parallelism: Best-in-class implementation

  • 3D Parallelism: Combines data + tensor + pipeline

  • Used by NVIDIA (GPT-3, MT-NLG 530B, Megatron-Turing)

Weaknesses:

  • Complex Setup: Requires deep understanding of parallelism

  • Less Flexible: Tightly coupled to specific model architectures

26.4 When to Use What?

Scenario Recommendation
Single GPU, <10B model Standard training (no parallelism)
Multi-GPU, single node DDP or FSDP (SHARD_GRAD_OP)
Large model (10-70B), multi-GPU FSDP (FULL_SHARD) or DeepSpeed ZeRO-2
Very large model (70B+), multi-node FSDP + tensor parallelism or DeepSpeed ZeRO-3
Extreme scale (100B+) Megatron-LM (3D parallelism)
Limited GPU memory DeepSpeed ZeRO-Offload

27 Practical Training Recipes

27.1 Recipe 1: Fine-Tuning 7B Model (LLaMA-2) on 4\(\times\)A100

Setup:

  • Model: LLaMA-2-7B (7B params, fp16 → 14GB)

  • Batch size: 32 (8 per GPU)

  • Seq length: 2048

Strategy:

  • FSDP: SHARD_GRAD_OP (ZeRO-2)

  • Mixed precision: bf16 (no loss scaling)

  • Gradient checkpointing: Enabled (saves \(\sim\)40GB activations)

  • Flash Attention 2: Reduces memory by 2\(\times\)

TipExample
# Launch with torchrun
torchrun --nproc_per_node=4 train.py \
  --model_name meta-llama/Llama-2-7b-hf \
  --fsdp "shard_grad_op" \
  --bf16 true \
  --gradient_checkpointing true \
  --attn_implementation flash_attention_2

27.2 Recipe 2: Pre-Training 70B Model on 64\(\times\)H100

Setup:

  • Model: 70B params (GPT-3 scale)

  • Batch size: 2048 (32 per GPU)

  • Seq length: 4096

Strategy:

  • Data Parallelism: 8-way (8 replicas)

  • FSDP: FULL_SHARD (ZeRO-3) within each replica

  • Tensor Parallelism: 8-way (split layers across 8 GPUs)

  • Total: \(8 \times 8 = 64\) GPUs

  • Mixed precision: bf16

  • Gradient accumulation: 4 steps

  • Activation checkpointing: Every 4 layers

Expected Throughput: \(\sim\)2-3 tokens/sec/GPU → 130-200 tokens/sec total.

27.3 Recipe 3: Budget Training (10B Model on 2\(\times\)RTX 3090)

Setup:

  • Model: 10B params

  • GPU: RTX 3090 (24GB each)

  • Batch size: 4 (2 per GPU)

Strategy:

  • DeepSpeed ZeRO-3 + Offload: Offload optimizer states to CPU

  • Mixed precision: fp16 (bf16 not supported on RTX 3090)

  • Gradient checkpointing: Aggressive (every 2 layers)

  • Gradient accumulation: 8 steps (effective batch 32)

Expected Speed: Very slow (\(\sim\)0.1-0.2 tokens/sec/GPU)–CPU bottleneck.

28 Interview Questions & Key Concepts

28.1 Common Interview Questions

  1. Q: Explain the difference between DataParallel and DistributedDataParallel.

    A: DP is single-process multi-threaded (Python GIL bottleneck), uses master GPU to aggregate gradients. DDP is multi-process (one per GPU), uses All-Reduce for gradient sync–no master GPU, much faster (2-3\(\times\)).

  2. Q: What is ZeRO and how does it differ from FSDP?

    A: ZeRO (DeepSpeed) shards optimizer states (Stage 1), gradients (Stage 2), and parameters (Stage 3) to reduce memory redundancy. FSDP is PyTorch’s native implementation of ZeRO-3 principles, with tighter integration but fewer features (no CPU offload in vanilla FSDP).

  3. Q: When would you use tensor parallelism vs pipeline parallelism?

    A: Tensor parallelism splits layers across GPUs (intra-layer)–requires high bandwidth (use within node via NVLink). Pipeline parallelism splits model into stages (inter-layer)–uses point-to-point communication (suitable across nodes). For very large models, combine both (3D parallelism).

  4. Q: Explain gradient checkpointing. What’s the trade-off?

    A: Stores activations only at checkpoints (e.g., every \(k\) layers), recomputes intermediate activations during backward. Trade-off: Reduces memory by \(\sim\)10\(\times\) but increases compute by \(\sim\)2\(\times\) (one forward pass + one recompute).

  5. Q: Why use bf16 instead of fp16 for mixed precision?

    A: BF16 has same dynamic range as fp32 (8 exponent bits) → no gradient underflow, no loss scaling needed. FP16 has limited range (5 exponent bits) → requires loss scaling to avoid underflow. Modern GPUs (A100, H100) support bf16 natively.

  6. Q: How does Flash Attention reduce memory?

    A: Standard attention materializes \(QK^T\) matrix (\(O(N^2)\) memory). Flash Attention uses kernel fusion and tiling to compute attention in blocks, never storing full \(QK^T\)–reduces memory to \(O(N)\) and speeds up by 2-4\(\times\) via fewer HBM reads/writes.

  7. Q: Explain the All-Reduce algorithm used in DDP.

    A: Ring All-Reduce: GPUs arranged in logical ring, two phases: (1) Reduce-Scatter–each GPU sends chunks, accumulates sums; (2) All-Gather–broadcast accumulated chunks. Bandwidth-optimal: \(O(2M / \text{bandwidth})\) independent of \(N\).

  8. Q: What’s the memory breakdown for training a Transformer model?

    A: For \(P\) parameters:

    • Weights: \(2P\) (fp16)

    • Gradients: \(2P\) (fp16)

    • Optimizer (Adam): \(12P\) (fp32 momentum + variance)

    • Activations: Depends on batch size and sequence length–often \(>10P\)

  9. Q: How would you train a 175B model?

    A: Use 3D parallelism: (1) Tensor parallelism (8-way, within node), (2) Pipeline parallelism (8 stages, across nodes), (3) Data parallelism (16 replicas). Total: \(8 \times 8 \times 16 = 1024\) GPUs. Use ZeRO-1 or FSDP (SHARD_GRAD_OP) for data parallelism. Enable gradient checkpointing, bf16, Flash Attention.

  10. Q: What causes pipeline bubbles and how do you minimize them?

    A: Bubbles occur when pipeline stages wait for previous stages. Minimize by: (1) Increase micro-batch count (bubble fraction \(\sim (N-1)/M\)), (2) Use 1F1B schedule (interleave forward/backward), (3) Balance stage compute times (even layer distribution).

  11. Q: Explain tensor contiguity in PyTorch. Why does view sometimes fail?

    A: A tensor is contiguous when its data is laid out in one uninterrupted block of memory with the expected stride order (row-major by default). Operations like transpose(), slice[::2], or certain views can produce non-contiguous tensors–they reference the same underlying data but with different strides.

    view() only works on contiguous tensors because it just changes the shape metadata without moving data. If a tensor isn’t contiguous, you must:

    • Call .contiguous().view(...) to create a contiguous copy first

    • Use .reshape(...) which automatically copies if needed

    Example:

    x = torch.randn(3, 4)
    y = x.transpose(0, 1)  # Non-contiguous (strides changed)
    # y.view(-1) would fail!
    z = y.contiguous().view(-1)  # Works
    # Or: z = y.reshape(-1)  # Also works (copies if needed)
  12. Q: Does view work differently in distributed training (DDP)?

    A: No–view works the same way in DDP. It’s a local tensor operation that just requires contiguity. In DDP:

    • Each process has its own replica of the model

    • view operates on the local tensor shard

    • If a tensor becomes non-contiguous (e.g., after transpose, gather, or custom ops), use .contiguous().view(...)

    Caveat with DTensor: Newer PyTorch distributed APIs like DTensor may have restrictions on view-like operations depending on sharding strategy, but standard DDP tensors follow normal contiguity rules.

  13. Q: When should you worry about contiguity in practice?

    A: Be conscious of contiguity when:

    • After transpose(), permute(), or non-trivial slicing

    • Before view()–it will throw RuntimeError if non-contiguous

    • When passing tensors to C++/CUDA kernels that assume contiguous layout

    • In tight training loops–unnecessary .contiguous() copies waste memory/time

    Best practice:

    • Use .reshape(...) instead of .view(...) if you’re unsure (it handles contiguity automatically)

    • Check contiguity with tensor.is_contiguous() when debugging

    • Avoid calling .contiguous() unnecessarily–it creates a copy

28.2 Key Takeaways for Interviews

Note

Critical Concepts to Master:

  • DDP: Multi-process data parallelism with All-Reduce

  • ZeRO/FSDP: Sharding optimizer, gradients, parameters to reduce memory

  • Tensor Parallelism: Split layers across GPUs (intra-layer)

  • Pipeline Parallelism: Split model into stages (inter-layer)

  • Mixed Precision: Use fp16/bf16 for compute, fp32 for updates

  • Gradient Checkpointing: Trade memory for recomputation

  • Flash Attention: Fused kernel to reduce attention memory

Note

Modern LLM Training Stack:

  • Framework: PyTorch FSDP or DeepSpeed ZeRO

  • Precision: BF16 (A100/H100) or FP16 with loss scaling

  • Optimizer: AdamW with warmup + cosine decay

  • Memory: Gradient checkpointing + Flash Attention 2

  • Communication: NCCL (intra-node), InfiniBand (inter-node)

29 References & Further Reading

  • ZeRO (Rajbhandari et al., 2020): ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

  • FSDP (Meta AI, 2021): PyTorch Fully Sharded Data Parallel

  • Megatron-LM (Shoeybi et al., 2019): Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

  • GPipe (Huang et al., 2019): GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism

  • Flash Attention (Dao et al., 2022): FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • Mixed Precision (Micikevicius et al., 2018): Mixed Precision Training

  • DeepSpeed Documentation: https://www.deepspeed.ai/

  • PyTorch FSDP Tutorial: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html