Sparse Autoencoders are a promising unsupervised approach to identify which vectors are meaningful in the internal activations of a Language Model.
They’ve been shown to often find causally relevant, interpretable directions.
They still have some problems, like:
- validating SAEs as an approach;
- learning how to measure their performance;
- learning how to train SAEs at a scale;
Formally
Given activations from a Language model, a Sparse Autoencoder SAEs decomposes and re-constructs the activations using a pair of encoder e decoder function ( defined by:
These functions are trained to map back to making them an autoencoder.
Thus, is a set of linear weights that specify how to combine the columns of to reproduce .
The columns of , which we denote by for , represent the dictionary of directions into which the SAE decompose . These are referred to as latent to disambiguate between the learnt features and the conceptual features which are hypothesized to comprise the language model’s representation vectors.
The decomposition is made non-negative and sparse trough the choice of activation function and appropriate regularization, such that typically has much fewer than n non-zero entries1.
JumpReLU SAEs
In @lieberumGemmaScopeOpen2024 they’ve focused on JumpReLU SAEs, as they have been shown to be a slight Pareto Improvement2
JumpReLU activation. The JumpReLU activation is a shifted Heaviside step function as a gating mechanism together with a conventional ReLU:
Here, is the JumpReLU’s vector-valued learnable threshold parameter, is an element-wise multiplication, and is the Heaviside step function, which is 1 if the input is positive, 0 otherwise.
Intuitively, the JumpReLU leaves pre-activations unchanged above the threshold , but sets them to zero below the threshold, with a different threshold per latent. So basically is just a ReLU, but instead of setting to zero everything negative, it sets to zero everything below which it’s different and learnt for every latent dimension.
Loss Function. As the loss function a errore reconstruction loss is used and the number of active (non-zero) latent is regularized trough a L0 penalty:
where is the sparsity penalty coefficient. Since the L0 penalty and JumpReLU activation function are piece-wise constant w.r.t. threshold parameter , they use straight-trough estimators (STEs) to train (Rajamanoharan et al., 2024b). This introduces an additional hyper-parameter, the kernel density estimator bandwidth , which controls the quality of the gradient estimates used to train the threshold parameters 3.
Footnotes
-
Initial Work (Cunningham et al., 2023; Bricken et al., 2023) used a ReLU activation function to enforce non-negativity and a L1 penalty on the decomposition to encourage sparsity. JumpReLU enforces sparsity by zeroing out all entries of below a positive threshold. TopK SAE instead zeroes out all but the top K entries of . ↩
-
A Pareto improvement is an improvement to a system when a change in allocation of goods harms no one and benefits at least one person. Pareto improvements are also referred to as “no-brainers” and are generally expected to be rare, due to the obvious and powerful incentive to make any available Pareto improvement.] over other approaches, and allow for a variable number of active latents at different tokens (unlike TopK SAEs). ↩
-
A large value of ε results in biased but low variance estimates, leading to SAEs with good sparsity but sub-optimal fidelity, whereas a low value of ε results in high variance estimates that cause the threshold to fail to train at all, resulting in SAEs that fail to be sparse. We find through hyperparameter sweeps across multiple layers and sites that ε = 0.001 provides a good trade-off (when SAE inputs are normalized to have an unit mean squared norm) and use this to train the SAEs released as part of Gemma Scope. ↩