| from transformers import Gemma3nAudioEncoder, Gemma3nConfig | |
| from transformers import AutoFeatureExtractor, PreTrainedModel | |
| from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder | |
| class Audio(PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.audio_tower = Gemma3nAudioEncoder(config.audio_config) | |
| self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) | |
| class GemmaAudio(PreTrainedModel): | |
| config_class = Gemma3nConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = Audio(config) | |
| def forward(self, input_features, input_features_mask, **kwargs): | |
| output = self.model.audio_tower( | |
| input_features, ~input_features_mask, | |
| ) | |
| project = self.model.embed_audio(inputs_embeds = output[0]) | |
| return project, output[1] |