jerry2247's picture
Update app.py
08c9ad5 verified
import numpy as np
import torch
import joblib
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
import gradio as gr
import cv2
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# My model from Collab (unchanged)
class ImageAuthenticityClassifier(nn.Module):
def __init__(self, backbone, w, b):
super().__init__()
self.backbone = backbone
d = w.shape[0]
self.head = nn.Linear(d, 1)
# Load my trained classifier head
with torch.no_grad():
self.head.weight.copy_(
w.unsqueeze(0).to(dtype=self.head.weight.dtype,
device=self.head.weight.device)
)
bias_tensor = torch.tensor(
[b],
dtype=self.head.bias.dtype,
device=self.head.bias.device,
)
self.head.bias.copy_(bias_tensor)
def forward(self, pixel_values, return_tokens: bool = False):
outputs = self.backbone(pixel_values=pixel_values)
hidden = outputs.last_hidden_state
patch_tokens = hidden[:, 1:, :]
emb = patch_tokens.mean(dim = 1)
logits = self.head(emb) # Apply classifier head to mean patch token embeddings
prob = torch.sigmoid(logits)
if (return_tokens):
return logits, prob, emb, patch_tokens
return logits, prob, emb
# Load linear classifier head for logistic regression
model_save_path = "logisticRegressionClassifier.joblib"
logisticRegressionClassifier = joblib.load(model_save_path)
coef = logisticRegressionClassifier.coef_
w = torch.from_numpy(coef.squeeze(0)).float()
intercept = logisticRegressionClassifier.intercept_
b = float(intercept[0])
# Load DinoV3 backbone + processor (gated repo via token)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
# Inference helper functions (unchanged)
def load_image(online_image_url):
img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
return img
def prepare_pixel_values(img):
inputs = processor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
return pixel_values
# Unused
def predict_from_online_url(online_image_url):
img = load_image(online_image_url)
pixel_values = prepare_pixel_values(img)
with torch.no_grad():
logits, prob, emb = image_auth_model(pixel_values)
return float(prob[0][0].item())
# Grad-CAM Helper Functions (Unchanged) -------------------
def compute_cam_from_tokens(patch_tokens, pixel_values, patch_size=16):
# Dimension calculations
H_in, W_in = pixel_values.shape[-2], pixel_values.shape[-1]
H_p = H_in // patch_size
W_p = W_in // patch_size
num_spatial = H_p * W_p
# Tokens and grads for all 200 tokens after CLS. Keep only the spatial patch tokens (drop the 4 global tokens at start)
tokens_all = patch_tokens[0] # (200, D)
grads_all = patch_tokens.grad[0] # (200, D)
tokens_spatial = tokens_all[-num_spatial:, :] # (196, D)
grads_spatial = grads_all[-num_spatial:, :] # (196, D)
# Get a single weight per feature dimension averaged over all patches
weights = grads_spatial.mean(dim=0) # (D,)
# For each patch, combine activation and weights to make different importance for each patch, and normalize results.
cam_per_patch = (tokens_spatial * weights).sum(dim=-1)
cam_per_patch = torch.relu(cam_per_patch)
cam_per_patch = cam_per_patch - cam_per_patch.min()
cam_per_patch = cam_per_patch / (cam_per_patch.max() + 1e-8) # shape: (N,)
cam_grid = cam_per_patch.reshape(H_p, W_p)
cam = cam_grid.unsqueeze(0).unsqueeze(0) # (1, 1, H_p, W_p)
cam_up = F.interpolate(
cam,
size=(H_in, W_in),
mode="bilinear",
align_corners=False,
)[0, 0] # (H_in, W_in)
return cam_up
def grad_cam_from_online_url(online_image_url):
# Load image and get pixel_values
img = load_image(online_image_url)
pixel_values = prepare_pixel_values(img)
# Run prediction with return_tokens=True
logits, prob, emb, patch_tokens = image_auth_model(pixel_values, return_tokens=True)
ai_prob = float(prob[0][0].item())
target_logit = logits[0, 0]
image_auth_model.zero_grad()
if patch_tokens.grad is not None:
patch_tokens.grad.zero_()
patch_tokens.retain_grad()
target_logit.backward(retain_graph=True) # Finds d_target_logit/d_patch_tokens in patch_tokens.grad()
# Compute Grad-CAM heatmap
cam_up = compute_cam_from_tokens(patch_tokens, pixel_values)
cam_np = cam_up.detach().cpu().numpy()
orig_np = np.array(img).astype(np.float32) / 255.0
H0, W0, _ = orig_np.shape
cam = cam_np.astype(np.float32)
if cam.shape != (H0, W0):
cam_t = torch.from_numpy(cam).unsqueeze(0).unsqueeze(0)
cam_t = F.interpolate(cam_t, size=(H0, W0), mode="bilinear", align_corners=False)
cam = cam_t[0, 0].cpu().numpy()
cam_uint8 = np.uint8(cam * 255)
heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
alpha = 0.5
overlay = alpha * heatmap_rgb + (1.0 - alpha) * orig_np
overlay = np.clip(overlay, 0.0, 1.0)
return ai_prob, orig_np, overlay
# -----------------------
# Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
# -----------------------
def ui_predict(image_url: str):
if not image_url:
return None, "Awaiting input", "Enter an image URL to run a prediction.", None
try:
img = load_image(image_url)
ai_prob, img, img_with_gradcam_overlay = grad_cam_from_online_url(image_url)
percent = ai_prob * 100.0
verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
headline = verdict
detail = f"{percent:.1f}% probability the image is AI-generated"
return img, headline, detail, img_with_gradcam_overlay
except Exception as e:
return None, "Error", str(e), None
demo = gr.Interface(
fn=ui_predict,
inputs=gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
),
outputs=[
gr.Image(label="Preview"),
gr.Textbox(label="Verdict"),
gr.Textbox(label="Details"),
gr.Image(label="Grad-CAM"),
],
title="Image Authenticicity",
description="Paste an image URL to estimate how likely it is AI-generated.",
api_name="predict",
)
if __name__ == "__main__":
demo.launch()