yasserrmd commited on
Commit
aa26b57
·
verified ·
1 Parent(s): f4f37ec

Create modeling_diffusion.py

Browse files
Files changed (1) hide show
  1. modeling_diffusion.py +26 -0
modeling_diffusion.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from huggingface_hub import PyTorchModelHubMixin
3
+
4
+ class DiffusionTextModel(nn.Module, PyTorchModelHubMixin):
5
+ def __init__(self, vocab_size, max_seq_len, max_time_steps,
6
+ embed_dim=128, n_layers=4, n_heads=4):
7
+ super().__init__()
8
+ self.token_emb = nn.Embedding(vocab_size, embed_dim)
9
+ self.pos_emb = nn.Embedding(max_seq_len, embed_dim)
10
+ self.time_emb = nn.Embedding(max_time_steps+1, embed_dim)
11
+
12
+ enc_layer = nn.TransformerEncoderLayer(
13
+ d_model=embed_dim, nhead=n_heads,
14
+ dim_feedforward=4*embed_dim, activation="gelu"
15
+ )
16
+ self.transformer = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
17
+ self.out = nn.Linear(embed_dim, vocab_size)
18
+
19
+ def forward(self, x, t):
20
+ B, L = x.shape
21
+ tok = self.token_emb(x)
22
+ pos = self.pos_emb(torch.arange(L, device=x.device).unsqueeze(0).expand(B, L))
23
+ tim = self.time_emb(t).unsqueeze(1).expand(B, L, -1)
24
+ h = tok + pos + tim
25
+ h = self.transformer(h.transpose(0,1)).transpose(0,1)
26
+ return self.out(h)