snp-universal-embedding / snp_universal_embedding.py
366degrees's picture
Upload 4 files
8c8d036 verified
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import os, json
print("βœ… Environment ready")
print("Torch:", torch.__version__)
# ============================================================
# Custom SNP Model Architecture
# ============================================================
class CustomSNPModel(nn.Module):
def __init__(self, base_model="bert-base-uncased"):
super().__init__()
self.shared_encoder = AutoModel.from_pretrained(base_model)
hidden_size = self.shared_encoder.config.hidden_size
self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.projection = nn.Linear(hidden_size, 6)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.shared_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
cls = outputs.last_hidden_state[:, 0, :]
proj = self.projection(cls)
return proj
print("βœ… SNP architecture defined.")
# ============================================================
# Load Checkpoint (optional; comment out if not available)
# ============================================================
ckpt_path = "pytorch_model.bin"
if os.path.exists(ckpt_path):
print(f"Loading weights from {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location="cpu")
clean_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model = CustomSNPModel(base_model="bert-base-uncased")
model.load_state_dict(clean_state_dict, strict=False)
print("βœ… Checkpoint loaded successfully.")
else:
print("⚠️ No checkpoint found, initializing new model.")
model = CustomSNPModel()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# ============================================================
# Example Inference
# ============================================================
text = "A student must decide between a scholarship and their family."
inputs = tokenizer(text, return_tensors="pt")
inputs.pop("token_type_ids", None)
with torch.no_grad():
output = model(**inputs)
print("βœ… Embedding generated successfully.")
print("Embedding shape:", output.shape if hasattr(output, "shape") else type(output))