Spaces:
Runtime error
Runtime error
| ''' | |
| LinCIR | |
| Copyright (c) 2023-present NAVER Corp. | |
| CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) | |
| ''' | |
| import torch | |
| from clip.model import CLIP | |
| from transformers import CLIPTextModelWithProjection | |
| def _make_causal_mask( | |
| input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 | |
| ): | |
| """ | |
| Make causal mask used for bi-directional self-attention. | |
| Copy-paste from https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/clip/modeling_clip.py#L679-L693 | |
| """ | |
| bsz, tgt_len = input_ids_shape | |
| mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) | |
| mask_cond = torch.arange(mask.size(-1), device=device) | |
| mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) | |
| mask = mask.to(dtype) | |
| if past_key_values_length > 0: | |
| mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) | |
| return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) | |
| def encode_with_pseudo_tokens_HF(clip_model: CLIPTextModelWithProjection, text: torch.Tensor, pseudo_tokens: torch.Tensor, | |
| num_tokens=1, return_last_states=False) -> torch.Tensor: | |
| x = clip_model.text_model.embeddings.token_embedding(text).type(clip_model.dtype) # [batch_size, n_ctx, d_model] | |
| x = torch.where(text.unsqueeze(-1) == 259, | |
| pseudo_tokens.unsqueeze(1).type(clip_model.dtype), | |
| x) | |
| x = x + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids) | |
| _causal_attention_mask = _make_causal_mask(text.shape, x.dtype, device=x.device) | |
| x = clip_model.text_model.encoder(inputs_embeds=x, | |
| attention_mask=None, | |
| causal_attention_mask=_causal_attention_mask, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| return_dict=False) | |
| x = x[0] | |
| x_last = clip_model.text_model.final_layer_norm(x) | |
| x = x_last[torch.arange(x_last.shape[0], device=x_last.device), | |
| text.to(dtype=torch.int, device=x_last.device).argmax(dim=-1), | |
| ] | |
| if hasattr(clip_model, 'text_projection'): | |
| x = clip_model.text_projection(x) | |
| if return_last_states: | |
| return x, x_last | |
| else: | |
| return x | |