Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -82,7 +82,7 @@ class WeightedTrainer(Trainer):
|
|
| 82 |
return (loss, outputs) if return_outputs else loss
|
| 83 |
|
| 84 |
# fine-tuning function
|
| 85 |
-
def train_function_no_sweeps(base_model_path
|
| 86 |
|
| 87 |
# Set the LoRA config
|
| 88 |
config = {
|
|
@@ -170,7 +170,14 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
|
| 170 |
tokenizer.save_pretrained(save_path)
|
| 171 |
|
| 172 |
return save_path
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
# Load the data from pickle files (replace with your local paths)
|
| 175 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
| 176 |
train_sequences = pickle.load(f)
|
|
@@ -198,6 +205,7 @@ test_labels = truncate_labels(test_labels, max_sequence_length)
|
|
| 198 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
| 199 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
| 200 |
|
|
|
|
| 201 |
# Compute Class Weights
|
| 202 |
classes = [0, 1]
|
| 203 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
|
@@ -248,10 +256,46 @@ saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_datase
|
|
| 248 |
|
| 249 |
# debug result
|
| 250 |
dubug_result = saved_path #predictions #class_weights
|
|
|
|
| 251 |
|
| 252 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
| 253 |
|
| 254 |
with demo:
|
| 255 |
gr.Markdown("# DEMO FOR ESM2Bind")
|
| 256 |
-
gr.Textbox(dubug_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
demo.launch()
|
|
|
|
| 82 |
return (loss, outputs) if return_outputs else loss
|
| 83 |
|
| 84 |
# fine-tuning function
|
| 85 |
+
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
| 86 |
|
| 87 |
# Set the LoRA config
|
| 88 |
config = {
|
|
|
|
| 170 |
tokenizer.save_pretrained(save_path)
|
| 171 |
|
| 172 |
return save_path
|
| 173 |
+
|
| 174 |
+
# Constants & Globals
|
| 175 |
+
MODEL_OPTIONS = [
|
| 176 |
+
"facebook/esm2_t6_8M_UR50D",
|
| 177 |
+
"facebook/esm2_t12_35M_UR50D",
|
| 178 |
+
"facebook/esm2_t33_650M_UR50D",
|
| 179 |
+
] # models users can choose from
|
| 180 |
+
|
| 181 |
# Load the data from pickle files (replace with your local paths)
|
| 182 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
| 183 |
train_sequences = pickle.load(f)
|
|
|
|
| 205 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
| 206 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
| 207 |
|
| 208 |
+
'''
|
| 209 |
# Compute Class Weights
|
| 210 |
classes = [0, 1]
|
| 211 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
|
|
|
| 256 |
|
| 257 |
# debug result
|
| 258 |
dubug_result = saved_path #predictions #class_weights
|
| 259 |
+
'''
|
| 260 |
|
| 261 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
| 262 |
|
| 263 |
with demo:
|
| 264 |
gr.Markdown("# DEMO FOR ESM2Bind")
|
| 265 |
+
#gr.Textbox(dubug_result)
|
| 266 |
+
|
| 267 |
+
with gr.Tab("Finetune Pre-trained Model"):
|
| 268 |
+
gr.Markdown("## Finetune Pre-trained Model")
|
| 269 |
+
with gr.Column():
|
| 270 |
+
gr.Markdown("## Load Inputs & Select Parameters")
|
| 271 |
+
gr.Markdown(
|
| 272 |
+
""" Pick a dataset, a model & adjust params (_optional_), and press **Finetune Pre-trained Model!"""
|
| 273 |
+
)
|
| 274 |
+
with gr.Row():
|
| 275 |
+
with gr.Column(scale=0.5, variant="compact"):
|
| 276 |
+
base_model_name = gr.Dropdown(
|
| 277 |
+
choices=MODEL_OPTIONS,
|
| 278 |
+
value=MODEL_OPTIONS[0],
|
| 279 |
+
label="Base Model Name",
|
| 280 |
+
interactive = True,
|
| 281 |
+
)
|
| 282 |
+
finetune_button = gr.Button(
|
| 283 |
+
value="Finetune Pre-trained Model",
|
| 284 |
+
interactive=True,
|
| 285 |
+
variant="primary",
|
| 286 |
+
)
|
| 287 |
+
finetune_output_text = gr.Textbox(
|
| 288 |
+
lines=1,
|
| 289 |
+
max_lines=12,
|
| 290 |
+
label="Finetune Status",
|
| 291 |
+
placeholder="Finetune Status Shown Here",
|
| 292 |
+
)
|
| 293 |
+
# Tab "Finetune Pre-trained Model" actions
|
| 294 |
+
finetune_button.click(
|
| 295 |
+
fn = train_function_no_sweeps,
|
| 296 |
+
inputs=[base_model_name], #finetune_dataset_name],
|
| 297 |
+
outputs = [finetune_output_text],
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
demo.launch()
|