import numpy as np import torch import joblib import torch.nn as nn import torch.nn.functional as F from transformers import AutoImageProcessor, AutoModel from PIL import Image import requests import gradio as gr import cv2 import os device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # My model from Collab (unchanged) class ImageAuthenticityClassifier(nn.Module): def __init__(self, backbone, w, b): super().__init__() self.backbone = backbone d = w.shape[0] self.head = nn.Linear(d, 1) # Load my trained classifier head with torch.no_grad(): self.head.weight.copy_( w.unsqueeze(0).to(dtype=self.head.weight.dtype, device=self.head.weight.device) ) bias_tensor = torch.tensor( [b], dtype=self.head.bias.dtype, device=self.head.bias.device, ) self.head.bias.copy_(bias_tensor) def forward(self, pixel_values, return_tokens: bool = False): outputs = self.backbone(pixel_values=pixel_values) hidden = outputs.last_hidden_state patch_tokens = hidden[:, 1:, :] emb = patch_tokens.mean(dim = 1) logits = self.head(emb) # Apply classifier head to mean patch token embeddings prob = torch.sigmoid(logits) if (return_tokens): return logits, prob, emb, patch_tokens return logits, prob, emb # Load linear classifier head for logistic regression model_save_path = "logisticRegressionClassifier.joblib" logisticRegressionClassifier = joblib.load(model_save_path) coef = logisticRegressionClassifier.coef_ w = torch.from_numpy(coef.squeeze(0)).float() intercept = logisticRegressionClassifier.intercept_ b = float(intercept[0]) # Load DinoV3 backbone + processor (gated repo via token) HF_TOKEN = os.environ.get("HF_TOKEN", None) backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device) processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,) image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device) # Inference helper functions (unchanged) def load_image(online_image_url): img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB") return img def prepare_pixel_values(img): inputs = processor(images=img, return_tensors="pt") pixel_values = inputs["pixel_values"].to(device) return pixel_values # Unused def predict_from_online_url(online_image_url): img = load_image(online_image_url) pixel_values = prepare_pixel_values(img) with torch.no_grad(): logits, prob, emb = image_auth_model(pixel_values) return float(prob[0][0].item()) # Grad-CAM Helper Functions (Unchanged) ------------------- def compute_cam_from_tokens(patch_tokens, pixel_values, patch_size=16): # Dimension calculations H_in, W_in = pixel_values.shape[-2], pixel_values.shape[-1] H_p = H_in // patch_size W_p = W_in // patch_size num_spatial = H_p * W_p # Tokens and grads for all 200 tokens after CLS. Keep only the spatial patch tokens (drop the 4 global tokens at start) tokens_all = patch_tokens[0] # (200, D) grads_all = patch_tokens.grad[0] # (200, D) tokens_spatial = tokens_all[-num_spatial:, :] # (196, D) grads_spatial = grads_all[-num_spatial:, :] # (196, D) # Get a single weight per feature dimension averaged over all patches weights = grads_spatial.mean(dim=0) # (D,) # For each patch, combine activation and weights to make different importance for each patch, and normalize results. cam_per_patch = (tokens_spatial * weights).sum(dim=-1) cam_per_patch = torch.relu(cam_per_patch) cam_per_patch = cam_per_patch - cam_per_patch.min() cam_per_patch = cam_per_patch / (cam_per_patch.max() + 1e-8) # shape: (N,) cam_grid = cam_per_patch.reshape(H_p, W_p) cam = cam_grid.unsqueeze(0).unsqueeze(0) # (1, 1, H_p, W_p) cam_up = F.interpolate( cam, size=(H_in, W_in), mode="bilinear", align_corners=False, )[0, 0] # (H_in, W_in) return cam_up def grad_cam_from_online_url(online_image_url): # Load image and get pixel_values img = load_image(online_image_url) pixel_values = prepare_pixel_values(img) # Run prediction with return_tokens=True logits, prob, emb, patch_tokens = image_auth_model(pixel_values, return_tokens=True) ai_prob = float(prob[0][0].item()) target_logit = logits[0, 0] image_auth_model.zero_grad() if patch_tokens.grad is not None: patch_tokens.grad.zero_() patch_tokens.retain_grad() target_logit.backward(retain_graph=True) # Finds d_target_logit/d_patch_tokens in patch_tokens.grad() # Compute Grad-CAM heatmap cam_up = compute_cam_from_tokens(patch_tokens, pixel_values) cam_np = cam_up.detach().cpu().numpy() orig_np = np.array(img).astype(np.float32) / 255.0 H0, W0, _ = orig_np.shape cam = cam_np.astype(np.float32) if cam.shape != (H0, W0): cam_t = torch.from_numpy(cam).unsqueeze(0).unsqueeze(0) cam_t = F.interpolate(cam_t, size=(H0, W0), mode="bilinear", align_corners=False) cam = cam_t[0, 0].cpu().numpy() cam_uint8 = np.uint8(cam * 255) heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET) heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 alpha = 0.5 overlay = alpha * heatmap_rgb + (1.0 - alpha) * orig_np overlay = np.clip(overlay, 0.0, 1.0) return ai_prob, orig_np, overlay # ----------------------- # Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol) # ----------------------- def ui_predict(image_url: str): if not image_url: return None, "Awaiting input", "Enter an image URL to run a prediction.", None try: img = load_image(image_url) ai_prob, img, img_with_gradcam_overlay = grad_cam_from_online_url(image_url) percent = ai_prob * 100.0 verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated" headline = verdict detail = f"{percent:.1f}% probability the image is AI-generated" return img, headline, detail, img_with_gradcam_overlay except Exception as e: return None, "Error", str(e), None demo = gr.Interface( fn=ui_predict, inputs=gr.Textbox( label="Image URL", placeholder="https://example.com/image.jpg", ), outputs=[ gr.Image(label="Preview"), gr.Textbox(label="Verdict"), gr.Textbox(label="Details"), gr.Image(label="Grad-CAM"), ], title="Image Authenticicity", description="Paste an image URL to estimate how likely it is AI-generated.", api_name="predict", ) if __name__ == "__main__": demo.launch()