File size: 2,446 Bytes
8c8d036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import os, json

print("✅ Environment ready")
print("Torch:", torch.__version__)

# ============================================================
# Custom SNP Model Architecture
# ============================================================
class CustomSNPModel(nn.Module):
    def __init__(self, base_model="bert-base-uncased"):
        super().__init__()
        self.shared_encoder = AutoModel.from_pretrained(base_model)
        hidden_size = self.shared_encoder.config.hidden_size
        self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
        self.prism_head  = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
        self.projection  = nn.Linear(hidden_size, 6)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.shared_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        cls = outputs.last_hidden_state[:, 0, :]
        proj = self.projection(cls)
        return proj

print("✅ SNP architecture defined.")

# ============================================================
# Load Checkpoint (optional; comment out if not available)
# ============================================================
ckpt_path = "pytorch_model.bin"
if os.path.exists(ckpt_path):
    print(f"Loading weights from {ckpt_path}")
    state_dict = torch.load(ckpt_path, map_location="cpu")
    clean_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

    model = CustomSNPModel(base_model="bert-base-uncased")
    model.load_state_dict(clean_state_dict, strict=False)
    print("✅ Checkpoint loaded successfully.")
else:
    print("⚠️ No checkpoint found, initializing new model.")
    model = CustomSNPModel()

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# ============================================================
# Example Inference
# ============================================================
text = "A student must decide between a scholarship and their family."
inputs = tokenizer(text, return_tensors="pt")
inputs.pop("token_type_ids", None)

with torch.no_grad():
    output = model(**inputs)

print("✅ Embedding generated successfully.")
print("Embedding shape:", output.shape if hasattr(output, "shape") else type(output))