import logging
import azure.functions as func
from PIL import Image
import io
import torch
import sys
from pathlib import Path

# Añadir el directorio de los módulos al sys.path
module_path = Path(__file__).parent / 'RMBG'
sys.path.append(str(module_path))

from briarmbg import BriaRMBG
from utilities import preprocess_image, postprocess_image
import numpy as np

app = func.FunctionApp(http_auth_level=func.AuthLevel.ANONYMOUS)

# Función para redimensionar la imagen manteniendo la relación de aspecto
def resize_image(image, target_width=1440, target_height=1440):
    original_width, original_height = image.size

    # Calcular la relación de aspecto original
    aspect_ratio = original_width / original_height

    # Calcular las nuevas dimensiones manteniendo la relación de aspecto
    if aspect_ratio > 1:  # Imagen horizontal
        new_width = target_width
        new_height = int(target_width / aspect_ratio)
    else:  # Imagen vertical
        new_height = target_height
        new_width = int(target_height * aspect_ratio)

    # Redimensionar la imagen
    resized_img = image.resize((new_width, new_height), Image.LANCZOS)

    return resized_img

# Cargar el modelo una vez para reutilizarlo en futuras invocaciones
net = BriaRMBG()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
net.to(device)
net.eval()

@app.route(route="processimage")
def process_image(req: func.HttpRequest) -> func.HttpResponse:
    logging.info('Python HTTP trigger function processed a request.')

    try:
        # Leer la imagen de entrada desde la solicitud
        image_data = req.get_body()
        input_image = Image.open(io.BytesIO(image_data))

        # Redimensionar la imagen
        resized_image = resize_image(input_image)

        # Preparar la imagen redimensionada para el modelo
        model_input_size = [1024, 1024]
        image = preprocess_image(np.array(resized_image), model_input_size).to(device)

        # Inferencia
        result = net(image)

        # Post-procesamiento
        result_image = postprocess_image(result[0][0], input_image.size)

        # Guardar el resultado
        pil_im = Image.fromarray(result_image)

        # Crear una nueva imagen en blanco
        white_bg = Image.new("RGBA", (1440, 2560), (255, 255, 255, 255))

        # Redimensionar la máscara para que coincida con las dimensiones de la imagen redimensionada
        mask_resized = pil_im.resize(resized_image.size, Image.LANCZOS)

        # Calcular la posición para centrar la imagen en el fondo blanco
        x_offset = (white_bg.width - resized_image.width) // 2
        y_offset = white_bg.height - resized_image.height - 100

        # Pegar la imagen sin fondo en el fondo blanco
        white_bg.paste(resized_image, (x_offset, y_offset), mask=mask_resized)

        # Convertir la imagen resultante a bytes para la respuesta HTTP
        output_buffer = io.BytesIO()
        white_bg.save(output_buffer, format="PNG")
        output_buffer.seek(0)

        return func.HttpResponse(output_buffer.read(), mimetype="image/png")

    except Exception as e:
        logging.error(f"Error processing the image: {e}")
        return func.HttpResponse(f"Error processing the image: {e}", status_code=500)
