import os import torch import torch.nn as nn import numpy as np import pandas as pd from typing import Dict, List, Tuple, Optional from tqdm import tqdm from collections import Counter from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score import warnings from model import negative_sampling_phla warnings.filterwarnings("ignore") from physicochemical import PhysicochemicalEncoder from model import ( ESM2Encoder, ESMFoldEncoder, PeptideHLABindingPredictor, PepHLA_Dataset, peptide_hla_collate_fn, TCRPeptideHLABindingPredictor, TCRPepHLA_Dataset, tcr_pep_hla_collate_fn, EarlyStopping ) # ============================================================================ # Utility functions # ============================================================================ def load_train_data( df_train_list: List[pd.DataFrame], df_val_list: List[pd.DataFrame], hla_dict_path: str = 'pMHC/HLA_dict.npy', ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Load training and validation datasets only. Args: hla_dict_path: Path to HLA dictionary train_folds: List of training fold indices val_folds: List of validation fold indices sample_frac: Fraction of data to sample (for quick testing) seed: Random seed Returns: df_train, df_val """ print("Loading training and validation data...") # Load HLA dictionary HLA_dict = np.load(hla_dict_path, allow_pickle=True).item() # Process HLA names → full sequence for df in df_train_list + df_val_list: df['HLA'] = df['HLA'].apply(lambda x: x[4:] if x.startswith('HLA-') else x) df['HLA_full'] = df['HLA'].apply(lambda x: HLA_dict[x]) return df_train_list, df_val_list def load_test_data( df_test: pd.DataFrame, hla_dict_path: str = 'pMHC/HLA_dict.npy' ) -> pd.DataFrame: """ Preprocess a given test DataFrame (e.g. independent or external set). Args: df_test: Test dataframe with at least 'HLA', 'peptide', 'label' hla_dict_path: Path to HLA dictionary (to map HLA name to full sequence) Returns: Processed df_test with 'HLA_full' added """ print("Processing test data...") HLA_dict = np.load(hla_dict_path, allow_pickle=True).item() df_test = df_test.copy() df_test['HLA'] = df_test['HLA'].apply(lambda x: x[4:] if x.startswith('HLA-') else x) df_test['HLA_full'] = df_test['HLA'].apply(lambda x: HLA_dict[x]) print(f"✓ Test set: {len(df_test)} samples") return df_test class StriMap_pHLA: """ StriMap for Structure-informed Peptide-HLA Binding Prediction Model """ def __init__( self, device: str = 'cuda:0', model_save_path: str = '/data/model_params/best_model_phla.pt', pep_dim: int = 256, hla_dim: int = 256, bilinear_dim: int = 256, loss_fn: str = 'bce', alpha: float = 0.5, gamma: float = 2.0, esm2_layer: int = 33, batch_size: int = 256, esmfold_cache_dir: str = "/data/esm_cache", cache_dir: str = '/data/phla_cache', cache_save: bool = False, seed: int = 1, pos_weights: Optional[float] = None ): """ Initialize StriMap model Args: device: Device for computation cache_dir: Directory for caching embeddings model_save_path: Path to save best model pep_dim: Peptide embedding dimension hla_dim: HLA embedding dimension bilinear_dim: Bilinear attention dimension loss_fn: Loss function ('bce' or 'focal') alpha: Alpha parameter for focal loss gamma: Gamma parameter for focal loss esm2_layer: ESM2 layer to extract features from esmfold_cache_dir: Cache directory for ESMFold cache_dir: Directory for caching embeddings seed: Random seed """ self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model_save_path = model_save_path if not os.path.exists(os.path.dirname(model_save_path)) and os.path.dirname(model_save_path) != '': os.makedirs(os.path.dirname(model_save_path), exist_ok=True) self.seed = seed self.cache_save = cache_save self.batch_size = batch_size self.loss_fn_name = loss_fn self.alpha = alpha self.gamma = gamma self.pos_weights = pos_weights # Set random seeds self._set_seed(seed) # Initialize encoders print("Initializing encoders...") self.phys_encoder = PhysicochemicalEncoder(device=self.device) self.esm2_encoder = ESM2Encoder(device=str(self.device), layer=esm2_layer, cache_dir=cache_dir) self.esmfold_encoder = ESMFoldEncoder(esm_cache_dir=esmfold_cache_dir, cache_dir=cache_dir) # Initialize model print("Initializing binding prediction model...") self.model = PeptideHLABindingPredictor( pep_dim=pep_dim, hla_dim=hla_dim, bilinear_dim=bilinear_dim, loss_fn=self.loss_fn_name, alpha=self.alpha, gamma=self.gamma, device=str(self.device), pos_weights=self.pos_weights ).to(self.device) # Embeddings cache self.phys_dict = None self.esm2_dict = None self.struct_dict = None print(f"✓ StriMap initialized on {self.device}") def _set_seed(self, seed: int): """Set random seeds for reproducibility""" np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def prepare_embeddings( self, df: pd.DataFrame, force_recompute: bool = False, ): """ Prepare all embeddings (physicochemical, ESM2, structure) Args: df: DataFrame containing 'peptide' and 'HLA_full' columns force_recompute: Force recomputation even if cache exists incremental: If True, only compute missing sequences phys_cache: Physicochemical embeddings cache file esm2_cache: ESM2 embeddings cache file struct_cache: Structure embeddings cache file """ # Extract unique sequences all_peptides = sorted(set(df['peptide'].astype(str))) all_hlas = sorted(set(df['HLA_full'].astype(str))) print(f"\n{'='*70}") print(f"Preparing embeddings for:") print(f" - {len(all_peptides)} unique peptides") print(f" - {len(all_hlas)} unique HLAs") print(f"{'='*70}\n") # ======================================================================== # 1. Physicochemical features # ======================================================================== self.phys_dict = { 'pep': self._encode_phys(all_peptides), 'hla': self._encode_phys(all_hlas) } # ======================================================================== # 2. ESM2 embeddings # ======================================================================== self.esm2_dict = { 'pep': self._encode_esm2(all_peptides, prefix='pep', re_embed=force_recompute), 'hla': self._encode_esm2(all_hlas, prefix='hla', re_embed=force_recompute) } # ======================================================================== # 3. Structure features (only for HLA) # ======================================================================== self.struct_dict = self._encode_structure(all_hlas) # ======================================================================== # Summary # ======================================================================== print(f"{'='*70}") print("✓ All embeddings prepared!") print(f" - Phys: {len(self.phys_dict['pep'])} peptides, {len(self.phys_dict['hla'])} HLAs") print(f" - ESM2: {len(self.esm2_dict['pep'])} peptides, {len(self.esm2_dict['hla'])} HLAs") print(f" - Struct: {len(self.struct_dict)} HLAs") print(f"{'='*70}\n") def _encode_phys(self, sequences: List[str]) -> Dict[str, torch.Tensor]: """Encode physicochemical properties""" emb_dict = {} for i in tqdm(range(0, len(sequences), self.batch_size), desc="Phys encoding"): batch = sequences[i:i+self.batch_size] embs = self.phys_encoder(batch).cpu() # [B, L, D] for seq, emb in zip(batch, embs): emb_dict[seq] = emb return emb_dict def _encode_esm2(self, sequences: List[str], prefix: str, re_embed: bool = False) -> Dict[str, torch.Tensor]: """Encode with ESM2""" df_tmp = pd.DataFrame({'seq': sequences}) emb_dict = self.esm2_encoder.forward( df_tmp, seq_col='seq', prefix=prefix, batch_size=self.batch_size, re_embed=re_embed, cache_save=self.cache_save ) return emb_dict def _encode_structure(self, sequences: List[str], re_embed: bool = False) -> Dict[str, Tuple]: """Encode structure with ESMFold""" feat_list, coor_list = self.esmfold_encoder.forward( pd.DataFrame({'hla': sequences}), 'hla', device=str(self.device), re_embed=re_embed, ) struct_dict = { seq: (feat, coor) for seq, feat, coor in zip(sequences, feat_list, coor_list) } return struct_dict def train( self, df_train: pd.DataFrame, df_val: pd.DataFrame, epochs: int = 100, batch_size: int = 256, lr: float = 1e-4, patience: int = 5, num_workers: int = 8, fold_id: Optional[int] = None ) -> Dict[str, List[float]]: """ Train the model Args: df_train: Training data df_val: Validation data epochs: Number of epochs batch_size: Batch size lr: Learning rate patience: Early stopping patience num_workers: Number of data loading workers fold_id: Fold identifier for saving (None for single model) Returns: Dictionary with training history """ # Check if embeddings are prepared if self.phys_dict is None or self.esm2_dict is None or self.struct_dict is None: raise ValueError("Embeddings not prepared! Call prepare_embeddings() first.") # Create datasets print("Creating datasets...") train_dataset = PepHLA_Dataset(df_train, self.phys_dict, self.esm2_dict, self.struct_dict) val_dataset = PepHLA_Dataset(df_val, self.phys_dict, self.esm2_dict, self.struct_dict) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=peptide_hla_collate_fn, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=peptide_hla_collate_fn, pin_memory=True ) # Optimizer and early stopping optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr) # Model save path for this fold save_path = self.model_save_path if fold_id is None else \ self.model_save_path.replace('.pt', f'_fold{fold_id}.pt') early_stopping = EarlyStopping( patience=patience, save_path=save_path ) # Training history history = { 'train_loss': [], 'val_loss': [], 'val_auc': [], 'val_prc': [] } fold_str = f"Fold {fold_id}" if fold_id is not None else "Single model" print(f"\nStarting training for {epochs} epochs [{fold_str}]...") print("=" * 70) for epoch in range(epochs): # Training self.model.train() train_loss = 0.0 train_batches = 0 train_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False, ncols=80) for batch in train_iter: optimizer.zero_grad() probs, loss, _, _ = self.model(batch) loss.backward() optimizer.step() train_loss += loss.item() train_batches += 1 train_loss /= train_batches # Validation self.model.eval() val_loss = 0.0 val_preds = [] val_labels = [] val_batches = 0 with torch.no_grad(): val_iter = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False, ncols=80) for batch in val_iter: probs, loss, _, _ = self.model(batch) val_loss += loss.item() val_batches += 1 val_preds.extend(probs) val_labels.extend(batch['label']) val_auc = roc_auc_score(val_labels, val_preds) val_loss /= val_batches val_prc = average_precision_score(val_labels, val_preds) # Record history history['train_loss'].append(train_loss) history['val_loss'].append(val_loss) history['val_auc'].append(val_auc) history['val_prc'].append(val_prc) # Print metrics print(f"[{fold_str}] Epoch [{epoch+1}/{epochs}] | " f"Train Loss: {train_loss:.4f} | " f"Val Loss: {val_loss:.4f} | " f"Val AUC: {val_auc:.4f} | " f"Val PRC: {val_prc:.4f}") # Early stopping early_stopping(val_prc, self.model) if early_stopping.early_stop: print(f"\n[{fold_str}] Early stopping triggered at epoch {epoch+1}!") break # Load best model print(f"\n[{fold_str}] Loading best model from {save_path}...") self.model.load_state_dict(torch.load(save_path)) print("=" * 70) print(f"✓ Training completed for {fold_str}!") return history def train_kfold( self, train_folds: List[Tuple[pd.DataFrame, pd.DataFrame]], epochs: int = 100, batch_size: int = 256, lr: float = 1e-4, patience: int = 5, num_workers: int = 8 ) -> List[Dict[str, List[float]]]: """ Train K-fold cross-validation models Args: train_folds: List of (train_df, val_df) tuples for each fold epochs: Number of epochs per fold batch_size: Batch size lr: Learning rate patience: Early stopping patience num_workers: Number of data loading workers Returns: List of training histories for each fold """ num_folds = len(train_folds) all_histories = [] print("\n" + "=" * 70) print(f"Starting {num_folds}-Fold Cross-Validation Training") print("=" * 70) for fold_id, (df_train, df_val) in enumerate(train_folds): print(f"\n{'='*70}") print(f"Training Fold {fold_id+1}/{num_folds}") print(f"Train: {len(df_train)} samples | Val: {len(df_val)} samples") print(f"{'='*70}") self._set_seed(fold_id + self.seed) # Different seed for each fold # Reinitialize model for this fold self.model = PeptideHLABindingPredictor( pep_dim=self.model.pep_dim, hla_dim=self.model.hla_dim, bilinear_dim=self.model.bilinear_dim, loss_fn=self.loss_fn_name, alpha=self.alpha, gamma=self.gamma, device=str(self.device), pos_weights=self.pos_weights ).to(self.device) # Train this fold history = self.train( df_train, df_val, epochs=epochs, batch_size=batch_size, lr=lr, patience=patience, num_workers=num_workers, fold_id=fold_id ) all_histories.append(history) print("\n" + "=" * 70) print(f"✓ All {num_folds} folds training completed!") print("=" * 70) # Print summary print("\nCross-Validation Summary:") print("-" * 70) for fold_id, history in enumerate(all_histories): best_auc = max(history['val_auc']) best_epoch = history['val_auc'].index(best_auc) + 1 print(f"Fold {fold_id}: Best Val AUC = {best_auc:.4f} (Epoch {best_epoch})") mean_auc = np.mean([max(h['val_auc']) for h in all_histories]) std_auc = np.std([max(h['val_auc']) for h in all_histories]) print("-" * 70) print(f"Mean Val AUC: {mean_auc:.4f} ± {std_auc:.4f}") print("=" * 70 + "\n") return all_histories def predict( self, df: pd.DataFrame, batch_size: int = 256, return_probs: bool = True, return_attn: bool = False, use_kfold: bool = False, num_folds: Optional[int] = None, ensemble_method: str = 'mean', num_workers: int = 8 ) -> np.ndarray: """ Make predictions on a dataset Args: df: DataFrame with peptide and HLA_full columns batch_size: Batch size for inference return_probs: If True, return probabilities; else return binary predictions use_kfold: If True, use ensemble of K models num_folds: Number of folds (required if use_kfold=True) ensemble_method: 'mean' or 'median' for ensemble Returns: Array of predictions """ # Check if embeddings are prepared if self.phys_dict is None or self.esm2_dict is None or self.struct_dict is None: raise ValueError("Embeddings not prepared! Call prepare_embeddings() first.") if use_kfold: if num_folds is None: raise ValueError("num_folds must be specified when use_kfold=True") return self._predict_ensemble( df, batch_size, num_folds, ensemble_method, return_probs, return_attn, num_workers ) else: # load single model print(f"\nLoading model from {self.model_save_path} for prediction...") self.model.load_state_dict(torch.load(self.model_save_path, map_location=self.device), strict=False) # Single model prediction return self._predict_single(df, batch_size, return_probs, return_attn, num_workers) def _pad_attention(self, attns: List[np.ndarray]) -> np.ndarray: """Pad attention maps to the same length""" max_len = max(a.shape[1] for a in attns) attns_padded = [] for a in attns: padding = max_len - a.shape[1] pad_width_3d = ((0, 0), # 不填充 H 维度 (0, padding), # 填充 Lv 维度的末尾 (0, 0)) # 不填充 Lq 维度 attns_padded.append(np.pad(a, pad_width_3d, mode='constant', constant_values=0.0)) return np.concatenate(attns_padded, axis=0) def _predict_single( self, df: pd.DataFrame, batch_size: int, return_probs: bool, return_attn: bool = False, num_workers: int = 8 ) -> np.ndarray: """Single model prediction""" self.model.eval() dataset = PepHLA_Dataset(df, self.phys_dict, self.esm2_dict, self.struct_dict) loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=peptide_hla_collate_fn, pin_memory=True ) preds = [] attns = [] with torch.no_grad(): for batch in tqdm(loader, desc="Predicting"): probs, loss, attn, _ = self.model(batch) preds.extend(probs.tolist()) if return_attn: attns.append(attn) preds = np.array(preds) if not return_probs: preds = (preds >= 0.5).astype(int) # padding attns to the same length if not return_attn: return preds, None else: return preds, self._pad_attention(attns) def _predict_ensemble( self, df: pd.DataFrame, batch_size: int, num_folds: int, ensemble_method: str, return_probs: bool, return_attn: bool = False, num_workers: int = 8 ) -> np.ndarray: """Ensemble prediction using K-fold models""" print(f"\nEnsemble prediction using {num_folds} models...") print(f"Ensemble method: {ensemble_method}") all_preds = [] all_attns = [] for fold_id in range(num_folds): # Load fold model fold_model_path = self.model_save_path.replace('.pt', f'_fold{fold_id}.pt') if not os.path.exists(fold_model_path): print(f"⚠ Warning: {fold_model_path} not found, skipping...") continue print(f"Loading model from {fold_model_path}...") self.model.load_state_dict(torch.load(fold_model_path, map_location=self.device), strict=False) # Predict with this fold if not return_attn: fold_preds, _ = self._predict_single(df, batch_size, return_probs=True, num_workers=num_workers) else: fold_preds, attn_padded = self._predict_single(df, batch_size, return_probs=True, return_attn=True, num_workers=num_workers) all_attns.append(attn_padded) all_preds.append(fold_preds) if len(all_preds) == 0: raise ValueError("No fold models found!") # Ensemble predictions all_preds = np.array(all_preds) # [num_folds, num_samples] if ensemble_method == 'mean': ensemble_preds = np.mean(all_preds, axis=0) elif ensemble_method == 'median': ensemble_preds = np.median(all_preds, axis=0) else: raise ValueError(f"Unknown ensemble method: {ensemble_method}") print(f"✓ Ensemble prediction completed using {len(all_preds)} models") if not return_probs: ensemble_preds = (ensemble_preds >= 0.5).astype(int) if not return_attn: return ensemble_preds, None else: # num_attn_each_fold = attns_padded.shape[0] // len(all_preds) # # average attns across folds # attns_padded = attns_padded.reshape(len(all_preds), num_attn_each_fold, attns_padded.shape[1], attns_padded.shape[2]) # attns_padded = np.mean(attns_padded, axis=1) return ensemble_preds, self._pad_attention(all_attns) def evaluate( self, df: pd.DataFrame, batch_size: int = 256, threshold: float = 0.5, use_kfold: bool = False, num_folds: Optional[int] = None, ensemble_method: str = 'mean', num_workers: int = 8 ) -> Dict[str, float]: """ Evaluate model on a dataset Args: df: DataFrame with peptide, HLA_full, and label columns batch_size: Batch size for inference threshold: Classification threshold use_kfold: If True, use ensemble of K models num_folds: Number of folds (required if use_kfold=True) ensemble_method: 'mean' or 'median' for ensemble Returns: Dictionary of metrics """ y_true = df['label'].values y_prob, _ = self.predict( df, batch_size=batch_size, return_probs=True, use_kfold=use_kfold, num_folds=num_folds, ensemble_method=ensemble_method, num_workers=num_workers ) y_pred = (y_prob >= threshold).astype(int) # Calculate metrics tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel().tolist() accuracy = (tp + tn) / (tn + fp + fn + tp) try: mcc = ((tp*tn) - (fn*fp)) / np.sqrt(float((tp+fn)*(tn+fp)*(tp+fp)*(tn+fn))) except: mcc = 0.0 try: recall = tp / (tp + fn) except: recall = 0.0 try: precision = tp / (tp + fp) except: precision = 0.0 try: f1 = 2 * precision * recall / (precision + recall) except: f1 = 0.0 try: roc_auc = roc_auc_score(y_true, y_prob) except: roc_auc = 0.0 try: # prc from sklearn.metrics import average_precision_score prc_auc = average_precision_score(y_true, y_prob) except: prc_auc = 0.0 # Print results model_type = f"{num_folds}-Fold Ensemble ({ensemble_method})" if use_kfold else "Single Model" print("\n" + "=" * 70) print(f"Evaluation Results [{model_type}]") print("=" * 70) print(f"tn = {tn}, fp = {fp}, fn = {fn}, tp = {tp}") print(f"y_pred: 0 = {Counter(y_pred)[0]} | 1 = {Counter(y_pred)[1]}") print(f"y_true: 0 = {Counter(y_true)[0]} | 1 = {Counter(y_true)[1]}") print(f"AUC: {roc_auc:.4f} | PRC: {prc_auc:.4f} | ACC: {accuracy:.4f} | MCC: {mcc:.4f} | F1: {f1:.4f}") print(f"Precision: {precision:.4f} | Recall: {recall:.4f}") print("=" * 70 + "\n") return y_prob, { 'auc': roc_auc, 'prc': prc_auc, 'accuracy': accuracy, 'mcc': mcc, 'f1': f1, 'precision': precision, 'recall': recall, 'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp } def save_model(self, path: str): """Save model weights""" torch.save(self.model.state_dict(), path) print(f"✓ Model saved to {path}") def load_model(self, path: str): """Load model weights""" self.model.load_state_dict(torch.load(path, map_location=self.device), strict=False) print(f"✓ Model loaded from {path}") # ============================================================================ # -*- coding: utf-8 -*- import os import numpy as np import pandas as pd from collections import Counter from tqdm import tqdm import torch from sklearn.metrics import roc_auc_score, confusion_matrix class StriMap_TCRpHLA: """ Structure-informed TCR(α/β)–peptide–HLA Binding Prediction - Reuses encoders from StriMap_pHLA (phys, ESM2, ESMFold) - Precomputes peptide–HLA features using pretrained StriMap_pHLA.model (PeptideHLABindingPredictor) and injects them into batch during training/inference. """ def __init__( self, pep_hla_system = None, # already-initialized and pretrained device: str = 'cuda:0', model_save_path: str = 'best_model_tcrpHLA.pt', tcr_dim: int = 256, pep_dim: int = 256, hla_dim: int = 256, bilinear_dim: int = 256, loss_fn: str = 'bce', alpha: float = 0.5, gamma: float = 2.0, resample_negatives: bool = False, seed: int = 1, pos_weights: Optional[float] = None ): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model_save_path = model_save_path self.seed = seed self.alpha = alpha self.gamma = gamma self.loss_fn_name = loss_fn self.resample_negatives = resample_negatives self.pos_weights = pos_weights # seed self._set_seed(seed) if pep_hla_system is None: raise ValueError("`pep_hla_system` must be provided — pass a trained StriMap_pHLA instance.") # Reuse encoders from StriMap_pHLA self.phys_encoder = pep_hla_system.phys_encoder self.esm2_encoder = pep_hla_system.esm2_encoder self.esmfold_encoder= pep_hla_system.esmfold_encoder self.pep_hla_model = pep_hla_system.model # PeptideHLABindingPredictor with encode_peptide_hla() # Initialize TCR–pHLA model self.model = TCRPeptideHLABindingPredictor( tcr_dim=tcr_dim, pep_dim=pep_dim, hla_dim=hla_dim, bilinear_dim=bilinear_dim, loss_fn=self.loss_fn_name, alpha=self.alpha, gamma=self.gamma, pos_weights=self.pos_weights, device=str(self.device), ).to(self.device) # Embedding caches self.phys_dict = None self.esm2_dict = None self.struct_dict = None self.pep_hla_feat_dict = {} print(f"✓ StriMap_TCRpHLA initialized on {self.device}") # -------------------- utils -------------------- def _set_seed(self, seed: int): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # -------------------- encoders -------------------- def _encode_phys(self, sequences): emb_dict = {} batch_size = 256 for i in tqdm(range(0, len(sequences), batch_size), desc="Phys encoding (TCRpHLA)"): batch = sequences[i:i+batch_size] embs = self.phys_encoder(batch).cpu() # [B, L, D] for seq, emb in zip(batch, embs): emb_dict[seq] = emb return emb_dict def save_model(self, path: str): torch.save(self.model.state_dict(), path) print(f"✓ Model saved to {path}") def load_model(self, path: str): """Load model weights""" self.model.load_state_dict(torch.load(path, map_location=self.device)) print(f"✓ Model loaded from {path}") def _encode_esm2(self, sequences, prefix: str, re_embed: bool=False): df_tmp = pd.DataFrame({'seq': sequences}) return self.esm2_encoder.forward( df_tmp, seq_col='seq', prefix=prefix, batch_size=128, re_embed=re_embed ) def _encode_structure(self, sequences, prefix: str, re_embed: bool=False): feat_list, coor_list = self.esmfold_encoder.forward( pd.DataFrame({prefix: sequences}), prefix, device=str(self.device), re_embed=re_embed ) return {seq: (feat, coor) for seq, feat, coor in zip(sequences, feat_list, coor_list)} # -------------------- public: prepare embeddings -------------------- def prepare_embeddings(self, df: pd.DataFrame, force_recompute: bool=False): """ Prepare per-residue encodings for TCRα, TCRβ, peptide, and HLA. Peptide structure is computed via ESMFold as requested. """ all_tcra = sorted(set(df['tcra'].astype(str))) all_tcrb = sorted(set(df['tcrb'].astype(str))) all_peps = sorted(set(df['peptide'].astype(str))) all_hlas = sorted(set(df['HLA_full'].astype(str))) self.max_pep_len = max(len(p) for p in all_peps) print(f"\nPreparing embeddings:") print(f" - TCRα: {len(all_tcra)} | TCRβ: {len(all_tcrb)} | peptides: {len(all_peps)} | HLAs: {len(all_hlas)}\n") self.phys_dict = { 'tcra': self._encode_phys(all_tcra), 'tcrb': self._encode_phys(all_tcrb), 'pep': self._encode_phys(all_peps), 'hla': self._encode_phys(all_hlas) } self.esm2_dict = { 'tcra': self._encode_esm2(all_tcra, prefix='tcra', re_embed=force_recompute), 'tcrb': self._encode_esm2(all_tcrb, prefix='tcrb', re_embed=force_recompute), 'pep': self._encode_esm2(all_peps, prefix='pep', re_embed=force_recompute), 'hla': self._encode_esm2(all_hlas, prefix='hla', re_embed=force_recompute) } # Move everything in phys_dict and esm2_dict to CPU for d in [self.phys_dict, self.esm2_dict]: for k1 in d.keys(): # tcra / tcrb / pep / hla for k2 in d[k1].keys(): # actual sequences if torch.is_tensor(d[k1][k2]): d[k1][k2] = d[k1][k2].cpu() torch.cuda.empty_cache() # IMPORTANT: include peptide structure via ESMFold self.struct_dict = { 'tcra': self._encode_structure(all_tcra, prefix='tcra', re_embed=force_recompute), 'tcrb': self._encode_structure(all_tcrb, prefix='tcrb', re_embed=force_recompute), 'pep': self._encode_structure(all_peps, prefix='pep', re_embed=force_recompute), 'hla': self._encode_structure(all_hlas, prefix='hla', re_embed=force_recompute) } print("✓ Embeddings prepared for TCRα/β, peptide (with ESMFold), and HLA.") # Move structure features to CPU for part in ['tcra', 'tcrb', 'pep', 'hla']: for seq, (feat, coord) in self.struct_dict[part].items(): self.struct_dict[part][seq] = (feat.cpu(), coord.cpu()) torch.cuda.empty_cache() print("✓ All embeddings moved to CPU, GPU memory released.") # -------------------- public: precompute pHLA features -------------------- def prepare_pep_hla_features(self, df: pd.DataFrame): """ Precompute peptide-HLA features using pretrained PeptideHLABindingPredictor. The resulting features are stored in self.pep_hla_feat_dict and later injected into each batch. """ assert self.phys_dict is not None and self.esm2_dict is not None and self.struct_dict is not None, \ "Call prepare_embeddings() first." pairs = {(row['peptide'], row['HLA_full']) for _, row in df.iterrows()} self.pep_hla_model.eval() for p in self.pep_hla_model.parameters(): p.requires_grad = False print(f"\nPrecomputing peptide-HLA features for {len(pairs)} unique pairs...") with torch.no_grad(): for pep, hla in tqdm(pairs, desc="pHLA features"): pep_phys = self.phys_dict['pep'][pep].unsqueeze(0).to(self.device) pep_esm = self.esm2_dict['pep'][pep].unsqueeze(0).to(self.device) # If your PeptideHLABindingPredictor supports peptide structure, pass it too: pep_struct, pep_coord = self.struct_dict['pep'][pep] pep_struct = pep_struct.unsqueeze(0).to(self.device) pep_coord = pep_coord.unsqueeze(0).to(self.device) hla_phys = self.phys_dict['hla'][hla].unsqueeze(0).to(self.device) hla_esm = self.esm2_dict['hla'][hla].unsqueeze(0).to(self.device) hla_struct, hla_coord = self.struct_dict['hla'][hla] hla_struct = hla_struct.unsqueeze(0).to(self.device) hla_coord = hla_coord.unsqueeze(0).to(self.device) # NOTE: encode_peptide_hla must accept (pep_struct, pep_coord) if you upgraded it; # otherwise remove those two args. pep_feat, hla_feat = self.pep_hla_model.encode_peptide_hla( pep, pep_phys, pep_esm, hla_phys, hla_esm, hla_struct, hla_coord, max_pep_len=self.max_pep_len ) self.pep_hla_feat_dict[(pep, hla)] = { 'pep_feat_pretrain': pep_feat.squeeze(0).cpu(), # [Lp, pep_dim] 'hla_feat_pretrain': hla_feat.squeeze(0).cpu() # [Lh, hla_dim] } print("✓ Pretrained peptide-HLA features prepared.") # -------------------- training -------------------- def train( self, df_train: pd.DataFrame, df_val: Optional[pd.DataFrame] = None, df_test: Optional[pd.DataFrame] = None, df_neg: Optional[pd.DataFrame] = None, epochs: int = 100, batch_size: int = 128, lr: float = 1e-4, patience: int = 5, num_workers: int = 8, ): """ Train the TCR-pHLA model. Args: df_train: Training data. df_val: Optional validation data. df_test: Optional test data for evaluation after each epoch. df_neg: Optional negative samples for training. Set when resample_negatives=True. epochs: Number of epochs. batch_size: Batch size. lr: Learning rate. patience: Early stopping patience. num_workers: Data loading workers. Returns: history: Dict containing training and validation metrics. """ # ---- Prepare embeddings ---- print("Preparing peptide-HLA features...") all_dfs = [df for df in [df_train, df_val, df_test, df_neg] if df is not None] self.prepare_pep_hla_features(pd.concat(all_dfs, axis=0)) # ---- Validation loader (optional) ---- if df_val is not None: val_ds = TCRPepHLA_Dataset(df_val, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=tcr_pep_hla_collate_fn, pin_memory=True ) stopper = EarlyStopping(patience=patience, save_path=self.model_save_path) else: val_loader, stopper = None, None # ---- Optimizer ---- optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr) # ---- Metric history ---- history = {'train_loss': [], 'train_auc': []} if df_val is not None: history.update({'val_loss': [], 'val_auc': [], 'val_prc': []}) print("\nStart training TCR–pHLA model...") df_train_pos = df_train[df_train['label'] == 1].copy().reset_index(drop=True) for epoch in range(epochs): # ---------- Training ---------- if self.resample_negatives: df_train_neg = negative_sampling_phla(df_train_pos, random_state=epoch) if df_neg is not None: df_train_neg = pd.concat([df_train_neg, df_neg], axis=0).reset_index(drop=True) df_train_resample = pd.concat([df_train_pos, df_train_neg], axis=0).reset_index(drop=True) train_ds = TCRPepHLA_Dataset(df_train_resample, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict) else: train_ds = TCRPepHLA_Dataset(df_train, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=tcr_pep_hla_collate_fn, pin_memory=True ) self.model.train() train_labels, train_preds = [], [] epoch_loss = 0.0 for ibatch, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")): optimizer.zero_grad() probs, loss, _, _ = self.model(batch) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=2.0) optimizer.step() epoch_loss += loss.item() train_labels.extend(batch['label'].cpu().numpy().tolist()) train_preds.extend(probs.detach().cpu().numpy().tolist()) train_auc = roc_auc_score(train_labels, train_preds) train_loss = epoch_loss / (ibatch + 1) history['train_loss'].append(train_loss) history['train_auc'].append(train_auc) print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train AUC: {train_auc:.4f}") # ---------- Validation ---------- if df_val is not None: self.model.eval() val_loss_sum, val_labels, val_preds = 0.0, [], [] with torch.no_grad(): for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"): probs, loss, _, _ = self.model(batch) val_loss_sum += loss.item() val_labels.extend(batch['label'].cpu().numpy().tolist()) val_preds.extend(probs.detach().cpu().numpy().tolist()) val_loss = val_loss_sum / len(val_loader) val_auc = roc_auc_score(val_labels, val_preds) val_prc = average_precision_score(val_labels, val_preds) history['val_loss'].append(val_loss) history['val_auc'].append(val_auc) history['val_prc'].append(val_prc) print(f"Epoch {epoch+1}/{epochs} | Val AUC: {val_auc:.4f} | Val PRC: {val_prc:.4f} | Val Loss: {val_loss:.4f}") stopper(val_auc, self.model) if stopper.early_stop: print(f"Early stopping at epoch {epoch+1}") break # ---------- Optional Test ---------- if df_test is not None: test_ds = TCRPepHLA_Dataset(df_test, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict) test_loader = torch.utils.data.DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=tcr_pep_hla_collate_fn, pin_memory=True ) self.model.eval() test_labels, test_preds = [], [] with torch.no_grad(): for batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} [Test]"): probs, _, _, _ = self.model(batch) test_labels.extend(batch['label'].cpu().numpy().tolist()) test_preds.extend(probs.detach().cpu().numpy().tolist()) test_auc = roc_auc_score(test_labels, test_preds) test_prc = average_precision_score(test_labels, test_preds) print(f"Epoch {epoch+1}/{epochs} | Test AUC: {test_auc:.4f} | Test PRC: {test_prc:.4f}") # ---- Load best model only if validation used ---- if df_val is not None and os.path.exists(self.model_save_path): self.model.load_state_dict(torch.load(self.model_save_path, map_location=self.device)) print(f"✓ Training finished. Best model loaded from {self.model_save_path}") else: print("✓ Training finished (no validation set used).") return history def train_kfold( self, train_folds: List[Tuple[pd.DataFrame, pd.DataFrame]], df_test: Optional[pd.DataFrame] = None, df_neg: Optional[pd.DataFrame] = None, epochs: int = 100, batch_size: int = 128, lr: float = 1e-4, patience: int = 8, num_workers: int = 8, ) -> List[Dict[str, List[float]]]: """ K-fold cross-validation training for TCR-pHLA model. Args: train_folds: list of (train_df, val_df) tuples for each fold df_test: optional test data for evaluation after each epoch df_neg: optional negative samples for training. Set when resample_negatives=True. epochs: training epochs batch_size: batch size lr: learning rate patience: early stopping patience num_workers: dataloader workers Returns: List of training histories for each fold """ num_folds = len(train_folds) all_histories = [] print("\n" + "=" * 70) print(f"Starting {num_folds}-Fold Cross-Validation Training (TCR-pHLA)") print("=" * 70) for fold_id, (df_train, df_val) in enumerate(train_folds): print(f"\n{'='*70}") print(f"Training Fold {fold_id+1}/{num_folds}") print(f"{'='*70}") self._set_seed(self.seed + fold_id) self.model = TCRPeptideHLABindingPredictor( tcr_dim=self.model.tcr_dim, pep_dim=self.model.pep_dim, hla_dim=self.model.hla_dim, bilinear_dim=self.model.bilinear_dim, loss_fn=self.loss_fn_name, alpha=self.alpha, gamma=self.gamma, pos_weights=self.pos_weights, device=str(self.device), ).to(self.device) fold_save_path = self.model_save_path.replace(".pt", f"_fold{fold_id}.pt") history = self.train( df_train=df_train, df_val=df_val, df_test=df_test, df_neg=df_neg, epochs=epochs, batch_size=batch_size, lr=lr, patience=patience, num_workers=num_workers, ) torch.save(self.model.state_dict(), fold_save_path) print(f"✓ Saved fold {fold_id} model to {fold_save_path}") all_histories.append(history) print("\n" + "=" * 70) print(f"✓ All {num_folds} folds training completed (TCR-pHLA)") print("=" * 70) if df_val is not None: print("\nCross-Validation Summary:") print("-" * 70) for fold_id, hist in enumerate(all_histories): best_auc = max(hist['val_auc']) best_prc = max(hist['val_prc']) best_epoch = hist['val_auc'].index(best_auc) + 1 print(f"Fold {fold_id}: Best Val AUC = {best_auc:.4f}, Best Val PRC = {best_prc:.4f}, (Epoch {best_epoch})") mean_auc = np.mean([max(h['val_auc']) for h in all_histories]) std_auc = np.std([max(h['val_auc']) for h in all_histories]) print("-" * 70) print(f"Mean Val AUC: {mean_auc:.4f} ± {std_auc:.4f}") print("=" * 70 + "\n") return all_histories # -------------------- single-set predict -------------------- def _predict_single( self, df: pd.DataFrame, batch_size: int = 128, return_probs: bool = True, num_workers: int = 8 ): self.model.eval() ds = TCRPepHLA_Dataset(df, self.phys_dict, self.esm2_dict, self.struct_dict, self.pep_hla_feat_dict) loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=False, collate_fn=tcr_pep_hla_collate_fn, num_workers=num_workers, pin_memory=True ) preds = [] pep_feat_all = [] attn_all = [] with torch.no_grad(): for batch in tqdm(loader, desc="Predicting (TCR-pHLA)"): probs, _, pep_feature, attn_dict = self.model(batch) preds.extend(probs.tolist()) pep_feat_all.append(pep_feature) attn_all.append(attn_dict) preds = np.array(preds) if not return_probs: preds = (preds >= 0.5).astype(int) return preds, pep_feat_all, attn_all # ================================================================ # Ensemble prediction # ================================================================ def _predict_ensemble( self, df: pd.DataFrame, batch_size: int = 128, num_folds: int = 5, ensemble_method: str = 'mean', return_probs: bool = True, num_workers: int = 8 ) -> np.ndarray: """ Ensemble prediction using multiple fold models. """ print(f"\nEnsemble prediction using {num_folds} TCR–pHLA models...") print(f"Ensemble method: {ensemble_method}") pep_feats_folds = [] attn_dict_folds = [] all_preds = [] for fold_id in range(num_folds): fold_model_path = self.model_save_path.replace(".pt", f"_fold{fold_id}.pt") if not os.path.exists(fold_model_path): print(f"⚠ Warning: {fold_model_path} not found, skipping...") continue print(f"Loading model from {fold_model_path}...") self.model.load_state_dict(torch.load(fold_model_path, map_location=self.device), strict=False) # Predict for this fold fold_preds, fold_pep_feature, fold_attn_dict = self._predict_single( df, batch_size=batch_size, return_probs=True, num_workers=num_workers ) all_preds.append(fold_preds) pep_feats_folds.append(fold_pep_feature) attn_dict_folds.append(fold_attn_dict) if len(all_preds) == 0: raise ValueError("No fold models found!") if ensemble_method == 'mean': ensemble_preds = np.mean(all_preds, axis=0) elif ensemble_method == 'median': ensemble_preds = np.median(all_preds, axis=0) else: raise ValueError(f"Unknown ensemble method: {ensemble_method}") print(f"✓ Ensemble prediction completed using {len(all_preds)} folds") if not return_probs: ensemble_preds = (ensemble_preds >= 0.5).astype(int) return ensemble_preds, pep_feats_folds, attn_dict_folds # ================================================================ # Unified predict() with ensemble support # ================================================================ def predict( self, df: pd.DataFrame, batch_size: int = 128, return_probs: bool = True, use_kfold: bool = False, num_folds: Optional[int] = None, ensemble_method: str = 'mean', num_workers: int = 8 ) -> Tuple[np.ndarray, List, List]: """ Predict binding probabilities or binary labels. If use_kfold=True, averages predictions across fold models. """ print('Preparing peptide-HLA features for prediction set...') self.prepare_pep_hla_features(df) if use_kfold: if num_folds is None: raise ValueError("num_folds must be specified when use_kfold=True") return self._predict_ensemble( df=df, batch_size=batch_size, num_folds=num_folds, ensemble_method=ensemble_method, return_probs=return_probs, num_workers=num_workers ) else: return self._predict_single(df, batch_size=batch_size, return_probs=return_probs, num_workers=num_workers) # ================================================================ # Unified evaluate() with ensemble support # ================================================================ def evaluate( self, df: pd.DataFrame, batch_size: int = 128, threshold: float = 0.5, use_kfold: bool = False, num_folds: Optional[int] = None, ensemble_method: str = 'mean', num_workers: int = 8 ) -> Dict[str, float]: """ Evaluate model performance on a dataset. If use_kfold=True, performs ensemble evaluation across folds. """ y_true = df['label'].values y_prob, all_pep_features, merged_attn = self.predict( df, batch_size=batch_size, return_probs=True, use_kfold=use_kfold, num_folds=num_folds, ensemble_method=ensemble_method, num_workers=num_workers ) y_pred = (y_prob >= threshold).astype(int) tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel().tolist() accuracy = (tp + tn) / (tn + fp + fn + tp + 1e-9) try: mcc = ((tp*tn) - (fn*fp)) / np.sqrt(float((tp+fn)*(tn+fp)*(tp+fp)*(tn+fn)) + 1e-9) except: mcc = 0.0 recall = tp / (tp + fn + 1e-9) precision = tp / (tp + fp + 1e-9) f1 = 2 * precision * recall / (precision + recall + 1e-9) try: auc = roc_auc_score(y_true, y_prob, max_fpr=0.1) except: auc = 0.0 print("\n" + "=" * 70) print(f"Evaluation Results [{'K-Fold Ensemble' if use_kfold else 'Single Model'}]") print("=" * 70) print(f"tn={tn}, fp={fp}, fn={fn}, tp={tp}") print(f"AUC={auc:.4f} | ACC={accuracy:.4f} | MCC={mcc:.4f} | F1={f1:.4f} | P={precision:.4f} | R={recall:.4f}") print("=" * 70 + "\n") return dict( auc=auc, accuracy=accuracy, mcc=mcc, f1=f1, precision=precision, recall=recall, tn=tn, fp=fp, fn=fn, tp=tp )