Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import joblib | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModel | |
| from PIL import Image | |
| import requests | |
| import gradio as gr | |
| import cv2 | |
| import os | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # My model from Collab (unchanged) | |
| class ImageAuthenticityClassifier(nn.Module): | |
| def __init__(self, backbone, w, b): | |
| super().__init__() | |
| self.backbone = backbone | |
| d = w.shape[0] | |
| self.head = nn.Linear(d, 1) | |
| # Load my trained classifier head | |
| with torch.no_grad(): | |
| self.head.weight.copy_( | |
| w.unsqueeze(0).to(dtype=self.head.weight.dtype, | |
| device=self.head.weight.device) | |
| ) | |
| bias_tensor = torch.tensor( | |
| [b], | |
| dtype=self.head.bias.dtype, | |
| device=self.head.bias.device, | |
| ) | |
| self.head.bias.copy_(bias_tensor) | |
| def forward(self, pixel_values, return_tokens: bool = False): | |
| outputs = self.backbone(pixel_values=pixel_values) | |
| hidden = outputs.last_hidden_state | |
| patch_tokens = hidden[:, 1:, :] | |
| emb = patch_tokens.mean(dim = 1) | |
| logits = self.head(emb) # Apply classifier head to mean patch token embeddings | |
| prob = torch.sigmoid(logits) | |
| if (return_tokens): | |
| return logits, prob, emb, patch_tokens | |
| return logits, prob, emb | |
| # Load linear classifier head for logistic regression | |
| model_save_path = "logisticRegressionClassifier.joblib" | |
| logisticRegressionClassifier = joblib.load(model_save_path) | |
| coef = logisticRegressionClassifier.coef_ | |
| w = torch.from_numpy(coef.squeeze(0)).float() | |
| intercept = logisticRegressionClassifier.intercept_ | |
| b = float(intercept[0]) | |
| # Load DinoV3 backbone + processor (gated repo via token) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device) | |
| processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,) | |
| image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device) | |
| # Inference helper functions (unchanged) | |
| def load_image(online_image_url): | |
| img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB") | |
| return img | |
| def prepare_pixel_values(img): | |
| inputs = processor(images=img, return_tensors="pt") | |
| pixel_values = inputs["pixel_values"].to(device) | |
| return pixel_values | |
| # Unused | |
| def predict_from_online_url(online_image_url): | |
| img = load_image(online_image_url) | |
| pixel_values = prepare_pixel_values(img) | |
| with torch.no_grad(): | |
| logits, prob, emb = image_auth_model(pixel_values) | |
| return float(prob[0][0].item()) | |
| # Grad-CAM Helper Functions (Unchanged) ------------------- | |
| def compute_cam_from_tokens(patch_tokens, pixel_values, patch_size=16): | |
| # Dimension calculations | |
| H_in, W_in = pixel_values.shape[-2], pixel_values.shape[-1] | |
| H_p = H_in // patch_size | |
| W_p = W_in // patch_size | |
| num_spatial = H_p * W_p | |
| # Tokens and grads for all 200 tokens after CLS. Keep only the spatial patch tokens (drop the 4 global tokens at start) | |
| tokens_all = patch_tokens[0] # (200, D) | |
| grads_all = patch_tokens.grad[0] # (200, D) | |
| tokens_spatial = tokens_all[-num_spatial:, :] # (196, D) | |
| grads_spatial = grads_all[-num_spatial:, :] # (196, D) | |
| # Get a single weight per feature dimension averaged over all patches | |
| weights = grads_spatial.mean(dim=0) # (D,) | |
| # For each patch, combine activation and weights to make different importance for each patch, and normalize results. | |
| cam_per_patch = (tokens_spatial * weights).sum(dim=-1) | |
| cam_per_patch = torch.relu(cam_per_patch) | |
| cam_per_patch = cam_per_patch - cam_per_patch.min() | |
| cam_per_patch = cam_per_patch / (cam_per_patch.max() + 1e-8) # shape: (N,) | |
| cam_grid = cam_per_patch.reshape(H_p, W_p) | |
| cam = cam_grid.unsqueeze(0).unsqueeze(0) # (1, 1, H_p, W_p) | |
| cam_up = F.interpolate( | |
| cam, | |
| size=(H_in, W_in), | |
| mode="bilinear", | |
| align_corners=False, | |
| )[0, 0] # (H_in, W_in) | |
| return cam_up | |
| def grad_cam_from_online_url(online_image_url): | |
| # Load image and get pixel_values | |
| img = load_image(online_image_url) | |
| pixel_values = prepare_pixel_values(img) | |
| # Run prediction with return_tokens=True | |
| logits, prob, emb, patch_tokens = image_auth_model(pixel_values, return_tokens=True) | |
| ai_prob = float(prob[0][0].item()) | |
| target_logit = logits[0, 0] | |
| image_auth_model.zero_grad() | |
| if patch_tokens.grad is not None: | |
| patch_tokens.grad.zero_() | |
| patch_tokens.retain_grad() | |
| target_logit.backward(retain_graph=True) # Finds d_target_logit/d_patch_tokens in patch_tokens.grad() | |
| # Compute Grad-CAM heatmap | |
| cam_up = compute_cam_from_tokens(patch_tokens, pixel_values) | |
| cam_np = cam_up.detach().cpu().numpy() | |
| orig_np = np.array(img).astype(np.float32) / 255.0 | |
| H0, W0, _ = orig_np.shape | |
| cam = cam_np.astype(np.float32) | |
| if cam.shape != (H0, W0): | |
| cam_t = torch.from_numpy(cam).unsqueeze(0).unsqueeze(0) | |
| cam_t = F.interpolate(cam_t, size=(H0, W0), mode="bilinear", align_corners=False) | |
| cam = cam_t[0, 0].cpu().numpy() | |
| cam_uint8 = np.uint8(cam * 255) | |
| heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET) | |
| heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| alpha = 0.5 | |
| overlay = alpha * heatmap_rgb + (1.0 - alpha) * orig_np | |
| overlay = np.clip(overlay, 0.0, 1.0) | |
| return ai_prob, orig_np, overlay | |
| # ----------------------- | |
| # Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol) | |
| # ----------------------- | |
| def ui_predict(image_url: str): | |
| if not image_url: | |
| return None, "Awaiting input", "Enter an image URL to run a prediction.", None | |
| try: | |
| img = load_image(image_url) | |
| ai_prob, img, img_with_gradcam_overlay = grad_cam_from_online_url(image_url) | |
| percent = ai_prob * 100.0 | |
| verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated" | |
| headline = verdict | |
| detail = f"{percent:.1f}% probability the image is AI-generated" | |
| return img, headline, detail, img_with_gradcam_overlay | |
| except Exception as e: | |
| return None, "Error", str(e), None | |
| demo = gr.Interface( | |
| fn=ui_predict, | |
| inputs=gr.Textbox( | |
| label="Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| ), | |
| outputs=[ | |
| gr.Image(label="Preview"), | |
| gr.Textbox(label="Verdict"), | |
| gr.Textbox(label="Details"), | |
| gr.Image(label="Grad-CAM"), | |
| ], | |
| title="Image Authenticicity", | |
| description="Paste an image URL to estimate how likely it is AI-generated.", | |
| api_name="predict", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |