Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| File: vlm.py | |
| Description: Vision language model utility functions. | |
| Author: Didier Guillevic | |
| Date: 2025-05-08 | |
| """ | |
| from transformers import AutoProcessor | |
| from transformers import Mistral3ForConditionalGeneration | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| import re | |
| import time | |
| import torch | |
| import base64 | |
| import spaces | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| # | |
| # Load the model: OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym | |
| # | |
| model_id = "OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Mistral3ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| #_attn_implementation="flash_attention_2", | |
| torch_dtype=torch.float16 | |
| ).eval().to(device) | |
| # | |
| # Encode images as base64 | |
| # | |
| def encode_image(image_path): | |
| """Encode the image to base64.""" | |
| try: | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| except FileNotFoundError: | |
| print(f"Error: The file {image_path} was not found.") | |
| return None | |
| except Exception as e: # Added general exception handling | |
| print(f"Error: {e}") | |
| return None | |
| # | |
| # Build messages | |
| # | |
| def normalize_message_content(msg: dict) -> dict: | |
| content = msg.get("content") | |
| # Case 1: Already in expected format | |
| if isinstance(content, list) and all(isinstance(item, dict) for item in content): | |
| return {"role": msg["role"], "content": content} | |
| # Case 2: String (assume text) | |
| if isinstance(content, str): | |
| return {"role": msg["role"], "content": [{"type": "text", "text": content}]} | |
| # Case 3: Tuple with image path(s) | |
| if isinstance(content, tuple): | |
| return { | |
| "role": msg["role"], | |
| "content": [ | |
| {"type": "image", "image": encode_image(path)} # your `encode_image()` function | |
| for path in content if isinstance(path, str) | |
| ] | |
| } | |
| logger.warning(f"Unexpected content format in message: {msg}") | |
| return {"role": msg["role"], "content": [{"type": "text", "text": str(content)}]} | |
| def build_messages(message: dict, history: list[dict]): | |
| """Build messages given message & history from a **multimodal** chat interface. | |
| Args: | |
| message: dictionary with keys: 'text', 'files' | |
| history: list of dictionaries | |
| Returns: | |
| list of messages (to be sent to the model) | |
| """ | |
| logger.info(f"{message=}") | |
| logger.info(f"{history=}") | |
| # Get the user's text and list of images | |
| user_text = message.get("text", "") | |
| user_images = message.get("files", []) # List of images | |
| # Build the user message's content from the provided message | |
| user_content = [] | |
| if user_text: | |
| user_content.append({"type": "text", "text": user_text}) | |
| for image in user_images: | |
| user_content.append( | |
| { | |
| "type": "image", | |
| "image": f"data:image/jpeg;base64,{encode_image(image)}" | |
| } | |
| ) | |
| # Normalize existing history content | |
| messages = [normalize_message_content(msg) for msg in history] | |
| # Append new user message | |
| messages.append({'role': 'user', 'content': user_content}) | |
| logger.info(f"{messages=}") | |
| return messages | |
| # | |
| # stream response | |
| # | |
| def stream_response( | |
| messages: list[dict], | |
| max_new_tokens: int=1_024, | |
| temperature: float=0.15 | |
| ): | |
| """Stream the model's response to the chat interface. | |
| Args: | |
| messages: list of messages to send to the model | |
| """ | |
| # Generate model's response | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device, dtype=torch.float16) | |
| # Generate | |
| streamer = TextIteratorStreamer( | |
| processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_args = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_args) | |
| thread.start() | |
| partial_message = "" | |
| for new_text in streamer: | |
| partial_message += new_text | |
| yield partial_message | |