caokai1073 commited on
Commit
8721078
·
verified ·
1 Parent(s): aa7a278

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +39 -8
src/streamlit_app.py CHANGED
@@ -4,7 +4,12 @@ from io import StringIO
4
  from predictor import load_model, predict_from_df
5
  from Bio import SeqIO
6
  import torch
 
 
7
 
 
 
 
8
  st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide")
9
 
10
  st.title("🧠 Peptide–HLA Binding Predictor")
@@ -13,26 +18,46 @@ Upload a **CSV** file with columns `Peptide` and `HLA`,
13
  or a **FASTA** file containing peptide sequences (headers optionally include HLA type).
14
  """)
15
 
16
- uploaded_file = st.file_uploader("Upload CSV or FASTA", type=["csv", "fasta"])
17
-
18
- # 加载模型
19
  @st.cache_resource
20
  def get_model():
21
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
- from huggingface_hub import hf_hub_download
23
 
24
- model, device = load_model("/app/src/model.pt", device=device)
25
-
 
 
 
 
 
 
 
 
26
  return model, device
27
 
28
  model, device = get_model()
29
 
 
 
 
 
 
30
  if uploaded_file:
 
 
 
 
 
 
 
 
31
  if uploaded_file.name.endswith(".csv"):
32
- df = pd.read_csv(uploaded_file)
33
  else:
34
  seqs = []
35
- for rec in SeqIO.parse(uploaded_file, "fasta"):
36
  header = rec.id
37
  seq = str(rec.seq)
38
  # 尝试从header提取HLA,比如 ">HLA-A*02:01|SLLMWITQC"
@@ -46,6 +71,9 @@ if uploaded_file:
46
  st.write("✅ Uploaded data preview:")
47
  st.dataframe(df.head())
48
 
 
 
 
49
  if st.button("🚀 Run Prediction"):
50
  with st.spinner("Running model inference..."):
51
  result_df = predict_from_df(df, model)
@@ -53,6 +81,9 @@ if uploaded_file:
53
  st.success("✅ Prediction complete!")
54
  st.dataframe(result_df.head(10))
55
 
 
 
 
56
  csv = result_df.to_csv(index=False).encode("utf-8")
57
  st.download_button(
58
  "⬇️ Download results as CSV",
 
4
  from predictor import load_model, predict_from_df
5
  from Bio import SeqIO
6
  import torch
7
+ import os
8
+ from huggingface_hub import hf_hub_download
9
 
10
+ # ==============================
11
+ # 页面配置
12
+ # ==============================
13
  st.set_page_config(page_title="🧬 Peptide–HLA Binding Predictor", layout="wide")
14
 
15
  st.title("🧠 Peptide–HLA Binding Predictor")
 
18
  or a **FASTA** file containing peptide sequences (headers optionally include HLA type).
19
  """)
20
 
21
+ # ==============================
22
+ # 模型加载函数(缓存)
23
+ # ==============================
24
  @st.cache_resource
25
  def get_model():
26
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
27
 
28
+ # 尝试从本地加载,如果失败则从 HF Hub 下载
29
+ local_path = "/app/src/model.pt"
30
+ if not os.path.exists(local_path):
31
+ st.warning("Model not found locally. Downloading from Hugging Face Hub...")
32
+ local_path = hf_hub_download(
33
+ repo_id="caokai1073/StriMap", # 替换为你的 Space repo
34
+ filename="src/model.pt"
35
+ )
36
+
37
+ model, device = load_model(local_path, device=device)
38
  return model, device
39
 
40
  model, device = get_model()
41
 
42
+ # ==============================
43
+ # 文件上传(使用 /tmp 临时目录)
44
+ # ==============================
45
+ uploaded_file = st.file_uploader("Upload CSV or FASTA", type=["csv", "fasta"])
46
+
47
  if uploaded_file:
48
+ # 将上传文件保存到可写的 /tmp 路径
49
+ temp_path = os.path.join("/tmp", uploaded_file.name)
50
+ with open(temp_path, "wb") as f:
51
+ f.write(uploaded_file.getbuffer())
52
+
53
+ # ==============================
54
+ # 文件解析
55
+ # ==============================
56
  if uploaded_file.name.endswith(".csv"):
57
+ df = pd.read_csv(temp_path)
58
  else:
59
  seqs = []
60
+ for rec in SeqIO.parse(temp_path, "fasta"):
61
  header = rec.id
62
  seq = str(rec.seq)
63
  # 尝试从header提取HLA,比如 ">HLA-A*02:01|SLLMWITQC"
 
71
  st.write("✅ Uploaded data preview:")
72
  st.dataframe(df.head())
73
 
74
+ # ==============================
75
+ # 模型预测
76
+ # ==============================
77
  if st.button("🚀 Run Prediction"):
78
  with st.spinner("Running model inference..."):
79
  result_df = predict_from_df(df, model)
 
81
  st.success("✅ Prediction complete!")
82
  st.dataframe(result_df.head(10))
83
 
84
+ # ==============================
85
+ # 下载结果
86
+ # ==============================
87
  csv = result_df.to_csv(index=False).encode("utf-8")
88
  st.download_button(
89
  "⬇️ Download results as CSV",