StriMap / src /streamlit_app.py
caokai1073's picture
Update src/streamlit_app.py
8721078 verified
raw
history blame
3.05 kB
import streamlit as st
import pandas as pd
from io import StringIO
from predictor import load_model, predict_from_df
from Bio import SeqIO
import torch
import os
from huggingface_hub import hf_hub_download
# ==============================
# 页面配置
# ==============================
st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide")
st.title("🧠 Peptide–HLA Binding Predictor")
st.markdown("""
Upload a **CSV** file with columns `Peptide` and `HLA`,
or a **FASTA** file containing peptide sequences (headers optionally include HLA type).
""")
# ==============================
# 模型加载函数(缓存)
# ==============================
@st.cache_resource
def get_model():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 尝试从本地加载,如果失败则从 HF Hub 下载
local_path = "/app/src/model.pt"
if not os.path.exists(local_path):
st.warning("Model not found locally. Downloading from Hugging Face Hub...")
local_path = hf_hub_download(
repo_id="caokai1073/StriMap", # 替换为你的 Space repo
filename="src/model.pt"
)
model, device = load_model(local_path, device=device)
return model, device
model, device = get_model()
# ==============================
# 文件上传(使用 /tmp 临时目录)
# ==============================
uploaded_file = st.file_uploader("Upload CSV or FASTA", type=["csv", "fasta"])
if uploaded_file:
# 将上传文件保存到可写的 /tmp 路径
temp_path = os.path.join("/tmp", uploaded_file.name)
with open(temp_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# ==============================
# 文件解析
# ==============================
if uploaded_file.name.endswith(".csv"):
df = pd.read_csv(temp_path)
else:
seqs = []
for rec in SeqIO.parse(temp_path, "fasta"):
header = rec.id
seq = str(rec.seq)
# 尝试从header提取HLA,比如 ">HLA-A*02:01|SLLMWITQC"
if "|" in header:
hla, _ = header.split("|", 1)
else:
hla = "HLA-Unknown"
seqs.append([seq, hla])
df = pd.DataFrame(seqs, columns=["Peptide", "HLA"])
st.write("✅ Uploaded data preview:")
st.dataframe(df.head())
# ==============================
# 模型预测
# ==============================
if st.button("🚀 Run Prediction"):
with st.spinner("Running model inference..."):
result_df = predict_from_df(df, model)
st.success("✅ Prediction complete!")
st.dataframe(result_df.head(10))
# ==============================
# 下载结果
# ==============================
csv = result_df.to_csv(index=False).encode("utf-8")
st.download_button(
"⬇️ Download results as CSV",
data=csv,
file_name="hla_binding_predictions.csv",
mime="text/csv",
)