6  Chapter 5: Weight Initialization for Deep Learning

7 Introduction

Proper initialization prevents vanishing/exploding gradients and accelerates training. Poor initialization can cause:

  • Symmetry: All neurons learn the same features (if initialized identically)

  • Saturation: Activations pushed to flat regions where gradients \(\to 0\)

  • Gradient Explosion/Vanishing: Gradients grow or shrink exponentially through layers

Key Principle: Maintain activation and gradient variance across layers during initialization.

Note

Why This Matters:
Deep networks (50+ layers) are extremely sensitive to initialization. A bad scheme can make training impossible–loss stuck at random guessing or gradients vanishing to zero within a few iterations.

8 Historical Evolution

8.1 Early Era (pre-2010)

Zero or Constant Initialization:

  • All weights set to the same value (e.g., 0 or 0.1)

  • Problem: Symmetry–all neurons compute identical functions and receive identical gradients

  • Network degenerates to single neuron per layer

Small Random Initialization:

  • \(W \sim \mathcal{N}(0, 0.01)\) (normal with small variance)

  • Worked for shallow nets (2-3 layers)

  • Problem: Deep nets suffered gradient vanishing–activations shrink exponentially through layers

8.2 Xavier/Glorot Initialization (2010)

Designed for sigmoid and tanh activations. Maintains variance of activations and gradients across layers.

Setup:

  • Linear layer: \(y = Wx\) where \(W \in \mathbb{R}^{n_{\text{out}} \times n_{\text{in}}}\)

  • \(n_{\text{in}}\) = number of input features (fan-in)

  • \(n_{\text{out}}\) = number of output features (fan-out)

  • Example: 512 → 256 layer has \(n_{\text{in}}=512, n_{\text{out}}=256\)

Derivation (Forward Pass):

  • Assume inputs \(x_i\) have \(\text{Var}(x_i) = 1\), zero mean, i.i.d.

  • Each output: \(y_j = \sum_{i=1}^{n_{\text{in}}} W_{ji} x_i\)

  • Variance: \(\text{Var}(y_j) = \sum_{i=1}^{n_{\text{in}}} \text{Var}(W_{ji} x_i) = n_{\text{in}} \cdot \text{Var}(W)\) (independence)

  • Want \(\text{Var}(y_j) = 1\) (same as input) \(\Rightarrow\) set \(\text{Var}(W) = 1/n_{\text{in}}\)

Backward Pass Constraint:

  • Gradient backprop: \(\frac{\partial L}{\partial x} = W^T \frac{\partial L}{\partial y}\)

  • Similar analysis: want \(\text{Var}(\frac{\partial L}{\partial x}) = \text{Var}(\frac{\partial L}{\partial y})\)

  • Requires \(\text{Var}(W) = 1/n_{\text{out}}\)

  • Compromise: Average the two \(\Rightarrow \text{Var}(W) = \frac{2}{n_{\text{in}} + n_{\text{out}}}\)

Formulas:

  • Uniform: \(W \sim \mathcal{U}\left[-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}\right]\)

  • Gaussian: \(W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}}\right)\)

Limitations:

  • Assumes activations are approximately linear (true for sigmoid/tanh near zero)

  • Breaks for ReLU: ReLU zeroes out half the activations–variance drops by \(\sim 50\%\) per layer

8.3 He Initialization (2015)

Designed for ReLU and its variants (Leaky ReLU, PReLU). Accounts for ReLU killing half the activations.

The ReLU Problem with Xavier:

  • ReLU: \(\text{ReLU}(x) = \max(0, x)\) zeroes out negative values

  • For symmetric distribution (e.g., \(W_{ji}x_i \sim \mathcal{N}(0, \sigma^2)\)), half become zero

  • Post-ReLU variance: \(\text{Var}(\text{ReLU}(x)) \approx \frac{1}{2} \text{Var}(x)\)

Derivation:

  • Pre-activation: \(z_j = \sum_{i=1}^{n_{\text{in}}} W_{ji} x_i\) has \(\text{Var}(z_j) = n_{\text{in}} \cdot \text{Var}(W)\)

  • Post-ReLU: \(y_j = \text{ReLU}(z_j)\) has \(\text{Var}(y_j) \approx \frac{1}{2} n_{\text{in}} \cdot \text{Var}(W)\)

  • Want \(\text{Var}(y_j) = 1\) \(\Rightarrow\) set \(\text{Var}(W) = 2/n_{\text{in}}\)

  • Factor of 2 compensates for ReLU zeroing half the activations

Why focus on \(n_{\text{in}}\) only?

  • ReLU’s unbounded activations make forward pass variance control critical (unlike bounded sigmoid/tanh)

  • Empirically, using only \(n_{\text{in}}\) (fan-in) works better than averaging with \(n_{\text{out}}\)

  • PyTorch default: mode=‘fan_in’ for kaiming_normal_

Formulas:

  • Gaussian (most common): \(W \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{\text{in}}}}\right)\)

  • Uniform: \(W \sim \mathcal{U}\left[-\sqrt{\frac{6}{n_{\text{in}}}}, \sqrt{\frac{6}{n_{\text{in}}}}\right]\)

Impact:

  • Enabled training of very deep CNNs (ResNet-50, ResNet-152)

  • Became standard for computer vision

9 Modern Recipes (2024)

9.1 Standard Initialization by Architecture

Recommended initialization schemes for common architectures.
Architecture Initialization Notes
CNNs (ResNet, EfficientNet) He Normal ReLU activations
Transformers (BERT, GPT) Xavier/Glorot Layer norm stabilizes
Vision Transformers (ViT) Truncated Normal \(\sigma=0.02\), stabilizes patches
RNNs/LSTMs Orthogonal + Xavier Prevents exploding gradients
GANs (Generator) Normal \(\mathcal{N}(0, 0.02)\) Stabilizes early training
GANs (Discriminator) He Normal Leaky ReLU typical

9.2 Transformer-Specific Practices

Modern transformers (BERT, GPT, LLaMA) use a combination of schemes:

Layer-by-Layer Recommendations:

  • Embedding layers: \(\mathcal{N}(0, 0.02)\) or Xavier Normal

  • Linear layers (Q/K/V, FFN): Xavier Normal (PyTorch nn.Linear default)

  • Output projection: Sometimes scaled by \(1/\sqrt{d_{\text{model}}}\) or \(1/\sqrt{\text{num\_layers}}\) to stabilize very deep stacks (GPT-3, PaLM)

  • LayerNorm: \(\gamma=1\) (scale), \(\beta=0\) (shift)–learnable parameters

Note

Why Xavier for Transformers?
Layer normalization stabilizes activations, making the choice of activation function less critical. Xavier works well even though transformers use GELU/SwiGLU (not sigmoid/tanh). Modern frameworks (Hugging Face, DeepSpeed) handle this automatically.

Depth Scaling (Very Deep Models):

  • For 100+ layer transformers, scale output projections: \(W_{\text{out}} = W_{\text{out}} / \sqrt{2 \cdot \text{num\_layers}}\)

  • Prevents gradient explosion in residual connections

  • Used in GPT-3 (96 layers), PaLM (118 layers)

9.3 Convolutional Networks (ResNets)

Standard Practice:

  • Conv layers: He Normal (ReLU activations)

  • Batch norm: \(\gamma=1, \beta=0\)

Advanced (Very Deep Networks):

  • Fixup/ReZero: Initialize final batch norm in each residual block to \(\gamma=0\)

  • Effect: Residual branches initially output zero–network starts as identity mapping

  • Enables training 1000+ layer ResNets without batch norm

Note

What does “initialize residual branch to zero” mean?
It does not mean all weights are zero. It means the residual branch scale starts at zero so the block behaves like identity. \[y = x + \alpha F(x), \quad \alpha = 0 \Rightarrow y = x\]

image

Implementations:

  • ReZero: add a learnable scalar \(\alpha\) per block and initialize \(\alpha=0\).

  • Fixup: initialize residual path (often last BN scale \(\gamma\)) so its output is near zero.

  • Modern LLMs: scale residual outputs by \(1/\sqrt{L}\) (or \(1/\sqrt{2L}\)) to keep deep stacks stable.

  • LayerNorm: \(\gamma=1, \beta=0\) preserves scale (identity-like), but is not a zeroed branch.

9.4 Recurrent Networks (RNNs/LSTMs)

Special Considerations:

  • Recurrent weights: Orthogonal initialization (prevents exploding gradients through time)

  • Input-to-hidden: Xavier Normal

  • LSTM forget gate bias: Initialize to 1 or 2 (keeps long-term memory early in training)

TipExample

Example: LSTM Initialization in PyTorch

lstm = nn.LSTM(input_size, hidden_size)
for name, param in lstm.named_parameters():
    if `weight_ih' in name:
        nn.init.xavier_uniform_(param)
    elif `weight_hh' in name:
        nn.init.orthogonal_(param)
    elif `bias' in name:
        nn.init.zeros_(param)
        # Set forget gate bias to 1
        n = param.size(0)
        param[n//4:n//2].fill_(1.0)

10 Special Cases

10.1 Bias Terms

Standard Practice:

  • Hidden layers: Initialize to zero

  • Output layer (classification): Can initialize to \(\log(\text{class prior})\) for faster convergence

  • Example: Binary classification with 90% class 0, 10% class 1–initialize bias to \(\log(0.1/0.9) \approx -2.2\)

10.2 Normalization Layers

Batch Normalization / Layer Normalization:

  • \(\gamma\) (scale): Initialize to 1

  • \(\beta\) (shift): Initialize to 0

  • Effect: Initially acts as identity–no transformation until learning begins

  • Benefit: Reduces sensitivity to weight initialization in earlier layers

10.3 Transfer Learning / Fine-tuning

Recommendations:

  • Pretrained weights: Keep as-is (already trained on large datasets)

  • New layers (e.g., classification head): Use Xavier or He depending on activation

  • Fine-tuning learning rate: Typically \(10^{-5}\) to \(10^{-6}\) (much lower than pre-training)

10.4 Generative Models (GANs, VAEs)

GANs:

  • Generator: \(\mathcal{N}(0, 0.02)\) (stabilizes early training, prevents mode collapse)

  • Discriminator: He Normal (typically uses Leaky ReLU)

VAEs:

  • Encoder/Decoder: Xavier or He depending on activation

  • Mean/Variance layers: Xavier (final layers of encoder)

11 Practical Implementation

11.1 PyTorch Defaults

PyTorch’s nn.Linear, nn.Conv2d, etc., come with built-in initialization:

  • Linear: Kaiming Uniform (variant of He) by default

  • Conv2d: Kaiming Uniform

  • Embedding: \(\mathcal{N}(0, 1)\)–often overridden to \(\mathcal{N}(0, 0.02)\) in transformers

Recommendation: For most cases, PyTorch defaults work well. Override only for specific architectures (e.g., transformers, GANs).

11.2 Common Initialization Functions

TipExample

PyTorch Initialization Examples:

import torch.nn as nn

# He Normal
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')

# Xavier Normal
nn.init.xavier_normal_(layer.weight)

# Custom normal
nn.init.normal_(layer.weight, mean=0, std=0.02)

# Truncated normal (ViT-style)
nn.init.trunc_normal_(layer.weight, mean=0, std=0.02, a=-2*0.02, b=2*0.02)

# Orthogonal (RNNs)
nn.init.orthogonal_(layer.weight)

# Zeros (biases)
nn.init.zeros_(layer.bias)

12 Interview Cheat Sheet

“Why not initialize all weights to zero?”

  • Symmetry problem–all neurons compute the same function and receive same gradients

  • Network degenerates to a single neuron per layer

  • Breaks gradient descent–no way to learn diverse features

“Xavier vs He–when to use which?”

  • Xavier: Sigmoid/tanh activations (rare today) or transformers with layer norm

  • He: ReLU and variants (modern default for CNNs)

  • Rule of thumb: If using ReLU in CNNs → He; if transformer with layer norm → Xavier

“What’s the initialization for GPT/BERT?”

  • Embeddings: \(\mathcal{N}(0, 0.02)\)

  • Linear layers (Q/K/V, FFN): Xavier Normal (PyTorch nn.Linear default)

  • LayerNorm: \(\gamma=1, \beta=0\)

  • Output projection (very deep models): Scaled by \(1/\sqrt{\text{num\_layers}}\)

“How does batch norm affect initialization?”

  • Normalizes activations–makes training less sensitive to initialization

  • Can use wider range of initialization schemes without breaking training

  • Still use He/Xavier for best practice–provides good starting point

“What about very deep networks (100+ layers)?”

  • Standard init + layer norm usually sufficient for transformers

  • ResNets: Use Fixup/ReZero (initialize residual branches to zero)

  • Transformers: Scale output projections by \(1/\sqrt{\text{num\_layers}}\) (GPT-3, PaLM)

“When does initialization matter most?”

  • Very deep networks (50+ layers): Critical–bad init prevents convergence

  • Networks without normalization (no batch norm/layer norm): Very sensitive

  • GANs: Sensitive to init–affects stability and mode collapse

  • Transfer learning: Less critical–pretrained weights already good

13 Summary

Historical Evolution:

  1. Random Small Weights (pre-2010): \(\mathcal{N}(0, 0.01)\)–failed for deep nets

  2. Xavier/Glorot (2010): Designed for sigmoid/tanh–maintains variance across layers

  3. He (2015): Designed for ReLU–accounts for half-zero activations

  4. Modern Variants (2020+): Depth scaling, truncated normal, orthogonal for specific architectures

Modern Best Practices (2024):

  • CNNs (ResNet, EfficientNet): He Normal

  • Transformers (BERT, GPT, LLaMA): Xavier + \(\mathcal{N}(0, 0.02)\) embeddings

  • Vision Transformers (ViT): Truncated Normal \(\sigma=0.02\)

  • RNNs/LSTMs: Orthogonal (recurrent weights) + Xavier (input weights)

  • GANs: \(\mathcal{N}(0, 0.02)\) for generator, He for discriminator

  • Very Deep (100+ layers): Add depth scaling to output projections

Key Takeaways:

  • Initialization prevents vanishing/exploding gradients in early training

  • Modern frameworks (PyTorch, TensorFlow) provide good defaults–override only when needed

  • Layer norm/batch norm reduces sensitivity–but proper init still helps

  • Transfer learning: Keep pretrained weights, initialize only new layers



For questions, corrections, or suggestions: peymanr@gmail.com