Upload model_architecture.py with huggingface_hub
Browse files- model_architecture.py +97 -0
model_architecture.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv
|
| 4 |
+
|
| 5 |
+
class XGNet(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Heterogeneous GNN for xG with Global Features.
|
| 8 |
+
|
| 9 |
+
Graph Structure:
|
| 10 |
+
- Nodes: shooter (player_id), goal (learnable)
|
| 11 |
+
- Edges: goal → shooter (distance, angle_to_goal, dist_to_gk, angle_to_gk)
|
| 12 |
+
- Global: 18 shot-level features (body part, play pattern, timing, etc.)
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int,
|
| 15 |
+
use_norm: bool, num_global_features: int = 18):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
# 1) Node encoders ---------------------------------------------------
|
| 19 |
+
self.shooter_emb = nn.Embedding(num_players + 1, hid) # +1 for padding/UNK
|
| 20 |
+
self.goal_feat = nn.Parameter(torch.zeros(1, hid)) # learnable goal feature
|
| 21 |
+
|
| 22 |
+
# Global feature encoder
|
| 23 |
+
self.global_encoder = nn.Linear(num_global_features, hid)
|
| 24 |
+
|
| 25 |
+
self.dropout = nn.Dropout(p=p)
|
| 26 |
+
|
| 27 |
+
# 2) Edge-conditioned message passing -------------------------------
|
| 28 |
+
def mk_gat_with_edge(edge_dim: int):
|
| 29 |
+
"""GAT with edge features"""
|
| 30 |
+
return GATv2Conv(
|
| 31 |
+
in_channels=(hid, hid),
|
| 32 |
+
out_channels=hid,
|
| 33 |
+
edge_dim=edge_dim,
|
| 34 |
+
heads=heads,
|
| 35 |
+
concat=False,
|
| 36 |
+
dropout=p,
|
| 37 |
+
add_self_loops=False,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.convs = nn.ModuleList()
|
| 41 |
+
self.norms = nn.ModuleList()
|
| 42 |
+
|
| 43 |
+
for _ in range(num_layers):
|
| 44 |
+
conv = HeteroConv({
|
| 45 |
+
('goal', 'distance', 'shooter'): mk_gat_with_edge(edge_dim=1),
|
| 46 |
+
('goal', 'angle_to_goal', 'shooter'): mk_gat_with_edge(edge_dim=1),
|
| 47 |
+
('goal', 'dist_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1),
|
| 48 |
+
('goal', 'angle_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1),
|
| 49 |
+
}, aggr='sum')
|
| 50 |
+
self.convs.append(conv)
|
| 51 |
+
|
| 52 |
+
if use_norm:
|
| 53 |
+
self.norms.append(nn.LayerNorm(hid))
|
| 54 |
+
|
| 55 |
+
# 3) Read-out --------------------------------------------------------
|
| 56 |
+
self.output = nn.Sequential(
|
| 57 |
+
nn.Linear(hid, hid//2),
|
| 58 |
+
nn.ReLU(),
|
| 59 |
+
nn.Dropout(p),
|
| 60 |
+
nn.Linear(hid//2, 1),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# ----------------------------------------------------------------------
|
| 64 |
+
def forward(self, data):
|
| 65 |
+
# Prepare node feature dict
|
| 66 |
+
shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long())
|
| 67 |
+
shooter_emb = self.dropout(shooter_emb)
|
| 68 |
+
|
| 69 |
+
x = {
|
| 70 |
+
'goal' : self.goal_feat.expand(data['goal'].num_nodes, -1),
|
| 71 |
+
'shooter': shooter_emb,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# Message passing
|
| 75 |
+
for li, conv in enumerate(self.convs):
|
| 76 |
+
x_new = conv(x, data.edge_index_dict, data.edge_attr_dict)
|
| 77 |
+
|
| 78 |
+
shooter_updated = x_new['shooter']
|
| 79 |
+
if self.norms is not None and len(self.norms) > 0:
|
| 80 |
+
shooter_updated = self.norms[li](shooter_updated)
|
| 81 |
+
|
| 82 |
+
x['shooter'] = self.dropout(shooter_updated + x['shooter'])
|
| 83 |
+
|
| 84 |
+
# Graph-level pooling (handles batches transparently)
|
| 85 |
+
shooter_batch = getattr(
|
| 86 |
+
data['shooter'], 'batch',
|
| 87 |
+
torch.zeros(x['shooter'].size(0), dtype=torch.long, device=x['shooter'].device)
|
| 88 |
+
)
|
| 89 |
+
g_repr = global_mean_pool(x['shooter'], shooter_batch) # (batch × hid)
|
| 90 |
+
|
| 91 |
+
# Encode global features
|
| 92 |
+
global_feat = self.global_encoder(data.global_features.squeeze(1)) # (batch × hid)
|
| 93 |
+
|
| 94 |
+
# Combine pooled node embeddings + global context
|
| 95 |
+
combined = g_repr + global_feat # Element-wise addition
|
| 96 |
+
|
| 97 |
+
return self.output(combined).squeeze(1) # xG ∈ (0, 1)
|