yfyeung commited on
Commit
b307546
·
verified ·
1 Parent(s): 0a7cb9b

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. configuration_clsp.py +49 -0
  2. modeling_clsp.py +298 -0
  3. modular_clsp.py +1911 -0
  4. zipformer2.py +0 -0
configuration_clsp.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class CLSPConfig(PretrainedConfig):
5
+ model_type = "clsp"
6
+
7
+ def __init__(
8
+ self,
9
+ feature_dim: int = 128,
10
+ output_downsampling_factor: int = 2,
11
+ downsampling_factor: str = "1,2,4,8,4,2,1",
12
+ num_encoder_layers: str = "1,2,3,4,1,1,1",
13
+ encoder_dim: str = "1280,1280,1280,1280,1280,1280,1280",
14
+ encoder_unmasked_dim: str = "768,768,768,768,768,768,768",
15
+ query_head_dim: str = "32",
16
+ pos_head_dim: str = "4",
17
+ value_head_dim: str = "12",
18
+ pos_dim: int = 48,
19
+ num_heads: str = "8,8,8,8,8,8,8",
20
+ feedforward_dim: str = "3840,3840,3840,3840,3840,3840,3840",
21
+ cnn_module_kernel: str = "31,31,15,15,15,31,31",
22
+ causal: bool = False,
23
+ chunk_size: str = "-1",
24
+ left_context_frames: str = "-1",
25
+ text_encoder_dim: int = 768,
26
+ joint_dim: int = 512,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+ # SPEAR encoder related
31
+ self.feature_dim = feature_dim
32
+ self.output_downsampling_factor = output_downsampling_factor
33
+ self.downsampling_factor = downsampling_factor
34
+ self.num_encoder_layers = num_encoder_layers
35
+ self.encoder_dim = encoder_dim
36
+ self.encoder_unmasked_dim = encoder_unmasked_dim
37
+ self.query_head_dim = query_head_dim
38
+ self.pos_head_dim = pos_head_dim
39
+ self.value_head_dim = value_head_dim
40
+ self.pos_dim = pos_dim
41
+ self.num_heads = num_heads
42
+ self.feedforward_dim = feedforward_dim
43
+ self.cnn_module_kernel = cnn_module_kernel
44
+ self.causal = causal
45
+ self.chunk_size = chunk_size
46
+ self.left_context_frames = left_context_frames
47
+
48
+ self.text_encoder_dim = text_encoder_dim
49
+ self.joint_dim = joint_dim
modeling_clsp.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Yifan Yang
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import math
19
+ from typing import Optional, Tuple
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from transformers import PreTrainedModel, RobertaConfig, RobertaModel
25
+
26
+ from .configuration_clsp import CLSPConfig
27
+ from .zipformer2 import Conv2dSubsampling, Zipformer2
28
+
29
+
30
+ class CLSPModel(PreTrainedModel):
31
+ config_class = CLSPConfig
32
+
33
+ def __init__(self, config: CLSPConfig):
34
+ super().__init__(config)
35
+ self.model = get_model(config)
36
+
37
+ def forward(self, *args, **kwargs):
38
+ return self.model(*args, **kwargs)
39
+
40
+ def load_audio(self, audio_path):
41
+ return self.model.load_audio(audio_path)
42
+
43
+
44
+ class MLPLayers(nn.Module):
45
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
46
+ super(MLPLayers, self).__init__()
47
+ self.nonlin = nonlin
48
+ self.dropout = dropout
49
+
50
+ sequence = []
51
+ for u0, u1 in zip(units[:-1], units[1:]):
52
+ sequence.append(nn.Linear(u0, u1))
53
+ sequence.append(self.nonlin)
54
+ sequence.append(nn.Dropout(self.dropout))
55
+ sequence = sequence[:-2]
56
+
57
+ self.sequential = nn.Sequential(*sequence)
58
+
59
+ def forward(self, X):
60
+ X = self.sequential(X)
61
+ return X
62
+
63
+
64
+ class CLAP(nn.Module):
65
+ def __init__(
66
+ self,
67
+ encoder_embed: nn.Module,
68
+ encoder: nn.Module,
69
+ encoder_downsample: Optional[nn.Module] = None,
70
+ encoder_dim: int = 384,
71
+ text_encoder_dim: int = 768,
72
+ joint_dim: int = 512,
73
+ ):
74
+ """CLAP-style dual encoder model.
75
+
76
+ Args:
77
+ encoder_embed:
78
+ It is a Convolutional 2D subsampling module. It converts
79
+ an input of shape (N, T, idim) to an output of of shape
80
+ (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
81
+ encoder:
82
+ It is the transcription network in the paper. Its accepts
83
+ two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
84
+ It returns two tensors: `logits` of shape (N, T, encoder_dim) and
85
+ `logit_lens` of shape (N,).
86
+ """
87
+ super().__init__()
88
+
89
+ # audio branch
90
+ self.encoder_embed = encoder_embed
91
+ self.encoder = encoder
92
+ self.encoder_downsample = encoder_downsample
93
+ self.audio_projection = nn.Sequential(
94
+ nn.Linear(encoder_dim, joint_dim),
95
+ nn.ReLU(),
96
+ nn.Linear(joint_dim, joint_dim),
97
+ )
98
+ self.audio_transform = MLPLayers(
99
+ units=[joint_dim, joint_dim, joint_dim], dropout=0.1
100
+ )
101
+
102
+ # text branch
103
+ self.text_encoder = text_encoder = RobertaModel(
104
+ RobertaConfig.from_pretrained("roberta-base")
105
+ )
106
+ self.text_projection = nn.Sequential(
107
+ nn.Linear(text_encoder_dim, joint_dim),
108
+ nn.ReLU(),
109
+ nn.Linear(joint_dim, joint_dim),
110
+ )
111
+ self.text_transform = MLPLayers(
112
+ units=[joint_dim, joint_dim, joint_dim], dropout=0.1
113
+ )
114
+
115
+ self.logit_scale = nn.Parameter(torch.full((), math.log(1 / 0.07)))
116
+
117
+ def forward_audio_encoder(
118
+ self, x: torch.Tensor, x_lens: torch.Tensor, freeze_encoder: bool = False
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """Compute audio encoder outputs.
121
+ Args:
122
+ x:
123
+ A 3-D tensor of shape (N, T, C).
124
+ x_lens:
125
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
126
+ before padding.
127
+
128
+ Returns:
129
+ encoder_out:
130
+ Encoder output, of shape (N, T, C).
131
+ encoder_out_lens:
132
+ Encoder output lengths, of shape (N,).
133
+ """
134
+ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
135
+ with torch.set_grad_enabled(not freeze_encoder):
136
+ x, x_lens = self.encoder_embed(x, x_lens)
137
+ src_key_padding_mask = make_pad_mask(x_lens)
138
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
139
+ encoder_out, encoder_out_lens = self.encoder(
140
+ x, x_lens, src_key_padding_mask
141
+ )
142
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
143
+
144
+ assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
145
+
146
+ if self.encoder_downsample is not None:
147
+ encoder_out = encoder_out.permute(1, 0, 2)
148
+ encoder_out = self.encoder_downsample(encoder_out)
149
+ encoder_out = encoder_out.permute(1, 0, 2)
150
+ encoder_out_lens = (encoder_out_lens + 1) // 2
151
+
152
+ padding_mask = make_pad_mask(encoder_out_lens)
153
+ encoder_out = encoder_out.masked_fill(padding_mask.unsqueeze(-1), 0.0)
154
+ embedding = encoder_out.sum(dim=1) / encoder_out_lens.unsqueeze(-1) # (N, C)
155
+
156
+ return embedding
157
+
158
+ def forward_text_encoder(self, y: dict, freeze_encoder: bool = False):
159
+ with torch.set_grad_enabled(not freeze_encoder):
160
+ encoder_out = self.text_encoder(
161
+ input_ids=y["input_ids"],
162
+ attention_mask=y["attention_mask"],
163
+ )["pooler_output"]
164
+
165
+ return encoder_out
166
+
167
+ def forward(
168
+ self,
169
+ audio: Optional[torch.Tensor] = None,
170
+ audio_lens: Optional[torch.Tensor] = None,
171
+ text: Optional[dict] = None,
172
+ freeze_audio_encoder: bool = False,
173
+ freeze_text_encoder: bool = False,
174
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
175
+ """
176
+ Args:
177
+ audio:
178
+ A 3-D tensor of shape (N, T, C).
179
+ audio_lens:
180
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
181
+ before padding.
182
+ text:
183
+ A dict containing the text input ids and attention mask.
184
+ Returns:
185
+ Return the CLAP loss
186
+ """
187
+ if audio is not None:
188
+ assert audio.ndim == 3, audio.shape
189
+ assert audio_lens.ndim == 1, audio_lens.shape
190
+
191
+ audio_encoder_out = self.forward_audio_encoder(
192
+ audio, audio_lens, freeze_encoder=freeze_audio_encoder
193
+ )
194
+ audio_encoder_out = self.audio_projection(audio_encoder_out)
195
+ audio_encoder_out = self.audio_transform(audio_encoder_out)
196
+ audio_encoder_out = F.normalize(audio_encoder_out, dim=-1)
197
+
198
+ if text is not None:
199
+ assert text["input_ids"].ndim == 2, text["input_ids"].shape
200
+
201
+ text_encoder_out = self.forward_text_encoder(
202
+ text, freeze_encoder=freeze_text_encoder
203
+ )
204
+ text_encoder_out = self.text_projection(text_encoder_out)
205
+ text_encoder_out = self.text_transform(text_encoder_out)
206
+ text_encoder_out = F.normalize(text_encoder_out, dim=-1)
207
+
208
+ return (
209
+ audio_encoder_out if audio is not None else None,
210
+ text_encoder_out if text is not None else None,
211
+ self.logit_scale.exp(),
212
+ )
213
+
214
+
215
+ def _to_int_tuple(s: str):
216
+ return tuple(map(int, s.split(",")))
217
+
218
+
219
+ def make_pad_mask(
220
+ lengths: torch.Tensor,
221
+ max_len: int = 0,
222
+ pad_left: bool = False,
223
+ ) -> torch.Tensor:
224
+ """
225
+ Args:
226
+ lengths:
227
+ A 1-D tensor containing sentence lengths.
228
+ max_len:
229
+ The length of masks.
230
+ pad_left:
231
+ If ``False`` (default), padding is on the right.
232
+ If ``True``, padding is on the left.
233
+ Returns:
234
+ Return a 2-D bool tensor, where masked positions
235
+ are filled with `True` and non-masked positions are
236
+ filled with `False`.
237
+
238
+ >>> lengths = torch.tensor([1, 3, 2, 5])
239
+ >>> make_pad_mask(lengths)
240
+ tensor([[False, True, True, True, True],
241
+ [False, False, False, True, True],
242
+ [False, False, True, True, True],
243
+ [False, False, False, False, False]])
244
+ """
245
+ assert lengths.ndim == 1, lengths.ndim
246
+ max_len = max(max_len, lengths.max())
247
+ n = lengths.size(0)
248
+ seq_range = torch.arange(0, max_len, device=lengths.device)
249
+ expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len)
250
+
251
+ if pad_left:
252
+ mask = expanded_lengths < (max_len - lengths).unsqueeze(1)
253
+ else:
254
+ mask = expanded_lengths >= lengths.unsqueeze(-1)
255
+
256
+ return mask
257
+
258
+
259
+ def get_encoder_embed(config: CLSPConfig) -> nn.Module:
260
+ encoder_embed = Conv2dSubsampling(
261
+ in_channels=config.feature_dim,
262
+ out_channels=_to_int_tuple(config.encoder_dim)[0],
263
+ )
264
+ return encoder_embed
265
+
266
+
267
+ def get_encoder_model(config: CLSPConfig) -> nn.Module:
268
+ encoder = Zipformer2(
269
+ output_downsampling_factor=config.output_downsampling_factor,
270
+ downsampling_factor=_to_int_tuple(config.downsampling_factor),
271
+ num_encoder_layers=_to_int_tuple(config.num_encoder_layers),
272
+ encoder_dim=_to_int_tuple(config.encoder_dim),
273
+ encoder_unmasked_dim=_to_int_tuple(config.encoder_unmasked_dim),
274
+ query_head_dim=_to_int_tuple(config.query_head_dim),
275
+ pos_head_dim=_to_int_tuple(config.pos_head_dim),
276
+ value_head_dim=_to_int_tuple(config.value_head_dim),
277
+ pos_dim=config.pos_dim,
278
+ num_heads=_to_int_tuple(config.num_heads),
279
+ feedforward_dim=_to_int_tuple(config.feedforward_dim),
280
+ cnn_module_kernel=_to_int_tuple(config.cnn_module_kernel),
281
+ causal=config.causal,
282
+ chunk_size=_to_int_tuple(config.chunk_size),
283
+ left_context_frames=_to_int_tuple(config.left_context_frames),
284
+ )
285
+ return encoder
286
+
287
+
288
+ def get_model(config: CLSPConfig) -> nn.Module:
289
+ encoder_embed = get_encoder_embed(config)
290
+ encoder = get_encoder_model(config)
291
+ model = CLAP(
292
+ encoder_embed=encoder_embed,
293
+ encoder=encoder,
294
+ encoder_dim=max(_to_int_tuple(config.encoder_dim)),
295
+ text_encoder_dim=config.text_encoder_dim,
296
+ joint_dim=config.joint_dim,
297
+ )
298
+ return model
modular_clsp.py ADDED
@@ -0,0 +1,1911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import logging
19
+ import math
20
+ import random
21
+ from typing import Optional, Tuple, Union
22
+
23
+ # import k2
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch import Tensor
27
+ from torch.cuda.amp import custom_bwd, custom_fwd
28
+
29
+
30
+ def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
31
+ max_value = torch.max(x, y)
32
+ diff = torch.abs(x - y)
33
+ return max_value + torch.log1p(torch.exp(-diff))
34
+
35
+
36
+ # RuntimeError: Exporting the operator logaddexp to ONNX opset version
37
+ # 14 is not supported. Please feel free to request support or submit
38
+ # a pull request on PyTorch GitHub.
39
+ #
40
+ # The following function is to solve the above error when exporting
41
+ # models to ONNX via torch.jit.trace()
42
+ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
43
+ # Caution(fangjun): Put torch.jit.is_scripting() before
44
+ # torch.onnx.is_in_onnx_export();
45
+ # otherwise, it will cause errors for torch.jit.script().
46
+ #
47
+ # torch.logaddexp() works for both torch.jit.script() and
48
+ # torch.jit.trace() but it causes errors for ONNX export.
49
+ #
50
+ if torch.jit.is_scripting():
51
+ # Note: We cannot use torch.jit.is_tracing() here as it also
52
+ # matches torch.onnx.export().
53
+ return torch.logaddexp(x, y)
54
+ elif torch.onnx.is_in_onnx_export():
55
+ return logaddexp_onnx(x, y)
56
+ else:
57
+ # for torch.jit.trace()
58
+ return torch.logaddexp(x, y)
59
+
60
+
61
+ class PiecewiseLinear(object):
62
+ """
63
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
64
+ the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
65
+ respectively.
66
+ """
67
+
68
+ def __init__(self, *args):
69
+ assert len(args) >= 1, len(args)
70
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
71
+ self.pairs = list(args[0].pairs)
72
+ else:
73
+ self.pairs = [(float(x), float(y)) for x, y in args]
74
+ for x, y in self.pairs:
75
+ assert isinstance(x, (float, int)), type(x)
76
+ assert isinstance(y, (float, int)), type(y)
77
+
78
+ for i in range(len(self.pairs) - 1):
79
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
80
+ i,
81
+ self.pairs[i],
82
+ self.pairs[i + 1],
83
+ )
84
+
85
+ def __str__(self):
86
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
87
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
88
+
89
+ def __call__(self, x):
90
+ if x <= self.pairs[0][0]:
91
+ return self.pairs[0][1]
92
+ elif x >= self.pairs[-1][0]:
93
+ return self.pairs[-1][1]
94
+ else:
95
+ cur_x, cur_y = self.pairs[0]
96
+ for i in range(1, len(self.pairs)):
97
+ next_x, next_y = self.pairs[i]
98
+ if x >= cur_x and x <= next_x:
99
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
100
+ cur_x, cur_y = next_x, next_y
101
+ assert False
102
+
103
+ def __mul__(self, alpha):
104
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
105
+
106
+ def __add__(self, x):
107
+ if isinstance(x, (float, int)):
108
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
109
+ s, x = self.get_common_basis(x)
110
+ return PiecewiseLinear(
111
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
112
+ )
113
+
114
+ def max(self, x):
115
+ if isinstance(x, (float, int)):
116
+ x = PiecewiseLinear((0, x))
117
+ s, x = self.get_common_basis(x, include_crossings=True)
118
+ return PiecewiseLinear(
119
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
120
+ )
121
+
122
+ def min(self, x):
123
+ if isinstance(x, float) or isinstance(x, int):
124
+ x = PiecewiseLinear((0, x))
125
+ s, x = self.get_common_basis(x, include_crossings=True)
126
+ return PiecewiseLinear(
127
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
128
+ )
129
+
130
+ def __eq__(self, other):
131
+ return self.pairs == other.pairs
132
+
133
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
134
+ """
135
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
136
+ functions to self and p, but with the same x values.
137
+
138
+ p: the other piecewise linear function
139
+ include_crossings: if true, include in the x values positions
140
+ where the functions indicate by this and p cross.
141
+ """
142
+ assert isinstance(p, PiecewiseLinear), type(p)
143
+
144
+ # get sorted x-values without repetition.
145
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
146
+ y_vals1 = [self(x) for x in x_vals]
147
+ y_vals2 = [p(x) for x in x_vals]
148
+
149
+ if include_crossings:
150
+ extra_x_vals = []
151
+ for i in range(len(x_vals) - 1):
152
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
153
+ # if the two lines in this subsegment potentially cross each other..
154
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
155
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
156
+ # `pos`, between 0 and 1, gives the relative x position,
157
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
158
+ pos = diff_cur / (diff_cur + diff_next)
159
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
160
+ extra_x_vals.append(extra_x_val)
161
+ if len(extra_x_vals) > 0:
162
+ x_vals = sorted(set(x_vals + extra_x_vals))
163
+
164
+ y_vals1 = [self(x) for x in x_vals]
165
+ y_vals2 = [p(x) for x in x_vals]
166
+
167
+ return (
168
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
169
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
170
+ )
171
+
172
+
173
+ class ScheduledFloat(torch.nn.Module):
174
+ """
175
+ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
176
+ it does not have a working forward() function. You are supposed to cast it to float, as
177
+ in, float(parent_module.whatever), and use it as something like a dropout prob.
178
+
179
+ It is a floating point value whose value changes depending on the batch count of the
180
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
181
+ in sorted order on x; x corresponds to the batch index. For batch-index values before the
182
+ first x or after the last x, we just use the first or last y value.
183
+
184
+ Example:
185
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
186
+
187
+ `default` is used when self.batch_count is not set or not in training mode or in
188
+ torch.jit scripting mode.
189
+ """
190
+
191
+ def __init__(self, *args, default: float = 0.0):
192
+ super().__init__()
193
+ # self.batch_count and self.name will be written to in the training loop.
194
+ self.batch_count = None
195
+ self.name = None
196
+ self.default = default
197
+ self.schedule = PiecewiseLinear(*args)
198
+
199
+ def extra_repr(self) -> str:
200
+ return (
201
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
202
+ )
203
+
204
+ def __float__(self):
205
+ batch_count = self.batch_count
206
+ if (
207
+ batch_count is None
208
+ or not self.training
209
+ or torch.jit.is_scripting()
210
+ or torch.jit.is_tracing()
211
+ ):
212
+ return float(self.default)
213
+ else:
214
+ ans = self.schedule(self.batch_count)
215
+ if random.random() < 0.0002:
216
+ logging.info(
217
+ f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
218
+ )
219
+ return ans
220
+
221
+ def __add__(self, x):
222
+ if isinstance(x, float) or isinstance(x, int):
223
+ return ScheduledFloat(self.schedule + x, default=self.default)
224
+ else:
225
+ return ScheduledFloat(
226
+ self.schedule + x.schedule, default=self.default + x.default
227
+ )
228
+
229
+ def max(self, x):
230
+ if isinstance(x, float) or isinstance(x, int):
231
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
232
+ else:
233
+ return ScheduledFloat(
234
+ self.schedule.max(x.schedule), default=max(self.default, x.default)
235
+ )
236
+
237
+
238
+ FloatLike = Union[float, ScheduledFloat]
239
+
240
+
241
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
242
+ """
243
+ A randomized way of casting a floating point value to half precision.
244
+ """
245
+ if x.dtype == torch.float16:
246
+ return x
247
+ x_abs = x.abs()
248
+ is_too_small = x_abs < min_abs
249
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
250
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
251
+ # for those elements].
252
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
253
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
254
+
255
+
256
+ class CutoffEstimator:
257
+ """
258
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
259
+ proportion of items will be above the cutoff on average.
260
+
261
+ p is the proportion of items that should be above the cutoff.
262
+ """
263
+
264
+ def __init__(self, p: float):
265
+ self.p = p
266
+ # total count of items
267
+ self.count = 0
268
+ # total count of items that were above the cutoff
269
+ self.count_above = 0
270
+ # initial cutoff value
271
+ self.cutoff = 0
272
+
273
+ def __call__(self, x: float) -> bool:
274
+ """
275
+ Returns true if x is above the cutoff.
276
+ """
277
+ ans = x > self.cutoff
278
+ self.count += 1
279
+ if ans:
280
+ self.count_above += 1
281
+ cur_p = self.count_above / self.count
282
+ delta_p = cur_p - self.p
283
+ if (delta_p > 0) == ans:
284
+ q = abs(delta_p)
285
+ self.cutoff = x * q + self.cutoff * (1 - q)
286
+ return ans
287
+
288
+
289
+ class SoftmaxFunction(torch.autograd.Function):
290
+ """
291
+ Tries to handle half-precision derivatives in a randomized way that should
292
+ be more accurate for training than the default behavior.
293
+ """
294
+
295
+ @staticmethod
296
+ def forward(ctx, x: Tensor, dim: int):
297
+ ans = x.softmax(dim=dim)
298
+ # if x dtype is float16, x.softmax() returns a float32 because
299
+ # (presumably) that op does not support float16, and autocast
300
+ # is enabled.
301
+ if torch.is_autocast_enabled():
302
+ ans = ans.to(torch.get_autocast_gpu_dtype())
303
+ ctx.save_for_backward(ans)
304
+ ctx.x_dtype = x.dtype
305
+ ctx.dim = dim
306
+ return ans
307
+
308
+ @staticmethod
309
+ def backward(ctx, ans_grad: Tensor):
310
+ (ans,) = ctx.saved_tensors
311
+ with torch_autocast(enabled=False):
312
+ ans_grad = ans_grad.to(torch.float32)
313
+ ans = ans.to(torch.float32)
314
+ x_grad = ans_grad * ans
315
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
316
+ return x_grad, None
317
+
318
+
319
+ def softmax(x: Tensor, dim: int):
320
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
321
+ return x.softmax(dim=dim)
322
+
323
+ return SoftmaxFunction.apply(x, dim)
324
+
325
+
326
+ class MaxEigLimiterFunction(torch.autograd.Function):
327
+ @staticmethod
328
+ def forward(
329
+ ctx,
330
+ x: Tensor,
331
+ coeffs: Tensor,
332
+ direction: Tensor,
333
+ channel_dim: int,
334
+ grad_scale: float,
335
+ ) -> Tensor:
336
+ ctx.channel_dim = channel_dim
337
+ ctx.grad_scale = grad_scale
338
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
339
+ return x
340
+
341
+ @staticmethod
342
+ def backward(ctx, x_grad, *args):
343
+ with torch.enable_grad():
344
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
345
+ x_orig.requires_grad = True
346
+ num_channels = x_orig.shape[ctx.channel_dim]
347
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
348
+ new_direction.requires_grad = False
349
+ x = x - x.mean(dim=0)
350
+ x_var = (x**2).mean()
351
+ x_residual = x - coeffs * new_direction
352
+ x_residual_var = (x_residual**2).mean()
353
+ # `variance_proportion` is the proportion of the variance accounted for
354
+ # by the top eigen-direction. This is to be minimized.
355
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
356
+ variance_proportion.backward()
357
+ x_orig_grad = x_orig.grad
358
+ x_extra_grad = (
359
+ x_orig.grad
360
+ * ctx.grad_scale
361
+ * x_grad.norm()
362
+ / (x_orig_grad.norm() + 1.0e-20)
363
+ )
364
+ return x_grad + x_extra_grad.detach(), None, None, None, None
365
+
366
+
367
+ class BiasNormFunction(torch.autograd.Function):
368
+ # This computes:
369
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
370
+ # return x * scales
371
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
372
+ # it can just store the returned value (chances are, this will also be needed for
373
+ # some other reason, related to the next operation, so we can save memory).
374
+ @staticmethod
375
+ def forward(
376
+ ctx,
377
+ x: Tensor,
378
+ bias: Tensor,
379
+ log_scale: Tensor,
380
+ channel_dim: int,
381
+ store_output_for_backprop: bool,
382
+ ) -> Tensor:
383
+ assert bias.ndim == 1
384
+ if channel_dim < 0:
385
+ channel_dim = channel_dim + x.ndim
386
+ ctx.store_output_for_backprop = store_output_for_backprop
387
+ ctx.channel_dim = channel_dim
388
+ for _ in range(channel_dim + 1, x.ndim):
389
+ bias = bias.unsqueeze(-1)
390
+ scales = (
391
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
392
+ ) * log_scale.exp()
393
+ ans = x * scales
394
+ ctx.save_for_backward(
395
+ ans.detach() if store_output_for_backprop else x,
396
+ scales.detach(),
397
+ bias.detach(),
398
+ log_scale.detach(),
399
+ )
400
+ return ans
401
+
402
+ @staticmethod
403
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
404
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
405
+ if ctx.store_output_for_backprop:
406
+ x = ans_or_x / scales
407
+ else:
408
+ x = ans_or_x
409
+ x = x.detach()
410
+ x.requires_grad = True
411
+ bias.requires_grad = True
412
+ log_scale.requires_grad = True
413
+ with torch.enable_grad():
414
+ # recompute scales from x, bias and log_scale.
415
+ scales = (
416
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
417
+ ) * log_scale.exp()
418
+ ans = x * scales
419
+ ans.backward(gradient=ans_grad)
420
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
421
+
422
+
423
+ class BiasNorm(torch.nn.Module):
424
+ """
425
+ This is intended to be a simpler, and hopefully cheaper, replacement for
426
+ LayerNorm. The observation this is based on, is that Transformer-type
427
+ networks, especially with pre-norm, sometimes seem to set one of the
428
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
429
+ the LayerNorm because the output magnitude is then not strongly dependent
430
+ on the other (useful) features. Presumably the weight and bias of the
431
+ LayerNorm are required to allow it to do this.
432
+
433
+ Instead, we give the BiasNorm a trainable bias that it can use when
434
+ computing the scale for normalization. We also give it a (scalar)
435
+ trainable scale on the output.
436
+
437
+
438
+ Args:
439
+ num_channels: the number of channels, e.g. 512.
440
+ channel_dim: the axis/dimension corresponding to the channel,
441
+ interpreted as an offset from the input's ndim if negative.
442
+ This is NOT the num_channels; it should typically be one of
443
+ {-2, -1, 0, 1, 2, 3}.
444
+ log_scale: the initial log-scale that we multiply the output by; this
445
+ is learnable.
446
+ log_scale_min: FloatLike, minimum allowed value of log_scale
447
+ log_scale_max: FloatLike, maximum allowed value of log_scale
448
+ store_output_for_backprop: only possibly affects memory use; recommend
449
+ to set to True if you think the output of this module is more likely
450
+ than the input of this module to be required to be stored for the
451
+ backprop.
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ num_channels: int,
457
+ channel_dim: int = -1, # CAUTION: see documentation.
458
+ log_scale: float = 1.0,
459
+ log_scale_min: float = -1.5,
460
+ log_scale_max: float = 1.5,
461
+ store_output_for_backprop: bool = False,
462
+ ) -> None:
463
+ super(BiasNorm, self).__init__()
464
+ self.num_channels = num_channels
465
+ self.channel_dim = channel_dim
466
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
467
+ self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4))
468
+
469
+ self.log_scale_min = log_scale_min
470
+ self.log_scale_max = log_scale_max
471
+
472
+ self.store_output_for_backprop = store_output_for_backprop
473
+
474
+ def forward(self, x: Tensor) -> Tensor:
475
+ assert x.shape[self.channel_dim] == self.num_channels
476
+
477
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
478
+ channel_dim = self.channel_dim
479
+ if channel_dim < 0:
480
+ channel_dim += x.ndim
481
+ bias = self.bias
482
+ for _ in range(channel_dim + 1, x.ndim):
483
+ bias = bias.unsqueeze(-1)
484
+ scales = (
485
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
486
+ ) * self.log_scale.exp()
487
+ return x * scales
488
+
489
+ log_scale = limit_param_value(
490
+ self.log_scale,
491
+ min=float(self.log_scale_min),
492
+ max=float(self.log_scale_max),
493
+ training=self.training,
494
+ )
495
+
496
+ return BiasNormFunction.apply(
497
+ x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
498
+ )
499
+
500
+
501
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
502
+ """
503
+ Behaves like a constructor of a modified version of nn.Linear
504
+ that gives an easy way to set the default initial parameter scale.
505
+
506
+ Args:
507
+ Accepts the standard args and kwargs that nn.Linear accepts
508
+ e.g. in_features, out_features, bias=False.
509
+
510
+ initial_scale: you can override this if you want to increase
511
+ or decrease the initial magnitude of the module's output
512
+ (affects the initialization of weight_scale and bias_scale).
513
+ Another option, if you want to do something like this, is
514
+ to re-initialize the parameters.
515
+ """
516
+ ans = nn.Linear(*args, **kwargs)
517
+ with torch.no_grad():
518
+ ans.weight[:] *= initial_scale
519
+ if ans.bias is not None:
520
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
521
+ return ans
522
+
523
+
524
+ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
525
+ """
526
+ Behaves like a constructor of a modified version of nn.Conv1d
527
+ that gives an easy way to set the default initial parameter scale.
528
+
529
+ Args:
530
+ Accepts the standard args and kwargs that nn.Linear accepts
531
+ e.g. in_features, out_features, bias=False.
532
+
533
+ initial_scale: you can override this if you want to increase
534
+ or decrease the initial magnitude of the module's output
535
+ (affects the initialization of weight_scale and bias_scale).
536
+ Another option, if you want to do something like this, is
537
+ to re-initialize the parameters.
538
+ """
539
+ ans = nn.Conv1d(*args, **kwargs)
540
+ with torch.no_grad():
541
+ ans.weight[:] *= initial_scale
542
+ if ans.bias is not None:
543
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
544
+ return ans
545
+
546
+
547
+ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
548
+ """
549
+ Behaves like a constructor of a modified version of nn.Conv2d
550
+ that gives an easy way to set the default initial parameter scale.
551
+
552
+ Args:
553
+ Accepts the standard args and kwargs that nn.Linear accepts
554
+ e.g. in_features, out_features, bias=False, but:
555
+ NO PADDING-RELATED ARGS.
556
+
557
+ initial_scale: you can override this if you want to increase
558
+ or decrease the initial magnitude of the module's output
559
+ (affects the initialization of weight_scale and bias_scale).
560
+ Another option, if you want to do something like this, is
561
+ to re-initialize the parameters.
562
+ """
563
+ ans = nn.Conv2d(*args, **kwargs)
564
+ with torch.no_grad():
565
+ ans.weight[:] *= initial_scale
566
+ if ans.bias is not None:
567
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
568
+ return ans
569
+
570
+
571
+ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
572
+ """
573
+ Behaves like a depthwise 1d convolution, except that it is causal in
574
+ a chunkwise way, as if we had a block-triangular attention mask.
575
+ The chunk size is provided at test time (it should probably be
576
+ kept in sync with the attention mask).
577
+
578
+ This has a little more than twice the parameters of a conventional
579
+ depthwise conv1d module: we implement it by having one
580
+ depthwise convolution, of half the width, that is causal (via
581
+ right-padding); and one depthwise convolution that is applied only
582
+ within chunks, that we multiply by a scaling factor which depends
583
+ on the position within the chunk.
584
+
585
+ Args:
586
+ Accepts the standard args and kwargs that nn.Linear accepts
587
+ e.g. in_features, out_features, bias=False.
588
+
589
+ initial_scale: you can override this if you want to increase
590
+ or decrease the initial magnitude of the module's output
591
+ (affects the initialization of weight_scale and bias_scale).
592
+ Another option, if you want to do something like this, is
593
+ to re-initialize the parameters.
594
+ """
595
+
596
+ def __init__(
597
+ self,
598
+ channels: int,
599
+ kernel_size: int,
600
+ initial_scale: float = 1.0,
601
+ bias: bool = True,
602
+ ):
603
+ super().__init__()
604
+ assert kernel_size % 2 == 1
605
+
606
+ half_kernel_size = (kernel_size + 1) // 2
607
+ # will pad manually, on one side.
608
+ self.causal_conv = nn.Conv1d(
609
+ in_channels=channels,
610
+ out_channels=channels,
611
+ groups=channels,
612
+ kernel_size=half_kernel_size,
613
+ padding=0,
614
+ bias=True,
615
+ )
616
+
617
+ self.chunkwise_conv = nn.Conv1d(
618
+ in_channels=channels,
619
+ out_channels=channels,
620
+ groups=channels,
621
+ kernel_size=kernel_size,
622
+ padding=kernel_size // 2,
623
+ bias=bias,
624
+ )
625
+
626
+ # first row is correction factors added to the scale near the left edge of the chunk,
627
+ # second row is correction factors added to the scale near the right edge of the chunk,
628
+ # both of these are added to a default scale of 1.0.
629
+ self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
630
+ self.kernel_size = kernel_size
631
+
632
+ with torch.no_grad():
633
+ self.causal_conv.weight[:] *= initial_scale
634
+ self.chunkwise_conv.weight[:] *= initial_scale
635
+ if bias:
636
+ torch.nn.init.uniform_(
637
+ self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
638
+ )
639
+
640
+ def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
641
+ """Forward function.
642
+
643
+ Args:
644
+ x: a Tensor of shape (batch_size, channels, seq_len)
645
+ chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
646
+ """
647
+ (batch_size, num_channels, seq_len) = x.shape
648
+
649
+ # half_kernel_size = self.kernel_size + 1 // 2
650
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
651
+ # in the causal conv. It's the amount by which we must pad on the left,
652
+ # to make the convolution causal.
653
+ left_pad = self.kernel_size // 2
654
+
655
+ if chunk_size < 0 or chunk_size > seq_len:
656
+ chunk_size = seq_len
657
+ right_pad = -seq_len % chunk_size
658
+
659
+ x = torch.nn.functional.pad(x, (left_pad, right_pad))
660
+
661
+ x_causal = self.causal_conv(x[..., : left_pad + seq_len])
662
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
663
+
664
+ x_chunk = x[..., left_pad:]
665
+ num_chunks = x_chunk.shape[2] // chunk_size
666
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
667
+ x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
668
+ batch_size * num_chunks, num_channels, chunk_size
669
+ )
670
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
671
+
672
+ chunk_scale = self._get_chunk_scale(chunk_size)
673
+
674
+ x_chunk = x_chunk * chunk_scale
675
+ x_chunk = x_chunk.reshape(
676
+ batch_size, num_chunks, num_channels, chunk_size
677
+ ).permute(0, 2, 1, 3)
678
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
679
+ ..., :seq_len
680
+ ]
681
+
682
+ return x_chunk + x_causal
683
+
684
+ def _get_chunk_scale(self, chunk_size: int):
685
+ """Returns tensor of shape (num_channels, chunk_size) that will be used to
686
+ scale the output of self.chunkwise_conv."""
687
+ left_edge = self.chunkwise_conv_scale[0]
688
+ right_edge = self.chunkwise_conv_scale[1]
689
+ if chunk_size < self.kernel_size:
690
+ left_edge = left_edge[:, :chunk_size]
691
+ right_edge = right_edge[:, -chunk_size:]
692
+ else:
693
+ t = chunk_size - self.kernel_size
694
+ channels = left_edge.shape[0]
695
+ pad = torch.zeros(
696
+ channels, t, device=left_edge.device, dtype=left_edge.dtype
697
+ )
698
+ left_edge = torch.cat((left_edge, pad), dim=-1)
699
+ right_edge = torch.cat((pad, right_edge), dim=-1)
700
+ return 1.0 + (left_edge + right_edge)
701
+
702
+ def streaming_forward(
703
+ self,
704
+ x: Tensor,
705
+ cache: Tensor,
706
+ ) -> Tuple[Tensor, Tensor]:
707
+ """Streaming Forward function.
708
+
709
+ Args:
710
+ x: a Tensor of shape (batch_size, channels, seq_len)
711
+ cache: cached left context of shape (batch_size, channels, left_pad)
712
+ """
713
+ (batch_size, num_channels, seq_len) = x.shape
714
+
715
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
716
+ # in the causal conv. It's the amount by which we must pad on the left,
717
+ # to make the convolution causal.
718
+ left_pad = self.kernel_size // 2
719
+
720
+ # Pad cache
721
+ assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
722
+ x = torch.cat([cache, x], dim=2)
723
+ # Update cache
724
+ cache = x[..., -left_pad:]
725
+
726
+ x_causal = self.causal_conv(x)
727
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
728
+
729
+ x_chunk = x[..., left_pad:]
730
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
731
+
732
+ chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
733
+ x_chunk = x_chunk * chunk_scale
734
+
735
+ return x_chunk + x_causal, cache
736
+
737
+
738
+ class BalancerFunction(torch.autograd.Function):
739
+ @staticmethod
740
+ def forward(
741
+ ctx,
742
+ x: Tensor,
743
+ min_mean: float,
744
+ max_mean: float,
745
+ min_rms: float,
746
+ max_rms: float,
747
+ grad_scale: float,
748
+ channel_dim: int,
749
+ ) -> Tensor:
750
+ if channel_dim < 0:
751
+ channel_dim += x.ndim
752
+ ctx.channel_dim = channel_dim
753
+ ctx.save_for_backward(x)
754
+ ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
755
+ return x
756
+
757
+ @staticmethod
758
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
759
+ (x,) = ctx.saved_tensors
760
+ (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
761
+
762
+ try:
763
+ with torch.enable_grad():
764
+ with torch_autocast(enabled=False):
765
+ x = x.to(torch.float32)
766
+ x = x.detach()
767
+ x.requires_grad = True
768
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
769
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
770
+ mean = x.mean(dim=mean_dims, keepdim=True)
771
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
772
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
773
+
774
+ m = mean / stddev
775
+ # part of loss that relates to mean / stddev
776
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
777
+
778
+ # put a much larger scale on the RMS-max-limit loss, so that if both it and the
779
+ # m_loss are violated we fix the RMS loss first.
780
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
781
+ r_loss = (rms_clamped / rms).log().abs()
782
+
783
+ loss = m_loss + r_loss
784
+
785
+ loss.backward(gradient=torch.ones_like(loss))
786
+ loss_grad = x.grad
787
+ loss_grad_rms = (
788
+ (loss_grad**2)
789
+ .mean(dim=mean_dims, keepdim=True)
790
+ .sqrt()
791
+ .clamp(min=1.0e-20)
792
+ )
793
+
794
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
795
+
796
+ x_grad_float = x_grad.to(torch.float32)
797
+ # scale each element of loss_grad by the absolute value of the corresponding
798
+ # element of x_grad, which we view as a noisy estimate of its magnitude for that
799
+ # (frame and dimension). later we can consider factored versions.
800
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
801
+ x_grad = x_grad_mod.to(x_grad.dtype)
802
+ except Exception as e:
803
+ logging.info(
804
+ f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
805
+ )
806
+
807
+ return x_grad, None, None, None, None, None, None
808
+
809
+
810
+ class Balancer(torch.nn.Module):
811
+ """
812
+ Modifies the backpropped derivatives of a function to try to encourage, for
813
+ each channel, that it is positive at least a proportion `threshold` of the
814
+ time. It does this by multiplying negative derivative values by up to
815
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
816
+ interpolated from 1 at the threshold to those extremal values when none
817
+ of the inputs are positive.
818
+
819
+ Args:
820
+ num_channels: the number of channels
821
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
822
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
823
+ min_positive: the minimum, per channel, of the proportion of the time
824
+ that (x > 0), below which we start to modify the derivatives.
825
+ max_positive: the maximum, per channel, of the proportion of the time
826
+ that (x > 0), above which we start to modify the derivatives.
827
+ scale_gain_factor: determines the 'gain' with which we increase the
828
+ change in gradient once the constraints on min_abs and max_abs
829
+ are violated.
830
+ min_abs: the minimum average-absolute-value difference from the mean
831
+ value per channel, which we allow, before we start to modify
832
+ the derivatives to prevent this.
833
+ max_abs: the maximum average-absolute-value difference from the mean
834
+ value per channel, which we allow, before we start to modify
835
+ the derivatives to prevent this.
836
+ prob: determines the minimum probability with which we modify the
837
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
838
+ on each forward(). This is done randomly to prevent all layers
839
+ from doing it at the same time.
840
+ """
841
+
842
+ def __init__(
843
+ self,
844
+ num_channels: int,
845
+ channel_dim: int,
846
+ min_positive: FloatLike = 0.05,
847
+ max_positive: FloatLike = 0.95,
848
+ min_abs: FloatLike = 0.2,
849
+ max_abs: FloatLike = 100.0,
850
+ grad_scale: FloatLike = 0.04,
851
+ prob: Optional[FloatLike] = None,
852
+ ):
853
+ super().__init__()
854
+
855
+ if prob is None:
856
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
857
+ self.prob = prob
858
+ # 5% of the time we will return and do nothing because memory usage is
859
+ # too high.
860
+ self.mem_cutoff = CutoffEstimator(0.05)
861
+
862
+ # actually self.num_channels is no longer needed except for an assertion.
863
+ self.num_channels = num_channels
864
+ self.channel_dim = channel_dim
865
+ self.min_positive = min_positive
866
+ self.max_positive = max_positive
867
+ self.min_abs = min_abs
868
+ self.max_abs = max_abs
869
+ self.grad_scale = grad_scale
870
+
871
+ def forward(self, x: Tensor) -> Tensor:
872
+ if (
873
+ torch.jit.is_scripting()
874
+ or not x.requires_grad
875
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
876
+ ):
877
+ return _no_op(x)
878
+
879
+ prob = float(self.prob)
880
+ if random.random() < prob:
881
+ # The following inner-functions convert from the way we historically specified
882
+ # these limitations, as limits on the absolute value and the proportion of positive
883
+ # values, to limits on the RMS value and the (mean / stddev).
884
+ def _abs_to_rms(x):
885
+ # for normally distributed data, if the expected absolute value is x, the
886
+ # expected rms value will be sqrt(pi/2) * x.
887
+ return 1.25331413732 * x
888
+
889
+ def _proportion_positive_to_mean(x):
890
+ def _atanh(x):
891
+ eps = 1.0e-10
892
+ # eps is to prevent crashes if x is exactly 0 or 1.
893
+ # we'll just end up returning a fairly large value.
894
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
895
+
896
+ def _approx_inverse_erf(x):
897
+ # 1 / (sqrt(pi) * ln(2)),
898
+ # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
899
+ # this approximation is extremely crude and gets progressively worse for
900
+ # x very close to -1 or +1, but we mostly care about the "middle" region
901
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
902
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
903
+ # which is pretty close to 0.05.
904
+ return 0.8139535143 * _atanh(x)
905
+
906
+ # first convert x from the range 0..1 to the range -1..1 which the error
907
+ # function returns
908
+ x = -1 + (2 * x)
909
+ return _approx_inverse_erf(x)
910
+
911
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
912
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
913
+ min_rms = _abs_to_rms(float(self.min_abs))
914
+ max_rms = _abs_to_rms(float(self.max_abs))
915
+ grad_scale = float(self.grad_scale)
916
+
917
+ assert x.shape[self.channel_dim] == self.num_channels
918
+
919
+ return BalancerFunction.apply(
920
+ x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
921
+ )
922
+ else:
923
+ return _no_op(x)
924
+
925
+
926
+ def penalize_abs_values_gt(
927
+ x: Tensor, limit: float, penalty: float, name: str = None
928
+ ) -> Tensor:
929
+ """
930
+ Returns x unmodified, but in backprop will put a penalty for the excess of
931
+ the absolute values of elements of x over the limit "limit". E.g. if
932
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
933
+
934
+ Caution: the value of this penalty will be affected by grad scaling used
935
+ in automatic mixed precision training. For this reasons we use this,
936
+ it shouldn't really matter, or may even be helpful; we just use this
937
+ to disallow really implausible values of scores to be given to softmax.
938
+
939
+ The name is for randomly printed debug info.
940
+ """
941
+ x_sign = x.sign()
942
+ over_limit = (x.abs() - limit) > 0
943
+ # The following is a memory efficient way to penalize the absolute values of
944
+ # x that's over the limit. (The memory efficiency comes when you think
945
+ # about which items torch needs to cache for the autograd, and which ones it
946
+ # can throw away). The numerical value of aux_loss as computed here will
947
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
948
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
949
+ # limit).relu().
950
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
951
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
952
+ # sum() due to how with_loss() works.
953
+ x = with_loss(x, aux_loss, name)
954
+ # you must use x for something, or this will be ineffective.
955
+ return x
956
+
957
+
958
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
959
+ if x.ndim == 2:
960
+ return x.diag()
961
+ else:
962
+ (batch, dim, dim) = x.shape
963
+ x = x.reshape(batch, dim * dim)
964
+ x = x[:, :: dim + 1]
965
+ assert x.shape == (batch, dim)
966
+ return x
967
+
968
+
969
+ def _whitening_metric(x: Tensor, num_groups: int):
970
+ """
971
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
972
+ of the centered feature covariance are the same within each group's covariance matrix
973
+ and also between groups.
974
+ Args:
975
+ x: a Tensor of shape (*, num_channels)
976
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
977
+ Returns:
978
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
979
+ greater than 1.0 otherwise.
980
+ """
981
+ assert x.dtype != torch.float16
982
+ x = x.reshape(-1, x.shape[-1])
983
+ (num_frames, num_channels) = x.shape
984
+ assert num_channels % num_groups == 0
985
+ channels_per_group = num_channels // num_groups
986
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
987
+ # x now has shape (num_groups, num_frames, channels_per_group)
988
+ # subtract the mean so we use the centered, not uncentered, covariance.
989
+ # My experience has been that when we "mess with the gradients" like this,
990
+ # it's better not do anything that tries to move the mean around, because
991
+ # that can easily cause instability.
992
+ x = x - x.mean(dim=1, keepdim=True)
993
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
994
+ x_covar = torch.matmul(x.transpose(1, 2), x)
995
+ x_covar_mean_diag = _diag(x_covar).mean()
996
+ # the following expression is what we'd get if we took the matrix product
997
+ # of each covariance and measured the mean of its trace, i.e.
998
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
999
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
1000
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
1001
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
1002
+ return metric
1003
+
1004
+
1005
+ class WhiteningPenaltyFunction(torch.autograd.Function):
1006
+ @staticmethod
1007
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
1008
+ ctx.save_for_backward(x)
1009
+ ctx.module = module
1010
+ return x
1011
+
1012
+ @staticmethod
1013
+ def backward(ctx, x_grad: Tensor):
1014
+ (x_orig,) = ctx.saved_tensors
1015
+ w = ctx.module
1016
+
1017
+ try:
1018
+ with torch.enable_grad():
1019
+ with torch_autocast(enabled=False):
1020
+ x_detached = x_orig.to(torch.float32).detach()
1021
+ x_detached.requires_grad = True
1022
+
1023
+ metric = _whitening_metric(x_detached, w.num_groups)
1024
+
1025
+ if random.random() < 0.005 or __name__ == "__main__":
1026
+ logging.info(
1027
+ f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
1028
+ f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}"
1029
+ )
1030
+
1031
+ if metric < float(w.whitening_limit):
1032
+ w.prob = w.min_prob
1033
+ return x_grad, None
1034
+ else:
1035
+ w.prob = w.max_prob
1036
+ metric.backward()
1037
+ penalty_grad = x_detached.grad
1038
+ scale = float(w.grad_scale) * (
1039
+ x_grad.to(torch.float32).norm()
1040
+ / (penalty_grad.norm() + 1.0e-20)
1041
+ )
1042
+ penalty_grad = penalty_grad * scale
1043
+ return x_grad + penalty_grad.to(x_grad.dtype), None
1044
+ except Exception as e:
1045
+ logging.info(
1046
+ f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue."
1047
+ )
1048
+ return x_grad, None
1049
+
1050
+
1051
+ class Whiten(nn.Module):
1052
+ def __init__(
1053
+ self,
1054
+ num_groups: int,
1055
+ whitening_limit: FloatLike,
1056
+ prob: Union[float, Tuple[float, float]],
1057
+ grad_scale: FloatLike,
1058
+ ):
1059
+ """
1060
+ Args:
1061
+ num_groups: the number of groups to divide the channel dim into before
1062
+ whitening. We will attempt to make the feature covariance
1063
+ within each group, after mean subtraction, as "white" as possible,
1064
+ while having the same trace across all groups.
1065
+ whitening_limit: a value greater than 1.0, that dictates how much
1066
+ freedom we have to violate the constraints. 1.0 would mean perfectly
1067
+ white, with exactly the same trace across groups; larger values
1068
+ give more freedom. E.g. 2.0.
1069
+ prob: the probability with which we apply the gradient modification
1070
+ (also affects the grad scale). May be supplied as a float,
1071
+ or as a pair (min_prob, max_prob)
1072
+
1073
+ grad_scale: determines the scale on the gradient term from this object,
1074
+ relative to the rest of the gradient on the attention weights.
1075
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
1076
+ """
1077
+ super(Whiten, self).__init__()
1078
+ assert num_groups >= 1
1079
+ assert float(whitening_limit) >= 1
1080
+ assert float(grad_scale) >= 0
1081
+ self.num_groups = num_groups
1082
+ self.whitening_limit = whitening_limit
1083
+ self.grad_scale = grad_scale
1084
+
1085
+ if isinstance(prob, float):
1086
+ prob = (prob, prob)
1087
+ (self.min_prob, self.max_prob) = prob
1088
+ assert 0 < self.min_prob <= self.max_prob <= 1
1089
+ self.prob = self.max_prob
1090
+ self.name = None # will be set in training loop
1091
+
1092
+ def forward(self, x: Tensor) -> Tensor:
1093
+ """
1094
+ In the forward pass, this function just returns the input unmodified.
1095
+ In the backward pass, it will modify the gradients to ensure that the
1096
+ distribution in each group has close to (lambda times I) as the covariance
1097
+ after mean subtraction, with the same lambda across groups.
1098
+ For whitening_limit > 1, there will be more freedom to violate this
1099
+ constraint.
1100
+
1101
+ Args:
1102
+ x: the input of shape (*, num_channels)
1103
+
1104
+ Returns:
1105
+ x, unmodified. You should make sure
1106
+ you use the returned value, or the graph will be freed
1107
+ and nothing will happen in backprop.
1108
+ """
1109
+ grad_scale = float(self.grad_scale)
1110
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
1111
+ return _no_op(x)
1112
+ else:
1113
+ return WhiteningPenaltyFunction.apply(x, self)
1114
+
1115
+
1116
+ class WithLoss(torch.autograd.Function):
1117
+ @staticmethod
1118
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
1119
+ ctx.y_shape = y.shape
1120
+ if random.random() < 0.002 and name is not None:
1121
+ loss_sum = y.sum().item()
1122
+ logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
1123
+ return x
1124
+
1125
+ @staticmethod
1126
+ def backward(ctx, ans_grad: Tensor):
1127
+ return (
1128
+ ans_grad,
1129
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
1130
+ None,
1131
+ )
1132
+
1133
+
1134
+ def with_loss(x, y, name):
1135
+ # returns x but adds y.sum() to the loss function.
1136
+ return WithLoss.apply(x, y, name)
1137
+
1138
+
1139
+ class ScaleGradFunction(torch.autograd.Function):
1140
+ @staticmethod
1141
+ def forward(ctx, x: Tensor, alpha: float) -> Tensor:
1142
+ ctx.alpha = alpha
1143
+ return x
1144
+
1145
+ @staticmethod
1146
+ def backward(ctx, grad: Tensor):
1147
+ return grad * ctx.alpha, None
1148
+
1149
+
1150
+ def scale_grad(x: Tensor, alpha: float):
1151
+ return ScaleGradFunction.apply(x, alpha)
1152
+
1153
+
1154
+ class ScaleGrad(nn.Module):
1155
+ def __init__(self, alpha: float):
1156
+ super().__init__()
1157
+ self.alpha = alpha
1158
+
1159
+ def forward(self, x: Tensor) -> Tensor:
1160
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
1161
+ return x
1162
+ return scale_grad(x, self.alpha)
1163
+
1164
+
1165
+ class LimitParamValue(torch.autograd.Function):
1166
+ @staticmethod
1167
+ def forward(ctx, x: Tensor, min: float, max: float):
1168
+ ctx.save_for_backward(x)
1169
+ assert max >= min
1170
+ ctx.min = min
1171
+ ctx.max = max
1172
+ return x
1173
+
1174
+ @staticmethod
1175
+ def backward(ctx, x_grad: Tensor):
1176
+ (x,) = ctx.saved_tensors
1177
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
1178
+ # x more positive).
1179
+ x_grad = x_grad * torch.where(
1180
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
1181
+ )
1182
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
1183
+ # x more negative).
1184
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
1185
+ return x_grad, None, None
1186
+
1187
+
1188
+ def limit_param_value(
1189
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
1190
+ ):
1191
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
1192
+ # (elements mostly) stays within a supplied range. This is done by modifying the
1193
+ # gradients in backprop.
1194
+ # It's not necessary to do this on every batch: do it only some of the time,
1195
+ # to save a little time.
1196
+ if training and random.random() < prob:
1197
+ return LimitParamValue.apply(x, min, max)
1198
+ else:
1199
+ return x
1200
+
1201
+
1202
+ def _no_op(x: Tensor) -> Tensor:
1203
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1204
+ return x
1205
+ else:
1206
+ # a no-op function that will have a node in the autograd graph,
1207
+ # to avoid certain bugs relating to backward hooks
1208
+ return x.chunk(1, dim=-1)[0]
1209
+
1210
+
1211
+ class Identity(torch.nn.Module):
1212
+ def __init__(self):
1213
+ super(Identity, self).__init__()
1214
+
1215
+ def forward(self, x):
1216
+ return _no_op(x)
1217
+
1218
+
1219
+ class DoubleSwishFunction(torch.autograd.Function):
1220
+ """
1221
+ double_swish(x) = x * torch.sigmoid(x-1)
1222
+
1223
+ This is a definition, originally motivated by its close numerical
1224
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1225
+
1226
+ Memory-efficient derivative computation:
1227
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1228
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1229
+ Now, s'(x) = s(x) * (1-s(x)).
1230
+ double_swish'(x) = x * s'(x) + s(x).
1231
+ = x * s(x) * (1-s(x)) + s(x).
1232
+ = double_swish(x) * (1-s(x)) + s(x)
1233
+ ... so we just need to remember s(x) but not x itself.
1234
+ """
1235
+
1236
+ @staticmethod
1237
+ def forward(ctx, x: Tensor) -> Tensor:
1238
+ requires_grad = x.requires_grad
1239
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
1240
+ x = x.to(torch.float32)
1241
+
1242
+ s = torch.sigmoid(x - 1.0)
1243
+ y = x * s
1244
+
1245
+ if requires_grad:
1246
+ deriv = y * (1 - s) + s
1247
+
1248
+ # notes on derivative of x * sigmoid(x - 1):
1249
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1250
+ # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
1251
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1252
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1253
+ # floors), should be expectation-preserving.
1254
+ floor = -0.044
1255
+ ceil = 1.2
1256
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1257
+ deriv
1258
+ )
1259
+ if __name__ == "__main__":
1260
+ # for self-testing only.
1261
+ assert d_scaled.min() >= 0.0
1262
+ assert d_scaled.max() < 256.0
1263
+ d_int = d_scaled.to(torch.uint8)
1264
+ ctx.save_for_backward(d_int)
1265
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1266
+ y = y.to(torch.float16)
1267
+ return y
1268
+
1269
+ @staticmethod
1270
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1271
+ (d,) = ctx.saved_tensors
1272
+ # the same constants as used in forward pass.
1273
+ floor = -0.043637
1274
+ ceil = 1.2
1275
+
1276
+ d = d * ((ceil - floor) / 255.0) + floor
1277
+ return y_grad * d
1278
+
1279
+
1280
+ class DoubleSwish(torch.nn.Module):
1281
+ def __init__(self):
1282
+ super().__init__()
1283
+
1284
+ def forward(self, x: Tensor) -> Tensor:
1285
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1286
+ that we approximate closely with x * sigmoid(x-1).
1287
+ """
1288
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1289
+ return x * torch.sigmoid(x - 1.0)
1290
+ return DoubleSwishFunction.apply(x)
1291
+
1292
+
1293
+ # Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
1294
+ class Dropout2(nn.Module):
1295
+ def __init__(self, p: FloatLike):
1296
+ super().__init__()
1297
+ self.p = p
1298
+
1299
+ def forward(self, x: Tensor) -> Tensor:
1300
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
1301
+
1302
+
1303
+ class MulForDropout3(torch.autograd.Function):
1304
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
1305
+ # grad and is zero-or-one.
1306
+ @staticmethod
1307
+ @custom_fwd
1308
+ def forward(ctx, x, y, alpha):
1309
+ assert not y.requires_grad
1310
+ ans = x * y * alpha
1311
+ ctx.save_for_backward(ans)
1312
+ ctx.alpha = alpha
1313
+ return ans
1314
+
1315
+ @staticmethod
1316
+ @custom_bwd
1317
+ def backward(ctx, ans_grad):
1318
+ (ans,) = ctx.saved_tensors
1319
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
1320
+ return x_grad, None, None
1321
+
1322
+
1323
+ # Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
1324
+ # and it lets you choose one dimension to share the dropout mask over
1325
+ class Dropout3(nn.Module):
1326
+ def __init__(self, p: FloatLike, shared_dim: int):
1327
+ super().__init__()
1328
+ self.p = p
1329
+ self.shared_dim = shared_dim
1330
+
1331
+ def forward(self, x: Tensor) -> Tensor:
1332
+ p = float(self.p)
1333
+ if not self.training or p == 0:
1334
+ return _no_op(x)
1335
+ scale = 1.0 / (1 - p)
1336
+ rand_shape = list(x.shape)
1337
+ rand_shape[self.shared_dim] = 1
1338
+ mask = torch.rand(*rand_shape, device=x.device) > p
1339
+ ans = MulForDropout3.apply(x, mask, scale)
1340
+ return ans
1341
+
1342
+
1343
+ class SwooshLFunction(torch.autograd.Function):
1344
+ """
1345
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
1346
+ """
1347
+
1348
+ @staticmethod
1349
+ def forward(ctx, x: Tensor) -> Tensor:
1350
+ requires_grad = x.requires_grad
1351
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
1352
+ x = x.to(torch.float32)
1353
+
1354
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1355
+
1356
+ coeff = -0.08
1357
+
1358
+ with torch_autocast(enabled=False):
1359
+ with torch.enable_grad():
1360
+ x = x.detach()
1361
+ x.requires_grad = True
1362
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
1363
+
1364
+ if not requires_grad:
1365
+ return y
1366
+
1367
+ y.backward(gradient=torch.ones_like(y))
1368
+
1369
+ grad = x.grad
1370
+ floor = coeff
1371
+ ceil = 1.0 + coeff + 0.005
1372
+
1373
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1374
+ grad
1375
+ )
1376
+ if __name__ == "__main__":
1377
+ # for self-testing only.
1378
+ assert d_scaled.min() >= 0.0
1379
+ assert d_scaled.max() < 256.0
1380
+
1381
+ d_int = d_scaled.to(torch.uint8)
1382
+ ctx.save_for_backward(d_int)
1383
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1384
+ y = y.to(torch.get_autocast_gpu_dtype())
1385
+ return y
1386
+
1387
+ @staticmethod
1388
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1389
+ (d,) = ctx.saved_tensors
1390
+ # the same constants as used in forward pass.
1391
+
1392
+ coeff = -0.08
1393
+ floor = coeff
1394
+ ceil = 1.0 + coeff + 0.005
1395
+ d = d * ((ceil - floor) / 255.0) + floor
1396
+ return y_grad * d
1397
+
1398
+
1399
+ class SwooshL(torch.nn.Module):
1400
+ def forward(self, x: Tensor) -> Tensor:
1401
+ """Return Swoosh-L activation."""
1402
+ if True or torch.jit.is_scripting() or torch.jit.is_tracing():
1403
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1404
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
1405
+ if not x.requires_grad:
1406
+ return k2.swoosh_l_forward(x).to(x.dtype)
1407
+ else:
1408
+ return k2.swoosh_l(x).to(x.dtype)
1409
+ # return SwooshLFunction.apply(x)
1410
+
1411
+
1412
+ class SwooshLOnnx(torch.nn.Module):
1413
+ def forward(self, x: Tensor) -> Tensor:
1414
+ """Return Swoosh-L activation."""
1415
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1416
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
1417
+
1418
+
1419
+ class SwooshRFunction(torch.autograd.Function):
1420
+ """
1421
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
1422
+
1423
+ derivatives are between -0.08 and 0.92.
1424
+ """
1425
+
1426
+ @staticmethod
1427
+ def forward(ctx, x: Tensor) -> Tensor:
1428
+ requires_grad = x.requires_grad
1429
+
1430
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
1431
+ x = x.to(torch.float32)
1432
+
1433
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1434
+
1435
+ with torch_autocast(enabled=False):
1436
+ with torch.enable_grad():
1437
+ x = x.detach()
1438
+ x.requires_grad = True
1439
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1440
+
1441
+ if not requires_grad:
1442
+ return y
1443
+ y.backward(gradient=torch.ones_like(y))
1444
+
1445
+ grad = x.grad
1446
+ floor = -0.08
1447
+ ceil = 0.925
1448
+
1449
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1450
+ grad
1451
+ )
1452
+ if __name__ == "__main__":
1453
+ # for self-testing only.
1454
+ assert d_scaled.min() >= 0.0
1455
+ assert d_scaled.max() < 256.0
1456
+
1457
+ d_int = d_scaled.to(torch.uint8)
1458
+ ctx.save_for_backward(d_int)
1459
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1460
+ y = y.to(torch.get_autocast_gpu_dtype())
1461
+ return y
1462
+
1463
+ @staticmethod
1464
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1465
+ (d,) = ctx.saved_tensors
1466
+ # the same constants as used in forward pass.
1467
+ floor = -0.08
1468
+ ceil = 0.925
1469
+ d = d * ((ceil - floor) / 255.0) + floor
1470
+ return y_grad * d
1471
+
1472
+
1473
+ class SwooshR(torch.nn.Module):
1474
+ def forward(self, x: Tensor) -> Tensor:
1475
+ """Return Swoosh-R activation."""
1476
+ if True or torch.jit.is_scripting() or torch.jit.is_tracing():
1477
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1478
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1479
+ if not x.requires_grad:
1480
+ return k2.swoosh_r_forward(x).to(x.dtype)
1481
+ else:
1482
+ return k2.swoosh_r(x).to(x.dtype)
1483
+ # return SwooshRFunction.apply(x)
1484
+
1485
+
1486
+ class SwooshROnnx(torch.nn.Module):
1487
+ def forward(self, x: Tensor) -> Tensor:
1488
+ """Return Swoosh-R activation."""
1489
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1490
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
1491
+
1492
+
1493
+ # simple version of SwooshL that does not redefine the backprop, used in
1494
+ # ActivationDropoutAndLinearFunction.
1495
+ def SwooshLForward(x: Tensor):
1496
+ x_offset = x - 4.0
1497
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1498
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1499
+ return log_sum - 0.08 * x - 0.035
1500
+
1501
+
1502
+ # simple version of SwooshR that does not redefine the backprop, used in
1503
+ # ActivationDropoutAndLinearFunction.
1504
+ def SwooshRForward(x: Tensor):
1505
+ x_offset = x - 1.0
1506
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1507
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1508
+ return log_sum - 0.08 * x - 0.313261687
1509
+
1510
+
1511
+ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
1512
+ @staticmethod
1513
+ @custom_fwd
1514
+ def forward(
1515
+ ctx,
1516
+ x: Tensor,
1517
+ weight: Tensor,
1518
+ bias: Optional[Tensor],
1519
+ activation: str,
1520
+ dropout_p: float,
1521
+ dropout_shared_dim: Optional[int],
1522
+ ):
1523
+ if dropout_p != 0.0:
1524
+ dropout_shape = list(x.shape)
1525
+ if dropout_shared_dim is not None:
1526
+ dropout_shape[dropout_shared_dim] = 1
1527
+ # else it won't be very memory efficient.
1528
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
1529
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
1530
+ )
1531
+ else:
1532
+ dropout_mask = None
1533
+
1534
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
1535
+
1536
+ ctx.activation = activation
1537
+
1538
+ forward_activation_dict = {
1539
+ "SwooshL": k2.swoosh_l_forward,
1540
+ "SwooshR": k2.swoosh_r_forward,
1541
+ }
1542
+ # it will raise a KeyError if this fails. This will be an error. We let it
1543
+ # propagate to the user.
1544
+ activation_func = forward_activation_dict[activation]
1545
+ x = activation_func(x)
1546
+ if dropout_mask is not None:
1547
+ x = x * dropout_mask
1548
+ x = torch.nn.functional.linear(x, weight, bias)
1549
+ return x
1550
+
1551
+ @staticmethod
1552
+ @custom_bwd
1553
+ def backward(ctx, ans_grad: Tensor):
1554
+ saved = ctx.saved_tensors
1555
+ (x, weight, bias, dropout_mask) = saved
1556
+
1557
+ forward_and_deriv_activation_dict = {
1558
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
1559
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
1560
+ }
1561
+ # the following lines a KeyError if the activation is unrecognized.
1562
+ # This will be an error. We let it propagate to the user.
1563
+ func = forward_and_deriv_activation_dict[ctx.activation]
1564
+
1565
+ y, func_deriv = func(x)
1566
+ if dropout_mask is not None:
1567
+ y = y * dropout_mask
1568
+ # now compute derivative of y w.r.t. weight and bias..
1569
+ # y: (..., in_channels), ans_grad: (..., out_channels),
1570
+ (out_channels, in_channels) = weight.shape
1571
+
1572
+ in_channels = y.shape[-1]
1573
+ g = ans_grad.reshape(-1, out_channels)
1574
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
1575
+ y_deriv = torch.matmul(ans_grad, weight)
1576
+ bias_deriv = None if bias is None else g.sum(dim=0)
1577
+ x_deriv = y_deriv * func_deriv
1578
+ if dropout_mask is not None:
1579
+ # order versus func_deriv does not matter
1580
+ x_deriv = x_deriv * dropout_mask
1581
+
1582
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
1583
+
1584
+
1585
+ class ActivationDropoutAndLinear(torch.nn.Module):
1586
+ """
1587
+ This merges an activation function followed by dropout and then a nn.Linear module;
1588
+ it does so in a memory efficient way so that it only stores the input to the whole
1589
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
1590
+ equivalent to:
1591
+ nn.Sequential(SwooshL(),
1592
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
1593
+ ScaledLinear(in_channels, out_channels, bias=bias,
1594
+ initial_scale=initial_scale))
1595
+ If dropout_shared_dim is None, the dropout would be equivalent to
1596
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
1597
+ mask is smaller.
1598
+
1599
+ Args:
1600
+ in_channels: number of input channels, e.g. 256
1601
+ out_channels: number of output channels, e.g. 256
1602
+ bias: if true, have a bias
1603
+ activation: the activation function, for now just support SwooshL.
1604
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
1605
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
1606
+ shared (e.g. the time dimension). If None, this may be less memory
1607
+ efficient if there are modules before this one that cache the input
1608
+ for their backprop (e.g. Balancer or Whiten).
1609
+ """
1610
+
1611
+ def __init__(
1612
+ self,
1613
+ in_channels: int,
1614
+ out_channels: int,
1615
+ bias: bool = True,
1616
+ activation: str = "SwooshL",
1617
+ dropout_p: FloatLike = 0.0,
1618
+ dropout_shared_dim: Optional[int] = -1,
1619
+ initial_scale: float = 1.0,
1620
+ ):
1621
+ super().__init__()
1622
+ # create a temporary module of nn.Linear that we'll steal the
1623
+ # weights and bias from
1624
+ l = ScaledLinear(
1625
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
1626
+ )
1627
+
1628
+ self.weight = l.weight
1629
+ # register_parameter properly handles making it a parameter when l.bias
1630
+ # is None. I think there is some reason for doing it this way rather
1631
+ # than just setting it to None but I don't know what it is, maybe
1632
+ # something to do with exporting the module..
1633
+ self.register_parameter("bias", l.bias)
1634
+
1635
+ self.activation = activation
1636
+ self.dropout_p = dropout_p
1637
+ self.dropout_shared_dim = dropout_shared_dim
1638
+
1639
+ def forward(self, x: Tensor):
1640
+ if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
1641
+ if self.activation == "SwooshL":
1642
+ x = SwooshLForward(x)
1643
+ elif self.activation == "SwooshR":
1644
+ x = SwooshRForward(x)
1645
+ else:
1646
+ assert False, self.activation
1647
+ return torch.nn.functional.linear(x, self.weight, self.bias)
1648
+
1649
+ return ActivationDropoutAndLinearFunction.apply(
1650
+ x,
1651
+ self.weight,
1652
+ self.bias,
1653
+ self.activation,
1654
+ float(self.dropout_p),
1655
+ self.dropout_shared_dim,
1656
+ )
1657
+
1658
+
1659
+ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
1660
+ if num_channels <= x.shape[-1]:
1661
+ return x[..., :num_channels]
1662
+ else:
1663
+ shape = list(x.shape)
1664
+ shape[-1] = num_channels - shape[-1]
1665
+ zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
1666
+ return torch.cat((x, zeros), dim=-1)
1667
+
1668
+
1669
+ def _test_whiten():
1670
+ for proportion in [0.1, 0.5, 10.0]:
1671
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1672
+ x = torch.randn(100, 128)
1673
+ direction = torch.randn(128)
1674
+ coeffs = torch.randn(100, 1)
1675
+ x += proportion * direction * coeffs
1676
+
1677
+ x.requires_grad = True
1678
+
1679
+ m = Whiten(
1680
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1681
+ ) # grad_scale
1682
+
1683
+ for _ in range(4):
1684
+ y = m(x)
1685
+
1686
+ y_grad = torch.randn_like(x)
1687
+ y.backward(gradient=y_grad)
1688
+
1689
+ if proportion < 0.2:
1690
+ assert torch.allclose(x.grad, y_grad)
1691
+ elif proportion > 1.0:
1692
+ assert not torch.allclose(x.grad, y_grad)
1693
+
1694
+
1695
+ def _test_balancer_sign():
1696
+ probs = torch.arange(0, 1, 0.01)
1697
+ N = 1000
1698
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
1699
+ x = x.detach()
1700
+ x.requires_grad = True
1701
+ m = Balancer(
1702
+ probs.numel(),
1703
+ channel_dim=0,
1704
+ min_positive=0.05,
1705
+ max_positive=0.95,
1706
+ min_abs=0.0,
1707
+ prob=1.0,
1708
+ )
1709
+
1710
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1711
+
1712
+ y = m(x)
1713
+ y.backward(gradient=y_grad)
1714
+ print("_test_balancer_sign: x = ", x)
1715
+ print("_test_balancer_sign: y grad = ", y_grad)
1716
+ print("_test_balancer_sign: x grad = ", x.grad)
1717
+
1718
+
1719
+ def _test_balancer_magnitude():
1720
+ magnitudes = torch.arange(0, 1, 0.01)
1721
+ N = 1000
1722
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
1723
+ x = x.detach()
1724
+ x.requires_grad = True
1725
+ m = Balancer(
1726
+ magnitudes.numel(),
1727
+ channel_dim=0,
1728
+ min_positive=0.0,
1729
+ max_positive=1.0,
1730
+ min_abs=0.2,
1731
+ max_abs=0.7,
1732
+ prob=1.0,
1733
+ )
1734
+
1735
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1736
+
1737
+ y = m(x)
1738
+ y.backward(gradient=y_grad)
1739
+ print("_test_balancer_magnitude: x = ", x)
1740
+ print("_test_balancer_magnitude: y grad = ", y_grad)
1741
+ print("_test_balancer_magnitude: x grad = ", x.grad)
1742
+
1743
+
1744
+ def _test_double_swish_deriv():
1745
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1746
+ x.requires_grad = True
1747
+ m = DoubleSwish()
1748
+
1749
+ tol = (1.2 - (-0.043637)) / 255.0
1750
+ torch.autograd.gradcheck(m, x, atol=tol)
1751
+
1752
+ # for self-test.
1753
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1754
+ x.requires_grad = True
1755
+ y = m(x)
1756
+
1757
+
1758
+ def _test_swooshl_deriv():
1759
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1760
+ x.requires_grad = True
1761
+ m = SwooshL()
1762
+
1763
+ tol = 1.0 / 255.0
1764
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1765
+
1766
+ # for self-test.
1767
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1768
+ x.requires_grad = True
1769
+ y = m(x)
1770
+
1771
+
1772
+ def _test_swooshr_deriv():
1773
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1774
+ x.requires_grad = True
1775
+ m = SwooshR()
1776
+
1777
+ tol = 1.0 / 255.0
1778
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1779
+
1780
+ # for self-test.
1781
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1782
+ x.requires_grad = True
1783
+ y = m(x)
1784
+
1785
+
1786
+ def _test_softmax():
1787
+ a = torch.randn(2, 10, dtype=torch.float64)
1788
+ b = a.clone()
1789
+ a.requires_grad = True
1790
+ b.requires_grad = True
1791
+ a.softmax(dim=1)[:, 0].sum().backward()
1792
+ print("a grad = ", a.grad)
1793
+ softmax(b, dim=1)[:, 0].sum().backward()
1794
+ print("b grad = ", b.grad)
1795
+ assert torch.allclose(a.grad, b.grad)
1796
+
1797
+
1798
+ def _test_piecewise_linear():
1799
+ p = PiecewiseLinear((0, 10.0))
1800
+ for x in [-100, 0, 100]:
1801
+ assert p(x) == 10.0
1802
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
1803
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
1804
+ print("x, y = ", x, y)
1805
+ assert p(x) == y, (x, p(x), y)
1806
+
1807
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
1808
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
1809
+ pq = p.max(q)
1810
+ for x in x_vals:
1811
+ y1 = max(p(x), q(x))
1812
+ y2 = pq(x)
1813
+ assert abs(y1 - y2) < 0.001
1814
+ pq = p.min(q)
1815
+ for x in x_vals:
1816
+ y1 = min(p(x), q(x))
1817
+ y2 = pq(x)
1818
+ assert abs(y1 - y2) < 0.001
1819
+ pq = p + q
1820
+ for x in x_vals:
1821
+ y1 = p(x) + q(x)
1822
+ y2 = pq(x)
1823
+ assert abs(y1 - y2) < 0.001
1824
+
1825
+
1826
+ def _test_activation_dropout_and_linear():
1827
+ in_channels = 20
1828
+ out_channels = 30
1829
+
1830
+ for bias in [True, False]:
1831
+ # actually we don't test for dropout_p != 0.0 because forward functions will give
1832
+ # different answers. This is because we are using the k2 implementation of
1833
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
1834
+ # internally, messing up the random state.
1835
+ for dropout_p in [0.0]:
1836
+ for activation in ["SwooshL", "SwooshR"]:
1837
+ m1 = nn.Sequential(
1838
+ SwooshL() if activation == "SwooshL" else SwooshR(),
1839
+ Dropout3(p=dropout_p, shared_dim=-1),
1840
+ ScaledLinear(
1841
+ in_channels, out_channels, bias=bias, initial_scale=0.5
1842
+ ),
1843
+ )
1844
+ m2 = ActivationDropoutAndLinear(
1845
+ in_channels,
1846
+ out_channels,
1847
+ bias=bias,
1848
+ initial_scale=0.5,
1849
+ activation=activation,
1850
+ dropout_p=dropout_p,
1851
+ )
1852
+ with torch.no_grad():
1853
+ m2.weight[:] = m1[2].weight
1854
+ if bias:
1855
+ m2.bias[:] = m1[2].bias
1856
+ # make sure forward gives same result.
1857
+ x1 = torch.randn(10, in_channels)
1858
+ x1.requires_grad = True
1859
+
1860
+ # TEMP.
1861
+ assert torch.allclose(
1862
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
1863
+ )
1864
+
1865
+ x2 = x1.clone().detach()
1866
+ x2.requires_grad = True
1867
+ seed = 10
1868
+ torch.manual_seed(seed)
1869
+ y1 = m1(x1)
1870
+ y_grad = torch.randn_like(y1)
1871
+ y1.backward(gradient=y_grad)
1872
+ torch.manual_seed(seed)
1873
+ y2 = m2(x2)
1874
+ y2.backward(gradient=y_grad)
1875
+
1876
+ print(
1877
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
1878
+ )
1879
+ print("y1 = ", y1)
1880
+ print("y2 = ", y2)
1881
+ assert torch.allclose(y1, y2, atol=0.02)
1882
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
1883
+ if bias:
1884
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
1885
+ print("x1.grad = ", x1.grad)
1886
+ print("x2.grad = ", x2.grad)
1887
+
1888
+ def isclose(a, b):
1889
+ # return true if cosine similarity is > 0.9.
1890
+ return (a * b).sum() > 0.9 * (
1891
+ (a**2).sum() * (b**2).sum()
1892
+ ).sqrt()
1893
+
1894
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
1895
+ # storage of it.
1896
+ assert isclose(x1.grad, x2.grad)
1897
+
1898
+
1899
+ if __name__ == "__main__":
1900
+ logging.getLogger().setLevel(logging.INFO)
1901
+ torch.set_num_threads(1)
1902
+ torch.set_num_interop_threads(1)
1903
+ _test_piecewise_linear()
1904
+ _test_softmax()
1905
+ _test_whiten()
1906
+ _test_balancer_sign()
1907
+ _test_balancer_magnitude()
1908
+ _test_double_swish_deriv()
1909
+ _test_swooshr_deriv()
1910
+ _test_swooshl_deriv()
1911
+ _test_activation_dropout_and_linear()
zipformer2.py ADDED
The diff for this file is too large to render. See raw diff