--- language: - en license: apache-2.0 library_name: jax tags: - vanta - bert - mlm - rope - swiglu - gqa datasets: - HuggingFaceFW/fineweb-edu - HuggingFaceFW/fineweb - bigcode/starcoderdata - open-web-math/open-web-math metrics: - loss --- # Zenyx-Vanta 350M (Omni-Mix) Zenyx-Vanta is a modernized **Bidirectional Encoder** (BERT-style) model. This iteration uses the **Omni-Mix** dataset strategy, designed to provide the encoder with a balance of high-quality educational text, general web knowledge, Pythonic logic, and mathematical reasoning. ## Architecture Details * **Model Type:** Masked Language Model (MLM) * **Parameters:** ~350 Million * **Tokenizer:** Qwen 2.5 (151,646 vocab size) * **Positioning:** Rotary Positional Embeddings (RoPE) with 10k base * **Activation:** SwiGLU (SiLU-gated MLP) * **Attention:** Grouped Query Attention (GQA) with 12 Heads (4 KV Heads) ## Training Data: The "Omni-Mix" Vanta was trained on a balanced 4-way distribution to maximize cross-domain reasoning: 1. **FineWeb-Edu (25%):** High-signal educational content. 2. **FineWeb (25%):** General linguistic context from broad web crawls. 3. **StarCoderData - Python (25%):** Source code for logic and syntax understanding. 4. **Open-Web-Math (25%):** Mathematical text and LaTeX for symbolic reasoning. ## Technical Specifications | Parameter | Value | | ----- | ----- | | `hidden_size` | 768 | | `num_hidden_layers` | 12 | | `num_attention_heads` | 12 | | `num_key_value_heads` | 4 | | `intermediate_size` | 3072 | | `max_position_embeddings` | 2048 | | `hidden_act` | SwiGLU (SiLU) | ## Quick Start / Inference To use Zenyx-Vanta for mask filling, you can use the following snippet (requires `jax`, `flax`, and `transformers`): ```python from transformers import AutoTokenizer import jax.numpy as jnp # Note: Ensure your local ZenyxVanta architecture definition matches the model weights # model = ZenyxVanta(vocab_size=151646) tokenizer = AutoTokenizer.from_pretrained("Arko007/zenyx-vanta-bert") text = "The powerhouse of the cell is the ___." prompt = text.replace("___", "<|MASK|>") inputs = tokenizer(prompt, return_tensors="np") # logits = model.apply({'params': params}, inputs['input_ids']) # ... (Standard JAX inference logic) ``` ## Credits Developed by **Arko007** and the **Zenyx** team.