| | """Copied from https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py""" |
| | import math |
| | from abc import ABC, abstractmethod |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | IMPORT_FLASH = False |
| | IMPORT_FLASH_ERR = "Flash Attention not installed." |
| |
|
| | |
| |
|
| |
|
| | def rtt_half(x: Tensor) -> Tensor: |
| | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| | return torch.cat([-x2, x1], dim=x1.ndim - 1) |
| |
|
| |
|
| | def apply_rotary_pos_emb( |
| | q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0 |
| | ) -> Tuple[Tensor, Tensor]: |
| | """ |
| | Applies Rotary Position Embeddings to query and key tensors. |
| | """ |
| | cos, sin = ( |
| | cos[offset : q.shape[0] + offset, ...], |
| | sin[offset : q.shape[0] + offset, ...], |
| | ) |
| | return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | class StridingSubsampling(nn.Module): |
| | """ |
| | Strided Subsampling layer used to reduce the sequence length. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | subsampling_factor: int, |
| | feat_in: int, |
| | feat_out: int, |
| | conv_channels: int, |
| | ): |
| | super().__init__() |
| | self._sampling_num = int(math.log(subsampling_factor, 2)) |
| | self._stride = 2 |
| | self._kernel_size = 3 |
| | self._padding = (self._kernel_size - 1) // 2 |
| |
|
| | layers: List[nn.Module] = [] |
| | in_channels = 1 |
| | for _ in range(self._sampling_num): |
| | layers.append( |
| | torch.nn.Conv2d( |
| | in_channels=in_channels, |
| | out_channels=conv_channels, |
| | kernel_size=self._kernel_size, |
| | stride=self._stride, |
| | padding=self._padding, |
| | ) |
| | ) |
| | layers.append(nn.ReLU()) |
| | in_channels = conv_channels |
| |
|
| | out_length = self.calc_output_length(torch.tensor(feat_in)) |
| | self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) |
| | self.conv = torch.nn.Sequential(*layers) |
| |
|
| | def calc_output_length(self, lengths: Tensor) -> Tensor: |
| | """ |
| | Calculates the output length after applying the subsampling. |
| | """ |
| | lengths = lengths.to(torch.float) |
| | add_pad = 2 * self._padding - self._kernel_size |
| | for _ in range(self._sampling_num): |
| | lengths = torch.div(lengths + add_pad, self._stride) + 1.0 |
| | lengths = torch.floor(lengths) |
| | return lengths.to(dtype=torch.int) |
| |
|
| | def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]: |
| | x = self.conv(x.unsqueeze(1)) |
| | b, _, t, _ = x.size() |
| | x = self.out(x.transpose(1, 2).reshape(b, t, -1)) |
| | return x, self.calc_output_length(lengths) |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module, ABC): |
| | """ |
| | Base class of Multi-Head Attention Mechanisms. |
| | """ |
| |
|
| | def __init__(self, n_head: int, n_feat: int, flash_attn=False): |
| | super().__init__() |
| | assert n_feat % n_head == 0 |
| | self.d_k = n_feat // n_head |
| | self.h = n_head |
| | self.linear_q = nn.Linear(n_feat, n_feat) |
| | self.linear_k = nn.Linear(n_feat, n_feat) |
| | self.linear_v = nn.Linear(n_feat, n_feat) |
| | self.linear_out = nn.Linear(n_feat, n_feat) |
| | self.flash_attn = flash_attn |
| | if self.flash_attn and not IMPORT_FLASH: |
| | raise RuntimeError( |
| | f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. " |
| | "Please install flash_attn or use --no_flash flag. " |
| | "If you have already done this, " |
| | "--force-reinstall flag might be useful" |
| | ) |
| |
|
| | def forward_qkv( |
| | self, query: Tensor, key: Tensor, value: Tensor |
| | ) -> Tuple[Tensor, Tensor, Tensor]: |
| | """ |
| | Projects the inputs into queries, keys, and values for multi-head attention. |
| | """ |
| | b = query.size(0) |
| | q = self.linear_q(query).view(b, -1, self.h, self.d_k) |
| | k = self.linear_k(key).view(b, -1, self.h, self.d_k) |
| | v = self.linear_v(value).view(b, -1, self.h, self.d_k) |
| | if self.flash_attn: |
| | return q, k, v |
| | return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| |
|
| | def forward_attention( |
| | self, value: Tensor, scores: Tensor, mask: Optional[Tensor] |
| | ) -> Tensor: |
| | """ |
| | Computes the scaled dot-product attention given the projected values and scores. |
| | """ |
| | b = value.size(0) |
| | if mask is not None: |
| | mask = mask.unsqueeze(1) |
| | scores = scores.masked_fill(mask, -10000.0) |
| | attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) |
| | else: |
| | attn = torch.softmax(scores, dim=-1) |
| | x = torch.matmul(attn, value) |
| | x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k) |
| | return self.linear_out(x) |
| |
|
| |
|
| | class RelPositionMultiHeadAttention(MultiHeadAttention): |
| | """ |
| | Relative Position Multi-Head Attention module. |
| | """ |
| |
|
| | def __init__(self, n_head: int, n_feat: int): |
| | super().__init__(n_head, n_feat) |
| | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) |
| | self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) |
| | self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) |
| |
|
| | def rel_shift(self, x: Tensor) -> Tensor: |
| | b, h, qlen, pos_len = x.size() |
| | x = torch.nn.functional.pad(x, pad=(1, 0)) |
| | x = x.view(b, h, -1, qlen) |
| | return x[:, :, 1:].view(b, h, qlen, pos_len) |
| |
|
| | def forward( |
| | self, |
| | query: Tensor, |
| | key: Tensor, |
| | value: Tensor, |
| | pos_emb: Tensor, |
| | mask: Optional[Tensor] = None, |
| | ) -> Tensor: |
| | q, k, v = self.forward_qkv(query, key, value) |
| | q = q.transpose(1, 2) |
| | p = self.linear_pos(pos_emb) |
| | p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2) |
| | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
| | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) |
| | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| | matrix_bd = self.rel_shift(matrix_bd) |
| | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) |
| | matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] |
| | scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| | return self.forward_attention(v, scores, mask) |
| |
|
| |
|
| | class RotaryPositionMultiHeadAttention(MultiHeadAttention): |
| | """ |
| | Rotary Position Multi-Head Attention module. |
| | """ |
| |
|
| | def forward( |
| | self, |
| | query: Tensor, |
| | key: Tensor, |
| | value: Tensor, |
| | pos_emb: List[Tensor], |
| | mask: Optional[Tensor] = None, |
| | ) -> Tensor: |
| | b, t, _ = value.size() |
| | query = query.transpose(0, 1).view(t, b, self.h, self.d_k) |
| | key = key.transpose(0, 1).view(t, b, self.h, self.d_k) |
| | value = value.transpose(0, 1).view(t, b, self.h, self.d_k) |
| |
|
| | cos, sin = pos_emb |
| | query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0) |
| |
|
| | q, k, v = self.forward_qkv( |
| | query.view(t, b, self.h * self.d_k).transpose(0, 1), |
| | key.view(t, b, self.h * self.d_k).transpose(0, 1), |
| | value.view(t, b, self.h * self.d_k).transpose(0, 1), |
| | ) |
| |
|
| | |
| | scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k)) |
| | out = self.forward_attention(v, scores, mask) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | return out |
| |
|
| |
|
| | class PositionalEncoding(nn.Module, ABC): |
| | """ |
| | Base class of Positional Encodings. |
| | """ |
| |
|
| | def __init__(self, dim: int, base: int): |
| | super().__init__() |
| | self.dim = dim |
| | self.base = base |
| |
|
| | @abstractmethod |
| | def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
| | pass |
| |
|
| | def extend_pe(self, length: int, device: torch.device): |
| | """ |
| | Extends the positional encoding buffer to process longer sequences. |
| | """ |
| | pe = self.create_pe(length, device) |
| | if pe is None: |
| | return |
| | if hasattr(self, "pe"): |
| | self.pe = pe |
| | else: |
| | self.register_buffer("pe", pe, persistent=False) |
| |
|
| |
|
| | class RelPositionalEmbedding(PositionalEncoding): |
| | """ |
| | Relative Positional Embedding module. |
| | """ |
| |
|
| | def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
| | """ |
| | Creates the relative positional encoding matrix. |
| | """ |
| | if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1: |
| | return None |
| | positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1) |
| | pos_length = positions.size(0) |
| | pe = torch.zeros(pos_length, self.dim, device=positions.device) |
| | div_term = torch.exp( |
| | torch.arange(0, self.dim, 2, device=pe.device) |
| | * -(math.log(10000.0) / self.dim) |
| | ) |
| | pe[:, 0::2] = torch.sin(positions * div_term) |
| | pe[:, 1::2] = torch.cos(positions * div_term) |
| | return pe.unsqueeze(0) |
| |
|
| | def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: |
| | input_len = x.size(1) |
| | center_pos = self.pe.size(1) // 2 + 1 |
| | start_pos = center_pos - input_len |
| | end_pos = center_pos + input_len - 1 |
| | return x, self.pe[:, start_pos:end_pos] |
| |
|
| |
|
| | class RotaryPositionalEmbedding(PositionalEncoding): |
| | """ |
| | Rotary Positional Embedding module. |
| | """ |
| |
|
| | def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
| | """ |
| | Creates or extends the rotary positional encoding matrix. |
| | """ |
| | if hasattr(self, "pe") and self.pe.size(0) >= 2 * length: |
| | return None |
| | positions = torch.arange(0, length, dtype=torch.float32, device=device) |
| | inv_freq = 1.0 / ( |
| | self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim) |
| | ) |
| | t = torch.arange(length, device=positions.device, dtype=inv_freq.dtype) |
| | freqs = torch.einsum("i,j->ij", t, inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1).to(positions.device) |
| | return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]]) |
| |
|
| | def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]: |
| | cos_emb = self.pe[0 : x.shape[1]] |
| | half_pe = self.pe.shape[0] // 2 |
| | sin_emb = self.pe[half_pe : half_pe + x.shape[1]] |
| | return x, [cos_emb, sin_emb] |
| |
|
| |
|
| | class ConformerConvolution(nn.Module): |
| | """ |
| | Conformer Convolution module. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | kernel_size: int, |
| | ): |
| | super().__init__() |
| | assert (kernel_size - 1) % 2 == 0 |
| | self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1) |
| | self.depthwise_conv = nn.Conv1d( |
| | in_channels=d_model, |
| | out_channels=d_model, |
| | kernel_size=kernel_size, |
| | padding=(kernel_size - 1) // 2, |
| | groups=d_model, |
| | bias=True, |
| | ) |
| | self.batch_norm = nn.BatchNorm1d(d_model) |
| | self.activation = nn.SiLU() |
| | self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1) |
| |
|
| | def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor: |
| | x = x.transpose(1, 2) |
| | x = self.pointwise_conv1(x) |
| | x = nn.functional.glu(x, dim=1) |
| | if pad_mask is not None: |
| | x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) |
| | x = self.depthwise_conv(x) |
| | x = self.batch_norm(x) |
| | x = self.activation(x) |
| | x = self.pointwise_conv2(x) |
| | return x.transpose(1, 2) |
| |
|
| |
|
| | class ConformerFeedForward(nn.Module): |
| | """ |
| | Conformer Feed Forward module. |
| | """ |
| |
|
| | def __init__(self, d_model: int, d_ff: int, use_bias=True): |
| | super().__init__() |
| | self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias) |
| | self.activation = nn.SiLU() |
| | self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | return self.linear2(self.activation(self.linear1(x))) |
| |
|
| |
|
| | class ConformerLayer(nn.Module): |
| | """ |
| | Conformer Layer module. |
| | This module combines several submodules including feed forward networks, |
| | depthwise separable convolution, and multi-head self-attention |
| | to form a single Conformer block. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | d_ff: int, |
| | self_attention_model: str, |
| | n_heads: int = 16, |
| | conv_kernel_size: int = 31, |
| | flash_attn: bool = False, |
| | ): |
| | super().__init__() |
| | self.fc_factor = 0.5 |
| | self.norm_feed_forward1 = nn.LayerNorm(d_model) |
| | self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) |
| | self.norm_conv = nn.LayerNorm(d_model) |
| | self.conv = ConformerConvolution( |
| | d_model=d_model, |
| | kernel_size=conv_kernel_size, |
| | ) |
| | self.norm_self_att = nn.LayerNorm(d_model) |
| | if self_attention_model == "rotary": |
| | self.self_attn: nn.Module = RotaryPositionMultiHeadAttention( |
| | n_head=n_heads, |
| | n_feat=d_model, |
| | flash_attn=flash_attn, |
| | ) |
| | else: |
| | assert not flash_attn, "Not supported flash_attn for rel_pos" |
| | self.self_attn = RelPositionMultiHeadAttention( |
| | n_head=n_heads, |
| | n_feat=d_model, |
| | ) |
| | self.norm_feed_forward2 = nn.LayerNorm(d_model) |
| | self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) |
| | self.norm_out = nn.LayerNorm(d_model) |
| |
|
| | def forward( |
| | self, |
| | x: Tensor, |
| | pos_emb: Union[Tensor, List[Tensor]], |
| | att_mask: Optional[Tensor] = None, |
| | pad_mask: Optional[Tensor] = None, |
| | ) -> Tensor: |
| | residual = x |
| | x = self.norm_feed_forward1(x) |
| | x = self.feed_forward1(x) |
| | residual = residual + x * self.fc_factor |
| |
|
| | x = self.norm_self_att(residual) |
| | x = self.self_attn(x, x, x, pos_emb, mask=att_mask) |
| | residual = residual + x |
| |
|
| | x = self.norm_conv(residual) |
| | x = self.conv(x, pad_mask=pad_mask) |
| | residual = residual + x |
| |
|
| | x = self.norm_feed_forward2(residual) |
| | x = self.feed_forward2(x) |
| | residual = residual + x * self.fc_factor |
| |
|
| | x = self.norm_out(residual) |
| | return x |
| |
|
| |
|
| | class ConformerEncoder(nn.Module): |
| | """ |
| | Conformer Encoder module. |
| | This module encapsulates the entire Conformer encoder architecture, |
| | consisting of a StridingSubsampling layer, positional embeddings, and |
| | a stack of Conformer Layers. |
| | It serves as the main component responsible for processing speech features. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | feat_in: int = 64, |
| | n_layers: int = 16, |
| | d_model: int = 768, |
| | subsampling_factor: int = 4, |
| | ff_expansion_factor: int = 4, |
| | self_attention_model: str = "rotary", |
| | n_heads: int = 16, |
| | pos_emb_max_len: int = 5000, |
| | conv_kernel_size: int = 31, |
| | flash_attn: bool = False, |
| | ): |
| | super().__init__() |
| | self.feat_in = feat_in |
| | assert self_attention_model in [ |
| | "rotary", |
| | "rel_pos", |
| | ], f"Not supported attn = {self_attention_model}" |
| |
|
| | self.pre_encode = StridingSubsampling( |
| | subsampling_factor=subsampling_factor, |
| | feat_in=feat_in, |
| | feat_out=d_model, |
| | conv_channels=d_model, |
| | ) |
| |
|
| | if self_attention_model == "rotary": |
| | self.pos_enc: nn.Module = RotaryPositionalEmbedding( |
| | d_model // n_heads, pos_emb_max_len |
| | ) |
| | else: |
| | self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len) |
| |
|
| | self.layers = nn.ModuleList() |
| | for _ in range(n_layers): |
| | layer = ConformerLayer( |
| | d_model=d_model, |
| | d_ff=d_model * ff_expansion_factor, |
| | self_attention_model=self_attention_model, |
| | n_heads=n_heads, |
| | conv_kernel_size=conv_kernel_size, |
| | flash_attn=flash_attn, |
| | ) |
| | self.layers.append(layer) |
| |
|
| | self.pos_enc.extend_pe(pos_emb_max_len, next(self.parameters()).device) |
| |
|
| | def input_example( |
| | self, |
| | batch_size: int = 1, |
| | seqlen: int = 200, |
| | ): |
| | device = next(self.parameters()).device |
| | features = torch.zeros(batch_size, self.feat_in, seqlen) |
| | feature_lengths = torch.full([batch_size], features.shape[-1]) |
| | return features.float().to(device), feature_lengths.to(device) |
| |
|
| | def input_names(self): |
| | return ["audio_signal", "length"] |
| |
|
| | def output_names(self): |
| | return ["encoded", "encoded_len"] |
| |
|
| | def dynamic_axes(self): |
| | return { |
| | "audio_signal": {0: "batch_size", 2: "seq_len"}, |
| | "length": {0: "batch_size"}, |
| | "encoded": {0: "batch_size", 1: "seq_len"}, |
| | "encoded_len": {0: "batch_size"}, |
| | } |
| |
|
| | def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]: |
| | audio_signal, length = self.pre_encode( |
| | x=audio_signal.transpose(1, 2), lengths=length |
| | ) |
| |
|
| | max_len = audio_signal.size(1) |
| | audio_signal, pos_emb = self.pos_enc(x=audio_signal) |
| |
|
| | pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand( |
| | length.size(0), -1 |
| | ) < length.unsqueeze(-1) |
| |
|
| | att_mask = None |
| | if audio_signal.shape[0] > 1: |
| | att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1]) |
| | att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2)) |
| | att_mask = ~att_mask |
| |
|
| | pad_mask = ~pad_mask |
| |
|
| | for layer in self.layers: |
| | audio_signal = layer( |
| | x=audio_signal, |
| | pos_emb=pos_emb, |
| | att_mask=att_mask, |
| | pad_mask=pad_mask, |
| | ) |
| |
|
| | return audio_signal.transpose(1, 2), length |
| |
|