|
|
"""Gradio Space for exploring Curia models and CuriaBench datasets. |
|
|
|
|
|
This application allows users to: |
|
|
|
|
|
- Select any available Curia classification head. |
|
|
- Load the matching CuriaBench test split and sample random images per class. |
|
|
- Upload custom medical images that match the model's expected orientation. |
|
|
- Forward images through the selected model head and visualise class probabilities. |
|
|
|
|
|
The space expects an HF token with access to "raidium" resources to be |
|
|
provided via the HF_TOKEN environment variable (configure it as a secret when |
|
|
deploying to Hugging Face Spaces). |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import base64 |
|
|
import random |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from datasets import Dataset |
|
|
from PIL import Image |
|
|
import traceback |
|
|
|
|
|
from inference import ( |
|
|
load_curia_dataset, |
|
|
load_id_to_labels, |
|
|
infer_image, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HEAD_OPTIONS: List[Tuple[str, str]] = [ |
|
|
("abdominal-trauma", "Active Extravasation"), |
|
|
("anatomy-ct", "Anatomy CT"), |
|
|
("anatomy-mri", "Anatomy MRI"), |
|
|
("atlas-stroke", "Atlas Stroke"), |
|
|
("covidx-ct", "COVIDx CT"), |
|
|
("deep-lesion-site", "Deep Lesion Site"), |
|
|
("emidec-classification-mask", "EMIDEC Classification"), |
|
|
("ich", "Intracranial Hemorrhage"), |
|
|
("ixi", "IXI"), |
|
|
("kits", "KiTS"), |
|
|
("kneeMRI", "Knee MRI"), |
|
|
("luna16-3D", "LUNA16 3D"), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
("oasis", "OASIS"), |
|
|
] |
|
|
|
|
|
|
|
|
HEADS_REQUIRING_MASK: set[str] = { |
|
|
"anatomy-ct", |
|
|
"anatomy-mri", |
|
|
"deep-lesion-site", |
|
|
"emidec-classification-mask", |
|
|
"kits", |
|
|
"kneeMRI", |
|
|
"luna16-3D", |
|
|
"neural_foraminal_narrowing", |
|
|
"spinal_canal_stenosis", |
|
|
"subarticular_stenosis", |
|
|
} |
|
|
|
|
|
HEADS_3D = { |
|
|
"oasis", |
|
|
"luna16-3D", |
|
|
"kneeMRI", |
|
|
} |
|
|
|
|
|
REGRESSION_HEADS = { |
|
|
"ixi", |
|
|
} |
|
|
|
|
|
DATASET_OPTIONS: Dict[str, str] = { |
|
|
"anatomy-ct": "Anatomy CT (test)", |
|
|
"anatomy-ct-hard": "Anatomy CT Hard (test)", |
|
|
"anatomy-mri": "Anatomy MRI (test)", |
|
|
"covidx-ct": "COVIDx CT (test)", |
|
|
"deep-lesion-site": "Deep Lesion Site (test)", |
|
|
"emidec-classification-mask": "EMIDEC Classification Mask (test)", |
|
|
"ixi": "IXI (test)", |
|
|
"kits": "KiTS (test)", |
|
|
"kneeMRI": "Knee MRI (test)", |
|
|
"luna16-3D": "LUNA16 3D (test)", |
|
|
"oasis": "OASIS (test)", |
|
|
} |
|
|
|
|
|
DEFAULT_DATASET_FOR_HEAD: Dict[str, str] = { |
|
|
"anatomy-ct": "anatomy-ct", |
|
|
"anatomy-mri": "anatomy-mri", |
|
|
"covidx-ct": "covidx-ct", |
|
|
"deep-lesion-site": "deep-lesion-site", |
|
|
"emidec-classification-mask": "emidec-classification-mask", |
|
|
"ixi": "ixi", |
|
|
"kits": "kits", |
|
|
"kneeMRI": "kneeMRI", |
|
|
"luna16-3D": "luna16-3D", |
|
|
"oasis": "oasis", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_WINDOWINGS: Dict[str, Optional[Dict[str, int]]] = { |
|
|
"anatomy-ct": {"window_level": 40, "window_width": 400}, |
|
|
"anatomy-ct-hard": {"window_level": 40, "window_width": 400}, |
|
|
"anatomy-mri": None, |
|
|
"atlas-stroke": None, |
|
|
"covidx-ct": {"window_level": -600, "window_width": 1500}, |
|
|
"deep-lesion-site": {"window_level": 40, "window_width": 400}, |
|
|
"emidec-classification-mask": None, |
|
|
"ich": {"window_level": 40, "window_width": 80}, |
|
|
"ixi": None, |
|
|
"kits": {"window_level": 40, "window_width": 400}, |
|
|
"kneeMRI": None, |
|
|
"luna16": {"window_level": -600, "window_width": 1500}, |
|
|
"luna16-3D": {"window_level": -600, "window_width": 1500}, |
|
|
"oasis": None, |
|
|
} |
|
|
|
|
|
LOGO_PATH = "Logo horizontal medium copie 4_CREME.png" |
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
.gr-prose { max-width: 900px; } |
|
|
#app-hero { |
|
|
display: flex; |
|
|
align-items: center; |
|
|
gap: 2.5rem; |
|
|
margin-bottom: 1.5rem; |
|
|
padding-right: 1.5rem; |
|
|
} |
|
|
#app-hero .hero-text { |
|
|
flex: 1; |
|
|
padding-right: 1rem; |
|
|
} |
|
|
#app-hero .hero-text h1 { |
|
|
font-size: 2.25rem; |
|
|
margin-bottom: 0.5rem; |
|
|
} |
|
|
#app-hero .hero-text p { |
|
|
margin: 0.25rem 0; |
|
|
line-height: 1.5; |
|
|
} |
|
|
#app-hero .hero-logo img { |
|
|
max-height: 60px; |
|
|
width: auto; |
|
|
display: block; |
|
|
} |
|
|
@media (max-width: 768px) { |
|
|
#app-hero { |
|
|
flex-direction: column; |
|
|
text-align: center; |
|
|
padding-right: 0; |
|
|
} |
|
|
#app-hero .hero-text { |
|
|
padding-right: 0; |
|
|
} |
|
|
#app-hero .hero-text h1, |
|
|
#app-hero .hero-text p { |
|
|
text-align: center; |
|
|
} |
|
|
#app-hero .hero-logo img { |
|
|
margin: 0 auto 1rem; |
|
|
} |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
def load_logo_data_uri() -> str: |
|
|
try: |
|
|
with open(LOGO_PATH, "rb") as logo_file: |
|
|
encoded = base64.b64encode(logo_file.read()).decode("ascii") |
|
|
return f"data:image/png;base64,{encoded}" |
|
|
except FileNotFoundError: |
|
|
return "" |
|
|
|
|
|
|
|
|
LOGO_DATA_URI = load_logo_data_uri() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_windowing(image: np.ndarray, head: str) -> np.ndarray: |
|
|
"""Apply CT windowing based on the dataset. |
|
|
|
|
|
For CT images, applies window level and width transformation. |
|
|
For MRI images (windowing=None), returns the image unchanged. |
|
|
|
|
|
Args: |
|
|
image: Raw image array (e.g., in Hounsfield Units for CT) |
|
|
subset: Dataset subset name to determine windowing parameters |
|
|
|
|
|
Returns: |
|
|
Windowed image array |
|
|
""" |
|
|
windowing = DEFAULT_WINDOWINGS.get(head) |
|
|
|
|
|
|
|
|
if windowing is None: |
|
|
return image |
|
|
|
|
|
window_level = windowing["window_level"] |
|
|
window_width = windowing["window_width"] |
|
|
|
|
|
|
|
|
|
|
|
window_min = window_level - window_width / 2 |
|
|
window_max = window_level + window_width / 2 |
|
|
|
|
|
|
|
|
windowed = np.clip(image, window_min, window_max) |
|
|
windowed = (windowed - window_min) / (window_max - window_min) |
|
|
|
|
|
return windowed.astype(np.float32) |
|
|
|
|
|
|
|
|
def to_display_image(image: np.ndarray) -> np.ndarray: |
|
|
"""Normalise image for display purposes (uint8, 3-channel).""" |
|
|
|
|
|
|
|
|
if image.ndim == 3: |
|
|
gr.Info(f"Image is 3D, we display only the middle slice") |
|
|
image = image[:, :, image.shape[2] // 2] |
|
|
|
|
|
arr = np.array(image, copy=True) |
|
|
if not np.isfinite(arr).all(): |
|
|
arr = np.nan_to_num(arr, nan=0.0) |
|
|
|
|
|
arr_min = float(arr.min()) |
|
|
arr_max = float(arr.max()) |
|
|
if arr_max - arr_min > 1e-6: |
|
|
arr = (arr - arr_min) / (arr_max - arr_min) |
|
|
else: |
|
|
arr = np.zeros_like(arr) |
|
|
|
|
|
arr = (arr * 255).clip(0, 255).astype(np.uint8) |
|
|
if arr.ndim == 2: |
|
|
arr = np.stack([arr, arr, arr], axis=-1) |
|
|
return arr |
|
|
|
|
|
|
|
|
def prepare_mask_tensor(mask: np.ndarray, height: int, width: int) -> Optional[torch.Tensor]: |
|
|
arr = np.squeeze(mask) |
|
|
if arr.ndim == 2: |
|
|
arr = arr.reshape(1, height, width) |
|
|
else: |
|
|
if arr.shape[-2:] == (height, width): |
|
|
arr = arr.reshape(-1, height, width) |
|
|
elif arr.shape[0] == height and arr.shape[1] == width: |
|
|
arr = np.transpose(arr, (2, 0, 1)) |
|
|
elif arr.shape[1] == height and arr.shape[2] == width: |
|
|
arr = arr.reshape(arr.shape[0], height, width) |
|
|
elif arr.size % (height * width) == 0: |
|
|
try: |
|
|
arr = arr.reshape(-1, height, width) |
|
|
except ValueError: |
|
|
return None |
|
|
else: |
|
|
return None |
|
|
|
|
|
mask_tensors: List[torch.Tensor] = [] |
|
|
for idx, slice_arr in enumerate(arr): |
|
|
bool_mask = torch.from_numpy(slice_arr > 0) |
|
|
if bool_mask.any(): |
|
|
mask_tensors.append(bool_mask) |
|
|
|
|
|
if not mask_tensors: |
|
|
return None |
|
|
|
|
|
stacked = torch.stack(mask_tensors, dim=0).bool() |
|
|
return stacked |
|
|
|
|
|
|
|
|
def apply_contour_overlay( |
|
|
image: np.ndarray, |
|
|
mask: Any, |
|
|
thickness: int = 1, |
|
|
color: Tuple[int, int, int] = (255, 0, 0), |
|
|
) -> np.ndarray: |
|
|
"""Draw only the contours of segmentation masks instead of filled masks.""" |
|
|
height, width = image.shape[:2] |
|
|
mask_tensor = prepare_mask_tensor(mask, height, width) |
|
|
if mask_tensor is None: |
|
|
return image |
|
|
|
|
|
|
|
|
output = image.copy() |
|
|
|
|
|
|
|
|
for idx in range(mask_tensor.shape[0]): |
|
|
mask_np = mask_tensor[idx].numpy().astype(np.uint8) |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
|
|
|
cv2.drawContours(output, contours, -1, color, thickness) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def render_image_with_mask_info(image: np.ndarray, mask: Any) -> np.ndarray: |
|
|
display = to_display_image(image) |
|
|
if mask is None: |
|
|
return display |
|
|
|
|
|
try: |
|
|
overlaid = apply_contour_overlay(display, mask) |
|
|
return overlaid |
|
|
except Exception: |
|
|
gr.Warning("Mask provided but could not be visualised.") |
|
|
return display |
|
|
|
|
|
|
|
|
def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int: |
|
|
if "target" not in dataset.column_names: |
|
|
return random.randrange(len(dataset)) |
|
|
|
|
|
if target is None: |
|
|
return random.randrange(len(dataset)) |
|
|
|
|
|
indices = [idx for idx, value in enumerate(dataset["target"]) if value == target] |
|
|
if not indices: |
|
|
return random.randrange(len(dataset)) |
|
|
return random.choice(indices) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_dataset_display(head: str) -> str: |
|
|
"""Update the dataset name display based on the selected head.""" |
|
|
dataset_key = DEFAULT_DATASET_FOR_HEAD.get(head) |
|
|
if dataset_key: |
|
|
dataset_label = DATASET_OPTIONS.get(dataset_key, dataset_key) |
|
|
return f"**Dataset:** {dataset_label}" |
|
|
return "**Dataset:** not available" |
|
|
|
|
|
|
|
|
def update_upload_component_state(head: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
|
|
"""Disable upload component for heads that require masks.""" |
|
|
if head in HEADS_REQUIRING_MASK: |
|
|
info_update = gr.update( |
|
|
value="⚠️ Custom image upload is disabled for this task because it requires a mask from the dataset.", |
|
|
visible=True, |
|
|
) |
|
|
upload_update = gr.update(interactive=False) |
|
|
return info_update, upload_update |
|
|
elif head in HEADS_3D: |
|
|
info_update = gr.update( |
|
|
value="⚠️ Custom image upload is disabled for this task because it requires a 3D image.", |
|
|
visible=True, |
|
|
) |
|
|
upload_update = gr.update(interactive=False) |
|
|
return info_update, upload_update |
|
|
|
|
|
info_update = gr.update(visible=False) |
|
|
upload_update = gr.update(interactive=True) |
|
|
return info_update, upload_update |
|
|
|
|
|
|
|
|
def load_dataset_metadata(head: str) -> Tuple[Dict[str, Any], str, Dict[str, Any]]: |
|
|
"""Load dataset metadata based on the selected head.""" |
|
|
subset = DEFAULT_DATASET_FOR_HEAD.get(head) |
|
|
if not subset: |
|
|
dropdown = gr.update(choices=["Random"], value="Random", interactive=False) |
|
|
button = gr.update(interactive=False) |
|
|
return dropdown, "No dataset found for this head.", button |
|
|
|
|
|
|
|
|
id2label = load_id_to_labels().get(head, {}) |
|
|
|
|
|
|
|
|
try: |
|
|
dataset = load_curia_dataset(subset) |
|
|
except Exception as exc: |
|
|
dropdown = gr.update(choices=["Random"], value="Random", interactive=False) |
|
|
button = gr.update(interactive=False) |
|
|
return dropdown, f"Failed to load dataset: {exc}", button |
|
|
|
|
|
|
|
|
classes = sorted(id2label.keys()) |
|
|
options = [ |
|
|
"Random", |
|
|
*[f"{cls_id}: {id2label[cls_id]}" for cls_id in classes], |
|
|
] |
|
|
dropdown = gr.update(choices=options, value="Random", interactive=True) |
|
|
button = gr.update(interactive=True) |
|
|
return dropdown, f"Loaded {subset} ({len(dataset)} test samples)", button |
|
|
|
|
|
|
|
|
def parse_target_selection(selection: str) -> Optional[int]: |
|
|
if not selection or selection == "Random": |
|
|
return None |
|
|
|
|
|
try: |
|
|
target_str = selection.split(":", 1)[0].strip() |
|
|
return int(target_str) |
|
|
except (ValueError, AttributeError): |
|
|
return None |
|
|
|
|
|
|
|
|
def sample_dataset_example( |
|
|
subset: str, |
|
|
target_id: Optional[int], |
|
|
) -> Tuple[np.ndarray, Dict[str, Any]]: |
|
|
dataset = load_curia_dataset(subset) |
|
|
index = pick_random_indices(dataset, target_id) |
|
|
record = dataset[index] |
|
|
image = np.array(record["image"]).astype(np.float32) |
|
|
mask_array = record.get("mask") |
|
|
|
|
|
meta = { |
|
|
"index": index, |
|
|
"target": record.get("target"), |
|
|
"mask": mask_array, |
|
|
} |
|
|
|
|
|
return image, meta |
|
|
|
|
|
|
|
|
def load_dataset_sample( |
|
|
target_selection: str, |
|
|
head: str, |
|
|
) -> Tuple[ |
|
|
Optional[np.ndarray], |
|
|
str, |
|
|
Dict[str, Any], |
|
|
Dict[str, Any], |
|
|
Optional[Dict[str, Any]], |
|
|
]: |
|
|
"""Load a dataset sample based on the selected head.""" |
|
|
subset = DEFAULT_DATASET_FOR_HEAD.get(head) |
|
|
if not subset: |
|
|
gr.Warning("No dataset found for this head.") |
|
|
return None, "", gr.update(visible=False), gr.update(visible=False), None |
|
|
|
|
|
try: |
|
|
target_id = parse_target_selection(target_selection) |
|
|
image, meta = sample_dataset_example(subset, target_id) |
|
|
|
|
|
windowed_image = apply_windowing(image, subset) |
|
|
display = to_display_image(windowed_image) |
|
|
if meta.get("mask") is not None: |
|
|
display = apply_contour_overlay(display, meta.get("mask")) |
|
|
|
|
|
target = meta.get("target") |
|
|
|
|
|
ground_truth_update = gr.update(value="") |
|
|
if target is not None: |
|
|
|
|
|
id2label = load_id_to_labels().get(head, {}) |
|
|
label_name = id2label.get(target, str(target)) |
|
|
ground_truth_update = gr.update(value=f"{label_name} (class {target})", visible=True) |
|
|
|
|
|
return ( |
|
|
display, |
|
|
"", |
|
|
gr.update(visible=False), |
|
|
ground_truth_update, |
|
|
{"image": image, "mask": meta.get("mask")}, |
|
|
) |
|
|
except Exception as exc: |
|
|
gr.Warning(f"Failed to load sample: {exc}") |
|
|
return None, "", gr.update(visible=False), gr.update(visible=False), None |
|
|
|
|
|
|
|
|
def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame: |
|
|
"""Return a dataframe sorted by probability desc.""" |
|
|
|
|
|
values = probs.detach().cpu().numpy() |
|
|
rows = [ |
|
|
{"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)} |
|
|
for idx, val in enumerate(values) |
|
|
] |
|
|
df = pd.DataFrame(rows) |
|
|
df.sort_values("probability", ascending=False, inplace=True) |
|
|
return df |
|
|
|
|
|
|
|
|
def run_inference( |
|
|
image_state: Optional[Dict[str, Any]], |
|
|
head: str, |
|
|
) -> Tuple[str, Dict[str, Any]]: |
|
|
if not image_state or "image" not in image_state: |
|
|
return "Load a dataset sample or upload an image first.", gr.update(visible=False) |
|
|
|
|
|
try: |
|
|
image = image_state["image"] |
|
|
output = infer_image(image, head, image_state.get("mask"), return_probs=head not in REGRESSION_HEADS) |
|
|
|
|
|
if head in REGRESSION_HEADS: |
|
|
return f"{output:.1f}", gr.update(visible=False) |
|
|
|
|
|
|
|
|
id2label = load_id_to_labels().get(head, {}) |
|
|
|
|
|
df = format_probabilities(output, id2label) |
|
|
top_row = df.iloc[0] |
|
|
prediction = f"{top_row['label']} (p={top_row['probability']:.3f})" |
|
|
result_text = prediction |
|
|
return result_text, gr.update(visible=True, value=df) |
|
|
except Exception as exc: |
|
|
traceback.print_exc() |
|
|
return f"Failed to run inference: {exc}", gr.update(visible=False) |
|
|
|
|
|
def handle_upload_preview( |
|
|
image: np.ndarray | Image.Image | None, |
|
|
head: str, |
|
|
) -> Tuple[Optional[np.ndarray], str, str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]: |
|
|
"""Handle image upload preview, deriving dataset from head.""" |
|
|
if image is None: |
|
|
return None, "Please upload an image.", "", pd.DataFrame(), gr.update(visible=False), None |
|
|
|
|
|
try: |
|
|
np_image = np.array(image).astype(np.float32) |
|
|
if np_image.ndim == 3: |
|
|
|
|
|
np_image = np_image.mean(axis=-1) |
|
|
|
|
|
|
|
|
display = to_display_image(np_image) |
|
|
|
|
|
return ( |
|
|
display, |
|
|
"Image uploaded. Computing predictions...", |
|
|
"", |
|
|
pd.DataFrame(), |
|
|
gr.update(value=""), |
|
|
{"image": np_image, "mask": None}, |
|
|
) |
|
|
except Exception as exc: |
|
|
return None, f"Failed to load image: {exc}", "", pd.DataFrame(), gr.update(value=""), None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_demo() -> gr.Blocks: |
|
|
with gr.Blocks(css=CUSTOM_CSS) as demo: |
|
|
logo_block = "" |
|
|
if LOGO_DATA_URI: |
|
|
logo_block = f'<div class="hero-logo"><img src="{LOGO_DATA_URI}" alt="Curia logo" /></div>' |
|
|
hero_html = f""" |
|
|
<div id=\"app-hero\"> |
|
|
{logo_block} |
|
|
<div class=\"hero-text\"> |
|
|
<h1>Curia Model Playground</h1> |
|
|
<p>Experiment with the multi-head Curia models on CuriaBench evaluation data or your own medical images.</p> |
|
|
<p>Each head expects a single 2D slice in the Curia-defined plane/orientation (PL axial, IL coronal, IP sagittal) with raw Hounsfield units (CT) or normalised MRI intensities.</p> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
gr.HTML(hero_html) |
|
|
|
|
|
default_head = "kits" |
|
|
head_dropdown = gr.Dropdown( |
|
|
label="Model head", |
|
|
choices=[(label, key) for key, label in HEAD_OPTIONS], |
|
|
value=default_head, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
dataset_display = gr.Markdown(f"**Dataset:** {DATASET_OPTIONS.get(DEFAULT_DATASET_FOR_HEAD.get(default_head, ''), 'Unknown')}") |
|
|
dataset_status = gr.Markdown("Select a model head to load class metadata.") |
|
|
class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random") |
|
|
dataset_btn = gr.Button("Load dataset sample") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Upload custom image") |
|
|
|
|
|
initial_requires_mask = default_head in HEADS_REQUIRING_MASK |
|
|
upload_info_text = gr.Markdown( |
|
|
value=( |
|
|
"⚠️ Custom image upload is disabled for this task because it requires a mask from the dataset." |
|
|
if initial_requires_mask |
|
|
else "" |
|
|
), |
|
|
visible=initial_requires_mask, |
|
|
) |
|
|
upload_component = gr.Image( |
|
|
label="Upload image", |
|
|
image_mode="L", |
|
|
type="numpy", |
|
|
interactive=not initial_requires_mask, |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
status_text = gr.Markdown() |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_display = gr.Image(label="Image", interactive=False, type="numpy") |
|
|
|
|
|
with gr.Column(): |
|
|
ground_truth_display = gr.Textbox(label="Ground Truth", interactive=False) |
|
|
main_prediction = gr.Textbox(label="Prediction", value="", interactive=False) |
|
|
prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"], visible=False) |
|
|
|
|
|
image_state = gr.State() |
|
|
|
|
|
|
|
|
|
|
|
demo.load( |
|
|
fn=load_dataset_metadata, |
|
|
inputs=[head_dropdown], |
|
|
outputs=[class_dropdown, dataset_status, dataset_btn], |
|
|
) |
|
|
|
|
|
head_dropdown.change( |
|
|
fn=update_dataset_display, |
|
|
inputs=[head_dropdown], |
|
|
outputs=[dataset_display], |
|
|
).then( |
|
|
fn=update_upload_component_state, |
|
|
inputs=[head_dropdown], |
|
|
outputs=[upload_info_text, upload_component], |
|
|
).then( |
|
|
fn=load_dataset_metadata, |
|
|
inputs=[head_dropdown], |
|
|
outputs=[class_dropdown, dataset_status, dataset_btn], |
|
|
) |
|
|
|
|
|
dataset_btn.click( |
|
|
fn=load_dataset_sample, |
|
|
inputs=[class_dropdown, head_dropdown], |
|
|
outputs=[ |
|
|
image_display, |
|
|
main_prediction, |
|
|
prediction_probs, |
|
|
ground_truth_display, |
|
|
image_state, |
|
|
], |
|
|
).then( |
|
|
fn=run_inference, |
|
|
inputs=[image_state, head_dropdown], |
|
|
outputs=[main_prediction, prediction_probs], |
|
|
) |
|
|
|
|
|
upload_component.upload( |
|
|
fn=handle_upload_preview, |
|
|
inputs=[upload_component, head_dropdown], |
|
|
outputs=[ |
|
|
image_display, |
|
|
status_text, |
|
|
main_prediction, |
|
|
prediction_probs, |
|
|
ground_truth_display, |
|
|
image_state, |
|
|
], |
|
|
).then( |
|
|
fn=run_inference, |
|
|
inputs=[image_state, head_dropdown], |
|
|
outputs=[main_prediction, prediction_probs], |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = build_demo() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|