rokati commited on
Commit
7a44376
·
verified ·
1 Parent(s): daf2eb8

Upload model_architecture.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_architecture.py +97 -0
model_architecture.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv
4
+
5
+ class XGNet(nn.Module):
6
+ """
7
+ Heterogeneous GNN for xG with Global Features.
8
+
9
+ Graph Structure:
10
+ - Nodes: shooter (player_id), goal (learnable)
11
+ - Edges: goal → shooter (distance, angle_to_goal, dist_to_gk, angle_to_gk)
12
+ - Global: 18 shot-level features (body part, play pattern, timing, etc.)
13
+ """
14
+ def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int,
15
+ use_norm: bool, num_global_features: int = 18):
16
+ super().__init__()
17
+
18
+ # 1) Node encoders ---------------------------------------------------
19
+ self.shooter_emb = nn.Embedding(num_players + 1, hid) # +1 for padding/UNK
20
+ self.goal_feat = nn.Parameter(torch.zeros(1, hid)) # learnable goal feature
21
+
22
+ # Global feature encoder
23
+ self.global_encoder = nn.Linear(num_global_features, hid)
24
+
25
+ self.dropout = nn.Dropout(p=p)
26
+
27
+ # 2) Edge-conditioned message passing -------------------------------
28
+ def mk_gat_with_edge(edge_dim: int):
29
+ """GAT with edge features"""
30
+ return GATv2Conv(
31
+ in_channels=(hid, hid),
32
+ out_channels=hid,
33
+ edge_dim=edge_dim,
34
+ heads=heads,
35
+ concat=False,
36
+ dropout=p,
37
+ add_self_loops=False,
38
+ )
39
+
40
+ self.convs = nn.ModuleList()
41
+ self.norms = nn.ModuleList()
42
+
43
+ for _ in range(num_layers):
44
+ conv = HeteroConv({
45
+ ('goal', 'distance', 'shooter'): mk_gat_with_edge(edge_dim=1),
46
+ ('goal', 'angle_to_goal', 'shooter'): mk_gat_with_edge(edge_dim=1),
47
+ ('goal', 'dist_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1),
48
+ ('goal', 'angle_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1),
49
+ }, aggr='sum')
50
+ self.convs.append(conv)
51
+
52
+ if use_norm:
53
+ self.norms.append(nn.LayerNorm(hid))
54
+
55
+ # 3) Read-out --------------------------------------------------------
56
+ self.output = nn.Sequential(
57
+ nn.Linear(hid, hid//2),
58
+ nn.ReLU(),
59
+ nn.Dropout(p),
60
+ nn.Linear(hid//2, 1),
61
+ )
62
+
63
+ # ----------------------------------------------------------------------
64
+ def forward(self, data):
65
+ # Prepare node feature dict
66
+ shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long())
67
+ shooter_emb = self.dropout(shooter_emb)
68
+
69
+ x = {
70
+ 'goal' : self.goal_feat.expand(data['goal'].num_nodes, -1),
71
+ 'shooter': shooter_emb,
72
+ }
73
+
74
+ # Message passing
75
+ for li, conv in enumerate(self.convs):
76
+ x_new = conv(x, data.edge_index_dict, data.edge_attr_dict)
77
+
78
+ shooter_updated = x_new['shooter']
79
+ if self.norms is not None and len(self.norms) > 0:
80
+ shooter_updated = self.norms[li](shooter_updated)
81
+
82
+ x['shooter'] = self.dropout(shooter_updated + x['shooter'])
83
+
84
+ # Graph-level pooling (handles batches transparently)
85
+ shooter_batch = getattr(
86
+ data['shooter'], 'batch',
87
+ torch.zeros(x['shooter'].size(0), dtype=torch.long, device=x['shooter'].device)
88
+ )
89
+ g_repr = global_mean_pool(x['shooter'], shooter_batch) # (batch × hid)
90
+
91
+ # Encode global features
92
+ global_feat = self.global_encoder(data.global_features.squeeze(1)) # (batch × hid)
93
+
94
+ # Combine pooled node embeddings + global context
95
+ combined = g_repr + global_feat # Element-wise addition
96
+
97
+ return self.output(combined).squeeze(1) # xG ∈ (0, 1)