CLSP / modeling_clsp.py
yfyeung's picture
Upload folder using huggingface_hub
447a8a2 verified
# Copyright 2025 Yifan Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torchaudio.compliance.kaldi import fbank as torch_fbank
from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer
from .configuration_clsp import CLSPConfig
from .zipformer2 import Conv2dSubsampling, Zipformer2
class CLSPModel(PreTrainedModel):
config_class = CLSPConfig
def __init__(self, config: CLSPConfig):
super().__init__(config)
self.model = get_model(config)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def load_audio(self, audio_path):
return self.model.load_audio(audio_path)
class MLPLayers(nn.Module):
def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
super(MLPLayers, self).__init__()
self.nonlin = nonlin
self.dropout = dropout
sequence = []
for u0, u1 in zip(units[:-1], units[1:]):
sequence.append(nn.Linear(u0, u1))
sequence.append(self.nonlin)
sequence.append(nn.Dropout(self.dropout))
sequence = sequence[:-2]
self.sequential = nn.Sequential(*sequence)
def forward(self, X):
X = self.sequential(X)
return X
class CLAP(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: nn.Module,
encoder_downsample: Optional[nn.Module] = None,
encoder_dim: int = 384,
text_encoder_dim: int = 768,
joint_dim: int = 512,
):
"""CLAP-style dual encoder model.
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
"""
super().__init__()
# audio branch
self.encoder_embed = encoder_embed
self.encoder = encoder
self.encoder_downsample = encoder_downsample
self.audio_projection = nn.Sequential(
nn.Linear(encoder_dim, joint_dim),
nn.ReLU(),
nn.Linear(joint_dim, joint_dim),
)
self.audio_transform = MLPLayers(
units=[joint_dim, joint_dim, joint_dim], dropout=0.1
)
# text branch
self.text_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
self.text_encoder = text_encoder = RobertaModel(
RobertaConfig.from_pretrained("roberta-base")
)
self.text_projection = nn.Sequential(
nn.Linear(text_encoder_dim, joint_dim),
nn.ReLU(),
nn.Linear(joint_dim, joint_dim),
)
self.text_transform = MLPLayers(
units=[joint_dim, joint_dim, joint_dim], dropout=0.1
)
self.logit_scale = nn.Parameter(torch.full((), math.log(1 / 0.07)))
def _load_audio_single(self, audio_path: str) -> Tuple[torch.Tensor, int]:
waveform, sr = torchaudio.load(audio_path) # (channels, num_samples)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True) # (1, num_samples)
if sr != 16000:
transform = torchaudio.transforms.Resample(sr, 16000)
waveform = transform(waveform)
waveform_len = waveform.shape[-1]
return waveform, waveform_len
def load_audio(self, audio_paths: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
assert isinstance(audio_paths, list), "Must receive a list of files for reading"
waveforms = []
waveform_lens = []
for audio in audio_paths:
wav, wav_len = self._load_audio_single(audio)
waveforms.append(wav.squeeze())
waveform_lens.append(wav_len)
waveforms = pad_sequence(waveforms, batch_first=True) # (N, T)
waveform_lens = torch.tensor(waveform_lens)
return waveforms, waveform_lens
def compute_fbank(
self, wavs: torch.Tensor, wav_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute fbank features
Args:
wavs (torch.Tensor): the mono-channel input waveform, (N, T)
wav_lens (torch.Tensor): the length of each waveform in samples (N)
Returns:
The fbank features, and their lengths
"""
assert wavs.ndim == 2, wavs.shape
low_freq = 20.0
high_freq = -400.0
dither = 0.0
snip_egdes = False
features = []
for i, wav in enumerate(wavs):
feat = torch_fbank(
wav[: wav_lens[i]].unsqueeze(0),
sample_frequency=16000, # this is fixed to 16000
num_mel_bins=128,
low_freq=low_freq,
snip_edges=snip_egdes,
high_freq=high_freq,
dither=dither,
energy_floor=1.0e-10,
)
features.append(feat)
feat_len = torch.tensor([f.shape[0] for f in features]).to(wavs.device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
).to(wavs.device)
return features, feat_len
def forward_audio_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor, freeze_encoder: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute audio encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
with torch.set_grad_enabled(not freeze_encoder):
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
if self.encoder_downsample is not None:
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_downsample(encoder_out)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out_lens = (encoder_out_lens + 1) // 2
padding_mask = make_pad_mask(encoder_out_lens)
encoder_out = encoder_out.masked_fill(padding_mask.unsqueeze(-1), 0.0)
embedding = encoder_out.sum(dim=1) / encoder_out_lens.unsqueeze(-1) # (N, C)
return embedding
def forward_text_encoder(self, y: dict, freeze_encoder: bool = False):
with torch.set_grad_enabled(not freeze_encoder):
encoder_out = self.text_encoder(
input_ids=y["input_ids"],
attention_mask=y["attention_mask"],
)["pooler_output"]
return encoder_out
def forward(
self,
audio: Optional[torch.Tensor] = None,
audio_lens: Optional[torch.Tensor] = None,
text: Optional[dict] = None,
freeze_audio_encoder: bool = False,
freeze_text_encoder: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
audio (torch.Tensor): Input audio waveforms (N, L).
audio_lens (torch.Tensor): The length of the audio waveforms (N).
text: Input text list (N).
Returns:
The encoded representations and logit scale.
"""
if audio is not None:
assert audio.ndim == 2, audio.shape
assert audio_lens.ndim == 1, audio_lens.shape
x, x_lens = self.compute_fbank(audio, audio_lens)
audio_encoder_out = self.forward_audio_encoder(
x, x_lens, freeze_encoder=freeze_audio_encoder
)
audio_encoder_out = self.audio_projection(audio_encoder_out)
audio_encoder_out = self.audio_transform(audio_encoder_out)
audio_encoder_out = F.normalize(audio_encoder_out, dim=-1)
if text is not None:
text = self.text_tokenizer(
text,
padding=True,
truncation=True,
return_tensors="pt",
)
text = {
k: v.to(device=next(self.parameters()).device) for k, v in text.items()
}
assert text["input_ids"].ndim == 2, text["input_ids"].shape
text_encoder_out = self.forward_text_encoder(
text, freeze_encoder=freeze_text_encoder
)
text_encoder_out = self.text_projection(text_encoder_out)
text_encoder_out = self.text_transform(text_encoder_out)
text_encoder_out = F.normalize(text_encoder_out, dim=-1)
return (
audio_encoder_out if audio is not None else None,
text_encoder_out if text is not None else None,
self.logit_scale.exp(),
)
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
def make_pad_mask(
lengths: torch.Tensor,
max_len: int = 0,
pad_left: bool = False,
) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
pad_left:
If ``False`` (default), padding is on the right.
If ``True``, padding is on the left.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len)
if pad_left:
mask = expanded_lengths < (max_len - lengths).unsqueeze(1)
else:
mask = expanded_lengths >= lengths.unsqueeze(-1)
return mask
def get_encoder_embed(config: CLSPConfig) -> nn.Module:
encoder_embed = Conv2dSubsampling(
in_channels=config.feature_dim,
out_channels=_to_int_tuple(config.encoder_dim)[0],
)
return encoder_embed
def get_encoder_model(config: CLSPConfig) -> nn.Module:
encoder = Zipformer2(
output_downsampling_factor=config.output_downsampling_factor,
downsampling_factor=_to_int_tuple(config.downsampling_factor),
num_encoder_layers=_to_int_tuple(config.num_encoder_layers),
encoder_dim=_to_int_tuple(config.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(config.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(config.query_head_dim),
pos_head_dim=_to_int_tuple(config.pos_head_dim),
value_head_dim=_to_int_tuple(config.value_head_dim),
pos_dim=config.pos_dim,
num_heads=_to_int_tuple(config.num_heads),
feedforward_dim=_to_int_tuple(config.feedforward_dim),
cnn_module_kernel=_to_int_tuple(config.cnn_module_kernel),
causal=config.causal,
chunk_size=_to_int_tuple(config.chunk_size),
left_context_frames=_to_int_tuple(config.left_context_frames),
)
return encoder
def get_model(config: CLSPConfig) -> nn.Module:
encoder_embed = get_encoder_embed(config)
encoder = get_encoder_model(config)
model = CLAP(
encoder_embed=encoder_embed,
encoder=encoder,
encoder_dim=max(_to_int_tuple(config.encoder_dim)),
text_encoder_dim=config.text_encoder_dim,
joint_dim=config.joint_dim,
)
return model