Applying μTransfer to Scale Sparse Autoencoders
26 September 2024
In this post, we’ll explore how to apply μTransfer (Maximal Update Parameterization) to scale a Sparse Autoencoder (SAE), by increasing the hidden dimension (d_sae
) of your SAE (i.e., increase the expansion factor). We want to apply μTransfer to ensure consistent training dynamics across different scales of the hidden dimension.
This work is heavily based on The Practitioner’s Guide to the Maximal Update Parameterization (please read this first). By applying μTransfer scaling to a model:
- Consistent Training Dynamics: Ensures activations, gradients, and weight updates remain consistent across expansion factors.
- Simplified Scaling: Only parameters connected to the scaled dimension (
d_sae
) are adjusted. - Stable Training: Prevents issues like exploding activations / gradients.
- Stable Hyperparameters: No need to sweep learning rates, however this method doesn’t solve regularization parameters.
TL;DR: Jump to Scaling Rules to find how to scale initialization & learning rates.
Original SAE
Our SAE model, JumpReLUSAE
, from Gemma Scope from Scratch (TODO: Gated SAE reference)
import torch
import torch.nn as nn
class JumpReLUSAE(nn.Module):
def __init__(self, d_model, d_sae, sigma=0.02):
super().__init__()
# Initialize parameters
self.W_enc = nn.Parameter(torch.randn(d_model, self.d_sae) * sigma)
self.W_dec = nn.Parameter(torch.randn(self.d_sae, d_model) * sigma)
self.threshold = nn.Parameter(torch.randn(self.d_sae) * sigma)
self.b_enc = nn.Parameter(torch.randn(self.d_sae) * sigma)
self.b_dec = nn.Parameter(torch.randn(d_model) * sigma)
def encode(self, input_acts):
= input_acts @ self.W_enc + self.b_enc
pre_acts = (pre_acts > self.threshold)
mask = mask * torch.relu(pre_acts)
acts return acts
def decode(self, acts):
return acts @ self.W_dec + self.b_dec
def forward(self, input_acts):
= self.encode(input_acts)
acts = self.decode(acts)
recon return recon
d_model
: Fixed input and output dimension of the SAE.d_sae
: Hidden dimension, determined byd_model * expansion_factor
.
Our goal is to scale up the expansion_factor
, (and therefore d_sae
) and apply μTransfer principles to give us information on scaled learning rates and initialization.
Applying μTransfer Scaling
Definitions
- Input and Output Dimension (Fixed): \(d_{\text{model}}\)
- Expansion Factor: \(\text{expansion\_factor}\)
- Hidden Dimension: \(d_{\text{sae}} = \text{expansion\_factor} \times d_{\text{model}}\)
- Base Initialization Variance: \(\sigma_{\text{base}}^2\)
- Base Learning Rate: \(\eta_{\text{base}}\)
- Width Multiplier: \(m_d = \text{expansion\_factor}\ /\ \text{expansion\_factor}_{\text{base}}\)
Scaling Principles
- Parameters connected to scaled dimensions: Adjust initialization variance and learning rate inversely with \(m_d\)
- Parameters connecting only to fixed dimensions: Keep the same.
- Output Scaling: Scale decoder’s output by \(\alpha_{\text{output}} = 1\ /\ m_d\).
[TODO I’m going to derive the initialization variance and learning rate for each of the weights based on the forwards and backwards pass]
Decoder Weights (W_dec
)
Dimensions: \(d_{\text{sae}} \times d_{\text{model}}\)
Reconstruction output: \[ \text{recon} = \text{acts} \times W_{\text{dec}} + b_{\text{dec}} \]
Forward Pass
To maintain consistent training dynamics across different \(d_{\text{sae}}\), we need the varience of \(\text{recon}\) to remain constant. We follow the derivation similar to Forward Pass at Initialization.
At initialization, \(\text{acts}\) and \(W_{\text{dec}}\) are independent and have zero mean.
\[ \text{Var}(\text{recon}) = \text{Var}(\text{acts} \times W_{\text{dec}}) = \text{Var}(\text{acts}) \times \text{Var}(W_{\text{dec}}) \times d_{\text{sae}} \]
Since \(d_{\text{sae}}\) scales with \(m_d\):
\[ d_{\text{sae}} = m_d \times d_{\text{sae, base}} \]
\[ \implies \text{Var}(\text{recon}) = \text{Var}(\text{acts}) \times \text{Var}(W_{\text{dec}}) \times (m_d \times d_{\text{sae, base}}) \]
To keep \(\text{Var}(\text{recon})\) constant across scales, we need to scale \(\text{Var}(W_{\text{dec}})\) inversely with \(m_d\)
\[ \text{Var}(W_{\text{dec}}) = \frac{\sigma^2_{\text{base}}}{m_d} \]
\[ \implies \text{Var}(\text{recon}) = \text{Var}(\text{acts}) \times \sigma^2_{\text{base}} \times d_{\text{sae, base}} \]
Backwards Pass
The gradient with respect to \(W_\text{dec}\):
\[ \nabla_{W_{\text{dec}}}\mathcal{L} = \text{acts}^\top \nabla_{\text{recon}}\mathcal{L} \]
The magnitude of \(\nabla_{W_{\text{dec}}}\mathcal{L}\) scales with \(d_\text{sae}\) since \(\text{acts}\) has dimension \(d_\text{sae}\)
To main consistent updates to \(\Delta W_\text{dec}\), we scale the learning rate inversely with \(m_d\):
\[ \eta_{W_\text{dec}} = \frac{\eta_\text{base}}{m_d} \]
This aligns with the derivations in the Effect of weight update on activations.
Output Scaling
Even with the above adjustments, correlations between \(\text{acts}\) and \(W_\text{dec}\) develop during training, causing \(\text{Var}(\text{recon})\) to grow with \(m_d\)
- Only neurons with positive pre-activation inputs become active
- During backpropagation, only active neurons contribute to weight updates
- Certain neurons consistently activate for specific inputs, reinforcing the weight-activation relationship.
- Due to the positive correlation, the variance of \(\text{recon}\) increases.
\[ \text{Var}(\text{recon}) \propto m_d \]
Therefore we apply a scaling factor:
\[ \text{recon} = (\text{acts} \times W_\text{dec} + b_\text{dec}) \times \alpha_\text{output}, \quad \alpha_\text{output} = \frac{1}{m_d} \]
Encoder Weights (W_enc
)
Dimensions: \(d_{\text{model}} \times d_{\text{sae}}\)
Forward Pass
Encoder computes pre-activations:
\[ \text{pre\_acts} = \text{input\_acts} \times W_\text{enc} + b_\text{enc} \]
Breaking down the matrix multiply:
\[ \text{pre\_acts}_i = \sum\limits^{d_\text{model}}_{k = 1}{\text{input\_acts}_k \times W_{\text{enc}, k, i} + b_{\text{enc}, i}} \]
A.k.a each \(\text{pre\_acts}\) is a sum over \(d_\text{model}\) terms. Since \(d_\text{model}\) is fixed, scaling \(d_\text{sae}\) does not affect the variance of \(\text{pre\_acts}\)
\[ \text{Var}(\text{pre\_acts}) = \text{Var}(\text{input\_acts}) \times \text{Var}(W_{\text{enc}}) \times d_{\text{model}} \]
Backwards Pass
While the forward pass variance remains constant, the backwards pass introduces scaling issues.
The gradient with respect to \(W_\text{enc}\):
\[ \nabla_{W_{\text{enc}}}\mathcal{L} = \text{input\_acts}^\top \nabla_{\text{pre\_acts}}\mathcal{L} \]
The gradient \(\nabla_{\text{pre\_acts}}\mathcal{L}\) has dimensions affected by \(d_\text{sae}\), causing the magnitude of \(\nabla_{W_{\text{enc}}}\mathcal{L}\) increases with \(m_d\).
To maintain consistent weight updates, we scale learning rate and initialization variance inversely with \(m_d\):
\[ \text{Var}(W_\text{enc}) = \frac{\sigma^2_\text{base}}{m_d} \quad \eta_{W_\text{enc}} = \frac{\eta_\text{base}}{m_d} \]
You can see more detail of this in Appendix 1 and ElutherAI’s backwards gradient pass
Encoder bias (b_enc
) & threshold
\(b_\text{enc}\) and \(\text{threshold}\) are directly connected to \(d_\text{sae}\). Therefore their activations and gradients scale with \(m_d\) similar to \(W_\text{dec}\)
Decoder Bias (b_dec
)
Connects directly to \(d_{\text{model}}\), no scaling is required.
Scaling Rules Summary
\[ m_d = \text{expansion\_factor}\ /\ \text{expansion\_factor}_\text{base} \]
Parameter | Initialization Variance | Learning Rate |
---|---|---|
Encoder Weights (\(W_{\text{enc}}\)) | \(\sigma^2_\text{base} / m_d\) | \(\eta_\text{base} / m_d\) |
Decoder Weights (\(W_{\text{dec}}\)) | \(\sigma^2_\text{base} / m_d\) | \(\eta_\text{base} / m_d\) |
Encoder Bias (\(b_{\text{enc}}\)) | \(\sigma^2_\text{base} / m_d\) | \(\eta_\text{base} / m_d\) |
threshold | \(\sigma^2_\text{base} / m_d\) | \(\eta_\text{base} / m_d\) |
Decoder Bias (\(b_{\text{dec}}\)) | \(\sigma^2_{\text{base}}\) | \(\eta_{\text{base}}\) |
Output Scaling | Multiply output by \(\alpha_{\text{output}} = 1 / {m_d}\) | N/A |
Updated SAE with μTransfer Scaling
import torch
import torch.nn as nn
import math
class JumpReLUSAE(nn.Module):
def __init__(self, d_model, expansion_factor, sigma_base=0.02, expansion_factor_base=None):
super().__init__()
self.d_model = d_model
self.expansion_factor = expansion_factor
self.expansion_factor_base = expansion_factor_base if expansion_factor_base is not None else expansion_factor
self.m_d = self.expansion_factor / self.expansion_factor_base # Width multiplier
self.d_sae = int(self.expansion_factor * self.d_model)
self.alpha_output = 1 / self.m_d
# Scale initialization variance inversely with m_d
= sigma_base / math.sqrt(self.m_d)
sigma_scaled
# Initialize parameters
self.W_enc = nn.Parameter(torch.randn(d_model, self.d_sae) * sigma_scaled)
self.W_dec = nn.Parameter(torch.randn(self.d_sae, d_model) * sigma_scaled)
self.threshold = nn.Parameter(torch.randn(self.d_sae) * sigma_scaled)
self.b_enc = nn.Parameter(torch.randn(self.d_sae) * sigma_scaled)
self.b_dec = nn.Parameter(torch.randn(d_model) * sigma_base)
def encode(self, input_acts):
= input_acts @ self.W_enc + self.b_enc
pre_acts = (pre_acts > self.threshold)
mask = mask * torch.relu(pre_acts)
acts return acts
def decode(self, acts):
= acts @ self.W_dec + self.b_dec
recon = recon * self.alpha_output # Apply output scaling
recon return recon
def forward(self, input_acts):
= self.encode(input_acts)
acts = self.decode(acts)
recon return recon
= 768 # GPT-2
d_model = 16
expansion_factor_base = 32
expansion_factor = 0.02
sigma_base = 1e-3
lr_base
= JumpReLUSAE(d_model, expansion_factor, sigma_base, expansion_factor_base)
sae
= [sae.W_enc, sae.W_dec, sae.b_enc, sae.threshold]
scaled_params = [sae.b_dec]
fixed_params
= torch.optim.AdamW([
optimizer 'params': scaled_params, 'lr': lr_base / sae.m_d},
{'params': fixed_params, 'lr': lr_base}
{ ])
References
Appendices
Appendix 1: Deriving \(W_\text{enc}\) backwards pass
The gradient with respect to \(W_\text{enc}\):
\[ \nabla_{W_{\text{enc}}}\mathcal{L} = \text{input\_acts}^\top \nabla_{\text{pre\_acts}}\mathcal{L} \]
- \(\text{input\_acts}^{\top}\) has dimensions \(d_\text{model} \times \text{batch\_size}\).
- \(\nabla_{\text{pre\_acts}}\mathcal{L}\) has dimensions \(\text{batch\_size} \times d_\text{sae}\).
- \(\implies \nabla_{W_\text{enc}}\mathcal{L}\) has dimensions \(d_\text{model} \times d_\text{sae}\).
Gradient w.r.t \(\text{pre\_acts}\):
\[ \nabla_{\text{pre\_acts}}\mathcal{L} = \nabla_{\text{acts}}\mathcal{L} \odot \nabla_{\text{pre\_acts}}\text{acts} \]
- \(\nabla_{\text{pre\_acts}}\text{acts}\) is a binary mask of active neurons.
Variance of \(\nabla_{\text{pre\_acts}}\mathcal{L}\):
\[ \text{Var}(\nabla_{\text{pre\_acts}}\mathcal{L}) = \text{Var}(\nabla_{\text{acts}}\mathcal{L}) \times p_\text{active} \]
- \(p_\text{active}\) is the probability that a neuron is active (approximately constant).
Dependence on Decoder Weights (\(W_\text{dec}\))
\[ \nabla_{\text{acts}}\mathcal{L} = \nabla_{\text{recon}}\mathcal{L} \times W_\text{dec}^{\top} \]
\(W_\text{dec}^{\top}\) has dimensions \(d_\text{model} \times d_\text{sae}\).
As \(d_\text{sae}\) increases, \(W_\text{dec}\) becomes larger, affecting the variance of \(\nabla_{\text{acts}}\mathcal{L}\) and consequently \(\nabla_{\text{pre\_acts}}\mathcal{L}\).
Variance of Each Element in \(\nabla_{W_\text{enc}}\mathcal{L}\)
- Each element of \(\nabla_{W_\text{enc}}\mathcal{L}\) is computed as:
\[ [\nabla_{W_\text{enc}}\mathcal{L}]_{ij} = \sum_{n=1}^{\text{batch\_size}} \text{input\_acts}_{ni} \times [\nabla_{\text{pre\_acts}}\mathcal{L}]_{nj} \]
The variance of each element depends on \(\text{Var}(\text{input\_acts})\) and \(\text{Var}(\nabla_{\text{pre\_acts}}\mathcal{L})\).
Since \(\text{Var}(\nabla_{\text{pre\_acts}}\mathcal{L})\) increases with \(d_\text{sae}\) (due to \(W_\text{dec}\)), the variance of each element in \(\nabla_{W_\text{enc}}\mathcal{L}\) also increases with \(d_\text{sae}\).
Scaling Total Number of Elements
Total number of elements in \(\nabla_{W_\text{enc}}\mathcal{L}\) is \(d_\text{model} \times d_\text{sae}\) (derived above).
- As \(d_\text{sae}\) increases, the total number of elements increases proportionally, scaling the magnitude of the gradient.