curia / app.py
cdancette's picture
contour instead of overlay
b49f319
raw
history blame
19 kB
"""Gradio Space for exploring Curia models and CuriaBench datasets.
This application allows users to:
- Select any available Curia classification head.
- Load the matching CuriaBench test split and sample random images per class.
- Upload custom medical images that match the model's expected orientation.
- Forward images through the selected model head and visualise class probabilities.
The space expects an HF token with access to "raidium" resources to be
provided via the HF_TOKEN environment variable (configure it as a secret when
deploying to Hugging Face Spaces).
"""
from __future__ import annotations
import os
import random
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Union
import cv2
import gradio as gr
import numpy as np
import pandas as pd
import torch
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
from PIL import Image
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
)
HF_REPO_ID = "raidium/curia"
HF_DATASET_ID = "raidium/CuriaBench"
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
HEAD_OPTIONS: List[Tuple[str, str]] = [
("anatomy-ct", "Anatomy CT"),
("anatomy-mri", "Anatomy MRI"),
("atlas-stroke", "Atlas Stroke"),
("covidx-ct", "COVIDx CT"),
("deep-lesion-site", "Deep Lesion Site"),
("emidec-classification-mask", "EMIDEC Classification"),
("ich", "Intracranial Hemorrhage"),
("ixi", "IXI"),
("kits", "KiTS"),
("kneeMRI", "Knee MRI"),
("luna16-3D", "LUNA16 3D"),
("neural_foraminal_narrowing", "Neural Foraminal Narrowing"),
("oasis", "OASIS"),
("spinal_canal_stenosis", "Spinal Canal Stenosis"),
("subarticular_stenosis", "Subarticular Stenosis"),
]
DATASET_OPTIONS: Dict[str, Dict[str, Any]] = {
"anatomy-ct": {"label": "Anatomy CT (test)", "head": "anatomy-ct"},
"anatomy-ct-hard": {"label": "Anatomy CT Hard (test)", "head": "anatomy-ct"},
"anatomy-mri": {"label": "Anatomy MRI (test)", "head": "anatomy-mri"},
"covidctset": {"label": "COVID CT Set (test)", "head": "covidx-ct"},
"covidx-ct": {"label": "COVIDx CT (test)", "head": "covidx-ct"},
"deep-lesion-site": {"label": "Deep Lesion Site (test)", "head": "deep-lesion-site"},
"emidec-classification-mask": {
"label": "EMIDEC Classification Mask (test)",
"head": "emidec-classification-mask",
},
"ixi": {"label": "IXI (test)", "head": "ixi"},
"kits": {"label": "KiTS (test)", "head": "kits"},
"kneeMRI": {"label": "Knee MRI (test)", "head": "kneeMRI"},
"luna16": {"label": "LUNA16 (test)", "head": "luna16-3D"},
"luna16-3D": {"label": "LUNA16 3D (test)", "head": "luna16-3D"},
"oasis": {"label": "OASIS (test)", "head": "oasis"},
}
# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------
def resolve_token() -> Optional[str]:
"""Return the Hugging Face token if configured."""
return os.environ.get("HF_TOKEN")
@lru_cache(maxsize=1)
def load_processor() -> AutoImageProcessor:
token = resolve_token()
return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token)
@lru_cache(maxsize=len(HEAD_OPTIONS))
def load_model(head: str) -> AutoModelForImageClassification:
token = resolve_token()
model = AutoModelForImageClassification.from_pretrained(
HF_REPO_ID,
trust_remote_code=True,
subfolder=head,
token=token,
)
model.eval()
return model
@lru_cache(maxsize=len(DATASET_OPTIONS))
def load_curia_dataset(subset: str) -> Any:
token = resolve_token()
ds = load_dataset(
HF_DATASET_ID,
subset,
split="test",
token=token,
)
if isinstance(ds, DatasetDict):
return ds["test"]
return ds
def to_numpy_image(image: Any) -> np.ndarray:
"""Convert dataset or user-provided imagery to a float32 numpy array."""
if isinstance(image, np.ndarray):
arr = image
elif isinstance(image, Image.Image):
arr = np.array(image)
else:
# Some datasets provide nested dicts or lists – attempt to coerce.
arr = np.array(image)
if arr.ndim == 3 and arr.shape[-1] == 3:
# Convert RGB to grayscale by averaging channels
arr = arr.mean(axis=-1)
if arr.ndim != 2:
raise ValueError("Expected a 2D image (H, W). Please provide a single axial/coronal/sagittal slice.")
return arr.astype(np.float32)
def to_display_image(image: np.ndarray) -> np.ndarray:
"""Normalise image for display purposes (uint8, 3-channel)."""
arr = np.array(image, copy=True)
if not np.isfinite(arr).all():
arr = np.nan_to_num(arr, nan=0.0)
arr_min = float(arr.min())
arr_max = float(arr.max())
if arr_max - arr_min > 1e-6:
arr = (arr - arr_min) / (arr_max - arr_min)
else:
arr = np.zeros_like(arr)
arr = (arr * 255).clip(0, 255).astype(np.uint8)
if arr.ndim == 2:
arr = np.stack([arr, arr, arr], axis=-1)
return arr
def coerce_mask_array(mask: Any) -> Optional[np.ndarray]:
if mask is None:
return None
try:
arr = np.array(mask)
except Exception:
return None
if arr.size == 0:
return None
return arr
def prepare_mask_tensor(mask: Any, height: int, width: int) -> Optional[torch.Tensor]:
mask_array = coerce_mask_array(mask)
if mask_array is None:
return None
arr = np.squeeze(mask_array)
if arr.ndim == 2:
arr = arr.reshape(1, height, width)
else:
if arr.shape[-2:] == (height, width):
arr = arr.reshape(-1, height, width)
elif arr.shape[0] == height and arr.shape[1] == width:
arr = np.transpose(arr, (2, 0, 1))
elif arr.shape[1] == height and arr.shape[2] == width:
arr = arr.reshape(arr.shape[0], height, width)
elif arr.size % (height * width) == 0:
try:
arr = arr.reshape(-1, height, width)
except ValueError:
return None
else:
return None
mask_tensors: List[torch.Tensor] = []
for idx, slice_arr in enumerate(arr):
bool_mask = torch.from_numpy(slice_arr > 0)
if bool_mask.any():
mask_tensors.append(bool_mask)
if not mask_tensors:
return None
stacked = torch.stack(mask_tensors, dim=0).bool()
return stacked
def apply_contour_overlay(
image: np.ndarray,
mask: Any,
thickness: int = 1,
color: Tuple[int, int, int] = (255, 0, 0),
) -> np.ndarray:
"""Draw only the contours of segmentation masks instead of filled masks."""
height, width = image.shape[:2]
mask_tensor = prepare_mask_tensor(mask, height, width)
if mask_tensor is None:
return image
# Work on a copy of the image
output = image.copy()
# Process each mask separately
for idx in range(mask_tensor.shape[0]):
mask_np = mask_tensor[idx].numpy().astype(np.uint8)
# Find contours
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Draw contours on the image
cv2.drawContours(output, contours, -1, color, thickness)
return output
def render_image_with_mask_info(image: np.ndarray, mask: Any) -> Tuple[np.ndarray, Optional[str]]:
display = to_display_image(image)
if mask is None:
return display, None
try:
overlaid = apply_contour_overlay(display, mask)
return overlaid, ""
except Exception:
return display, "Mask provided but could not be visualised."
def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]:
target_feature = dataset.features.get("target")
if target_feature and hasattr(target_feature, "names"):
names = list(target_feature.names)
id2label = {i: name for i, name in enumerate(names)}
classes = list(range(len(names)))
return classes, id2label
# Fall back to generic inspection
targets = dataset["target"] if "target" in dataset.column_names else []
unique = sorted({int(t) for t in targets}) if targets else []
id2label = {i: str(i) for i in unique}
return unique, id2label
def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
if "target" not in dataset.column_names:
return random.randrange(len(dataset))
if target is None:
return random.randrange(len(dataset))
indices = [idx for idx, value in enumerate(dataset["target"]) if value == target]
if not indices:
return random.randrange(len(dataset))
return random.choice(indices)
def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
"""Return a dataframe sorted by probability desc."""
values = probs.detach().cpu().numpy()
rows = [
{"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)}
for idx, val in enumerate(values)
]
df = pd.DataFrame(rows)
df.sort_values("probability", ascending=False, inplace=True)
return df
def infer_image(
image: np.ndarray,
head: str,
) -> Tuple[str, pd.DataFrame]:
processor = load_processor()
model = load_model(head)
with torch.no_grad():
processed = processor(images=image, return_tensors="pt")
outputs = model(**processed)
print(outputs)
logits = outputs["logits"]
probs = torch.nn.functional.softmax(logits[0], dim=-1)
id2label = model.config.id2label or {}
df = format_probabilities(probs, id2label)
top_row = df.iloc[0]
prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
return prediction, df
# ---------------------------------------------------------------------------
# Gradio callbacks
# ---------------------------------------------------------------------------
def update_dataset_from_head(head: str) -> Dict[str, Any]:
# Find the first dataset that matches this head
for dataset_key, meta in DATASET_OPTIONS.items():
if meta["head"] == head:
return gr.update(value=dataset_key)
return gr.update()
def load_dataset_metadata(subset: str) -> Tuple[Dict[str, Any], str]:
try:
dataset = load_curia_dataset(subset)
except Exception as exc: # pragma: no cover - surfaced in UI
dropdown = gr.update(choices=["Random"], value="Random")
return dropdown, f"Failed to load dataset: {exc}"
classes, id2label = dataset_class_metadata(dataset)
if not classes:
dropdown = gr.update(
choices=["Random"],
value="Random",
)
return dropdown, "No class metadata detected; sampling at random"
options = [
"Random",
*[f"{cls_id}: {id2label.get(cls_id, str(cls_id))}" for cls_id in classes],
]
dropdown = gr.update(choices=options, value="Random")
return dropdown, f"Loaded {subset} ({len(dataset)} test samples)"
def parse_target_selection(selection: str) -> Optional[int]:
if not selection or selection == "Random":
return None
try:
target_str = selection.split(":", 1)[0].strip()
return int(target_str)
except (ValueError, AttributeError):
return None
def sample_dataset_example(
subset: str,
target_id: Optional[int],
) -> Tuple[np.ndarray, str, Dict[str, Any]]:
dataset = load_curia_dataset(subset)
index = pick_random_indices(dataset, target_id)
record = dataset[index]
image = to_numpy_image(record["image"])
mask_array = coerce_mask_array(record.get("mask"))
meta = {
"index": index,
"target": record.get("target"),
"mask": mask_array,
}
return image, f"Sample #{index}", meta
def load_dataset_sample(
subset: str,
target_selection: str,
head: str,
) -> Tuple[
Optional[np.ndarray],
str,
pd.DataFrame,
Dict[str, Any],
Optional[Dict[str, Any]],
]:
try:
target_id = parse_target_selection(target_selection)
image, caption, meta = sample_dataset_example(subset, target_id)
display, mask_msg = render_image_with_mask_info(image, meta.get("mask"))
target = meta.get("target")
meta_text = caption
if target is not None:
meta_text += f" | target={target}"
status = "Image loaded. Click 'Run inference' to compute predictions."
if mask_msg:
status += f" {mask_msg}"
meta_text = status + "\n\n" + meta_text
# Generate ground truth display
ground_truth_update = gr.update(visible=False)
if target is not None:
model = load_model(head)
id2label = model.config.id2label or {}
label_name = id2label.get(target, str(target))
ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True)
return (
display,
meta_text,
pd.DataFrame(),
ground_truth_update,
{"image": image, "mask": meta.get("mask")},
)
except Exception as exc: # pragma: no cover - surfaced in UI
return None, f"Failed to load sample: {exc}", pd.DataFrame(), gr.update(visible=False), None
def run_inference(
sample_state: Optional[Dict[str, Any]],
head: str,
) -> Tuple[str, pd.DataFrame]:
if not sample_state or "image" not in sample_state:
return "Load a dataset sample or upload an image first.", pd.DataFrame()
try:
image = sample_state["image"]
prediction, df = infer_image(image, head)
result_text = f"**Prediction:** {prediction}"
return result_text, df
except Exception as exc: # pragma: no cover - surfaced in UI
return f"Failed to run inference: {exc}", pd.DataFrame()
def handle_upload_preview(
image: np.ndarray | Image.Image | None,
) -> Tuple[Optional[np.ndarray], str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]:
if image is None:
return None, "Please upload an image.", pd.DataFrame(), gr.update(visible=False), None
try:
np_image = to_numpy_image(image)
display = to_display_image(np_image)
return (
display,
"Image uploaded. Click 'Run inference' to compute predictions.",
pd.DataFrame(),
gr.update(visible=False),
{"image": np_image, "mask": None},
)
except Exception as exc: # pragma: no cover - surfaced in UI
return None, f"Failed to load image: {exc}", pd.DataFrame(), gr.update(visible=False), None
# ---------------------------------------------------------------------------
# Interface definition
# ---------------------------------------------------------------------------
def build_demo() -> gr.Blocks:
with gr.Blocks(css=".gr-prose { max-width: 900px; }") as demo:
gr.Markdown(
"""
# Curia Model Playground
Experiment with the multi-head Curia models on CuriaBench evaluation data or
your own medical images. Each head expects a single 2D slice in the
corresponding plane/orientation as defined for Curia (PL for axial, IL for
coronal, IP for sagittal). Ensure images are unwindowed and either raw
Hounsfield units (CT) or normalised intensity values (MRI).
"""
)
head_dropdown = gr.Dropdown(
label="Model head",
choices=[(label, key) for key, label in HEAD_OPTIONS],
value="anatomy-ct",
)
gr.Markdown("---")
with gr.Row():
with gr.Column():
gr.Markdown("### Load dataset sample")
dataset_dropdown = gr.Dropdown(
label="CuriaBench subset",
choices=[(meta["label"], key) for key, meta in DATASET_OPTIONS.items()],
value="anatomy-ct",
)
dataset_status = gr.Markdown("Select a dataset to load class metadata.")
class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random")
dataset_btn = gr.Button("Load dataset sample")
with gr.Column():
gr.Markdown("### Upload custom image")
upload_component = gr.Image(label="Upload image", image_mode="L", type="numpy")
gr.Markdown("---")
infer_btn = gr.Button("Run inference", variant="primary")
with gr.Row():
with gr.Column():
image_display = gr.Image(label="Image", interactive=False, type="numpy")
ground_truth_display = gr.Markdown(visible=False)
with gr.Column():
gr.Markdown("### Predictions")
status_text = gr.Markdown()
prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
image_state = gr.State()
# Event wiring
head_dropdown.change(
fn=update_dataset_from_head,
inputs=[head_dropdown],
outputs=[dataset_dropdown],
)
dataset_dropdown.change(
fn=load_dataset_metadata,
inputs=[dataset_dropdown],
outputs=[class_dropdown, dataset_status],
)
dataset_btn.click(
fn=load_dataset_sample,
inputs=[dataset_dropdown, class_dropdown, head_dropdown],
outputs=[
image_display,
status_text,
prediction_probs,
ground_truth_display,
image_state,
],
)
upload_component.upload(
fn=handle_upload_preview,
inputs=[upload_component],
outputs=[
image_display,
status_text,
prediction_probs,
ground_truth_display,
image_state,
],
)
infer_btn.click(
fn=run_inference,
inputs=[image_state, head_dropdown],
outputs=[status_text, prediction_probs],
)
gr.Markdown(
"""
### Notes
- Configure the `HF_TOKEN` secret in your Space to load private checkpoints
and datasets from the `raidium` organisation.
- When masks are available in the dataset sample, their contours are drawn on the
image for visual reference using OpenCV.
- Uploaded images must be single-channel arrays. Multi-channel inputs are
converted to grayscale automatically.
"""
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch()