Spaces:
Running
Running
File size: 6,742 Bytes
e8c13db 30eefab e8c13db 30eefab 08c9ad5 e8c13db 30eefab e8c13db 30eefab e8c13db 30eefab e8c13db 30eefab e8c13db 30eefab e8c13db 30eefab e8c13db 30eefab e8c13db 30eefab e8c13db 30eefab e8c13db 00836ce e8c13db 00836ce e8c13db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import numpy as np
import torch
import joblib
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
import gradio as gr
import cv2
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# My model from Collab (unchanged)
class ImageAuthenticityClassifier(nn.Module):
def __init__(self, backbone, w, b):
super().__init__()
self.backbone = backbone
d = w.shape[0]
self.head = nn.Linear(d, 1)
# Load my trained classifier head
with torch.no_grad():
self.head.weight.copy_(
w.unsqueeze(0).to(dtype=self.head.weight.dtype,
device=self.head.weight.device)
)
bias_tensor = torch.tensor(
[b],
dtype=self.head.bias.dtype,
device=self.head.bias.device,
)
self.head.bias.copy_(bias_tensor)
def forward(self, pixel_values, return_tokens: bool = False):
outputs = self.backbone(pixel_values=pixel_values)
hidden = outputs.last_hidden_state
patch_tokens = hidden[:, 1:, :]
emb = patch_tokens.mean(dim = 1)
logits = self.head(emb) # Apply classifier head to mean patch token embeddings
prob = torch.sigmoid(logits)
if (return_tokens):
return logits, prob, emb, patch_tokens
return logits, prob, emb
# Load linear classifier head for logistic regression
model_save_path = "logisticRegressionClassifier.joblib"
logisticRegressionClassifier = joblib.load(model_save_path)
coef = logisticRegressionClassifier.coef_
w = torch.from_numpy(coef.squeeze(0)).float()
intercept = logisticRegressionClassifier.intercept_
b = float(intercept[0])
# Load DinoV3 backbone + processor (gated repo via token)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
# Inference helper functions (unchanged)
def load_image(online_image_url):
img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
return img
def prepare_pixel_values(img):
inputs = processor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
return pixel_values
# Unused
def predict_from_online_url(online_image_url):
img = load_image(online_image_url)
pixel_values = prepare_pixel_values(img)
with torch.no_grad():
logits, prob, emb = image_auth_model(pixel_values)
return float(prob[0][0].item())
# Grad-CAM Helper Functions (Unchanged) -------------------
def compute_cam_from_tokens(patch_tokens, pixel_values, patch_size=16):
# Dimension calculations
H_in, W_in = pixel_values.shape[-2], pixel_values.shape[-1]
H_p = H_in // patch_size
W_p = W_in // patch_size
num_spatial = H_p * W_p
# Tokens and grads for all 200 tokens after CLS. Keep only the spatial patch tokens (drop the 4 global tokens at start)
tokens_all = patch_tokens[0] # (200, D)
grads_all = patch_tokens.grad[0] # (200, D)
tokens_spatial = tokens_all[-num_spatial:, :] # (196, D)
grads_spatial = grads_all[-num_spatial:, :] # (196, D)
# Get a single weight per feature dimension averaged over all patches
weights = grads_spatial.mean(dim=0) # (D,)
# For each patch, combine activation and weights to make different importance for each patch, and normalize results.
cam_per_patch = (tokens_spatial * weights).sum(dim=-1)
cam_per_patch = torch.relu(cam_per_patch)
cam_per_patch = cam_per_patch - cam_per_patch.min()
cam_per_patch = cam_per_patch / (cam_per_patch.max() + 1e-8) # shape: (N,)
cam_grid = cam_per_patch.reshape(H_p, W_p)
cam = cam_grid.unsqueeze(0).unsqueeze(0) # (1, 1, H_p, W_p)
cam_up = F.interpolate(
cam,
size=(H_in, W_in),
mode="bilinear",
align_corners=False,
)[0, 0] # (H_in, W_in)
return cam_up
def grad_cam_from_online_url(online_image_url):
# Load image and get pixel_values
img = load_image(online_image_url)
pixel_values = prepare_pixel_values(img)
# Run prediction with return_tokens=True
logits, prob, emb, patch_tokens = image_auth_model(pixel_values, return_tokens=True)
ai_prob = float(prob[0][0].item())
target_logit = logits[0, 0]
image_auth_model.zero_grad()
if patch_tokens.grad is not None:
patch_tokens.grad.zero_()
patch_tokens.retain_grad()
target_logit.backward(retain_graph=True) # Finds d_target_logit/d_patch_tokens in patch_tokens.grad()
# Compute Grad-CAM heatmap
cam_up = compute_cam_from_tokens(patch_tokens, pixel_values)
cam_np = cam_up.detach().cpu().numpy()
orig_np = np.array(img).astype(np.float32) / 255.0
H0, W0, _ = orig_np.shape
cam = cam_np.astype(np.float32)
if cam.shape != (H0, W0):
cam_t = torch.from_numpy(cam).unsqueeze(0).unsqueeze(0)
cam_t = F.interpolate(cam_t, size=(H0, W0), mode="bilinear", align_corners=False)
cam = cam_t[0, 0].cpu().numpy()
cam_uint8 = np.uint8(cam * 255)
heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
alpha = 0.5
overlay = alpha * heatmap_rgb + (1.0 - alpha) * orig_np
overlay = np.clip(overlay, 0.0, 1.0)
return ai_prob, orig_np, overlay
# -----------------------
# Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
# -----------------------
def ui_predict(image_url: str):
if not image_url:
return None, "Awaiting input", "Enter an image URL to run a prediction.", None
try:
img = load_image(image_url)
ai_prob, img, img_with_gradcam_overlay = grad_cam_from_online_url(image_url)
percent = ai_prob * 100.0
verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
headline = verdict
detail = f"{percent:.1f}% probability the image is AI-generated"
return img, headline, detail, img_with_gradcam_overlay
except Exception as e:
return None, "Error", str(e), None
demo = gr.Interface(
fn=ui_predict,
inputs=gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
),
outputs=[
gr.Image(label="Preview"),
gr.Textbox(label="Verdict"),
gr.Textbox(label="Details"),
gr.Image(label="Grad-CAM"),
],
title="Image Authenticicity",
description="Paste an image URL to estimate how likely it is AI-generated.",
api_name="predict",
)
if __name__ == "__main__":
demo.launch() |