zenyx-vanta-bert / README.md
Arko007's picture
Update README.md
7876ce2 verified
metadata
language:
  - en
license: apache-2.0
library_name: jax
tags:
  - vanta
  - bert
  - mlm
  - rope
  - swiglu
  - gqa
datasets:
  - HuggingFaceFW/fineweb-edu
  - HuggingFaceFW/fineweb
  - bigcode/starcoderdata
  - open-web-math/open-web-math
metrics:
  - loss

Zenyx-Vanta 350M (Omni-Mix)

Zenyx-Vanta is a modernized Bidirectional Encoder (BERT-style) model. This iteration uses the Omni-Mix dataset strategy, designed to provide the encoder with a balance of high-quality educational text, general web knowledge, Pythonic logic, and mathematical reasoning.

Architecture Details

  • Model Type: Masked Language Model (MLM)
  • Parameters: ~350 Million
  • Tokenizer: Qwen 2.5 (151,646 vocab size)
  • Positioning: Rotary Positional Embeddings (RoPE) with 10k base
  • Activation: SwiGLU (SiLU-gated MLP)
  • Attention: Grouped Query Attention (GQA) with 12 Heads (4 KV Heads)

Training Data: The "Omni-Mix"

Vanta was trained on a balanced 4-way distribution to maximize cross-domain reasoning:

  1. FineWeb-Edu (25%): High-signal educational content.
  2. FineWeb (25%): General linguistic context from broad web crawls.
  3. StarCoderData - Python (25%): Source code for logic and syntax understanding.
  4. Open-Web-Math (25%): Mathematical text and LaTeX for symbolic reasoning.

Technical Specifications

Parameter Value
hidden_size 768
num_hidden_layers 12
num_attention_heads 12
num_key_value_heads 4
intermediate_size 3072
max_position_embeddings 2048
hidden_act SwiGLU (SiLU)

Quick Start / Inference

To use Zenyx-Vanta for mask filling, you can use the following snippet (requires jax, flax, and transformers):

from transformers import AutoTokenizer
import jax.numpy as jnp
# Note: Ensure your local ZenyxVanta architecture definition matches the model weights
# model = ZenyxVanta(vocab_size=151646)

tokenizer = AutoTokenizer.from_pretrained("Arko007/zenyx-vanta-bert")
text = "The powerhouse of the cell is the ___."
prompt = text.replace("___", "<|MASK|>")

inputs = tokenizer(prompt, return_tensors="np")
# logits = model.apply({'params': params}, inputs['input_ids'])
# ... (Standard JAX inference logic)

Credits

Developed by Arko007 and the Zenyx team.