LeanLlama-70B-Instruct-bnb-4bit
Llama 3.3 70B Instruct (4-bit quantized) with 256x KV cache compression on 25% of layers (20 out of 80), reducing KV cache memory for those layers while maintaining model quality.
What is this?
This model applies learned linear autoencoders to compress the key-value cache in 20 of Llama's 80 transformer layers. Each KV vector (1024 dims = 8 GQA heads x 128 head_dim) is compressed to just 4 dimensions via an encoder, then reconstructed through a 2-layer decoder (4 -> 128 -> 1024) with GELU activation. Compression happens per-token at inference time with no changes to the attention mechanism itself.
The base weights are from unsloth/Llama-3.3-70B-Instruct-bnb-4bit (NF4 quantization with double quantization).
Key results (measured on WikiText-2):
- Baseline perplexity (4-bit): 5.44
- With compression: 7.55 (+2.11 points)
- Compressed layers: 40, 42, 45-51, 53-63 (upper half of the network)
- Model size: 36.8 GB (4-bit weights + compressor weights)
These layers were selected through a per-layer compressibility sweep across all 80 layers. The upper layers (40+) have inherently low-dimensional KV subspaces and tolerate aggressive compression with minimal impact on output quality.
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"miike-ai/LeanLlama-70B-Instruct-bnb-4bit",
trust_remote_code=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("miike-ai/LeanLlama-70B-Instruct-bnb-4bit")
messages = [{"role": "user", "content": "What is the theory of relativity?"}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
inputs = inputs.to(model.device)
outputs = model.generate(inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
How it works
Each compressed layer has two small autoencoder modules attached to the self-attention block:
- K compressor:
Linear(1024, 4)encoder +Linear(4, 128) -> GELU -> Linear(128, 1024)decoder - V compressor: Same architecture
After each forward pass, the model intercepts the KV cache and compresses the new tokens in the configured layers. The compressors use _SafeLinear (a custom nn.Module) instead of nn.Linear so that bitsandbytes quantization leaves them untouched — compressor weights stay in full precision.
The compressor weights add only 21 MB to the model (280 parameters across 20 layers).
Layer selection methodology
All 80 layers were individually tested with compression to measure per-layer perplexity impact. Layers were added progressively in order of compressibility:
| Layers compressed | PPL increase | Status |
|---|---|---|
| 9 layers (safest) | +1.32 pts | Negligible impact |
| 12 layers | +2.31 pts | Minor impact |
| 15 layers | +6.16 pts | Errors compounding |
| 20 layers | +2.11 pts | Optimized selection |
The final 20-layer configuration was optimized to stay under +2.5 PPL points by selecting only layers with independently low compression cost from the upper half of the network (layers 40-63).
Architecture details
- Base model: meta-llama/Llama-3.3-70B-Instruct
- Quantization: NF4 with double quantization (via unsloth)
- Model class:
LeanLlamaForCausalLM(extendsLlamaForCausalLM) - Compression ratio: 256x per compressed layer (1024 -> 4 dimensions)
- Compressed layers: 20/80 (25%)
- Compressor overhead: 21 MB (negligible vs 36.8 GB model)
- Training data for compressors: WikiText-2 (6 epochs per layer)
- Total size: 36.8 GB
Limitations
- Compression adds a small perplexity penalty (~2 points on WikiText-2)
- The
trust_remote_code=Trueflag is required since this uses a custom model class - Compressor weights were trained on WikiText-2; other domains may see different compression quality
- This is a 4-bit quantized model; for full precision, see the base model
Citation
If you use this model, please cite the base model:
@article{grattafiori2024llama3,
title={The Llama 3 Herd of Models},
author={Grattafiori, Aaron and others},
journal={arXiv preprint arXiv:2407.21783},
year={2024}
}
- Downloads last month
- 12
Model tree for miike-ai/LeanLlama-70B-Instruct-bnb-4bit
Base model
meta-llama/Llama-3.1-70B