|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
import os, json |
|
|
|
|
|
print("β
Environment ready") |
|
|
print("Torch:", torch.__version__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomSNPModel(nn.Module): |
|
|
def __init__(self, base_model="bert-base-uncased"): |
|
|
super().__init__() |
|
|
self.shared_encoder = AutoModel.from_pretrained(base_model) |
|
|
hidden_size = self.shared_encoder.config.hidden_size |
|
|
self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh()) |
|
|
self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh()) |
|
|
self.projection = nn.Linear(hidden_size, 6) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None): |
|
|
outputs = self.shared_encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids |
|
|
) |
|
|
cls = outputs.last_hidden_state[:, 0, :] |
|
|
proj = self.projection(cls) |
|
|
return proj |
|
|
|
|
|
print("β
SNP architecture defined.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ckpt_path = "pytorch_model.bin" |
|
|
if os.path.exists(ckpt_path): |
|
|
print(f"Loading weights from {ckpt_path}") |
|
|
state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
clean_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
|
|
|
model = CustomSNPModel(base_model="bert-base-uncased") |
|
|
model.load_state_dict(clean_state_dict, strict=False) |
|
|
print("β
Checkpoint loaded successfully.") |
|
|
else: |
|
|
print("β οΈ No checkpoint found, initializing new model.") |
|
|
model = CustomSNPModel() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text = "A student must decide between a scholarship and their family." |
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
inputs.pop("token_type_ids", None) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(**inputs) |
|
|
|
|
|
print("β
Embedding generated successfully.") |
|
|
print("Embedding shape:", output.shape if hasattr(output, "shape") else type(output)) |
|
|
|