|
|
"""Gradio Space for exploring Curia models and CuriaBench datasets. |
|
|
|
|
|
This application allows users to: |
|
|
|
|
|
- Select any available Curia classification head. |
|
|
- Load the matching CuriaBench test split and sample random images per class. |
|
|
- Upload custom medical images that match the model's expected orientation. |
|
|
- Forward images through the selected model head and visualise class probabilities. |
|
|
|
|
|
The space expects an HF token with access to "raidium" resources to be |
|
|
provided via the HF_TOKEN environment variable (configure it as a secret when |
|
|
deploying to Hugging Face Spaces). |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import random |
|
|
from functools import lru_cache |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset |
|
|
from PIL import Image |
|
|
from transformers import ( |
|
|
AutoImageProcessor, |
|
|
AutoModelForImageClassification, |
|
|
) |
|
|
|
|
|
|
|
|
HF_REPO_ID = "raidium/curia" |
|
|
HF_DATASET_ID = "raidium/CuriaBench" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HEAD_OPTIONS: List[Tuple[str, str]] = [ |
|
|
("anatomy-ct", "Anatomy CT"), |
|
|
("anatomy-mri", "Anatomy MRI"), |
|
|
("atlas-stroke", "Atlas Stroke"), |
|
|
("covidx-ct", "COVIDx CT"), |
|
|
("deep-lesion-site", "Deep Lesion Site"), |
|
|
("emidec-classification-mask", "EMIDEC Classification"), |
|
|
("ich", "Intracranial Hemorrhage"), |
|
|
("ixi", "IXI"), |
|
|
("kits", "KiTS"), |
|
|
("kneeMRI", "Knee MRI"), |
|
|
("luna16-3D", "LUNA16 3D"), |
|
|
("neural_foraminal_narrowing", "Neural Foraminal Narrowing"), |
|
|
("oasis", "OASIS"), |
|
|
("spinal_canal_stenosis", "Spinal Canal Stenosis"), |
|
|
("subarticular_stenosis", "Subarticular Stenosis"), |
|
|
] |
|
|
|
|
|
|
|
|
DATASET_OPTIONS: Dict[str, Dict[str, Any]] = { |
|
|
"anatomy-ct": {"label": "Anatomy CT (test)", "head": "anatomy-ct"}, |
|
|
"anatomy-ct-hard": {"label": "Anatomy CT Hard (test)", "head": "anatomy-ct"}, |
|
|
"anatomy-mri": {"label": "Anatomy MRI (test)", "head": "anatomy-mri"}, |
|
|
"covidctset": {"label": "COVID CT Set (test)", "head": "covidx-ct"}, |
|
|
"covidx-ct": {"label": "COVIDx CT (test)", "head": "covidx-ct"}, |
|
|
"deep-lesion-site": {"label": "Deep Lesion Site (test)", "head": "deep-lesion-site"}, |
|
|
"emidec-classification-mask": { |
|
|
"label": "EMIDEC Classification Mask (test)", |
|
|
"head": "emidec-classification-mask", |
|
|
}, |
|
|
"ixi": {"label": "IXI (test)", "head": "ixi"}, |
|
|
"kits": {"label": "KiTS (test)", "head": "kits"}, |
|
|
"kneeMRI": {"label": "Knee MRI (test)", "head": "kneeMRI"}, |
|
|
"luna16": {"label": "LUNA16 (test)", "head": "luna16-3D"}, |
|
|
"luna16-3D": {"label": "LUNA16 3D (test)", "head": "luna16-3D"}, |
|
|
"oasis": {"label": "OASIS (test)", "head": "oasis"}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_token() -> Optional[str]: |
|
|
"""Return the Hugging Face token if configured.""" |
|
|
|
|
|
return os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def load_processor() -> AutoImageProcessor: |
|
|
token = resolve_token() |
|
|
return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=len(HEAD_OPTIONS)) |
|
|
def load_model(head: str) -> AutoModelForImageClassification: |
|
|
token = resolve_token() |
|
|
model = AutoModelForImageClassification.from_pretrained( |
|
|
HF_REPO_ID, |
|
|
trust_remote_code=True, |
|
|
subfolder=head, |
|
|
token=token, |
|
|
) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=len(DATASET_OPTIONS)) |
|
|
def load_curia_dataset(subset: str) -> Any: |
|
|
token = resolve_token() |
|
|
ds = load_dataset( |
|
|
HF_DATASET_ID, |
|
|
subset, |
|
|
split="test", |
|
|
token=token, |
|
|
) |
|
|
if isinstance(ds, DatasetDict): |
|
|
return ds["test"] |
|
|
return ds |
|
|
|
|
|
|
|
|
def to_numpy_image(image: Any) -> np.ndarray: |
|
|
"""Convert dataset or user-provided imagery to a float32 numpy array.""" |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
arr = image |
|
|
elif isinstance(image, Image.Image): |
|
|
arr = np.array(image) |
|
|
else: |
|
|
|
|
|
arr = np.array(image) |
|
|
|
|
|
if arr.ndim == 3 and arr.shape[-1] == 3: |
|
|
|
|
|
arr = arr.mean(axis=-1) |
|
|
|
|
|
if arr.ndim != 2: |
|
|
raise ValueError("Expected a 2D image (H, W). Please provide a single axial/coronal/sagittal slice.") |
|
|
|
|
|
return arr.astype(np.float32) |
|
|
|
|
|
|
|
|
def to_display_image(image: np.ndarray) -> np.ndarray: |
|
|
"""Normalise image for display purposes (uint8, 3-channel).""" |
|
|
|
|
|
arr = np.array(image, copy=True) |
|
|
if not np.isfinite(arr).all(): |
|
|
arr = np.nan_to_num(arr, nan=0.0) |
|
|
|
|
|
arr_min = float(arr.min()) |
|
|
arr_max = float(arr.max()) |
|
|
if arr_max - arr_min > 1e-6: |
|
|
arr = (arr - arr_min) / (arr_max - arr_min) |
|
|
else: |
|
|
arr = np.zeros_like(arr) |
|
|
|
|
|
arr = (arr * 255).clip(0, 255).astype(np.uint8) |
|
|
if arr.ndim == 2: |
|
|
arr = np.stack([arr, arr, arr], axis=-1) |
|
|
return arr |
|
|
|
|
|
|
|
|
def coerce_mask_array(mask: Any) -> Optional[np.ndarray]: |
|
|
if mask is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
arr = np.array(mask) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
if arr.size == 0: |
|
|
return None |
|
|
return arr |
|
|
|
|
|
|
|
|
def prepare_mask_tensor(mask: Any, height: int, width: int) -> Optional[torch.Tensor]: |
|
|
mask_array = coerce_mask_array(mask) |
|
|
if mask_array is None: |
|
|
return None |
|
|
|
|
|
arr = np.squeeze(mask_array) |
|
|
if arr.ndim == 2: |
|
|
arr = arr.reshape(1, height, width) |
|
|
else: |
|
|
if arr.shape[-2:] == (height, width): |
|
|
arr = arr.reshape(-1, height, width) |
|
|
elif arr.shape[0] == height and arr.shape[1] == width: |
|
|
arr = np.transpose(arr, (2, 0, 1)) |
|
|
elif arr.shape[1] == height and arr.shape[2] == width: |
|
|
arr = arr.reshape(arr.shape[0], height, width) |
|
|
elif arr.size % (height * width) == 0: |
|
|
try: |
|
|
arr = arr.reshape(-1, height, width) |
|
|
except ValueError: |
|
|
return None |
|
|
else: |
|
|
return None |
|
|
|
|
|
mask_tensors: List[torch.Tensor] = [] |
|
|
for idx, slice_arr in enumerate(arr): |
|
|
bool_mask = torch.from_numpy(slice_arr > 0) |
|
|
if bool_mask.any(): |
|
|
mask_tensors.append(bool_mask) |
|
|
|
|
|
if not mask_tensors: |
|
|
return None |
|
|
|
|
|
stacked = torch.stack(mask_tensors, dim=0).bool() |
|
|
return stacked |
|
|
|
|
|
|
|
|
def apply_contour_overlay( |
|
|
image: np.ndarray, |
|
|
mask: Any, |
|
|
thickness: int = 1, |
|
|
color: Tuple[int, int, int] = (255, 0, 0), |
|
|
) -> np.ndarray: |
|
|
"""Draw only the contours of segmentation masks instead of filled masks.""" |
|
|
height, width = image.shape[:2] |
|
|
mask_tensor = prepare_mask_tensor(mask, height, width) |
|
|
if mask_tensor is None: |
|
|
return image |
|
|
|
|
|
|
|
|
output = image.copy() |
|
|
|
|
|
|
|
|
for idx in range(mask_tensor.shape[0]): |
|
|
mask_np = mask_tensor[idx].numpy().astype(np.uint8) |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
|
|
|
cv2.drawContours(output, contours, -1, color, thickness) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def render_image_with_mask_info(image: np.ndarray, mask: Any) -> Tuple[np.ndarray, Optional[str]]: |
|
|
display = to_display_image(image) |
|
|
if mask is None: |
|
|
return display, None |
|
|
|
|
|
try: |
|
|
overlaid = apply_contour_overlay(display, mask) |
|
|
return overlaid, "" |
|
|
except Exception: |
|
|
return display, "Mask provided but could not be visualised." |
|
|
|
|
|
|
|
|
def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]: |
|
|
target_feature = dataset.features.get("target") |
|
|
if target_feature and hasattr(target_feature, "names"): |
|
|
names = list(target_feature.names) |
|
|
id2label = {i: name for i, name in enumerate(names)} |
|
|
classes = list(range(len(names))) |
|
|
return classes, id2label |
|
|
|
|
|
|
|
|
targets = dataset["target"] if "target" in dataset.column_names else [] |
|
|
unique = sorted({int(t) for t in targets}) if targets else [] |
|
|
id2label = {i: str(i) for i in unique} |
|
|
return unique, id2label |
|
|
|
|
|
|
|
|
def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int: |
|
|
if "target" not in dataset.column_names: |
|
|
return random.randrange(len(dataset)) |
|
|
|
|
|
if target is None: |
|
|
return random.randrange(len(dataset)) |
|
|
|
|
|
indices = [idx for idx, value in enumerate(dataset["target"]) if value == target] |
|
|
if not indices: |
|
|
return random.randrange(len(dataset)) |
|
|
return random.choice(indices) |
|
|
|
|
|
|
|
|
def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame: |
|
|
"""Return a dataframe sorted by probability desc.""" |
|
|
|
|
|
values = probs.detach().cpu().numpy() |
|
|
rows = [ |
|
|
{"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)} |
|
|
for idx, val in enumerate(values) |
|
|
] |
|
|
df = pd.DataFrame(rows) |
|
|
df.sort_values("probability", ascending=False, inplace=True) |
|
|
return df |
|
|
|
|
|
|
|
|
def infer_image( |
|
|
image: np.ndarray, |
|
|
head: str, |
|
|
) -> Tuple[str, pd.DataFrame]: |
|
|
processor = load_processor() |
|
|
model = load_model(head) |
|
|
with torch.no_grad(): |
|
|
processed = processor(images=image, return_tensors="pt") |
|
|
outputs = model(**processed) |
|
|
print(outputs) |
|
|
logits = outputs["logits"] |
|
|
probs = torch.nn.functional.softmax(logits[0], dim=-1) |
|
|
|
|
|
id2label = model.config.id2label or {} |
|
|
df = format_probabilities(probs, id2label) |
|
|
top_row = df.iloc[0] |
|
|
prediction = f"{top_row['label']} (p={top_row['probability']:.3f})" |
|
|
return prediction, df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_dataset_from_head(head: str) -> Dict[str, Any]: |
|
|
|
|
|
for dataset_key, meta in DATASET_OPTIONS.items(): |
|
|
if meta["head"] == head: |
|
|
return gr.update(value=dataset_key) |
|
|
return gr.update() |
|
|
|
|
|
|
|
|
def load_dataset_metadata(subset: str) -> Tuple[Dict[str, Any], str]: |
|
|
try: |
|
|
dataset = load_curia_dataset(subset) |
|
|
except Exception as exc: |
|
|
dropdown = gr.update(choices=["Random"], value="Random") |
|
|
return dropdown, f"Failed to load dataset: {exc}" |
|
|
|
|
|
classes, id2label = dataset_class_metadata(dataset) |
|
|
if not classes: |
|
|
dropdown = gr.update( |
|
|
choices=["Random"], |
|
|
value="Random", |
|
|
) |
|
|
return dropdown, "No class metadata detected; sampling at random" |
|
|
|
|
|
options = [ |
|
|
"Random", |
|
|
*[f"{cls_id}: {id2label.get(cls_id, str(cls_id))}" for cls_id in classes], |
|
|
] |
|
|
dropdown = gr.update(choices=options, value="Random") |
|
|
return dropdown, f"Loaded {subset} ({len(dataset)} test samples)" |
|
|
|
|
|
|
|
|
def parse_target_selection(selection: str) -> Optional[int]: |
|
|
if not selection or selection == "Random": |
|
|
return None |
|
|
|
|
|
try: |
|
|
target_str = selection.split(":", 1)[0].strip() |
|
|
return int(target_str) |
|
|
except (ValueError, AttributeError): |
|
|
return None |
|
|
|
|
|
|
|
|
def sample_dataset_example( |
|
|
subset: str, |
|
|
target_id: Optional[int], |
|
|
) -> Tuple[np.ndarray, str, Dict[str, Any]]: |
|
|
dataset = load_curia_dataset(subset) |
|
|
index = pick_random_indices(dataset, target_id) |
|
|
record = dataset[index] |
|
|
image = to_numpy_image(record["image"]) |
|
|
mask_array = coerce_mask_array(record.get("mask")) |
|
|
|
|
|
meta = { |
|
|
"index": index, |
|
|
"target": record.get("target"), |
|
|
"mask": mask_array, |
|
|
} |
|
|
|
|
|
return image, f"Sample #{index}", meta |
|
|
|
|
|
|
|
|
def load_dataset_sample( |
|
|
subset: str, |
|
|
target_selection: str, |
|
|
head: str, |
|
|
) -> Tuple[ |
|
|
Optional[np.ndarray], |
|
|
str, |
|
|
pd.DataFrame, |
|
|
Dict[str, Any], |
|
|
Optional[Dict[str, Any]], |
|
|
]: |
|
|
try: |
|
|
target_id = parse_target_selection(target_selection) |
|
|
image, caption, meta = sample_dataset_example(subset, target_id) |
|
|
display, mask_msg = render_image_with_mask_info(image, meta.get("mask")) |
|
|
target = meta.get("target") |
|
|
meta_text = caption |
|
|
if target is not None: |
|
|
meta_text += f" | target={target}" |
|
|
status = "Image loaded. Click 'Run inference' to compute predictions." |
|
|
if mask_msg: |
|
|
status += f" {mask_msg}" |
|
|
meta_text = status + "\n\n" + meta_text |
|
|
|
|
|
|
|
|
ground_truth_update = gr.update(visible=False) |
|
|
if target is not None: |
|
|
model = load_model(head) |
|
|
id2label = model.config.id2label or {} |
|
|
label_name = id2label.get(target, str(target)) |
|
|
ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True) |
|
|
|
|
|
return ( |
|
|
display, |
|
|
meta_text, |
|
|
pd.DataFrame(), |
|
|
ground_truth_update, |
|
|
{"image": image, "mask": meta.get("mask")}, |
|
|
) |
|
|
except Exception as exc: |
|
|
return None, f"Failed to load sample: {exc}", pd.DataFrame(), gr.update(visible=False), None |
|
|
|
|
|
|
|
|
def run_inference( |
|
|
sample_state: Optional[Dict[str, Any]], |
|
|
head: str, |
|
|
) -> Tuple[str, pd.DataFrame]: |
|
|
if not sample_state or "image" not in sample_state: |
|
|
return "Load a dataset sample or upload an image first.", pd.DataFrame() |
|
|
|
|
|
try: |
|
|
image = sample_state["image"] |
|
|
prediction, df = infer_image(image, head) |
|
|
result_text = f"**Prediction:** {prediction}" |
|
|
return result_text, df |
|
|
except Exception as exc: |
|
|
return f"Failed to run inference: {exc}", pd.DataFrame() |
|
|
|
|
|
|
|
|
def handle_upload_preview( |
|
|
image: np.ndarray | Image.Image | None, |
|
|
) -> Tuple[Optional[np.ndarray], str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]: |
|
|
if image is None: |
|
|
return None, "Please upload an image.", pd.DataFrame(), gr.update(visible=False), None |
|
|
|
|
|
try: |
|
|
np_image = to_numpy_image(image) |
|
|
display = to_display_image(np_image) |
|
|
return ( |
|
|
display, |
|
|
"Image uploaded. Click 'Run inference' to compute predictions.", |
|
|
pd.DataFrame(), |
|
|
gr.update(visible=False), |
|
|
{"image": np_image, "mask": None}, |
|
|
) |
|
|
except Exception as exc: |
|
|
return None, f"Failed to load image: {exc}", pd.DataFrame(), gr.update(visible=False), None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_demo() -> gr.Blocks: |
|
|
with gr.Blocks(css=".gr-prose { max-width: 900px; }") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Curia Model Playground |
|
|
|
|
|
Experiment with the multi-head Curia models on CuriaBench evaluation data or |
|
|
your own medical images. Each head expects a single 2D slice in the |
|
|
corresponding plane/orientation as defined for Curia (PL for axial, IL for |
|
|
coronal, IP for sagittal). Ensure images are unwindowed and either raw |
|
|
Hounsfield units (CT) or normalised intensity values (MRI). |
|
|
""" |
|
|
) |
|
|
|
|
|
head_dropdown = gr.Dropdown( |
|
|
label="Model head", |
|
|
choices=[(label, key) for key, label in HEAD_OPTIONS], |
|
|
value="anatomy-ct", |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Load dataset sample") |
|
|
dataset_dropdown = gr.Dropdown( |
|
|
label="CuriaBench subset", |
|
|
choices=[(meta["label"], key) for key, meta in DATASET_OPTIONS.items()], |
|
|
value="anatomy-ct", |
|
|
) |
|
|
dataset_status = gr.Markdown("Select a dataset to load class metadata.") |
|
|
class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random") |
|
|
dataset_btn = gr.Button("Load dataset sample") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Upload custom image") |
|
|
upload_component = gr.Image(label="Upload image", image_mode="L", type="numpy") |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
infer_btn = gr.Button("Run inference", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_display = gr.Image(label="Image", interactive=False, type="numpy") |
|
|
ground_truth_display = gr.Markdown(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Predictions") |
|
|
status_text = gr.Markdown() |
|
|
prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"]) |
|
|
|
|
|
image_state = gr.State() |
|
|
|
|
|
|
|
|
head_dropdown.change( |
|
|
fn=update_dataset_from_head, |
|
|
inputs=[head_dropdown], |
|
|
outputs=[dataset_dropdown], |
|
|
) |
|
|
|
|
|
dataset_dropdown.change( |
|
|
fn=load_dataset_metadata, |
|
|
inputs=[dataset_dropdown], |
|
|
outputs=[class_dropdown, dataset_status], |
|
|
) |
|
|
|
|
|
dataset_btn.click( |
|
|
fn=load_dataset_sample, |
|
|
inputs=[dataset_dropdown, class_dropdown, head_dropdown], |
|
|
outputs=[ |
|
|
image_display, |
|
|
status_text, |
|
|
prediction_probs, |
|
|
ground_truth_display, |
|
|
image_state, |
|
|
], |
|
|
) |
|
|
|
|
|
upload_component.upload( |
|
|
fn=handle_upload_preview, |
|
|
inputs=[upload_component], |
|
|
outputs=[ |
|
|
image_display, |
|
|
status_text, |
|
|
prediction_probs, |
|
|
ground_truth_display, |
|
|
image_state, |
|
|
], |
|
|
) |
|
|
|
|
|
infer_btn.click( |
|
|
fn=run_inference, |
|
|
inputs=[image_state, head_dropdown], |
|
|
outputs=[status_text, prediction_probs], |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### Notes |
|
|
|
|
|
- Configure the `HF_TOKEN` secret in your Space to load private checkpoints |
|
|
and datasets from the `raidium` organisation. |
|
|
- When masks are available in the dataset sample, their contours are drawn on the |
|
|
image for visual reference using OpenCV. |
|
|
- Uploaded images must be single-channel arrays. Multi-channel inputs are |
|
|
converted to grayscale automatically. |
|
|
""" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = build_demo() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|