dn6 HF Staff commited on
Commit
ceca157
·
verified ·
1 Parent(s): a04d677

Upload transformer/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transformer/model.py +6 -9
transformer/model.py CHANGED
@@ -936,16 +936,14 @@ class DiffusionTokenEncoder(nn.Module):
936
 
937
 
938
  class EmbeddingLayer(nn.Module):
939
- """Embedding layer for 1D features."""
940
 
941
- def __init__(self, n_channels: int, total_channels: int, output_channels: int):
942
  super().__init__()
943
- self.weight = nn.Parameter(torch.zeros(n_channels, total_channels))
944
- self.proj = linearNoBias(total_channels, output_channels)
945
 
946
  def forward(self, x: torch.Tensor) -> torch.Tensor:
947
- emb = torch.einsum("...i,io->...o", x, self.weight)
948
- return self.proj(emb)
949
 
950
 
951
  class OneDFeatureEmbedder(nn.Module):
@@ -954,13 +952,12 @@ class OneDFeatureEmbedder(nn.Module):
954
  def __init__(self, features: dict, output_channels: int):
955
  super().__init__()
956
  self.features = {k: v for k, v in features.items() if v is not None}
957
- total_embedding_input_features = sum(self.features.values())
958
  self.embedders = nn.ModuleDict({
959
- feature: EmbeddingLayer(n_channels, total_embedding_input_features, output_channels)
960
  for feature, n_channels in self.features.items()
961
  })
962
 
963
- def forward(self, f: dict, collapse_length: int) -> torch.Tensor:
964
  result = None
965
  for feature in self.features:
966
  x = f.get(feature)
 
936
 
937
 
938
  class EmbeddingLayer(nn.Module):
939
+ """Embedding layer for 1D features - simple linear projection."""
940
 
941
+ def __init__(self, n_channels: int, output_channels: int):
942
  super().__init__()
943
+ self.weight = nn.Parameter(torch.zeros(output_channels, n_channels))
 
944
 
945
  def forward(self, x: torch.Tensor) -> torch.Tensor:
946
+ return F.linear(x, self.weight)
 
947
 
948
 
949
  class OneDFeatureEmbedder(nn.Module):
 
952
  def __init__(self, features: dict, output_channels: int):
953
  super().__init__()
954
  self.features = {k: v for k, v in features.items() if v is not None}
 
955
  self.embedders = nn.ModuleDict({
956
+ feature: EmbeddingLayer(n_channels, output_channels)
957
  for feature, n_channels in self.features.items()
958
  })
959
 
960
+ def forward(self, f: dict) -> torch.Tensor:
961
  result = None
962
  for feature in self.features:
963
  x = f.get(feature)