StriMap / src /main.py
cao
fix
c7acc8d
raw
history blame
54.7 kB
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from collections import Counter
from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score
import warnings
from model import negative_sampling_phla
warnings.filterwarnings("ignore")
from physicochemical import PhysicochemicalEncoder
from model import (
ESM2Encoder,
ESMFoldEncoder,
PeptideHLABindingPredictor,
PepHLA_Dataset,
peptide_hla_collate_fn,
TCRPeptideHLABindingPredictor,
TCRPepHLA_Dataset,
tcr_pep_hla_collate_fn,
EarlyStopping
)
# ============================================================================
# Utility functions
# ============================================================================
def load_train_data(
df_train_list: List[pd.DataFrame],
df_val_list: List[pd.DataFrame],
hla_dict_path: str = 'pMHC/HLA_dict.npy',
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Load training and validation datasets only.
Args:
hla_dict_path: Path to HLA dictionary
train_folds: List of training fold indices
val_folds: List of validation fold indices
sample_frac: Fraction of data to sample (for quick testing)
seed: Random seed
Returns:
df_train, df_val
"""
print("Loading training and validation data...")
# Load HLA dictionary
HLA_dict = np.load(hla_dict_path, allow_pickle=True).item()
# Process HLA names β†’ full sequence
for df in df_train_list + df_val_list:
df['HLA'] = df['HLA'].apply(lambda x: x[4:] if x.startswith('HLA-') else x)
df['HLA_full'] = df['HLA'].apply(lambda x: HLA_dict[x])
return df_train_list, df_val_list
def load_test_data(
df_test: pd.DataFrame,
hla_dict_path: str = 'pMHC/HLA_dict.npy'
) -> pd.DataFrame:
"""
Preprocess a given test DataFrame (e.g. independent or external set).
Args:
df_test: Test dataframe with at least 'HLA', 'peptide', 'label'
hla_dict_path: Path to HLA dictionary (to map HLA name to full sequence)
Returns:
Processed df_test with 'HLA_full' added
"""
print("Processing test data...")
HLA_dict = np.load(hla_dict_path, allow_pickle=True).item()
df_test = df_test.copy()
df_test['HLA'] = df_test['HLA'].apply(lambda x: x[4:] if x.startswith('HLA-') else x)
df_test['HLA_full'] = df_test['HLA'].apply(lambda x: HLA_dict[x])
print(f"βœ“ Test set: {len(df_test)} samples")
return df_test
class StriMap_pHLA:
"""
StriMap for Structure-informed Peptide-HLA Binding Prediction Model
"""
def __init__(
self,
device: str = 'cuda:0',
model_save_path: str = '/data/model_params/best_model_phla.pt',
pep_dim: int = 256,
hla_dim: int = 256,
bilinear_dim: int = 256,
loss_fn: str = 'bce',
alpha: float = 0.5,
gamma: float = 2.0,
esm2_layer: int = 33,
batch_size: int = 256,
esmfold_cache_dir: str = "/data/esm_cache",
cache_dir: str = '/data/phla_cache',
cache_save: bool = False,
seed: int = 1,
pos_weights: Optional[float] = None
):
"""
Initialize StriMap model
Args:
device: Device for computation
cache_dir: Directory for caching embeddings
model_save_path: Path to save best model
pep_dim: Peptide embedding dimension
hla_dim: HLA embedding dimension
bilinear_dim: Bilinear attention dimension
loss_fn: Loss function ('bce' or 'focal')
alpha: Alpha parameter for focal loss
gamma: Gamma parameter for focal loss
esm2_layer: ESM2 layer to extract features from
esmfold_cache_dir: Cache directory for ESMFold
cache_dir: Directory for caching embeddings
seed: Random seed
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.model_save_path = model_save_path
if not os.path.exists(os.path.dirname(model_save_path)) and os.path.dirname(model_save_path) != '':
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
self.seed = seed
self.cache_save = cache_save
self.batch_size = batch_size
self.loss_fn_name = loss_fn
self.alpha = alpha
self.gamma = gamma
self.pos_weights = pos_weights
# Set random seeds
self._set_seed(seed)
# Initialize encoders
print("Initializing encoders...")
self.phys_encoder = PhysicochemicalEncoder(device=self.device)
self.esm2_encoder = ESM2Encoder(device=str(self.device), layer=esm2_layer, cache_dir=cache_dir)
self.esmfold_encoder = ESMFoldEncoder(esm_cache_dir=esmfold_cache_dir, cache_dir=cache_dir)
# Initialize model
print("Initializing binding prediction model...")
self.model = PeptideHLABindingPredictor(
pep_dim=pep_dim,
hla_dim=hla_dim,
bilinear_dim=bilinear_dim,
loss_fn=self.loss_fn_name,
alpha=self.alpha,
gamma=self.gamma,
device=str(self.device),
pos_weights=self.pos_weights
).to(self.device)
# Embeddings cache
self.phys_dict = None
self.esm2_dict = None
self.struct_dict = None
print(f"βœ“ StriMap initialized on {self.device}")
def _set_seed(self, seed: int):
"""Set random seeds for reproducibility"""
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def prepare_embeddings(
self,
df: pd.DataFrame,
force_recompute: bool = False,
):
"""
Prepare all embeddings (physicochemical, ESM2, structure)
Args:
df: DataFrame containing 'peptide' and 'HLA_full' columns
force_recompute: Force recomputation even if cache exists
incremental: If True, only compute missing sequences
phys_cache: Physicochemical embeddings cache file
esm2_cache: ESM2 embeddings cache file
struct_cache: Structure embeddings cache file
"""
# Extract unique sequences
all_peptides = sorted(set(df['peptide'].astype(str)))
all_hlas = sorted(set(df['HLA_full'].astype(str)))
print(f"\n{'='*70}")
print(f"Preparing embeddings for:")
print(f" - {len(all_peptides)} unique peptides")
print(f" - {len(all_hlas)} unique HLAs")
print(f"{'='*70}\n")
# ========================================================================
# 1. Physicochemical features
# ========================================================================
self.phys_dict = {
'pep': self._encode_phys(all_peptides),
'hla': self._encode_phys(all_hlas)
}
# ========================================================================
# 2. ESM2 embeddings
# ========================================================================
self.esm2_dict = {
'pep': self._encode_esm2(all_peptides, prefix='pep', re_embed=force_recompute),
'hla': self._encode_esm2(all_hlas, prefix='hla', re_embed=force_recompute)
}
# ========================================================================
# 3. Structure features (only for HLA)
# ========================================================================
self.struct_dict = self._encode_structure(all_hlas)
# ========================================================================
# Summary
# ========================================================================
print(f"{'='*70}")
print("βœ“ All embeddings prepared!")
print(f" - Phys: {len(self.phys_dict['pep'])} peptides, {len(self.phys_dict['hla'])} HLAs")
print(f" - ESM2: {len(self.esm2_dict['pep'])} peptides, {len(self.esm2_dict['hla'])} HLAs")
print(f" - Struct: {len(self.struct_dict)} HLAs")
print(f"{'='*70}\n")
def _encode_phys(self,
sequences: List[str]) -> Dict[str, torch.Tensor]:
"""Encode physicochemical properties"""
emb_dict = {}
for i in tqdm(range(0, len(sequences), self.batch_size), desc="Phys encoding"):
batch = sequences[i:i+self.batch_size]
embs = self.phys_encoder(batch).cpu() # [B, L, D]
for seq, emb in zip(batch, embs):
emb_dict[seq] = emb
return emb_dict
def _encode_esm2(self, sequences: List[str], prefix: str, re_embed: bool = False) -> Dict[str, torch.Tensor]:
"""Encode with ESM2"""
df_tmp = pd.DataFrame({'seq': sequences})
emb_dict = self.esm2_encoder.forward(
df_tmp,
seq_col='seq',
prefix=prefix,
batch_size=self.batch_size,
re_embed=re_embed,
cache_save=self.cache_save
)
return emb_dict
def _encode_structure(self, sequences: List[str], re_embed: bool = False) -> Dict[str, Tuple]:
"""Encode structure with ESMFold"""
feat_list, coor_list = self.esmfold_encoder.forward(
pd.DataFrame({'hla': sequences}),
'hla',
device=str(self.device),
re_embed=re_embed,
)
struct_dict = {
seq: (feat, coor)
for seq, feat, coor in zip(sequences, feat_list, coor_list)
}
return struct_dict
def train(
self,
df_train: pd.DataFrame,
df_val: pd.DataFrame,
epochs: int = 100,
batch_size: int = 256,
lr: float = 1e-4,
patience: int = 5,
num_workers: int = 8,
fold_id: Optional[int] = None
) -> Dict[str, List[float]]:
"""
Train the model
Args:
df_train: Training data
df_val: Validation data
epochs: Number of epochs
batch_size: Batch size
lr: Learning rate
patience: Early stopping patience
num_workers: Number of data loading workers
fold_id: Fold identifier for saving (None for single model)
Returns:
Dictionary with training history
"""
# Check if embeddings are prepared
if self.phys_dict is None or self.esm2_dict is None or self.struct_dict is None:
raise ValueError("Embeddings not prepared! Call prepare_embeddings() first.")
# Create datasets
print("Creating datasets...")
train_dataset = PepHLA_Dataset(df_train, self.phys_dict, self.esm2_dict, self.struct_dict)
val_dataset = PepHLA_Dataset(df_val, self.phys_dict, self.esm2_dict, self.struct_dict)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=peptide_hla_collate_fn,
pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=peptide_hla_collate_fn,
pin_memory=True
)
# Optimizer and early stopping
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
# Model save path for this fold
save_path = self.model_save_path if fold_id is None else \
self.model_save_path.replace('.pt', f'_fold{fold_id}.pt')
early_stopping = EarlyStopping(
patience=patience,
save_path=save_path
)
# Training history
history = {
'train_loss': [],
'val_loss': [],
'val_auc': [],
'val_prc': []
}
fold_str = f"Fold {fold_id}" if fold_id is not None else "Single model"
print(f"\nStarting training for {epochs} epochs [{fold_str}]...")
print("=" * 70)
for epoch in range(epochs):
# Training
self.model.train()
train_loss = 0.0
train_batches = 0
train_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False, ncols=80)
for batch in train_iter:
optimizer.zero_grad()
probs, loss, _, _ = self.model(batch)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_batches += 1
train_loss /= train_batches
# Validation
self.model.eval()
val_loss = 0.0
val_preds = []
val_labels = []
val_batches = 0
with torch.no_grad():
val_iter = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False, ncols=80)
for batch in val_iter:
probs, loss, _, _ = self.model(batch)
val_loss += loss.item()
val_batches += 1
val_preds.extend(probs)
val_labels.extend(batch['label'])
val_auc = roc_auc_score(val_labels, val_preds)
val_loss /= val_batches
val_prc = average_precision_score(val_labels, val_preds)
# Record history
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['val_auc'].append(val_auc)
history['val_prc'].append(val_prc)
# Print metrics
print(f"[{fold_str}] Epoch [{epoch+1}/{epochs}] | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Val AUC: {val_auc:.4f} | "
f"Val PRC: {val_prc:.4f}")
# Early stopping
early_stopping(val_prc, self.model)
if early_stopping.early_stop:
print(f"\n[{fold_str}] Early stopping triggered at epoch {epoch+1}!")
break
# Load best model
print(f"\n[{fold_str}] Loading best model from {save_path}...")
self.model.load_state_dict(torch.load(save_path))
print("=" * 70)
print(f"βœ“ Training completed for {fold_str}!")
return history
def train_kfold(
self,
train_folds: List[Tuple[pd.DataFrame, pd.DataFrame]],
epochs: int = 100,
batch_size: int = 256,
lr: float = 1e-4,
patience: int = 5,
num_workers: int = 8
) -> List[Dict[str, List[float]]]:
"""
Train K-fold cross-validation models
Args:
train_folds: List of (train_df, val_df) tuples for each fold
epochs: Number of epochs per fold
batch_size: Batch size
lr: Learning rate
patience: Early stopping patience
num_workers: Number of data loading workers
Returns:
List of training histories for each fold
"""
num_folds = len(train_folds)
all_histories = []
print("\n" + "=" * 70)
print(f"Starting {num_folds}-Fold Cross-Validation Training")
print("=" * 70)
for fold_id, (df_train, df_val) in enumerate(train_folds):
print(f"\n{'='*70}")
print(f"Training Fold {fold_id+1}/{num_folds}")
print(f"Train: {len(df_train)} samples | Val: {len(df_val)} samples")
print(f"{'='*70}")
self._set_seed(fold_id + self.seed) # Different seed for each fold
# Reinitialize model for this fold
self.model = PeptideHLABindingPredictor(
pep_dim=self.model.pep_dim,
hla_dim=self.model.hla_dim,
bilinear_dim=self.model.bilinear_dim,
loss_fn=self.loss_fn_name,
alpha=self.alpha,
gamma=self.gamma,
device=str(self.device),
pos_weights=self.pos_weights
).to(self.device)
# Train this fold
history = self.train(
df_train,
df_val,
epochs=epochs,
batch_size=batch_size,
lr=lr,
patience=patience,
num_workers=num_workers,
fold_id=fold_id
)
all_histories.append(history)
print("\n" + "=" * 70)
print(f"βœ“ All {num_folds} folds training completed!")
print("=" * 70)
# Print summary
print("\nCross-Validation Summary:")
print("-" * 70)
for fold_id, history in enumerate(all_histories):
best_auc = max(history['val_auc'])
best_epoch = history['val_auc'].index(best_auc) + 1
print(f"Fold {fold_id}: Best Val AUC = {best_auc:.4f} (Epoch {best_epoch})")
mean_auc = np.mean([max(h['val_auc']) for h in all_histories])
std_auc = np.std([max(h['val_auc']) for h in all_histories])
print("-" * 70)
print(f"Mean Val AUC: {mean_auc:.4f} Β± {std_auc:.4f}")
print("=" * 70 + "\n")
return all_histories
def predict(
self,
df: pd.DataFrame,
batch_size: int = 256,
return_probs: bool = True,
return_attn: bool = False,
use_kfold: bool = False,
num_folds: Optional[int] = None,
ensemble_method: str = 'mean',
num_workers: int = 8
) -> np.ndarray:
"""
Make predictions on a dataset
Args:
df: DataFrame with peptide and HLA_full columns
batch_size: Batch size for inference
return_probs: If True, return probabilities; else return binary predictions
use_kfold: If True, use ensemble of K models
num_folds: Number of folds (required if use_kfold=True)
ensemble_method: 'mean' or 'median' for ensemble
Returns:
Array of predictions
"""
# Check if embeddings are prepared
if self.phys_dict is None or self.esm2_dict is None or self.struct_dict is None:
raise ValueError("Embeddings not prepared! Call prepare_embeddings() first.")
if use_kfold:
if num_folds is None:
raise ValueError("num_folds must be specified when use_kfold=True")
return self._predict_ensemble(
df,
batch_size,
num_folds,
ensemble_method,
return_probs,
return_attn,
num_workers
)
else:
# load single model
print(f"\nLoading model from {self.model_save_path} for prediction...")
self.model.load_state_dict(torch.load(self.model_save_path, map_location=self.device), strict=False)
# Single model prediction
return self._predict_single(df, batch_size, return_probs, return_attn, num_workers)
def _pad_attention(self, attns: List[np.ndarray]) -> np.ndarray:
"""Pad attention maps to the same length"""
max_len = max(a.shape[1] for a in attns)
attns_padded = []
for a in attns:
padding = max_len - a.shape[1]
pad_width_3d = ((0, 0), # 不呫充 H η»΄εΊ¦
(0, padding), # ε‘«ε…… Lv η»΄εΊ¦ηš„ζœ«ε°Ύ
(0, 0)) # 不呫充 Lq η»΄εΊ¦
attns_padded.append(np.pad(a, pad_width_3d, mode='constant', constant_values=0.0))
return np.concatenate(attns_padded, axis=0)
def _predict_single(
self,
df: pd.DataFrame,
batch_size: int,
return_probs: bool,
return_attn: bool = False,
num_workers: int = 8
) -> np.ndarray:
"""Single model prediction"""
self.model.eval()
dataset = PepHLA_Dataset(df, self.phys_dict, self.esm2_dict, self.struct_dict)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=peptide_hla_collate_fn,
pin_memory=True
)
preds = []
attns = []
with torch.no_grad():
for batch in tqdm(loader, desc="Predicting"):
probs, loss, attn, _ = self.model(batch)
preds.extend(probs.tolist())
if return_attn:
attns.append(attn)
preds = np.array(preds)
if not return_probs:
preds = (preds >= 0.5).astype(int)
# padding attns to the same length
if not return_attn:
return preds, None
else:
return preds, self._pad_attention(attns)
def _predict_ensemble(
self,
df: pd.DataFrame,
batch_size: int,
num_folds: int,
ensemble_method: str,
return_probs: bool,
return_attn: bool = False,
num_workers: int = 8
) -> np.ndarray:
"""Ensemble prediction using K-fold models"""
print(f"\nEnsemble prediction using {num_folds} models...")
print(f"Ensemble method: {ensemble_method}")
all_preds = []
all_attns = []
for fold_id in range(num_folds):
# Load fold model
fold_model_path = self.model_save_path.replace('.pt', f'_fold{fold_id}.pt')
if not os.path.exists(fold_model_path):
print(f"⚠ Warning: {fold_model_path} not found, skipping...")
continue
print(f"Loading model from {fold_model_path}...")
self.model.load_state_dict(torch.load(fold_model_path, map_location=self.device), strict=False)
# Predict with this fold
if not return_attn:
fold_preds, _ = self._predict_single(df, batch_size, return_probs=True, num_workers=num_workers)
else:
fold_preds, attn_padded = self._predict_single(df, batch_size, return_probs=True, return_attn=True, num_workers=num_workers)
all_attns.append(attn_padded)
all_preds.append(fold_preds)
if len(all_preds) == 0:
raise ValueError("No fold models found!")
# Ensemble predictions
all_preds = np.array(all_preds) # [num_folds, num_samples]
if ensemble_method == 'mean':
ensemble_preds = np.mean(all_preds, axis=0)
elif ensemble_method == 'median':
ensemble_preds = np.median(all_preds, axis=0)
else:
raise ValueError(f"Unknown ensemble method: {ensemble_method}")
print(f"βœ“ Ensemble prediction completed using {len(all_preds)} models")
if not return_probs:
ensemble_preds = (ensemble_preds >= 0.5).astype(int)
if not return_attn:
return ensemble_preds, None
else:
# num_attn_each_fold = attns_padded.shape[0] // len(all_preds)
# # average attns across folds
# attns_padded = attns_padded.reshape(len(all_preds), num_attn_each_fold, attns_padded.shape[1], attns_padded.shape[2])
# attns_padded = np.mean(attns_padded, axis=1)
return ensemble_preds, self._pad_attention(all_attns)
def evaluate(
self,
df: pd.DataFrame,
batch_size: int = 256,
threshold: float = 0.5,
use_kfold: bool = False,
num_folds: Optional[int] = None,
ensemble_method: str = 'mean',
num_workers: int = 8
) -> Dict[str, float]:
"""
Evaluate model on a dataset
Args:
df: DataFrame with peptide, HLA_full, and label columns
batch_size: Batch size for inference
threshold: Classification threshold
use_kfold: If True, use ensemble of K models
num_folds: Number of folds (required if use_kfold=True)
ensemble_method: 'mean' or 'median' for ensemble
Returns:
Dictionary of metrics
"""
y_true = df['label'].values
y_prob, _ = self.predict(
df,
batch_size=batch_size,
return_probs=True,
use_kfold=use_kfold,
num_folds=num_folds,
ensemble_method=ensemble_method,
num_workers=num_workers
)
y_pred = (y_prob >= threshold).astype(int)
# Calculate metrics
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel().tolist()
accuracy = (tp + tn) / (tn + fp + fn + tp)
try:
mcc = ((tp*tn) - (fn*fp)) / np.sqrt(float((tp+fn)*(tn+fp)*(tp+fp)*(tn+fn)))
except:
mcc = 0.0
try:
recall = tp / (tp + fn)
except:
recall = 0.0
try:
precision = tp / (tp + fp)
except:
precision = 0.0
try:
f1 = 2 * precision * recall / (precision + recall)
except:
f1 = 0.0
try:
roc_auc = roc_auc_score(y_true, y_prob)
except:
roc_auc = 0.0
try:
# prc
from sklearn.metrics import average_precision_score
prc_auc = average_precision_score(y_true, y_prob)
except:
prc_auc = 0.0
# Print results
model_type = f"{num_folds}-Fold Ensemble ({ensemble_method})" if use_kfold else "Single Model"
print("\n" + "=" * 70)
print(f"Evaluation Results [{model_type}]")
print("=" * 70)
print(f"tn = {tn}, fp = {fp}, fn = {fn}, tp = {tp}")
print(f"y_pred: 0 = {Counter(y_pred)[0]} | 1 = {Counter(y_pred)[1]}")
print(f"y_true: 0 = {Counter(y_true)[0]} | 1 = {Counter(y_true)[1]}")
print(f"AUC: {roc_auc:.4f} | PRC: {prc_auc:.4f} | ACC: {accuracy:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f}")
print(f"Precision: {precision:.4f} | Recall: {recall:.4f}")
print("=" * 70 + "\n")
return y_prob, {
'auc': roc_auc,
'prc': prc_auc,
'accuracy': accuracy,
'mcc': mcc,
'f1': f1,
'precision': precision,
'recall': recall,
'tn': tn,
'fp': fp,
'fn': fn,
'tp': tp
}
def save_model(self, path: str):
"""Save model weights"""
torch.save(self.model.state_dict(), path)
print(f"βœ“ Model saved to {path}")
def load_model(self, path: str):
"""Load model weights"""
self.model.load_state_dict(torch.load(path, map_location=self.device), strict=False)
print(f"βœ“ Model loaded from {path}")
# ============================================================================
# -*- coding: utf-8 -*-
import os
import numpy as np
import pandas as pd
from collections import Counter
from tqdm import tqdm
import torch
from sklearn.metrics import roc_auc_score, confusion_matrix
class StriMap_TCRpHLA:
"""
Structure-informed TCR(Ξ±/Ξ²)–peptide–HLA Binding Prediction
- Reuses encoders from StriMap_pHLA (phys, ESM2, ESMFold)
- Precomputes peptide–HLA features using pretrained StriMap_pHLA.model (PeptideHLABindingPredictor)
and injects them into batch during training/inference.
"""
def __init__(
self,
pep_hla_system = None, # already-initialized and pretrained
device: str = 'cuda:0',
model_save_path: str = 'best_model_tcrpHLA.pt',
tcr_dim: int = 256,
pep_dim: int = 256,
hla_dim: int = 256,
bilinear_dim: int = 256,
loss_fn: str = 'bce',
alpha: float = 0.5,
gamma: float = 2.0,
resample_negatives: bool = False,
seed: int = 1,
pos_weights: Optional[float] = None
):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.model_save_path = model_save_path
self.seed = seed
self.alpha = alpha
self.gamma = gamma
self.loss_fn_name = loss_fn
self.resample_negatives = resample_negatives
self.pos_weights = pos_weights
# seed
self._set_seed(seed)
if pep_hla_system is None:
raise ValueError("`pep_hla_system` must be provided β€” pass a trained StriMap_pHLA instance.")
# Reuse encoders from StriMap_pHLA
self.phys_encoder = pep_hla_system.phys_encoder
self.esm2_encoder = pep_hla_system.esm2_encoder
self.esmfold_encoder= pep_hla_system.esmfold_encoder
self.pep_hla_model = pep_hla_system.model # PeptideHLABindingPredictor with encode_peptide_hla()
# Initialize TCR–pHLA model
self.model = TCRPeptideHLABindingPredictor(
tcr_dim=tcr_dim,
pep_dim=pep_dim,
hla_dim=hla_dim,
bilinear_dim=bilinear_dim,
loss_fn=self.loss_fn_name,
alpha=self.alpha,
gamma=self.gamma,
pos_weights=self.pos_weights,
device=str(self.device),
).to(self.device)
# Embedding caches
self.phys_dict = None
self.esm2_dict = None
self.struct_dict = None
self.pep_hla_feat_dict = {}
print(f"βœ“ StriMap_TCRpHLA initialized on {self.device}")
# -------------------- utils --------------------
def _set_seed(self, seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# -------------------- encoders --------------------
def _encode_phys(self, sequences):
emb_dict = {}
batch_size = 256
for i in tqdm(range(0, len(sequences), batch_size), desc="Phys encoding (TCRpHLA)"):
batch = sequences[i:i+batch_size]
embs = self.phys_encoder(batch).cpu() # [B, L, D]
for seq, emb in zip(batch, embs):
emb_dict[seq] = emb
return emb_dict
def save_model(self, path: str):
torch.save(self.model.state_dict(), path)
print(f"βœ“ Model saved to {path}")
def load_model(self, path: str):
"""Load model weights"""
self.model.load_state_dict(torch.load(path, map_location=self.device))
print(f"βœ“ Model loaded from {path}")
def _encode_esm2(self, sequences, prefix: str, re_embed: bool=False):
df_tmp = pd.DataFrame({'seq': sequences})
return self.esm2_encoder.forward(
df_tmp, seq_col='seq', prefix=prefix, batch_size=128, re_embed=re_embed
)
def _encode_structure(self, sequences, prefix: str, re_embed: bool=False):
feat_list, coor_list = self.esmfold_encoder.forward(
pd.DataFrame({prefix: sequences}), prefix, device=str(self.device), re_embed=re_embed
)
return {seq: (feat, coor) for seq, feat, coor in zip(sequences, feat_list, coor_list)}
# -------------------- public: prepare embeddings --------------------
def prepare_embeddings(self, df: pd.DataFrame, force_recompute: bool=False):
"""
Prepare per-residue encodings for TCRΞ±, TCRΞ², peptide, and HLA.
Peptide structure is computed via ESMFold as requested.
"""
all_tcra = sorted(set(df['tcra'].astype(str)))
all_tcrb = sorted(set(df['tcrb'].astype(str)))
all_peps = sorted(set(df['peptide'].astype(str)))
all_hlas = sorted(set(df['HLA_full'].astype(str)))
self.max_pep_len = max(len(p) for p in all_peps)
print(f"\nPreparing embeddings:")
print(f" - TCRΞ±: {len(all_tcra)} | TCRΞ²: {len(all_tcrb)} | peptides: {len(all_peps)} | HLAs: {len(all_hlas)}\n")
self.phys_dict = {
'tcra': self._encode_phys(all_tcra),
'tcrb': self._encode_phys(all_tcrb),
'pep': self._encode_phys(all_peps),
'hla': self._encode_phys(all_hlas)
}
self.esm2_dict = {
'tcra': self._encode_esm2(all_tcra, prefix='tcra', re_embed=force_recompute),
'tcrb': self._encode_esm2(all_tcrb, prefix='tcrb', re_embed=force_recompute),
'pep': self._encode_esm2(all_peps, prefix='pep', re_embed=force_recompute),
'hla': self._encode_esm2(all_hlas, prefix='hla', re_embed=force_recompute)
}
# Move everything in phys_dict and esm2_dict to CPU
for d in [self.phys_dict, self.esm2_dict]:
for k1 in d.keys(): # tcra / tcrb / pep / hla
for k2 in d[k1].keys(): # actual sequences
if torch.is_tensor(d[k1][k2]):
d[k1][k2] = d[k1][k2].cpu()
torch.cuda.empty_cache()
# IMPORTANT: include peptide structure via ESMFold
self.struct_dict = {
'tcra': self._encode_structure(all_tcra, prefix='tcra', re_embed=force_recompute),
'tcrb': self._encode_structure(all_tcrb, prefix='tcrb', re_embed=force_recompute),
'pep': self._encode_structure(all_peps, prefix='pep', re_embed=force_recompute),
'hla': self._encode_structure(all_hlas, prefix='hla', re_embed=force_recompute)
}
print("βœ“ Embeddings prepared for TCRΞ±/Ξ², peptide (with ESMFold), and HLA.")
# Move structure features to CPU
for part in ['tcra', 'tcrb', 'pep', 'hla']:
for seq, (feat, coord) in self.struct_dict[part].items():
self.struct_dict[part][seq] = (feat.cpu(), coord.cpu())
torch.cuda.empty_cache()
print("βœ“ All embeddings moved to CPU, GPU memory released.")
# -------------------- public: precompute pHLA features --------------------
def prepare_pep_hla_features(self, df: pd.DataFrame):
"""
Precompute peptide-HLA features using pretrained PeptideHLABindingPredictor.
The resulting features are stored in self.pep_hla_feat_dict and later injected into each batch.
"""
assert self.phys_dict is not None and self.esm2_dict is not None and self.struct_dict is not None, \
"Call prepare_embeddings() first."
pairs = {(row['peptide'], row['HLA_full']) for _, row in df.iterrows()}
self.pep_hla_model.eval()
for p in self.pep_hla_model.parameters():
p.requires_grad = False
print(f"\nPrecomputing peptide-HLA features for {len(pairs)} unique pairs...")
with torch.no_grad():
for pep, hla in tqdm(pairs, desc="pHLA features"):
pep_phys = self.phys_dict['pep'][pep].unsqueeze(0).to(self.device)
pep_esm = self.esm2_dict['pep'][pep].unsqueeze(0).to(self.device)
# If your PeptideHLABindingPredictor supports peptide structure, pass it too:
pep_struct, pep_coord = self.struct_dict['pep'][pep]
pep_struct = pep_struct.unsqueeze(0).to(self.device)
pep_coord = pep_coord.unsqueeze(0).to(self.device)
hla_phys = self.phys_dict['hla'][hla].unsqueeze(0).to(self.device)
hla_esm = self.esm2_dict['hla'][hla].unsqueeze(0).to(self.device)
hla_struct, hla_coord = self.struct_dict['hla'][hla]
hla_struct = hla_struct.unsqueeze(0).to(self.device)
hla_coord = hla_coord.unsqueeze(0).to(self.device)
# NOTE: encode_peptide_hla must accept (pep_struct, pep_coord) if you upgraded it;
# otherwise remove those two args.
pep_feat, hla_feat = self.pep_hla_model.encode_peptide_hla(
pep,
pep_phys, pep_esm,
hla_phys, hla_esm,
hla_struct, hla_coord,
max_pep_len=self.max_pep_len
)
self.pep_hla_feat_dict[(pep, hla)] = {
'pep_feat_pretrain': pep_feat.squeeze(0).cpu(), # [Lp, pep_dim]
'hla_feat_pretrain': hla_feat.squeeze(0).cpu() # [Lh, hla_dim]
}
print("βœ“ Pretrained peptide-HLA features prepared.")
# -------------------- training --------------------
def train(
self,
df_train: pd.DataFrame,
df_val: Optional[pd.DataFrame] = None,
df_test: Optional[pd.DataFrame] = None,
df_neg: Optional[pd.DataFrame] = None,
epochs: int = 100,
batch_size: int = 128,
lr: float = 1e-4,
patience: int = 5,
num_workers: int = 8,
):
"""
Train the TCR-pHLA model.
Args:
df_train: Training data.
df_val: Optional validation data.
df_test: Optional test data for evaluation after each epoch.
df_neg: Optional negative samples for training. Set when resample_negatives=True.
epochs: Number of epochs.
batch_size: Batch size.
lr: Learning rate.
patience: Early stopping patience.
num_workers: Data loading workers.
Returns:
history: Dict containing training and validation metrics.
"""
# ---- Prepare embeddings ----
print("Preparing peptide-HLA features...")
all_dfs = [df for df in [df_train, df_val, df_test, df_neg] if df is not None]
self.prepare_pep_hla_features(pd.concat(all_dfs, axis=0))
# ---- Validation loader (optional) ----
if df_val is not None:
val_ds = TCRPepHLA_Dataset(df_val, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
collate_fn=tcr_pep_hla_collate_fn, pin_memory=True
)
stopper = EarlyStopping(patience=patience, save_path=self.model_save_path)
else:
val_loader, stopper = None, None
# ---- Optimizer ----
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
# ---- Metric history ----
history = {'train_loss': [], 'train_auc': []}
if df_val is not None:
history.update({'val_loss': [], 'val_auc': [], 'val_prc': []})
print("\nStart training TCR–pHLA model...")
df_train_pos = df_train[df_train['label'] == 1].copy().reset_index(drop=True)
for epoch in range(epochs):
# ---------- Training ----------
if self.resample_negatives:
df_train_neg = negative_sampling_phla(df_train_pos, random_state=epoch)
if df_neg is not None:
df_train_neg = pd.concat([df_train_neg, df_neg], axis=0).reset_index(drop=True)
df_train_resample = pd.concat([df_train_pos, df_train_neg], axis=0).reset_index(drop=True)
train_ds = TCRPepHLA_Dataset(df_train_resample, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict)
else:
train_ds = TCRPepHLA_Dataset(df_train, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers,
collate_fn=tcr_pep_hla_collate_fn, pin_memory=True
)
self.model.train()
train_labels, train_preds = [], []
epoch_loss = 0.0
for ibatch, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")):
optimizer.zero_grad()
probs, loss, _, _ = self.model(batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=2.0)
optimizer.step()
epoch_loss += loss.item()
train_labels.extend(batch['label'].cpu().numpy().tolist())
train_preds.extend(probs.detach().cpu().numpy().tolist())
train_auc = roc_auc_score(train_labels, train_preds)
train_loss = epoch_loss / (ibatch + 1)
history['train_loss'].append(train_loss)
history['train_auc'].append(train_auc)
print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train AUC: {train_auc:.4f}")
# ---------- Validation ----------
if df_val is not None:
self.model.eval()
val_loss_sum, val_labels, val_preds = 0.0, [], []
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
probs, loss, _, _ = self.model(batch)
val_loss_sum += loss.item()
val_labels.extend(batch['label'].cpu().numpy().tolist())
val_preds.extend(probs.detach().cpu().numpy().tolist())
val_loss = val_loss_sum / len(val_loader)
val_auc = roc_auc_score(val_labels, val_preds)
val_prc = average_precision_score(val_labels, val_preds)
history['val_loss'].append(val_loss)
history['val_auc'].append(val_auc)
history['val_prc'].append(val_prc)
print(f"Epoch {epoch+1}/{epochs} | Val AUC: {val_auc:.4f} | Val PRC: {val_prc:.4f} | Val Loss: {val_loss:.4f}")
stopper(val_auc, self.model)
if stopper.early_stop:
print(f"Early stopping at epoch {epoch+1}")
break
# ---------- Optional Test ----------
if df_test is not None:
test_ds = TCRPepHLA_Dataset(df_test, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict)
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
collate_fn=tcr_pep_hla_collate_fn, pin_memory=True
)
self.model.eval()
test_labels, test_preds = [], []
with torch.no_grad():
for batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} [Test]"):
probs, _, _, _ = self.model(batch)
test_labels.extend(batch['label'].cpu().numpy().tolist())
test_preds.extend(probs.detach().cpu().numpy().tolist())
test_auc = roc_auc_score(test_labels, test_preds)
test_prc = average_precision_score(test_labels, test_preds)
print(f"Epoch {epoch+1}/{epochs} | Test AUC: {test_auc:.4f} | Test PRC: {test_prc:.4f}")
# ---- Load best model only if validation used ----
if df_val is not None and os.path.exists(self.model_save_path):
self.model.load_state_dict(torch.load(self.model_save_path, map_location=self.device))
print(f"βœ“ Training finished. Best model loaded from {self.model_save_path}")
else:
print("βœ“ Training finished (no validation set used).")
return history
def train_kfold(
self,
train_folds: List[Tuple[pd.DataFrame, pd.DataFrame]],
df_test: Optional[pd.DataFrame] = None,
df_neg: Optional[pd.DataFrame] = None,
epochs: int = 100,
batch_size: int = 128,
lr: float = 1e-4,
patience: int = 8,
num_workers: int = 8,
) -> List[Dict[str, List[float]]]:
"""
K-fold cross-validation training for TCR-pHLA model.
Args:
train_folds: list of (train_df, val_df) tuples for each fold
df_test: optional test data for evaluation after each epoch
df_neg: optional negative samples for training. Set when resample_negatives=True.
epochs: training epochs
batch_size: batch size
lr: learning rate
patience: early stopping patience
num_workers: dataloader workers
Returns:
List of training histories for each fold
"""
num_folds = len(train_folds)
all_histories = []
print("\n" + "=" * 70)
print(f"Starting {num_folds}-Fold Cross-Validation Training (TCR-pHLA)")
print("=" * 70)
for fold_id, (df_train, df_val) in enumerate(train_folds):
print(f"\n{'='*70}")
print(f"Training Fold {fold_id+1}/{num_folds}")
print(f"{'='*70}")
self._set_seed(self.seed + fold_id)
self.model = TCRPeptideHLABindingPredictor(
tcr_dim=self.model.tcr_dim,
pep_dim=self.model.pep_dim,
hla_dim=self.model.hla_dim,
bilinear_dim=self.model.bilinear_dim,
loss_fn=self.loss_fn_name,
alpha=self.alpha,
gamma=self.gamma,
pos_weights=self.pos_weights,
device=str(self.device),
).to(self.device)
fold_save_path = self.model_save_path.replace(".pt", f"_fold{fold_id}.pt")
history = self.train(
df_train=df_train,
df_val=df_val,
df_test=df_test,
df_neg=df_neg,
epochs=epochs,
batch_size=batch_size,
lr=lr,
patience=patience,
num_workers=num_workers,
)
torch.save(self.model.state_dict(), fold_save_path)
print(f"βœ“ Saved fold {fold_id} model to {fold_save_path}")
all_histories.append(history)
print("\n" + "=" * 70)
print(f"βœ“ All {num_folds} folds training completed (TCR-pHLA)")
print("=" * 70)
if df_val is not None:
print("\nCross-Validation Summary:")
print("-" * 70)
for fold_id, hist in enumerate(all_histories):
best_auc = max(hist['val_auc'])
best_prc = max(hist['val_prc'])
best_epoch = hist['val_auc'].index(best_auc) + 1
print(f"Fold {fold_id}: Best Val AUC = {best_auc:.4f}, Best Val PRC = {best_prc:.4f}, (Epoch {best_epoch})")
mean_auc = np.mean([max(h['val_auc']) for h in all_histories])
std_auc = np.std([max(h['val_auc']) for h in all_histories])
print("-" * 70)
print(f"Mean Val AUC: {mean_auc:.4f} Β± {std_auc:.4f}")
print("=" * 70 + "\n")
return all_histories
# -------------------- single-set predict --------------------
def _predict_single(
self, df: pd.DataFrame,
batch_size: int = 128,
return_probs: bool = True,
num_workers: int = 8
):
self.model.eval()
ds = TCRPepHLA_Dataset(df, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict)
loader = torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
shuffle=False,
collate_fn=tcr_pep_hla_collate_fn,
num_workers=num_workers,
pin_memory=True
)
preds = []
pep_feat_all = []
attn_all = []
with torch.no_grad():
for batch in tqdm(loader, desc="Predicting (TCR-pHLA)"):
probs, _, pep_feature, attn_dict = self.model(batch)
preds.extend(probs.tolist())
pep_feat_all.append(pep_feature)
attn_all.append(attn_dict)
preds = np.array(preds)
if not return_probs:
preds = (preds >= 0.5).astype(int)
return preds, pep_feat_all, attn_all
# ================================================================
# Ensemble prediction
# ================================================================
def _predict_ensemble(
self,
df: pd.DataFrame,
batch_size: int = 128,
num_folds: int = 5,
ensemble_method: str = 'mean',
return_probs: bool = True,
num_workers: int = 8
) -> np.ndarray:
"""
Ensemble prediction using multiple fold models.
"""
print(f"\nEnsemble prediction using {num_folds} TCR–pHLA models...")
print(f"Ensemble method: {ensemble_method}")
pep_feats_folds = []
attn_dict_folds = []
all_preds = []
for fold_id in range(num_folds):
fold_model_path = self.model_save_path.replace(".pt", f"_fold{fold_id}.pt")
if not os.path.exists(fold_model_path):
print(f"⚠ Warning: {fold_model_path} not found, skipping...")
continue
print(f"Loading model from {fold_model_path}...")
self.model.load_state_dict(torch.load(fold_model_path, map_location=self.device), strict=False)
# Predict for this fold
fold_preds, fold_pep_feature, fold_attn_dict = self._predict_single(
df, batch_size=batch_size, return_probs=True, num_workers=num_workers
)
all_preds.append(fold_preds)
pep_feats_folds.append(fold_pep_feature)
attn_dict_folds.append(fold_attn_dict)
if len(all_preds) == 0:
raise ValueError("No fold models found!")
if ensemble_method == 'mean':
ensemble_preds = np.mean(all_preds, axis=0)
elif ensemble_method == 'median':
ensemble_preds = np.median(all_preds, axis=0)
else:
raise ValueError(f"Unknown ensemble method: {ensemble_method}")
print(f"βœ“ Ensemble prediction completed using {len(all_preds)} folds")
if not return_probs:
ensemble_preds = (ensemble_preds >= 0.5).astype(int)
return ensemble_preds, pep_feats_folds, attn_dict_folds
# ================================================================
# Unified predict() with ensemble support
# ================================================================
def predict(
self,
df: pd.DataFrame,
batch_size: int = 128,
return_probs: bool = True,
use_kfold: bool = False,
num_folds: Optional[int] = None,
ensemble_method: str = 'mean',
num_workers: int = 8
) -> Tuple[np.ndarray, List, List]:
"""
Predict binding probabilities or binary labels.
If use_kfold=True, averages predictions across fold models.
"""
print('Preparing peptide-HLA features for prediction set...')
self.prepare_pep_hla_features(df)
if use_kfold:
if num_folds is None:
raise ValueError("num_folds must be specified when use_kfold=True")
return self._predict_ensemble(
df=df,
batch_size=batch_size,
num_folds=num_folds,
ensemble_method=ensemble_method,
return_probs=return_probs,
num_workers=num_workers
)
else:
return self._predict_single(df, batch_size=batch_size, return_probs=return_probs, num_workers=num_workers)
# ================================================================
# Unified evaluate() with ensemble support
# ================================================================
def evaluate(
self,
df: pd.DataFrame,
batch_size: int = 128,
threshold: float = 0.5,
use_kfold: bool = False,
num_folds: Optional[int] = None,
ensemble_method: str = 'mean',
num_workers: int = 8
) -> Dict[str, float]:
"""
Evaluate model performance on a dataset.
If use_kfold=True, performs ensemble evaluation across folds.
"""
y_true = df['label'].values
y_prob, all_pep_features, merged_attn = self.predict(
df,
batch_size=batch_size,
return_probs=True,
use_kfold=use_kfold,
num_folds=num_folds,
ensemble_method=ensemble_method,
num_workers=num_workers
)
y_pred = (y_prob >= threshold).astype(int)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel().tolist()
accuracy = (tp + tn) / (tn + fp + fn + tp + 1e-9)
try:
mcc = ((tp*tn) - (fn*fp)) / np.sqrt(float((tp+fn)*(tn+fp)*(tp+fp)*(tn+fn)) + 1e-9)
except:
mcc = 0.0
recall = tp / (tp + fn + 1e-9)
precision = tp / (tp + fp + 1e-9)
f1 = 2 * precision * recall / (precision + recall + 1e-9)
try:
auc = roc_auc_score(y_true, y_prob, max_fpr=0.1)
except:
auc = 0.0
print("\n" + "=" * 70)
print(f"Evaluation Results [{'K-Fold Ensemble' if use_kfold else 'Single Model'}]")
print("=" * 70)
print(f"tn={tn}, fp={fp}, fn={fn}, tp={tp}")
print(f"AUC={auc:.4f} | ACC={accuracy:.4f} | MCC={mcc:.4f} | F1={f1:.4f} | P={precision:.4f} | R={recall:.4f}")
print("=" * 70 + "\n")
return dict(
auc=auc, accuracy=accuracy, mcc=mcc, f1=f1,
precision=precision, recall=recall,
tn=tn, fp=fp, fn=fn, tp=tp
)