Sparse Autoencoders for Gemma-3-27b-it
This repository contains 9 Sparse Autoencoders (SAEs) trained on google/gemma-3-27b-it using the BatchTopK architecture.
Architecture: BatchTopK SAE
These SAEs use the BatchTopK architecture, which enforces sparsity by:
- Computing feature activations:
z = Wx + b(encoder) - Selecting top-k features across the batch (not per-sample)
- Reconstructing:
x̂ = W'z_topk + b_dec(decoder)
This approach tends to produce more interpretable features than ReLU-based SAEs and has better training dynamics.
Repository Structure
layer_45/
dict_16k_k80/ # 16,384 features, k=80
ae.pt # SAE weights
config.json # Training configuration
feature_labels.json # Natural language feature descriptions
dict_16k_k160/ # 16,384 features, k=160
dict_65k_k80/ # 65,536 features, k=80
dict_65k_k160/ # 65,536 features, k=160
layer_47/
(same structure)
layer_45_mlp/
(same structure)
Available SAEs
| Layer | Dict Size | k | Activation Dim | Parameters | Sparsity |
|---|---|---|---|---|---|
| 45 | 16,384 | 80 | 5,376 | 176,182,528 | 0.49% |
| 45 | 16,384 | 160 | 5,376 | 176,182,528 | 0.98% |
| 45 | 65,536 | 80 | 5,376 | 704,713,984 | 0.12% |
| 45 | 65,536 | 160 | 5,376 | 704,713,984 | 0.24% |
| 47 | 16,384 | 80 | 5,376 | 176,182,528 | 0.49% |
| 47 | 16,384 | 160 | 5,376 | 176,182,528 | 0.98% |
| 47 | 65,536 | 80 | 5,376 | 704,713,984 | 0.12% |
| 47 | 65,536 | 160 | 5,376 | 704,713,984 | 0.24% |
Total Parameters: 3,523,586,048
Model Details
Training Details
Base Model: google/gemma-3-27b-it
Hook Point: residual_stream (post-layer activations)
Dataset: FineWeb (HuggingFaceFW/fineweb)
Training Hyperparameters:
- Optimizer: Adam
- Learning rate: 5e-5
- Warmup steps: 1,000
- Training steps: ~244,140
- Context length: 2,048 tokens
- Batch size: 2,048 activations
- Decay start: 195,312 steps
BatchTopK Parameters:
- Auxiliary loss coefficient (α): 0.03125
- Threshold decay (β): 0.999
- Threshold start step: 1,000
Sparsity Levels:
- k=80: Higher sparsity, more selective features
- k=160: Lower sparsity, more features active per sample
Dictionary Sizes:
- 16,384: Compact, efficient, good for resource-constrained analysis
- 65,536: Comprehensive, captures more fine-grained patterns
Feature Labels
This repository includes natural language descriptions for all features, generated using LLM-as-a-judge (GPT-4) on maximum activating examples. Each feature has:
- Title: Short description of what the feature detects
- Description: Detailed explanation with examples
- Examples: Token sequences that maximally activate the feature
Usage
Installation
pip install torch transformers huggingface_hub
Loading an SAE
import torch
from huggingface_hub import hf_hub_download
# Download specific SAE
ae_path = hf_hub_download(
repo_id="uzaymacar/gemma-3-27b-saes",
filename="layer_45/dict_16k_k80/ae.pt",
subfolder=None,
)
config_path = hf_hub_download(
repo_id="uzaymacar/gemma-3-27b-saes",
filename="layer_45/dict_16k_k80/config.json",
)
# Load SAE
ae_data = torch.load(ae_path, map_location='cpu')
with open(config_path, 'r') as f:
config = json.load(f)
print(f"Loaded SAE with {config['trainer']['dict_size']} features")
print(f"Activation dimension: {config['trainer']['activation_dim']}")
print(f"Top-k: {config['trainer']['k']}")
# SAE weights
encoder_weight = ae_data['encoder.weight'] # [dict_size, activation_dim]
encoder_bias = ae_data['encoder.bias'] # [dict_size]
decoder_weight = ae_data['decoder.weight'] # [activation_dim, dict_size]
decoder_bias = ae_data['b_dec'] # [activation_dim]
threshold = ae_data['threshold'] # Learned threshold
Using the SAE for Analysis
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
# Load base model
model_name = "google/gemma-3-27b-it"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Get activations from layer 45
text = "The capital of France is Paris"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
layer_45_acts = outputs.hidden_states[45] # [batch, seq, activation_dim]
# Encode with SAE
acts_flat = layer_45_acts.reshape(-1, layer_45_acts.shape[-1]) # [batch*seq, dim]
# Encoder: z = Wx + b
z = F.linear(acts_flat, encoder_weight, encoder_bias) # [batch*seq, dict_size]
# Top-k selection (per sample, not batch)
top_k = config['trainer']['k']
top_values, top_indices = torch.topk(z, k=top_k, dim=-1)
# Create sparse representation
z_topk = torch.zeros_like(z)
z_topk.scatter_(-1, top_indices, top_values)
# Decode: x̂ = W'z + b
reconstructed = F.linear(z_topk, decoder_weight.t(), decoder_bias)
# Compute reconstruction loss
mse_loss = F.mse_loss(reconstructed, acts_flat)
print(f"Reconstruction MSE: {mse_loss.item():.6f}")
# Find active features
active_features = top_indices[0, 0] # First token's active features
print(f"Active features for first token: {active_features.tolist()}")
Loading Feature Labels
import json
from huggingface_hub import hf_hub_download
# Download feature labels
labels_path = hf_hub_download(
repo_id="uzaymacar/gemma-3-27b-saes",
filename="layer_45/dict_16k_k80/feature_labels.json",
)
with open(labels_path, 'r') as f:
labels = json.load(f)
# Examine a specific feature
feature_id = 1234
if str(feature_id) in labels:
label = labels[str(feature_id)]
print(f"Feature {feature_id}:")
print(f" Title: {label.get('title', 'N/A')}")
print(f" Description: {label.get('description', 'N/A')}")
Citation
If you use these SAEs in your research, please cite:
@software{gemma3_27b_saes,
author = {Macar, Uzay},
title = {Sparse Autoencoders for Gemma-3-27b-it},
year = {2024},
url = {https://huggingface.co/uzaymacar/gemma-3-27b-saes}
}
SAE Training Framework:
@software{dictionary_learning,
author = {Marks, Samuel and others},
title = {Dictionary Learning for Mechanistic Interpretability},
year = {2024},
url = {https://github.com/saprmarks/dictionary_learning}
}
BatchTopK Architecture:
@article{gao2024batchTopK,
title={Scaling and evaluating sparse autoencoders},
author={Gao, Leo and others},
journal={arXiv preprint arXiv:2406.04093},
year={2024}
}
License
These SAEs are released under the same license as the base model (google/gemma-3-27b-it).
Acknowledgments
- Trained using dictionary_learning
- Base model: google/gemma-3-27b-it
- Training data: FineWeb
Contact
For questions or issues, please contact me at [email protected]
Note: These SAEs are research artifacts. While they provide valuable insights into model representations, they should be used as one tool among many for interpretability research.
- Downloads last month
- 39