Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from training.detectors import DETECTOR | |
| import yaml | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # load the model | |
| def load_model(model_name, config_path, weights_path): | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| config['model_name'] = model_name | |
| model_class = DETECTOR[model_name] | |
| model = model_class(config).to(device) | |
| checkpoint = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(checkpoint, strict=True) | |
| model.eval() | |
| return model | |
| # preprocess a single video | |
| def preprocess_video(video_path, output_dir, frame_num=32): | |
| os.makedirs(output_dir, exist_ok=True) | |
| frames_dir = os.path.join(output_dir, "frames") | |
| os.makedirs(frames_dir, exist_ok=True) | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_indices = np.linspace(0, total_frames - 1, frame_num, dtype=int) | |
| # extract frames | |
| frames = [] | |
| for idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame_path = os.path.join(frames_dir, f"frame_{idx:04d}.png") | |
| cv2.imwrite(frame_path, frame) | |
| frames.append(frame_path) | |
| cap.release() | |
| return frames | |
| # inference on a single video | |
| def infer_video(video_path, model, device): | |
| output_dir = "temp_video_frames" | |
| frames = preprocess_video(video_path, output_dir) | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| probs = [] | |
| for frame_path in frames: | |
| frame = Image.open(frame_path).convert("RGB") | |
| frame = transform(frame).unsqueeze(0).to(device) | |
| data_dict = { | |
| "image": frame, | |
| "label": torch.tensor([0]).to(device), # Dummy label | |
| "label_spe": torch.tensor([0]).to(device), # Dummy specific label | |
| } | |
| with torch.no_grad(): | |
| pred_dict = model(data_dict, inference=True) | |
| logits = pred_dict["cls"] # Shape: [batch_size, num_classes] | |
| prob = torch.softmax(logits, dim=1)[:, 1].item() # Probability of being "fake" | |
| probs.append(prob) | |
| avg_prob = np.mean(probs) | |
| prediction = "Fake" if avg_prob > 0.5 else "Real" | |
| return prediction, avg_prob | |
| # main function for terminal-based inference | |
| def main(video_filename, model_name): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| config_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/config/detector/{model_name}.yaml" | |
| weights_path = f"/teamspace/studios/this_studio/DeepfakeBench/training/weights/{model_name}_best.pth" | |
| if not os.path.exists(config_path): | |
| print(f"Error: Config file for model '{model_name}' not found at {config_path}.") | |
| return | |
| if not os.path.exists(weights_path): | |
| print(f"Error: Weights file for model '{model_name}' not found at {weights_path}.") | |
| return | |
| model = load_model(model_name, config_path, weights_path) | |
| video_path = os.path.join(os.getcwd(), video_filename) | |
| if not os.path.exists(video_path): | |
| print(f"Error: Video file '{video_filename}' not found in the current directory.") | |
| return | |
| prediction, confidence = infer_video(video_path, model, device) | |
| print(f"Model: {model_name}") | |
| print(f"Prediction: {prediction} (Confidence: {confidence:.4f})") | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) != 3: | |
| print("Usage: python inference_script.py <video_filename> <model_name>") | |
| print("Available models: xception, meso4, meso4Inception, efficientnetb4, ucf, etc.") | |
| else: | |
| video_filename = sys.argv[1] | |
| model_name = sys.argv[2] | |
| main(video_filename, model_name) |