Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -134,6 +134,17 @@ with torch.no_grad():
|
|
| 134 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 135 |
predictions = torch.argmax(logits, dim=2)
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# debug result
|
| 138 |
dubug_result = predictions #class_weights
|
| 139 |
|
|
|
|
| 134 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 135 |
predictions = torch.argmax(logits, dim=2)
|
| 136 |
|
| 137 |
+
# Define labels
|
| 138 |
+
id2label = {
|
| 139 |
+
0: "No binding site",
|
| 140 |
+
1: "Binding site"
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# Print the predicted labels for each token
|
| 144 |
+
for token, prediction in zip(tokens, predictions[0].numpy()):
|
| 145 |
+
if token not in ['<pad>', '<cls>', '<eos>']:
|
| 146 |
+
print((token, id2label[prediction]))
|
| 147 |
+
|
| 148 |
# debug result
|
| 149 |
dubug_result = predictions #class_weights
|
| 150 |
|