6 Chapter 5: Weight Initialization for Deep Learning
7 Introduction
Proper initialization prevents vanishing/exploding gradients and accelerates training. Poor initialization can cause:
Symmetry: All neurons learn the same features (if initialized identically)
Saturation: Activations pushed to flat regions where gradients \(\to 0\)
Gradient Explosion/Vanishing: Gradients grow or shrink exponentially through layers
Key Principle: Maintain activation and gradient variance across layers during initialization.
Why This Matters:
Deep networks (50+ layers) are extremely sensitive to initialization. A bad scheme can make training impossible–loss stuck at random guessing or gradients vanishing to zero within a few iterations.
8 Historical Evolution
8.1 Early Era (pre-2010)
Zero or Constant Initialization:
All weights set to the same value (e.g., 0 or 0.1)
Problem: Symmetry–all neurons compute identical functions and receive identical gradients
Network degenerates to single neuron per layer
Small Random Initialization:
\(W \sim \mathcal{N}(0, 0.01)\) (normal with small variance)
Worked for shallow nets (2-3 layers)
Problem: Deep nets suffered gradient vanishing–activations shrink exponentially through layers
8.2 Xavier/Glorot Initialization (2010)
Designed for sigmoid and tanh activations. Maintains variance of activations and gradients across layers.
Setup:
Linear layer: \(y = Wx\) where \(W \in \mathbb{R}^{n_{\text{out}} \times n_{\text{in}}}\)
\(n_{\text{in}}\) = number of input features (fan-in)
\(n_{\text{out}}\) = number of output features (fan-out)
Example: 512 → 256 layer has \(n_{\text{in}}=512, n_{\text{out}}=256\)
Derivation (Forward Pass):
Assume inputs \(x_i\) have \(\text{Var}(x_i) = 1\), zero mean, i.i.d.
Each output: \(y_j = \sum_{i=1}^{n_{\text{in}}} W_{ji} x_i\)
Variance: \(\text{Var}(y_j) = \sum_{i=1}^{n_{\text{in}}} \text{Var}(W_{ji} x_i) = n_{\text{in}} \cdot \text{Var}(W)\) (independence)
Want \(\text{Var}(y_j) = 1\) (same as input) \(\Rightarrow\) set \(\text{Var}(W) = 1/n_{\text{in}}\)
Backward Pass Constraint:
Gradient backprop: \(\frac{\partial L}{\partial x} = W^T \frac{\partial L}{\partial y}\)
Similar analysis: want \(\text{Var}(\frac{\partial L}{\partial x}) = \text{Var}(\frac{\partial L}{\partial y})\)
Requires \(\text{Var}(W) = 1/n_{\text{out}}\)
Compromise: Average the two \(\Rightarrow \text{Var}(W) = \frac{2}{n_{\text{in}} + n_{\text{out}}}\)
Formulas:
Uniform: \(W \sim \mathcal{U}\left[-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}\right]\)
Gaussian: \(W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}}\right)\)
Limitations:
Assumes activations are approximately linear (true for sigmoid/tanh near zero)
Breaks for ReLU: ReLU zeroes out half the activations–variance drops by \(\sim 50\%\) per layer
8.3 He Initialization (2015)
Designed for ReLU and its variants (Leaky ReLU, PReLU). Accounts for ReLU killing half the activations.
The ReLU Problem with Xavier:
ReLU: \(\text{ReLU}(x) = \max(0, x)\) zeroes out negative values
For symmetric distribution (e.g., \(W_{ji}x_i \sim \mathcal{N}(0, \sigma^2)\)), half become zero
Post-ReLU variance: \(\text{Var}(\text{ReLU}(x)) \approx \frac{1}{2} \text{Var}(x)\)
Derivation:
Pre-activation: \(z_j = \sum_{i=1}^{n_{\text{in}}} W_{ji} x_i\) has \(\text{Var}(z_j) = n_{\text{in}} \cdot \text{Var}(W)\)
Post-ReLU: \(y_j = \text{ReLU}(z_j)\) has \(\text{Var}(y_j) \approx \frac{1}{2} n_{\text{in}} \cdot \text{Var}(W)\)
Want \(\text{Var}(y_j) = 1\) \(\Rightarrow\) set \(\text{Var}(W) = 2/n_{\text{in}}\)
Factor of 2 compensates for ReLU zeroing half the activations
Why focus on \(n_{\text{in}}\) only?
ReLU’s unbounded activations make forward pass variance control critical (unlike bounded sigmoid/tanh)
Empirically, using only \(n_{\text{in}}\) (fan-in) works better than averaging with \(n_{\text{out}}\)
PyTorch default:
mode=‘fan_in’forkaiming_normal_
Formulas:
Gaussian (most common): \(W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{\text{in}}}}\right)\)
Uniform: \(W \sim \mathcal{U}\left[-\sqrt{\frac{6}{n_{\text{in}}}}, \sqrt{\frac{6}{n_{\text{in}}}}\right]\)
Impact:
Enabled training of very deep CNNs (ResNet-50, ResNet-152)
Became standard for computer vision
9 Modern Recipes (2024)
9.1 Standard Initialization by Architecture
| Architecture | Initialization | Notes |
|---|---|---|
| CNNs (ResNet, EfficientNet) | He Normal | ReLU activations |
| Transformers (BERT, GPT) | Xavier/Glorot | Layer norm stabilizes |
| Vision Transformers (ViT) | Truncated Normal | \(\sigma=0.02\), stabilizes patches |
| RNNs/LSTMs | Orthogonal + Xavier | Prevents exploding gradients |
| GANs (Generator) | Normal \(\mathcal{N}(0, 0.02)\) | Stabilizes early training |
| GANs (Discriminator) | He Normal | Leaky ReLU typical |
9.2 Transformer-Specific Practices
Modern transformers (BERT, GPT, LLaMA) use a combination of schemes:
Layer-by-Layer Recommendations:
Embedding layers: \(\mathcal{N}(0, 0.02)\) or Xavier Normal
Linear layers (Q/K/V, FFN): Xavier Normal (PyTorch
nn.Lineardefault)Output projection: Sometimes scaled by \(1/\sqrt{d_{\text{model}}}\) or \(1/\sqrt{\text{num\_layers}}\) to stabilize very deep stacks (GPT-3, PaLM)
LayerNorm: \(\gamma=1\) (scale), \(\beta=0\) (shift)–learnable parameters
Why Xavier for Transformers?
Layer normalization stabilizes activations, making the choice of activation function less critical. Xavier works well even though transformers use GELU/SwiGLU (not sigmoid/tanh). Modern frameworks (Hugging Face, DeepSpeed) handle this automatically.
Depth Scaling (Very Deep Models):
For 100+ layer transformers, scale output projections: \(W_{\text{out}} = W_{\text{out}} / \sqrt{2 \cdot \text{num\_layers}}\)
Prevents gradient explosion in residual connections
Used in GPT-3 (96 layers), PaLM (118 layers)
9.3 Convolutional Networks (ResNets)
Standard Practice:
Conv layers: He Normal (ReLU activations)
Batch norm: \(\gamma=1, \beta=0\)
Advanced (Very Deep Networks):
Fixup/ReZero: Initialize final batch norm in each residual block to \(\gamma=0\)
Effect: Residual branches initially output zero–network starts as identity mapping
Enables training 1000+ layer ResNets without batch norm
What does “initialize residual branch to zero” mean?
It does not mean all weights are zero. It means the residual branch scale starts at zero so the block behaves like identity. \[y = x + \alpha F(x), \quad \alpha = 0 \Rightarrow y = x\]
Implementations:
ReZero: add a learnable scalar \(\alpha\) per block and initialize \(\alpha=0\).
Fixup: initialize residual path (often last BN scale \(\gamma\)) so its output is near zero.
Modern LLMs: scale residual outputs by \(1/\sqrt{L}\) (or \(1/\sqrt{2L}\)) to keep deep stacks stable.
LayerNorm: \(\gamma=1, \beta=0\) preserves scale (identity-like), but is not a zeroed branch.
9.4 Recurrent Networks (RNNs/LSTMs)
Special Considerations:
Recurrent weights: Orthogonal initialization (prevents exploding gradients through time)
Input-to-hidden: Xavier Normal
LSTM forget gate bias: Initialize to 1 or 2 (keeps long-term memory early in training)
Example: LSTM Initialization in PyTorch
lstm = nn.LSTM(input_size, hidden_size)
for name, param in lstm.named_parameters():
if `weight_ih' in name:
nn.init.xavier_uniform_(param)
elif `weight_hh' in name:
nn.init.orthogonal_(param)
elif `bias' in name:
nn.init.zeros_(param)
# Set forget gate bias to 1
n = param.size(0)
param[n//4:n//2].fill_(1.0)
10 Special Cases
10.1 Bias Terms
Standard Practice:
Hidden layers: Initialize to zero
Output layer (classification): Can initialize to \(\log(\text{class prior})\) for faster convergence
Example: Binary classification with 90% class 0, 10% class 1–initialize bias to \(\log(0.1/0.9) \approx -2.2\)
10.2 Normalization Layers
Batch Normalization / Layer Normalization:
\(\gamma\) (scale): Initialize to 1
\(\beta\) (shift): Initialize to 0
Effect: Initially acts as identity–no transformation until learning begins
Benefit: Reduces sensitivity to weight initialization in earlier layers
10.3 Transfer Learning / Fine-tuning
Recommendations:
Pretrained weights: Keep as-is (already trained on large datasets)
New layers (e.g., classification head): Use Xavier or He depending on activation
Fine-tuning learning rate: Typically \(10^{-5}\) to \(10^{-6}\) (much lower than pre-training)
10.4 Generative Models (GANs, VAEs)
GANs:
Generator: \(\mathcal{N}(0, 0.02)\) (stabilizes early training, prevents mode collapse)
Discriminator: He Normal (typically uses Leaky ReLU)
VAEs:
Encoder/Decoder: Xavier or He depending on activation
Mean/Variance layers: Xavier (final layers of encoder)
11 Practical Implementation
11.1 PyTorch Defaults
PyTorch’s nn.Linear, nn.Conv2d, etc., come with built-in initialization:
Linear: Kaiming Uniform (variant of He) by default
Conv2d: Kaiming Uniform
Embedding: \(\mathcal{N}(0, 1)\)–often overridden to \(\mathcal{N}(0, 0.02)\) in transformers
Recommendation: For most cases, PyTorch defaults work well. Override only for specific architectures (e.g., transformers, GANs).
11.2 Common Initialization Functions
PyTorch Initialization Examples:
import torch.nn as nn
# He Normal
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
# Xavier Normal
nn.init.xavier_normal_(layer.weight)
# Custom normal
nn.init.normal_(layer.weight, mean=0, std=0.02)
# Truncated normal (ViT-style)
nn.init.trunc_normal_(layer.weight, mean=0, std=0.02, a=-2*0.02, b=2*0.02)
# Orthogonal (RNNs)
nn.init.orthogonal_(layer.weight)
# Zeros (biases)
nn.init.zeros_(layer.bias)
12 Interview Cheat Sheet
“Why not initialize all weights to zero?”
Symmetry problem–all neurons compute the same function and receive same gradients
Network degenerates to a single neuron per layer
Breaks gradient descent–no way to learn diverse features
“Xavier vs He–when to use which?”
Xavier: Sigmoid/tanh activations (rare today) or transformers with layer norm
He: ReLU and variants (modern default for CNNs)
Rule of thumb: If using ReLU in CNNs → He; if transformer with layer norm → Xavier
“What’s the initialization for GPT/BERT?”
Embeddings: \(\mathcal{N}(0, 0.02)\)
Linear layers (Q/K/V, FFN): Xavier Normal (PyTorch
nn.Lineardefault)LayerNorm: \(\gamma=1, \beta=0\)
Output projection (very deep models): Scaled by \(1/\sqrt{\text{num\_layers}}\)
“How does batch norm affect initialization?”
Normalizes activations–makes training less sensitive to initialization
Can use wider range of initialization schemes without breaking training
Still use He/Xavier for best practice–provides good starting point
“What about very deep networks (100+ layers)?”
Standard init + layer norm usually sufficient for transformers
ResNets: Use Fixup/ReZero (initialize residual branches to zero)
Transformers: Scale output projections by \(1/\sqrt{\text{num\_layers}}\) (GPT-3, PaLM)
“When does initialization matter most?”
Very deep networks (50+ layers): Critical–bad init prevents convergence
Networks without normalization (no batch norm/layer norm): Very sensitive
GANs: Sensitive to init–affects stability and mode collapse
Transfer learning: Less critical–pretrained weights already good
13 Summary
Historical Evolution:
Random Small Weights (pre-2010): \(\mathcal{N}(0, 0.01)\)–failed for deep nets
Xavier/Glorot (2010): Designed for sigmoid/tanh–maintains variance across layers
He (2015): Designed for ReLU–accounts for half-zero activations
Modern Variants (2020+): Depth scaling, truncated normal, orthogonal for specific architectures
Modern Best Practices (2024):
CNNs (ResNet, EfficientNet): He Normal
Transformers (BERT, GPT, LLaMA): Xavier + \(\mathcal{N}(0, 0.02)\) embeddings
Vision Transformers (ViT): Truncated Normal \(\sigma=0.02\)
RNNs/LSTMs: Orthogonal (recurrent weights) + Xavier (input weights)
GANs: \(\mathcal{N}(0, 0.02)\) for generator, He for discriminator
Very Deep (100+ layers): Add depth scaling to output projections
Key Takeaways:
Initialization prevents vanishing/exploding gradients in early training
Modern frameworks (PyTorch, TensorFlow) provide good defaults–override only when needed
Layer norm/batch norm reduces sensitivity–but proper init still helps
Transfer learning: Keep pretrained weights, initialize only new layers
For questions, corrections, or suggestions: peymanr@gmail.com