Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -122,13 +122,16 @@ def predict_bind(base_model_path,PEFT_model_path,input_seq):
|
|
| 122 |
predictions = torch.argmax(logits, dim=2)
|
| 123 |
|
| 124 |
binding_site=[]
|
|
|
|
| 125 |
# Print the predicted labels for each token
|
| 126 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
| 127 |
if token not in ['<pad>', '<cls>', '<eos>']:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
return binding_site
|
| 133 |
|
| 134 |
# fine-tuning function
|
|
|
|
| 122 |
predictions = torch.argmax(logits, dim=2)
|
| 123 |
|
| 124 |
binding_site=[]
|
| 125 |
+
pos = 0
|
| 126 |
# Print the predicted labels for each token
|
| 127 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
| 128 |
if token not in ['<pad>', '<cls>', '<eos>']:
|
| 129 |
+
pos++
|
| 130 |
+
print((pos, token, id2label[prediction]))
|
| 131 |
+
if prediction == 1:
|
| 132 |
+
print((pos, token, id2label[prediction]))
|
| 133 |
+
binding_site.append([pos, token, id2label[prediction]])
|
| 134 |
+
|
| 135 |
return binding_site
|
| 136 |
|
| 137 |
# fine-tuning function
|