caokai1073 commited on
Commit
374d0bb
·
verified ·
1 Parent(s): c7acc8d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +41 -20
src/streamlit_app.py CHANGED
@@ -18,56 +18,64 @@ Upload a **CSV** file with columns `Peptide` and `HLA`,
18
  or a **FASTA** file containing peptide sequences (headers optionally include HLA type).
19
  """)
20
 
21
- import os
 
 
 
 
 
 
 
 
 
 
22
  os.environ["HF_HOME"] = "/data/huggingface"
23
  os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface"
24
  os.environ["TORCH_HOME"] = "/data/huggingface"
25
- os.environ["ESM_CACHE_DIR"] = "/data/phla_cache"
26
- os.makedirs("/data/phla_cache", exist_ok=True)
27
 
28
  # ==============================
29
- # 模型加载函数(缓存)
30
  # ==============================
31
  @st.cache_resource
32
- def get_model():
33
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
34
 
35
- # 尝试从本地加载,如果失败则从 HF Hub 下载
36
- local_path = "/app/src/model.pt"
37
  if not os.path.exists(local_path):
38
- st.warning("Model not found locally. Downloading from Hugging Face Hub...")
 
39
  local_path = hf_hub_download(
40
- repo_id="caokai1073/StriMap", # 替换为你的 Space repo
41
- filename="src/model.pt"
 
42
  )
43
 
44
  model, device = load_model(local_path, device=device)
45
  return model, device
46
 
47
- model, device = get_model()
48
 
49
  # ==============================
50
- # 文件上传(使用 /tmp 临时目录)
51
  # ==============================
52
- uploaded_file = st.file_uploader("Upload CSV or FASTA", type=["csv", "fasta"])
53
 
54
  if uploaded_file:
55
- # 将上传文件保存到可写的 /tmp 路径
56
- temp_path = os.path.join("/tmp", uploaded_file.name)
57
  with open(temp_path, "wb") as f:
58
  f.write(uploaded_file.getbuffer())
59
 
60
  # ==============================
61
  # 文件解析
62
  # ==============================
63
- if uploaded_file.name.endswith(".csv"):
64
  df = pd.read_csv(temp_path)
65
  else:
66
  seqs = []
67
  for rec in SeqIO.parse(temp_path, "fasta"):
68
  header = rec.id
69
  seq = str(rec.seq)
70
- # 尝试从header提取HLA,比如 ">HLA-A*02:01|SLLMWITQC"
71
  if "|" in header:
72
  hla, _ = header.split("|", 1)
73
  else:
@@ -79,10 +87,13 @@ if uploaded_file:
79
  st.dataframe(df.head())
80
 
81
  # ==============================
82
- # 模型预测
83
  # ==============================
84
  if st.button("🚀 Run Prediction"):
85
- with st.spinner("Running model inference..."):
 
 
 
86
  result_df = predict_from_df(df, model)
87
 
88
  st.success("✅ Prediction complete!")
@@ -97,4 +108,14 @@ if uploaded_file:
97
  data=csv,
98
  file_name="hla_binding_predictions.csv",
99
  mime="text/csv",
100
- )
 
 
 
 
 
 
 
 
 
 
 
18
  or a **FASTA** file containing peptide sequences (headers optionally include HLA type).
19
  """)
20
 
21
+ # ==============================
22
+ # 全局路径设置
23
+ # ==============================
24
+ CACHE_DIR = "/data/phla_cache"
25
+ MODEL_DIR = "/app/src"
26
+ UPLOAD_DIR = "/data/uploads"
27
+
28
+ for d in [CACHE_DIR, MODEL_DIR, UPLOAD_DIR]:
29
+ os.makedirs(d, exist_ok=True)
30
+
31
+ # 环境变量(确保所有模型和 ESM 缓存写入 /data)
32
  os.environ["HF_HOME"] = "/data/huggingface"
33
  os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface"
34
  os.environ["TORCH_HOME"] = "/data/huggingface"
35
+ os.environ["ESM_CACHE_DIR"] = CACHE_DIR
 
36
 
37
  # ==============================
38
+ # 模型加载函数(延迟加载 + 缓存)
39
  # ==============================
40
  @st.cache_resource
41
+ def load_model_cached():
42
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
43
+ local_path = os.path.join(MODEL_DIR, "model.pt")
44
 
 
 
45
  if not os.path.exists(local_path):
46
+ st.warning("🔄 Model not found locally. Downloading from Hugging Face model repo...")
47
+ # ⚠️ 使用 Model Repo,而不是 Space Repo
48
  local_path = hf_hub_download(
49
+ repo_id="caokai1073/StriMap-model", # 建议单独创建模型仓库
50
+ filename="model.pt",
51
+ cache_dir=MODEL_DIR
52
  )
53
 
54
  model, device = load_model(local_path, device=device)
55
  return model, device
56
 
 
57
 
58
  # ==============================
59
+ # 上传文件(安全写入 /data/uploads)
60
  # ==============================
61
+ uploaded_file = st.file_uploader("📤 Upload CSV or FASTA", type=["csv", "fasta"])
62
 
63
  if uploaded_file:
64
+ safe_name = uploaded_file.name.replace(" ", "_")
65
+ temp_path = os.path.join(UPLOAD_DIR, safe_name)
66
  with open(temp_path, "wb") as f:
67
  f.write(uploaded_file.getbuffer())
68
 
69
  # ==============================
70
  # 文件解析
71
  # ==============================
72
+ if safe_name.endswith(".csv"):
73
  df = pd.read_csv(temp_path)
74
  else:
75
  seqs = []
76
  for rec in SeqIO.parse(temp_path, "fasta"):
77
  header = rec.id
78
  seq = str(rec.seq)
 
79
  if "|" in header:
80
  hla, _ = header.split("|", 1)
81
  else:
 
87
  st.dataframe(df.head())
88
 
89
  # ==============================
90
+ # 模型预测(延迟加载)
91
  # ==============================
92
  if st.button("🚀 Run Prediction"):
93
+ with st.spinner("🔄 Loading model (this may take ~1 min first time)..."):
94
+ model, device = load_model_cached()
95
+
96
+ with st.spinner("Running inference..."):
97
  result_df = predict_from_df(df, model)
98
 
99
  st.success("✅ Prediction complete!")
 
108
  data=csv,
109
  file_name="hla_binding_predictions.csv",
110
  mime="text/csv",
111
+ )
112
+
113
+ # ==============================
114
+ # Debug / data check (optional)
115
+ # ==============================
116
+ if st.sidebar.button("📁 List /data files"):
117
+ files = []
118
+ for root, _, filenames in os.walk("/data"):
119
+ for f in filenames:
120
+ files.append(os.path.join(root, f))
121
+ st.sidebar.write(files)