File size: 2,446 Bytes
8c8d036 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import os, json
print("✅ Environment ready")
print("Torch:", torch.__version__)
# ============================================================
# Custom SNP Model Architecture
# ============================================================
class CustomSNPModel(nn.Module):
def __init__(self, base_model="bert-base-uncased"):
super().__init__()
self.shared_encoder = AutoModel.from_pretrained(base_model)
hidden_size = self.shared_encoder.config.hidden_size
self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.projection = nn.Linear(hidden_size, 6)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.shared_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
cls = outputs.last_hidden_state[:, 0, :]
proj = self.projection(cls)
return proj
print("✅ SNP architecture defined.")
# ============================================================
# Load Checkpoint (optional; comment out if not available)
# ============================================================
ckpt_path = "pytorch_model.bin"
if os.path.exists(ckpt_path):
print(f"Loading weights from {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location="cpu")
clean_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model = CustomSNPModel(base_model="bert-base-uncased")
model.load_state_dict(clean_state_dict, strict=False)
print("✅ Checkpoint loaded successfully.")
else:
print("⚠️ No checkpoint found, initializing new model.")
model = CustomSNPModel()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# ============================================================
# Example Inference
# ============================================================
text = "A student must decide between a scholarship and their family."
inputs = tokenizer(text, return_tensors="pt")
inputs.pop("token_type_ids", None)
with torch.no_grad():
output = model(**inputs)
print("✅ Embedding generated successfully.")
print("Embedding shape:", output.shape if hasattr(output, "shape") else type(output))
|