|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from torchaudio.compliance.kaldi import fbank as torch_fbank |
|
|
from transformers import PreTrainedModel, RobertaConfig, RobertaModel, RobertaTokenizer |
|
|
|
|
|
from .configuration_clsp import CLSPConfig |
|
|
from .zipformer2 import Conv2dSubsampling, Zipformer2 |
|
|
|
|
|
|
|
|
class CLSPModel(PreTrainedModel): |
|
|
config_class = CLSPConfig |
|
|
|
|
|
def __init__(self, config: CLSPConfig): |
|
|
super().__init__(config) |
|
|
self.model = get_model(config) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return self.model(*args, **kwargs) |
|
|
|
|
|
def load_audio(self, audio_path): |
|
|
return self.model.load_audio(audio_path) |
|
|
|
|
|
|
|
|
class MLPLayers(nn.Module): |
|
|
def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): |
|
|
super(MLPLayers, self).__init__() |
|
|
self.nonlin = nonlin |
|
|
self.dropout = dropout |
|
|
|
|
|
sequence = [] |
|
|
for u0, u1 in zip(units[:-1], units[1:]): |
|
|
sequence.append(nn.Linear(u0, u1)) |
|
|
sequence.append(self.nonlin) |
|
|
sequence.append(nn.Dropout(self.dropout)) |
|
|
sequence = sequence[:-2] |
|
|
|
|
|
self.sequential = nn.Sequential(*sequence) |
|
|
|
|
|
def forward(self, X): |
|
|
X = self.sequential(X) |
|
|
return X |
|
|
|
|
|
|
|
|
class CLAP(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
encoder_embed: nn.Module, |
|
|
encoder: nn.Module, |
|
|
encoder_downsample: Optional[nn.Module] = None, |
|
|
encoder_dim: int = 384, |
|
|
text_encoder_dim: int = 768, |
|
|
joint_dim: int = 512, |
|
|
): |
|
|
"""CLAP-style dual encoder model. |
|
|
|
|
|
Args: |
|
|
encoder_embed: |
|
|
It is a Convolutional 2D subsampling module. It converts |
|
|
an input of shape (N, T, idim) to an output of of shape |
|
|
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. |
|
|
encoder: |
|
|
It is the transcription network in the paper. Its accepts |
|
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). |
|
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and |
|
|
`logit_lens` of shape (N,). |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.encoder_embed = encoder_embed |
|
|
self.encoder = encoder |
|
|
self.encoder_downsample = encoder_downsample |
|
|
self.audio_projection = nn.Sequential( |
|
|
nn.Linear(encoder_dim, joint_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(joint_dim, joint_dim), |
|
|
) |
|
|
self.audio_transform = MLPLayers( |
|
|
units=[joint_dim, joint_dim, joint_dim], dropout=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
self.text_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
|
|
self.text_encoder = text_encoder = RobertaModel( |
|
|
RobertaConfig.from_pretrained("roberta-base") |
|
|
) |
|
|
self.text_projection = nn.Sequential( |
|
|
nn.Linear(text_encoder_dim, joint_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(joint_dim, joint_dim), |
|
|
) |
|
|
self.text_transform = MLPLayers( |
|
|
units=[joint_dim, joint_dim, joint_dim], dropout=0.1 |
|
|
) |
|
|
|
|
|
self.logit_scale = nn.Parameter(torch.full((), math.log(1 / 0.07))) |
|
|
|
|
|
def _load_audio_single(self, audio_path: str) -> Tuple[torch.Tensor, int]: |
|
|
waveform, sr = torchaudio.load(audio_path) |
|
|
if waveform.size(0) > 1: |
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
if sr != 16000: |
|
|
transform = torchaudio.transforms.Resample(sr, 16000) |
|
|
waveform = transform(waveform) |
|
|
waveform_len = waveform.shape[-1] |
|
|
return waveform, waveform_len |
|
|
|
|
|
def load_audio(self, audio_paths: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
assert isinstance(audio_paths, list), "Must receive a list of files for reading" |
|
|
waveforms = [] |
|
|
waveform_lens = [] |
|
|
for audio in audio_paths: |
|
|
wav, wav_len = self._load_audio_single(audio) |
|
|
waveforms.append(wav.squeeze()) |
|
|
waveform_lens.append(wav_len) |
|
|
|
|
|
waveforms = pad_sequence(waveforms, batch_first=True) |
|
|
waveform_lens = torch.tensor(waveform_lens) |
|
|
return waveforms, waveform_lens |
|
|
|
|
|
def compute_fbank( |
|
|
self, wavs: torch.Tensor, wav_lens: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute fbank features |
|
|
Args: |
|
|
wavs (torch.Tensor): the mono-channel input waveform, (N, T) |
|
|
wav_lens (torch.Tensor): the length of each waveform in samples (N) |
|
|
Returns: |
|
|
The fbank features, and their lengths |
|
|
""" |
|
|
assert wavs.ndim == 2, wavs.shape |
|
|
low_freq = 20.0 |
|
|
high_freq = -400.0 |
|
|
dither = 0.0 |
|
|
snip_egdes = False |
|
|
|
|
|
features = [] |
|
|
for i, wav in enumerate(wavs): |
|
|
feat = torch_fbank( |
|
|
wav[: wav_lens[i]].unsqueeze(0), |
|
|
sample_frequency=16000, |
|
|
num_mel_bins=128, |
|
|
low_freq=low_freq, |
|
|
snip_edges=snip_egdes, |
|
|
high_freq=high_freq, |
|
|
dither=dither, |
|
|
energy_floor=1.0e-10, |
|
|
) |
|
|
features.append(feat) |
|
|
feat_len = torch.tensor([f.shape[0] for f in features]).to(wavs.device) |
|
|
features = pad_sequence( |
|
|
features, batch_first=True, padding_value=math.log(1e-10) |
|
|
).to(wavs.device) |
|
|
return features, feat_len |
|
|
|
|
|
def forward_audio_encoder( |
|
|
self, x: torch.Tensor, x_lens: torch.Tensor, freeze_encoder: bool = False |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute audio encoder outputs. |
|
|
Args: |
|
|
x: |
|
|
A 3-D tensor of shape (N, T, C). |
|
|
x_lens: |
|
|
A 1-D tensor of shape (N,). It contains the number of frames in `x` |
|
|
before padding. |
|
|
|
|
|
Returns: |
|
|
encoder_out: |
|
|
Encoder output, of shape (N, T, C). |
|
|
encoder_out_lens: |
|
|
Encoder output lengths, of shape (N,). |
|
|
""" |
|
|
|
|
|
with torch.set_grad_enabled(not freeze_encoder): |
|
|
x, x_lens = self.encoder_embed(x, x_lens) |
|
|
src_key_padding_mask = make_pad_mask(x_lens) |
|
|
x = x.permute(1, 0, 2) |
|
|
encoder_out, encoder_out_lens = self.encoder( |
|
|
x, x_lens, src_key_padding_mask |
|
|
) |
|
|
encoder_out = encoder_out.permute(1, 0, 2) |
|
|
|
|
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) |
|
|
|
|
|
if self.encoder_downsample is not None: |
|
|
encoder_out = encoder_out.permute(1, 0, 2) |
|
|
encoder_out = self.encoder_downsample(encoder_out) |
|
|
encoder_out = encoder_out.permute(1, 0, 2) |
|
|
encoder_out_lens = (encoder_out_lens + 1) // 2 |
|
|
|
|
|
padding_mask = make_pad_mask(encoder_out_lens) |
|
|
encoder_out = encoder_out.masked_fill(padding_mask.unsqueeze(-1), 0.0) |
|
|
embedding = encoder_out.sum(dim=1) / encoder_out_lens.unsqueeze(-1) |
|
|
|
|
|
return embedding |
|
|
|
|
|
def forward_text_encoder(self, y: dict, freeze_encoder: bool = False): |
|
|
with torch.set_grad_enabled(not freeze_encoder): |
|
|
encoder_out = self.text_encoder( |
|
|
input_ids=y["input_ids"], |
|
|
attention_mask=y["attention_mask"], |
|
|
)["pooler_output"] |
|
|
|
|
|
return encoder_out |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
audio: Optional[torch.Tensor] = None, |
|
|
audio_lens: Optional[torch.Tensor] = None, |
|
|
text: Optional[dict] = None, |
|
|
freeze_audio_encoder: bool = False, |
|
|
freeze_text_encoder: bool = False, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
audio (torch.Tensor): Input audio waveforms (N, L). |
|
|
audio_lens (torch.Tensor): The length of the audio waveforms (N). |
|
|
text: Input text list (N). |
|
|
Returns: |
|
|
The encoded representations and logit scale. |
|
|
""" |
|
|
if audio is not None: |
|
|
assert audio.ndim == 2, audio.shape |
|
|
assert audio_lens.ndim == 1, audio_lens.shape |
|
|
x, x_lens = self.compute_fbank(audio, audio_lens) |
|
|
audio_encoder_out = self.forward_audio_encoder( |
|
|
x, x_lens, freeze_encoder=freeze_audio_encoder |
|
|
) |
|
|
audio_encoder_out = self.audio_projection(audio_encoder_out) |
|
|
audio_encoder_out = self.audio_transform(audio_encoder_out) |
|
|
audio_encoder_out = F.normalize(audio_encoder_out, dim=-1) |
|
|
|
|
|
if text is not None: |
|
|
text = self.text_tokenizer( |
|
|
text, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text = { |
|
|
k: v.to(device=next(self.parameters()).device) for k, v in text.items() |
|
|
} |
|
|
assert text["input_ids"].ndim == 2, text["input_ids"].shape |
|
|
text_encoder_out = self.forward_text_encoder( |
|
|
text, freeze_encoder=freeze_text_encoder |
|
|
) |
|
|
text_encoder_out = self.text_projection(text_encoder_out) |
|
|
text_encoder_out = self.text_transform(text_encoder_out) |
|
|
text_encoder_out = F.normalize(text_encoder_out, dim=-1) |
|
|
|
|
|
return ( |
|
|
audio_encoder_out if audio is not None else None, |
|
|
text_encoder_out if text is not None else None, |
|
|
self.logit_scale.exp(), |
|
|
) |
|
|
|
|
|
|
|
|
def _to_int_tuple(s: str): |
|
|
return tuple(map(int, s.split(","))) |
|
|
|
|
|
|
|
|
def make_pad_mask( |
|
|
lengths: torch.Tensor, |
|
|
max_len: int = 0, |
|
|
pad_left: bool = False, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
lengths: |
|
|
A 1-D tensor containing sentence lengths. |
|
|
max_len: |
|
|
The length of masks. |
|
|
pad_left: |
|
|
If ``False`` (default), padding is on the right. |
|
|
If ``True``, padding is on the left. |
|
|
Returns: |
|
|
Return a 2-D bool tensor, where masked positions |
|
|
are filled with `True` and non-masked positions are |
|
|
filled with `False`. |
|
|
|
|
|
>>> lengths = torch.tensor([1, 3, 2, 5]) |
|
|
>>> make_pad_mask(lengths) |
|
|
tensor([[False, True, True, True, True], |
|
|
[False, False, False, True, True], |
|
|
[False, False, True, True, True], |
|
|
[False, False, False, False, False]]) |
|
|
""" |
|
|
assert lengths.ndim == 1, lengths.ndim |
|
|
max_len = max(max_len, lengths.max()) |
|
|
n = lengths.size(0) |
|
|
seq_range = torch.arange(0, max_len, device=lengths.device) |
|
|
expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len) |
|
|
|
|
|
if pad_left: |
|
|
mask = expanded_lengths < (max_len - lengths).unsqueeze(1) |
|
|
else: |
|
|
mask = expanded_lengths >= lengths.unsqueeze(-1) |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
def get_encoder_embed(config: CLSPConfig) -> nn.Module: |
|
|
encoder_embed = Conv2dSubsampling( |
|
|
in_channels=config.feature_dim, |
|
|
out_channels=_to_int_tuple(config.encoder_dim)[0], |
|
|
) |
|
|
return encoder_embed |
|
|
|
|
|
|
|
|
def get_encoder_model(config: CLSPConfig) -> nn.Module: |
|
|
encoder = Zipformer2( |
|
|
output_downsampling_factor=config.output_downsampling_factor, |
|
|
downsampling_factor=_to_int_tuple(config.downsampling_factor), |
|
|
num_encoder_layers=_to_int_tuple(config.num_encoder_layers), |
|
|
encoder_dim=_to_int_tuple(config.encoder_dim), |
|
|
encoder_unmasked_dim=_to_int_tuple(config.encoder_unmasked_dim), |
|
|
query_head_dim=_to_int_tuple(config.query_head_dim), |
|
|
pos_head_dim=_to_int_tuple(config.pos_head_dim), |
|
|
value_head_dim=_to_int_tuple(config.value_head_dim), |
|
|
pos_dim=config.pos_dim, |
|
|
num_heads=_to_int_tuple(config.num_heads), |
|
|
feedforward_dim=_to_int_tuple(config.feedforward_dim), |
|
|
cnn_module_kernel=_to_int_tuple(config.cnn_module_kernel), |
|
|
causal=config.causal, |
|
|
chunk_size=_to_int_tuple(config.chunk_size), |
|
|
left_context_frames=_to_int_tuple(config.left_context_frames), |
|
|
) |
|
|
return encoder |
|
|
|
|
|
|
|
|
def get_model(config: CLSPConfig) -> nn.Module: |
|
|
encoder_embed = get_encoder_embed(config) |
|
|
encoder = get_encoder_model(config) |
|
|
model = CLAP( |
|
|
encoder_embed=encoder_embed, |
|
|
encoder=encoder, |
|
|
encoder_dim=max(_to_int_tuple(config.encoder_dim)), |
|
|
text_encoder_dim=config.text_encoder_dim, |
|
|
joint_dim=config.joint_dim, |
|
|
) |
|
|
return model |
|
|
|