""" SafeQwen2.5-VL Model Implementation SafeQwen2.5-VL extends Qwen2.5-VL with multimodal safety classification capabilities. It adds a safety classification head that operates on pooled image features to identify potentially unsafe content across 20 safety categories. Key features: - Non-invasive architecture: Uses standard Qwen2.5-VL forward pass - Post-processing safety classification on image features - Simple pooling strategy for feature aggregation - Full gradient flow compatibility for training Author: SafeQwen Team """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List, Union from dataclasses import dataclass from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig, ) from .configuration_safeqwen2_5_vl import SafeQwen2_5_VLConfig @dataclass class SafeQwen2_5_VLOutput(CausalLMOutputWithPast): """ Output class for SafeQwen2.5-VL with safety classification results. Extends the standard CausalLMOutputWithPast to include safety-related outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*): Language modeling loss (and safety loss if labels provided). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*): Cached key/value attention states. hidden_states (`tuple(torch.FloatTensor)`, *optional*): Hidden states of the model at each layer. attentions (`tuple(torch.FloatTensor)`, *optional*): Attention weights at each layer. rope_deltas (`torch.LongTensor`, *optional*): RoPE position deltas for Qwen2.5-VL. img_safety_logits (`torch.FloatTensor` of shape `(batch_size, num_safety_categories)`, *optional*): Safety classification logits for each image in the batch. img_safety_probs (`torch.FloatTensor` of shape `(batch_size, num_safety_categories)`, *optional*): Safety classification probabilities (softmax of logits). """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None img_safety_logits: Optional[torch.FloatTensor] = None img_safety_probs: Optional[torch.FloatTensor] = None class SafetyMLP(nn.Module): """ Multi-layer perceptron for safety classification. A simple feedforward network that maps image features to safety category logits. Args: input_size (`int`): Size of input features (typically model hidden size). hidden_size (`int`): Size of hidden layer(s). output_size (`int`): Number of output safety categories. num_hidden_layers (`int`, *optional*, defaults to 1): Number of hidden layers in the MLP. """ def __init__( self, input_size: int, hidden_size: int, output_size: int, num_hidden_layers: int = 1 ): super().__init__() layers = [] # First layer layers.append(nn.Linear(input_size, hidden_size)) layers.append(nn.GELU()) layers.append(nn.Dropout(0.1)) # Additional hidden layers for _ in range(num_hidden_layers - 1): layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(nn.GELU()) layers.append(nn.Dropout(0.1)) # Output layer layers.append(nn.Linear(hidden_size, output_size)) self.mlp = nn.Sequential(*layers) def forward(self, x): return self.mlp(x) class SafeQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): """ SafeQwen2.5-VL model for conditional generation with safety classification. This model extends Qwen2_5_VLForConditionalGeneration with an additional safety classification head that analyzes image content for potential safety concerns. The model architecture: 1. Uses standard Qwen2.5-VL for vision-language modeling 2. Extracts image features from hidden states using pooling 3. Passes pooled features through a safety classification MLP 4. Returns both generation outputs and safety predictions Key design principles: - Non-invasive: Does not modify base Qwen2.5-VL forward pass - Post-processing: Safety classification happens after standard forward pass - Gradient-friendly: Maintains full gradient flow for end-to-end training Example: ```python from transformers import AutoModel, AutoProcessor import torch # Load model and processor model = AutoModel.from_pretrained("your-username/SafeQwen2.5-VL-7B", trust_remote_code=True) processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") # Prepare inputs messages = [ { "role": "user", "content": [ {"type": "image", "image": "path/to/image.jpg"}, {"type": "text", "text": "Describe this image."}, ], } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Generate with safety classification outputs = model(**inputs, do_safety=True) # Access safety predictions safety_probs = outputs.img_safety_probs print(f"Safety probabilities: {safety_probs}") # Generate text generated_ids = model.generate(**inputs, max_new_tokens=128) ``` """ config_class = SafeQwen2_5_VLConfig def __init__(self, config: SafeQwen2_5_VLConfig): super().__init__(config) # Add safety head if safety configuration is present num_safety_categories = getattr(config, 'num_safety_categories', None) if num_safety_categories and num_safety_categories > 0: hidden_size = config.hidden_size safety_head_hidden_scale = getattr(config, 'safety_head_hidden_scale', 4.0) safety_hidden_size = int(hidden_size * safety_head_hidden_scale) safety_num_hidden_layers = getattr(config, 'safety_num_hidden_layers', 1) self.img_safety_head = SafetyMLP( input_size=hidden_size, hidden_size=safety_hidden_size, output_size=num_safety_categories, num_hidden_layers=safety_num_hidden_layers ) else: self.img_safety_head = None def _extract_image_features_simple( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None ) -> Optional[torch.Tensor]: """ Extract image features using pooling over image token positions. Args: hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): Hidden states from the model. attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): Attention mask (currently unused, reserved for future use). input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): Input token IDs used to identify image token positions. Returns: `torch.Tensor` of shape `(batch_size, hidden_size)` or `None`: Pooled image features for each sample in the batch. """ if input_ids is None: return None # Find image token positions (Qwen2.5-VL uses specific image token IDs) image_token_id = getattr(self.config, 'image_token_id', 151655) # Create mask for image tokens and move to same device as hidden_states image_mask = (input_ids == image_token_id).to(hidden_states.device) # [batch_size, seq_len] if not image_mask.any(): return None # Pool image token features for each sample in batch batch_size = hidden_states.shape[0] hidden_size = hidden_states.shape[-1] # Use list comprehension to avoid in-place operations image_features_list = [] for i in range(batch_size): sample_image_mask = image_mask[i] # [seq_len] if sample_image_mask.any(): # Extract hidden states for image tokens sample_image_features = hidden_states[i][sample_image_mask] # [num_image_tokens, hidden_size] # Simple mean pooling - maintains gradients pooled_features = sample_image_features.mean(dim=0) # [hidden_size] image_features_list.append(pooled_features) else: # For samples without images, use gradient-preserving zero zero_features = hidden_states[i, 0, :] * 0.0 image_features_list.append(zero_features) # Stack the features - this maintains gradient flow image_features = torch.stack(image_features_list, dim=0) # [batch_size, hidden_size] return image_features def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, do_safety: bool = True, safety_labels: Optional[torch.LongTensor] = None, **kwargs ) -> Union[Tuple, SafeQwen2_5_VLOutput]: """ Forward pass with optional safety classification. Args: do_safety (`bool`, *optional*, defaults to `True`): Whether to perform safety classification. Set to False during generation. safety_labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Ground truth safety category labels for training (currently unused). Returns: `SafeQwen2_5_VLOutput` or `tuple`: Model outputs including optional safety predictions. """ # Force output_hidden_states if we need safety classification if do_safety and self.img_safety_head is not None: output_hidden_states = True return_dict = True # Standard Qwen2.5-VL forward pass - NO MODIFICATIONS outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, return_dict=True, **kwargs ) # Initialize safety outputs img_safety_logits = None img_safety_probs = None # Post-process for safety classification # Only do safety classification during initial forward pass, not during generation is_generation = past_key_values is not None and len(past_key_values) > 0 # Check if we have image tokens in the input has_image_tokens = False if input_ids is not None: image_token_id = getattr(self.config, 'image_token_id', 151655) has_image_tokens = (input_ids == image_token_id).any().item() # Only perform safety classification if: # 1. Safety is requested # 2. We have a safety head # 3. We have hidden states # 4. We have image tokens # 5. This is NOT during text generation should_do_safety = ( do_safety and self.img_safety_head is not None and outputs.hidden_states is not None and has_image_tokens and not is_generation ) if should_do_safety: # Extract image features from hidden states last_hidden_state = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size] image_features = self._extract_image_features_simple( last_hidden_state, attention_mask, input_ids ) if image_features is not None: # Run through safety head img_safety_logits = self.img_safety_head(image_features) img_safety_probs = torch.softmax(img_safety_logits, dim=-1) # Return results if return_dict is False: output = (outputs.loss, outputs.logits, outputs.past_key_values, outputs.hidden_states, outputs.attentions) if img_safety_logits is not None: output += (img_safety_logits, img_safety_probs) return output else: return SafeQwen2_5_VLOutput( loss=outputs.loss, logits=outputs.logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=getattr(outputs, 'rope_deltas', None), img_safety_logits=img_safety_logits, img_safety_probs=img_safety_probs )