| import torch | |
| from PIL import Image | |
| import torchvision.transforms.functional as TVF | |
| import google.generativeai as genai | |
| import os | |
| GEMINI_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
| if GEMINI_API_KEY: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| gemini_model = genai.GenerativeModel('gemini-1.5-flash') | |
| else: | |
| print("Warning: GOOGLE_API_KEY not found in environment variables") | |
| gemini_model = None | |
| CAPTION_TYPE_MAP = { | |
| "Descriptive": [ | |
| "Write a descriptive caption for this image in a formal tone.", | |
| "Write a descriptive caption for this image in a formal tone within {word_count} words.", | |
| "Write a {length} descriptive caption for this image in a formal tone.", | |
| ], | |
| "Training Prompt": [ | |
| "Write a stable diffusion prompt for this image.", | |
| "Write a stable diffusion prompt for this image within {word_count} words.", | |
| "Write a {length} stable diffusion prompt for this image.", | |
| ], | |
| "MidJourney": [ | |
| "Write a MidJourney prompt for this image.", | |
| "Write a MidJourney prompt for this image within {word_count} words.", | |
| "Write a {length} MidJourney prompt for this image.", | |
| ], | |
| } | |
| def get_image_features(input_image: Image.Image, clip_model, image_adapter=None): | |
| """Extract features from image using CLIP""" | |
| image = input_image.resize((384, 384), Image.LANCZOS) | |
| pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 | |
| pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) | |
| with torch.no_grad(): | |
| vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True) | |
| if image_adapter is not None: | |
| embedded_images = image_adapter(vision_outputs.hidden_states) | |
| return embedded_images | |
| else: | |
| return vision_outputs.last_hidden_state | |
| def generate_caption(input_image: Image.Image, | |
| caption_type: str = "Descriptive", | |
| caption_length: str = "long", | |
| extra_options: list = None, | |
| name_input: str = "", | |
| custom_prompt: str = "", | |
| clip_model=None, | |
| image_adapter=None): | |
| """ | |
| Generate caption for an image using Gemini API. No Bullet points, proper punctuation, and no extra information. | |
| Args: | |
| input_image: PIL Image object | |
| caption_type: Type of caption ("Descriptive", "Training Prompt", "MidJourney") | |
| caption_length: Length specification ("any", "short", "long", etc. or number as string) | |
| extra_options: List of extra options | |
| name_input: Name to use for person/character in image | |
| custom_prompt: Custom prompt to override default settings | |
| clip_model: CLIP model (optional, for compatibility) | |
| image_adapter: Image adapter model (optional, for compatibility) | |
| Returns: | |
| tuple: (generated_caption) | |
| """ | |
| if gemini_model is None: | |
| return "Error: Gemini API key not configured", "Please set GEMINI_API_KEY environment variable" | |
| if input_image is None: | |
| return "Error: No image provided", "Please provide an image" | |
| if extra_options is None: | |
| extra_options = [] | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| length = None if caption_length == "any" else caption_length | |
| if isinstance(length, str): | |
| try: | |
| length = int(length) | |
| except ValueError: | |
| pass | |
| if length is None: | |
| map_idx = 0 | |
| elif isinstance(length, int): | |
| map_idx = 1 | |
| elif isinstance(length, str): | |
| map_idx = 2 | |
| else: | |
| raise ValueError(f"Invalid caption length: {length}") | |
| prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx] | |
| if len(extra_options) > 0: | |
| prompt_str += " " + " ".join(extra_options) | |
| prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length) | |
| if custom_prompt.strip() != "": | |
| prompt_str = custom_prompt.strip() | |
| try: | |
| if clip_model is not None: | |
| image_features = get_image_features(input_image, clip_model, image_adapter) | |
| print(f"Extracted image features shape: {image_features.shape if hasattr(image_features, 'shape') else 'N/A'}") | |
| full_prompt = f"""You are a helpful image captioner. | |
| {prompt_str} | |
| Please analyze the provided image and generate a caption according to the instructions above. Just only the caption text, no additional information.""" | |
| response = gemini_model.generate_content([full_prompt, input_image]) | |
| if response.text: | |
| caption = response.text.strip() | |
| else: | |
| caption = "Failed to generate caption" | |
| except Exception as e: | |
| print(f"Error generating caption: {str(e)}") | |
| return prompt_str, f"Error: {str(e)}" | |
| return prompt_str, caption | |
| def caption_image_from_path(image_path: str, **kwargs): | |
| """Caption an image from file path""" | |
| image = Image.open(image_path) | |
| return generate_caption(image, **kwargs) | |
| def caption_image_simple(image_path: str, caption_type: str = "Descriptive"): | |
| """Simple interface to caption an image""" | |
| image = Image.open(image_path) | |
| prompt_used, caption = generate_caption(image, caption_type=caption_type) | |
| print(f"Caption: {caption}") | |
| return caption | |