import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer import os, json print("✅ Environment ready") print("Torch:", torch.__version__) # ============================================================ # Custom SNP Model Architecture # ============================================================ class CustomSNPModel(nn.Module): def __init__(self, base_model="bert-base-uncased"): super().__init__() self.shared_encoder = AutoModel.from_pretrained(base_model) hidden_size = self.shared_encoder.config.hidden_size self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh()) self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh()) self.projection = nn.Linear(hidden_size, 6) def forward(self, input_ids, attention_mask=None, token_type_ids=None): outputs = self.shared_encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) cls = outputs.last_hidden_state[:, 0, :] proj = self.projection(cls) return proj print("✅ SNP architecture defined.") # ============================================================ # Load Checkpoint (optional; comment out if not available) # ============================================================ ckpt_path = "pytorch_model.bin" if os.path.exists(ckpt_path): print(f"Loading weights from {ckpt_path}") state_dict = torch.load(ckpt_path, map_location="cpu") clean_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model = CustomSNPModel(base_model="bert-base-uncased") model.load_state_dict(clean_state_dict, strict=False) print("✅ Checkpoint loaded successfully.") else: print("⚠️ No checkpoint found, initializing new model.") model = CustomSNPModel() tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") # ============================================================ # Example Inference # ============================================================ text = "A student must decide between a scholarship and their family." inputs = tokenizer(text, return_tensors="pt") inputs.pop("token_type_ids", None) with torch.no_grad(): output = model(**inputs) print("✅ Embedding generated successfully.") print("Embedding shape:", output.shape if hasattr(output, "shape") else type(output))