StriMap / src /predictor.py
cao
Add model and predictor files
78f28d5
raw
history blame
748 Bytes
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA, load_test_data
def load_model(model_path="model.pt", device=None):
model = StriMap_pHLA(
device=device,
model_save_path=model_path,
cache_save=False,
)
model.load_model(model_path)
return model, device
def predict_from_df(df, model):
df = load_test_data(
df_test=df,
hla_dict_path='HLA_dict.npy',
)
model.prepare_embeddings(
df,
force_recompute=False,
)
df['label'] = 1
torch.cuda.empty_cache()
predictions, _ = model.predict(df, batch_size=128, return_probs=True, use_kfold=False)
df["Prediction"] = predictions
# remove label
df = df.drop(columns=['label'])
return df