# ─── monkey-patch gradio_client so bool schemas don’t crash json_schema_to_python_type ─── import gradio_client.utils as _gc_utils # back up originals _orig_get_type = _gc_utils.get_type _orig_json2py = _gc_utils._json_schema_to_python_type def _patched_get_type(schema): # treat any boolean schema as if it were an empty dict if isinstance(schema, bool): schema = {} return _orig_get_type(schema) def _patched_json_schema_to_python_type(schema, defs=None): # treat any boolean schema as if it were an empty dict if isinstance(schema, bool): schema = {} return _orig_json2py(schema, defs) _gc_utils.get_type = _patched_get_type _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type # ─── now it’s safe to import Gradio and build your interface ─────────────────────────── import gradio as gr from gradio.themes import Soft import os import sys import argparse import tempfile import shutil import base64 import io import torch import selfies import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib import cm from typing import Optional from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel from torch.utils.data import DataLoader from Bio.PDB import PDBParser, MMCIFParser from Bio.Data import IUPACData from utils.drug_tokenizer import DrugTokenizer from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI from utils.foldseek_util import get_struc_seq # ───── Helpers ───────────────────────────────────────────────── three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()} three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"}) def simple_seq_from_structure(path: str) -> str: parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True) structure = parser.get_structure("P", path) chains = list(structure.get_chains()) if not chains: return "" chain = max(chains, key=lambda c: len(list(c.get_residues()))) return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain) def smiles_to_selfies(smiles_text: str) -> Optional[str]: try: sf = selfies.encoder(smiles_text) smiles_back = selfies.decoder(sf) if not smiles_back: return None return sf except Exception: return None def parse_config(): p = argparse.ArgumentParser() p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2") p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer") p.add_argument("--agg_mode", type=str, default="mean_all_tok") p.add_argument("--group_size", type=int, default=1) p.add_argument("--fusion", default="CAN") p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--save_path_prefix", default="save_model_ckp/") p.add_argument("--dataset", default="Human") return p.parse_args() args = parse_config() DEVICE = args.device # ───── Load models & tokenizers ───────────────────────────────── prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path) prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path) drug_tokenizer = DrugTokenizer() drug_model = AutoModel.from_pretrained(args.drug_encoder_path) encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE) def collate_fn(batch): query1, query2, scores = zip(*batch) query_encodings1 = prot_tokenizer.batch_encode_plus( list(query1), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) query_encodings2 = drug_tokenizer.batch_encode_plus( list(query2), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) scores = torch.tensor(list(scores)) attention_mask1 = query_encodings1["attention_mask"].bool() attention_mask2 = query_encodings2["attention_mask"].bool() return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores def get_case_feature(model, loader): model.eval() with torch.no_grad(): for p_ids, p_mask, d_ids, d_mask, _ in loader: p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE) d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE) p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask) return [(p_emb.cpu(), d_emb.cpu(), p_ids.cpu(), d_ids.cpu(), p_mask.cpu(), d_mask.cpu(), None)] # ─────────────── visualisation ─────────────────────────────────────────── def _safe_is_special(tokenizer, tok: str) -> bool: # Some tokenisers expose different special token sets; fall back conservatively. special_sets = [] if hasattr(tokenizer, "all_special_tokens"): special_sets.append(set(tokenizer.all_special_tokens)) if hasattr(tokenizer, "special_tokens_map"): special_sets.extend(set(v) if isinstance(v, list) else {v} for v in tokenizer.special_tokens_map.values()) for s in special_sets: if tok in s: return True return False def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str: """ Render a Protein → Drug cross-attention heat-map and optional Top-30 residue table. """ model.eval() with torch.no_grad(): # ── unpack single-case tensors ─────────────────────────────────────────── p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0] p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE) p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE) # ── forward pass: Protein → Drug attention (B, n_p, n_d) ─────────────── _, att_pd = model(p_emb, d_emb, p_mask, d_mask) attn = att_pd.squeeze(0).cpu() # (n_p, n_d) # ── decode tokens (skip special symbols) ──────────────────────────────── def clean_ids(ids, tokenizer): toks = tokenizer.convert_ids_to_tokens(ids.tolist()) return [t for t in toks if not _safe_is_special(tokenizer, t)] p_tokens_full = clean_ids(p_ids[0], prot_tokenizer) p_indices_full = list(range(1, len(p_tokens_full) + 1)) d_tokens_full = clean_ids(d_ids[0], drug_tokenizer) d_indices_full = list(range(1, len(d_tokens_full) + 1)) # ── safety cut-off to match attn mat size ────────────────────────────── p_tokens = p_tokens_full[: attn.size(0)] p_indices = p_indices_full[: attn.size(0)] d_tokens = d_tokens_full[: attn.size(1)] d_indices = d_indices_full[: attn.size(1)] attn = attn[: len(p_tokens), : len(d_tokens)] orig_attn = attn.clone() # ── adaptive sparsity pruning ─────────────────────────────────────────── thr = attn.max().item() * 0.05 if attn.numel() > 0 else 0.0 row_keep = (attn.max(dim=1).values > thr) if attn.size(0) else torch.tensor([], dtype=torch.bool) col_keep = (attn.max(dim=0).values > thr) if attn.size(1) else torch.tensor([], dtype=torch.bool) if row_keep.sum().item() < 3 and attn.size(0) > 0: row_keep = torch.ones(attn.size(0), dtype=torch.bool) if col_keep.sum().item() < 3 and attn.size(1) > 0: col_keep = torch.ones(attn.size(1), dtype=torch.bool) attn = attn[row_keep][:, col_keep] p_tokens = [tok for keep, tok in zip(row_keep.tolist(), p_tokens) if keep] p_indices = [idx for keep, idx in zip(row_keep.tolist(), p_indices) if keep] d_tokens = [tok for keep, tok in zip(col_keep.tolist(), d_tokens) if keep] d_indices = [idx for keep, idx in zip(col_keep.tolist(), d_indices) if keep] # ── cap column count at 150 for readability ───────────────────────────── if attn.size(1) > 150: topc = torch.topk(attn.sum(0), k=150).indices attn = attn[:, topc] d_tokens = [d_tokens[i] for i in topc] d_indices = [d_indices[i] for i in topc] # ── draw heat-map ────────────────────────────────────────────────────── x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)] y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)] fig_w = min(22, max(8, len(x_labels) * 0.6)) fig_h = min(24, max(6, len(y_labels) * 0.8)) fig, ax = plt.subplots(figsize=(fig_w, fig_h)) im = ax.imshow(attn.numpy(), aspect="auto", cmap=cm.viridis, interpolation="nearest") ax.set_title("Protein → Drug Attention", pad=8, fontsize=11) ax.set_xticks(range(len(x_labels))) ax.set_xticklabels(x_labels, rotation=90, fontsize=8, ha="center", va="center") ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False, pad=27) ax.set_yticks(range(len(y_labels))) ax.set_yticklabels(y_labels, fontsize=7) ax.tick_params(axis="y", top=True, bottom=False, labeltop=True, labelbottom=False, pad=10) fig.colorbar(im, fraction=0.026, pad=0.01) fig.tight_layout() # build PNG / PDF buf_png = io.BytesIO() fig.savefig(buf_png, format="png", dpi=140) buf_png.seek(0) buf_pdf = io.BytesIO() fig.savefig(buf_pdf, format="pdf") buf_pdf.seek(0) plt.close(fig) png_b64 = base64.b64encode(buf_png.getvalue()).decode() pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode() html_heat = ( f"
{drug_tok_text} → Top-30 Protein residues"
f"Please extract or enter a protein sequence first.
" if not drug_seq.strip(): return "Please enter a drug sequence.
" if not drug_seq.strip().startswith("["): conv = smiles_to_selfies(drug_seq.strip()) if conv is None: return "SMILES→SELFIES conversion failed.
" drug_seq = conv loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn) feats = get_case_feature(encoding, loader) model = FusionDTI(446, 768, args).to(DEVICE) ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt") if os.path.isfile(ckpt): model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None) def clear_cb(): return "", "", None, "", None # ───── Theme & CSS ───────────────────────────────────────────── css = """ :root{ --bg:#f8fafc; --card:#f8fafc; --text:#0f172a; --muted:#6b7280; --border:#e5e7eb; --shadow:0 6px 24px rgba(2,6,23,.06); --radius:14px; --icon-size:20px; } *{box-sizing:border-box} html,body{background:#fff!important;color:var(--text)!important} .gradio-container{max-width:1120px;margin:0 auto} /* Title and subtitle */ h1{ font-family:Inter,ui-sans-serif;letter-spacing:.2px;font-weight:700; font-size:32px;margin:22px 0 12px;text-align:center } .subtle{color:var(--muted);font-size:14px;text-align:center;margin:-6px 0 18px} /* Card style */ .card{ background:var(--card); border:1px solid var(--border); border-radius:var(--radius); box-shadow:var(--shadow); padding:22px; } /* Top links */ .link-row{display:flex;justify-content:center;gap:14px;margin:0 auto 18px;flex-wrap:wrap} /* Two-column grid: left=input, right=controls */ .grid-2{display:grid;grid-template-columns:1.4fr .9fr;gap:16px} .grid-2 .col{display:flex;flex-direction:column;gap:12px} /* Buttons */ .gr-button{border-radius:12px !important;font-weight:700 !important;letter-spacing:.2px} #extract-btn{background:linear-gradient(90deg,#EFAFB2,#EFAFB2); color:#0f172a} #inference-btn{background:linear-gradient(90deg,#B2CBDF,#B2CBDF); color:#0f172a} #clear-btn{background:#FFE2B5; color:#0A0A0A; border:1px solid var(--border)} /* Result spacing */ #result-table{margin-bottom:16px} /* Figure container */ .figure-wrap{border:1px solid var(--border);border-radius:12px;overflow:hidden;box-shadow:var(--shadow)} .figure-wrap img{display:block;width:100%;height:auto} /* Right pane: vertical radio layout and full-width controls (kept for button styling) */ .right-pane .gr-button{ width:100% !important; height:48px !important; border-radius:12px !important; font-weight:700 !important; letter-spacing:.2px; } /* ───────── Publication links (Bulma-like) ───────── */ .publication-links { display: flex; justify-content: center; gap: 14px; flex-wrap: wrap; margin: 6px 0 18px; } .link-block a { display: inline-flex; align-items: center; gap: 8px; padding: 10px 18px; font-size: 14px; font-weight: 600; border-radius: 9999px; text-decoration: none; transition: all 0.15s ease-in-out; } /* colour variants */ .btn-danger { background:#e2e8f0; color:#0f172a; } .btn-dark { background:#e2e8f0; color:#0f172a; } .btn-link { background:#e2e8f0; color:#0f172a; } .btn-warning { background:#e2e8f0; color:#0f172a; } .link-block a:hover { filter: brightness(0.95); transform: translateY(-1px); } .loscalzo-block img { height: 100px; width: auto; object-fit: contain; } .loscalzo-block { display: flex; align-items: center; gap: 10px; margin: 0 auto; justify-content: center; } .link-btn{ display:inline-flex !important; align-items:center !important; gap:8px !important; padding:10px 18px !important; font-size:14px !important; font-weight:600 !important; border-radius:9999px !important; background:#e2e8f0 !important; color:#0f172a !important; text-decoration:none !important; border:1px solid #e5e7eb !important; transition:all 0.15s ease-in-out !important; } .link-btn:hover{ filter:brightness(0.95); transform:translateY(-1px); } .project-links{ display:flex !important; justify-content:center !important; gap:28px !important; flex-wrap:wrap !important; margin-bottom:32px !important; } #example-btn { background: #979ea8 !important; color: #1e293b !important; } """ # ───── Gradio Interface Definition ─────────────────────────────── with gr.Blocks() as demo: # ───────────── Title ───────────── gr.Markdown(".pdb or .cif file. A structure-aware
sequence will be generated using
Foldseek,
based on 3D structures from
AlphaFold DB or the
Protein Data Bank (PDB)..cif or .pdb file.