Spaces:
Sleeping
Sleeping
cao
commited on
Commit
·
c7acc8d
1
Parent(s):
9ad4750
fix
Browse files- src/main.py +4 -4
- src/model.py +2 -2
src/main.py
CHANGED
|
@@ -92,7 +92,7 @@ class StriMap_pHLA:
|
|
| 92 |
def __init__(
|
| 93 |
self,
|
| 94 |
device: str = 'cuda:0',
|
| 95 |
-
model_save_path: str = 'model_params/best_model_phla.pt',
|
| 96 |
pep_dim: int = 256,
|
| 97 |
hla_dim: int = 256,
|
| 98 |
bilinear_dim: int = 256,
|
|
@@ -101,9 +101,9 @@ class StriMap_pHLA:
|
|
| 101 |
gamma: float = 2.0,
|
| 102 |
esm2_layer: int = 33,
|
| 103 |
batch_size: int = 256,
|
| 104 |
-
esmfold_cache_dir: str = "esm_cache",
|
| 105 |
-
cache_dir: str = 'phla_cache',
|
| 106 |
-
cache_save: bool =
|
| 107 |
seed: int = 1,
|
| 108 |
pos_weights: Optional[float] = None
|
| 109 |
):
|
|
|
|
| 92 |
def __init__(
|
| 93 |
self,
|
| 94 |
device: str = 'cuda:0',
|
| 95 |
+
model_save_path: str = '/data/model_params/best_model_phla.pt',
|
| 96 |
pep_dim: int = 256,
|
| 97 |
hla_dim: int = 256,
|
| 98 |
bilinear_dim: int = 256,
|
|
|
|
| 101 |
gamma: float = 2.0,
|
| 102 |
esm2_layer: int = 33,
|
| 103 |
batch_size: int = 256,
|
| 104 |
+
esmfold_cache_dir: str = "/data/esm_cache",
|
| 105 |
+
cache_dir: str = '/data/phla_cache',
|
| 106 |
+
cache_save: bool = False,
|
| 107 |
seed: int = 1,
|
| 108 |
pos_weights: Optional[float] = None
|
| 109 |
):
|
src/model.py
CHANGED
|
@@ -342,7 +342,7 @@ class ESM2Encoder(nn.Module):
|
|
| 342 |
def __init__(self,
|
| 343 |
device="cuda:0",
|
| 344 |
layer=33,
|
| 345 |
-
cache_dir='cache'):
|
| 346 |
"""
|
| 347 |
Initialize an ESM2 encoder.
|
| 348 |
|
|
@@ -695,7 +695,7 @@ def batch_embed_to_dicts(
|
|
| 695 |
return emb_dict, coord_dict, failures
|
| 696 |
|
| 697 |
class ESMFoldEncoder(nn.Module):
|
| 698 |
-
def __init__(self, model_name="facebook/esmfold_v1", esm_cache_dir="esm_cache", cache_dir="cache"):
|
| 699 |
super(ESMFoldEncoder, self).__init__()
|
| 700 |
self.model_name = model_name
|
| 701 |
self.esm_cache_dir = esm_cache_dir
|
|
|
|
| 342 |
def __init__(self,
|
| 343 |
device="cuda:0",
|
| 344 |
layer=33,
|
| 345 |
+
cache_dir='/data/cache'):
|
| 346 |
"""
|
| 347 |
Initialize an ESM2 encoder.
|
| 348 |
|
|
|
|
| 695 |
return emb_dict, coord_dict, failures
|
| 696 |
|
| 697 |
class ESMFoldEncoder(nn.Module):
|
| 698 |
+
def __init__(self, model_name="facebook/esmfold_v1", esm_cache_dir="/data/esm_cache", cache_dir="/data/cache"):
|
| 699 |
super(ESMFoldEncoder, self).__init__()
|
| 700 |
self.model_name = model_name
|
| 701 |
self.esm_cache_dir = esm_cache_dir
|