"""Gradio Space for exploring Curia models and CuriaBench datasets. This application allows users to: - Select any available Curia classification head. - Load the matching CuriaBench test split and sample random images per class. - Upload custom medical images that match the model's expected orientation. - Forward images through the selected model head and visualise class probabilities. The space expects an HF token with access to "raidium" resources to be provided via the HF_TOKEN environment variable (configure it as a secret when deploying to Hugging Face Spaces). """ from __future__ import annotations import os import random from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple, Union import cv2 import gradio as gr import numpy as np import pandas as pd import torch from datasets import Dataset, DatasetDict, IterableDataset, load_dataset from PIL import Image from transformers import ( AutoImageProcessor, AutoModelForImageClassification, ) HF_REPO_ID = "raidium/curia" HF_DATASET_ID = "raidium/CuriaBench" # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- HEAD_OPTIONS: List[Tuple[str, str]] = [ ("anatomy-ct", "Anatomy CT"), ("anatomy-mri", "Anatomy MRI"), ("atlas-stroke", "Atlas Stroke"), ("covidx-ct", "COVIDx CT"), ("deep-lesion-site", "Deep Lesion Site"), ("emidec-classification-mask", "EMIDEC Classification"), ("ich", "Intracranial Hemorrhage"), ("ixi", "IXI"), ("kits", "KiTS"), ("kneeMRI", "Knee MRI"), ("luna16-3D", "LUNA16 3D"), ("neural_foraminal_narrowing", "Neural Foraminal Narrowing"), ("oasis", "OASIS"), ("spinal_canal_stenosis", "Spinal Canal Stenosis"), ("subarticular_stenosis", "Subarticular Stenosis"), ] DATASET_OPTIONS: Dict[str, Dict[str, Any]] = { "anatomy-ct": {"label": "Anatomy CT (test)", "head": "anatomy-ct"}, "anatomy-ct-hard": {"label": "Anatomy CT Hard (test)", "head": "anatomy-ct"}, "anatomy-mri": {"label": "Anatomy MRI (test)", "head": "anatomy-mri"}, "covidctset": {"label": "COVID CT Set (test)", "head": "covidx-ct"}, "covidx-ct": {"label": "COVIDx CT (test)", "head": "covidx-ct"}, "deep-lesion-site": {"label": "Deep Lesion Site (test)", "head": "deep-lesion-site"}, "emidec-classification-mask": { "label": "EMIDEC Classification Mask (test)", "head": "emidec-classification-mask", }, "ixi": {"label": "IXI (test)", "head": "ixi"}, "kits": {"label": "KiTS (test)", "head": "kits"}, "kneeMRI": {"label": "Knee MRI (test)", "head": "kneeMRI"}, "luna16": {"label": "LUNA16 (test)", "head": "luna16-3D"}, "luna16-3D": {"label": "LUNA16 3D (test)", "head": "luna16-3D"}, "oasis": {"label": "OASIS (test)", "head": "oasis"}, } # --------------------------------------------------------------------------- # Utility helpers # --------------------------------------------------------------------------- def resolve_token() -> Optional[str]: """Return the Hugging Face token if configured.""" return os.environ.get("HF_TOKEN") @lru_cache(maxsize=1) def load_processor() -> AutoImageProcessor: token = resolve_token() return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token) @lru_cache(maxsize=len(HEAD_OPTIONS)) def load_model(head: str) -> AutoModelForImageClassification: token = resolve_token() model = AutoModelForImageClassification.from_pretrained( HF_REPO_ID, trust_remote_code=True, subfolder=head, token=token, ) model.eval() return model @lru_cache(maxsize=len(DATASET_OPTIONS)) def load_curia_dataset(subset: str) -> Any: token = resolve_token() ds = load_dataset( HF_DATASET_ID, subset, split="test", token=token, ) if isinstance(ds, DatasetDict): return ds["test"] return ds def to_numpy_image(image: Any) -> np.ndarray: """Convert dataset or user-provided imagery to a float32 numpy array.""" if isinstance(image, np.ndarray): arr = image elif isinstance(image, Image.Image): arr = np.array(image) else: # Some datasets provide nested dicts or lists – attempt to coerce. arr = np.array(image) if arr.ndim == 3 and arr.shape[-1] == 3: # Convert RGB to grayscale by averaging channels arr = arr.mean(axis=-1) if arr.ndim != 2: raise ValueError("Expected a 2D image (H, W). Please provide a single axial/coronal/sagittal slice.") return arr.astype(np.float32) def to_display_image(image: np.ndarray) -> np.ndarray: """Normalise image for display purposes (uint8, 3-channel).""" arr = np.array(image, copy=True) if not np.isfinite(arr).all(): arr = np.nan_to_num(arr, nan=0.0) arr_min = float(arr.min()) arr_max = float(arr.max()) if arr_max - arr_min > 1e-6: arr = (arr - arr_min) / (arr_max - arr_min) else: arr = np.zeros_like(arr) arr = (arr * 255).clip(0, 255).astype(np.uint8) if arr.ndim == 2: arr = np.stack([arr, arr, arr], axis=-1) return arr def coerce_mask_array(mask: Any) -> Optional[np.ndarray]: if mask is None: return None try: arr = np.array(mask) except Exception: return None if arr.size == 0: return None return arr def prepare_mask_tensor(mask: Any, height: int, width: int) -> Optional[torch.Tensor]: mask_array = coerce_mask_array(mask) if mask_array is None: return None arr = np.squeeze(mask_array) if arr.ndim == 2: arr = arr.reshape(1, height, width) else: if arr.shape[-2:] == (height, width): arr = arr.reshape(-1, height, width) elif arr.shape[0] == height and arr.shape[1] == width: arr = np.transpose(arr, (2, 0, 1)) elif arr.shape[1] == height and arr.shape[2] == width: arr = arr.reshape(arr.shape[0], height, width) elif arr.size % (height * width) == 0: try: arr = arr.reshape(-1, height, width) except ValueError: return None else: return None mask_tensors: List[torch.Tensor] = [] for idx, slice_arr in enumerate(arr): bool_mask = torch.from_numpy(slice_arr > 0) if bool_mask.any(): mask_tensors.append(bool_mask) if not mask_tensors: return None stacked = torch.stack(mask_tensors, dim=0).bool() return stacked def apply_contour_overlay( image: np.ndarray, mask: Any, thickness: int = 1, color: Tuple[int, int, int] = (255, 0, 0), ) -> np.ndarray: """Draw only the contours of segmentation masks instead of filled masks.""" height, width = image.shape[:2] mask_tensor = prepare_mask_tensor(mask, height, width) if mask_tensor is None: return image # Work on a copy of the image output = image.copy() # Process each mask separately for idx in range(mask_tensor.shape[0]): mask_np = mask_tensor[idx].numpy().astype(np.uint8) # Find contours contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Draw contours on the image cv2.drawContours(output, contours, -1, color, thickness) return output def render_image_with_mask_info(image: np.ndarray, mask: Any) -> Tuple[np.ndarray, Optional[str]]: display = to_display_image(image) if mask is None: return display, None try: overlaid = apply_contour_overlay(display, mask) return overlaid, "" except Exception: return display, "Mask provided but could not be visualised." def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]: target_feature = dataset.features.get("target") if target_feature and hasattr(target_feature, "names"): names = list(target_feature.names) id2label = {i: name for i, name in enumerate(names)} classes = list(range(len(names))) return classes, id2label # Fall back to generic inspection targets = dataset["target"] if "target" in dataset.column_names else [] unique = sorted({int(t) for t in targets}) if targets else [] id2label = {i: str(i) for i in unique} return unique, id2label def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int: if "target" not in dataset.column_names: return random.randrange(len(dataset)) if target is None: return random.randrange(len(dataset)) indices = [idx for idx, value in enumerate(dataset["target"]) if value == target] if not indices: return random.randrange(len(dataset)) return random.choice(indices) def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame: """Return a dataframe sorted by probability desc.""" values = probs.detach().cpu().numpy() rows = [ {"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)} for idx, val in enumerate(values) ] df = pd.DataFrame(rows) df.sort_values("probability", ascending=False, inplace=True) return df def infer_image( image: np.ndarray, head: str, ) -> Tuple[str, pd.DataFrame]: processor = load_processor() model = load_model(head) with torch.no_grad(): processed = processor(images=image, return_tensors="pt") outputs = model(**processed) print(outputs) logits = outputs["logits"] probs = torch.nn.functional.softmax(logits[0], dim=-1) id2label = model.config.id2label or {} df = format_probabilities(probs, id2label) top_row = df.iloc[0] prediction = f"{top_row['label']} (p={top_row['probability']:.3f})" return prediction, df # --------------------------------------------------------------------------- # Gradio callbacks # --------------------------------------------------------------------------- def update_dataset_from_head(head: str) -> Dict[str, Any]: # Find the first dataset that matches this head for dataset_key, meta in DATASET_OPTIONS.items(): if meta["head"] == head: return gr.update(value=dataset_key) return gr.update() def load_dataset_metadata(subset: str) -> Tuple[Dict[str, Any], str]: try: dataset = load_curia_dataset(subset) except Exception as exc: # pragma: no cover - surfaced in UI dropdown = gr.update(choices=["Random"], value="Random") return dropdown, f"Failed to load dataset: {exc}" classes, id2label = dataset_class_metadata(dataset) if not classes: dropdown = gr.update( choices=["Random"], value="Random", ) return dropdown, "No class metadata detected; sampling at random" options = [ "Random", *[f"{cls_id}: {id2label.get(cls_id, str(cls_id))}" for cls_id in classes], ] dropdown = gr.update(choices=options, value="Random") return dropdown, f"Loaded {subset} ({len(dataset)} test samples)" def parse_target_selection(selection: str) -> Optional[int]: if not selection or selection == "Random": return None try: target_str = selection.split(":", 1)[0].strip() return int(target_str) except (ValueError, AttributeError): return None def sample_dataset_example( subset: str, target_id: Optional[int], ) -> Tuple[np.ndarray, str, Dict[str, Any]]: dataset = load_curia_dataset(subset) index = pick_random_indices(dataset, target_id) record = dataset[index] image = to_numpy_image(record["image"]) mask_array = coerce_mask_array(record.get("mask")) meta = { "index": index, "target": record.get("target"), "mask": mask_array, } return image, f"Sample #{index}", meta def load_dataset_sample( subset: str, target_selection: str, head: str, ) -> Tuple[ Optional[np.ndarray], str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]], ]: try: target_id = parse_target_selection(target_selection) image, caption, meta = sample_dataset_example(subset, target_id) display, mask_msg = render_image_with_mask_info(image, meta.get("mask")) target = meta.get("target") meta_text = caption if target is not None: meta_text += f" | target={target}" status = "Image loaded. Click 'Run inference' to compute predictions." if mask_msg: status += f" {mask_msg}" meta_text = status + "\n\n" + meta_text # Generate ground truth display ground_truth_update = gr.update(visible=False) if target is not None: model = load_model(head) id2label = model.config.id2label or {} label_name = id2label.get(target, str(target)) ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True) return ( display, meta_text, pd.DataFrame(), ground_truth_update, {"image": image, "mask": meta.get("mask")}, ) except Exception as exc: # pragma: no cover - surfaced in UI return None, f"Failed to load sample: {exc}", pd.DataFrame(), gr.update(visible=False), None def run_inference( sample_state: Optional[Dict[str, Any]], head: str, ) -> Tuple[str, pd.DataFrame]: if not sample_state or "image" not in sample_state: return "Load a dataset sample or upload an image first.", pd.DataFrame() try: image = sample_state["image"] prediction, df = infer_image(image, head) result_text = f"**Prediction:** {prediction}" return result_text, df except Exception as exc: # pragma: no cover - surfaced in UI return f"Failed to run inference: {exc}", pd.DataFrame() def handle_upload_preview( image: np.ndarray | Image.Image | None, ) -> Tuple[Optional[np.ndarray], str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]: if image is None: return None, "Please upload an image.", pd.DataFrame(), gr.update(visible=False), None try: np_image = to_numpy_image(image) display = to_display_image(np_image) return ( display, "Image uploaded. Click 'Run inference' to compute predictions.", pd.DataFrame(), gr.update(visible=False), {"image": np_image, "mask": None}, ) except Exception as exc: # pragma: no cover - surfaced in UI return None, f"Failed to load image: {exc}", pd.DataFrame(), gr.update(visible=False), None # --------------------------------------------------------------------------- # Interface definition # --------------------------------------------------------------------------- def build_demo() -> gr.Blocks: with gr.Blocks(css=".gr-prose { max-width: 900px; }") as demo: gr.Markdown( """ # Curia Model Playground Experiment with the multi-head Curia models on CuriaBench evaluation data or your own medical images. Each head expects a single 2D slice in the corresponding plane/orientation as defined for Curia (PL for axial, IL for coronal, IP for sagittal). Ensure images are unwindowed and either raw Hounsfield units (CT) or normalised intensity values (MRI). """ ) head_dropdown = gr.Dropdown( label="Model head", choices=[(label, key) for key, label in HEAD_OPTIONS], value="anatomy-ct", ) gr.Markdown("---") with gr.Row(): with gr.Column(): gr.Markdown("### Load dataset sample") dataset_dropdown = gr.Dropdown( label="CuriaBench subset", choices=[(meta["label"], key) for key, meta in DATASET_OPTIONS.items()], value="anatomy-ct", ) dataset_status = gr.Markdown("Select a dataset to load class metadata.") class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random") dataset_btn = gr.Button("Load dataset sample") with gr.Column(): gr.Markdown("### Upload custom image") upload_component = gr.Image(label="Upload image", image_mode="L", type="numpy") gr.Markdown("---") infer_btn = gr.Button("Run inference", variant="primary") with gr.Row(): with gr.Column(): image_display = gr.Image(label="Image", interactive=False, type="numpy") ground_truth_display = gr.Markdown(visible=False) with gr.Column(): gr.Markdown("### Predictions") status_text = gr.Markdown() prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"]) image_state = gr.State() # Event wiring head_dropdown.change( fn=update_dataset_from_head, inputs=[head_dropdown], outputs=[dataset_dropdown], ) dataset_dropdown.change( fn=load_dataset_metadata, inputs=[dataset_dropdown], outputs=[class_dropdown, dataset_status], ) dataset_btn.click( fn=load_dataset_sample, inputs=[dataset_dropdown, class_dropdown, head_dropdown], outputs=[ image_display, status_text, prediction_probs, ground_truth_display, image_state, ], ) upload_component.upload( fn=handle_upload_preview, inputs=[upload_component], outputs=[ image_display, status_text, prediction_probs, ground_truth_display, image_state, ], ) infer_btn.click( fn=run_inference, inputs=[image_state, head_dropdown], outputs=[status_text, prediction_probs], ) gr.Markdown( """ ### Notes - Configure the `HF_TOKEN` secret in your Space to load private checkpoints and datasets from the `raidium` organisation. - When masks are available in the dataset sample, their contours are drawn on the image for visual reference using OpenCV. - Uploaded images must be single-channel arrays. Multi-channel inputs are converted to grayscale automatically. """ ) return demo demo = build_demo() if __name__ == "__main__": demo.launch()