# Copyright 2025 Yifan Yang # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from torchaudio.compliance.kaldi import fbank as torch_fbank from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer from .configuration_clsp import CLSPConfig from .zipformer2 import Conv2dSubsampling, Zipformer2 class CLSPModel(PreTrainedModel): config_class = CLSPConfig def __init__(self, config: CLSPConfig): super().__init__(config) self.model = get_model(config) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def load_audio(self, audio_path): return self.model.load_audio(audio_path) class MLPLayers(nn.Module): def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): super(MLPLayers, self).__init__() self.nonlin = nonlin self.dropout = dropout sequence = [] for u0, u1 in zip(units[:-1], units[1:]): sequence.append(nn.Linear(u0, u1)) sequence.append(self.nonlin) sequence.append(nn.Dropout(self.dropout)) sequence = sequence[:-2] self.sequential = nn.Sequential(*sequence) def forward(self, X): X = self.sequential(X) return X class CLAP(nn.Module): def __init__( self, encoder_embed: nn.Module, encoder: nn.Module, encoder_downsample: Optional[nn.Module] = None, encoder_dim: int = 384, text_encoder_dim: int = 768, joint_dim: int = 512, ): """CLAP-style dual encoder model. Args: encoder_embed: It is a Convolutional 2D subsampling module. It converts an input of shape (N, T, idim) to an output of of shape (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. encoder: It is the transcription network in the paper. Its accepts two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). It returns two tensors: `logits` of shape (N, T, encoder_dim) and `logit_lens` of shape (N,). """ super().__init__() # audio branch self.encoder_embed = encoder_embed self.encoder = encoder self.encoder_downsample = encoder_downsample self.audio_projection = nn.Sequential( nn.Linear(encoder_dim, joint_dim), nn.ReLU(), nn.Linear(joint_dim, joint_dim), ) self.audio_transform = MLPLayers( units=[joint_dim, joint_dim, joint_dim], dropout=0.1 ) # text branch self.text_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") self.text_encoder = text_encoder = RobertaModel( RobertaConfig.from_pretrained("roberta-base") ) self.text_projection = nn.Sequential( nn.Linear(text_encoder_dim, joint_dim), nn.ReLU(), nn.Linear(joint_dim, joint_dim), ) self.text_transform = MLPLayers( units=[joint_dim, joint_dim, joint_dim], dropout=0.1 ) self.logit_scale = nn.Parameter(torch.full((), math.log(1 / 0.07))) def _load_audio_single(self, audio_path: str) -> Tuple[torch.Tensor, int]: waveform, sr = torchaudio.load(audio_path) # (channels, num_samples) if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True) # (1, num_samples) if sr != 16000: transform = torchaudio.transforms.Resample(sr, 16000) waveform = transform(waveform) waveform_len = waveform.shape[-1] return waveform, waveform_len def load_audio(self, audio_paths: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: assert isinstance(audio_paths, list), "Must receive a list of files for reading" waveforms = [] waveform_lens = [] for audio in audio_paths: wav, wav_len = self._load_audio_single(audio) waveforms.append(wav.squeeze()) waveform_lens.append(wav_len) waveforms = pad_sequence(waveforms, batch_first=True) # (N, T) waveform_lens = torch.tensor(waveform_lens) return waveforms, waveform_lens def compute_fbank( self, wavs: torch.Tensor, wav_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute fbank features Args: wavs (torch.Tensor): the mono-channel input waveform, (N, T) wav_lens (torch.Tensor): the length of each waveform in samples (N) Returns: The fbank features, and their lengths """ assert wavs.ndim == 2, wavs.shape low_freq = 20.0 high_freq = -400.0 dither = 0.0 snip_egdes = False features = [] for i, wav in enumerate(wavs): feat = torch_fbank( wav[: wav_lens[i]].unsqueeze(0), sample_frequency=16000, # this is fixed to 16000 num_mel_bins=128, low_freq=low_freq, snip_edges=snip_egdes, high_freq=high_freq, dither=dither, energy_floor=1.0e-10, ) features.append(feat) feat_len = torch.tensor([f.shape[0] for f in features]).to(wavs.device) features = pad_sequence( features, batch_first=True, padding_value=math.log(1e-10) ).to(wavs.device) return features, feat_len def forward_audio_encoder( self, x: torch.Tensor, x_lens: torch.Tensor, freeze_encoder: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute audio encoder outputs. Args: x: A 3-D tensor of shape (N, T, C). x_lens: A 1-D tensor of shape (N,). It contains the number of frames in `x` before padding. Returns: encoder_out: Encoder output, of shape (N, T, C). encoder_out_lens: Encoder output lengths, of shape (N,). """ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") with torch.set_grad_enabled(not freeze_encoder): x, x_lens = self.encoder_embed(x, x_lens) src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_out, encoder_out_lens = self.encoder( x, x_lens, src_key_padding_mask ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) if self.encoder_downsample is not None: encoder_out = encoder_out.permute(1, 0, 2) encoder_out = self.encoder_downsample(encoder_out) encoder_out = encoder_out.permute(1, 0, 2) encoder_out_lens = (encoder_out_lens + 1) // 2 padding_mask = make_pad_mask(encoder_out_lens) encoder_out = encoder_out.masked_fill(padding_mask.unsqueeze(-1), 0.0) embedding = encoder_out.sum(dim=1) / encoder_out_lens.unsqueeze(-1) # (N, C) return embedding def forward_text_encoder(self, y: dict, freeze_encoder: bool = False): with torch.set_grad_enabled(not freeze_encoder): encoder_out = self.text_encoder( input_ids=y["input_ids"], attention_mask=y["attention_mask"], )["pooler_output"] return encoder_out def forward( self, audio: Optional[torch.Tensor] = None, audio_lens: Optional[torch.Tensor] = None, text: Optional[dict] = None, freeze_audio_encoder: bool = False, freeze_text_encoder: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: audio (torch.Tensor): Input audio waveforms (N, L). audio_lens (torch.Tensor): The length of the audio waveforms (N). text: Input text list (N). Returns: The encoded representations and logit scale. """ if audio is not None: assert audio.ndim == 2, audio.shape assert audio_lens.ndim == 1, audio_lens.shape x, x_lens = self.compute_fbank(audio, audio_lens) audio_encoder_out = self.forward_audio_encoder( x, x_lens, freeze_encoder=freeze_audio_encoder ) audio_encoder_out = self.audio_projection(audio_encoder_out) audio_encoder_out = self.audio_transform(audio_encoder_out) audio_encoder_out = F.normalize(audio_encoder_out, dim=-1) if text is not None: text = self.text_tokenizer( text, padding=True, truncation=True, return_tensors="pt", ) text = { k: v.to(device=next(self.parameters()).device) for k, v in text.items() } assert text["input_ids"].ndim == 2, text["input_ids"].shape text_encoder_out = self.forward_text_encoder( text, freeze_encoder=freeze_text_encoder ) text_encoder_out = self.text_projection(text_encoder_out) text_encoder_out = self.text_transform(text_encoder_out) text_encoder_out = F.normalize(text_encoder_out, dim=-1) return ( audio_encoder_out if audio is not None else None, text_encoder_out if text is not None else None, self.logit_scale.exp(), ) def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) def make_pad_mask( lengths: torch.Tensor, max_len: int = 0, pad_left: bool = False, ) -> torch.Tensor: """ Args: lengths: A 1-D tensor containing sentence lengths. max_len: The length of masks. pad_left: If ``False`` (default), padding is on the right. If ``True``, padding is on the left. Returns: Return a 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are filled with `False`. >>> lengths = torch.tensor([1, 3, 2, 5]) >>> make_pad_mask(lengths) tensor([[False, True, True, True, True], [False, False, False, True, True], [False, False, True, True, True], [False, False, False, False, False]]) """ assert lengths.ndim == 1, lengths.ndim max_len = max(max_len, lengths.max()) n = lengths.size(0) seq_range = torch.arange(0, max_len, device=lengths.device) expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len) if pad_left: mask = expanded_lengths < (max_len - lengths).unsqueeze(1) else: mask = expanded_lengths >= lengths.unsqueeze(-1) return mask def get_encoder_embed(config: CLSPConfig) -> nn.Module: encoder_embed = Conv2dSubsampling( in_channels=config.feature_dim, out_channels=_to_int_tuple(config.encoder_dim)[0], ) return encoder_embed def get_encoder_model(config: CLSPConfig) -> nn.Module: encoder = Zipformer2( output_downsampling_factor=config.output_downsampling_factor, downsampling_factor=_to_int_tuple(config.downsampling_factor), num_encoder_layers=_to_int_tuple(config.num_encoder_layers), encoder_dim=_to_int_tuple(config.encoder_dim), encoder_unmasked_dim=_to_int_tuple(config.encoder_unmasked_dim), query_head_dim=_to_int_tuple(config.query_head_dim), pos_head_dim=_to_int_tuple(config.pos_head_dim), value_head_dim=_to_int_tuple(config.value_head_dim), pos_dim=config.pos_dim, num_heads=_to_int_tuple(config.num_heads), feedforward_dim=_to_int_tuple(config.feedforward_dim), cnn_module_kernel=_to_int_tuple(config.cnn_module_kernel), causal=config.causal, chunk_size=_to_int_tuple(config.chunk_size), left_context_frames=_to_int_tuple(config.left_context_frames), ) return encoder def get_model(config: CLSPConfig) -> nn.Module: encoder_embed = get_encoder_embed(config) encoder = get_encoder_model(config) model = CLAP( encoder_embed=encoder_embed, encoder=encoder, encoder_dim=max(_to_int_tuple(config.encoder_dim)), text_encoder_dim=config.text_encoder_dim, joint_dim=config.joint_dim, ) return model