Upload transformer/model.py with huggingface_hub
Browse files- transformer/model.py +6 -9
transformer/model.py
CHANGED
|
@@ -936,16 +936,14 @@ class DiffusionTokenEncoder(nn.Module):
|
|
| 936 |
|
| 937 |
|
| 938 |
class EmbeddingLayer(nn.Module):
|
| 939 |
-
"""Embedding layer for 1D features."""
|
| 940 |
|
| 941 |
-
def __init__(self, n_channels: int,
|
| 942 |
super().__init__()
|
| 943 |
-
self.weight = nn.Parameter(torch.zeros(
|
| 944 |
-
self.proj = linearNoBias(total_channels, output_channels)
|
| 945 |
|
| 946 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 947 |
-
|
| 948 |
-
return self.proj(emb)
|
| 949 |
|
| 950 |
|
| 951 |
class OneDFeatureEmbedder(nn.Module):
|
|
@@ -954,13 +952,12 @@ class OneDFeatureEmbedder(nn.Module):
|
|
| 954 |
def __init__(self, features: dict, output_channels: int):
|
| 955 |
super().__init__()
|
| 956 |
self.features = {k: v for k, v in features.items() if v is not None}
|
| 957 |
-
total_embedding_input_features = sum(self.features.values())
|
| 958 |
self.embedders = nn.ModuleDict({
|
| 959 |
-
feature: EmbeddingLayer(n_channels,
|
| 960 |
for feature, n_channels in self.features.items()
|
| 961 |
})
|
| 962 |
|
| 963 |
-
def forward(self, f: dict
|
| 964 |
result = None
|
| 965 |
for feature in self.features:
|
| 966 |
x = f.get(feature)
|
|
|
|
| 936 |
|
| 937 |
|
| 938 |
class EmbeddingLayer(nn.Module):
|
| 939 |
+
"""Embedding layer for 1D features - simple linear projection."""
|
| 940 |
|
| 941 |
+
def __init__(self, n_channels: int, output_channels: int):
|
| 942 |
super().__init__()
|
| 943 |
+
self.weight = nn.Parameter(torch.zeros(output_channels, n_channels))
|
|
|
|
| 944 |
|
| 945 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 946 |
+
return F.linear(x, self.weight)
|
|
|
|
| 947 |
|
| 948 |
|
| 949 |
class OneDFeatureEmbedder(nn.Module):
|
|
|
|
| 952 |
def __init__(self, features: dict, output_channels: int):
|
| 953 |
super().__init__()
|
| 954 |
self.features = {k: v for k, v in features.items() if v is not None}
|
|
|
|
| 955 |
self.embedders = nn.ModuleDict({
|
| 956 |
+
feature: EmbeddingLayer(n_channels, output_channels)
|
| 957 |
for feature, n_channels in self.features.items()
|
| 958 |
})
|
| 959 |
|
| 960 |
+
def forward(self, f: dict) -> torch.Tensor:
|
| 961 |
result = None
|
| 962 |
for feature in self.features:
|
| 963 |
x = f.get(feature)
|