import csv import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Dict, List, Tuple from torch.nn.utils.parametrizations import weight_norm from torch.nn import TransformerEncoder, TransformerEncoderLayer import esm import pandas as pd from tqdm import tqdm from typing import Dict, List, Tuple import tempfile from pathlib import Path import mdtraj as md # import io # import gzip import os from egnn_pytorch import EGNN from transformers import AutoTokenizer, EsmForProteinFolding import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # from re import search as re_search import re def determine_tcr_seq_vj(cdr3,V,J,chain,guess01=False): def file2dict(filename,key_fields,store_fields,delimiter='\t'): """Read file to a dictionary. key_fields: fields to be used as keys store_fields: fields to be saved as a list delimiter: delimiter used in the given file.""" dictionary={} with open(filename, newline='') as csvfile: reader = csv.DictReader(csvfile,delimiter=delimiter) for row in reader: keys = [row[k] for k in key_fields] store= [row[s] for s in store_fields] sub_dict = dictionary for key in keys[:-1]: if key not in sub_dict: sub_dict[key] = {} sub_dict = sub_dict[key] key = keys[-1] if key not in sub_dict: sub_dict[key] = [] sub_dict[key].append(store) return dictionary def get_protseqs_ntseqs(chain='B'): """returns sequence dictioaries for genes: protseqsV, protseqsJ, nucseqsV, nucseqsJ""" seq_dicts=[] for gene,type in zip(['v','j','v','j'],['aa','aa','nt','nt']): name = 'library/'+'tr'+chain.lower()+gene+'s_'+type+'.tsv' sdict = file2dict(name,key_fields=['Allele'],store_fields=[type+'_seq']) for g in sdict: sdict[g]=sdict[g][0][0] seq_dicts.append(sdict) return seq_dicts protVb,protJb,_,_ = get_protseqs_ntseqs(chain='B') protVa,protJa,_,_ = get_protseqs_ntseqs(chain='A') def splice_v_cdr3_j(pv: str, pj: str, cdr3: str) -> str: """ pv: V gene protein sequence pj: J gene protein sequence cdr3: C-starting, F/W-ending CDR3 sequence (protein) Returns: The spliced full sequence (V[:lastC] + CDR3 + J suffix) """ pv = (pv or "").strip().upper() pj = (pj or "").strip().upper() cdr3 = (cdr3 or "").strip().upper() # 1) V segment: Take the last 'C' (including the conserved C in V region) cpos = pv.rfind('C') if cpos == -1: raise ValueError("V sequence has no 'C' to anchor CDR3 start.") v_prefix = pv[:cpos] # up to and including C # 2) Align CDR3's "end overlap" in J # Start from the full length of cdr3, gradually shorten it, and find the longest suffix that can match in J j_suffix = pj # fallback (in extreme cases) for k in range(len(cdr3), 0, -1): tail = cdr3[-k:] # CDR3's suffix m = re.search(re.escape(tail), pj) if m: j_suffix = pj[m.end():] # Take the suffix from the matching segment break return v_prefix + cdr3 + j_suffix tcr_list = [] for i in range(len(cdr3)): cdr3_ = cdr3[i] V_ = V[i] J_ = J[i] if chain=='A': protseqsV = protVa protseqsJ = protJa else: protseqsV = protVb protseqsJ = protJb if guess01: if '*' not in V_: V_+='*01' if '*' not in J_: J_+='*01' pv = protseqsV[V_] pj = protseqsJ[J_] # t = pv[:pv.rfind('C')]+ cdr3_ + pj[re_search(r'[FW]G.[GV]',pj).start()+1:] t = splice_v_cdr3_j(pv, pj, cdr3_) tcr_list.append(t) return tcr_list # def negative_sampling_phla(df, neg_ratio=5, label_col='label', neg_label=0, random_state=42): # """ # Create negative samples by shuffling the TCR sequences while keeping the peptide-HLA pairs intact. # Ensures that the generated (TCR, peptide, HLA) triplets do not exist in the original dataset. # """ # negative_samples = [] # # 正样本 triplet 集合 # pos_triplets = set(zip( # df['tcra'], df['tcrb'], df['peptide'], df['HLA_full'] # )) # for i in range(neg_ratio): # shuffled_df = df.copy() # tcr_cols = ['tcra', 'cdr3a_start', 'cdr3a_end', 'tcrb', 'cdr3b_start', 'cdr3b_end'] # shuffled_tcr = df[tcr_cols].sample(frac=1, random_state=random_state + i).reset_index(drop=True) # for col in tcr_cols: # shuffled_df[col] = shuffled_tcr[col] # # 剔除:1) TCR 未改变的行 2) triplet 与正样本重复 # mask_keep = [] # for idx, row in shuffled_df.iterrows(): # triplet = (row['tcra'], row['tcrb'], row['peptide'], row['HLA_full']) # if triplet in pos_triplets: # mask_keep.append(False) # else: # mask_keep.append(True) # shuffled_df = shuffled_df[mask_keep] # shuffled_df[label_col] = neg_label # negative_samples.append(shuffled_df) # negative_samples = pd.concat(negative_samples, ignore_index=True).drop_duplicates() # return negative_samples import numpy as np import pandas as pd # def balanced_negative_sampling_phla(df, label_col='label', neg_label=0, random_state=42): # """ # 为每个 (peptide, HLA_full) 平衡采样负样本: # - 找出正样本最多的 peptide # - 该 peptide 的负样本数量 = 1:1,从其他 peptide 的 TCR 中采样(保持 peptide–HLA 配对) # - 其他 peptide 采样负样本,使每个 peptide 拥有相同总样本数 # - 保证 peptide 与 HLA_full 始终保持配对关系 # """ # np.random.seed(random_state) # pos_df = df[df[label_col] != neg_label].copy() # pos_counts = pos_df['peptide'].value_counts() # max_peptide = pos_counts.idxmax() # max_pos = pos_counts.max() # total_target = max_pos * 2 # 每个 peptide 的最终样本数(正+负) # neg_samples = [] # # 针对 max_peptide:负样本 = 1:1 # df_other_tcrs = pos_df[pos_df['peptide'] != max_peptide][['tcra', 'tcrb', 'cdr3a_start', 'cdr3a_end', 'cdr3b_start', 'cdr3b_end']].copy() # neg_max = pos_df[pos_df['peptide'] == max_peptide].copy() # sampled_tcrs = df_other_tcrs.sample( # n=max_pos, # replace=True if len(df_other_tcrs) < max_pos else False, # random_state=random_state # ).reset_index(drop=True) # neg_max.update(sampled_tcrs) # neg_max[label_col] = neg_label # neg_samples.append(neg_max) # # 针对其他 peptides # for pep, n_pos in pos_counts.items(): # if pep == max_peptide: # continue # n_neg = max(0, total_target - n_pos) # df_other_tcrs = pos_df[pos_df['peptide'] != pep][['tcra', 'tcrb', 'cdr3a_start', 'cdr3a_end', 'cdr3b_start', 'cdr3b_end']].copy() # neg_pep = pos_df[pos_df['peptide'] == pep].copy() # sampled_tcrs = df_other_tcrs.sample( # n=min(len(df_other_tcrs), n_neg), # replace=True if len(df_other_tcrs) < n_neg else False, # random_state=random_state # ).reset_index(drop=True) # sampled_tcrs = sampled_tcrs.iloc[:len(neg_pep)].copy() if len(sampled_tcrs) > len(neg_pep) else sampled_tcrs # neg_pep = pd.concat( # [neg_pep]*int(np.ceil(n_neg / len(neg_pep))), ignore_index=True # ).iloc[:n_neg] # neg_pep.update(sampled_tcrs) # neg_pep[label_col] = neg_label # neg_samples.append(neg_pep) # neg_df = pd.concat(neg_samples, ignore_index=True) # final_df = pd.concat([pos_df, neg_df], ignore_index=True).reset_index(drop=True) # return final_df def negative_sampling_phla(df, neg_ratio=5, label_col='label', neg_label=0, random_state=42): """ Create negative samples by shuffling TCRs while keeping peptide–HLA pairs intact. Ensures negative samples count = neg_ratio × positive samples count. """ np.random.seed(random_state) pos_triplets = set(zip(df['tcra'], df['tcrb'], df['peptide'], df['HLA_full'])) tcr_cols = ['tcra', 'cdr3a_start', 'cdr3a_end', 'tcrb', 'cdr3b_start', 'cdr3b_end'] n_pos = len(df) target_n_neg = n_pos * neg_ratio all_neg = [] i = 0 while len(all_neg) < target_n_neg: shuffled_df = df.copy() shuffled_tcr = df[tcr_cols].sample(frac=1, random_state=random_state + i).reset_index(drop=True) for col in tcr_cols: shuffled_df[col] = shuffled_tcr[col] mask_keep = [] for idx, row in shuffled_df.iterrows(): triplet = (row['tcra'], row['tcrb'], row['peptide'], row['HLA_full']) mask_keep.append(triplet not in pos_triplets) shuffled_df = shuffled_df[mask_keep] shuffled_df[label_col] = neg_label all_neg.append(shuffled_df) i += 1 if len(pd.concat(all_neg)) > target_n_neg * 1.5: break negative_samples = pd.concat(all_neg, ignore_index=True).drop_duplicates() negative_samples = negative_samples.sample( n=min(len(negative_samples), target_n_neg), random_state=random_state ).reset_index(drop=True) return negative_samples # def negative_sampling_tcr(df, neg_ratio=5, label_col='label', neg_label=0, random_state=42): # """ # Create negative samples by keeping TCR fixed but assigning random (peptide, HLA_full) # pairs that do not exist in the original dataset. # Ensures that the generated (TCR, peptide, HLA) triplets do not exist in the original data. # """ # np.random.seed(random_state) # negative_samples = [] # pos_triplets = set(zip(df['tcra'], df['tcrb'], df['peptide'], df['HLA_full'])) # all_pairs = list(set(zip(df['peptide'], df['HLA_full']))) # for i in range(neg_ratio): # neg_df = df.copy() # # 随机打乱 peptide–HLA 对,但保证不会选原来的那一个 # new_pairs = [] # for _, row in df.iterrows(): # while True: # pep, hla = all_pairs[np.random.randint(len(all_pairs))] # triplet = (row['tcra'], row['tcrb'], pep, hla) # if triplet not in pos_triplets: # new_pairs.append((pep, hla)) # break # neg_df[['peptide', 'HLA_full']] = pd.DataFrame(new_pairs, index=neg_df.index) # neg_df[label_col] = neg_label # negative_samples.append(neg_df) # negative_samples = pd.concat(negative_samples, ignore_index=True).drop_duplicates() # return negative_samples class EarlyStopping: def __init__(self, patience=10, verbose=True, delta=0.0, save_path='checkpoint.pt'): """ Early stopping based on both val_loss and val_auc. The model is saved whenever EITHER: - val_loss decreases by more than delta, OR - val_auc increases by more than delta. """ self.patience = patience self.verbose = verbose self.counter = 0 self.early_stop = False self.delta = delta self.save_path = save_path self.best_loss = np.inf self.best_auc = -np.inf def __call__(self, val_auc, model): improved = False # Check auc improvement if val_auc > self.best_auc + self.delta: self.best_auc = val_auc improved = True if improved: self.save_checkpoint(model, val_auc) self.counter = 0 else: self.counter += 1 if self.verbose: print(f"EarlyStopping counter: {self.counter} out of {self.patience}") if self.counter >= self.patience: self.early_stop = True def save_checkpoint(self, model, val_auc): """Save current best model.""" if self.verbose: print(f"Validation improved → Saving model (Score={val_auc:.4f}) to {self.save_path}") torch.save(model.state_dict(), self.save_path) # ============================================================================ # ESM2 Embedding via HuggingFace # ============================================================================ class ESM2Encoder(nn.Module): def __init__(self, device="cuda:0", layer=33, cache_dir='/data/cache'): """ Initialize an ESM2 encoder. Args: model_name (str): Name of the pretrained ESM2 model (e.g., 'esm2_t33_650M_UR50D'). device (str): Device to run on, e.g. 'cuda:0', 'cuda:1', or 'cpu'. layer (int): Layer number from which to extract representations. """ super().__init__() self.device = device self.layer = layer if cache_dir is None: cache_dir = os.path.dirname(os.path.abspath(__file__)) self.cache_dir = cache_dir os.makedirs(self.cache_dir, exist_ok=True) self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() self.batch_converter = self.alphabet.get_batch_converter() self.model = self.model.eval().to(device) def _cache_path(self, prefix): base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = base_dir + "/" + self.cache_dir os.makedirs(base_dir, exist_ok=True) return os.path.join(base_dir, f"{prefix}_esm2_layer{self.layer}.pt") def save_obj(self, obj, path): """Save object to a file (no compression).""" torch.save(obj, path) def load_obj(self, path): """Load object from a file (no compression).""" return torch.load(path, map_location="cpu", weights_only=False) @torch.no_grad() def _embed_batch(self, batch_data): batch_labels, batch_strs, batch_tokens = self.batch_converter(batch_data) batch_tokens = batch_tokens.to(self.device) results = self.model(batch_tokens, repr_layers=[self.layer], return_contacts=False) token_representations = results["representations"][self.layer] batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1) seq_reprs = [] for i, tokens_len in enumerate(batch_lens): seq_repr = token_representations[i, 1:tokens_len-1].cpu() seq_reprs.append(seq_repr) return seq_reprs @torch.no_grad() def forward(self, df, seq_col, prefix, batch_size=64, re_embed=False, cache_save=True): """ Add or update embeddings for sequences in a DataFrame. - If there are new sequences, automatically update the dictionary and save. - If re_embed=True, force re-computation of all sequences. """ cache_path = self._cache_path(prefix) emb_dict = {} if os.path.exists(cache_path) and not re_embed: print(f"[ESM2] Loading cached embeddings from {cache_path}") emb_dict = self.load_obj(cache_path) else: if re_embed: print(f"[ESM2] Re-embedding all sequences for {prefix}") else: print(f"[ESM2] No existing cache for {prefix}, will create new.") seqs = [str(s).strip().upper() for s in df[seq_col].tolist() if isinstance(s, str)] unique_seqs = sorted(set(seqs)) new_seqs = [s for s in unique_seqs if s not in emb_dict] if new_seqs: print(f"[ESM2] Found {len(new_seqs)} new sequences → computing embeddings...") data = [(str(i), s) for i, s in enumerate(new_seqs)] for i in tqdm(range(0, len(data), batch_size), desc=f"ESM2 update ({prefix})"): batch = data[i:i+batch_size] embs = self._embed_batch(batch) for (_, seq), emb in zip(batch, embs): emb_dict[seq] = emb.clone() if cache_save: print(f"[ESM2] Updating cache with new sequences") self.save_obj(emb_dict, cache_path) else: print(f"[ESM2] No new sequences for {prefix}, using existing cache") return emb_dict # ============================================================================ # ESMFold (transformers) # ============================================================================ class ESMFoldPredictorHF(nn.Module): def __init__(self, model_name="facebook/esmfold_v1", cache_dir=None, device='cpu', allow_tf32=True): super().__init__() self.model_name = model_name self.cache_dir = cache_dir self.device = device if allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # tokenizer and model print(f"Loading ESMFold model {model_name} on {device}... {'with' if cache_dir else 'without'} cache_dir: {cache_dir}") self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) self.model = EsmForProteinFolding.from_pretrained( model_name, low_cpu_mem_usage=True, cache_dir=cache_dir ).eval().to(self.device) @torch.no_grad() def infer_pdb_str(self, seq: str) -> str: pdb_str = self.model.infer_pdb(seq) return pdb_str @torch.no_grad() def forward_raw(self, seq: str): inputs = self.tokenizer([seq], return_tensors="pt", add_special_tokens=False) inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.model(**inputs) return outputs # ESMFoldOutput MAX_ASA_TIEN = { "ALA": 129.0, "ARG": 274.0, "ASN": 195.0, "ASP": 193.0, "CYS": 167.0, "GLN": 225.0, "GLU": 223.0, "GLY": 104.0, "HIS": 224.0, "ILE": 197.0, "LEU": 201.0, "LYS": 236.0, "MET": 224.0, "PHE": 240.0, "PRO": 159.0, "SER": 155.0, "THR": 172.0, "TRP": 285.0, "TYR": 263.0, "VAL": 174.0, } SS8_INDEX = {"H":0,"B":1,"E":2,"G":3,"I":4,"T":5,"S":6,"C":7,"-":7} class StructureFeatureExtractorNoDSSP(nn.Module): def __init__(self, device="cpu"): super().__init__() self.device = device self.in_dim = 6 + 8 + 1 + 1 + 1 # 17 self.to(torch.device(self.device)) @torch.no_grad() def _angles(self, traj): L = traj.n_residues sphi = np.zeros(L, dtype=np.float32); cphi = np.zeros(L, dtype=np.float32) spsi = np.zeros(L, dtype=np.float32); cpsi = np.zeros(L, dtype=np.float32) someg = np.zeros(L, dtype=np.float32); comeg = np.zeros(L, dtype=np.float32) # 1) phi: (C_{i-1}, N_i, CA_i, C_i) —— 当前残基 i 可用 atoms[1] (N_i) 来定位 phi_idx, phi_vals = md.compute_phi(traj) # phi_vals: (1, n_phi) if phi_vals.size > 0: for k, atoms in enumerate(phi_idx): res_i = traj.topology.atom(int(atoms[1])).residue.index # N_i 所在残基 if 0 <= res_i < L: ang = float(phi_vals[0, k]) sphi[res_i] = np.sin(ang); cphi[res_i] = np.cos(ang) # 2) psi: (N_i, CA_i, C_i, N_{i+1}) —— 当前残基 i 可用 atoms[1] (CA_i) psi_idx, psi_vals = md.compute_psi(traj) if psi_vals.size > 0: for k, atoms in enumerate(psi_idx): res_i = traj.topology.atom(int(atoms[1])).residue.index # CA_i if 0 <= res_i < L: ang = float(psi_vals[0, k]) spsi[res_i] = np.sin(ang); cpsi[res_i] = np.cos(ang) # 3) omega: (CA_i, C_i, N_{i+1}, CA_{i+1}) —— 当前残基 i 可用 atoms[0] (CA_i) omg_idx, omg_vals = md.compute_omega(traj) if omg_vals.size > 0: for k, atoms in enumerate(omg_idx): res_i = traj.topology.atom(int(atoms[0])).residue.index # CA_i if 0 <= res_i < L: ang = float(omg_vals[0, k]) someg[res_i] = np.sin(ang); comeg[res_i] = np.cos(ang) angles_feat = np.stack([sphi, cphi, spsi, cpsi, someg, comeg], axis=-1) # [L, 6] return angles_feat.astype(np.float32) @torch.no_grad() def _ss8(self, traj: md.Trajectory): ss = md.compute_dssp(traj, simplified=False)[0] L = traj.n_residues onehot = np.zeros((L, 8), dtype=np.float32) for i, ch in enumerate(ss): onehot[i, SS8_INDEX.get(ch, 7)] = 1.0 return onehot @torch.no_grad() def _rsa(self, traj: md.Trajectory): asa = md.shrake_rupley(traj, mode="residue")[0] # (L,) rsa = np.zeros_like(asa, dtype=np.float32) for i, res in enumerate(traj.topology.residues): max_asa = MAX_ASA_TIEN.get(res.name.upper(), None) rsa[i] = 0.0 if not max_asa else float(asa[i] / max_asa) return np.clip(rsa, 0.0, 1.0)[:, None] @torch.no_grad() def _contact_count(self, traj: md.Trajectory, cutoff_nm=0.8): L = traj.n_residues ca_atoms = traj.topology.select("name CA") if len(ca_atoms) == L: coors = traj.xyz[0, ca_atoms, :] # nm else: xyz = traj.xyz[0] coors = [] for res in traj.topology.residues: idxs = [a.index for a in res.atoms] coors.append(xyz[idxs, :].mean(axis=0)) coors = np.array(coors, dtype=np.float32) diff = coors[:, None, :] - coors[None, :, :] dist = np.sqrt((diff**2).sum(-1)) # nm mask = (dist < cutoff_nm).astype(np.float32) np.fill_diagonal(mask, 0.0) cnt = mask.sum(axis=1) return cnt[:, None].astype(np.float32) @torch.no_grad() def _plddt(self, pdb_file: str): # 用 Biopython 读取 PDB 的 B-factor(ESMFold/AlphaFold 会把 pLDDT 写在这里) from Bio.PDB import PDBParser import numpy as np parser = PDBParser(QUIET=True) structure = parser.get_structure("prot", pdb_file) model = structure[0] res_plddt = [] for chain in model: for residue in chain: atoms = list(residue.get_atoms()) if len(atoms) == 0: res_plddt.append(0.0) continue # 该残基原子 B-factor 的均值 bvals = [float(atom.get_bfactor()) for atom in atoms] res_plddt.append(float(np.mean(bvals))) # 归一化到 [0,1] plddt = np.array(res_plddt, dtype=np.float32) / 100.0 plddt = np.clip(plddt, 0.0, 1.0) return plddt[:, None] # [L,1] @torch.no_grad() def _parse_and_features(self, pdb_file: str): traj = md.load(pdb_file) L = traj.n_residues angles = self._angles(traj) # [L,6] ss8 = self._ss8(traj) # [L,8] rsa = self._rsa(traj) # [L,1] cnt = self._contact_count(traj) # [L,1] plddt = self._plddt(pdb_file) # [L,1] feats = np.concatenate([angles, ss8, rsa, cnt, plddt], axis=1).astype(np.float32) # [L,17] ca_atoms = traj.topology.select("name CA") if len(ca_atoms) == L: coors_nm = traj.xyz[0, ca_atoms, :] else: xyz = traj.xyz[0] res_coords = [] for res in traj.topology.residues: idxs = [a.index for a in res.atoms] res_coords.append(xyz[idxs, :].mean(axis=0)) coors_nm = np.array(res_coords, dtype=np.float32) coors_ang = coors_nm * 10.0 # nm -> Å return coors_ang.astype(np.float32), feats # [L,3], [L,17] @torch.no_grad() def forward(self, pdb_file: str): coors_ang, scalars = self._parse_and_features(pdb_file) coors = torch.tensor(coors_ang, dtype=torch.float32, device=self.device) # [N,3] scalars = torch.tensor(scalars, dtype=torch.float32, device=self.device) # [N,17] return scalars, coors # [N,17], [N,3] class ResiduePipelineWithHFESM: def __init__(self, esm_model_name="facebook/esmfold_v1", cache_dir=None, esm_device='cpu', allow_tf32=True ): self.esm = ESMFoldPredictorHF(esm_model_name, cache_dir, esm_device, allow_tf32) self.struct_encoder = StructureFeatureExtractorNoDSSP(device=esm_device) self.cache_dir = cache_dir @torch.no_grad() def __call__(self, seq: str, save_pdb_path: str = None) -> torch.Tensor: pdb_str = self.esm.infer_pdb_str(seq) if save_pdb_path is None: tmpdir = self.cache_dir if self.cache_dir is not None else tempfile.gettempdir() save_pdb_path = str(Path(tmpdir) / "esmfold_pred_fold5.pdb") Path(save_pdb_path).write_text(pdb_str) struct_emb, struct_coords = self.struct_encoder(save_pdb_path) return struct_emb, struct_coords def sanitize_protein_seq(seq: str) -> str: if not isinstance(seq, str): return "" s = "".join(seq.split()).upper() allowed = set("ACDEFGHIKLMNPQRSTVWYXBZJUO") return "".join([c for c in s if c in allowed]) @torch.no_grad() def batch_embed_to_dicts( df: pd.DataFrame, seq_col: str, pipeline, show_progress: bool = True, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], List[Tuple[str, str]]]: """ Returns: - emb_dict: {seq -> z(torch.Tensor[L, D])} - coord_dict:{seq -> coords(torch.Tensor[L, 3])} - failures: [(seq, err_msg), ...] """ raw_list = df[seq_col].astype(str).tolist() seqs = [] for s in raw_list: ss = sanitize_protein_seq(s) if ss: seqs.append(ss) uniq_seqs = sorted(set(seqs)) logger.info(f"Total rows: {len(df)}, valid seqs: {len(seqs)}, unique: {len(uniq_seqs)}") emb_dict: Dict[str, torch.Tensor] = {} coord_dict: Dict[str, torch.Tensor] = {} failures: List[Tuple[str, str]] = [] iterator = tqdm(uniq_seqs, desc="ESMfold Predicting structure...") if show_progress else uniq_seqs for seq in tqdm(iterator): if seq in emb_dict: continue try: z_t, c_t = pipeline(seq) # z: [L, D], coords: [L, 3] (torch.Tensor) emb_dict[seq] = z_t.detach().float().cpu() coord_dict[seq] = c_t.detach().float().cpu() except Exception as e: failures.append((seq, repr(e))) continue logger.info(f"[DONE] OK: {len(emb_dict)}, Failed: {len(failures)}") if failures[:3]: logger.error("[SAMPLE failures]", failures[:3]) return emb_dict, coord_dict, failures class ESMFoldEncoder(nn.Module): def __init__(self, model_name="facebook/esmfold_v1", esm_cache_dir="/data/esm_cache", cache_dir="/data/cache"): super(ESMFoldEncoder, self).__init__() self.model_name = model_name self.esm_cache_dir = esm_cache_dir self.cache_dir = cache_dir def save_obj(self, obj, path): """Save object to a file (no compression).""" torch.save(obj, path) def load_obj(self, path): """Load object from a file (no compression).""" return torch.load(path, map_location='cpu', weights_only=False) def load_esm_dict(self, device, df_data, chain, re_embed): def _clean_unique(series: pd.Series) -> list: cleaned = [] for s in series.astype(str).tolist(): ss = sanitize_protein_seq(s) if ss: cleaned.append(ss) return sorted(set(cleaned)) def _retry_embed_df( df: pd.DataFrame, chain: str, max_retries: int = 2, show_progress: bool = True, ): """ Try to embed protein sequences with retries on failures. Args: df (pd.DataFrame): A DataFrame containing a column `chain` with sequences. chain (str): The column name containing the sequences (e.g., "alpha", "beta"). pipeline: An embedding pipeline, should return (embedding, coords) for a sequence. max_retries (int): Maximum number of retries for failed sequences. show_progress (bool): Whether to display tqdm progress bars. Returns: feat_dict (Dict[str, torch.Tensor]): {sequence -> embedding tensor [L, D]}. coord_dict (Dict[str, torch.Tensor]): {sequence -> coordinate tensor [L, 3]}. failures (List[Tuple[str, str]]): List of (sequence, error_message) that still failed after retries. """ pipeline = ResiduePipelineWithHFESM( esm_model_name=self.model_name, cache_dir=self.esm_cache_dir, esm_device=device ) # 1. First attempt feat_dict, coord_dict, failures = batch_embed_to_dicts( df, chain, pipeline, show_progress=show_progress ) # 2. Retry loop for failed sequences tries = 0 while failures and tries < max_retries: tries += 1 retry_seqs = [s for s, _ in failures] logger.info(f"[retry {tries}/{max_retries}] {len(retry_seqs)} sequences") retry_df = pd.DataFrame({chain: retry_seqs}) f2, c2, failures = batch_embed_to_dicts( retry_df, chain, pipeline, show_progress=show_progress ) feat_dict.update(f2) coord_dict.update(c2) return feat_dict, coord_dict, failures def update_with_new_seqs(feat_dict, coord_dict, chain): base_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = base_dir + "/" + self.cache_dir os.makedirs(base_dir, exist_ok=True) path_feat = os.path.join(base_dir, f"{chain}_feat_dict.pt") path_coords = os.path.join(base_dir, f"{chain}_coord_dict.pt") all_seqs_clean = _clean_unique(df_data[chain]) new_seqs = [s for s in all_seqs_clean if s not in feat_dict] if not new_seqs: logger.info(f"No new {chain} sequences found") return feat_dict, coord_dict logger.info(f"Found new {chain} sequences, embedding...") df_new = pd.DataFrame({chain: new_seqs}) new_feat_dict, new_coord_dict, failures = _retry_embed_df(df_new, chain, max_retries=100) feat_dict.update(new_feat_dict) coord_dict.update(new_coord_dict) self.save_obj(feat_dict, path_feat) self.save_obj(coord_dict, path_coords) if failures: for seq, err in failures: logger.error(f"[create] failed: {seq} | {err}") logger.info(f"Updated and saved {path_feat} and {path_coords}") return feat_dict, coord_dict def get_or_create_dict(chain): base_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + self.cache_dir os.makedirs(base_dir, exist_ok=True) path_feat = os.path.join(base_dir, f"{chain}_feat_dict.pt") path_coords = os.path.join(base_dir, f"{chain}_coord_dict.pt") if os.path.exists(path_feat) and not re_embed: logger.info(f"Loading {path_feat} and {path_coords}") feat_dict = self.load_obj(path_feat) coord_dict = self.load_obj(path_coords) else: logger.info(f"{path_feat} and {path_coords} not found or re_embed=True, generating...") unique_seqs = _clean_unique(df_data[chain]) df_uniq = pd.DataFrame({chain: unique_seqs}) feat_dict, coord_dict, failures = _retry_embed_df( df_uniq, chain, show_progress=True, max_retries=100 ) self.save_obj(feat_dict, path_feat) self.save_obj(coord_dict, path_coords) if failures: for seq, err in failures: logger.error(f"[create] failed: {seq} | {err}") logger.info(f"Saved {path_feat} and {path_coords}") return feat_dict, coord_dict self.dict[chain+'_feat'], self.dict[chain+'_coord'] = update_with_new_seqs(*get_or_create_dict(chain), chain) def pad_and_stack(self, batch_feats, L_max, batch_coors): """ batch_feats: list of [L_i, D] tensors batch_coors: list of [L_i, 3] tensors return: feats: [B, L_max, D] coors: [B, L_max, 3] mask : [B, L_max] (True for real tokens) """ assert len(batch_feats) == len(batch_coors) B = len(batch_feats) D = batch_feats[0].shape[-1] feats_pad = [] coors_pad = [] masks = [] for x, c in zip(batch_feats, batch_coors): L = x.shape[0] pad_L = L_max - L # pad feats/coors with zeros feats_pad.append(torch.nn.functional.pad(x, (0, 0, 0, pad_L))) # [L_max, D] coors_pad.append(torch.nn.functional.pad(c, (0, 0, 0, pad_L))) # [L_max, 3] m = torch.zeros(L_max, dtype=torch.bool) m[:L] = True masks.append(m) feats = torch.stack(feats_pad, dim=0) # [B, L_max, D] coors = torch.stack(coors_pad, dim=0) # [B, L_max, 3] mask = torch.stack(masks, dim=0) # [B, L_max] return feats, coors, mask def forward(self, df_data, chain, device='cpu', re_embed=False): """ df_data: pd.DataFrame with a column `chain` containing sequences chain: str, e.g. "alpha" or "beta" device: str, e.g. 'cpu' or 'cuda:0' re_embed: bool, whether to re-embed even if cached files exist """ self.dict = {} self.load_esm_dict(device, df_data, chain, re_embed) batch_feats = [] batch_coors = [] for seq in df_data[chain].astype(str).tolist(): ss = sanitize_protein_seq(seq) if ss in self.dict[chain+'_feat'] and ss in self.dict[chain+'_coord']: batch_feats.append(self.dict[chain+'_feat'][ss]) batch_coors.append(self.dict[chain+'_coord'][ss]) else: raise ValueError(f"Sequence not found in embedding dict: {ss}") # L_max = max(x.shape[0] for x in batch_feats) return batch_feats, batch_coors # =================================== Dataset / Collate =========================================== class PepHLA_Dataset(torch.utils.data.Dataset): def __init__(self, df, phys_dict, esm2_dict, struct_dict): self.df = df self.phys_dict = phys_dict self.esm2_dict = esm2_dict self.struct_dict = struct_dict def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] pep = row['peptide'] hla = row['HLA_full'] label = torch.tensor(row['label'], dtype=torch.float32) pep_phys = self.phys_dict['pep'][pep] pep_esm = self.esm2_dict['pep'][pep] hla_phys = self.phys_dict['hla'][hla] hla_esm = self.esm2_dict['hla'][hla] hla_struct, hla_coord = self.struct_dict[hla] return { 'pep_phys': pep_phys, 'pep_esm': pep_esm, 'hla_phys': hla_phys, 'hla_esm': hla_esm, 'hla_struct': hla_struct, 'hla_coord': hla_coord, 'label': label, 'pep_id': pep, 'hla_id': hla, } def peptide_hla_collate_fn(batch): def pad_or_crop(x, original_len, target_len): L, D = x.shape valid_len = min(original_len, target_len) valid_part = x[:valid_len] if valid_len < target_len: pad_len = target_len - valid_len padding = x.new_zeros(pad_len, D) return torch.cat([valid_part, padding], dim=0) else: return valid_part out_batch = {} pep_lens = [len(item['pep_id']) for item in batch] max_pep_len = max(pep_lens) for key in batch[0].keys(): if key == 'label': out_batch[key] = torch.stack([item[key] for item in batch]) elif key.startswith('pep_') and not key.endswith('_id'): out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['pep_id']), max_pep_len) for item in batch]) elif key.endswith('_id'): out_batch[key] = [item[key] for item in batch] else: out_batch[key] = torch.stack([item[key] for item in batch]) def make_mask(lengths, max_len): masks = [] for L in lengths: m = torch.zeros(max_len, dtype=torch.bool) m[:L] = True masks.append(m) return torch.stack(masks) out_batch['pep_mask'] = make_mask(pep_lens, max_pep_len) return out_batch # =================================== Dataset / Collate =========================================== class TCRPepHLA_Dataset(torch.utils.data.Dataset): """ Dataset for TCRα + TCRβ + peptide + HLA binding. """ def __init__(self, df, phys_dict, esm2_dict, struct_dict, pep_hla_feat_dict): self.df = df self.phys_dict = phys_dict self.esm2_dict = esm2_dict self.struct_dict = struct_dict self.pep_hla_feat_dict = pep_hla_feat_dict def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] tcra = row['tcra'] tcrb = row['tcrb'] pep = row['peptide'] hla = row['HLA_full'] label = torch.tensor(row['label'], dtype=torch.float32) # ---- TCRα ---- tcra_phys = self.phys_dict['tcra'][tcra] tcra_esm = self.esm2_dict['tcra'][tcra] tcra_struct, tcra_coord = self.struct_dict['tcra'][tcra] tcra_cdr3_start = torch.tensor(row['cdr3a_start'], dtype=torch.long) tcra_cdr3_end = torch.tensor(row['cdr3a_end'], dtype=torch.long) # ---- TCRβ ---- tcrb_phys = self.phys_dict['tcrb'][tcrb] tcrb_esm = self.esm2_dict['tcrb'][tcrb] tcrb_struct, tcrb_coord = self.struct_dict['tcrb'][tcrb] tcrb_cdr3_start = torch.tensor(row['cdr3b_start'], dtype=torch.long) tcrb_cdr3_end = torch.tensor(row['cdr3b_end'], dtype=torch.long) # ---- peptide ---- pep_phys = self.phys_dict['pep'][pep] pep_esm = self.esm2_dict['pep'][pep] pep_struct, pep_coord = self.struct_dict['pep'][pep] # ---- HLA ---- hla_phys = self.phys_dict['hla'][hla] hla_esm = self.esm2_dict['hla'][hla] hla_struct, hla_coord = self.struct_dict['hla'][hla] feats = self.pep_hla_feat_dict[(pep, hla)] pep_feat_pretrain = feats['pep_feat_pretrain'] hla_feat_pretrain = feats['hla_feat_pretrain'] return { # TCRα 'tcra_phys': tcra_phys, 'tcra_esm': tcra_esm, 'tcra_struct': tcra_struct, 'tcra_coord': tcra_coord, 'cdr3a_start': tcra_cdr3_start, 'cdr3a_end': tcra_cdr3_end, # TCRβ 'tcrb_phys': tcrb_phys, 'tcrb_esm': tcrb_esm, 'tcrb_struct': tcrb_struct, 'tcrb_coord': tcrb_coord, 'cdr3b_start': tcrb_cdr3_start, 'cdr3b_end': tcrb_cdr3_end, # peptide 'pep_phys': pep_phys, 'pep_esm': pep_esm, 'pep_struct': pep_struct, 'pep_coord': pep_coord, # HLA 'hla_phys': hla_phys, 'hla_esm': hla_esm, 'hla_struct': hla_struct, 'hla_coord': hla_coord, 'tcra_id': tcra, 'tcrb_id': tcrb, 'pep_id': pep, 'hla_id': hla, 'label': label, 'pep_feat_pretrain': pep_feat_pretrain, 'hla_feat_pretrain': hla_feat_pretrain, } # =================================== Collate Function =========================================== def tcr_pep_hla_collate_fn(batch): def pad_or_crop(x, original_len, target_len): L, D = x.shape valid_len = min(original_len, target_len) valid_part = x[:valid_len] if valid_len < target_len: pad_len = target_len - valid_len padding = x.new_zeros(pad_len, D) return torch.cat([valid_part, padding], dim=0) else: return valid_part out_batch = {} tcra_lens = [len(item['tcra_id']) for item in batch] tcrb_lens = [len(item['tcrb_id']) for item in batch] pep_lens = [len(item['pep_id']) for item in batch] max_tcra_len = max(tcra_lens) max_tcrb_len = max(tcrb_lens) max_pep_len = max(pep_lens) for key in batch[0].keys(): if key == 'label': out_batch[key] = torch.stack([item[key] for item in batch]) elif key.startswith('tcra_') and not key.endswith('_id'): out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['tcra_id']), max_tcra_len) for item in batch]) elif key.startswith('tcrb_') and not key.endswith('_id'): out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['tcrb_id']), max_tcrb_len) for item in batch]) elif key.startswith('pep_') and not key.endswith('_id'): out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['pep_id']), max_pep_len) for item in batch]) elif key.endswith('_id'): out_batch[key] = [item[key] for item in batch] else: out_batch[key] = torch.stack([item[key] for item in batch]) def make_mask(lengths, max_len): masks = [] for L in lengths: m = torch.zeros(max_len, dtype=torch.bool) m[:L] = True masks.append(m) return torch.stack(masks) out_batch['tcra_mask'] = make_mask(tcra_lens, max_tcra_len) out_batch['tcrb_mask'] = make_mask(tcrb_lens, max_tcrb_len) out_batch['pep_mask'] = make_mask(pep_lens, max_pep_len) return out_batch # ==================================== 小积木:投影 + 门控 ========================================= class ResidueProjector(nn.Module): """把不同分支的通道维度对齐到同一 D""" def __init__(self, in_dim, out_dim): super().__init__() self.proj = nn.Linear(in_dim, out_dim) if in_dim != out_dim else nn.Identity() def forward(self, x): # x: [B,L,Di] return self.proj(x) class ResidueDoubleFusion(nn.Module): """ ResidueDoubleFusion: A residue-level two-branch fusion module that combines two modalities (x1, x2) using cross-attention followed by gated residual fusion and linear projection. Typical usage: - x1: physicochemical features - x2: ESM embeddings (or structure features) """ def __init__(self, dim, num_heads=8, dropout=0.1): super().__init__() self.dim = dim # Cross-attention: allows information flow between two modalities self.cross_attn = nn.MultiheadAttention( embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True ) # Gating mechanism: adaptively weight two modalities per residue self.gate = nn.Sequential( nn.Linear(dim * 2, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() ) # Optional projection after fusion self.out_proj = nn.Linear(dim, dim) # Layer norms for stable training self.norm_x1 = nn.LayerNorm(dim) self.norm_x2 = nn.LayerNorm(dim) self.norm_out = nn.LayerNorm(dim) def forward(self, x1, x2): """ Args: x1: Tensor [B, L, D] - first modality (e.g., physicochemical) x2: Tensor [B, L, D] - second modality (e.g., ESM embeddings) Returns: fused: Tensor [B, L, D] - fused residue-level representation """ # 1) Normalize both branches x1_norm = self.norm_x1(x1) x2_norm = self.norm_x2(x2) # 2) Cross-attention (x1 queries, x2 keys/values) # This allows x1 to attend to x2 at each residue position attn_out, _ = self.cross_attn( query=x1_norm, key=x2_norm, value=x2_norm ) # [B, L, D] # 3) Gating between original x1 and attention-enhanced x2 gate_val = self.gate(torch.cat([x1, attn_out], dim=-1)) # [B, L, 1] fused = gate_val * x1 + (1 - gate_val) * attn_out # 4) Optional projection + normalization fused = self.out_proj(fused) fused = self.norm_out(fused) return fused class ResidueTripleFusion(nn.Module): """ ResidueTripleFusion: A hierarchical three-branch feature fusion module for residue-level representations. Step 1: Fuse physicochemical features and protein language model embeddings. Step 2: Fuse the intermediate representation with structure-based features. Each fusion step uses ResidueDoubleFusion (cross-attention + gating + linear projection). """ def __init__(self, dim, num_heads=8, dropout=0.1): super().__init__() # Fuse physicochemical + ESM embeddings self.fuse_phys_esm = ResidueDoubleFusion(dim, num_heads=num_heads, dropout=dropout) # Fuse the fused phys+esm representation with structure embeddings self.fuse_f12_struct = ResidueDoubleFusion(dim, num_heads=num_heads, dropout=dropout) def forward(self, phys, esm, struct): """ Args: phys: Tensor [B, L, D], physicochemical features (e.g., AAindex-based) esm: Tensor [B, L, D], protein language model embeddings (e.g., ESM2, ProtT5) struct: Tensor [B, L, D], structure-derived features (e.g., torsion, RSA) Returns: fused: Tensor [B, L, D], final fused representation """ # Step 1: Fuse physicochemical and ESM embeddings f12 = self.fuse_phys_esm(phys, esm) # Step 2: Fuse the intermediate fused representation with structure features fused = self.fuse_f12_struct(f12, struct) return fused class BANLayer(nn.Module): """ Bilinear Attention Network Layer with proper 2D masked-softmax. v_mask: [B, L_v] True=valid q_mask: [B, L_q] True=valid """ def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3): super().__init__() self.c = 32 self.k = k self.v_dim = v_dim self.q_dim = q_dim self.h_dim = h_dim self.h_out = h_out self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout) self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout) if 1 < k: self.p_net = nn.AvgPool1d(self.k, stride=self.k) if h_out <= self.c: self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) else: self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) self.bn = nn.BatchNorm1d(h_dim) def attention_pooling(self, v, q, att_map): # att_map: [B, L_v, L_q] logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q)) if 1 < self.k: logits = self.p_net(logits.unsqueeze(1)).squeeze(1) * self.k return logits def _masked_softmax_2d(self, logits, v_mask, q_mask): """ logits: [B, h_out, L_v, L_q] v_mask: [B, L_v] or None q_mask: [B, L_q] or None return: probs [B, h_out, L_v, L_q] (masked entries=0, 在有效的二维子矩阵内归一化) """ B, H, Lv, Lq = logits.shape device = logits.device if v_mask is None: v_mask = torch.ones(B, Lv, dtype=torch.bool, device=device) if q_mask is None: q_mask = torch.ones(B, Lq, dtype=torch.bool, device=device) mask2d = (v_mask[:, :, None] & q_mask[:, None, :]) # [B, Lv, Lq] mask2d = mask2d[:, None, :, :].expand(B, H, Lv, Lq) # [B, H, Lv, Lq] logits = logits.masked_fill(~mask2d, -float('inf')) # 在 Lv*Lq 的联合空间做 softmax flat = logits.view(B, H, -1) # [B, H, Lv*Lq] # 处理极端情况:某些样本可能无有效格子,避免 NaN flat = torch.where(torch.isinf(flat), torch.full_like(flat, -1e9), flat) flat = F.softmax(flat, dim=-1) flat = torch.nan_to_num(flat, nan=0.0) # 安全兜底 probs = flat.view(B, H, Lv, Lq) # 把被 mask 的位置清零(数值稳定 & 便于可视化) probs = probs * mask2d.float() return probs def forward(self, v, q, v_mask=None, q_mask=None, softmax=True): """ v: [B, L_v, Dv], q: [B, L_q, Dq] """ B, L_v, _ = v.size() _, L_q, _ = q.size() v_ = self.v_net(v) # [B, L_v, h_dim*k] q_ = self.q_net(q) # [B, L_q, h_dim*k] if self.h_out <= self.c: att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias # [B,H,Lv,Lq] else: v_t = v_.transpose(1, 2).unsqueeze(3) # [B, K, Lv, 1] q_t = q_.transpose(1, 2).unsqueeze(2) # [B, K, 1, Lq] d_ = torch.matmul(v_t, q_t) # [B, K, Lv, Lq] att_maps = self.h_net(d_.permute(0, 2, 3, 1)) # [B, Lv, Lq, H] att_maps = att_maps.permute(0, 3, 1, 2) # [B, H, Lv, Lq] if softmax: att_maps = self._masked_softmax_2d(att_maps, v_mask, q_mask) else: # 即使不 softmax,也把无效格子清 0,避免泄漏 if v_mask is not None: att_maps = att_maps.masked_fill(~v_mask[:, None, :, None], 0.0) if q_mask is not None: att_maps = att_maps.masked_fill(~q_mask[:, None, None, :], 0.0) # 注意:此时 v_ / q_ 仍是 [B, L, K],与 att_maps 的 [B,H,Lv,Lq] 对齐 logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :]) for i in range(1, self.h_out): logits = logits + self.attention_pooling(v_, q_, att_maps[:, i, :, :]) logits = self.bn(logits) return logits, att_maps class FCNet(nn.Module): def __init__(self, dims, act='ReLU', dropout=0.2): super(FCNet, self).__init__() layers = [] for i in range(len(dims) - 2): in_dim = dims[i] out_dim = dims[i + 1] if 0 < dropout: layers.append(nn.Dropout(dropout)) layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) if '' != act: layers.append(getattr(nn, act)()) if 0 < dropout: layers.append(nn.Dropout(dropout)) layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) if '' != act: layers.append(getattr(nn, act)()) self.main = nn.Sequential(*layers) def forward(self, x): return self.main(x) class StackedEGNN(nn.Module): def __init__(self, dim, layers, update_coors=False, **egnn_kwargs): super().__init__() self.layers = nn.ModuleList([ EGNN(dim=dim, update_coors=update_coors, **egnn_kwargs) for _ in range(layers) ]) def forward(self, feats, coors, mask=None): # feats: [B, L_max, D], coors: [B, L_max, 3], mask: [B, L_max] (bool) for layer in self.layers: feats, coors = layer(feats, coors, mask=mask) return feats, coors class FocalLoss(nn.Module): def __init__(self, alpha=0.5, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') p_t = torch.exp(-bce_loss) alpha_weight = self.alpha * targets + (1 - self.alpha) * (1 - targets) loss = alpha_weight * (1 - p_t) ** self.gamma * bce_loss if self.reduction == 'mean': return torch.mean(loss) elif self.reduction == 'sum': return torch.sum(loss) else: return loss # ===================================== 主模型(完全版) =========================================== class PeptideHLABindingPredictor(nn.Module): def __init__( self, phys_dim=20, # 物化编码的输出维度(你定义的 PhysicochemicalEncoder) pep_dim=256, # 统一后的 peptide 通道 hla_dim=256, # 统一后的 HLA 通道 bilinear_dim=256, pseudo_seq_pos=None, # 口袋位点(假定 0-based 且落在 [0,179]) device="cuda:0", loss_fn='bce', alpha=0.5, gamma=2.0, dropout=0.2, pos_weights=None ): super().__init__() self.device = device self.pep_dim = pep_dim self.hla_dim = hla_dim self.bilinear_dim = bilinear_dim self.alpha = alpha self.gamma = gamma self.dropout = dropout if loss_fn == 'bce': self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weights]) if pos_weights is not None else None) elif loss_fn == 'focal': self.loss_fn = FocalLoss(alpha=alpha, gamma=gamma) else: raise ValueError(f"Unknown loss function: {loss_fn}") self.se3_model = StackedEGNN( dim=17, layers=3 ) self.max_pep_len = 20 self.max_hla_len = 180 self.pep_pos_embed = nn.Parameter(torch.randn(self.max_pep_len, pep_dim)) self.hla_pos_embed = nn.Parameter(torch.randn(self.max_hla_len, hla_dim)) # —— 分支投影到统一维度(逐残基)—— # peptide 分支(Physicochem -> pep_dim, ESM2(1280) -> pep_dim) self.proj_pep_phys = ResidueProjector(in_dim=phys_dim, out_dim=pep_dim) # 你的 PhysEnc 输出维设成 pep_dim self.proj_pep_esm = ResidueProjector(in_dim=1280, out_dim=pep_dim) # HLA 分支(Physicochem -> hla_dim, ESM2(1280) -> hla_dim, Struct(17/或se3_out) -> hla_dim) self.proj_hla_phys = ResidueProjector(in_dim=phys_dim, out_dim=hla_dim) # 你的 PhysEnc 输出维设成 hla_dim self.proj_hla_esm = ResidueProjector(in_dim=1280, out_dim=hla_dim) self.proj_hla_se3 = ResidueProjector(in_dim=17, out_dim=hla_dim) # 让 se3_model 输出维就是 hla_dim # —— 门控融合(逐残基)—— self.gate_pep = ResidueDoubleFusion(pep_dim) # pep_phys × pep_esm self.gate_hla = ResidueTripleFusion(hla_dim) # hla_phys × hla_esm × hla_struct d_model = self.pep_dim n_heads = 8 # 1. 用于 "Peptide 查询 HLA" (pep_q_hla_kv) self.cross_attn_pep_hla = nn.MultiheadAttention( embed_dim=d_model, num_heads=n_heads, dropout=self.dropout, batch_first=True ) self.norm_cross_pep = nn.LayerNorm(d_model) # 2. 用于 "HLA 查询 Peptide" (hla_q_pep_kv) self.cross_attn_hla_pep = nn.MultiheadAttention( embed_dim=d_model, num_heads=n_heads, dropout=self.dropout, batch_first=True ) self.norm_cross_hla = nn.LayerNorm(d_model) # —— 交互模块(Bilinear attention map)—— self.bi_attn = BANLayer(v_dim=pep_dim, q_dim=hla_dim, h_dim=bilinear_dim, h_out=4, k=3) # —— 头部 —— self.head = nn.Sequential( nn.Linear(bilinear_dim, bilinear_dim), nn.ReLU(), nn.Linear(bilinear_dim, 1) ) # —— 口袋位点 —— if pseudo_seq_pos is None: pseudo_seq_pos = [i-2 for i in [7, 9, 24, 45, 59, 62, 63, 66, 67, 69, 70, 73, 74, 76, 77, 80, 81, 84, 95, 97, 99, 114, 116, 118, 143, 147, 150, 152, 156, 158, 159, 163, 167, 171]] self.register_buffer("contact_idx", torch.tensor(pseudo_seq_pos, dtype=torch.long)) # -------------------------------------------- # Transformer Encoders for peptide & HLA # -------------------------------------------- encoder_layer_pep = TransformerEncoderLayer( d_model=pep_dim, # 输入维度 nhead=8, # 注意力头数(可调) dim_feedforward=pep_dim*4, dropout=self.dropout, batch_first=True # 输入形状 [B,L,D] ) self.pep_encoder = TransformerEncoder(encoder_layer_pep, num_layers=2) # 可以调整层数 encoder_layer_hla = TransformerEncoderLayer( d_model=hla_dim, nhead=8, dim_feedforward=hla_dim*4, dropout=self.dropout, batch_first=True ) self.hla_encoder = TransformerEncoder(encoder_layer_hla, num_layers=1) # -------------------------- 工具:把 list of [L,D] pad 成 [B,L_max,D] -------------------------- def _pad_stack(self, tensors, L_max=None): Ls = [t.shape[0] for t in tensors] if L_max is None: L_max = max(Ls) D = tensors[0].shape[-1] B = len(tensors) out = tensors[0].new_zeros((B, L_max, D)) mask = torch.zeros(B, L_max, dtype=torch.bool, device=out.device) for i, t in enumerate(tensors): L = t.shape[0] out[i, :L] = t mask[i, :L] = True return out, mask # [B,L_max,D], [B,L_max] # ----------------------------------- 口袋掩码 -------------------------------------- def _mask_to_pockets(self, hla_feat): """ 从 HLA 特征中只保留 pocket 位点,返回: - hla_pocket: [B, n_pocket, D] - pocket_mask: [B, n_pocket] (全部 True) """ B, L, D = hla_feat.shape # ensure idx in [0, L-1] idx = self.contact_idx.clamp(min=0, max=L-1) # gather pocket features hla_pocket = hla_feat[:, idx, :] # [B, n_pocket, D] return hla_pocket def add_positional_encoding(self, x, pos_embed): """ x: [B, L, D] pos_embed: [L_max, D] """ B, L, D = x.shape # 截取前 L 个位置编码 pe = pos_embed[:L, :].unsqueeze(0).expand(B, -1, -1) # [B, L, D] return x + pe def forward(self, batch): # take batch from DataLoader pep_phys = batch['pep_phys'].to(self.device, non_blocking=True) pep_esm = batch['pep_esm'].to(self.device, non_blocking=True) hla_phys = batch['hla_phys'].to(self.device, non_blocking=True) hla_esm = batch['hla_esm'].to(self.device, non_blocking=True) hla_struct = batch['hla_struct'].to(self.device, non_blocking=True) hla_coord = batch['hla_coord'].to(self.device, non_blocking=True) labels = batch['label'].to(self.device) # 1) peptide 物化 + ESM2 → gate 融合 pep_phys = self.proj_pep_phys(pep_phys) pep_esm = self.proj_pep_esm(pep_esm) pep_feat = self.gate_pep(pep_phys, pep_esm) # [B, Lp, D] pep_feat = self.add_positional_encoding(pep_feat, self.pep_pos_embed) pep_feat = self.pep_encoder(pep_feat, src_key_padding_mask=~batch['pep_mask'].to(self.device, non_blocking=True)) # 2) HLA 物化 + ESM2 + 结构 → SE3 → gate 融合 hla_phys = self.proj_hla_phys(hla_phys) hla_esm = self.proj_hla_esm(hla_esm) # hla_struct 是 [B, 180, 17],先过 SE3 hla_se3 = self.se3_model(hla_struct, hla_coord, None)[0] # [B, 180, 17] hla_se3 = self.proj_hla_se3(hla_se3) # →256 hla_feat = self.gate_hla(hla_phys, hla_esm, hla_se3) hla_feat = self.add_positional_encoding(hla_feat, self.hla_pos_embed) hla_feat = self.hla_encoder(hla_feat) # cross attention for pep pep_feat_cross, _ = self.cross_attn_pep_hla( query=pep_feat, key=hla_feat, value=hla_feat, key_padding_mask=None ) # cross attention for hla hla_feat_cross, _ = self.cross_attn_hla_pep( query=hla_feat, key=pep_feat, value=pep_feat, key_padding_mask=~batch['pep_mask'].to(self.device, non_blocking=True) ) pep_feat_updated = self.norm_cross_pep(pep_feat + pep_feat_cross) hla_feat_updated = self.norm_cross_hla(hla_feat + hla_feat_cross) # 3) mask HLA 口袋位点 hla_pocket = self._mask_to_pockets(hla_feat_updated) # 4) bilinear attention fused_vec, attn = self.bi_attn( pep_feat_updated, hla_pocket, v_mask=batch['pep_mask'].to(self.device, non_blocking=True), q_mask=None ) logits = self.head(fused_vec).squeeze(-1) probs = torch.sigmoid(logits).detach().cpu().numpy() binding_loss = self.loss_fn(logits, labels.float()) return probs, binding_loss, attn.detach().cpu().numpy().sum(axis=1), fused_vec.detach().cpu().numpy() # -------------------------- 编码器复用接口(给 TCR-HLA 模型用) -------------------------- def _pad_peptide(self, x, max_len): """Pad peptide feature tensor [1, L, D] to [1, max_len, D].""" B, L, D = x.shape if L < max_len: pad = x.new_zeros(B, max_len - L, D) return torch.cat([x, pad], dim=1) else: return x[:, :max_len, :] @torch.no_grad() def encode_peptide_hla(self, pep_id, pep_phys, pep_esm, hla_phys, hla_esm, hla_struct, hla_coord, max_pep_len): Lp = len(pep_id) pep_phys = self.proj_pep_phys(pep_phys) pep_esm = self.proj_pep_esm(pep_esm) pep_phys = self._pad_peptide(pep_phys, max_pep_len) pep_esm = self._pad_peptide(pep_esm, max_pep_len) device = pep_phys.device pep_mask = torch.zeros(1, max_pep_len, dtype=torch.bool, device=device) pep_mask[0, :Lp] = True pep_feat = self.gate_pep(pep_phys, pep_esm) pep_feat = self.add_positional_encoding(pep_feat, self.pep_pos_embed) pep_feat = self.pep_encoder(pep_feat, src_key_padding_mask=~pep_mask) # 2) hla encoding hla_phys = self.proj_hla_phys(hla_phys) hla_esm = self.proj_hla_esm(hla_esm) hla_se3 = self.se3_model(hla_struct, hla_coord, None)[0] hla_se3 = self.proj_hla_se3(hla_se3) hla_feat = self.gate_hla(hla_phys, hla_esm, hla_se3) hla_feat = self.add_positional_encoding(hla_feat, self.hla_pos_embed) hla_feat = self.hla_encoder(hla_feat) # --- 3a. Peptide (Q) 查询 HLA (K, V) --- pep_feat_cross, _ = self.cross_attn_pep_hla( query=pep_feat, key=hla_feat, value=hla_feat, key_padding_mask=None ) pep_feat_updated = self.norm_cross_pep(pep_feat + pep_feat_cross) # --- 3b. HLA (Q) 查询 Peptide (K, V) --- hla_feat_cross, _ = self.cross_attn_hla_pep( query=hla_feat, key=pep_feat, value=pep_feat, key_padding_mask=~pep_mask ) hla_feat_updated = self.norm_cross_hla(hla_feat + hla_feat_cross) return pep_feat_updated, hla_feat_updated class TCRPeptideHLABindingPredictor(nn.Module): def __init__( self, tcr_dim=256, pep_dim=256, hla_dim=256, bilinear_dim=256, loss_fn='bce', alpha=0.5, gamma=2.0, dropout=0.1, device='cuda:0', pos_weights=None ): super().__init__() # TCR α / β position embeddings self.max_tcra_len = 500 self.max_tcrb_len = 500 self.max_pep_len = 20 self.max_hla_len = 180 self.alpha = alpha self.gamma = gamma self.dropout = dropout if loss_fn == 'bce': self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weights]) if pos_weights is not None else None) elif loss_fn == 'focal': self.loss_fn = FocalLoss(alpha=alpha, gamma=gamma) else: raise ValueError(f"Unknown loss function: {loss_fn}") self.tcra_pos_embed = nn.Parameter(torch.randn(self.max_tcra_len, tcr_dim)) self.tcrb_pos_embed = nn.Parameter(torch.randn(self.max_tcrb_len, tcr_dim)) self.pep_pos_embed = nn.Parameter(torch.randn(self.max_pep_len, pep_dim)) self.hla_pos_embed = nn.Parameter(torch.randn(self.max_hla_len, hla_dim)) self.device = device self.tcr_dim = tcr_dim self.pep_dim = pep_dim self.hla_dim = hla_dim self.bilinear_dim = bilinear_dim d_model = tcr_dim n_heads = 8 self.cross_attn_tcra_pep = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True) self.cross_attn_tcra_hla = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True) self.cross_attn_tcrb_pep = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True) self.cross_attn_tcrb_hla = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True) self.norm_tcra_pep = nn.LayerNorm(d_model) self.norm_tcra_hla = nn.LayerNorm(d_model) self.norm_tcrb_pep = nn.LayerNorm(d_model) self.norm_tcrb_hla = nn.LayerNorm(d_model) # ======================= # TCRα / TCRβ encoders # ======================= def make_tcr_encoder(): proj_phys = ResidueProjector(20, tcr_dim) proj_esm = ResidueProjector(1280, tcr_dim) proj_struct = ResidueProjector(17, tcr_dim) se3 = StackedEGNN(dim=17, layers=1) gate = ResidueTripleFusion(tcr_dim) encoder_layer = TransformerEncoderLayer( d_model=tcr_dim, nhead=8, dim_feedforward=tcr_dim*4, dropout=self.dropout, batch_first=True ) encoder = TransformerEncoder(encoder_layer, num_layers=2) return nn.ModuleDict(dict( proj_phys=proj_phys, proj_esm=proj_esm, proj_struct=proj_struct, se3=se3, gate=gate, encoder=encoder )) self.tcra_enc = make_tcr_encoder() self.tcrb_enc = make_tcr_encoder() # ======================= # Peptide encoder (phys + esm + structure) # ======================= self.proj_pep_phys = ResidueProjector(20, pep_dim) self.proj_pep_esm = ResidueProjector(1280, pep_dim) self.proj_pep_struct = ResidueProjector(17, pep_dim) self.pep_se3 = StackedEGNN(dim=17, layers=1) self.pep_gate = ResidueTripleFusion(pep_dim) pep_encoder_layer = TransformerEncoderLayer( d_model=pep_dim, nhead=8, dim_feedforward=pep_dim*4, dropout=self.dropout, batch_first=True ) self.pep_encoder = TransformerEncoder(pep_encoder_layer, num_layers=2) # ======================= # HLA encoder # ======================= self.proj_hla_phys = ResidueProjector(20, hla_dim) self.proj_hla_esm = ResidueProjector(1280, hla_dim) self.proj_hla_struct = ResidueProjector(17, hla_dim) self.hla_se3 = StackedEGNN(dim=17, layers=1) self.hla_gate = ResidueTripleFusion(hla_dim) hla_encoder_layer = TransformerEncoderLayer( d_model=hla_dim, nhead=8, dim_feedforward=hla_dim*4, dropout=self.dropout, batch_first=True ) self.hla_encoder = TransformerEncoder(hla_encoder_layer, num_layers=1) self.pep_gate_2 = ResidueDoubleFusion(pep_dim) self.hla_gate_2 = ResidueDoubleFusion(hla_dim) # ======================= # Bilinear interactions # ======================= self.bi_tcra_pep = BANLayer(tcr_dim, pep_dim, bilinear_dim, h_out=4, k=3) self.bi_tcrb_pep = BANLayer(tcr_dim, pep_dim, bilinear_dim, h_out=4, k=3) self.bi_tcra_hla = BANLayer(tcr_dim, hla_dim, bilinear_dim, h_out=4, k=3) self.bi_tcrb_hla = BANLayer(tcr_dim, hla_dim, bilinear_dim, h_out=4, k=3) # ======================= # Head # ======================= total_fused_dim = bilinear_dim * 4 self.head = nn.Sequential( nn.Linear(total_fused_dim, bilinear_dim), nn.ReLU(), nn.Linear(bilinear_dim, 1) ) def encode_tcr(self, x_phys, x_esm, x_struct, x_coord, x_mask, enc, pos_embed): phys = enc['proj_phys'](x_phys) esm = enc['proj_esm'](x_esm) se3 = enc['se3'](x_struct, x_coord, None)[0] se3 = enc['proj_struct'](se3) feat = enc['gate'](phys, esm, se3) feat = self.add_positional_encoding(feat, pos_embed) feat = enc['encoder'](feat, src_key_padding_mask=~x_mask) return feat def add_positional_encoding(self, x, pos_embed): """ x: [B, L, D] pos_embed: [L_max, D] """ B, L, D = x.shape pe = pos_embed[:L, :].unsqueeze(0).expand(B, -1, -1) return x + pe # def _extract_cdr3_segment(self, tcr_feat, cdr3_start, cdr3_end): # B, L, D = tcr_feat.shape # device = tcr_feat.device # max_len = (cdr3_end - cdr3_start + 1).max().item() # # [max_len], 0..max_len-1 # rel_idx = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1) # [B, max_len] # # absolute index = start + rel_idx # abs_idx = cdr3_start.unsqueeze(1) + rel_idx # # clamp end # abs_idx = abs_idx.clamp(0, L-1) # # mask positions beyond end # mask = rel_idx <= (cdr3_end - cdr3_start).unsqueeze(1) # # gather # # expand abs_idx to [B, max_len, D] for gather # gather_idx = abs_idx.unsqueeze(-1).expand(-1, -1, D) # out = torch.gather(tcr_feat, 1, gather_idx) # [B, max_len, D] # return out, mask def _extract_cdr3_segment(self, tcr_feat, cdr3_start, cdr3_end): """ Extracts CDR3 embeddings and corresponding mask. tcr_feat: [B, L, D] cdr3_start, cdr3_end: [B] Returns: out: [B, max_len, D] mask: [B, max_len] (True = valid) """ B, L, D = tcr_feat.shape device = tcr_feat.device # 每个样本的 cdr3 长度 lens = (cdr3_end - cdr3_start).clamp(min=0) max_len = lens.max().item() rel_idx = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1) # [B, max_len] abs_idx = cdr3_start.unsqueeze(1) + rel_idx # [B, max_len] # mask: True 表示有效 mask = rel_idx < lens.unsqueeze(1) # 注意这里 "<" 就够了 # 将超出范围的索引设为 0(任意有效索引都行,因为会被mask掉) abs_idx = torch.where(mask, abs_idx, torch.zeros_like(abs_idx)) # gather gather_idx = abs_idx.unsqueeze(-1).expand(-1, -1, D) out = torch.gather(tcr_feat, 1, gather_idx) # 对 mask 为 False 的位置强制置零,避免无效 token 参与计算 out = out * mask.unsqueeze(-1) return out, mask def forward(self, batch): # TCRα / TCRβ tcra_feat = self.encode_tcr( batch['tcra_phys'].to(self.device, non_blocking=True), batch['tcra_esm'].to(self.device, non_blocking=True), batch['tcra_struct'].to(self.device, non_blocking=True), batch['tcra_coord'].to(self.device, non_blocking=True), batch['tcra_mask'].to(self.device, non_blocking=True), self.tcra_enc, self.tcra_pos_embed ) tcrb_feat = self.encode_tcr( batch['tcrb_phys'].to(self.device, non_blocking=True), batch['tcrb_esm'].to(self.device, non_blocking=True), batch['tcrb_struct'].to(self.device, non_blocking=True), batch['tcrb_coord'].to(self.device, non_blocking=True), batch['tcrb_mask'].to(self.device, non_blocking=True), self.tcrb_enc, self.tcrb_pos_embed ) # peptide pep_phys = self.proj_pep_phys(batch['pep_phys'].to(self.device, non_blocking=True)) pep_esm = self.proj_pep_esm(batch['pep_esm'].to(self.device, non_blocking=True)) pep_se3 = self.pep_se3(batch['pep_struct'].to(self.device, non_blocking=True), batch['pep_coord'].to(self.device, non_blocking=True), None)[0] pep_se3 = self.proj_pep_struct(pep_se3) pep_feat = self.pep_gate(pep_phys, pep_esm, pep_se3) pep_feat = self.add_positional_encoding(pep_feat, self.pep_pos_embed) pep_feat = self.pep_encoder( pep_feat, src_key_padding_mask=~batch['pep_mask'].to(self.device) ) # hla hla_phys = self.proj_hla_phys(batch['hla_phys'].to(self.device, non_blocking=True)) hla_esm = self.proj_hla_esm(batch['hla_esm'].to(self.device, non_blocking=True)) hla_se3 = self.hla_se3(batch['hla_struct'].to(self.device, non_blocking=True), batch['hla_coord'].to(self.device, non_blocking=True), None)[0] hla_se3 = self.proj_hla_struct(hla_se3) hla_feat = self.hla_gate(hla_phys, hla_esm, hla_se3) hla_feat = self.add_positional_encoding(hla_feat, self.hla_pos_embed) hla_feat = self.hla_encoder(hla_feat) if ('pep_feat_pretrain' in batch) and ('hla_feat_pretrain' in batch): pep_pretrain = batch['pep_feat_pretrain'].to(self.device, non_blocking=True) hla_pretrain = batch['hla_feat_pretrain'].to(self.device, non_blocking=True) # ---- 鲁棒的长度对齐 (裁剪到最小长度) ---- Lp = pep_feat.shape[1] Lp_pretrain = pep_pretrain.shape[1] if Lp != Lp_pretrain: Lp_min = min(Lp, Lp_pretrain) pep_feat = pep_feat[:, :Lp_min, :] pep_pretrain = pep_pretrain[:, :Lp_min, :] Lh = hla_feat.shape[1] Lh_pretrain = hla_pretrain.shape[1] if Lh != Lh_pretrain: Lh_min = min(Lh, Lh_pretrain) hla_feat = hla_feat[:, :Lh_min, :] hla_pretrain = hla_pretrain[:, :Lh_min, :] # ---- Peptide gating ---- pep_feat = self.pep_gate_2(pep_feat, pep_pretrain) # ---- HLA gating ---- hla_feat = self.hla_gate_2(hla_feat, hla_pretrain) # TCRα CDR3 segment tcra_cdr3, cdr3a_mask = self._extract_cdr3_segment( tcra_feat, batch['cdr3a_start'].to(self.device, non_blocking=True), batch['cdr3a_end'].to(self.device, non_blocking=True) ) # TCRβ CDR3 segment tcrb_cdr3, cdr3b_mask = self._extract_cdr3_segment( tcrb_feat, batch['cdr3b_start'].to(self.device, non_blocking=True), batch['cdr3b_end'].to(self.device, non_blocking=True) ) # TCRα CDR3 ← Peptide tcra_cdr3_cross, _ = self.cross_attn_tcra_pep( query=tcra_cdr3, # [B, La_cdr3, D] key=pep_feat, value=pep_feat, # [B, Lp, D] key_padding_mask=~batch['pep_mask'].to(self.device) ) tcra_cdr3 = self.norm_tcra_pep(tcra_cdr3 + tcra_cdr3_cross) # 重新掩蔽 padding 的 CDR3 位置,防止无效 token 漏光 tcra_cdr3 = tcra_cdr3 * cdr3a_mask.unsqueeze(-1) # TCRβ CDR3 ← Peptide tcrb_cdr3_cross, _ = self.cross_attn_tcrb_pep( query=tcrb_cdr3, key=pep_feat, value=pep_feat, key_padding_mask=~batch['pep_mask'].to(self.device) ) tcrb_cdr3 = self.norm_tcrb_pep(tcrb_cdr3 + tcrb_cdr3_cross) tcrb_cdr3 = tcrb_cdr3 * cdr3b_mask.unsqueeze(-1) # ------------------ Cross-Attn:TCR 全序列 ↔ HLA(整条 TCR) ------------------ # TCRα full ← HLA tcra_hla_cross, _ = self.cross_attn_tcra_hla( query=tcra_feat, # [B, La, D] key=hla_feat, value=hla_feat, # [B, Lh, D] key_padding_mask=None ) tcra_feat = self.norm_tcra_hla(tcra_feat + tcra_hla_cross) tcra_feat = tcra_feat * batch['tcra_mask'].to(self.device).unsqueeze(-1) # TCRβ full ← HLA tcrb_hla_cross, _ = self.cross_attn_tcrb_hla( query=tcrb_feat, key=hla_feat, value=hla_feat, key_padding_mask=None ) tcrb_feat = self.norm_tcrb_hla(tcrb_feat + tcrb_hla_cross) tcrb_feat = tcrb_feat * batch['tcrb_mask'].to(self.device).unsqueeze(-1) # bilinear fusion vec_tcra_pep, attn_tcra_pep = self.bi_tcra_pep(tcra_cdr3, pep_feat, v_mask=cdr3a_mask, q_mask=batch['pep_mask'].to(self.device)) vec_tcrb_pep, attn_tcrb_pep = self.bi_tcrb_pep(tcrb_cdr3, pep_feat, v_mask=cdr3b_mask, q_mask=batch['pep_mask'].to(self.device)) vec_tcra_hla, attn_tcra_hla = self.bi_tcra_hla(tcra_feat, hla_feat, v_mask=batch['tcra_mask'].to(self.device), q_mask=None) vec_tcrb_hla, attn_tcrb_hla = self.bi_tcrb_hla(tcrb_feat, hla_feat, v_mask=batch['tcrb_mask'].to(self.device), q_mask=None) attn_tcra_pep_small = attn_tcra_pep.sum(dim=1).float() attn_tcrb_pep_small = attn_tcrb_pep.sum(dim=1).float() attn_tcra_hla_small = attn_tcra_hla.sum(dim=1).float() attn_tcrb_hla_small = attn_tcrb_hla.sum(dim=1).float() attn_dict = { 'attn_tcra_pep': attn_tcra_pep_small.detach().cpu().numpy(), 'attn_tcrb_pep': attn_tcrb_pep_small.detach().cpu().numpy(), 'attn_tcra_hla': attn_tcra_hla_small.detach().cpu().numpy(), 'attn_tcrb_hla': attn_tcrb_hla_small.detach().cpu().numpy() } fused = torch.cat([vec_tcra_pep, vec_tcrb_pep, vec_tcra_hla, vec_tcrb_hla], dim=-1) logits = self.head(fused).squeeze(-1) labels = batch['label'].to(self.device) loss_binding = self.loss_fn(logits, labels.float()) probs = torch.sigmoid(logits) return probs, loss_binding, pep_feat.detach().cpu().numpy(), attn_dict