Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| from io import StringIO | |
| from predictor import load_model, predict_from_df | |
| from Bio import SeqIO | |
| import torch | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| # ============================== | |
| # 页面配置 | |
| # ============================== | |
| st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide") | |
| st.title("🧠 Peptide–HLA Binding Predictor") | |
| st.markdown(""" | |
| Upload a **CSV** file with columns `Peptide` and `HLA`, | |
| or a **FASTA** file containing peptide sequences (headers optionally include HLA type). | |
| """) | |
| import os | |
| os.environ["HF_HOME"] = "/data/huggingface" | |
| os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface" | |
| os.environ["TORCH_HOME"] = "/data/huggingface" | |
| os.environ["ESM_CACHE_DIR"] = "/data/phla_cache" | |
| os.makedirs("/data/phla_cache", exist_ok=True) | |
| # ============================== | |
| # 模型加载函数(缓存) | |
| # ============================== | |
| def get_model(): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # 尝试从本地加载,如果失败则从 HF Hub 下载 | |
| local_path = "/app/src/model.pt" | |
| if not os.path.exists(local_path): | |
| st.warning("Model not found locally. Downloading from Hugging Face Hub...") | |
| local_path = hf_hub_download( | |
| repo_id="caokai1073/StriMap", # 替换为你的 Space repo | |
| filename="src/model.pt" | |
| ) | |
| model, device = load_model(local_path, device=device) | |
| return model, device | |
| model, device = get_model() | |
| # ============================== | |
| # 文件上传(使用 /tmp 临时目录) | |
| # ============================== | |
| uploaded_file = st.file_uploader("Upload CSV or FASTA", type=["csv", "fasta"]) | |
| if uploaded_file: | |
| # 将上传文件保存到可写的 /tmp 路径 | |
| temp_path = os.path.join("/tmp", uploaded_file.name) | |
| with open(temp_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| # ============================== | |
| # 文件解析 | |
| # ============================== | |
| if uploaded_file.name.endswith(".csv"): | |
| df = pd.read_csv(temp_path) | |
| else: | |
| seqs = [] | |
| for rec in SeqIO.parse(temp_path, "fasta"): | |
| header = rec.id | |
| seq = str(rec.seq) | |
| # 尝试从header提取HLA,比如 ">HLA-A*02:01|SLLMWITQC" | |
| if "|" in header: | |
| hla, _ = header.split("|", 1) | |
| else: | |
| hla = "HLA-Unknown" | |
| seqs.append([seq, hla]) | |
| df = pd.DataFrame(seqs, columns=["Peptide", "HLA"]) | |
| st.write("✅ Uploaded data preview:") | |
| st.dataframe(df.head()) | |
| # ============================== | |
| # 模型预测 | |
| # ============================== | |
| if st.button("🚀 Run Prediction"): | |
| with st.spinner("Running model inference..."): | |
| result_df = predict_from_df(df, model) | |
| st.success("✅ Prediction complete!") | |
| st.dataframe(result_df.head(10)) | |
| # ============================== | |
| # 下载结果 | |
| # ============================== | |
| csv = result_df.to_csv(index=False).encode("utf-8") | |
| st.download_button( | |
| "⬇️ Download results as CSV", | |
| data=csv, | |
| file_name="hla_binding_predictions.csv", | |
| mime="text/csv", | |
| ) |