StriMap / src /model.py
cao
fix
c7acc8d
raw
history blame
79.7 kB
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple
from torch.nn.utils.parametrizations import weight_norm
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import esm
import pandas as pd
from tqdm import tqdm
from typing import Dict, List, Tuple
import tempfile
from pathlib import Path
import mdtraj as md
# import io
# import gzip
import os
from egnn_pytorch import EGNN
from transformers import AutoTokenizer, EsmForProteinFolding
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# from re import search as re_search
import re
def determine_tcr_seq_vj(cdr3,V,J,chain,guess01=False):
def file2dict(filename,key_fields,store_fields,delimiter='\t'):
"""Read file to a dictionary.
key_fields: fields to be used as keys
store_fields: fields to be saved as a list
delimiter: delimiter used in the given file."""
dictionary={}
with open(filename, newline='') as csvfile:
reader = csv.DictReader(csvfile,delimiter=delimiter)
for row in reader:
keys = [row[k] for k in key_fields]
store= [row[s] for s in store_fields]
sub_dict = dictionary
for key in keys[:-1]:
if key not in sub_dict:
sub_dict[key] = {}
sub_dict = sub_dict[key]
key = keys[-1]
if key not in sub_dict:
sub_dict[key] = []
sub_dict[key].append(store)
return dictionary
def get_protseqs_ntseqs(chain='B'):
"""returns sequence dictioaries for genes: protseqsV, protseqsJ, nucseqsV, nucseqsJ"""
seq_dicts=[]
for gene,type in zip(['v','j','v','j'],['aa','aa','nt','nt']):
name = 'library/'+'tr'+chain.lower()+gene+'s_'+type+'.tsv'
sdict = file2dict(name,key_fields=['Allele'],store_fields=[type+'_seq'])
for g in sdict:
sdict[g]=sdict[g][0][0]
seq_dicts.append(sdict)
return seq_dicts
protVb,protJb,_,_ = get_protseqs_ntseqs(chain='B')
protVa,protJa,_,_ = get_protseqs_ntseqs(chain='A')
def splice_v_cdr3_j(pv: str, pj: str, cdr3: str) -> str:
"""
pv: V gene protein sequence
pj: J gene protein sequence
cdr3: C-starting, F/W-ending CDR3 sequence (protein)
Returns: The spliced full sequence (V[:lastC] + CDR3 + J suffix)
"""
pv = (pv or "").strip().upper()
pj = (pj or "").strip().upper()
cdr3 = (cdr3 or "").strip().upper()
# 1) V segment: Take the last 'C' (including the conserved C in V region)
cpos = pv.rfind('C')
if cpos == -1:
raise ValueError("V sequence has no 'C' to anchor CDR3 start.")
v_prefix = pv[:cpos] # up to and including C
# 2) Align CDR3's "end overlap" in J
# Start from the full length of cdr3, gradually shorten it, and find the longest suffix that can match in J
j_suffix = pj # fallback (in extreme cases)
for k in range(len(cdr3), 0, -1):
tail = cdr3[-k:] # CDR3's suffix
m = re.search(re.escape(tail), pj)
if m:
j_suffix = pj[m.end():] # Take the suffix from the matching segment
break
return v_prefix + cdr3 + j_suffix
tcr_list = []
for i in range(len(cdr3)):
cdr3_ = cdr3[i]
V_ = V[i]
J_ = J[i]
if chain=='A':
protseqsV = protVa
protseqsJ = protJa
else:
protseqsV = protVb
protseqsJ = protJb
if guess01:
if '*' not in V_:
V_+='*01'
if '*' not in J_:
J_+='*01'
pv = protseqsV[V_]
pj = protseqsJ[J_]
# t = pv[:pv.rfind('C')]+ cdr3_ + pj[re_search(r'[FW]G.[GV]',pj).start()+1:]
t = splice_v_cdr3_j(pv, pj, cdr3_)
tcr_list.append(t)
return tcr_list
# def negative_sampling_phla(df, neg_ratio=5, label_col='label', neg_label=0, random_state=42):
# """
# Create negative samples by shuffling the TCR sequences while keeping the peptide-HLA pairs intact.
# Ensures that the generated (TCR, peptide, HLA) triplets do not exist in the original dataset.
# """
# negative_samples = []
# # 正样本 triplet 集合
# pos_triplets = set(zip(
# df['tcra'], df['tcrb'], df['peptide'], df['HLA_full']
# ))
# for i in range(neg_ratio):
# shuffled_df = df.copy()
# tcr_cols = ['tcra', 'cdr3a_start', 'cdr3a_end', 'tcrb', 'cdr3b_start', 'cdr3b_end']
# shuffled_tcr = df[tcr_cols].sample(frac=1, random_state=random_state + i).reset_index(drop=True)
# for col in tcr_cols:
# shuffled_df[col] = shuffled_tcr[col]
# # 剔除:1) TCR 未改变的行 2) triplet 与正样本重复
# mask_keep = []
# for idx, row in shuffled_df.iterrows():
# triplet = (row['tcra'], row['tcrb'], row['peptide'], row['HLA_full'])
# if triplet in pos_triplets:
# mask_keep.append(False)
# else:
# mask_keep.append(True)
# shuffled_df = shuffled_df[mask_keep]
# shuffled_df[label_col] = neg_label
# negative_samples.append(shuffled_df)
# negative_samples = pd.concat(negative_samples, ignore_index=True).drop_duplicates()
# return negative_samples
import numpy as np
import pandas as pd
# def balanced_negative_sampling_phla(df, label_col='label', neg_label=0, random_state=42):
# """
# 为每个 (peptide, HLA_full) 平衡采样负样本:
# - 找出正样本最多的 peptide
# - 该 peptide 的负样本数量 = 1:1,从其他 peptide 的 TCR 中采样(保持 peptide–HLA 配对)
# - 其他 peptide 采样负样本,使每个 peptide 拥有相同总样本数
# - 保证 peptide 与 HLA_full 始终保持配对关系
# """
# np.random.seed(random_state)
# pos_df = df[df[label_col] != neg_label].copy()
# pos_counts = pos_df['peptide'].value_counts()
# max_peptide = pos_counts.idxmax()
# max_pos = pos_counts.max()
# total_target = max_pos * 2 # 每个 peptide 的最终样本数(正+负)
# neg_samples = []
# # 针对 max_peptide:负样本 = 1:1
# df_other_tcrs = pos_df[pos_df['peptide'] != max_peptide][['tcra', 'tcrb', 'cdr3a_start', 'cdr3a_end', 'cdr3b_start', 'cdr3b_end']].copy()
# neg_max = pos_df[pos_df['peptide'] == max_peptide].copy()
# sampled_tcrs = df_other_tcrs.sample(
# n=max_pos,
# replace=True if len(df_other_tcrs) < max_pos else False,
# random_state=random_state
# ).reset_index(drop=True)
# neg_max.update(sampled_tcrs)
# neg_max[label_col] = neg_label
# neg_samples.append(neg_max)
# # 针对其他 peptides
# for pep, n_pos in pos_counts.items():
# if pep == max_peptide:
# continue
# n_neg = max(0, total_target - n_pos)
# df_other_tcrs = pos_df[pos_df['peptide'] != pep][['tcra', 'tcrb', 'cdr3a_start', 'cdr3a_end', 'cdr3b_start', 'cdr3b_end']].copy()
# neg_pep = pos_df[pos_df['peptide'] == pep].copy()
# sampled_tcrs = df_other_tcrs.sample(
# n=min(len(df_other_tcrs), n_neg),
# replace=True if len(df_other_tcrs) < n_neg else False,
# random_state=random_state
# ).reset_index(drop=True)
# sampled_tcrs = sampled_tcrs.iloc[:len(neg_pep)].copy() if len(sampled_tcrs) > len(neg_pep) else sampled_tcrs
# neg_pep = pd.concat(
# [neg_pep]*int(np.ceil(n_neg / len(neg_pep))), ignore_index=True
# ).iloc[:n_neg]
# neg_pep.update(sampled_tcrs)
# neg_pep[label_col] = neg_label
# neg_samples.append(neg_pep)
# neg_df = pd.concat(neg_samples, ignore_index=True)
# final_df = pd.concat([pos_df, neg_df], ignore_index=True).reset_index(drop=True)
# return final_df
def negative_sampling_phla(df, neg_ratio=5, label_col='label', neg_label=0, random_state=42):
"""
Create negative samples by shuffling TCRs while keeping peptide–HLA pairs intact.
Ensures negative samples count = neg_ratio × positive samples count.
"""
np.random.seed(random_state)
pos_triplets = set(zip(df['tcra'], df['tcrb'], df['peptide'], df['HLA_full']))
tcr_cols = ['tcra', 'cdr3a_start', 'cdr3a_end', 'tcrb', 'cdr3b_start', 'cdr3b_end']
n_pos = len(df)
target_n_neg = n_pos * neg_ratio
all_neg = []
i = 0
while len(all_neg) < target_n_neg:
shuffled_df = df.copy()
shuffled_tcr = df[tcr_cols].sample(frac=1, random_state=random_state + i).reset_index(drop=True)
for col in tcr_cols:
shuffled_df[col] = shuffled_tcr[col]
mask_keep = []
for idx, row in shuffled_df.iterrows():
triplet = (row['tcra'], row['tcrb'], row['peptide'], row['HLA_full'])
mask_keep.append(triplet not in pos_triplets)
shuffled_df = shuffled_df[mask_keep]
shuffled_df[label_col] = neg_label
all_neg.append(shuffled_df)
i += 1
if len(pd.concat(all_neg)) > target_n_neg * 1.5:
break
negative_samples = pd.concat(all_neg, ignore_index=True).drop_duplicates()
negative_samples = negative_samples.sample(
n=min(len(negative_samples), target_n_neg), random_state=random_state
).reset_index(drop=True)
return negative_samples
# def negative_sampling_tcr(df, neg_ratio=5, label_col='label', neg_label=0, random_state=42):
# """
# Create negative samples by keeping TCR fixed but assigning random (peptide, HLA_full)
# pairs that do not exist in the original dataset.
# Ensures that the generated (TCR, peptide, HLA) triplets do not exist in the original data.
# """
# np.random.seed(random_state)
# negative_samples = []
# pos_triplets = set(zip(df['tcra'], df['tcrb'], df['peptide'], df['HLA_full']))
# all_pairs = list(set(zip(df['peptide'], df['HLA_full'])))
# for i in range(neg_ratio):
# neg_df = df.copy()
# # 随机打乱 peptide–HLA 对,但保证不会选原来的那一个
# new_pairs = []
# for _, row in df.iterrows():
# while True:
# pep, hla = all_pairs[np.random.randint(len(all_pairs))]
# triplet = (row['tcra'], row['tcrb'], pep, hla)
# if triplet not in pos_triplets:
# new_pairs.append((pep, hla))
# break
# neg_df[['peptide', 'HLA_full']] = pd.DataFrame(new_pairs, index=neg_df.index)
# neg_df[label_col] = neg_label
# negative_samples.append(neg_df)
# negative_samples = pd.concat(negative_samples, ignore_index=True).drop_duplicates()
# return negative_samples
class EarlyStopping:
def __init__(self, patience=10, verbose=True, delta=0.0, save_path='checkpoint.pt'):
"""
Early stopping based on both val_loss and val_auc.
The model is saved whenever EITHER:
- val_loss decreases by more than delta, OR
- val_auc increases by more than delta.
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.early_stop = False
self.delta = delta
self.save_path = save_path
self.best_loss = np.inf
self.best_auc = -np.inf
def __call__(self, val_auc, model):
improved = False
# Check auc improvement
if val_auc > self.best_auc + self.delta:
self.best_auc = val_auc
improved = True
if improved:
self.save_checkpoint(model, val_auc)
self.counter = 0
else:
self.counter += 1
if self.verbose:
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
def save_checkpoint(self, model, val_auc):
"""Save current best model."""
if self.verbose:
print(f"Validation improved → Saving model (Score={val_auc:.4f}) to {self.save_path}")
torch.save(model.state_dict(), self.save_path)
# ============================================================================
# ESM2 Embedding via HuggingFace
# ============================================================================
class ESM2Encoder(nn.Module):
def __init__(self,
device="cuda:0",
layer=33,
cache_dir='/data/cache'):
"""
Initialize an ESM2 encoder.
Args:
model_name (str): Name of the pretrained ESM2 model (e.g., 'esm2_t33_650M_UR50D').
device (str): Device to run on, e.g. 'cuda:0', 'cuda:1', or 'cpu'.
layer (int): Layer number from which to extract representations.
"""
super().__init__()
self.device = device
self.layer = layer
if cache_dir is None:
cache_dir = os.path.dirname(os.path.abspath(__file__))
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
self.model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
self.batch_converter = self.alphabet.get_batch_converter()
self.model = self.model.eval().to(device)
def _cache_path(self, prefix):
base_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = base_dir + "/" + self.cache_dir
os.makedirs(base_dir, exist_ok=True)
return os.path.join(base_dir, f"{prefix}_esm2_layer{self.layer}.pt")
def save_obj(self, obj, path):
"""Save object to a file (no compression)."""
torch.save(obj, path)
def load_obj(self, path):
"""Load object from a file (no compression)."""
return torch.load(path, map_location="cpu", weights_only=False)
@torch.no_grad()
def _embed_batch(self, batch_data):
batch_labels, batch_strs, batch_tokens = self.batch_converter(batch_data)
batch_tokens = batch_tokens.to(self.device)
results = self.model(batch_tokens, repr_layers=[self.layer], return_contacts=False)
token_representations = results["representations"][self.layer]
batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)
seq_reprs = []
for i, tokens_len in enumerate(batch_lens):
seq_repr = token_representations[i, 1:tokens_len-1].cpu()
seq_reprs.append(seq_repr)
return seq_reprs
@torch.no_grad()
def forward(self, df, seq_col, prefix, batch_size=64, re_embed=False, cache_save=True):
"""
Add or update embeddings for sequences in a DataFrame.
- If there are new sequences, automatically update the dictionary and save.
- If re_embed=True, force re-computation of all sequences.
"""
cache_path = self._cache_path(prefix)
emb_dict = {}
if os.path.exists(cache_path) and not re_embed:
print(f"[ESM2] Loading cached embeddings from {cache_path}")
emb_dict = self.load_obj(cache_path)
else:
if re_embed:
print(f"[ESM2] Re-embedding all sequences for {prefix}")
else:
print(f"[ESM2] No existing cache for {prefix}, will create new.")
seqs = [str(s).strip().upper() for s in df[seq_col].tolist() if isinstance(s, str)]
unique_seqs = sorted(set(seqs))
new_seqs = [s for s in unique_seqs if s not in emb_dict]
if new_seqs:
print(f"[ESM2] Found {len(new_seqs)} new sequences → computing embeddings...")
data = [(str(i), s) for i, s in enumerate(new_seqs)]
for i in tqdm(range(0, len(data), batch_size), desc=f"ESM2 update ({prefix})"):
batch = data[i:i+batch_size]
embs = self._embed_batch(batch)
for (_, seq), emb in zip(batch, embs):
emb_dict[seq] = emb.clone()
if cache_save:
print(f"[ESM2] Updating cache with new sequences")
self.save_obj(emb_dict, cache_path)
else:
print(f"[ESM2] No new sequences for {prefix}, using existing cache")
return emb_dict
# ============================================================================
# ESMFold (transformers)
# ============================================================================
class ESMFoldPredictorHF(nn.Module):
def __init__(self,
model_name="facebook/esmfold_v1",
cache_dir=None,
device='cpu',
allow_tf32=True):
super().__init__()
self.model_name = model_name
self.cache_dir = cache_dir
self.device = device
if allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# tokenizer and model
print(f"Loading ESMFold model {model_name} on {device}... {'with' if cache_dir else 'without'} cache_dir: {cache_dir}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
self.model = EsmForProteinFolding.from_pretrained(
model_name, low_cpu_mem_usage=True, cache_dir=cache_dir
).eval().to(self.device)
@torch.no_grad()
def infer_pdb_str(self, seq: str) -> str:
pdb_str = self.model.infer_pdb(seq)
return pdb_str
@torch.no_grad()
def forward_raw(self, seq: str):
inputs = self.tokenizer([seq], return_tensors="pt", add_special_tokens=False)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
return outputs # ESMFoldOutput
MAX_ASA_TIEN = {
"ALA": 129.0, "ARG": 274.0, "ASN": 195.0, "ASP": 193.0, "CYS": 167.0,
"GLN": 225.0, "GLU": 223.0, "GLY": 104.0, "HIS": 224.0, "ILE": 197.0,
"LEU": 201.0, "LYS": 236.0, "MET": 224.0, "PHE": 240.0, "PRO": 159.0,
"SER": 155.0, "THR": 172.0, "TRP": 285.0, "TYR": 263.0, "VAL": 174.0,
}
SS8_INDEX = {"H":0,"B":1,"E":2,"G":3,"I":4,"T":5,"S":6,"C":7,"-":7}
class StructureFeatureExtractorNoDSSP(nn.Module):
def __init__(self, device="cpu"):
super().__init__()
self.device = device
self.in_dim = 6 + 8 + 1 + 1 + 1 # 17
self.to(torch.device(self.device))
@torch.no_grad()
def _angles(self, traj):
L = traj.n_residues
sphi = np.zeros(L, dtype=np.float32); cphi = np.zeros(L, dtype=np.float32)
spsi = np.zeros(L, dtype=np.float32); cpsi = np.zeros(L, dtype=np.float32)
someg = np.zeros(L, dtype=np.float32); comeg = np.zeros(L, dtype=np.float32)
# 1) phi: (C_{i-1}, N_i, CA_i, C_i) —— 当前残基 i 可用 atoms[1] (N_i) 来定位
phi_idx, phi_vals = md.compute_phi(traj) # phi_vals: (1, n_phi)
if phi_vals.size > 0:
for k, atoms in enumerate(phi_idx):
res_i = traj.topology.atom(int(atoms[1])).residue.index # N_i 所在残基
if 0 <= res_i < L:
ang = float(phi_vals[0, k])
sphi[res_i] = np.sin(ang); cphi[res_i] = np.cos(ang)
# 2) psi: (N_i, CA_i, C_i, N_{i+1}) —— 当前残基 i 可用 atoms[1] (CA_i)
psi_idx, psi_vals = md.compute_psi(traj)
if psi_vals.size > 0:
for k, atoms in enumerate(psi_idx):
res_i = traj.topology.atom(int(atoms[1])).residue.index # CA_i
if 0 <= res_i < L:
ang = float(psi_vals[0, k])
spsi[res_i] = np.sin(ang); cpsi[res_i] = np.cos(ang)
# 3) omega: (CA_i, C_i, N_{i+1}, CA_{i+1}) —— 当前残基 i 可用 atoms[0] (CA_i)
omg_idx, omg_vals = md.compute_omega(traj)
if omg_vals.size > 0:
for k, atoms in enumerate(omg_idx):
res_i = traj.topology.atom(int(atoms[0])).residue.index # CA_i
if 0 <= res_i < L:
ang = float(omg_vals[0, k])
someg[res_i] = np.sin(ang); comeg[res_i] = np.cos(ang)
angles_feat = np.stack([sphi, cphi, spsi, cpsi, someg, comeg], axis=-1) # [L, 6]
return angles_feat.astype(np.float32)
@torch.no_grad()
def _ss8(self, traj: md.Trajectory):
ss = md.compute_dssp(traj, simplified=False)[0]
L = traj.n_residues
onehot = np.zeros((L, 8), dtype=np.float32)
for i, ch in enumerate(ss):
onehot[i, SS8_INDEX.get(ch, 7)] = 1.0
return onehot
@torch.no_grad()
def _rsa(self, traj: md.Trajectory):
asa = md.shrake_rupley(traj, mode="residue")[0] # (L,)
rsa = np.zeros_like(asa, dtype=np.float32)
for i, res in enumerate(traj.topology.residues):
max_asa = MAX_ASA_TIEN.get(res.name.upper(), None)
rsa[i] = 0.0 if not max_asa else float(asa[i] / max_asa)
return np.clip(rsa, 0.0, 1.0)[:, None]
@torch.no_grad()
def _contact_count(self, traj: md.Trajectory, cutoff_nm=0.8):
L = traj.n_residues
ca_atoms = traj.topology.select("name CA")
if len(ca_atoms) == L:
coors = traj.xyz[0, ca_atoms, :] # nm
else:
xyz = traj.xyz[0]
coors = []
for res in traj.topology.residues:
idxs = [a.index for a in res.atoms]
coors.append(xyz[idxs, :].mean(axis=0))
coors = np.array(coors, dtype=np.float32)
diff = coors[:, None, :] - coors[None, :, :]
dist = np.sqrt((diff**2).sum(-1)) # nm
mask = (dist < cutoff_nm).astype(np.float32)
np.fill_diagonal(mask, 0.0)
cnt = mask.sum(axis=1)
return cnt[:, None].astype(np.float32)
@torch.no_grad()
def _plddt(self, pdb_file: str):
# 用 Biopython 读取 PDB 的 B-factor(ESMFold/AlphaFold 会把 pLDDT 写在这里)
from Bio.PDB import PDBParser
import numpy as np
parser = PDBParser(QUIET=True)
structure = parser.get_structure("prot", pdb_file)
model = structure[0]
res_plddt = []
for chain in model:
for residue in chain:
atoms = list(residue.get_atoms())
if len(atoms) == 0:
res_plddt.append(0.0)
continue
# 该残基原子 B-factor 的均值
bvals = [float(atom.get_bfactor()) for atom in atoms]
res_plddt.append(float(np.mean(bvals)))
# 归一化到 [0,1]
plddt = np.array(res_plddt, dtype=np.float32) / 100.0
plddt = np.clip(plddt, 0.0, 1.0)
return plddt[:, None] # [L,1]
@torch.no_grad()
def _parse_and_features(self, pdb_file: str):
traj = md.load(pdb_file)
L = traj.n_residues
angles = self._angles(traj) # [L,6]
ss8 = self._ss8(traj) # [L,8]
rsa = self._rsa(traj) # [L,1]
cnt = self._contact_count(traj) # [L,1]
plddt = self._plddt(pdb_file) # [L,1]
feats = np.concatenate([angles, ss8, rsa, cnt, plddt], axis=1).astype(np.float32) # [L,17]
ca_atoms = traj.topology.select("name CA")
if len(ca_atoms) == L:
coors_nm = traj.xyz[0, ca_atoms, :]
else:
xyz = traj.xyz[0]
res_coords = []
for res in traj.topology.residues:
idxs = [a.index for a in res.atoms]
res_coords.append(xyz[idxs, :].mean(axis=0))
coors_nm = np.array(res_coords, dtype=np.float32)
coors_ang = coors_nm * 10.0 # nm -> Å
return coors_ang.astype(np.float32), feats # [L,3], [L,17]
@torch.no_grad()
def forward(self, pdb_file: str):
coors_ang, scalars = self._parse_and_features(pdb_file)
coors = torch.tensor(coors_ang, dtype=torch.float32, device=self.device) # [N,3]
scalars = torch.tensor(scalars, dtype=torch.float32, device=self.device) # [N,17]
return scalars, coors # [N,17], [N,3]
class ResiduePipelineWithHFESM:
def __init__(self,
esm_model_name="facebook/esmfold_v1",
cache_dir=None,
esm_device='cpu',
allow_tf32=True
):
self.esm = ESMFoldPredictorHF(esm_model_name, cache_dir, esm_device, allow_tf32)
self.struct_encoder = StructureFeatureExtractorNoDSSP(device=esm_device)
self.cache_dir = cache_dir
@torch.no_grad()
def __call__(self, seq: str, save_pdb_path: str = None) -> torch.Tensor:
pdb_str = self.esm.infer_pdb_str(seq)
if save_pdb_path is None:
tmpdir = self.cache_dir if self.cache_dir is not None else tempfile.gettempdir()
save_pdb_path = str(Path(tmpdir) / "esmfold_pred_fold5.pdb")
Path(save_pdb_path).write_text(pdb_str)
struct_emb, struct_coords = self.struct_encoder(save_pdb_path)
return struct_emb, struct_coords
def sanitize_protein_seq(seq: str) -> str:
if not isinstance(seq, str):
return ""
s = "".join(seq.split()).upper()
allowed = set("ACDEFGHIKLMNPQRSTVWYXBZJUO")
return "".join([c for c in s if c in allowed])
@torch.no_grad()
def batch_embed_to_dicts(
df: pd.DataFrame,
seq_col: str,
pipeline,
show_progress: bool = True,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], List[Tuple[str, str]]]:
"""
Returns:
- emb_dict: {seq -> z(torch.Tensor[L, D])}
- coord_dict:{seq -> coords(torch.Tensor[L, 3])}
- failures: [(seq, err_msg), ...]
"""
raw_list = df[seq_col].astype(str).tolist()
seqs = []
for s in raw_list:
ss = sanitize_protein_seq(s)
if ss:
seqs.append(ss)
uniq_seqs = sorted(set(seqs))
logger.info(f"Total rows: {len(df)}, valid seqs: {len(seqs)}, unique: {len(uniq_seqs)}")
emb_dict: Dict[str, torch.Tensor] = {}
coord_dict: Dict[str, torch.Tensor] = {}
failures: List[Tuple[str, str]] = []
iterator = tqdm(uniq_seqs, desc="ESMfold Predicting structure...") if show_progress else uniq_seqs
for seq in tqdm(iterator):
if seq in emb_dict:
continue
try:
z_t, c_t = pipeline(seq) # z: [L, D], coords: [L, 3] (torch.Tensor)
emb_dict[seq] = z_t.detach().float().cpu()
coord_dict[seq] = c_t.detach().float().cpu()
except Exception as e:
failures.append((seq, repr(e)))
continue
logger.info(f"[DONE] OK: {len(emb_dict)}, Failed: {len(failures)}")
if failures[:3]:
logger.error("[SAMPLE failures]", failures[:3])
return emb_dict, coord_dict, failures
class ESMFoldEncoder(nn.Module):
def __init__(self, model_name="facebook/esmfold_v1", esm_cache_dir="/data/esm_cache", cache_dir="/data/cache"):
super(ESMFoldEncoder, self).__init__()
self.model_name = model_name
self.esm_cache_dir = esm_cache_dir
self.cache_dir = cache_dir
def save_obj(self, obj, path):
"""Save object to a file (no compression)."""
torch.save(obj, path)
def load_obj(self, path):
"""Load object from a file (no compression)."""
return torch.load(path, map_location='cpu', weights_only=False)
def load_esm_dict(self, device, df_data, chain, re_embed):
def _clean_unique(series: pd.Series) -> list:
cleaned = []
for s in series.astype(str).tolist():
ss = sanitize_protein_seq(s)
if ss:
cleaned.append(ss)
return sorted(set(cleaned))
def _retry_embed_df(
df: pd.DataFrame,
chain: str,
max_retries: int = 2,
show_progress: bool = True,
):
"""
Try to embed protein sequences with retries on failures.
Args:
df (pd.DataFrame): A DataFrame containing a column `chain` with sequences.
chain (str): The column name containing the sequences (e.g., "alpha", "beta").
pipeline: An embedding pipeline, should return (embedding, coords) for a sequence.
max_retries (int): Maximum number of retries for failed sequences.
show_progress (bool): Whether to display tqdm progress bars.
Returns:
feat_dict (Dict[str, torch.Tensor]): {sequence -> embedding tensor [L, D]}.
coord_dict (Dict[str, torch.Tensor]): {sequence -> coordinate tensor [L, 3]}.
failures (List[Tuple[str, str]]): List of (sequence, error_message) that still failed after retries.
"""
pipeline = ResiduePipelineWithHFESM(
esm_model_name=self.model_name,
cache_dir=self.esm_cache_dir,
esm_device=device
)
# 1. First attempt
feat_dict, coord_dict, failures = batch_embed_to_dicts(
df, chain, pipeline, show_progress=show_progress
)
# 2. Retry loop for failed sequences
tries = 0
while failures and tries < max_retries:
tries += 1
retry_seqs = [s for s, _ in failures]
logger.info(f"[retry {tries}/{max_retries}] {len(retry_seqs)} sequences")
retry_df = pd.DataFrame({chain: retry_seqs})
f2, c2, failures = batch_embed_to_dicts(
retry_df, chain, pipeline, show_progress=show_progress
)
feat_dict.update(f2)
coord_dict.update(c2)
return feat_dict, coord_dict, failures
def update_with_new_seqs(feat_dict, coord_dict, chain):
base_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = base_dir + "/" + self.cache_dir
os.makedirs(base_dir, exist_ok=True)
path_feat = os.path.join(base_dir, f"{chain}_feat_dict.pt")
path_coords = os.path.join(base_dir, f"{chain}_coord_dict.pt")
all_seqs_clean = _clean_unique(df_data[chain])
new_seqs = [s for s in all_seqs_clean if s not in feat_dict]
if not new_seqs:
logger.info(f"No new {chain} sequences found")
return feat_dict, coord_dict
logger.info(f"Found new {chain} sequences, embedding...")
df_new = pd.DataFrame({chain: new_seqs})
new_feat_dict, new_coord_dict, failures = _retry_embed_df(df_new, chain, max_retries=100)
feat_dict.update(new_feat_dict)
coord_dict.update(new_coord_dict)
self.save_obj(feat_dict, path_feat)
self.save_obj(coord_dict, path_coords)
if failures:
for seq, err in failures:
logger.error(f"[create] failed: {seq} | {err}")
logger.info(f"Updated and saved {path_feat} and {path_coords}")
return feat_dict, coord_dict
def get_or_create_dict(chain):
base_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + self.cache_dir
os.makedirs(base_dir, exist_ok=True)
path_feat = os.path.join(base_dir, f"{chain}_feat_dict.pt")
path_coords = os.path.join(base_dir, f"{chain}_coord_dict.pt")
if os.path.exists(path_feat) and not re_embed:
logger.info(f"Loading {path_feat} and {path_coords}")
feat_dict = self.load_obj(path_feat)
coord_dict = self.load_obj(path_coords)
else:
logger.info(f"{path_feat} and {path_coords} not found or re_embed=True, generating...")
unique_seqs = _clean_unique(df_data[chain])
df_uniq = pd.DataFrame({chain: unique_seqs})
feat_dict, coord_dict, failures = _retry_embed_df(
df_uniq, chain, show_progress=True, max_retries=100
)
self.save_obj(feat_dict, path_feat)
self.save_obj(coord_dict, path_coords)
if failures:
for seq, err in failures:
logger.error(f"[create] failed: {seq} | {err}")
logger.info(f"Saved {path_feat} and {path_coords}")
return feat_dict, coord_dict
self.dict[chain+'_feat'], self.dict[chain+'_coord'] = update_with_new_seqs(*get_or_create_dict(chain), chain)
def pad_and_stack(self, batch_feats, L_max, batch_coors):
"""
batch_feats: list of [L_i, D] tensors
batch_coors: list of [L_i, 3] tensors
return:
feats: [B, L_max, D]
coors: [B, L_max, 3]
mask : [B, L_max] (True for real tokens)
"""
assert len(batch_feats) == len(batch_coors)
B = len(batch_feats)
D = batch_feats[0].shape[-1]
feats_pad = []
coors_pad = []
masks = []
for x, c in zip(batch_feats, batch_coors):
L = x.shape[0]
pad_L = L_max - L
# pad feats/coors with zeros
feats_pad.append(torch.nn.functional.pad(x, (0, 0, 0, pad_L))) # [L_max, D]
coors_pad.append(torch.nn.functional.pad(c, (0, 0, 0, pad_L))) # [L_max, 3]
m = torch.zeros(L_max, dtype=torch.bool)
m[:L] = True
masks.append(m)
feats = torch.stack(feats_pad, dim=0) # [B, L_max, D]
coors = torch.stack(coors_pad, dim=0) # [B, L_max, 3]
mask = torch.stack(masks, dim=0) # [B, L_max]
return feats, coors, mask
def forward(self, df_data, chain, device='cpu', re_embed=False):
"""
df_data: pd.DataFrame with a column `chain` containing sequences
chain: str, e.g. "alpha" or "beta"
device: str, e.g. 'cpu' or 'cuda:0'
re_embed: bool, whether to re-embed even if cached files exist
"""
self.dict = {}
self.load_esm_dict(device, df_data, chain, re_embed)
batch_feats = []
batch_coors = []
for seq in df_data[chain].astype(str).tolist():
ss = sanitize_protein_seq(seq)
if ss in self.dict[chain+'_feat'] and ss in self.dict[chain+'_coord']:
batch_feats.append(self.dict[chain+'_feat'][ss])
batch_coors.append(self.dict[chain+'_coord'][ss])
else:
raise ValueError(f"Sequence not found in embedding dict: {ss}")
# L_max = max(x.shape[0] for x in batch_feats)
return batch_feats, batch_coors
# =================================== Dataset / Collate ===========================================
class PepHLA_Dataset(torch.utils.data.Dataset):
def __init__(self, df, phys_dict, esm2_dict, struct_dict):
self.df = df
self.phys_dict = phys_dict
self.esm2_dict = esm2_dict
self.struct_dict = struct_dict
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
pep = row['peptide']
hla = row['HLA_full']
label = torch.tensor(row['label'], dtype=torch.float32)
pep_phys = self.phys_dict['pep'][pep]
pep_esm = self.esm2_dict['pep'][pep]
hla_phys = self.phys_dict['hla'][hla]
hla_esm = self.esm2_dict['hla'][hla]
hla_struct, hla_coord = self.struct_dict[hla]
return {
'pep_phys': pep_phys,
'pep_esm': pep_esm,
'hla_phys': hla_phys,
'hla_esm': hla_esm,
'hla_struct': hla_struct,
'hla_coord': hla_coord,
'label': label,
'pep_id': pep,
'hla_id': hla,
}
def peptide_hla_collate_fn(batch):
def pad_or_crop(x, original_len, target_len):
L, D = x.shape
valid_len = min(original_len, target_len)
valid_part = x[:valid_len]
if valid_len < target_len:
pad_len = target_len - valid_len
padding = x.new_zeros(pad_len, D)
return torch.cat([valid_part, padding], dim=0)
else:
return valid_part
out_batch = {}
pep_lens = [len(item['pep_id']) for item in batch]
max_pep_len = max(pep_lens)
for key in batch[0].keys():
if key == 'label':
out_batch[key] = torch.stack([item[key] for item in batch])
elif key.startswith('pep_') and not key.endswith('_id'):
out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['pep_id']), max_pep_len) for item in batch])
elif key.endswith('_id'):
out_batch[key] = [item[key] for item in batch]
else:
out_batch[key] = torch.stack([item[key] for item in batch])
def make_mask(lengths, max_len):
masks = []
for L in lengths:
m = torch.zeros(max_len, dtype=torch.bool)
m[:L] = True
masks.append(m)
return torch.stack(masks)
out_batch['pep_mask'] = make_mask(pep_lens, max_pep_len)
return out_batch
# =================================== Dataset / Collate ===========================================
class TCRPepHLA_Dataset(torch.utils.data.Dataset):
"""
Dataset for TCRα + TCRβ + peptide + HLA binding.
"""
def __init__(self, df, phys_dict, esm2_dict, struct_dict, pep_hla_feat_dict):
self.df = df
self.phys_dict = phys_dict
self.esm2_dict = esm2_dict
self.struct_dict = struct_dict
self.pep_hla_feat_dict = pep_hla_feat_dict
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
tcra = row['tcra']
tcrb = row['tcrb']
pep = row['peptide']
hla = row['HLA_full']
label = torch.tensor(row['label'], dtype=torch.float32)
# ---- TCRα ----
tcra_phys = self.phys_dict['tcra'][tcra]
tcra_esm = self.esm2_dict['tcra'][tcra]
tcra_struct, tcra_coord = self.struct_dict['tcra'][tcra]
tcra_cdr3_start = torch.tensor(row['cdr3a_start'], dtype=torch.long)
tcra_cdr3_end = torch.tensor(row['cdr3a_end'], dtype=torch.long)
# ---- TCRβ ----
tcrb_phys = self.phys_dict['tcrb'][tcrb]
tcrb_esm = self.esm2_dict['tcrb'][tcrb]
tcrb_struct, tcrb_coord = self.struct_dict['tcrb'][tcrb]
tcrb_cdr3_start = torch.tensor(row['cdr3b_start'], dtype=torch.long)
tcrb_cdr3_end = torch.tensor(row['cdr3b_end'], dtype=torch.long)
# ---- peptide ----
pep_phys = self.phys_dict['pep'][pep]
pep_esm = self.esm2_dict['pep'][pep]
pep_struct, pep_coord = self.struct_dict['pep'][pep]
# ---- HLA ----
hla_phys = self.phys_dict['hla'][hla]
hla_esm = self.esm2_dict['hla'][hla]
hla_struct, hla_coord = self.struct_dict['hla'][hla]
feats = self.pep_hla_feat_dict[(pep, hla)]
pep_feat_pretrain = feats['pep_feat_pretrain']
hla_feat_pretrain = feats['hla_feat_pretrain']
return {
# TCRα
'tcra_phys': tcra_phys,
'tcra_esm': tcra_esm,
'tcra_struct': tcra_struct,
'tcra_coord': tcra_coord,
'cdr3a_start': tcra_cdr3_start,
'cdr3a_end': tcra_cdr3_end,
# TCRβ
'tcrb_phys': tcrb_phys,
'tcrb_esm': tcrb_esm,
'tcrb_struct': tcrb_struct,
'tcrb_coord': tcrb_coord,
'cdr3b_start': tcrb_cdr3_start,
'cdr3b_end': tcrb_cdr3_end,
# peptide
'pep_phys': pep_phys,
'pep_esm': pep_esm,
'pep_struct': pep_struct,
'pep_coord': pep_coord,
# HLA
'hla_phys': hla_phys,
'hla_esm': hla_esm,
'hla_struct': hla_struct,
'hla_coord': hla_coord,
'tcra_id': tcra,
'tcrb_id': tcrb,
'pep_id': pep,
'hla_id': hla,
'label': label,
'pep_feat_pretrain': pep_feat_pretrain,
'hla_feat_pretrain': hla_feat_pretrain,
}
# =================================== Collate Function ===========================================
def tcr_pep_hla_collate_fn(batch):
def pad_or_crop(x, original_len, target_len):
L, D = x.shape
valid_len = min(original_len, target_len)
valid_part = x[:valid_len]
if valid_len < target_len:
pad_len = target_len - valid_len
padding = x.new_zeros(pad_len, D)
return torch.cat([valid_part, padding], dim=0)
else:
return valid_part
out_batch = {}
tcra_lens = [len(item['tcra_id']) for item in batch]
tcrb_lens = [len(item['tcrb_id']) for item in batch]
pep_lens = [len(item['pep_id']) for item in batch]
max_tcra_len = max(tcra_lens)
max_tcrb_len = max(tcrb_lens)
max_pep_len = max(pep_lens)
for key in batch[0].keys():
if key == 'label':
out_batch[key] = torch.stack([item[key] for item in batch])
elif key.startswith('tcra_') and not key.endswith('_id'):
out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['tcra_id']), max_tcra_len) for item in batch])
elif key.startswith('tcrb_') and not key.endswith('_id'):
out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['tcrb_id']), max_tcrb_len) for item in batch])
elif key.startswith('pep_') and not key.endswith('_id'):
out_batch[key] = torch.stack([pad_or_crop(item[key], len(item['pep_id']), max_pep_len) for item in batch])
elif key.endswith('_id'):
out_batch[key] = [item[key] for item in batch]
else:
out_batch[key] = torch.stack([item[key] for item in batch])
def make_mask(lengths, max_len):
masks = []
for L in lengths:
m = torch.zeros(max_len, dtype=torch.bool)
m[:L] = True
masks.append(m)
return torch.stack(masks)
out_batch['tcra_mask'] = make_mask(tcra_lens, max_tcra_len)
out_batch['tcrb_mask'] = make_mask(tcrb_lens, max_tcrb_len)
out_batch['pep_mask'] = make_mask(pep_lens, max_pep_len)
return out_batch
# ==================================== 小积木:投影 + 门控 =========================================
class ResidueProjector(nn.Module):
"""把不同分支的通道维度对齐到同一 D"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = nn.Linear(in_dim, out_dim) if in_dim != out_dim else nn.Identity()
def forward(self, x): # x: [B,L,Di]
return self.proj(x)
class ResidueDoubleFusion(nn.Module):
"""
ResidueDoubleFusion:
A residue-level two-branch fusion module that combines two modalities (x1, x2)
using cross-attention followed by gated residual fusion and linear projection.
Typical usage:
- x1: physicochemical features
- x2: ESM embeddings (or structure features)
"""
def __init__(self, dim, num_heads=8, dropout=0.1):
super().__init__()
self.dim = dim
# Cross-attention: allows information flow between two modalities
self.cross_attn = nn.MultiheadAttention(
embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True
)
# Gating mechanism: adaptively weight two modalities per residue
self.gate = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.ReLU(),
nn.Linear(dim, 1),
nn.Sigmoid()
)
# Optional projection after fusion
self.out_proj = nn.Linear(dim, dim)
# Layer norms for stable training
self.norm_x1 = nn.LayerNorm(dim)
self.norm_x2 = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim)
def forward(self, x1, x2):
"""
Args:
x1: Tensor [B, L, D] - first modality (e.g., physicochemical)
x2: Tensor [B, L, D] - second modality (e.g., ESM embeddings)
Returns:
fused: Tensor [B, L, D] - fused residue-level representation
"""
# 1) Normalize both branches
x1_norm = self.norm_x1(x1)
x2_norm = self.norm_x2(x2)
# 2) Cross-attention (x1 queries, x2 keys/values)
# This allows x1 to attend to x2 at each residue position
attn_out, _ = self.cross_attn(
query=x1_norm,
key=x2_norm,
value=x2_norm
) # [B, L, D]
# 3) Gating between original x1 and attention-enhanced x2
gate_val = self.gate(torch.cat([x1, attn_out], dim=-1)) # [B, L, 1]
fused = gate_val * x1 + (1 - gate_val) * attn_out
# 4) Optional projection + normalization
fused = self.out_proj(fused)
fused = self.norm_out(fused)
return fused
class ResidueTripleFusion(nn.Module):
"""
ResidueTripleFusion:
A hierarchical three-branch feature fusion module for residue-level representations.
Step 1: Fuse physicochemical features and protein language model embeddings.
Step 2: Fuse the intermediate representation with structure-based features.
Each fusion step uses ResidueDoubleFusion (cross-attention + gating + linear projection).
"""
def __init__(self, dim, num_heads=8, dropout=0.1):
super().__init__()
# Fuse physicochemical + ESM embeddings
self.fuse_phys_esm = ResidueDoubleFusion(dim, num_heads=num_heads, dropout=dropout)
# Fuse the fused phys+esm representation with structure embeddings
self.fuse_f12_struct = ResidueDoubleFusion(dim, num_heads=num_heads, dropout=dropout)
def forward(self, phys, esm, struct):
"""
Args:
phys: Tensor [B, L, D], physicochemical features (e.g., AAindex-based)
esm: Tensor [B, L, D], protein language model embeddings (e.g., ESM2, ProtT5)
struct: Tensor [B, L, D], structure-derived features (e.g., torsion, RSA)
Returns:
fused: Tensor [B, L, D], final fused representation
"""
# Step 1: Fuse physicochemical and ESM embeddings
f12 = self.fuse_phys_esm(phys, esm)
# Step 2: Fuse the intermediate fused representation with structure features
fused = self.fuse_f12_struct(f12, struct)
return fused
class BANLayer(nn.Module):
"""
Bilinear Attention Network Layer with proper 2D masked-softmax.
v_mask: [B, L_v] True=valid
q_mask: [B, L_q] True=valid
"""
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
super().__init__()
self.c = 32
self.k = k
self.v_dim = v_dim
self.q_dim = q_dim
self.h_dim = h_dim
self.h_out = h_out
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
if 1 < k:
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
if h_out <= self.c:
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
else:
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
self.bn = nn.BatchNorm1d(h_dim)
def attention_pooling(self, v, q, att_map): # att_map: [B, L_v, L_q]
logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
if 1 < self.k:
logits = self.p_net(logits.unsqueeze(1)).squeeze(1) * self.k
return logits
def _masked_softmax_2d(self, logits, v_mask, q_mask):
"""
logits: [B, h_out, L_v, L_q]
v_mask: [B, L_v] or None
q_mask: [B, L_q] or None
return: probs [B, h_out, L_v, L_q] (masked entries=0, 在有效的二维子矩阵内归一化)
"""
B, H, Lv, Lq = logits.shape
device = logits.device
if v_mask is None:
v_mask = torch.ones(B, Lv, dtype=torch.bool, device=device)
if q_mask is None:
q_mask = torch.ones(B, Lq, dtype=torch.bool, device=device)
mask2d = (v_mask[:, :, None] & q_mask[:, None, :]) # [B, Lv, Lq]
mask2d = mask2d[:, None, :, :].expand(B, H, Lv, Lq) # [B, H, Lv, Lq]
logits = logits.masked_fill(~mask2d, -float('inf'))
# 在 Lv*Lq 的联合空间做 softmax
flat = logits.view(B, H, -1) # [B, H, Lv*Lq]
# 处理极端情况:某些样本可能无有效格子,避免 NaN
flat = torch.where(torch.isinf(flat), torch.full_like(flat, -1e9), flat)
flat = F.softmax(flat, dim=-1)
flat = torch.nan_to_num(flat, nan=0.0) # 安全兜底
probs = flat.view(B, H, Lv, Lq)
# 把被 mask 的位置清零(数值稳定 & 便于可视化)
probs = probs * mask2d.float()
return probs
def forward(self, v, q, v_mask=None, q_mask=None, softmax=True):
"""
v: [B, L_v, Dv], q: [B, L_q, Dq]
"""
B, L_v, _ = v.size()
_, L_q, _ = q.size()
v_ = self.v_net(v) # [B, L_v, h_dim*k]
q_ = self.q_net(q) # [B, L_q, h_dim*k]
if self.h_out <= self.c:
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias # [B,H,Lv,Lq]
else:
v_t = v_.transpose(1, 2).unsqueeze(3) # [B, K, Lv, 1]
q_t = q_.transpose(1, 2).unsqueeze(2) # [B, K, 1, Lq]
d_ = torch.matmul(v_t, q_t) # [B, K, Lv, Lq]
att_maps = self.h_net(d_.permute(0, 2, 3, 1)) # [B, Lv, Lq, H]
att_maps = att_maps.permute(0, 3, 1, 2) # [B, H, Lv, Lq]
if softmax:
att_maps = self._masked_softmax_2d(att_maps, v_mask, q_mask)
else:
# 即使不 softmax,也把无效格子清 0,避免泄漏
if v_mask is not None:
att_maps = att_maps.masked_fill(~v_mask[:, None, :, None], 0.0)
if q_mask is not None:
att_maps = att_maps.masked_fill(~q_mask[:, None, None, :], 0.0)
# 注意:此时 v_ / q_ 仍是 [B, L, K],与 att_maps 的 [B,H,Lv,Lq] 对齐
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
for i in range(1, self.h_out):
logits = logits + self.attention_pooling(v_, q_, att_maps[:, i, :, :])
logits = self.bn(logits)
return logits, att_maps
class FCNet(nn.Module):
def __init__(self, dims, act='ReLU', dropout=0.2):
super(FCNet, self).__init__()
layers = []
for i in range(len(dims) - 2):
in_dim = dims[i]
out_dim = dims[i + 1]
if 0 < dropout:
layers.append(nn.Dropout(dropout))
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
if '' != act:
layers.append(getattr(nn, act)())
if 0 < dropout:
layers.append(nn.Dropout(dropout))
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
if '' != act:
layers.append(getattr(nn, act)())
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
class StackedEGNN(nn.Module):
def __init__(self, dim, layers, update_coors=False, **egnn_kwargs):
super().__init__()
self.layers = nn.ModuleList([
EGNN(dim=dim, update_coors=update_coors, **egnn_kwargs)
for _ in range(layers)
])
def forward(self, feats, coors, mask=None):
# feats: [B, L_max, D], coors: [B, L_max, 3], mask: [B, L_max] (bool)
for layer in self.layers:
feats, coors = layer(feats, coors, mask=mask)
return feats, coors
class FocalLoss(nn.Module):
def __init__(self, alpha=0.5, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
p_t = torch.exp(-bce_loss)
alpha_weight = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_weight * (1 - p_t) ** self.gamma * bce_loss
if self.reduction == 'mean':
return torch.mean(loss)
elif self.reduction == 'sum':
return torch.sum(loss)
else:
return loss
# ===================================== 主模型(完全版) ===========================================
class PeptideHLABindingPredictor(nn.Module):
def __init__(
self,
phys_dim=20, # 物化编码的输出维度(你定义的 PhysicochemicalEncoder)
pep_dim=256, # 统一后的 peptide 通道
hla_dim=256, # 统一后的 HLA 通道
bilinear_dim=256,
pseudo_seq_pos=None, # 口袋位点(假定 0-based 且落在 [0,179])
device="cuda:0",
loss_fn='bce',
alpha=0.5,
gamma=2.0,
dropout=0.2,
pos_weights=None
):
super().__init__()
self.device = device
self.pep_dim = pep_dim
self.hla_dim = hla_dim
self.bilinear_dim = bilinear_dim
self.alpha = alpha
self.gamma = gamma
self.dropout = dropout
if loss_fn == 'bce':
self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weights]) if pos_weights is not None else None)
elif loss_fn == 'focal':
self.loss_fn = FocalLoss(alpha=alpha, gamma=gamma)
else:
raise ValueError(f"Unknown loss function: {loss_fn}")
self.se3_model = StackedEGNN(
dim=17, layers=3
)
self.max_pep_len = 20
self.max_hla_len = 180
self.pep_pos_embed = nn.Parameter(torch.randn(self.max_pep_len, pep_dim))
self.hla_pos_embed = nn.Parameter(torch.randn(self.max_hla_len, hla_dim))
# —— 分支投影到统一维度(逐残基)——
# peptide 分支(Physicochem -> pep_dim, ESM2(1280) -> pep_dim)
self.proj_pep_phys = ResidueProjector(in_dim=phys_dim, out_dim=pep_dim) # 你的 PhysEnc 输出维设成 pep_dim
self.proj_pep_esm = ResidueProjector(in_dim=1280, out_dim=pep_dim)
# HLA 分支(Physicochem -> hla_dim, ESM2(1280) -> hla_dim, Struct(17/或se3_out) -> hla_dim)
self.proj_hla_phys = ResidueProjector(in_dim=phys_dim, out_dim=hla_dim) # 你的 PhysEnc 输出维设成 hla_dim
self.proj_hla_esm = ResidueProjector(in_dim=1280, out_dim=hla_dim)
self.proj_hla_se3 = ResidueProjector(in_dim=17, out_dim=hla_dim) # 让 se3_model 输出维就是 hla_dim
# —— 门控融合(逐残基)——
self.gate_pep = ResidueDoubleFusion(pep_dim) # pep_phys × pep_esm
self.gate_hla = ResidueTripleFusion(hla_dim) # hla_phys × hla_esm × hla_struct
d_model = self.pep_dim
n_heads = 8
# 1. 用于 "Peptide 查询 HLA" (pep_q_hla_kv)
self.cross_attn_pep_hla = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=self.dropout,
batch_first=True
)
self.norm_cross_pep = nn.LayerNorm(d_model)
# 2. 用于 "HLA 查询 Peptide" (hla_q_pep_kv)
self.cross_attn_hla_pep = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_heads,
dropout=self.dropout,
batch_first=True
)
self.norm_cross_hla = nn.LayerNorm(d_model)
# —— 交互模块(Bilinear attention map)——
self.bi_attn = BANLayer(v_dim=pep_dim, q_dim=hla_dim, h_dim=bilinear_dim, h_out=4, k=3)
# —— 头部 ——
self.head = nn.Sequential(
nn.Linear(bilinear_dim, bilinear_dim),
nn.ReLU(),
nn.Linear(bilinear_dim, 1)
)
# —— 口袋位点 ——
if pseudo_seq_pos is None:
pseudo_seq_pos = [i-2 for i in [7, 9, 24, 45, 59, 62, 63, 66, 67, 69, 70, 73, 74, 76, 77, 80, 81, 84, 95, 97, 99, 114, 116, 118, 143, 147, 150, 152, 156, 158, 159, 163, 167, 171]]
self.register_buffer("contact_idx", torch.tensor(pseudo_seq_pos, dtype=torch.long))
# --------------------------------------------
# Transformer Encoders for peptide & HLA
# --------------------------------------------
encoder_layer_pep = TransformerEncoderLayer(
d_model=pep_dim, # 输入维度
nhead=8, # 注意力头数(可调)
dim_feedforward=pep_dim*4,
dropout=self.dropout,
batch_first=True # 输入形状 [B,L,D]
)
self.pep_encoder = TransformerEncoder(encoder_layer_pep, num_layers=2) # 可以调整层数
encoder_layer_hla = TransformerEncoderLayer(
d_model=hla_dim,
nhead=8,
dim_feedforward=hla_dim*4,
dropout=self.dropout,
batch_first=True
)
self.hla_encoder = TransformerEncoder(encoder_layer_hla, num_layers=1)
# -------------------------- 工具:把 list of [L,D] pad 成 [B,L_max,D] --------------------------
def _pad_stack(self, tensors, L_max=None):
Ls = [t.shape[0] for t in tensors]
if L_max is None: L_max = max(Ls)
D = tensors[0].shape[-1]
B = len(tensors)
out = tensors[0].new_zeros((B, L_max, D))
mask = torch.zeros(B, L_max, dtype=torch.bool, device=out.device)
for i, t in enumerate(tensors):
L = t.shape[0]
out[i, :L] = t
mask[i, :L] = True
return out, mask # [B,L_max,D], [B,L_max]
# ----------------------------------- 口袋掩码 --------------------------------------
def _mask_to_pockets(self, hla_feat):
"""
从 HLA 特征中只保留 pocket 位点,返回:
- hla_pocket: [B, n_pocket, D]
- pocket_mask: [B, n_pocket] (全部 True)
"""
B, L, D = hla_feat.shape
# ensure idx in [0, L-1]
idx = self.contact_idx.clamp(min=0, max=L-1)
# gather pocket features
hla_pocket = hla_feat[:, idx, :] # [B, n_pocket, D]
return hla_pocket
def add_positional_encoding(self, x, pos_embed):
"""
x: [B, L, D]
pos_embed: [L_max, D]
"""
B, L, D = x.shape
# 截取前 L 个位置编码
pe = pos_embed[:L, :].unsqueeze(0).expand(B, -1, -1) # [B, L, D]
return x + pe
def forward(self, batch):
# take batch from DataLoader
pep_phys = batch['pep_phys'].to(self.device, non_blocking=True)
pep_esm = batch['pep_esm'].to(self.device, non_blocking=True)
hla_phys = batch['hla_phys'].to(self.device, non_blocking=True)
hla_esm = batch['hla_esm'].to(self.device, non_blocking=True)
hla_struct = batch['hla_struct'].to(self.device, non_blocking=True)
hla_coord = batch['hla_coord'].to(self.device, non_blocking=True)
labels = batch['label'].to(self.device)
# 1) peptide 物化 + ESM2 → gate 融合
pep_phys = self.proj_pep_phys(pep_phys)
pep_esm = self.proj_pep_esm(pep_esm)
pep_feat = self.gate_pep(pep_phys, pep_esm) # [B, Lp, D]
pep_feat = self.add_positional_encoding(pep_feat, self.pep_pos_embed)
pep_feat = self.pep_encoder(pep_feat, src_key_padding_mask=~batch['pep_mask'].to(self.device, non_blocking=True))
# 2) HLA 物化 + ESM2 + 结构 → SE3 → gate 融合
hla_phys = self.proj_hla_phys(hla_phys)
hla_esm = self.proj_hla_esm(hla_esm)
# hla_struct 是 [B, 180, 17],先过 SE3
hla_se3 = self.se3_model(hla_struct, hla_coord, None)[0] # [B, 180, 17]
hla_se3 = self.proj_hla_se3(hla_se3) # →256
hla_feat = self.gate_hla(hla_phys, hla_esm, hla_se3)
hla_feat = self.add_positional_encoding(hla_feat, self.hla_pos_embed)
hla_feat = self.hla_encoder(hla_feat)
# cross attention for pep
pep_feat_cross, _ = self.cross_attn_pep_hla(
query=pep_feat,
key=hla_feat,
value=hla_feat,
key_padding_mask=None
)
# cross attention for hla
hla_feat_cross, _ = self.cross_attn_hla_pep(
query=hla_feat,
key=pep_feat,
value=pep_feat,
key_padding_mask=~batch['pep_mask'].to(self.device, non_blocking=True)
)
pep_feat_updated = self.norm_cross_pep(pep_feat + pep_feat_cross)
hla_feat_updated = self.norm_cross_hla(hla_feat + hla_feat_cross)
# 3) mask HLA 口袋位点
hla_pocket = self._mask_to_pockets(hla_feat_updated)
# 4) bilinear attention
fused_vec, attn = self.bi_attn(
pep_feat_updated,
hla_pocket,
v_mask=batch['pep_mask'].to(self.device, non_blocking=True),
q_mask=None
)
logits = self.head(fused_vec).squeeze(-1)
probs = torch.sigmoid(logits).detach().cpu().numpy()
binding_loss = self.loss_fn(logits, labels.float())
return probs, binding_loss, attn.detach().cpu().numpy().sum(axis=1), fused_vec.detach().cpu().numpy()
# -------------------------- 编码器复用接口(给 TCR-HLA 模型用) --------------------------
def _pad_peptide(self, x, max_len):
"""Pad peptide feature tensor [1, L, D] to [1, max_len, D]."""
B, L, D = x.shape
if L < max_len:
pad = x.new_zeros(B, max_len - L, D)
return torch.cat([x, pad], dim=1)
else:
return x[:, :max_len, :]
@torch.no_grad()
def encode_peptide_hla(self, pep_id, pep_phys, pep_esm, hla_phys, hla_esm, hla_struct, hla_coord, max_pep_len):
Lp = len(pep_id)
pep_phys = self.proj_pep_phys(pep_phys)
pep_esm = self.proj_pep_esm(pep_esm)
pep_phys = self._pad_peptide(pep_phys, max_pep_len)
pep_esm = self._pad_peptide(pep_esm, max_pep_len)
device = pep_phys.device
pep_mask = torch.zeros(1, max_pep_len, dtype=torch.bool, device=device)
pep_mask[0, :Lp] = True
pep_feat = self.gate_pep(pep_phys, pep_esm)
pep_feat = self.add_positional_encoding(pep_feat, self.pep_pos_embed)
pep_feat = self.pep_encoder(pep_feat, src_key_padding_mask=~pep_mask)
# 2) hla encoding
hla_phys = self.proj_hla_phys(hla_phys)
hla_esm = self.proj_hla_esm(hla_esm)
hla_se3 = self.se3_model(hla_struct, hla_coord, None)[0]
hla_se3 = self.proj_hla_se3(hla_se3)
hla_feat = self.gate_hla(hla_phys, hla_esm, hla_se3)
hla_feat = self.add_positional_encoding(hla_feat, self.hla_pos_embed)
hla_feat = self.hla_encoder(hla_feat)
# --- 3a. Peptide (Q) 查询 HLA (K, V) ---
pep_feat_cross, _ = self.cross_attn_pep_hla(
query=pep_feat,
key=hla_feat,
value=hla_feat,
key_padding_mask=None
)
pep_feat_updated = self.norm_cross_pep(pep_feat + pep_feat_cross)
# --- 3b. HLA (Q) 查询 Peptide (K, V) ---
hla_feat_cross, _ = self.cross_attn_hla_pep(
query=hla_feat,
key=pep_feat,
value=pep_feat,
key_padding_mask=~pep_mask
)
hla_feat_updated = self.norm_cross_hla(hla_feat + hla_feat_cross)
return pep_feat_updated, hla_feat_updated
class TCRPeptideHLABindingPredictor(nn.Module):
def __init__(
self,
tcr_dim=256,
pep_dim=256,
hla_dim=256,
bilinear_dim=256,
loss_fn='bce',
alpha=0.5,
gamma=2.0,
dropout=0.1,
device='cuda:0',
pos_weights=None
):
super().__init__()
# TCR α / β position embeddings
self.max_tcra_len = 500
self.max_tcrb_len = 500
self.max_pep_len = 20
self.max_hla_len = 180
self.alpha = alpha
self.gamma = gamma
self.dropout = dropout
if loss_fn == 'bce':
self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weights]) if pos_weights is not None else None)
elif loss_fn == 'focal':
self.loss_fn = FocalLoss(alpha=alpha, gamma=gamma)
else:
raise ValueError(f"Unknown loss function: {loss_fn}")
self.tcra_pos_embed = nn.Parameter(torch.randn(self.max_tcra_len, tcr_dim))
self.tcrb_pos_embed = nn.Parameter(torch.randn(self.max_tcrb_len, tcr_dim))
self.pep_pos_embed = nn.Parameter(torch.randn(self.max_pep_len, pep_dim))
self.hla_pos_embed = nn.Parameter(torch.randn(self.max_hla_len, hla_dim))
self.device = device
self.tcr_dim = tcr_dim
self.pep_dim = pep_dim
self.hla_dim = hla_dim
self.bilinear_dim = bilinear_dim
d_model = tcr_dim
n_heads = 8
self.cross_attn_tcra_pep = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True)
self.cross_attn_tcra_hla = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True)
self.cross_attn_tcrb_pep = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True)
self.cross_attn_tcrb_hla = nn.MultiheadAttention(d_model, n_heads, dropout=self.dropout, batch_first=True)
self.norm_tcra_pep = nn.LayerNorm(d_model)
self.norm_tcra_hla = nn.LayerNorm(d_model)
self.norm_tcrb_pep = nn.LayerNorm(d_model)
self.norm_tcrb_hla = nn.LayerNorm(d_model)
# =======================
# TCRα / TCRβ encoders
# =======================
def make_tcr_encoder():
proj_phys = ResidueProjector(20, tcr_dim)
proj_esm = ResidueProjector(1280, tcr_dim)
proj_struct = ResidueProjector(17, tcr_dim)
se3 = StackedEGNN(dim=17, layers=1)
gate = ResidueTripleFusion(tcr_dim)
encoder_layer = TransformerEncoderLayer(
d_model=tcr_dim, nhead=8, dim_feedforward=tcr_dim*4, dropout=self.dropout, batch_first=True
)
encoder = TransformerEncoder(encoder_layer, num_layers=2)
return nn.ModuleDict(dict(
proj_phys=proj_phys, proj_esm=proj_esm, proj_struct=proj_struct,
se3=se3, gate=gate, encoder=encoder
))
self.tcra_enc = make_tcr_encoder()
self.tcrb_enc = make_tcr_encoder()
# =======================
# Peptide encoder (phys + esm + structure)
# =======================
self.proj_pep_phys = ResidueProjector(20, pep_dim)
self.proj_pep_esm = ResidueProjector(1280, pep_dim)
self.proj_pep_struct = ResidueProjector(17, pep_dim)
self.pep_se3 = StackedEGNN(dim=17, layers=1)
self.pep_gate = ResidueTripleFusion(pep_dim)
pep_encoder_layer = TransformerEncoderLayer(
d_model=pep_dim, nhead=8, dim_feedforward=pep_dim*4, dropout=self.dropout, batch_first=True
)
self.pep_encoder = TransformerEncoder(pep_encoder_layer, num_layers=2)
# =======================
# HLA encoder
# =======================
self.proj_hla_phys = ResidueProjector(20, hla_dim)
self.proj_hla_esm = ResidueProjector(1280, hla_dim)
self.proj_hla_struct = ResidueProjector(17, hla_dim)
self.hla_se3 = StackedEGNN(dim=17, layers=1)
self.hla_gate = ResidueTripleFusion(hla_dim)
hla_encoder_layer = TransformerEncoderLayer(
d_model=hla_dim, nhead=8, dim_feedforward=hla_dim*4, dropout=self.dropout, batch_first=True
)
self.hla_encoder = TransformerEncoder(hla_encoder_layer, num_layers=1)
self.pep_gate_2 = ResidueDoubleFusion(pep_dim)
self.hla_gate_2 = ResidueDoubleFusion(hla_dim)
# =======================
# Bilinear interactions
# =======================
self.bi_tcra_pep = BANLayer(tcr_dim, pep_dim, bilinear_dim, h_out=4, k=3)
self.bi_tcrb_pep = BANLayer(tcr_dim, pep_dim, bilinear_dim, h_out=4, k=3)
self.bi_tcra_hla = BANLayer(tcr_dim, hla_dim, bilinear_dim, h_out=4, k=3)
self.bi_tcrb_hla = BANLayer(tcr_dim, hla_dim, bilinear_dim, h_out=4, k=3)
# =======================
# Head
# =======================
total_fused_dim = bilinear_dim * 4
self.head = nn.Sequential(
nn.Linear(total_fused_dim, bilinear_dim),
nn.ReLU(),
nn.Linear(bilinear_dim, 1)
)
def encode_tcr(self, x_phys, x_esm, x_struct, x_coord, x_mask, enc, pos_embed):
phys = enc['proj_phys'](x_phys)
esm = enc['proj_esm'](x_esm)
se3 = enc['se3'](x_struct, x_coord, None)[0]
se3 = enc['proj_struct'](se3)
feat = enc['gate'](phys, esm, se3)
feat = self.add_positional_encoding(feat, pos_embed)
feat = enc['encoder'](feat, src_key_padding_mask=~x_mask)
return feat
def add_positional_encoding(self, x, pos_embed):
"""
x: [B, L, D]
pos_embed: [L_max, D]
"""
B, L, D = x.shape
pe = pos_embed[:L, :].unsqueeze(0).expand(B, -1, -1)
return x + pe
# def _extract_cdr3_segment(self, tcr_feat, cdr3_start, cdr3_end):
# B, L, D = tcr_feat.shape
# device = tcr_feat.device
# max_len = (cdr3_end - cdr3_start + 1).max().item()
# # [max_len], 0..max_len-1
# rel_idx = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1) # [B, max_len]
# # absolute index = start + rel_idx
# abs_idx = cdr3_start.unsqueeze(1) + rel_idx
# # clamp end
# abs_idx = abs_idx.clamp(0, L-1)
# # mask positions beyond end
# mask = rel_idx <= (cdr3_end - cdr3_start).unsqueeze(1)
# # gather
# # expand abs_idx to [B, max_len, D] for gather
# gather_idx = abs_idx.unsqueeze(-1).expand(-1, -1, D)
# out = torch.gather(tcr_feat, 1, gather_idx) # [B, max_len, D]
# return out, mask
def _extract_cdr3_segment(self, tcr_feat, cdr3_start, cdr3_end):
"""
Extracts CDR3 embeddings and corresponding mask.
tcr_feat: [B, L, D]
cdr3_start, cdr3_end: [B]
Returns:
out: [B, max_len, D]
mask: [B, max_len] (True = valid)
"""
B, L, D = tcr_feat.shape
device = tcr_feat.device
# 每个样本的 cdr3 长度
lens = (cdr3_end - cdr3_start).clamp(min=0)
max_len = lens.max().item()
rel_idx = torch.arange(max_len, device=device).unsqueeze(0).expand(B, -1) # [B, max_len]
abs_idx = cdr3_start.unsqueeze(1) + rel_idx # [B, max_len]
# mask: True 表示有效
mask = rel_idx < lens.unsqueeze(1) # 注意这里 "<" 就够了
# 将超出范围的索引设为 0(任意有效索引都行,因为会被mask掉)
abs_idx = torch.where(mask, abs_idx, torch.zeros_like(abs_idx))
# gather
gather_idx = abs_idx.unsqueeze(-1).expand(-1, -1, D)
out = torch.gather(tcr_feat, 1, gather_idx)
# 对 mask 为 False 的位置强制置零,避免无效 token 参与计算
out = out * mask.unsqueeze(-1)
return out, mask
def forward(self, batch):
# TCRα / TCRβ
tcra_feat = self.encode_tcr(
batch['tcra_phys'].to(self.device, non_blocking=True),
batch['tcra_esm'].to(self.device, non_blocking=True),
batch['tcra_struct'].to(self.device, non_blocking=True),
batch['tcra_coord'].to(self.device, non_blocking=True),
batch['tcra_mask'].to(self.device, non_blocking=True),
self.tcra_enc,
self.tcra_pos_embed
)
tcrb_feat = self.encode_tcr(
batch['tcrb_phys'].to(self.device, non_blocking=True),
batch['tcrb_esm'].to(self.device, non_blocking=True),
batch['tcrb_struct'].to(self.device, non_blocking=True),
batch['tcrb_coord'].to(self.device, non_blocking=True),
batch['tcrb_mask'].to(self.device, non_blocking=True),
self.tcrb_enc,
self.tcrb_pos_embed
)
# peptide
pep_phys = self.proj_pep_phys(batch['pep_phys'].to(self.device, non_blocking=True))
pep_esm = self.proj_pep_esm(batch['pep_esm'].to(self.device, non_blocking=True))
pep_se3 = self.pep_se3(batch['pep_struct'].to(self.device, non_blocking=True), batch['pep_coord'].to(self.device, non_blocking=True), None)[0]
pep_se3 = self.proj_pep_struct(pep_se3)
pep_feat = self.pep_gate(pep_phys, pep_esm, pep_se3)
pep_feat = self.add_positional_encoding(pep_feat, self.pep_pos_embed)
pep_feat = self.pep_encoder(
pep_feat,
src_key_padding_mask=~batch['pep_mask'].to(self.device)
)
# hla
hla_phys = self.proj_hla_phys(batch['hla_phys'].to(self.device, non_blocking=True))
hla_esm = self.proj_hla_esm(batch['hla_esm'].to(self.device, non_blocking=True))
hla_se3 = self.hla_se3(batch['hla_struct'].to(self.device, non_blocking=True), batch['hla_coord'].to(self.device, non_blocking=True), None)[0]
hla_se3 = self.proj_hla_struct(hla_se3)
hla_feat = self.hla_gate(hla_phys, hla_esm, hla_se3)
hla_feat = self.add_positional_encoding(hla_feat, self.hla_pos_embed)
hla_feat = self.hla_encoder(hla_feat)
if ('pep_feat_pretrain' in batch) and ('hla_feat_pretrain' in batch):
pep_pretrain = batch['pep_feat_pretrain'].to(self.device, non_blocking=True)
hla_pretrain = batch['hla_feat_pretrain'].to(self.device, non_blocking=True)
# ---- 鲁棒的长度对齐 (裁剪到最小长度) ----
Lp = pep_feat.shape[1]
Lp_pretrain = pep_pretrain.shape[1]
if Lp != Lp_pretrain:
Lp_min = min(Lp, Lp_pretrain)
pep_feat = pep_feat[:, :Lp_min, :]
pep_pretrain = pep_pretrain[:, :Lp_min, :]
Lh = hla_feat.shape[1]
Lh_pretrain = hla_pretrain.shape[1]
if Lh != Lh_pretrain:
Lh_min = min(Lh, Lh_pretrain)
hla_feat = hla_feat[:, :Lh_min, :]
hla_pretrain = hla_pretrain[:, :Lh_min, :]
# ---- Peptide gating ----
pep_feat = self.pep_gate_2(pep_feat, pep_pretrain)
# ---- HLA gating ----
hla_feat = self.hla_gate_2(hla_feat, hla_pretrain)
# TCRα CDR3 segment
tcra_cdr3, cdr3a_mask = self._extract_cdr3_segment(
tcra_feat,
batch['cdr3a_start'].to(self.device, non_blocking=True),
batch['cdr3a_end'].to(self.device, non_blocking=True)
)
# TCRβ CDR3 segment
tcrb_cdr3, cdr3b_mask = self._extract_cdr3_segment(
tcrb_feat,
batch['cdr3b_start'].to(self.device, non_blocking=True),
batch['cdr3b_end'].to(self.device, non_blocking=True)
)
# TCRα CDR3 ← Peptide
tcra_cdr3_cross, _ = self.cross_attn_tcra_pep(
query=tcra_cdr3, # [B, La_cdr3, D]
key=pep_feat, value=pep_feat, # [B, Lp, D]
key_padding_mask=~batch['pep_mask'].to(self.device)
)
tcra_cdr3 = self.norm_tcra_pep(tcra_cdr3 + tcra_cdr3_cross)
# 重新掩蔽 padding 的 CDR3 位置,防止无效 token 漏光
tcra_cdr3 = tcra_cdr3 * cdr3a_mask.unsqueeze(-1)
# TCRβ CDR3 ← Peptide
tcrb_cdr3_cross, _ = self.cross_attn_tcrb_pep(
query=tcrb_cdr3,
key=pep_feat, value=pep_feat,
key_padding_mask=~batch['pep_mask'].to(self.device)
)
tcrb_cdr3 = self.norm_tcrb_pep(tcrb_cdr3 + tcrb_cdr3_cross)
tcrb_cdr3 = tcrb_cdr3 * cdr3b_mask.unsqueeze(-1)
# ------------------ Cross-Attn:TCR 全序列 ↔ HLA(整条 TCR) ------------------
# TCRα full ← HLA
tcra_hla_cross, _ = self.cross_attn_tcra_hla(
query=tcra_feat, # [B, La, D]
key=hla_feat, value=hla_feat, # [B, Lh, D]
key_padding_mask=None
)
tcra_feat = self.norm_tcra_hla(tcra_feat + tcra_hla_cross)
tcra_feat = tcra_feat * batch['tcra_mask'].to(self.device).unsqueeze(-1)
# TCRβ full ← HLA
tcrb_hla_cross, _ = self.cross_attn_tcrb_hla(
query=tcrb_feat,
key=hla_feat, value=hla_feat,
key_padding_mask=None
)
tcrb_feat = self.norm_tcrb_hla(tcrb_feat + tcrb_hla_cross)
tcrb_feat = tcrb_feat * batch['tcrb_mask'].to(self.device).unsqueeze(-1)
# bilinear fusion
vec_tcra_pep, attn_tcra_pep = self.bi_tcra_pep(tcra_cdr3, pep_feat, v_mask=cdr3a_mask, q_mask=batch['pep_mask'].to(self.device))
vec_tcrb_pep, attn_tcrb_pep = self.bi_tcrb_pep(tcrb_cdr3, pep_feat, v_mask=cdr3b_mask, q_mask=batch['pep_mask'].to(self.device))
vec_tcra_hla, attn_tcra_hla = self.bi_tcra_hla(tcra_feat, hla_feat, v_mask=batch['tcra_mask'].to(self.device), q_mask=None)
vec_tcrb_hla, attn_tcrb_hla = self.bi_tcrb_hla(tcrb_feat, hla_feat, v_mask=batch['tcrb_mask'].to(self.device), q_mask=None)
attn_tcra_pep_small = attn_tcra_pep.sum(dim=1).float()
attn_tcrb_pep_small = attn_tcrb_pep.sum(dim=1).float()
attn_tcra_hla_small = attn_tcra_hla.sum(dim=1).float()
attn_tcrb_hla_small = attn_tcrb_hla.sum(dim=1).float()
attn_dict = {
'attn_tcra_pep': attn_tcra_pep_small.detach().cpu().numpy(),
'attn_tcrb_pep': attn_tcrb_pep_small.detach().cpu().numpy(),
'attn_tcra_hla': attn_tcra_hla_small.detach().cpu().numpy(),
'attn_tcrb_hla': attn_tcrb_hla_small.detach().cpu().numpy()
}
fused = torch.cat([vec_tcra_pep, vec_tcrb_pep, vec_tcra_hla, vec_tcrb_hla], dim=-1)
logits = self.head(fused).squeeze(-1)
labels = batch['label'].to(self.device)
loss_binding = self.loss_fn(logits, labels.float())
probs = torch.sigmoid(logits)
return probs, loss_binding, pep_feat.detach().cpu().numpy(), attn_dict