Sophia Tang
commited on
Commit
·
92f7053
1
Parent(s):
6612621
update
Browse files- config.yaml +168 -0
- diffusion.py +1 -101
- scoring/{hemolysis.py → functions/hemolysis.py} +0 -0
config.yaml
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
noise:
|
| 2 |
+
type: loglinear
|
| 3 |
+
sigma_min: 1e-4
|
| 4 |
+
sigma_max: 20
|
| 5 |
+
state_dependent: True
|
| 6 |
+
|
| 7 |
+
mode: ppl_eval # train / ppl_eval / sample_eval
|
| 8 |
+
diffusion: absorbing_state
|
| 9 |
+
vocab: old_smiles # old_smiles / new_smiles / selfies / helm
|
| 10 |
+
backbone: roformer # peptideclm / helmgpt / dit / roformer / finetune_roformer
|
| 11 |
+
parameterization: subs # subs
|
| 12 |
+
time_conditioning: False
|
| 13 |
+
T: 0 # 0 (continuous time) / 1000
|
| 14 |
+
subs_masking: False
|
| 15 |
+
|
| 16 |
+
seed: 42
|
| 17 |
+
|
| 18 |
+
mcts:
|
| 19 |
+
num_children: 50
|
| 20 |
+
num_objectives: 5
|
| 21 |
+
topk: 100
|
| 22 |
+
mask_token: 4
|
| 23 |
+
num_iter: 128
|
| 24 |
+
sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
|
| 25 |
+
invalid_penalty: 0.5
|
| 26 |
+
sample_prob: 1.0
|
| 27 |
+
perm: True
|
| 28 |
+
dual: False
|
| 29 |
+
single: False
|
| 30 |
+
time_dependent: True
|
| 31 |
+
|
| 32 |
+
lr_scheduler:
|
| 33 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
| 34 |
+
num_warmup_steps: 2500
|
| 35 |
+
|
| 36 |
+
data:
|
| 37 |
+
train: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-train.csv
|
| 38 |
+
valid: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-val.csv
|
| 39 |
+
batchinohup ng: wrapping # padding / wrapping
|
| 40 |
+
|
| 41 |
+
loader:
|
| 42 |
+
global_batch_size: 64
|
| 43 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 44 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 45 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 46 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 47 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 48 |
+
pin_memory: True
|
| 49 |
+
|
| 50 |
+
sampling:
|
| 51 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
| 52 |
+
num_sequences: 100
|
| 53 |
+
sampling_eps: 1e-3
|
| 54 |
+
steps: 128
|
| 55 |
+
seq_length: 100
|
| 56 |
+
noise_removal: True
|
| 57 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 58 |
+
num_sample_log: 2
|
| 59 |
+
stride_length: 1
|
| 60 |
+
num_strides: 1
|
| 61 |
+
|
| 62 |
+
training:
|
| 63 |
+
antithetic_sampling: True
|
| 64 |
+
sampling_eps: 1e-3
|
| 65 |
+
focus_mask: False
|
| 66 |
+
#dynamic_batching: True
|
| 67 |
+
accumulator: False
|
| 68 |
+
|
| 69 |
+
eval:
|
| 70 |
+
checkpoint_path: /home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/epoch=10-step=156276.ckpt
|
| 71 |
+
disable_ema: False
|
| 72 |
+
compute_generative_perplexity: False
|
| 73 |
+
perplexity_batch_size: 8
|
| 74 |
+
compute_perplexity_on_sanity: False
|
| 75 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 76 |
+
generate_samples: True
|
| 77 |
+
generation_model: /home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/
|
| 78 |
+
|
| 79 |
+
optim:
|
| 80 |
+
weight_decay: 0.075
|
| 81 |
+
lr: 3e-4
|
| 82 |
+
beta1: 0.9
|
| 83 |
+
beta2: 0.999
|
| 84 |
+
eps: 1e-8
|
| 85 |
+
|
| 86 |
+
pepclm:
|
| 87 |
+
hidden_size: 768
|
| 88 |
+
cond_dim: 256
|
| 89 |
+
n_heads: 20
|
| 90 |
+
n_blocks: 4
|
| 91 |
+
dropout: 0.5
|
| 92 |
+
length: 512
|
| 93 |
+
#scale_by_sigma: True
|
| 94 |
+
|
| 95 |
+
model:
|
| 96 |
+
type: ddit
|
| 97 |
+
hidden_size: 768
|
| 98 |
+
cond_dim: 128
|
| 99 |
+
length: 512
|
| 100 |
+
n_blocks: 12
|
| 101 |
+
n_heads: 12
|
| 102 |
+
scale_by_sigma: True
|
| 103 |
+
dropout: 0.1
|
| 104 |
+
|
| 105 |
+
roformer:
|
| 106 |
+
hidden_size: 768
|
| 107 |
+
n_layers: 8
|
| 108 |
+
n_heads: 8
|
| 109 |
+
max_position_embeddings: 1035
|
| 110 |
+
|
| 111 |
+
helmgpt:
|
| 112 |
+
hidden_size: 256
|
| 113 |
+
embd_pdrop: 0.1
|
| 114 |
+
resid_pdrop: 0.1
|
| 115 |
+
attn_pdrop: 0.1
|
| 116 |
+
ff_dropout: 0.
|
| 117 |
+
block_size: 140
|
| 118 |
+
n_layer: 8
|
| 119 |
+
n_heads: 8
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
trainer:
|
| 123 |
+
_target_: lightning.Trainer
|
| 124 |
+
accelerator: cuda
|
| 125 |
+
num_nodes: 1
|
| 126 |
+
devices: ${device_count:}
|
| 127 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 128 |
+
gradient_clip_val: 1.0
|
| 129 |
+
precision: 64-true
|
| 130 |
+
num_sanity_val_steps: 2
|
| 131 |
+
max_epochs: 100
|
| 132 |
+
max_steps: 1_000_000
|
| 133 |
+
log_every_n_steps: 10
|
| 134 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 135 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 136 |
+
#val_check_interval: 40 #954
|
| 137 |
+
check_val_every_n_epoch: 1
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
wandb:
|
| 141 |
+
project: peptune
|
| 142 |
+
notes: null
|
| 143 |
+
group: null
|
| 144 |
+
job_type: null
|
| 145 |
+
name: sophia-tang
|
| 146 |
+
id: ${.name}_nov12_set2
|
| 147 |
+
|
| 148 |
+
hydra:
|
| 149 |
+
run:
|
| 150 |
+
dir: ./${now:%Y.%m.%d}/
|
| 151 |
+
job:
|
| 152 |
+
chdir: True
|
| 153 |
+
|
| 154 |
+
checkpointing:
|
| 155 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 156 |
+
save_dir: ${cwd:}
|
| 157 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 158 |
+
resume_from_ckpt: True
|
| 159 |
+
resume_ckpt_path: /home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/epoch=7-step=108225.ckpt
|
| 160 |
+
|
| 161 |
+
callbacks:
|
| 162 |
+
model_checkpoint:
|
| 163 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 164 |
+
every_n_epochs: 1
|
| 165 |
+
monitor: "val/nll"
|
| 166 |
+
save_top_k: 10
|
| 167 |
+
mode: "min"
|
| 168 |
+
dirpath: '/home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer'
|
diffusion.py
CHANGED
|
@@ -116,8 +116,6 @@ class Diffusion(L.LightningModule):
|
|
| 116 |
self.test_metrics = metrics.clone(prefix='test/')
|
| 117 |
|
| 118 |
|
| 119 |
-
"""LOSS"""
|
| 120 |
-
|
| 121 |
"""LOSS FOR INVALID PEPTIDES"""
|
| 122 |
|
| 123 |
@torch.no_grad()
|
|
@@ -248,18 +246,6 @@ class Diffusion(L.LightningModule):
|
|
| 248 |
t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
|
| 249 |
|
| 250 |
return t
|
| 251 |
-
|
| 252 |
-
"""def mask_samples(self, x0, mask_prob):
|
| 253 |
-
|
| 254 |
-
# generate array of values in range [0, 1] uniformly at random
|
| 255 |
-
# will be used to determine which tokens are masked
|
| 256 |
-
mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
|
| 257 |
-
|
| 258 |
-
# select tokens to mask if the random value in mask_indices is less than mask_prob
|
| 259 |
-
# this will mask approximately the fraction of tokens indicated by mask_prob
|
| 260 |
-
zt = torch.where(mask_indices < mask_prob, self.mask_token_id, x0)
|
| 261 |
-
|
| 262 |
-
return zt"""
|
| 263 |
|
| 264 |
def q_xt(self, x, mask_prob):
|
| 265 |
"""Computes the noisy sample xt.
|
|
@@ -349,48 +335,6 @@ class Diffusion(L.LightningModule):
|
|
| 349 |
# scale by T and return
|
| 350 |
return self.T * L_vb
|
| 351 |
|
| 352 |
-
"""def _forward_pass_diffusion(self, x0, attn_mask, mask=None):
|
| 353 |
-
|
| 354 |
-
print(x0)
|
| 355 |
-
# randomly sample time steps to start the denoising process for each x0 in batch
|
| 356 |
-
t = self.sample_t(x0.shape[0], x0.device)
|
| 357 |
-
|
| 358 |
-
# if we are training the intermediate transition blocks
|
| 359 |
-
if self.T > 0:
|
| 360 |
-
# scale by total timesteps T and cast to integer
|
| 361 |
-
t = (t * self.T).to(torch.int)
|
| 362 |
-
# scale down by T to get a multiple of 1/T
|
| 363 |
-
t = t / self.T
|
| 364 |
-
# add 1/T to ensure no 0 values
|
| 365 |
-
t += (1 / self.T)
|
| 366 |
-
|
| 367 |
-
# get noise and rate of noise at timestep t
|
| 368 |
-
sigma, dsigma = self.noise(t)
|
| 369 |
-
time_conditioning = sigma[:, None]
|
| 370 |
-
# get masking probabilities for all tokens for each batch
|
| 371 |
-
mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
|
| 372 |
-
|
| 373 |
-
# get masked samples at different timesteps
|
| 374 |
-
if mask is None: zt = self.q_xt(x0, mask_prob)
|
| 375 |
-
else: zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id))
|
| 376 |
-
|
| 377 |
-
model_output = self.forward(zt, attn_mask, time_conditioning)
|
| 378 |
-
|
| 379 |
-
utils.print_nans(model_output, 'model_output')
|
| 380 |
-
|
| 381 |
-
if self.T > 0:
|
| 382 |
-
# compute diffusion loss
|
| 383 |
-
diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
|
| 384 |
-
return diffusion_loss
|
| 385 |
-
|
| 386 |
-
# compute loss for the final that converts from z0 to x0
|
| 387 |
-
# -log(p_theta)
|
| 388 |
-
# get (batch_size, L) array of log-probabilities
|
| 389 |
-
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) # (B, L)
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
return -log_p_theta * (dsigma / torch.expm1(sigma))[:, None]"""
|
| 393 |
-
|
| 394 |
def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 395 |
"""
|
| 396 |
Training reverse diffusion model x_theta to reconstruct samples x0
|
|
@@ -634,21 +578,6 @@ class Diffusion(L.LightningModule):
|
|
| 634 |
|
| 635 |
# first step in expansion
|
| 636 |
def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
|
| 637 |
-
"""
|
| 638 |
-
Generates batch_size different samples from the same starting point for the
|
| 639 |
-
first expansion step of MCTS
|
| 640 |
-
|
| 641 |
-
Args:
|
| 642 |
-
x (_type_): _description_
|
| 643 |
-
t (_type_): _description_
|
| 644 |
-
dt (_type_): _description_
|
| 645 |
-
batch_size (_type_): _description_
|
| 646 |
-
p_x0 (_type_, optional): _description_. Defaults to None.
|
| 647 |
-
attn_mask (_type_, optional): _description_. Defaults to None.
|
| 648 |
-
|
| 649 |
-
Returns:
|
| 650 |
-
_type_: _description_
|
| 651 |
-
"""
|
| 652 |
|
| 653 |
assert self.config.noise.type == 'loglinear'
|
| 654 |
sigma_t, _ = self.noise(t)
|
|
@@ -880,9 +809,7 @@ class Diffusion(L.LightningModule):
|
|
| 880 |
0)[..., None]
|
| 881 |
return edge
|
| 882 |
|
| 883 |
-
|
| 884 |
-
"""TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
|
| 885 |
-
|
| 886 |
def on_train_epoch_start(self):
|
| 887 |
torch.cuda.empty_cache()
|
| 888 |
self.backbone.train()
|
|
@@ -1049,19 +976,6 @@ def sample_categorical(categorical_probs):
|
|
| 1049 |
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 1050 |
|
| 1051 |
def sample_batched_categorical(categorical_probs, batch_size):
|
| 1052 |
-
"""
|
| 1053 |
-
Generates `m` distinct sequences sampled from categorical probabilities
|
| 1054 |
-
using the Gumbel distribution to ensure randomness while following probabilities
|
| 1055 |
-
|
| 1056 |
-
Args:
|
| 1057 |
-
categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
|
| 1058 |
-
representing categorical probabilities
|
| 1059 |
-
m (int): number of distinct sequences to sample
|
| 1060 |
-
|
| 1061 |
-
Returns:
|
| 1062 |
-
torch.Tensor: tensor of shape (m, sequence_length), where each row is a
|
| 1063 |
-
distinct sequence of sampled category indices.
|
| 1064 |
-
"""
|
| 1065 |
_, sequence_length, vocab_size = categorical_probs.shape
|
| 1066 |
|
| 1067 |
# add Gumbel noise and sample m sequences
|
|
@@ -1074,20 +988,6 @@ def sample_batched_categorical(categorical_probs, batch_size):
|
|
| 1074 |
return sampled_sequences
|
| 1075 |
|
| 1076 |
def sample_batched_top_k(categorical_probs, batch_size, k):
|
| 1077 |
-
"""
|
| 1078 |
-
Generates `m` sequences sampled from the top-k probabilities of each token
|
| 1079 |
-
using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
|
| 1080 |
-
|
| 1081 |
-
Args:
|
| 1082 |
-
categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
|
| 1083 |
-
representing categorical probabilities.
|
| 1084 |
-
m (int): Number of sequences to sample.
|
| 1085 |
-
k (int): Number of top probabilities to consider for sampling.
|
| 1086 |
-
|
| 1087 |
-
Returns:
|
| 1088 |
-
torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
|
| 1089 |
-
sampled sequence of category indices.
|
| 1090 |
-
"""
|
| 1091 |
_, sequence_length, vocab_length = categorical_probs.shape
|
| 1092 |
|
| 1093 |
# Add Gumbel noise to the log probabilities
|
|
|
|
| 116 |
self.test_metrics = metrics.clone(prefix='test/')
|
| 117 |
|
| 118 |
|
|
|
|
|
|
|
| 119 |
"""LOSS FOR INVALID PEPTIDES"""
|
| 120 |
|
| 121 |
@torch.no_grad()
|
|
|
|
| 246 |
t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
|
| 247 |
|
| 248 |
return t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
def q_xt(self, x, mask_prob):
|
| 251 |
"""Computes the noisy sample xt.
|
|
|
|
| 335 |
# scale by T and return
|
| 336 |
return self.T * L_vb
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 339 |
"""
|
| 340 |
Training reverse diffusion model x_theta to reconstruct samples x0
|
|
|
|
| 578 |
|
| 579 |
# first step in expansion
|
| 580 |
def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
assert self.config.noise.type == 'loglinear'
|
| 583 |
sigma_t, _ = self.noise(t)
|
|
|
|
| 809 |
0)[..., None]
|
| 810 |
return edge
|
| 811 |
|
| 812 |
+
|
|
|
|
|
|
|
| 813 |
def on_train_epoch_start(self):
|
| 814 |
torch.cuda.empty_cache()
|
| 815 |
self.backbone.train()
|
|
|
|
| 976 |
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 977 |
|
| 978 |
def sample_batched_categorical(categorical_probs, batch_size):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 979 |
_, sequence_length, vocab_size = categorical_probs.shape
|
| 980 |
|
| 981 |
# add Gumbel noise and sample m sequences
|
|
|
|
| 988 |
return sampled_sequences
|
| 989 |
|
| 990 |
def sample_batched_top_k(categorical_probs, batch_size, k):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
_, sequence_length, vocab_length = categorical_probs.shape
|
| 992 |
|
| 993 |
# Add Gumbel noise to the log probabilities
|
scoring/{hemolysis.py → functions/hemolysis.py}
RENAMED
|
File without changes
|