|
|
"""Model and dataset loading, inference, and label extraction functions.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
from functools import lru_cache |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from datasets import DatasetDict, load_dataset |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms import functional as TF |
|
|
from transformers import ( |
|
|
AutoImageProcessor, |
|
|
AutoModelForImageClassification, |
|
|
) |
|
|
|
|
|
HF_REPO_ID = "raidium/curia" |
|
|
HF_DATASET_ID = "raidium/CuriaBench" |
|
|
|
|
|
|
|
|
class _NumpyToTensor: |
|
|
"""Convert numpy arrays to tensors while preserving tensors/images.""" |
|
|
|
|
|
def __call__(self, value: Any) -> torch.Tensor: |
|
|
if isinstance(value, (torch.Tensor, Image.Image)): |
|
|
return value |
|
|
return torch.tensor(value).unsqueeze(0) |
|
|
|
|
|
|
|
|
class AdaptativeResizeMask(torch.nn.Module): |
|
|
"""Resize binary masks with a fallback threshold to avoid empty masks.""" |
|
|
|
|
|
def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None: |
|
|
super().__init__() |
|
|
self.target_size = target_size |
|
|
self.initial_threshold = initial_threshold |
|
|
|
|
|
def forward(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
mask = mask.to(dtype=torch.float32) |
|
|
resized = TF.resize( |
|
|
mask, |
|
|
(self.target_size, self.target_size), |
|
|
interpolation=TF.InterpolationMode.BILINEAR, |
|
|
antialias=True, |
|
|
) |
|
|
binary = resized > self.initial_threshold |
|
|
if binary.sum() == 0: |
|
|
new_threshold = torch.max(resized) * 0.5 |
|
|
binary = resized > new_threshold |
|
|
return binary.to(dtype=torch.float32) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def make_mask_transform(crop_size: int = 512) -> transforms.Compose: |
|
|
"""Return the resize transform used during training/inference.""" |
|
|
|
|
|
return transforms.Compose( |
|
|
[ |
|
|
_NumpyToTensor(), |
|
|
AdaptativeResizeMask(target_size=crop_size), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]: |
|
|
"""Apply Curia's mask preprocessing so heads get the ROI they expect.""" |
|
|
|
|
|
if mask is None: |
|
|
return None |
|
|
|
|
|
mask_transform = make_mask_transform() |
|
|
|
|
|
try: |
|
|
mask_arr = np.array(mask) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
if mask_arr.size == 0: |
|
|
return None |
|
|
|
|
|
if mask_arr.ndim == 3: |
|
|
tensor = mask_transform(mask_arr.transpose(2, 0, 1)) |
|
|
tensor = tensor.transpose(1, 3).transpose(1, 2) |
|
|
else: |
|
|
tensor = mask_transform(torch.tensor([mask_arr])) |
|
|
tensor = tensor.unsqueeze(0) |
|
|
|
|
|
if isinstance(tensor, np.ndarray): |
|
|
tensor = torch.from_numpy(tensor) |
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def load_id_to_labels() -> Dict[str, Dict[str, str]]: |
|
|
"""Load the id_to_labels.json mapping file.""" |
|
|
json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json") |
|
|
with open(json_path, "r") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
for head in data: |
|
|
data[head] = {int(k): v for k, v in data[head].items()} |
|
|
return data |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def load_processor() -> AutoImageProcessor: |
|
|
token = os.environ.get("HF_TOKEN") |
|
|
return AutoImageProcessor.from_pretrained( |
|
|
HF_REPO_ID, trust_remote_code=True, token=token |
|
|
) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
def load_model(head: str) -> AutoModelForImageClassification: |
|
|
token = os.environ.get("HF_TOKEN") |
|
|
model = AutoModelForImageClassification.from_pretrained( |
|
|
HF_REPO_ID, |
|
|
trust_remote_code=True, |
|
|
subfolder=head, |
|
|
token=token, |
|
|
) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
def load_curia_dataset(subset: str) -> Any: |
|
|
token = os.environ.get("HF_TOKEN") |
|
|
ds = load_dataset( |
|
|
HF_DATASET_ID, |
|
|
subset, |
|
|
split="test", |
|
|
token=token, |
|
|
) |
|
|
if isinstance(ds, DatasetDict): |
|
|
return ds["test"] |
|
|
return ds |
|
|
|
|
|
def infer_image( |
|
|
image: np.ndarray, |
|
|
head: str, |
|
|
mask: Any | None = None, |
|
|
return_probs: bool = True, |
|
|
) -> torch.Tensor: |
|
|
processor = load_processor() |
|
|
model = load_model(head) |
|
|
with torch.no_grad(): |
|
|
processed = processor(images=image, return_tensors="pt") |
|
|
mask_tensor = prepare_mask_for_model(mask) |
|
|
if mask_tensor is not None: |
|
|
processed["mask"] = mask_tensor |
|
|
outputs = model(**processed) |
|
|
logits = outputs["logits"] |
|
|
if return_probs: |
|
|
probs = torch.nn.functional.softmax(logits[0], dim=-1) |
|
|
return probs |
|
|
else: |
|
|
return logits[0].squeeze() |
|
|
|