Spaces:
Sleeping
Sleeping
| import torch | |
| from main import StriMap_pHLA, StriMap_TCRpHLA, load_test_data | |
| def load_model(model_path="model.pt", device=None): | |
| model = StriMap_pHLA( | |
| device=device, | |
| model_save_path=model_path, | |
| cache_save=False, | |
| ) | |
| model.load_model(model_path) | |
| return model, device | |
| def predict_from_df(df, model): | |
| df = load_test_data( | |
| df_test=df, | |
| hla_dict_path='HLA_dict.npy', | |
| ) | |
| model.prepare_embeddings( | |
| df, | |
| force_recompute=False, | |
| ) | |
| df['label'] = 1 | |
| torch.cuda.empty_cache() | |
| predictions, _ = model.predict(df, batch_size=128, return_probs=True, use_kfold=False) | |
| df["Prediction"] = predictions | |
| # remove label | |
| df = df.drop(columns=['label']) | |
| return df |