| | from dataclasses import dataclass |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from timm import create_model |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModel, |
| | AutoTokenizer, |
| | PretrainedConfig, |
| | PreTrainedModel, |
| | ) |
| | from transformers.utils import ModelOutput |
| |
|
| | from .location_encoder import LocationEncoder |
| |
|
| |
|
| | class CLOSPConfig(PretrainedConfig): |
| | """ |
| | Configuration class for CLOSPModel. |
| | |
| | This class stores the configuration of a CLOSPModel, which is used to instantiate the model |
| | according to the specified parameters. |
| | """ |
| |
|
| | model_type = "closp" |
| |
|
| | def __init__( |
| | self, |
| | |
| | vision_model_key: str = "vit-s", |
| | s1_embedding_dim: int = 384, |
| | s2_embedding_dim: int = 384, |
| | s1_head_dim: int = 0, |
| | s2_head_dim: int = 0, |
| | |
| | text_model_name_or_path: str = "distilbert-base-uncased", |
| | |
| | use_location_encoder: bool = True, |
| | location_embedding_dim: int = 512, |
| | |
| | projection_dim: int = 768, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.vision_model_key = vision_model_key |
| | self.s1_embedding_dim = s1_embedding_dim |
| | self.s2_embedding_dim = s2_embedding_dim |
| | self.text_model_name_or_path = text_model_name_or_path |
| | self.use_location_encoder = use_location_encoder |
| | self.location_embedding_dim = location_embedding_dim |
| | self.projection_dim = projection_dim |
| | self.s1_head_dim = s1_head_dim |
| | self.s2_head_dim = s2_head_dim |
| |
|
| |
|
| | |
| | @dataclass |
| | class CLOSPOutput(ModelOutput): |
| | """ |
| | Base class for CLOSP model's outputs. |
| | """ |
| |
|
| | loss: torch.FloatTensor = None |
| | logits_per_image: torch.FloatTensor = None |
| | logits_per_text: torch.FloatTensor = None |
| | logits_per_loc_img: torch.FloatTensor = None |
| | logits_per_img_loc: torch.FloatTensor = None |
| | image_embeds: torch.FloatTensor = None |
| | text_embeds: torch.FloatTensor = None |
| | location_embeds: torch.FloatTensor = None |
| |
|
| |
|
| | class CLOSPModel(PreTrainedModel): |
| | config_class = CLOSPConfig |
| |
|
| | def __init__(self, config: CLOSPConfig): |
| | super().__init__(config) |
| | |
| | self.s1_encoder = create_model( |
| | config.vision_model_key, |
| | in_chans=2, |
| | num_classes=config.s1_head_dim, |
| | pretrained=False, |
| | ) |
| | self.s2_encoder = create_model( |
| | config.vision_model_key, |
| | in_chans=13, |
| | num_classes=config.s2_head_dim, |
| | pretrained=False, |
| | ) |
| | self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim) |
| | self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim) |
| |
|
| | |
| | self.text_model = AutoModel.from_config( |
| | AutoConfig.from_pretrained(config.text_model_name_or_path) |
| | ) |
| | self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) |
| |
|
| | |
| | if config.use_location_encoder: |
| | self.location_encoder = LocationEncoder(512, 2, 256, 10) |
| | self.location_projection = nn.Linear( |
| | config.location_embedding_dim, config.projection_dim |
| | ) |
| |
|
| | def tokenize_text(self, text: str): |
| | """Tokenizes input text using the model's tokenizer.""" |
| | return self.tokenizer( |
| | text, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=self.tokenizer.model_max_length, |
| | return_tensors="pt", |
| | ) |
| |
|
| | def get_image_features(self, image: torch.Tensor) -> torch.Tensor: |
| | """Encodes an image tensor into features.""" |
| | image = image.float() |
| | if image.shape[1] == 2: |
| | image_features = self.s1_projection(self.s1_encoder(image)) |
| | else: |
| | image_features = self.s2_projection(self.s2_encoder(image)) |
| |
|
| | return F.normalize(image_features, p=2, dim=-1) |
| |
|
| | def get_text_features( |
| | self, input_ids: torch.Tensor, attention_mask: torch.Tensor |
| | ) -> torch.Tensor: |
| | """Encodes text tokens into features.""" |
| | text_outputs = self.text_model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | output_hidden_states=True, |
| | ) |
| | text_features = text_outputs.last_hidden_state[:, 0, :] |
| | return F.normalize(text_features, p=2, dim=-1) |
| |
|
| | def get_location_features(self, coords: torch.Tensor) -> torch.Tensor: |
| | """Encodes coordinates into features.""" |
| | if not self.config.use_location_encoder: |
| | raise ValueError( |
| | "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config." |
| | ) |
| | location_features = self.location_encoder(coords) |
| | location_features = self.location_projection(location_features) |
| | return F.normalize(location_features, p=2, dim=-1) |
| |
|
| | def forward( |
| | self, |
| | image: torch.Tensor, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | coords: torch.Tensor = None, |
| | return_loss: bool = False, |
| | ) -> CLOSPOutput: |
| | image_embeds = self.get_image_features(image) |
| | text_embeds = self.get_text_features(input_ids, attention_mask) |
| |
|
| | |
| | logits_per_image = image_embeds @ text_embeds.T |
| | logits_per_text = logits_per_image.T |
| |
|
| | |
| | location_embeds = None |
| | logits_per_loc_img = None |
| | logits_per_img_loc = None |
| |
|
| | if self.config.use_location_encoder: |
| | if coords is None: |
| | raise ValueError( |
| | "Coordinates must be provided when use_location_encoder is True." |
| | ) |
| | location_embeds = self.get_location_features(coords) |
| | logits_per_loc_img = location_embeds @ image_embeds.T |
| | logits_per_img_loc = image_embeds @ location_embeds.T |
| |
|
| | |
| | loss = None |
| | if return_loss: |
| | outputs = [ |
| | logits_per_image, |
| | logits_per_text, |
| | logits_per_loc_img, |
| | logits_per_img_loc, |
| | ] |
| | ground_truth = torch.arange(len(input_ids)).to(self.device) |
| | loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None] |
| | loss = sum(loss) / len(loss) |
| |
|
| | return CLOSPOutput( |
| | loss=loss, |
| | logits_per_image=logits_per_image, |
| | logits_per_text=logits_per_text, |
| | logits_per_loc_img=logits_per_loc_img, |
| | logits_per_img_loc=logits_per_img_loc, |
| | image_embeds=image_embeds, |
| | text_embeds=text_embeds, |
| | location_embeds=location_embeds, |
| | ) |
| |
|