curia / inference.py
cdancette's picture
comments
882bd8f
"""Model and dataset loading, inference, and label extraction functions."""
from __future__ import annotations
import json
import os
from functools import lru_cache
from typing import Any, Dict, Optional
import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as TF
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
)
HF_REPO_ID = "raidium/curia"
HF_DATASET_ID = "raidium/CuriaBench"
class _NumpyToTensor:
"""Convert numpy arrays to tensors while preserving tensors/images."""
def __call__(self, value: Any) -> torch.Tensor:
if isinstance(value, (torch.Tensor, Image.Image)):
return value # type: ignore[return-value]
return torch.tensor(value).unsqueeze(0)
class AdaptativeResizeMask(torch.nn.Module):
"""Resize binary masks with a fallback threshold to avoid empty masks."""
def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None:
super().__init__()
self.target_size = target_size
self.initial_threshold = initial_threshold
def forward(self, mask: torch.Tensor) -> torch.Tensor: # type: ignore[override]
mask = mask.to(dtype=torch.float32)
resized = TF.resize(
mask,
(self.target_size, self.target_size),
interpolation=TF.InterpolationMode.BILINEAR,
antialias=True,
)
binary = resized > self.initial_threshold
if binary.sum() == 0:
new_threshold = torch.max(resized) * 0.5
binary = resized > new_threshold
return binary.to(dtype=torch.float32)
@lru_cache(maxsize=1)
def make_mask_transform(crop_size: int = 512) -> transforms.Compose:
"""Return the resize transform used during training/inference."""
return transforms.Compose(
[
_NumpyToTensor(),
AdaptativeResizeMask(target_size=crop_size),
]
)
def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]:
"""Apply Curia's mask preprocessing so heads get the ROI they expect."""
if mask is None:
return None
mask_transform = make_mask_transform()
try:
mask_arr = np.array(mask)
except Exception:
return None
if mask_arr.size == 0:
return None
if mask_arr.ndim == 3: # (H, W, slices)
tensor = mask_transform(mask_arr.transpose(2, 0, 1)) # (1, slices, H, W)
tensor = tensor.transpose(1, 3).transpose(1, 2) #
else:
tensor = mask_transform(torch.tensor([mask_arr]))
tensor = tensor.unsqueeze(0)
if isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor)
return tensor
@lru_cache(maxsize=1)
def load_id_to_labels() -> Dict[str, Dict[str, str]]:
"""Load the id_to_labels.json mapping file."""
json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json")
with open(json_path, "r") as f:
data = json.load(f)
# convert string keys to integers
for head in data:
data[head] = {int(k): v for k, v in data[head].items()}
return data
@lru_cache(maxsize=1)
def load_processor() -> AutoImageProcessor:
token = os.environ.get("HF_TOKEN")
return AutoImageProcessor.from_pretrained(
HF_REPO_ID, trust_remote_code=True, token=token
)
@lru_cache(maxsize=None)
def load_model(head: str) -> AutoModelForImageClassification:
token = os.environ.get("HF_TOKEN")
model = AutoModelForImageClassification.from_pretrained(
HF_REPO_ID,
trust_remote_code=True,
subfolder=head,
token=token,
)
model.eval()
return model
@lru_cache(maxsize=None)
def load_curia_dataset(subset: str) -> Any:
token = os.environ.get("HF_TOKEN")
ds = load_dataset(
HF_DATASET_ID,
subset,
split="test",
token=token,
)
if isinstance(ds, DatasetDict):
return ds["test"]
return ds
def infer_image(
image: np.ndarray,
head: str,
mask: Any | None = None,
return_probs: bool = True,
) -> torch.Tensor:
processor = load_processor()
model = load_model(head)
with torch.no_grad():
processed = processor(images=image, return_tensors="pt")
mask_tensor = prepare_mask_for_model(mask)
if mask_tensor is not None:
processed["mask"] = mask_tensor
outputs = model(**processed)
logits = outputs["logits"]
if return_probs:
probs = torch.nn.functional.softmax(logits[0], dim=-1)
return probs
else:
return logits[0].squeeze()