Sparse Autoencoders for Gemma-3-27b-it

This repository contains 9 Sparse Autoencoders (SAEs) trained on google/gemma-3-27b-it using the BatchTopK architecture.

Architecture: BatchTopK SAE

These SAEs use the BatchTopK architecture, which enforces sparsity by:

  1. Computing feature activations: z = Wx + b (encoder)
  2. Selecting top-k features across the batch (not per-sample)
  3. Reconstructing: x̂ = W'z_topk + b_dec (decoder)

This approach tends to produce more interpretable features than ReLU-based SAEs and has better training dynamics.

Repository Structure

layer_45/
  dict_16k_k80/       # 16,384 features, k=80
    ae.pt             # SAE weights
    config.json       # Training configuration
    feature_labels.json  # Natural language feature descriptions
  dict_16k_k160/      # 16,384 features, k=160
  dict_65k_k80/       # 65,536 features, k=80
  dict_65k_k160/      # 65,536 features, k=160
layer_47/
  (same structure)
layer_45_mlp/
  (same structure)

Available SAEs

Layer Dict Size k Activation Dim Parameters Sparsity
45 16,384 80 5,376 176,182,528 0.49%
45 16,384 160 5,376 176,182,528 0.98%
45 65,536 80 5,376 704,713,984 0.12%
45 65,536 160 5,376 704,713,984 0.24%
47 16,384 80 5,376 176,182,528 0.49%
47 16,384 160 5,376 176,182,528 0.98%
47 65,536 80 5,376 704,713,984 0.12%
47 65,536 160 5,376 704,713,984 0.24%

Total Parameters: 3,523,586,048

Model Details

Training Details

Base Model: google/gemma-3-27b-it

Hook Point: residual_stream (post-layer activations)

Dataset: FineWeb (HuggingFaceFW/fineweb)

Training Hyperparameters:

  • Optimizer: Adam
  • Learning rate: 5e-5
  • Warmup steps: 1,000
  • Training steps: ~244,140
  • Context length: 2,048 tokens
  • Batch size: 2,048 activations
  • Decay start: 195,312 steps

BatchTopK Parameters:

  • Auxiliary loss coefficient (α): 0.03125
  • Threshold decay (β): 0.999
  • Threshold start step: 1,000

Sparsity Levels:

  • k=80: Higher sparsity, more selective features
  • k=160: Lower sparsity, more features active per sample

Dictionary Sizes:

  • 16,384: Compact, efficient, good for resource-constrained analysis
  • 65,536: Comprehensive, captures more fine-grained patterns

Feature Labels

This repository includes natural language descriptions for all features, generated using LLM-as-a-judge (GPT-4) on maximum activating examples. Each feature has:

  • Title: Short description of what the feature detects
  • Description: Detailed explanation with examples
  • Examples: Token sequences that maximally activate the feature

Usage

Installation

pip install torch transformers huggingface_hub

Loading an SAE

import torch
from huggingface_hub import hf_hub_download

# Download specific SAE
ae_path = hf_hub_download(
    repo_id="uzaymacar/gemma-3-27b-saes",
    filename="layer_45/dict_16k_k80/ae.pt",
    subfolder=None,
)

config_path = hf_hub_download(
    repo_id="uzaymacar/gemma-3-27b-saes",
    filename="layer_45/dict_16k_k80/config.json",
)

# Load SAE
ae_data = torch.load(ae_path, map_location='cpu')
with open(config_path, 'r') as f:
    config = json.load(f)

print(f"Loaded SAE with {config['trainer']['dict_size']} features")
print(f"Activation dimension: {config['trainer']['activation_dim']}")
print(f"Top-k: {config['trainer']['k']}")

# SAE weights
encoder_weight = ae_data['encoder.weight']  # [dict_size, activation_dim]
encoder_bias = ae_data['encoder.bias']      # [dict_size]
decoder_weight = ae_data['decoder.weight']  # [activation_dim, dict_size]
decoder_bias = ae_data['b_dec']             # [activation_dim]
threshold = ae_data['threshold']            # Learned threshold

Using the SAE for Analysis

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

# Load base model
model_name = "google/gemma-3-27b-it"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Get activations from layer 45
text = "The capital of France is Paris"
inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
    layer_45_acts = outputs.hidden_states[45]  # [batch, seq, activation_dim]

# Encode with SAE
acts_flat = layer_45_acts.reshape(-1, layer_45_acts.shape[-1])  # [batch*seq, dim]

# Encoder: z = Wx + b
z = F.linear(acts_flat, encoder_weight, encoder_bias)  # [batch*seq, dict_size]

# Top-k selection (per sample, not batch)
top_k = config['trainer']['k']
top_values, top_indices = torch.topk(z, k=top_k, dim=-1)

# Create sparse representation
z_topk = torch.zeros_like(z)
z_topk.scatter_(-1, top_indices, top_values)

# Decode: x̂ = W'z + b
reconstructed = F.linear(z_topk, decoder_weight.t(), decoder_bias)

# Compute reconstruction loss
mse_loss = F.mse_loss(reconstructed, acts_flat)
print(f"Reconstruction MSE: {mse_loss.item():.6f}")

# Find active features
active_features = top_indices[0, 0]  # First token's active features
print(f"Active features for first token: {active_features.tolist()}")

Loading Feature Labels

import json
from huggingface_hub import hf_hub_download

# Download feature labels
labels_path = hf_hub_download(
    repo_id="uzaymacar/gemma-3-27b-saes",
    filename="layer_45/dict_16k_k80/feature_labels.json",
)

with open(labels_path, 'r') as f:
    labels = json.load(f)

# Examine a specific feature
feature_id = 1234
if str(feature_id) in labels:
    label = labels[str(feature_id)]
    print(f"Feature {feature_id}:")
    print(f"  Title: {label.get('title', 'N/A')}")
    print(f"  Description: {label.get('description', 'N/A')}")

Citation

If you use these SAEs in your research, please cite:

@software{gemma3_27b_saes,
  author = {Macar, Uzay},
  title = {Sparse Autoencoders for Gemma-3-27b-it},
  year = {2024},
  url = {https://huggingface.co/uzaymacar/gemma-3-27b-saes}
}

SAE Training Framework:

@software{dictionary_learning,
  author = {Marks, Samuel and others},
  title = {Dictionary Learning for Mechanistic Interpretability},
  year = {2024},
  url = {https://github.com/saprmarks/dictionary_learning}
}

BatchTopK Architecture:

@article{gao2024batchTopK,
  title={Scaling and evaluating sparse autoencoders},
  author={Gao, Leo and others},
  journal={arXiv preprint arXiv:2406.04093},
  year={2024}
}

License

These SAEs are released under the same license as the base model (google/gemma-3-27b-it).

Acknowledgments

Contact

For questions or issues, please contact me at [email protected]


Note: These SAEs are research artifacts. While they provide valuable insights into model representations, they should be used as one tool among many for interpretability research.

Downloads last month
39
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for uzaymacar/gemma-3-27b-saes

Finetuned
(379)
this model