diff --git "a/nothin/1/2/3/4/diamond_col_detect_dino_v3.ipynb" "b/nothin/1/2/3/4/diamond_col_detect_dino_v3.ipynb" new file mode 100644--- /dev/null +++ "b/nothin/1/2/3/4/diamond_col_detect_dino_v3.ipynb" @@ -0,0 +1,14272 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "96M6VtpCOX3U" + }, + "source": [ + "# New imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "qnwNKg9GOagk" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError(\"HTTPSConnectionPool(host='developer.download.nvidia.com', port=443): Read timed out. (read timeout=15)\")': /compute/redist/torch/\n", + "WARNING: Retrying (Retry(total=3, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ReadTimeoutError(\"HTTPSConnectionPool(host='developer.download.nvidia.com', port=443): Read timed out. (read timeout=15)\")': /compute/redist/torch/\n" + ] + } + ], + "source": [ + "!pip install -q huggingface_hub transformers timm torch torchvision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DScoV9nCOod0" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Admin\\miniconda3\\envs\\my_env_diam\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from huggingface_hub import login\n", + "login(token=\"\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Vdk00fOPO3bQ" + }, + "source": [ + "# Adding datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 90 + }, + "id": "TIAWWRI-OqGs", + "outputId": "51f38e72-e890-4332-beb7-f29662d659d1" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " Upload widget is only available when the cell has been executed in the\n", + " current browser session. Please rerun this cell to enable.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving kaggle.json to kaggle.json\n" + ] + }, + { + "data": { + "text/plain": [ + "{'kaggle.json': b'{\"username\":\"devpatel45\",\"key\":\"fd26936b60b93e22688fb90afed7e729\"}'}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# from google.colab import files\n", + "\n", + "# # Upload the kaggle.json file\n", + "# files.upload()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "mOaaLjJrPj5U" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The syntax of the command is incorrect.\n", + "'mv' is not recognized as an internal or external command,\n", + "operable program or batch file.\n", + "'chmod' is not recognized as an internal or external command,\n", + "operable program or batch file.\n" + ] + } + ], + "source": [ + "!mkdir -p ~/.kaggle\n", + "!mv kaggle.json ~/.kaggle/\n", + "!chmod 600 ~/.kaggle/kaggle.json # secure the file" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wLTvzXeCVN0T" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "NURUanPxPmlM" + }, + "outputs": [], + "source": [ + "!pip install -q kaggle" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nYrY9U6JPvN0", + "outputId": "cfc8b43b-4c21-436a-9a19-a92bd13f0c70" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset URL: https://www.kaggle.com/datasets/aayushpurswani/diamond-images-dataset\n", + "License(s): MIT\n", + "Downloading diamond-images-dataset.zip to d:\\DEV PATEL\\2025\\Diamond_AI_col\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " 0%| | 0.00/3.08G [00:00 Reading Data" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "ZsP0hUSATxAy" + }, + "outputs": [], + "source": [ + "import os\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from tqdm.notebook import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 226 + }, + "execution": { + "iopub.execute_input": "2025-09-10T07:52:46.068362Z", + "iopub.status.busy": "2025-09-10T07:52:46.067681Z", + "iopub.status.idle": "2025-09-10T07:52:46.266593Z", + "shell.execute_reply": "2025-09-10T07:52:46.265866Z", + "shell.execute_reply.started": "2025-09-10T07:52:46.068332Z" + }, + "id": "Evos55ETOORe", + "outputId": "7ebbaeca-5e94-4956-c25d-57cf5e2f63fc", + "trusted": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
path_to_imgstock_numbershapecaratclaritycolourcutpolishsymmetryfluorescencelablengthwidthdepthimage_path
0web_scraped/cushion/2106452.jpg2106452cushion0.55SI2NVGEXVGNGIA4.564.443.09D:/DEV PATEL/2025/Diamond_AI_col/diamond_image...
1web_scraped/cushion/2042329.jpg2042329cushion0.52SI2Y-ZEXEXVGFGIA4.404.312.99D:/DEV PATEL/2025/Diamond_AI_col/diamond_image...
2web_scraped/cushion/2055268.jpg2055268cushion0.50SI1LVGEXVGNGIA4.874.192.89D:/DEV PATEL/2025/Diamond_AI_col/diamond_image...
3web_scraped/cushion/2128779.jpg2128779cushion0.50VS2MEXEXVGFGIA4.734.282.84D:/DEV PATEL/2025/Diamond_AI_col/diamond_image...
4web_scraped/cushion/2103991.jpg2103991cushion0.51SI1MEXEXVGNGIA4.474.443.05D:/DEV PATEL/2025/Diamond_AI_col/diamond_image...
\n", + "
" + ], + "text/plain": [ + " path_to_img stock_number shape carat clarity \\\n", + "0 web_scraped/cushion/2106452.jpg 2106452 cushion 0.55 SI2 \n", + "1 web_scraped/cushion/2042329.jpg 2042329 cushion 0.52 SI2 \n", + "2 web_scraped/cushion/2055268.jpg 2055268 cushion 0.50 SI1 \n", + "3 web_scraped/cushion/2128779.jpg 2128779 cushion 0.50 VS2 \n", + "4 web_scraped/cushion/2103991.jpg 2103991 cushion 0.51 SI1 \n", + "\n", + " colour cut polish symmetry fluorescence lab length width depth \\\n", + "0 N VG EX VG N GIA 4.56 4.44 3.09 \n", + "1 Y-Z EX EX VG F GIA 4.40 4.31 2.99 \n", + "2 L VG EX VG N GIA 4.87 4.19 2.89 \n", + "3 M EX EX VG F GIA 4.73 4.28 2.84 \n", + "4 M EX EX VG N GIA 4.47 4.44 3.05 \n", + "\n", + " image_path \n", + "0 D:/DEV PATEL/2025/Diamond_AI_col/diamond_image... \n", + "1 D:/DEV PATEL/2025/Diamond_AI_col/diamond_image... \n", + "2 D:/DEV PATEL/2025/Diamond_AI_col/diamond_image... \n", + "3 D:/DEV PATEL/2025/Diamond_AI_col/diamond_image... \n", + "4 D:/DEV PATEL/2025/Diamond_AI_col/diamond_image... " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = pd.read_csv('D:/DEV PATEL/2025/Diamond_AI_col/diamond_images/web_scraped/diamond_data.csv')\n", + "data['image_path'] = 'D:/DEV PATEL/2025/Diamond_AI_col/diamond_images/' + data['path_to_img']\n", + "data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "dm5G4EjoYqRp" + }, + "outputs": [], + "source": [ + "data.to_csv(\"the_processedfile.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 837 + }, + "execution": { + "iopub.execute_input": "2025-09-10T07:52:46.819414Z", + "iopub.status.busy": "2025-09-10T07:52:46.818607Z", + "iopub.status.idle": "2025-09-10T07:52:46.830656Z", + "shell.execute_reply": "2025-09-10T07:52:46.829871Z", + "shell.execute_reply.started": "2025-09-10T07:52:46.819386Z" + }, + "id": "F5ANrmOcOORe", + "outputId": "861cb25e-cf25-4982-dbf2-9a62406b3519", + "trusted": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "colour\n", + "G 7674\n", + "F 6957\n", + "H 6555\n", + "E 6239\n", + "D 5824\n", + "I 5304\n", + "J 4273\n", + "K 2628\n", + "L 1388\n", + "M 754\n", + "N 417\n", + "FANCY 328\n", + "O-P 134\n", + "Q-R 80\n", + "U-V 78\n", + "S-T 68\n", + "W-X 36\n", + "Y-Z 23\n", + "BLUE 1\n", + "V:B 1\n", + "FC:P 1\n", + "D:P:BN 1\n", + "I:P 1\n", + "Name: count, dtype: int64" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"colour\"].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:26.656212Z", + "iopub.status.busy": "2025-09-03T12:32:26.655983Z", + "iopub.status.idle": "2025-09-03T12:32:26.683042Z", + "shell.execute_reply": "2025-09-03T12:32:26.682400Z", + "shell.execute_reply.started": "2025-09-03T12:32:26.656192Z" + }, + "id": "5Y2xteUNOORf", + "trusted": true + }, + "outputs": [], + "source": [ + "# # Empty sample and dropping\n", + "# data[data['full_path_to_img'] == '/kaggle/input/diamond-images-dataset/web_scraped/emerald/220188-630.jpg']\n", + "# data.drop(8235, axis = 0, inplace = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1AaGnuKkUF06" + }, + "source": [ + "# hugging face login" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zs6Xo3FQUJT6" + }, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "login(token=\"\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IuUNTIEmXggx" + }, + "source": [ + "# architecture for predictions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Usind dinov3 as backbone -> goes to two layer classification decoder -> results" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LsRWxuT7YUh6", + "outputId": "535b975b-1386-4c5f-d64c-2aada4e622f8" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Admin\\miniconda3\\envs\\my_env_diam\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Detected 23 classes. Label map provided: True\n", + "Removing rows with rare classes: ['BLUE', 'V:B', 'FC:P', 'D:P:BN', 'I:P']\n", + "Train: 41446 | Val: 7314\n" + ] + } + ], + "source": [ + "# Full pipeline: DINOv3 embeddings -> 2-layer classifier\n", + "# Run in Colab/Kaggle. Make sure to pip install timm + dependencies first:\n", + "# !pip install -q timm torch torchvision pandas scikit-learn\n", + "\n", + "import os\n", + "import random\n", + "from pathlib import Path\n", + "from tqdm import tqdm\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", + "from torchvision import transforms\n", + "from PIL import Image, UnidentifiedImageError # Import UnidentifiedImageError\n", + "import timm\n", + "\n", + "# -----------------------\n", + "# Config / hyperparams\n", + "# -----------------------\n", + "CSV_PATH = r\"D:\\DEV PATEL\\2025\\Diamond_AI_col\\the_processedfile.csv\" # <-- your CSV with columns: image_path,target\n", + "IMAGE_ROOT = \"\" # optional prefix to image paths in CSV\n", + "# old\n", + "# BACKBONE_NAME = \"dinov3_vitl14\"\n", + "# new\n", + "BACKBONE_NAME = \"hf_hub:timm/vit_small_patch16_dinov3.lvd1689m\"\n", + "BATCH_SIZE = 32\n", + "NUM_WORKERS = 4\n", + "LR = 1e-3\n", + "WEIGHT_DECAY = 1e-5\n", + "EPOCHS = 12\n", + "HIDDEN_DIM = 512 # hidden units in first classifier layer\n", + "DROPOUT = 0.3\n", + "SEED = 42\n", + "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "PRECOMPUTE = False # set True to precompute embeddings to disk (recommended if dataset fits memory)\n", + "\n", + "CACHE_DIR = \"embedding_cache\" # used if PRECOMPUTE True\n", + "os.makedirs(CACHE_DIR, exist_ok=True)\n", + "# -----------------------\n", + "\n", + "def seed_everything(seed=SEED):\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(seed)\n", + "\n", + "seed_everything()\n", + "\n", + "# -----------------------\n", + "# Dataset that reads images and returns (image, label)\n", + "# -----------------------\n", + "class ImageCSVLoader(Dataset):\n", + " def __init__(self, df, image_root=\"\", transform=None, label_map=None):\n", + " \"\"\"\n", + " df: DataFrame with columns ['image_path', 'colour']\n", + " image_root: optional prefix\n", + " transform: torchvision transforms for preprocessing\n", + " label_map: dict mapping class name -> int. If None, we assume targets are ints.\n", + " \"\"\"\n", + " self.df = df.reset_index(drop=True)\n", + " self.image_root = image_root\n", + " self.transform = transform\n", + " self.label_map = label_map\n", + " # Determine the expected size after transforms, for placeholder images\n", + " self._placeholder_size = (224, 224) # Default to common ViT input size\n", + " if transform:\n", + " # Try to infer size from transforms if possible\n", + " try:\n", + " dummy_img = Image.new('RGB', (256, 256)) # Start with a larger dummy\n", + " transformed_dummy = transform(dummy_img)\n", + " self._placeholder_size = (transformed_dummy.shape[-2], transformed_dummy.shape[-1])\n", + " except Exception:\n", + " pass # Fallback to default if size inference fails\n", + "\n", + "\n", + " def __len__(self):\n", + " return len(self.df)\n", + "\n", + " def __getitem__(self, idx):\n", + " row = self.df.iloc[idx]\n", + " img_path = os.path.join(self.image_root, row[\"image_path\"]) if self.image_root else row[\"image_path\"]\n", + " try:\n", + " img = Image.open(img_path).convert(\"RGB\")\n", + " except UnidentifiedImageError:\n", + " print(f\"Warning: Could not identify image file: {img_path}. Replacing with black image.\")\n", + " img = Image.new('RGB', self._placeholder_size, (0, 0, 0)) # Use inferred or default size\n", + " except Exception as e:\n", + " print(f\"Warning: Error loading image file {img_path}: {e}. Replacing with black image.\")\n", + " img = Image.new('RGB', self._placeholder_size, (0, 0, 0)) # Use inferred or default size\n", + "\n", + "\n", + " if self.transform:\n", + " img = self.transform(img)\n", + " target = row[\"colour\"]\n", + " if self.label_map is not None:\n", + " target = self.label_map[target]\n", + " else:\n", + " target = int(target)\n", + " return img, target\n", + "\n", + "# -----------------------\n", + "# Create transforms (DINO/ViT style)\n", + "# -----------------------\n", + "transform = transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(224),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=(0.485, 0.456, 0.406),\n", + " std=(0.229, 0.224, 0.225)),\n", + "])\n", + "\n", + "# -----------------------\n", + "# Load CSV and prepare label mapping\n", + "# -----------------------\n", + "df = pd.read_csv(CSV_PATH)\n", + "# Expecting df columns: ['image_path','colour']\n", + "if \"image_path\" not in df.columns or \"colour\" not in df.columns:\n", + " raise ValueError(\"CSV must contain 'image_path' and 'colour' columns\")\n", + "\n", + "# If colours are strings, build a mapping\n", + "if df['colour'].dtype == object:\n", + " classes = sorted(df['colour'].unique().tolist())\n", + " label_map = {c: i for i, c in enumerate(classes)}\n", + "else:\n", + " classes = sorted(df['colour'].unique().tolist())\n", + " label_map = None\n", + "\n", + "num_classes = len(classes) if label_map is not None else len(np.unique(df['colour']))\n", + "print(f\"Detected {num_classes} classes. Label map provided: {label_map is not None}\")\n", + "\n", + "# Identify classes with only one sample\n", + "counts = df['colour'].value_counts()\n", + "rare_classes = counts[counts < 2].index.tolist()\n", + "\n", + "# Remove rows with rare classes if any exist\n", + "if rare_classes:\n", + " print(f\"Removing rows with rare classes: {rare_classes}\")\n", + " df = df[~df['colour'].isin(rare_classes)].reset_index(drop=True)\n", + " # Rebuild classes and label_map after removing rare classes\n", + " classes = sorted(df['colour'].unique().tolist())\n", + " label_map = {c: i for i, c in enumerate(classes)}\n", + " num_classes = len(classes)\n", + "\n", + "# -----------------------\n", + "# Train / Val split (stratify on colour)\n", + "# -----------------------\n", + "train_df, val_df = train_test_split(df, test_size=0.15, stratify=df['colour'], random_state=SEED)\n", + "print(f\"Train: {len(train_df)} | Val: {len(val_df)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XmvEDXwMXlka", + "outputId": "b7a61d3b-afda-4920-f959-089fd3900ab5" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Backbone embedding dim: 1024\n" + ] + } + ], + "source": [ + "# -----------------------\n", + "# Build backbone (timm) and classifier\n", + "# -----------------------\n", + "def build_backbone(name=BACKBONE_NAME, device=DEVICE):\n", + " \"\"\"\n", + " Create timm model and set to eval mode and freeze params.\n", + " We'll use model.forward_features(...) to get a [B, embed_dim] tensor.\n", + " \"\"\"\n", + " # Updated model name to an available DINOv3 model\n", + " backbone = timm.create_model(\"vit_large_patch16_dinov3\", pretrained=True)\n", + " backbone.eval()\n", + " # freeze backbone\n", + " for p in backbone.parameters():\n", + " p.requires_grad = False\n", + " backbone.to(device)\n", + " return backbone\n", + "\n", + "backbone = build_backbone()\n", + "\n", + "# Determine embedding dim via a dummy forward (safe on CPU/GPU)\n", + "def infer_embedding_dim(backbone):\n", + " backbone.eval()\n", + " with torch.no_grad():\n", + " dummy = torch.randn(1, 3, 224, 224).to(next(backbone.parameters()).device)\n", + " # Use forward_features to get the feature map before the classifier head\n", + " feats = backbone.forward_features(dummy)\n", + " # If forward_features returns a tuple/list, take the first element (often the main feature map)\n", + " if isinstance(feats, (list, tuple)):\n", + " feats = feats[0]\n", + " # The shape should be (batch_size, sequence_length, embedding_dimension) for ViT features\n", + " # Or (batch_size, embedding_dimension) if it returns the pooled CLS token\n", + " # Let's assume it returns (B, seq_len, embed_dim) and we want embed_dim\n", + " if feats.ndim == 3:\n", + " return feats.shape[2]\n", + " # If it's already (B, embed_dim)\n", + " elif feats.ndim == 2:\n", + " return feats.shape[1]\n", + " else:\n", + " raise ValueError(f\"Unexpected feature shape from backbone: {feats.shape}\")\n", + "\n", + "EMBED_DIM = infer_embedding_dim(backbone)\n", + "print(\"Backbone embedding dim:\", EMBED_DIM)\n", + "\n", + "# Simple 2-layer classifier (embedding -> hidden -> classes)\n", + "class TwoLayerClassifier(nn.Module):\n", + " def __init__(self, in_dim, hidden_dim, num_classes, dropout=0.3):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(in_dim, hidden_dim),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(hidden_dim, num_classes)\n", + " )\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + "classifier = TwoLayerClassifier(EMBED_DIM, HIDDEN_DIM, num_classes, DROPOUT).to(DEVICE)\n", + "\n", + "# -----------------------\n", + "# Optional: Precompute embeddings (recommended if dataset fits in disk/memory)\n", + "# -----------------------\n", + "def precompute_embeddings(backbone, df, transform, image_root=\"\", cache_dir=CACHE_DIR, batch_size=32, num_workers=4):\n", + " \"\"\"\n", + " Compute embeddings for a dataframe and save as .pt files in cache_dir.\n", + " Returns a TensorDataset(embeddings, labels)\n", + " \"\"\"\n", + " ds = ImageCSVLoader(df, image_root=image_root, transform=transform, label_map=label_map)\n", + " loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)\n", + "\n", + " all_emb = []\n", + " all_lbl = []\n", + " backbone.eval()\n", + " with torch.no_grad():\n", + " for imgs, lbls in tqdm(loader, desc=\"Precomputing embeddings\"):\n", + " imgs = imgs.to(DEVICE)\n", + " feats = backbone.forward_features(imgs) # [B, seq_len, EMBED_DIM] or [B, EMBED_DIM]\n", + " if isinstance(feats, (list, tuple)):\n", + " feats = feats[0]\n", + " # Assuming we take the CLS token or flatten the sequence if necessary\n", + " if feats.ndim == 3:\n", + " # Take CLS token (first token)\n", + " feats = feats[:, 0]\n", + " elif feats.ndim == 2:\n", + " # Already seems to be pooled features\n", + " pass\n", + " else:\n", + " raise ValueError(f\"Unexpected feature shape during precomputation: {feats.shape}\")\n", + "\n", + " feats = feats.detach().cpu()\n", + " all_emb.append(feats)\n", + " all_lbl.append(torch.tensor(lbls))\n", + " all_emb = torch.cat(all_emb, dim=0)\n", + " all_lbl = torch.cat(all_lbl, dim=0)\n", + " # save\n", + " torch.save(all_emb, os.path.join(cache_dir, \"embeddings.pt\"))\n", + " torch.save(all_lbl, os.path.join(cache_dir, \"labels.pt\"))\n", + " print(\"Saved embeddings:\", all_emb.shape)\n", + " return TensorDataset(all_emb, all_lbl)\n", + "\n", + "if PRECOMPUTE:\n", + " train_dataset = precompute_embeddings(backbone, train_df, transform, IMAGE_ROOT)\n", + " val_dataset = precompute_embeddings(backbone, val_df, transform, IMAGE_ROOT)\n", + "else:\n", + " train_dataset = ImageCSVLoader(train_df, image_root=IMAGE_ROOT, transform=transform, label_map=label_map)\n", + " val_dataset = ImageCSVLoader(val_df, image_root=IMAGE_ROOT, transform=transform, label_map=label_map)\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n", + "\n", + "# -----------------------\n", + "# Training & evaluation loops\n", + "# -----------------------\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.AdamW(classifier.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # optional\n", + "\n", + "def train_one_epoch(backbone, classifier, loader, optimizer, criterion, device, precompute=PRECOMPUTE):\n", + " classifier.train()\n", + " running_loss = 0.0\n", + " running_correct = 0\n", + " n = 0\n", + " backbone.eval() # ensure backbone frozen\n", + " for batch in loader:\n", + " if precompute:\n", + " emb, lbl = batch\n", + " emb = emb.to(device)\n", + " lbl = lbl.to(device)\n", + " else:\n", + " imgs, lbl = batch\n", + " imgs = imgs.to(device)\n", + " lbl = lbl.to(device)\n", + " with torch.no_grad():\n", + " feats = backbone.forward_features(imgs)\n", + " if isinstance(feats, (list, tuple)):\n", + " feats = feats[0]\n", + " # Assuming we take the CLS token (index 0)\n", + " if feats.ndim == 3:\n", + " emb = feats[:, 0]\n", + " elif feats.ndim == 2:\n", + " emb = feats\n", + " else:\n", + " raise ValueError(f\"Unexpected feature shape during training: {feats.shape}\")\n", + "\n", + " optimizer.zero_grad()\n", + " logits = classifier(emb)\n", + " loss = criterion(logits, lbl)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " running_loss += float(loss.item()) * lbl.size(0)\n", + " preds = logits.argmax(dim=1)\n", + " running_correct += (preds == lbl).sum().item()\n", + " n += lbl.size(0)\n", + "\n", + " avg_loss = running_loss / n\n", + " avg_acc = running_correct / n\n", + " return avg_loss, avg_acc\n", + "\n", + "def eval_one_epoch(backbone, classifier, loader, criterion, device, precompute=PRECOMPUTE):\n", + " classifier.eval()\n", + " running_loss = 0.0\n", + " running_correct = 0\n", + " n = 0\n", + " with torch.no_grad():\n", + " for batch in loader:\n", + " if precompute:\n", + " emb, lbl = batch\n", + " emb = emb.to(device)\n", + " lbl = lbl.to(device)\n", + " else:\n", + " imgs, lbl = batch\n", + " imgs = imgs.to(device)\n", + " lbl = lbl.to(device)\n", + " feats = backbone.forward_features(imgs)\n", + " if isinstance(feats, (list, tuple)):\n", + " feats = feats[0]\n", + " # Assuming we take the CLS token (index 0)\n", + " if feats.ndim == 3:\n", + " emb = feats[:, 0]\n", + " elif feats.ndim == 2:\n", + " emb = feats\n", + " else:\n", + " raise ValueError(f\"Unexpected feature shape during evaluation: {feats.shape}\")\n", + " logits = classifier(emb)\n", + " loss = criterion(logits, lbl)\n", + " running_loss += float(loss.item()) * lbl.size(0)\n", + " preds = logits.argmax(dim=1)\n", + " running_correct += (preds == lbl).sum().item()\n", + " n += lbl.size(0)\n", + " avg_loss = running_loss / n\n", + " avg_acc = running_correct / n\n", + " return avg_loss, avg_acc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fZWK9DiVXlhp", + "outputId": "e48b76f1-607f-4d94-a3d6-3032c188dfb2" + }, + "outputs": [], + "source": [ + "# -----------------------\n", + "# Main training loop\n", + "# -----------------------\n", + "best_val_acc = 0.0\n", + "best_state = None\n", + "\n", + "for epoch in range(1, EPOCHS + 1):\n", + " train_loss, train_acc = train_one_epoch(backbone, classifier, train_loader, optimizer, criterion, DEVICE, precompute=PRECOMPUTE)\n", + " val_loss, val_acc = eval_one_epoch(backbone, classifier, val_loader, criterion, DEVICE, precompute=PRECOMPUTE)\n", + " scheduler.step()\n", + "\n", + " print(f\"[Epoch {epoch}/{EPOCHS}] Train loss: {train_loss:.4f} acc: {train_acc:.4f} | Val loss: {val_loss:.4f} acc: {val_acc:.4f}\")\n", + "\n", + " # save best\n", + " if val_acc > best_val_acc:\n", + " best_val_acc = val_acc\n", + " best_state = {\n", + " \"classifier_state\": classifier.state_dict(),\n", + " \"backbone_name\": BACKBONE_NAME,\n", + " \"label_map\": label_map,\n", + " \"embedding_dim\": EMBED_DIM,\n", + " \"epoch\": epoch,\n", + " \"val_acc\": val_acc\n", + " }\n", + " torch.save(best_state, \"best_classifier.pth\")\n", + " print(\"Saved best model -> best_classifier.pth\")\n", + "\n", + "print(\"Training finished. Best val acc:\", best_val_acc)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# with ConvNeXt-B as deocoder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Admin\\miniconda3\\envs\\my_env_diam\\Lib\\site-packages\\huggingface_hub\\file_download.py:143: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Admin\\.cache\\huggingface\\hub\\models--timm--vit_small_patch16_dinov3.lvd1689m. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n", + "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n", + " warnings.warn(message)\n", + "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inferred EMBED_DIM=384, TOKEN_SIDE=14, BACKBONE_HAS_CLS=False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`\n", + "c:\\Users\\Admin\\miniconda3\\envs\\my_env_diam\\Lib\\site-packages\\huggingface_hub\\file_download.py:143: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Admin\\.cache\\huggingface\\hub\\models--timm--convnext_base.fb_in22k_ft_in1k. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n", + "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n", + " warnings.warn(message)\n" + ] + } + ], + "source": [ + "# add/imports near top\n", + "import math\n", + "import copy\n", + "import types\n", + "\n", + "# -----------------------\n", + "# Build backbone (timm) - use BACKBONE_NAME variable if you prefer\n", + "# -----------------------\n", + "def build_backbone(name=BACKBONE_NAME, device=DEVICE):\n", + " \"\"\"\n", + " Create timm model and set to eval mode and freeze params.\n", + " We'll use model.forward_features(...) to get a [B, seq_len, embed_dim] tensor OR [B, embed_dim].\n", + " \"\"\"\n", + " backbone = timm.create_model(name, pretrained=True) # use the BACKBONE_NAME you already set\n", + " backbone.eval()\n", + " for p in backbone.parameters():\n", + " p.requires_grad = False\n", + " backbone.to(device)\n", + " return backbone\n", + "\n", + "backbone = build_backbone()\n", + "\n", + "# -----------------------\n", + "# Infer embedding dim and token spatial shape\n", + "# -----------------------\n", + "def infer_embedding_dim_and_shape(backbone):\n", + " backbone.eval()\n", + " with torch.no_grad():\n", + " dummy = torch.randn(1, 3, 224, 224).to(next(backbone.parameters()).device)\n", + " feats = backbone.forward_features(dummy)\n", + " if isinstance(feats, (list, tuple)):\n", + " feats = feats[0]\n", + " # If (B, seq_len, embed_dim)\n", + " if feats.ndim == 3:\n", + " seq_len = feats.shape[1]\n", + " embed_dim = feats.shape[2]\n", + " # Attempt to detect cls token:\n", + " # if seq_len == 1 + H*W then we assume CLS token present\n", + " possible = seq_len - 1\n", + " side = int(math.sqrt(possible))\n", + " if side * side == possible:\n", + " has_cls = True\n", + " token_side = side\n", + " else:\n", + " # maybe no cls, check if seq_len is perfect square\n", + " side2 = int(math.sqrt(seq_len))\n", + " if side2 * side2 == seq_len:\n", + " has_cls = False\n", + " token_side = side2\n", + " else:\n", + " # fallback: set token_side = int(sqrt(seq_len)) and hope\n", + " has_cls = False\n", + " token_side = int(math.sqrt(seq_len))\n", + " return embed_dim, token_side, has_cls\n", + " elif feats.ndim == 2:\n", + " # pooled features\n", + " embed_dim = feats.shape[1]\n", + " return embed_dim, None, False\n", + " else:\n", + " raise ValueError(f\"Unexpected feature shape from backbone: {feats.shape}\")\n", + "\n", + "EMBED_DIM, TOKEN_SIDE, BACKBONE_HAS_CLS = infer_embedding_dim_and_shape(backbone)\n", + "print(f\"Inferred EMBED_DIM={EMBED_DIM}, TOKEN_SIDE={TOKEN_SIDE}, BACKBONE_HAS_CLS={BACKBONE_HAS_CLS}\")\n", + "\n", + "# -----------------------\n", + "# ConvNeXt-based decoder wrapper\n", + "# -----------------------\n", + "class ConvNeXtDecoder(nn.Module):\n", + " \"\"\"\n", + " A wrapper that adapts a timm convnext_base (ConvNeXt-B) model to accept token maps\n", + " of shape (B, embed_dim, H, W). If the backbone returns pooled features (B, embed_dim),\n", + " the wrapper falls back to a Linear head.\n", + " \"\"\"\n", + " def __init__(self, embed_dim, num_classes, pretrained=True, device=DEVICE):\n", + " super().__init__()\n", + " self.embed_dim = embed_dim\n", + " self.num_classes = num_classes\n", + " self.device = device\n", + "\n", + " # instantiate convnext_base (ConvNeXt-B ~ 'convnext_base' in timm)\n", + " # you can choose 'convnext_small', 'convnext_base', 'convnext_large' as needed\n", + " model = timm.create_model(\"convnext_base\", pretrained=pretrained)\n", + " # find & replace first Conv2d to accept `embed_dim` channels\n", + " replaced = False\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, nn.Conv2d):\n", + " # Replace the first Conv2d we find with a 1x1 projection from embed_dim -> module.out_channels\n", + " parent = model\n", + " parts = name.split('.')\n", + " for p in parts[:-1]:\n", + " parent = getattr(parent, p)\n", + " orig_name = parts[-1]\n", + " orig_conv = getattr(parent, orig_name)\n", + " new_conv = nn.Conv2d(in_channels=embed_dim,\n", + " out_channels=orig_conv.out_channels,\n", + " kernel_size=1,\n", + " stride=1,\n", + " padding=0,\n", + " bias=(orig_conv.bias is not None))\n", + " # initialize new conv with a simple projection (xavier)\n", + " nn.init.xavier_uniform_(new_conv.weight)\n", + " if new_conv.bias is not None:\n", + " nn.init.zeros_(new_conv.bias)\n", + " setattr(parent, orig_name, new_conv)\n", + " replaced = True\n", + " break\n", + " if not replaced:\n", + " raise RuntimeError(\"Failed to find a Conv2d to replace in ConvNeXt model - timm layout unexpected.\")\n", + "\n", + " # replace classifier head to our num_classes\n", + " # timm convnext uses attribute `head` typically as Linear(in_features, out_features)\n", + " if hasattr(model, \"head\") and isinstance(model.head, nn.Linear):\n", + " in_feat = model.head.in_features\n", + " model.head = nn.Linear(in_feat, num_classes)\n", + " elif hasattr(model, \"fc\") and isinstance(model.fc, nn.Linear):\n", + " in_feat = model.fc.in_features\n", + " model.fc = nn.Linear(in_feat, num_classes)\n", + " else:\n", + " # fallback identify final linear\n", + " for nm, mod in model.named_modules():\n", + " if isinstance(mod, nn.Linear):\n", + " parent = model\n", + " parts = nm.split('.')[:-1]\n", + " for p in parts:\n", + " parent = getattr(parent, p)\n", + " last = nm.split('.')[-1]\n", + " in_feat = mod.in_features\n", + " setattr(parent, last, nn.Linear(in_feat, num_classes))\n", + " break\n", + "\n", + " self.convnext = model.to(device)\n", + "\n", + " # small linear fallback head for pooled features\n", + " self.pool_head = nn.Sequential(\n", + " nn.Linear(embed_dim, embed_dim // 2),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(0.2),\n", + " nn.Linear(embed_dim // 2, num_classes)\n", + " ).to(device)\n", + "\n", + " def forward(self, feats):\n", + " \"\"\"\n", + " feats: either\n", + " - 4D tensor: (B, C, H, W) -> feed into convnext\n", + " - 2D tensor: (B, embed_dim) -> feed into pool_head\n", + " \"\"\"\n", + " if feats.ndim == 4:\n", + " # feed to convnext (it expects images/feature maps)\n", + " return self.convnext(feats)\n", + " elif feats.ndim == 2:\n", + " return self.pool_head(feats)\n", + " else:\n", + " raise ValueError(\"ConvNeXtDecoder expects 4D (B,C,H,W) or 2D (B,embed_dim) inputs.\")\n", + "\n", + "# instantiate decoder\n", + "decoder = ConvNeXtDecoder(EMBED_DIM, num_classes, pretrained=True, device=DEVICE)\n", + "decoder.to(DEVICE)\n", + "\n", + "# If you want to freeze the convnext backbone inside the decoder and only fine-tune\n", + "# the final head, uncomment:\n", + "# for p in decoder.convnext.parameters():\n", + "# p.requires_grad = False\n", + "# for p in decoder.convnext.head.parameters(): # or last linear replaced above\n", + "# p.requires_grad = True\n", + "\n", + "# -----------------------\n", + "# Update precompute_embeddings to store token maps if present\n", + "# -----------------------\n", + "def precompute_embeddings(backbone, df, transform, image_root=\"\", cache_dir=CACHE_DIR, batch_size=32, num_workers=4):\n", + " \"\"\"\n", + " Compute embeddings for a dataframe and save as .pt files in cache_dir.\n", + " If backbone returns token maps we save token maps (B, C, H, W).\n", + " If backbone returns pooled embeddings, we save (B, C).\n", + " Returns a TensorDataset(embeddings, labels)\n", + " \"\"\"\n", + " ds = ImageCSVLoader(df, image_root=image_root, transform=transform, label_map=label_map)\n", + " loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)\n", + "\n", + " all_emb = []\n", + " all_lbl = []\n", + " backbone.eval()\n", + " with torch.no_grad():\n", + " for imgs, lbls in tqdm(loader, desc=\"Precomputing embeddings\"):\n", + " imgs = imgs.to(DEVICE)\n", + " feats = backbone.forward_features(imgs)\n", + " if isinstance(feats, (list, tuple)):\n", + " feats = feats[0]\n", + " if feats.ndim == 3:\n", + " # (B, seq_len, C) -> drop CLS if present, reshape to (B, C, H, W)\n", + " seq_len = feats.shape[1]\n", + " C = feats.shape[2]\n", + " # detect cls\n", + " poss = seq_len - 1\n", + " side = int(math.sqrt(poss)) if poss > 0 else int(math.sqrt(seq_len))\n", + " if side * side == poss:\n", + " token_feats = feats[:, 1:, :] # drop cls\n", + " else:\n", + " # assume no cls\n", + " token_feats = feats\n", + " side = int(math.sqrt(token_feats.shape[1]))\n", + " B = token_feats.shape[0]\n", + " token_feats = token_feats.permute(0, 2, 1).reshape(B, C, side, side) # (B,C,H,W)\n", + " emb_tensor = token_feats.detach().cpu()\n", + " elif feats.ndim == 2:\n", + " emb_tensor = feats.detach().cpu()\n", + " else:\n", + " raise ValueError(f\"Unexpected feature shape during precomputation: {feats.shape}\")\n", + "\n", + " all_emb.append(emb_tensor)\n", + " all_lbl.append(torch.tensor(lbls))\n", + "\n", + " # concatenate along batch dim; for 4D tensors need to cat on dim=0\n", + " # ensure all_emb elements are same ndim type, else error.\n", + " all_emb = torch.cat(all_emb, dim=0)\n", + " all_lbl = torch.cat(all_lbl, dim=0)\n", + "\n", + " # save\n", + " torch.save(all_emb, os.path.join(cache_dir, \"embeddings.pt\"))\n", + " torch.save(all_lbl, os.path.join(cache_dir, \"labels.pt\"))\n", + " print(\"Saved embeddings:\", all_emb.shape)\n", + " # create TensorDataset: if 4D, need a custom dataset wrapper since TensorDataset expects tensors of same ndim works fine.\n", + " return TensorDataset(all_emb, all_lbl)\n", + "\n", + "# -----------------------\n", + "# DataLoaders: if PRECOMPUTE True, the dataset items will be tensors of shape either (C,H,W) or (embed_dim,)\n", + "# If PRECOMPUTE False, ImageCSVLoader remains unchanged and we compute on the fly.\n", + "# -----------------------\n", + "if PRECOMPUTE:\n", + " train_dataset = precompute_embeddings(backbone, train_df, transform, IMAGE_ROOT)\n", + " val_dataset = precompute_embeddings(backbone, val_df, transform, IMAGE_ROOT)\n", + "else:\n", + " train_dataset = ImageCSVLoader(train_df, image_root=IMAGE_ROOT, transform=transform, label_map=label_map)\n", + " val_dataset = ImageCSVLoader(val_df, image_root=IMAGE_ROOT, transform=transform, label_map=label_map)\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n", + "\n", + "# -----------------------\n", + "# Update training loops to feed the decoder (convnext) correctly\n", + "# -----------------------\n", + "def make_features_for_decoder(batch, backbone, precompute=PRECOMPUTE):\n", + " \"\"\"\n", + " Given a batch (either (imgs, lbl) or (emb, lbl) depending on PRECOMPUTE),\n", + " return (feats_for_decoder, labels) where feats_for_decoder is either:\n", + " - 4D tensor (B,C,H,W) to send to ConvNeXtDecoder\n", + " - 2D tensor (B,embed_dim) to send to linear fallback\n", + " \"\"\"\n", + " if precompute:\n", + " emb, lbl = batch\n", + " emb = emb.to(DEVICE)\n", + " lbl = lbl.to(DEVICE)\n", + " # emb may be 4D or 2D already\n", + " return emb, lbl\n", + " else:\n", + " imgs, lbl = batch\n", + " imgs = imgs.to(DEVICE)\n", + " lbl = lbl.to(DEVICE)\n", + " with torch.no_grad():\n", + " feats = backbone.forward_features(imgs)\n", + " if isinstance(feats, (list, tuple)):\n", + " feats = feats[0]\n", + " if feats.ndim == 3:\n", + " # (B, seq_len, C)\n", + " seq_len = feats.shape[1]\n", + " C = feats.shape[2]\n", + " poss = seq_len - 1\n", + " side = int(math.sqrt(poss)) if poss > 0 else int(math.sqrt(seq_len))\n", + " if side * side == poss:\n", + " token_feats = feats[:, 1:, :]\n", + " else:\n", + " token_feats = feats\n", + " side = int(math.sqrt(token_feats.shape[1]))\n", + " emb_tensor = token_feats.permute(0, 2, 1).reshape(feats.shape[0], C, side, side) # (B,C,H,W)\n", + " elif feats.ndim == 2:\n", + " emb_tensor = feats\n", + " else:\n", + " raise ValueError(f\"Unexpected feature shape during training/eval: {feats.shape}\")\n", + " return emb_tensor, lbl\n", + "\n", + "def train_one_epoch(backbone, decoder, loader, optimizer, criterion, device, precompute=PRECOMPUTE):\n", + " decoder.train()\n", + " running_loss = 0.0\n", + " running_correct = 0\n", + " n = 0\n", + " backbone.eval()\n", + " for batch in loader:\n", + " emb_for_dec, lbl = make_features_for_decoder(batch, backbone, precompute=precompute)\n", + " optimizer.zero_grad()\n", + " logits = decoder(emb_for_dec)\n", + " loss = criterion(logits, lbl.to(device))\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " running_loss += float(loss.item()) * lbl.size(0)\n", + " preds = logits.argmax(dim=1)\n", + " running_correct += (preds == lbl.to(device)).sum().item()\n", + " n += lbl.size(0)\n", + "\n", + " avg_loss = running_loss / n\n", + " avg_acc = running_correct / n\n", + " return avg_loss, avg_acc\n", + "\n", + "def eval_one_epoch(backbone, decoder, loader, criterion, device, precompute=PRECOMPUTE):\n", + " decoder.eval()\n", + " running_loss = 0.0\n", + " running_correct = 0\n", + " n = 0\n", + " with torch.no_grad():\n", + " for batch in loader:\n", + " emb_for_dec, lbl = make_features_for_decoder(batch, backbone, precompute=precompute)\n", + " logits = decoder(emb_for_dec)\n", + " loss = criterion(logits, lbl.to(device))\n", + " running_loss += float(loss.item()) * lbl.size(0)\n", + " preds = logits.argmax(dim=1)\n", + " running_correct += (preds == lbl.to(device)).sum().item()\n", + " n += lbl.size(0)\n", + " avg_loss = running_loss / n\n", + " avg_acc = running_correct / n\n", + " return avg_loss, avg_acc\n", + "\n", + "# -----------------------\n", + "# Optimizer & scheduler: optimize decoder parameters\n", + "# -----------------------\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.AdamW(decoder.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", + "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Admin\\miniconda3\\envs\\my_env_diam\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:666: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n", + " warnings.warn(warn_msg)\n" + ] + } + ], + "source": [ + "\n", + "# -----------------------\n", + "# Main training loop (unchanged, but using decoder instead of simple classifier)\n", + "# -----------------------\n", + "best_val_acc = 0.0\n", + "best_state = None\n", + "\n", + "for epoch in range(1, EPOCHS + 1):\n", + " train_loss, train_acc = train_one_epoch(backbone, decoder, train_loader, optimizer, criterion, DEVICE, precompute=PRECOMPUTE)\n", + " val_loss, val_acc = eval_one_epoch(backbone, decoder, val_loader, criterion, DEVICE, precompute=PRECOMPUTE)\n", + " scheduler.step()\n", + "\n", + " print(f\"[Epoch {epoch}/{EPOCHS}] Train loss: {train_loss:.4f} acc: {train_acc:.4f} | Val loss: {val_loss:.4f} acc: {val_acc:.4f}\")\n", + "\n", + " # save best\n", + " if val_acc > best_val_acc:\n", + " best_val_acc = val_acc\n", + " best_state = {\n", + " \"decoder_state\": decoder.state_dict(),\n", + " \"decoder_type\": \"convnext_base\",\n", + " \"backbone_name\": BACKBONE_NAME,\n", + " \"label_map\": label_map,\n", + " \"embedding_dim\": EMBED_DIM,\n", + " \"epoch\": epoch,\n", + " \"val_acc\": val_acc\n", + " }\n", + " torch.save(best_state, \"best_decoder_convnext.pth\")\n", + " print(\"Saved best model -> best_decoder_convnext.pth\")\n", + "\n", + "print(\"Training finished. Best val acc:\", best_val_acc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# my older approaches [v1]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9lNVUKfAOORf" + }, + "source": [ + "# Preprocessing " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "execution": { + "iopub.execute_input": "2025-09-04T13:20:12.435473Z", + "iopub.status.busy": "2025-09-04T13:20:12.435144Z", + "iopub.status.idle": "2025-09-04T13:33:30.953614Z", + "shell.execute_reply": "2025-09-04T13:33:30.952751Z", + "shell.execute_reply.started": "2025-09-04T13:20:12.435445Z" + }, + "id": "aXXoB2QLOORf", + "jupyter": { + "outputs_hidden": true + }, + "outputId": "7d7ae035-a826-499e-f367-0d70ff5f05ff", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_1\"\n", + "__________________________________________________________________________________________________\n", + " Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + " image_input (InputLayer) [(None, 64, 64, 3)] 0 [] \n", + " \n", + " conv2d_2 (Conv2D) (None, 62, 62, 32) 896 ['image_input[0][0]'] \n", + " \n", + " max_pooling2d_2 (MaxPoolin (None, 31, 31, 32) 0 ['conv2d_2[0][0]'] \n", + " g2D) \n", + " \n", + " conv2d_3 (Conv2D) (None, 29, 29, 64) 18496 ['max_pooling2d_2[0][0]'] \n", + " \n", + " max_pooling2d_3 (MaxPoolin (None, 14, 14, 64) 0 ['conv2d_3[0][0]'] \n", + " g2D) \n", + " \n", + " tabular_input (InputLayer) [(None, 45)] 0 [] \n", + " \n", + " flatten_1 (Flatten) (None, 12544) 0 ['max_pooling2d_3[0][0]'] \n", + " \n", + " dense_5 (Dense) (None, 64) 2944 ['tabular_input[0][0]'] \n", + " \n", + " dense_4 (Dense) (None, 128) 1605760 ['flatten_1[0][0]'] \n", + " \n", + " dense_6 (Dense) (None, 32) 2080 ['dense_5[0][0]'] \n", + " \n", + " concatenate_1 (Concatenate (None, 160) 0 ['dense_4[0][0]', \n", + " ) 'dense_6[0][0]'] \n", + " \n", + " dense_7 (Dense) (None, 64) 10304 ['concatenate_1[0][0]'] \n", + " \n", + " main_output (Dense) (None, 23) 1495 ['dense_7[0][0]'] \n", + " \n", + "==================================================================================================\n", + "Total params: 1641975 (6.26 MB)\n", + "Trainable params: 1641975 (6.26 MB)\n", + "Non-trainable params: 0 (0.00 Byte)\n", + "__________________________________________________________________________________________________\n", + "Starting model training...\n", + "Epoch 1/10\n", + "19506/19506 [==============================] - 87s 4ms/step - loss: 2.2157 - accuracy: 0.1718 - val_loss: 2.1867 - val_accuracy: 0.1717\n", + "Epoch 2/10\n", + "19506/19506 [==============================] - 78s 4ms/step - loss: 2.1848 - accuracy: 0.1837 - val_loss: 2.1755 - val_accuracy: 0.1809\n", + "Epoch 3/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1738 - accuracy: 0.1850 - val_loss: 2.1901 - val_accuracy: 0.1786\n", + "Epoch 4/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1675 - accuracy: 0.1879 - val_loss: 2.1712 - val_accuracy: 0.1838\n", + "Epoch 5/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1644 - accuracy: 0.1914 - val_loss: 2.1744 - val_accuracy: 0.1863\n", + "Epoch 6/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1598 - accuracy: 0.1902 - val_loss: 2.1714 - val_accuracy: 0.1856\n", + "Epoch 7/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1569 - accuracy: 0.1925 - val_loss: 2.1853 - val_accuracy: 0.1854\n", + "Epoch 8/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1529 - accuracy: 0.1914 - val_loss: 2.1728 - val_accuracy: 0.1824\n", + "Epoch 9/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1497 - accuracy: 0.1930 - val_loss: 2.1745 - val_accuracy: 0.1886\n", + "Epoch 10/10\n", + "19506/19506 [==============================] - 77s 4ms/step - loss: 2.1508 - accuracy: 0.1930 - val_loss: 2.1736 - val_accuracy: 0.1845\n", + "Model training completed.\n", + "\n", + "Making a prediction on a new data point...\n", + "1/1 [==============================] - 0s 105ms/step\n", + "True color: E\n", + "Predicted color: G\n", + "Prediction probabilities: [[6.0091116e-12 1.2939192e-01 5.5591125e-11 1.1998820e-01 1.5205914e-01\n", + " 1.4335742e-03 6.1306266e-11 1.7045374e-01 1.4999163e-01 1.2224232e-01\n", + " 6.1535770e-11 8.1945576e-02 3.8573243e-02 2.1496758e-02 8.8118035e-03\n", + " 3.2531230e-03 1.8102876e-04 3.8288243e-05 8.0320082e-05 4.1988973e-05\n", + " 4.4954457e-11 5.6633617e-06 1.1639663e-05]]\n" + ] + } + ], + "source": [ + "# # Create a DataFrame from the dummy data\n", + "# # Create a DataFrame from the dummy data\n", + "# df = pd.DataFrame(data)\n", + "\n", + "# # Define a function to simulate loading an image.\n", + "# # In a real-world scenario, you would use a library like `Pillow` to load the actual image.\n", + "# # For this example, we'll create a dummy image with random pixel values.\n", + "# def load_image(image_path, target_size=(64, 64)):\n", + "# \"\"\"Simulates loading and preprocessing an image.\"\"\"\n", + "# # This is a placeholder. A real implementation would use:\n", + "# # from PIL import Image\n", + "# # img = Image.open(image_path).resize(target_size)\n", + "# # return np.array(img) / 255.0 # Normalize pixel values\n", + "\n", + "# # Placeholder: return a random numpy array as a dummy image\n", + "# return np.random.rand(target_size[0], target_size[1], 3)\n", + "\n", + "# # Load and preprocess all images\n", + "# image_data = np.array([load_image(path) for path in df['full_path_to_img']])\n", + "\n", + "# # Prepare the tabular data\n", + "# # Identify categorical and numerical features\n", + "# categorical_features = ['clarity', 'cut', 'polish', 'symmetry', 'fluorescence', 'lab']\n", + "# numerical_features = ['carat', 'length', 'width', 'depth']\n", + "# target_feature = 'colour'\n", + "\n", + "# # Separate features and target\n", + "# X_tabular = df[categorical_features + numerical_features]\n", + "# y = df[target_feature]\n", + "\n", + "# # One-hot encode categorical features\n", + "# X_tabular = pd.get_dummies(X_tabular, columns=categorical_features, drop_first=True)\n", + "\n", + "# # Scale numerical features\n", + "# scaler = StandardScaler()\n", + "# X_tabular[numerical_features] = scaler.fit_transform(X_tabular[numerical_features])\n", + "\n", + "# # Encode the target variable (diamond color)\n", + "# label_encoder = LabelEncoder()\n", + "# y_encoded = label_encoder.fit_transform(y)\n", + "# num_classes = len(label_encoder.classes_)\n", + "\n", + "# # Split the data into training and testing sets.\n", + "# # The `stratify` parameter has been removed to avoid the ValueError,\n", + "# # as some classes have only one sample in the dummy dataset.\n", + "# X_img_train, X_img_test, X_tab_train, X_tab_test, y_train, y_test = train_test_split(\n", + "# image_data, X_tabular, y_encoded, test_size=0.2, random_state=42\n", + "# )\n", + "\n", + "# # ----------------------------------------------------------------------\n", + "# # 2. Build the Multi-Modal Model Architecture\n", + "# # ----------------------------------------------------------------------\n", + "\n", + "# # Define the image input branch (CNN)\n", + "# image_input = Input(shape=(64, 64, 3), name='image_input')\n", + "# x = Conv2D(32, (3, 3), activation='relu')(image_input)\n", + "# x = MaxPooling2D(pool_size=(2, 2))(x)\n", + "# x = Conv2D(64, (3, 3), activation='relu')(x)\n", + "# x = MaxPooling2D(pool_size=(2, 2))(x)\n", + "# x = Flatten()(x)\n", + "# image_output = Dense(128, activation='relu')(x)\n", + "\n", + "# # Define the tabular data input branch (Dense layers)\n", + "# tabular_input = Input(shape=(X_tab_train.shape[1],), name='tabular_input')\n", + "# y = Dense(64, activation='relu')(tabular_input)\n", + "# tabular_output = Dense(32, activation='relu')(y)\n", + "\n", + "# # Concatenate the outputs of both branches\n", + "# combined = concatenate([image_output, tabular_output])\n", + "\n", + "# # Add a final dense layer for prediction\n", + "# z = Dense(64, activation='relu')(combined)\n", + "# output = Dense(num_classes, activation='softmax', name='main_output')(z)\n", + "\n", + "# # Create the final model with two inputs and one output\n", + "# model = Model(inputs=[image_input, tabular_input], outputs=output)\n", + "\n", + "# # Compile the model\n", + "# model.compile(optimizer='adam',\n", + "# loss='sparse_categorical_crossentropy',\n", + "# metrics=['accuracy'])\n", + "\n", + "# # Display the model architecture\n", + "# model.summary()\n", + "\n", + "# # ----------------------------------------------------------------------\n", + "# # 3. Train the Model\n", + "# # ----------------------------------------------------------------------\n", + "\n", + "# print(\"Starting model training...\")\n", + "# # Train the model with both inputs. Cast data to float32 to prevent errors.\n", + "# # Note: Since the dataset is very small, the accuracy will not be meaningful.\n", + "# # This is for demonstration purposes only.\n", + "# history = model.fit(\n", + "# {'image_input': X_img_train.astype(np.float32), 'tabular_input': X_tab_train.astype(np.float32)},\n", + "# {'main_output': y_train},\n", + "# epochs=10,\n", + "# batch_size=2,\n", + "# validation_data=({'image_input': X_img_test.astype(np.float32), 'tabular_input': X_tab_test.astype(np.float32)}, {'main_output': y_test})\n", + "# )\n", + "# print(\"Model training completed.\")\n", + "\n", + "# # ----------------------------------------------------------------------\n", + "# # 4. Make a Prediction\n", + "# # ----------------------------------------------------------------------\n", + "\n", + "# print(\"\\nMaking a prediction on a new data point...\")\n", + "# # Select a sample from the test set for prediction\n", + "# sample_index = 0\n", + "# sample_image = X_img_test[sample_index:sample_index+1]\n", + "# sample_tabular = X_tab_test[sample_index:sample_index+1]\n", + "# true_label = label_encoder.inverse_transform([y_test[sample_index]])[0]\n", + "\n", + "# # Make a prediction. Cast data to float32 to prevent errors.\n", + "# predictions = model.predict({'image_input': sample_image.astype(np.float32), 'tabular_input': sample_tabular.astype(np.float32)})\n", + "# predicted_label_index = np.argmax(predictions, axis=1)[0]\n", + "# predicted_label = label_encoder.inverse_transform([predicted_label_index])[0]\n", + "\n", + "# print(f\"True color: {true_label}\")\n", + "# print(f\"Predicted color: {predicted_label}\")\n", + "# print(f\"Prediction probabilities: {predictions}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-05T08:40:14.186159Z", + "iopub.status.busy": "2025-09-05T08:40:14.185809Z", + "iopub.status.idle": "2025-09-05T08:40:14.190713Z", + "shell.execute_reply": "2025-09-05T08:40:14.189879Z", + "shell.execute_reply.started": "2025-09-05T08:40:14.186125Z" + }, + "id": "cdq4FlcAOORf", + "trusted": true + }, + "outputs": [], + "source": [ + "# !pip install --upgrade ipywidgets widgetsnbextension jupyterlab-widgets\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HSJq3zsqOORf" + }, + "source": [ + "# Base VIT model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "LspRNt2tOORf", + "jupyter": { + "outputs_hidden": true, + "source_hidden": true + }, + "trusted": true + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "Multimodal ViT + Tabular training skeleton for predicting final GIA color grade\n", + "from rough polished-diamond images + tabular metadata.\n", + "\n", + "Key upgrades for low-VRAM (Kaggle/Colab Free):\n", + " - Auto GPU-aware CONFIG (img_size, batch_size, AMP, freezing)\n", + " - Gradient Accumulation (keeps effective batch without extra VRAM)\n", + " - timm Grad-Checkpointing (activation checkpointing) when available\n", + " - OOM-retry loop (batch↓ -> freeze backbone -> image res↓ -> smaller backbone)\n", + "\n", + "This file is an updated version that fixes the timm \"strict image size\" assertion by\n", + "creating the backbone with the runtime img_size when supported, and by turning off\n", + "patch_embed.strict_img_size when present.\n", + "\"\"\"\n", + "\n", + "# =========================\n", + "# Pre-import memory tuning\n", + "# =========================\n", + "import os\n", + "os.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"max_split_size_mb:128\")\n", + "\n", + "import math\n", + "import random\n", + "from pathlib import Path\n", + "from typing import List, Dict\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from PIL import Image\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torchvision import transforms\n", + "import timm\n", + "\n", + "from sklearn.preprocessing import LabelEncoder, StandardScaler\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score, confusion_matrix\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "# ------------------------\n", + "# Auto GPU-aware CONFIG\n", + "# ------------------------\n", + "def get_runtime_config():\n", + " \"\"\"Return a conservative CONFIG tuned to available (v)RAM.\"\"\"\n", + " seed = 42\n", + " base = {\n", + " \"epochs\": 8,\n", + " \"lr\": 3e-4,\n", + " \"weight_decay\": 0.05,\n", + " \"num_workers\": 0, # Fixed: Use 0 workers to avoid multiprocessing issues\n", + " \"seed\": seed,\n", + " \"fusion_mode\": \"late\", # cheapest memory mode\n", + " \"pretrained\": True,\n", + " \"tab_emb_dim\": 128,\n", + " \"hidden_head_dim\": 256,\n", + " \"vpt\": False, # off by default to keep code simple\n", + " \"vpt_num_prompts\": 10,\n", + " \"accum_steps\": 1, # updated below\n", + " \"num_classes\": None,\n", + " }\n", + "\n", + " torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)\n", + "\n", + " if not torch.cuda.is_available():\n", + " base.update({\n", + " \"device\": \"cpu\",\n", + " \"use_amp\": False,\n", + " \"freeze_backbone\": True,\n", + " \"img_size\": 96,\n", + " \"batch_size\": 4,\n", + " \"backbone_name\": \"vit_tiny_patch16_224\",\n", + " })\n", + " return base\n", + "\n", + " dev = torch.cuda.get_device_properties(0)\n", + " total_gb = dev.total_memory / (1024 ** 3)\n", + "\n", + " # Defaults for ~16GB GPUs (T4/P100/V100)\n", + " base.update({\n", + " \"device\": \"cuda\",\n", + " \"use_amp\": True,\n", + " \"freeze_backbone\": True, # PEFT-style by default on small VRAM\n", + " \"img_size\": 128,\n", + " \"batch_size\": 2,\n", + " \"accum_steps\": 4, # effective batch = 8\n", + " \"backbone_name\": \"vit_small_patch16_224\",\n", + " })\n", + "\n", + " # Adjust roughly by VRAM\n", + " if total_gb <= 8:\n", + " base.update({\"img_size\": 96, \"batch_size\": 1, \"accum_steps\": 8, \"backbone_name\": \"vit_tiny_patch16_224\"})\n", + " elif total_gb <= 16:\n", + " base.update({\"img_size\": 128, \"batch_size\": 2, \"accum_steps\": 4, \"backbone_name\": \"vit_small_patch16_224\"})\n", + " elif total_gb <= 24:\n", + " base.update({\"img_size\": 160, \"batch_size\": 4, \"accum_steps\": 4, \"freeze_backbone\": False, \"backbone_name\": \"vit_base_patch16_224\"})\n", + " else:\n", + " base.update({\"img_size\": 224, \"batch_size\": 8, \"accum_steps\": 4, \"freeze_backbone\": False, \"backbone_name\": \"vit_base_patch16_224\"})\n", + "\n", + " print(f\"[CONFIG] GPU: {dev.name} | VRAM={total_gb:.1f} GB -> img={base['img_size']} bs={base['batch_size']} \"\n", + " f\"accum={base['accum_steps']} freeze={base['freeze_backbone']} backbone={base['backbone_name']} AMP={base['use_amp']}\")\n", + " return base\n", + "\n", + "CONFIG = get_runtime_config()\n", + "\n", + "# ------------------------\n", + "# Example dummy dataframe (replace with your full dataset)\n", + "# ------------------------\n", + "# (User-provided sample) - here for reference / quick run\n", + "# sample_data = {\n", + "# 'full_path_to_img': [\n", + "# '/kaggle/input/diamond-images-dataset/web_scraped/cushion/2106452.jpg',\n", + "# '/kaggle/input/diamond-images-dataset/web_scraped/cushion/2042329.jpg',\n", + "# '/kaggle/input/diamond-images-dataset/web_scraped/cushion/2055268.jpg',\n", + "# '/kaggle/input/diamond-images-dataset/web_scraped/cushion/2128779.jpg',\n", + "# '/kaggle/input/diamond-images-dataset/web_scraped/cushion/2103991.jpg'\n", + "# ],\n", + "# 'carat': [0.55, 0.52, 0.50, 0.50, 0.51],\n", + "# 'clarity': ['SI2', 'SI2', 'SI1', 'VS2', 'SI1'],\n", + "# 'colour': ['NVG', 'Y-Z', 'LVG', 'M', 'M'],\n", + "# 'cut': ['EX', 'EX', 'VG', 'EX', 'EX'],\n", + "# 'polish': ['VG', 'EX', 'EX', 'VG', 'VG'],\n", + "# 'symmetry': ['NG', 'FG', 'NG', 'FG', 'NG'],\n", + "# 'fluorescence': ['I', 'FG', 'I', 'I', 'I'],\n", + "# 'lab': ['GIA', 'GIA', 'GIA', 'GIA', 'GIA'],\n", + "# 'length': [4.56, 4.40, 4.87, 4.73, 4.47],\n", + "# 'width': [4.44, 4.31, 4.19, 4.28, 4.44],\n", + "# 'depth': [3.09, 2.99, 2.89, 2.84, 3.05],\n", + "# }\n", + "\n", + "# If you have a CSV, load instead: df = pd.read_csv('diamond_data.csv')\n", + "# df = pd.DataFrame(sample_data)\n", + "df = data.copy()\n", + "\n", + "\n", + "# ------------------------\n", + "# Tabular preprocessing\n", + "# ------------------------\n", + "class TabularPreprocessor:\n", + " def __init__(self, categorical_cols: List[str], numeric_cols: List[str]):\n", + " self.categorical_cols = categorical_cols\n", + " self.numeric_cols = numeric_cols\n", + " self.label_encoders: Dict[str, LabelEncoder] = {}\n", + " self.scaler = StandardScaler()\n", + "\n", + " def fit(self, df: pd.DataFrame):\n", + " for c in self.categorical_cols:\n", + " le = LabelEncoder()\n", + " df[c] = df[c].astype(str).fillna('NA')\n", + " le.fit(df[c].values)\n", + " self.label_encoders[c] = le\n", + " if len(self.numeric_cols):\n", + " self.scaler.fit(df[self.numeric_cols].astype(float).values)\n", + "\n", + " def transform(self, df: pd.DataFrame):\n", + " cat_arrays = []\n", + " for c in self.categorical_cols:\n", + " arr = df[c].astype(str).fillna('NA').values\n", + " le = self.label_encoders[c]\n", + " cat_arrays.append(le.transform(arr))\n", + " cats = np.stack(cat_arrays, axis=1) if len(cat_arrays) else np.zeros((len(df), 0), dtype=np.int64)\n", + " nums = self.scaler.transform(df[self.numeric_cols].astype(float).values) if len(self.numeric_cols) else np.zeros((len(df), 0))\n", + " return cats, nums\n", + "\n", + " def get_cardinalities(self):\n", + " return {c: len(self.label_encoders[c].classes_) for c in self.categorical_cols}\n", + "\n", + "# ------------------------\n", + "# Robust split helper\n", + "# ------------------------\n", + "def safe_train_val_split(df_in: pd.DataFrame, target_col: str, test_size: float = 0.2, seed: int = 42):\n", + " if len(df_in) <= 10:\n", + " return df_in, df_in\n", + " counts = df_in[target_col].value_counts()\n", + " if counts.min() < 2:\n", + " print(f\"[warning] Stratified split not possible (min class={int(counts.min())}). Using random split.\")\n", + " return train_test_split(df_in, test_size=test_size, random_state=seed, shuffle=True)\n", + " return train_test_split(df_in, test_size=test_size, random_state=seed, stratify=df_in[target_col])\n", + "\n", + "# ------------------------\n", + "# Transforms & Dataset\n", + "# ------------------------\n", + "def get_transforms(img_size=224, train=True):\n", + " if train:\n", + " return transforms.Compose([\n", + " transforms.Resize((img_size, img_size)),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomRotation(10),\n", + " transforms.ColorJitter(0.05,0.05,0.05,0.01),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n", + " ])\n", + " else:\n", + " return transforms.Compose([\n", + " transforms.Resize((img_size, img_size)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n", + " ])\n", + "\n", + "class MultiModalDiamondDataset(Dataset):\n", + " def __init__(self, df: pd.DataFrame, tab_preproc: TabularPreprocessor,\n", + " categorical_cols: List[str], numeric_cols: List[str],\n", + " target_col: str, train=True):\n", + " self.df = df.reset_index(drop=True)\n", + " self.transform = get_transforms(CONFIG['img_size'], train)\n", + " self.tab_preproc = tab_preproc\n", + " self.categorical_cols = categorical_cols\n", + " self.numeric_cols = numeric_cols\n", + " self.target_col = target_col\n", + " self.cats, self.nums = tab_preproc.transform(self.df)\n", + "\n", + " def __len__(self): return len(self.df)\n", + "\n", + " def __getitem__(self, idx):\n", + " row = self.df.iloc[idx]\n", + " img_path = row['full_path_to_img']\n", + " try:\n", + " img = Image.open(img_path).convert('RGB')\n", + " except Exception:\n", + " img = Image.new('RGB', (CONFIG['img_size'], CONFIG['img_size']), (0,0,0))\n", + " img = self.transform(img)\n", + "\n", + " cat = torch.tensor(self.cats[idx].astype(np.int64)) if self.cats.shape[1] else torch.empty(0, dtype=torch.long)\n", + " num = torch.tensor(self.nums[idx].astype(np.float32)) if self.nums.shape[1] else torch.empty(0, dtype=torch.float32)\n", + " target = row[self.target_col]\n", + " return {'image': img, 'cat': cat, 'num': num, 'target': torch.tensor(target, dtype=torch.long)}\n", + "\n", + "# ------------------------\n", + "# Model components\n", + "# ------------------------\n", + "class SimpleTabularEncoder(nn.Module):\n", + " def __init__(self, cardinalities: Dict[str,int], numeric_dim:int, emb_dim=128, hidden_dim=256):\n", + " super().__init__()\n", + " self.cat_cols = list(cardinalities.keys())\n", + " self.embs = nn.ModuleDict()\n", + " for k in self.cat_cols:\n", + " card = cardinalities[k]\n", + " self.embs[k] = nn.Embedding(card, min(50, (card+1)//2))\n", + " cat_total_dim = sum([self.embs[k].embedding_dim for k in self.cat_cols]) if self.cat_cols else 0\n", + " self.numeric_dim = numeric_dim\n", + " in_dim = cat_total_dim + numeric_dim\n", + " self.net = nn.Sequential(\n", + " nn.Linear(max(1, in_dim), hidden_dim),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(hidden_dim),\n", + " nn.Linear(hidden_dim, emb_dim),\n", + " nn.ReLU()\n", + " )\n", + "\n", + " def forward(self, cat: torch.Tensor, num: torch.Tensor):\n", + " device = next(self.parameters()).device\n", + " if cat.shape[1] > 0:\n", + " emb_list = [self.embs[k](cat[:, i]) for i, k in enumerate(self.cat_cols)]\n", + " cat_emb = torch.cat(emb_list, dim=1)\n", + " else:\n", + " cat_emb = torch.zeros((cat.shape[0], 0), device=device)\n", + " x = torch.cat([cat_emb, num], dim=1) if num.shape[1] > 0 else cat_emb\n", + " if x.numel() == 0:\n", + " return torch.zeros((cat.shape[0], CONFIG['tab_emb_dim']), device=device)\n", + " return self.net(x)\n", + "\n", + "class MultiModalModel(nn.Module):\n", + " def __init__(self, cfg: Dict, cardinalities: Dict[str,int], numeric_dim:int, num_classes:int):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " # CRITICAL FIX: Pass the runtime img_size to timm.create_model\n", + " # This prevents the 'Input height... doesn't match model' assertion error.\n", + " self.backbone = timm.create_model(cfg['backbone_name'], pretrained=cfg['pretrained'], num_classes=0, img_size=cfg['img_size'])\n", + "\n", + " # Additionally, disable strict image size checks if the model supports it.\n", + " # This is a good practice for robustness.\n", + " if hasattr(self.backbone, 'patch_embed') and hasattr(self.backbone.patch_embed, 'strict_img_size'):\n", + " self.backbone.patch_embed.strict_img_size = False\n", + "\n", + " # enable grad checkpointing (if available in timm model)\n", + " if hasattr(self.backbone, \"set_grad_checkpointing\"):\n", + " self.backbone.set_grad_checkpointing(enable=True)\n", + "\n", + " embed_dim = getattr(self.backbone, 'num_features', getattr(self.backbone, 'embed_dim', 768))\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.tab_encoder = SimpleTabularEncoder(cardinalities, numeric_dim, emb_dim=cfg['tab_emb_dim'], hidden_dim=cfg['hidden_head_dim'])\n", + "\n", + " head_in = embed_dim + cfg['tab_emb_dim'] # late fusion\n", + " self.head = nn.Sequential(\n", + " nn.Linear(head_in, cfg['hidden_head_dim']),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(cfg['hidden_head_dim']),\n", + " nn.Linear(cfg['hidden_head_dim'], num_classes)\n", + " )\n", + "\n", + " def forward(self, image: torch.Tensor, cat: torch.Tensor, num: torch.Tensor):\n", + " tab_emb = self.tab_encoder(cat, num) # (B, tab_emb_dim)\n", + " # CRITICAL FIX: Extract only the CLS token from the backbone output.\n", + " # The CLS token is at index 0 of the sequence.\n", + " img_cls = self.backbone.forward_features(image)[:, 0] # (B, D)\n", + " x = torch.cat([img_cls, tab_emb], dim=1)\n", + " return self.head(x)\n", + "\n", + "# ------------------------\n", + "# Train / Validate\n", + "# ------------------------\n", + "def train_one_epoch(model, loader, optimizer, device, epoch, scaler=None):\n", + " model.train()\n", + " losses = []\n", + " pbar = tqdm(loader, desc=f\"Train {epoch}\")\n", + " criterion = nn.CrossEntropyLoss()\n", + " accum_steps = CONFIG.get('accum_steps', 1)\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " for i, batch in enumerate(pbar):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0),0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0),0), dtype=torch.float32, device=device)\n", + " targets = batch['target'].to(device, non_blocking=True)\n", + "\n", + " if CONFIG['use_amp'] and scaler is not None:\n", + " with torch.cuda.amp.autocast():\n", + " logits = model(imgs, cat, num)\n", + " loss = criterion(logits, targets) / accum_steps\n", + " scaler.scale(loss).backward()\n", + " else:\n", + " logits = model(imgs, cat, num)\n", + " loss = criterion(logits, targets) / accum_steps\n", + " loss.backward()\n", + "\n", + " if (i + 1) % accum_steps == 0:\n", + " if scaler is not None:\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " else:\n", + " optimizer.step()\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " losses.append(loss.item() * accum_steps)\n", + " pbar.set_postfix(loss=np.mean(losses))\n", + " return float(np.mean(losses))\n", + "\n", + "@torch.no_grad()\n", + "def validate(model, loader, device):\n", + " model.eval()\n", + " preds, trues = [], []\n", + " for batch in tqdm(loader, desc='Val'):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0),0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0),0), dtype=torch.float32, device=device)\n", + " targets = batch['target'].to(device, non_blocking=True)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds.append(logits.argmax(dim=1).cpu().numpy())\n", + " trues.append(targets.cpu().numpy())\n", + " preds = np.concatenate(preds) if len(preds) else np.array([])\n", + " trues = np.concatenate(trues) if len(trues) else np.array([])\n", + " acc = accuracy_score(trues, preds) if trues.size else 0.0\n", + " cm = confusion_matrix(trues, preds, labels=np.unique(trues)) if trues.size else np.zeros((0,0), dtype=int)\n", + " return acc, cm\n", + "\n", + "# ------------------------\n", + "# Runner\n", + "# ------------------------\n", + "def train_pipeline(train_df: pd.DataFrame, val_df: pd.DataFrame, target_col: str = 'colour'):\n", + " categorical_cols = ['clarity', 'cut', 'polish', 'symmetry', 'fluorescence', 'lab']\n", + " numeric_cols = ['carat', 'length', 'width', 'depth']\n", + "\n", + " # --- Preprocessing Step (FIT ON ALL DATA) ---\n", + " # CRITICAL FIX: Fit the preprocessor on the combined dataset\n", + " df_combined = pd.concat([train_df, val_df], ignore_index=True)\n", + " te = LabelEncoder()\n", + " df_combined[target_col] = te.fit_transform(df_combined[target_col].astype(str))\n", + " CONFIG['num_classes'] = len(te.classes_)\n", + " tab_pre = TabularPreprocessor(categorical_cols, numeric_cols)\n", + " tab_pre.fit(df_combined)\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " # Transform the training and validation sets separately\n", + " train_df[target_col] = te.transform(train_df[target_col].astype(str))\n", + " # CRITICAL FIX: Add this line to transform the validation set target as well.\n", + " val_df[target_col] = te.transform(val_df[target_col].astype(str))\n", + "\n", + " # OOM-resilient attempts\n", + " max_attempts, attempt = 4, 0\n", + " last_exc = None\n", + "\n", + " while attempt <= max_attempts:\n", + " try:\n", + " print(f\"[run] Attempt {attempt+1} | bs={CONFIG['batch_size']} | img={CONFIG['img_size']} | \"\n", + " f\"backbone={CONFIG['backbone_name']} | freeze={CONFIG['freeze_backbone']} | accum={CONFIG['accum_steps']}\")\n", + "\n", + " pin = torch.cuda.is_available()\n", + " train_ds = MultiModalDiamondDataset(train_df, tab_pre, categorical_cols, numeric_cols, target_col, train=True)\n", + " val_ds = MultiModalDiamondDataset(val_df, tab_pre, categorical_cols, numeric_cols, target_col, train=False)\n", + " train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True,\n", + " num_workers=CONFIG['num_workers'], pin_memory=pin, persistent_workers=False)\n", + " val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False,\n", + " num_workers=CONFIG['num_workers'], pin_memory=pin, persistent_workers=False)\n", + "\n", + " device = torch.device(CONFIG['device'] if torch.cuda.is_available() else \"cpu\")\n", + " if torch.cuda.is_available():\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.cuda.empty_cache()\n", + "\n", + " model = MultiModalModel(CONFIG, card, numeric_dim=len(numeric_cols), num_classes=CONFIG['num_classes'])\n", + "\n", + " # freeze backbone if requested\n", + " if CONFIG['freeze_backbone']:\n", + " for p in model.backbone.parameters(): p.requires_grad = False\n", + " for p in model.head.parameters(): p.requires_grad = True\n", + " for p in model.tab_encoder.parameters(): p.requires_grad = True\n", + "\n", + " model.to(device)\n", + "\n", + " # optimizer only for trainable params\n", + " opt = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad),\n", + " lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])\n", + " scaler = torch.cuda.amp.GradScaler() if (CONFIG['use_amp'] and torch.cuda.is_available()) else None\n", + "\n", + " best_acc = 0.0\n", + " for epoch in range(CONFIG['epochs']):\n", + " train_loss = train_one_epoch(model, train_loader, opt, device, epoch, scaler=scaler)\n", + " acc, cm = validate(model, val_loader, device)\n", + " print(f\"Epoch {epoch} | Train loss: {train_loss:.4f} | Val Acc: {acc:.4f}\")\n", + " print(cm)\n", + " if acc > best_acc:\n", + " best_acc = acc\n", + " # Save the model state and the preprocessor\n", + " torch.save({\"state_dict\": model.state_dict(),\n", + " \"classes\": te.classes_,\n", + " \"config\": CONFIG,\n", + " \"tab_pre\": tab_pre}, \"best_multimodal.pth\")\n", + " print(f\"Model saved with new best accuracy: {best_acc:.4f}\")\n", + " print(\"Best val acc:\", best_acc)\n", + " break\n", + "\n", + " except RuntimeError as e:\n", + " last_exc = e\n", + " msg = str(e).lower()\n", + " if (\"out of memory\" in msg) or (\"cuda out of memory\" in msg):\n", + " print(f\"[OOM] attempt {attempt+1}: {e}\")\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + "\n", + " # Mitigation ladder\n", + " if attempt == 0:\n", + " old = CONFIG['batch_size']\n", + " CONFIG['batch_size'] = max(1, old // 2)\n", + " CONFIG['accum_steps'] = max(CONFIG['accum_steps'], 2)\n", + " print(f\"[mitigation] batch {old} -> {CONFIG['batch_size']} | accum -> {CONFIG['accum_steps']}\")\n", + " elif attempt == 1:\n", + " CONFIG['freeze_backbone'] = True\n", + " print(\"[mitigation] freeze_backbone=True\")\n", + " elif attempt == 2:\n", + " old = CONFIG['img_size']\n", + " CONFIG['img_size'] = max(96, old // 2)\n", + " print(f\"[mitigation] img_size {old} -> {CONFIG['img_size']}\")\n", + " elif attempt == 3:\n", + " old = CONFIG['backbone_name']\n", + " CONFIG['backbone_name'] = 'vit_tiny_patch16_224' if 'small' in old else 'vit_small_patch16_224'\n", + " print(f\"[mitigation] backbone {old} -> {CONFIG['backbone_name']}\")\n", + " else:\n", + " print(\"[OOM] All mitigations exhausted.\")\n", + " raise\n", + " attempt += 1\n", + " continue\n", + " else:\n", + " raise\n", + "\n", + " else:\n", + " if last_exc is not None:\n", + " raise last_exc\n", + "\n", + "def test_model(test_df: pd.DataFrame, target_col: str = 'colour', model_path: str = \"best_multimodal.pth\", device: Optional[torch.device] = None, save_preds_csv: bool = True, top_n_mismatch: int = 50, save_dir: str = \"/kaggle/working\"):\n", + " \"\"\"\n", + " Loads the saved model and evaluates it on a provided test dataset.\n", + " Improvements:\n", + " - uses map_location when loading checkpoints for device portability\n", + " - safely handles unseen labels in test set (drops them and warns)\n", + " - supports running when test labels are missing (returns predictions)\n", + " - sets DataLoader pin_memory / num_workers appropriately for device\n", + " - saves plots and CSVs into a specified directory (default: /kaggle/working)\n", + " - returns (pred_labels, acc, cm, plots) when labels exist, else (pred_labels, plots)\n", + " \"\"\"\n", + " import matplotlib.pyplot as plt\n", + " os.makedirs(save_dir, exist_ok=True)\n", + "\n", + " if not os.path.exists(model_path):\n", + " print(f\"Error: Model file '{model_path}' not found. Please train the model first.\")\n", + " return None\n", + "\n", + " # Load checkpoint onto CPU first (portable)\n", + " checkpoint = torch.load(model_path, map_location='cpu')\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " saved_classes = checkpoint[\"classes\"]\n", + " saved_config = checkpoint[\"config\"]\n", + " tab_pre = checkpoint[\"tab_pre\"]\n", + " save_dir: str = \"/kaggle/working\"\n", + "\n", + " print(\"--- Starting Test Evaluation ---\")\n", + " print(f\"Loading model trained on classes: {saved_classes}\")\n", + "\n", + " # Build a label encoder equivalent for mapping\n", + " te = LabelEncoder()\n", + " te.fit(saved_classes)\n", + "\n", + " test_df = test_df.copy()\n", + "\n", + " # If test contains target column, try to map to training class indices\n", + " y_true = None\n", + " if target_col in test_df.columns:\n", + " vals = test_df[target_col].astype(str).values\n", + " class_to_idx = {c: i for i, c in enumerate(saved_classes)}\n", + " mapped = np.array([class_to_idx.get(v, -1) for v in vals], dtype=int)\n", + "\n", + " if (mapped == -1).any():\n", + " n_bad = int((mapped == -1).sum())\n", + " print(f\"[warning] {n_bad} rows in test_df have labels not seen in training. These rows will be dropped for scoring.\")\n", + " mask = mapped != -1\n", + " test_df = test_df.loc[mask].reset_index(drop=True)\n", + " mapped = mapped[mask]\n", + "\n", + " # store true integer labels for scoring\n", + " if mapped.size:\n", + " y_true = mapped\n", + " test_df[target_col] = mapped\n", + " else:\n", + " y_true = None\n", + "\n", + " # Recreate model & move to device\n", + " categorical_cols = getattr(tab_pre, \"categorical_cols\", [])\n", + " numeric_cols = getattr(tab_pre, \"numeric_cols\", [])\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " if device is None:\n", + " if torch.cuda.is_available() and saved_config.get('device', 'cpu') == 'cuda':\n", + " device = torch.device('cuda')\n", + " else:\n", + " device = torch.device('cpu')\n", + "\n", + " model = MultiModalModel(saved_config, card, numeric_dim=len(numeric_cols), num_classes=len(saved_classes))\n", + " model.load_state_dict(state_dict)\n", + " model.to(device)\n", + "\n", + " # Build dataloader (pin_memory only if using CUDA)\n", + " pin = device.type == 'cuda'\n", + " num_workers = min(4, saved_config.get('num_workers', 0)) if pin else 0\n", + " test_ds = MultiModalDiamondDataset(test_df, tab_pre, categorical_cols, numeric_cols, target_col, train=False)\n", + " test_loader = DataLoader(test_ds, batch_size=max(1, saved_config.get('batch_size', 1)),\n", + " shuffle=False, num_workers=num_workers, pin_memory=pin)\n", + "\n", + " # Run evaluation / prediction\n", + " model.eval()\n", + " all_preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(test_loader, desc=\"Test\"):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.float32, device=device)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds = logits.argmax(dim=1).cpu().numpy()\n", + " all_preds.append(preds)\n", + "\n", + " preds = np.concatenate(all_preds) if len(all_preds) else np.array([], dtype=int)\n", + " pred_labels = te.inverse_transform(preds) if preds.size else np.array([])\n", + "\n", + " compare_df = test_df.copy()\n", + " if target_col in compare_df.columns and y_true is not None:\n", + " compare_df['actual_label'] = [saved_classes[int(i)] for i in y_true]\n", + " else:\n", + " compare_df['actual_label'] = None\n", + " compare_df['predicted_label'] = list(pred_labels)[:len(compare_df)]\n", + "\n", + " plots = {}\n", + "\n", + " if y_true is not None and len(y_true):\n", + " acc = accuracy_score(y_true, preds)\n", + " cm = confusion_matrix(y_true, preds, labels=np.arange(len(saved_classes)))\n", + "\n", + " # Confusion matrix plot\n", + " fig_cm, ax = plt.subplots(figsize=(6, 6))\n", + " im = ax.imshow(cm, interpolation='nearest')\n", + " ax.set_title('Confusion Matrix')\n", + " ax.set_xlabel('Predicted')\n", + " ax.set_ylabel('Actual')\n", + "\n", + " classes = list(saved_classes)\n", + " ax.set_xticks(np.arange(len(classes)))\n", + " ax.set_yticks(np.arange(len(classes)))\n", + " ax.set_xticklabels(classes, rotation=90)\n", + " ax.set_yticklabels(classes)\n", + " fig_cm.colorbar(im, ax=ax)\n", + " plt.tight_layout()\n", + " cm_path = os.path.join(save_dir, 'confusion_matrix.png')\n", + " fig_cm.savefig(cm_path, dpi=150)\n", + " plt.close(fig_cm)\n", + " plots['confusion_matrix'] = cm_path\n", + "\n", + " # Actual vs Predicted counts (bar chart)\n", + " actual_counts = np.bincount(y_true, minlength=len(classes))\n", + " pred_counts = np.bincount(preds, minlength=len(classes))\n", + " x = np.arange(len(classes))\n", + " width = 0.35\n", + " fig_bar, ax = plt.subplots(figsize=(max(6, len(classes)*0.3), 4))\n", + " ax.bar(x - width/2, actual_counts, width, label='Actual')\n", + " ax.bar(x + width/2, pred_counts, width, label='Predicted')\n", + " ax.set_xlabel('Class')\n", + " ax.set_ylabel('Count')\n", + " ax.set_xticks(x)\n", + " ax.set_xticklabels(classes, rotation=90)\n", + " ax.legend()\n", + " plt.tight_layout()\n", + " bar_path = os.path.join(save_dir, 'actual_vs_pred_counts.png')\n", + " fig_bar.savefig(bar_path, dpi=150)\n", + " plt.close(fig_bar)\n", + " plots['actual_vs_pred_counts'] = bar_path\n", + "\n", + " print(f\"Final Test Accuracy: {acc:.4f}\")\n", + " print(\"Test Confusion Matrix:\")\n", + " print(cm)\n", + "\n", + " if 'full_path_to_img' in test_df.columns:\n", + " mismatch_mask = (preds != y_true)\n", + " mismatch_idxs = np.where(mismatch_mask)[0]\n", + " n_show = min(top_n_mismatch, len(mismatch_idxs))\n", + " mismatch_list = []\n", + " for idx in mismatch_idxs[:n_show]:\n", + " mismatch_list.append({\n", + " 'img_path': test_df.iloc[idx]['full_path_to_img'],\n", + " 'actual': saved_classes[int(y_true[idx])],\n", + " 'predicted': saved_classes[int(preds[idx])]})\n", + " mismatch_df = pd.DataFrame(mismatch_list)\n", + " mismatch_csv = os.path.join(save_dir, 'mismatches.csv')\n", + " mismatch_df.to_csv(mismatch_csv, index=False)\n", + " plots['mismatches_csv'] = mismatch_csv\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions_with_actuals.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " return pred_labels, float(acc), cm, plots\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " print(\"No true labels available (or all were unseen). Returning predictions only.\")\n", + " return pred_labels, plots\n", + "\n", + "\n", + "\n", + "# ------------------------\n", + "# Run example\n", + "# ------------------------\n", + "if __name__ == '__main__':\n", + " # 1. Split the data once\n", + " train_df, val_df = safe_train_val_split(df.copy(), 'colour', test_size=0.2, seed=CONFIG['seed'])\n", + "\n", + " # 2. Run the training pipeline\n", + " train_pipeline(train_df, val_df)\n", + "\n", + " # 3. Now, call the test function on the validation set as a proxy\n", + " # In a real scenario, you would load a separate test.csv file here.\n", + " # For example: test_df_real = pd.read_csv(\"path/to/your/test_data.csv\")\n", + " preds, acc, cm, plots = test_model(test_df_sample, save_preds_csv=True)\n", + " from PIL import Image\n", + " Image.open(plots['confusion_matrix']).show()\n", + " pd.read_csv(plots['preds_csv']).head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-05T12:34:20.105670Z", + "iopub.status.busy": "2025-09-05T12:34:20.105316Z", + "iopub.status.idle": "2025-09-05T12:34:20.127755Z", + "shell.execute_reply": "2025-09-05T12:34:20.127074Z", + "shell.execute_reply.started": "2025-09-05T12:34:20.105642Z" + }, + "id": "IDFlK0fxOORh", + "jupyter": { + "source_hidden": true + }, + "trusted": true + }, + "outputs": [], + "source": [ + "def test_model(test_df: pd.DataFrame, target_col: str = 'colour', model_path: str = \"best_multimodal.pth\", device: Optional[torch.device] = None, save_preds_csv: bool = True, top_n_mismatch: int = 50, save_dir: str = \"/kaggle/working\"):\n", + " \"\"\"\n", + " Loads the saved model and evaluates it on a provided test dataset.\n", + " Improvements:\n", + " - uses map_location when loading checkpoints for device portability\n", + " - safely handles unseen labels in test set (drops them and warns)\n", + " - supports running when test labels are missing (returns predictions)\n", + " - sets DataLoader pin_memory / num_workers appropriately for device\n", + " - saves plots and CSVs into a specified directory (default: /kaggle/working)\n", + " - returns (pred_labels, acc, cm, plots) when labels exist, else (pred_labels, plots)\n", + " \"\"\"\n", + " import matplotlib.pyplot as plt\n", + " os.makedirs(save_dir, exist_ok=True)\n", + "\n", + " if not os.path.exists(model_path):\n", + " print(f\"Error: Model file '{model_path}' not found. Please train the model first.\")\n", + " return None\n", + "\n", + " # Load checkpoint onto CPU first (portable)\n", + " checkpoint = torch.load(model_path, map_location='cpu')\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " saved_classes = checkpoint[\"classes\"]\n", + " saved_config = checkpoint[\"config\"]\n", + " tab_pre = checkpoint[\"tab_pre\"]\n", + " save_dir: str = \"/kaggle/working\"\n", + "\n", + " print(\"--- Starting Test Evaluation ---\")\n", + " print(f\"Loading model trained on classes: {saved_classes}\")\n", + "\n", + " # Build a label encoder equivalent for mapping\n", + " te = LabelEncoder()\n", + " te.fit(saved_classes)\n", + "\n", + " test_df = test_df.copy()\n", + "\n", + " # If test contains target column, try to map to training class indices\n", + " y_true = None\n", + " if target_col in test_df.columns:\n", + " vals = test_df[target_col].astype(str).values\n", + " class_to_idx = {c: i for i, c in enumerate(saved_classes)}\n", + " mapped = np.array([class_to_idx.get(v, -1) for v in vals], dtype=int)\n", + "\n", + " if (mapped == -1).any():\n", + " n_bad = int((mapped == -1).sum())\n", + " print(f\"[warning] {n_bad} rows in test_df have labels not seen in training. These rows will be dropped for scoring.\")\n", + " mask = mapped != -1\n", + " test_df = test_df.loc[mask].reset_index(drop=True)\n", + " mapped = mapped[mask]\n", + "\n", + " # store true integer labels for scoring\n", + " if mapped.size:\n", + " y_true = mapped\n", + " test_df[target_col] = mapped\n", + " else:\n", + " y_true = None\n", + "\n", + " # Recreate model & move to device\n", + " categorical_cols = getattr(tab_pre, \"categorical_cols\", [])\n", + " numeric_cols = getattr(tab_pre, \"numeric_cols\", [])\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " if device is None:\n", + " if torch.cuda.is_available() and saved_config.get('device', 'cpu') == 'cuda':\n", + " device = torch.device('cuda')\n", + " else:\n", + " device = torch.device('cpu')\n", + "\n", + " model = MultiModalModel(saved_config, card, numeric_dim=len(numeric_cols), num_classes=len(saved_classes))\n", + " model.load_state_dict(state_dict)\n", + " model.to(device)\n", + "\n", + " # Build dataloader (pin_memory only if using CUDA)\n", + " pin = device.type == 'cuda'\n", + " num_workers = min(4, saved_config.get('num_workers', 0)) if pin else 0\n", + " test_ds = MultiModalDiamondDataset(test_df, tab_pre, categorical_cols, numeric_cols, target_col, train=False)\n", + " test_loader = DataLoader(test_ds, batch_size=max(1, saved_config.get('batch_size', 1)),\n", + " shuffle=False, num_workers=num_workers, pin_memory=pin)\n", + "\n", + " # Run evaluation / prediction\n", + " model.eval()\n", + " all_preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(test_loader, desc=\"Test\"):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.float32, device=device)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds = logits.argmax(dim=1).cpu().numpy()\n", + " all_preds.append(preds)\n", + "\n", + " preds = np.concatenate(all_preds) if len(all_preds) else np.array([], dtype=int)\n", + " pred_labels = te.inverse_transform(preds) if preds.size else np.array([])\n", + "\n", + " compare_df = test_df.copy()\n", + " if target_col in compare_df.columns and y_true is not None:\n", + " compare_df['actual_label'] = [saved_classes[int(i)] for i in y_true]\n", + " else:\n", + " compare_df['actual_label'] = None\n", + " compare_df['predicted_label'] = list(pred_labels)[:len(compare_df)]\n", + "\n", + " plots = {}\n", + "\n", + " if y_true is not None and len(y_true):\n", + " acc = accuracy_score(y_true, preds)\n", + " cm = confusion_matrix(y_true, preds, labels=np.arange(len(saved_classes)))\n", + "\n", + " # Confusion matrix plot\n", + " fig_cm, ax = plt.subplots(figsize=(6, 6))\n", + " im = ax.imshow(cm, interpolation='nearest')\n", + " ax.set_title('Confusion Matrix')\n", + " ax.set_xlabel('Predicted')\n", + " ax.set_ylabel('Actual')\n", + "\n", + " classes = list(saved_classes)\n", + " ax.set_xticks(np.arange(len(classes)))\n", + " ax.set_yticks(np.arange(len(classes)))\n", + " ax.set_xticklabels(classes, rotation=90)\n", + " ax.set_yticklabels(classes)\n", + " fig_cm.colorbar(im, ax=ax)\n", + " plt.tight_layout()\n", + " cm_path = os.path.join(save_dir, 'confusion_matrix.png')\n", + " fig_cm.savefig(cm_path, dpi=150)\n", + " plt.close(fig_cm)\n", + " plots['confusion_matrix'] = cm_path\n", + "\n", + " # Actual vs Predicted counts (bar chart)\n", + " actual_counts = np.bincount(y_true, minlength=len(classes))\n", + " pred_counts = np.bincount(preds, minlength=len(classes))\n", + " x = np.arange(len(classes))\n", + " width = 0.35\n", + " fig_bar, ax = plt.subplots(figsize=(max(6, len(classes)*0.3), 4))\n", + " ax.bar(x - width/2, actual_counts, width, label='Actual')\n", + " ax.bar(x + width/2, pred_counts, width, label='Predicted')\n", + " ax.set_xlabel('Class')\n", + " ax.set_ylabel('Count')\n", + " ax.set_xticks(x)\n", + " ax.set_xticklabels(classes, rotation=90)\n", + " ax.legend()\n", + " plt.tight_layout()\n", + " bar_path = os.path.join(save_dir, 'actual_vs_pred_counts.png')\n", + " fig_bar.savefig(bar_path, dpi=150)\n", + " plt.close(fig_bar)\n", + " plots['actual_vs_pred_counts'] = bar_path\n", + "\n", + " print(f\"Final Test Accuracy: {acc:.4f}\")\n", + " print(\"Test Confusion Matrix:\")\n", + " print(cm)\n", + "\n", + " if 'full_path_to_img' in test_df.columns:\n", + " mismatch_mask = (preds != y_true)\n", + " mismatch_idxs = np.where(mismatch_mask)[0]\n", + " n_show = min(top_n_mismatch, len(mismatch_idxs))\n", + " mismatch_list = []\n", + " for idx in mismatch_idxs[:n_show]:\n", + " mismatch_list.append({\n", + " 'img_path': test_df.iloc[idx]['full_path_to_img'],\n", + " 'actual': saved_classes[int(y_true[idx])],\n", + " 'predicted': saved_classes[int(preds[idx])]})\n", + " mismatch_df = pd.DataFrame(mismatch_list)\n", + " mismatch_csv = os.path.join(save_dir, 'mismatches.csv')\n", + " mismatch_df.to_csv(mismatch_csv, index=False)\n", + " plots['mismatches_csv'] = mismatch_csv\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions_with_actuals.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " return pred_labels, float(acc), cm, plots\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " print(\"No true labels available (or all were unseen). Returning predictions only.\")\n", + " return pred_labels, plots\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "01ecf103739a4fc481b9469879eaf2a4" + ] + }, + "collapsed": true, + "execution": { + "iopub.execute_input": "2025-09-05T12:34:22.684362Z", + "iopub.status.busy": "2025-09-05T12:34:22.684029Z", + "iopub.status.idle": "2025-09-05T12:36:15.923934Z", + "shell.execute_reply": "2025-09-05T12:36:15.923059Z", + "shell.execute_reply.started": "2025-09-05T12:34:22.684334Z" + }, + "id": "cpEp1NPSOORi", + "jupyter": { + "outputs_hidden": true, + "source_hidden": true + }, + "outputId": "afd2b447-7223-42d7-f61f-b3ef52acae4f", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[warning] Stratified split not possible (min class=1). Using random split.\n", + "--- Starting Test Evaluation ---\n", + "Loading model trained on classes: ['BLUE' 'D' 'D:P:BN' 'E' 'F' 'FANCY' 'FC:P' 'G' 'H' 'I' 'I:P' 'J' 'K' 'L'\n", + " 'M' 'N' 'O-P' 'Q-R' 'S-T' 'U-V' 'V:B' 'W-X' 'Y-Z']\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "01ecf103739a4fc481b9469879eaf2a4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Test: 0%| | 0/4877 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
path_to_imgstock_numbershapecaratclaritycolourcutpolishsymmetryfluorescencelablengthwidthdepthfull_path_to_imgactual_labelpredicted_label
0web_scraped/emerald/2132934.jpg2132934emerald0.53VVS23EXEXVGNGIA5.233.862.63/kaggle/input/diamond-images-dataset/web_scrap...EF
1web_scraped/princess/2101219.jpg2101219princess0.50VS13EXEXVGNGIA4.344.253.04/kaggle/input/diamond-images-dataset/web_scrap...EE
2web_scraped/round/2127275.jpg2127275round0.53VS11EXEXEXNGIA5.265.283.18/kaggle/input/diamond-images-dataset/web_scrap...DF
3web_scraped/princess/223350-231.jpg223350-231princess1.20VS23GDEXVGNGIA5.645.474.29/kaggle/input/diamond-images-dataset/web_scrap...EH
4web_scraped/round/2087662.jpg2087662round0.70SI112VGVGVGNGIA5.545.583.58/kaggle/input/diamond-images-dataset/web_scrap...KG
\n", + "" + ], + "text/plain": [ + " path_to_img stock_number shape carat clarity \\\n", + "0 web_scraped/emerald/2132934.jpg 2132934 emerald 0.53 VVS2 \n", + "1 web_scraped/princess/2101219.jpg 2101219 princess 0.50 VS1 \n", + "2 web_scraped/round/2127275.jpg 2127275 round 0.53 VS1 \n", + "3 web_scraped/princess/223350-231.jpg 223350-231 princess 1.20 VS2 \n", + "4 web_scraped/round/2087662.jpg 2087662 round 0.70 SI1 \n", + "\n", + " colour cut polish symmetry fluorescence lab length width depth \\\n", + "0 3 EX EX VG N GIA 5.23 3.86 2.63 \n", + "1 3 EX EX VG N GIA 4.34 4.25 3.04 \n", + "2 1 EX EX EX N GIA 5.26 5.28 3.18 \n", + "3 3 GD EX VG N GIA 5.64 5.47 4.29 \n", + "4 12 VG VG VG N GIA 5.54 5.58 3.58 \n", + "\n", + " full_path_to_img actual_label \\\n", + "0 /kaggle/input/diamond-images-dataset/web_scrap... E \n", + "1 /kaggle/input/diamond-images-dataset/web_scrap... E \n", + "2 /kaggle/input/diamond-images-dataset/web_scrap... D \n", + "3 /kaggle/input/diamond-images-dataset/web_scrap... E \n", + "4 /kaggle/input/diamond-images-dataset/web_scrap... K \n", + "\n", + " predicted_label \n", + "0 F \n", + "1 E \n", + "2 F \n", + "3 H \n", + "4 G " + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_df, val_df = safe_train_val_split(df.copy(), 'colour', test_size=0.2, seed=CONFIG['seed'])\n", + "test_df_sample = val_df.copy()\n", + "\n", + "# 3. Now, call the test function\n", + "preds, acc, cm, plots = test_model(test_df_sample, save_preds_csv=True)\n", + "from PIL import Image\n", + "Image.open(plots['confusion_matrix']).show()\n", + "pd.read_csv(plots['preds_csv']).head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3OnCjcbPOORi" + }, + "source": [ + "# VIT with 2-layer cross-attention fusion block" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "e117f15be163406d8c50c85ae4b9f116", + "4e79276d97474db99d5498313885e7bd", + "d77118d6fe3a4cc6be5a414c1351aad3", + "6f8846bccbc24e539dba1a4ed01fe579", + "735a6a59fe50432098afd0a96944843e", + "c2caf28562a748459eae02d4dc8bb127", + "5470677e9def45369325697ceb179350", + "f0ea30505fe44e279a95efda845e2050", + "09a097efe7904cff9541988a11b8cd9e", + "979e0c2aacd24681b997d0412a4a14ec", + "00ea6c874cff4cceae8ba712aedf500f", + "d4ee91254e5742239712f57da465bda9", + "41719b398f7541d78dcb119601174348", + "3f64a0066a204aa6a48f0478886e7174", + "eadef35ba1a34b2e8f2723754cdc67a0", + "a08dac5b49324eb5be413064626cebcf", + "4e83f8ffa346481da0177ac368e7bb83", + "8604ccb723e14f1289e59d2a348e05b9", + "f4ab28af52bc47e9b4a1db0cf5c544d6", + "3b72dcaa3ad14ab096a6c0a8dfe10543", + "5dd187b4e6df4fb28fbf42e06d028266" + ] + }, + "collapsed": true, + "execution": { + "iopub.execute_input": "2025-09-08T05:21:35.119639Z", + "iopub.status.busy": "2025-09-08T05:21:35.119286Z", + "iopub.status.idle": "2025-09-08T07:45:52.693255Z", + "shell.execute_reply": "2025-09-08T07:45:52.692352Z", + "shell.execute_reply.started": "2025-09-08T05:21:35.119612Z" + }, + "id": "wv6bB9VOOORi", + "jupyter": { + "outputs_hidden": true, + "source_hidden": true + }, + "outputId": "43714522-88fd-4953-a807-a7a89a77e9b8", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[CONFIG] GPU: Tesla T4 | VRAM=14.7 GB -> img=128 bs=2 accum=4 freeze=True backbone=vit_small_patch16_224 AMP=True\n", + "[warning] Stratified split not possible (min class=1). Using random split.\n", + "[run] Attempt 1 | bs=2 | img=128 | backbone=vit_small_patch16_224 | freeze=True | accum=4\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e117f15be163406d8c50c85ae4b9f116", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading model.safetensors: 0%| | 0.00/88.2M [00:00 freeze backbone -> image res↓ -> smaller backbone)\n", + "\n", + "This version incorporates a 2-layer Cross-Attention fusion block for more\n", + "sophisticated interaction between image and tabular features, aiming for\n", + "higher accuracy.\n", + "\"\"\"\n", + "\n", + "# =========================\n", + "# Pre-import memory tuning\n", + "# =========================\n", + "import os\n", + "os.environ.setdefault(\"PYTORCH_CUDA_ALLOC_CONF\", \"max_split_size_mb:128\")\n", + "\n", + "import math\n", + "import random\n", + "from pathlib import Path\n", + "from typing import List, Dict, Optional\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from PIL import Image\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torchvision import transforms\n", + "import timm\n", + "\n", + "from sklearn.preprocessing import LabelEncoder, StandardScaler\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score, confusion_matrix\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "# ------------------------\n", + "# Auto GPU-aware CONFIG\n", + "# ------------------------\n", + "def get_runtime_config():\n", + " \"\"\"Return a conservative CONFIG tuned to available (v)RAM.\"\"\"\n", + " seed = 42\n", + " base = {\n", + " \"epochs\": 10, # Increased epochs for more complex fusion\n", + " \"lr\": 1e-4,\n", + " \"weight_decay\": 0.05,\n", + " \"num_workers\": 0,\n", + " \"seed\": seed,\n", + " \"fusion_mode\": \"cross_attention\", # Changed to cross-attention\n", + " \"pretrained\": True,\n", + " \"tab_emb_dim\": 128,\n", + " \"hidden_head_dim\": 256,\n", + " \"vpt\": False,\n", + " \"vpt_num_prompts\": 10,\n", + " \"accum_steps\": 1,\n", + " \"num_classes\": None,\n", + " # Cross-attention specific\n", + " \"num_fusion_layers\": 2,\n", + " \"num_attention_heads\": 8,\n", + " }\n", + "\n", + " torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)\n", + "\n", + " if not torch.cuda.is_available():\n", + " base.update({\n", + " \"device\": \"cpu\",\n", + " \"use_amp\": False,\n", + " \"freeze_backbone\": True,\n", + " \"img_size\": 96,\n", + " \"batch_size\": 4,\n", + " \"backbone_name\": \"vit_tiny_patch16_224\",\n", + " })\n", + " return base\n", + "\n", + " dev = torch.cuda.get_device_properties(0)\n", + " total_gb = dev.total_memory / (1024 ** 3)\n", + "\n", + " # Defaults for ~16GB GPUs (T4/P100/V100)\n", + " base.update({\n", + " \"device\": \"cuda\",\n", + " \"use_amp\": True,\n", + " \"freeze_backbone\": True,\n", + " \"img_size\": 128,\n", + " \"batch_size\": 2,\n", + " \"accum_steps\": 4,\n", + " \"backbone_name\": \"vit_small_patch16_224\",\n", + " })\n", + "\n", + " if total_gb <= 8:\n", + " base.update({\"img_size\": 96, \"batch_size\": 1, \"accum_steps\": 8, \"backbone_name\": \"vit_tiny_patch16_224\"})\n", + " elif total_gb <= 16:\n", + " base.update({\"img_size\": 128, \"batch_size\": 2, \"accum_steps\": 4, \"backbone_name\": \"vit_small_patch16_224\"})\n", + " elif total_gb <= 24:\n", + " base.update({\"img_size\": 160, \"batch_size\": 4, \"accum_steps\": 4, \"freeze_backbone\": False, \"backbone_name\": \"vit_base_patch16_224\"})\n", + " else:\n", + " base.update({\"img_size\": 224, \"batch_size\": 8, \"accum_steps\": 4, \"freeze_backbone\": False, \"backbone_name\": \"vit_base_patch16_224\"})\n", + "\n", + " print(f\"[CONFIG] GPU: {dev.name} | VRAM={total_gb:.1f} GB -> img={base['img_size']} bs={base['batch_size']} \"\n", + " f\"accum={base['accum_steps']} freeze={base['freeze_backbone']} backbone={base['backbone_name']} AMP={base['use_amp']}\")\n", + " return base\n", + "\n", + "CONFIG = get_runtime_config()\n", + "\n", + "# ------------------------\n", + "# Example dummy dataframe (replace with your full dataset)\n", + "# ------------------------\n", + "# Create a larger dummy dataset for more robust training\n", + "# def create_dummy_data(num_samples=200):\n", + "# img_dir = Path(\"./dummy_images\")\n", + "# img_dir.mkdir(exist_ok=True)\n", + "\n", + "# dummy_image_paths = []\n", + "# for i in range(num_samples):\n", + "# path = img_dir / f\"dummy_{i}.png\"\n", + "# Image.new('RGB', (224, 224), color = (random.randint(0,255), random.randint(0,255), random.randint(0,255))).save(path)\n", + "# dummy_image_paths.append(str(path))\n", + "\n", + "# data = {\n", + "# 'full_path_to_img': dummy_image_paths,\n", + "# 'carat': np.random.uniform(0.3, 2.5, num_samples),\n", + "# 'clarity': np.random.choice(['SI2', 'SI1', 'VS2', 'VS1', 'VVS2'], num_samples),\n", + "# 'colour': np.random.choice(['D', 'E', 'F', 'G', 'H', 'I', 'J'], num_samples),\n", + "# 'cut': np.random.choice(['EX', 'VG', 'G'], num_samples),\n", + "# 'polish': np.random.choice(['EX', 'VG', 'G'], num_samples),\n", + "# 'symmetry': np.random.choice(['EX', 'VG', 'G'], num_samples),\n", + "# 'fluorescence': np.random.choice(['N', 'F', 'M', 'S'], num_samples),\n", + "# 'lab': ['GIA'] * num_samples,\n", + "# 'length': np.random.uniform(4.0, 9.0, num_samples),\n", + "# 'width': np.random.uniform(4.0, 9.0, num_samples),\n", + "# 'depth': np.random.uniform(2.5, 5.5, num_samples),\n", + "# }\n", + "# return pd.DataFrame(data)\n", + "\n", + "df = data.copy()\n", + "\n", + "\n", + "# ------------------------\n", + "# Tabular preprocessing\n", + "# ------------------------\n", + "class TabularPreprocessor:\n", + " def __init__(self, categorical_cols: List[str], numeric_cols: List[str]):\n", + " self.categorical_cols = categorical_cols\n", + " self.numeric_cols = numeric_cols\n", + " self.label_encoders: Dict[str, LabelEncoder] = {}\n", + " self.scaler = StandardScaler()\n", + "\n", + " def fit(self, df: pd.DataFrame):\n", + " for c in self.categorical_cols:\n", + " le = LabelEncoder()\n", + " df[c] = df[c].astype(str).fillna('NA')\n", + " le.fit(df[c].values)\n", + " self.label_encoders[c] = le\n", + " if len(self.numeric_cols):\n", + " self.scaler.fit(df[self.numeric_cols].astype(float).values)\n", + "\n", + " def transform(self, df: pd.DataFrame):\n", + " cat_arrays = []\n", + " for c in self.categorical_cols:\n", + " arr = df[c].astype(str).fillna('NA').values\n", + " le = self.label_encoders[c]\n", + " cat_arrays.append(le.transform(arr))\n", + " cats = np.stack(cat_arrays, axis=1) if len(cat_arrays) else np.zeros((len(df), 0), dtype=np.int64)\n", + " nums = self.scaler.transform(df[self.numeric_cols].astype(float).values) if len(self.numeric_cols) else np.zeros((len(df), 0))\n", + " return cats, nums\n", + "\n", + " def get_cardinalities(self):\n", + " return {c: len(self.label_encoders[c].classes_) for c in self.categorical_cols}\n", + "\n", + "# ------------------------\n", + "# Robust split helper\n", + "# ------------------------\n", + "def safe_train_val_split(df_in: pd.DataFrame, target_col: str, test_size: float = 0.2, seed: int = 42):\n", + " if len(df_in) <= 10:\n", + " return df_in, df_in\n", + " counts = df_in[target_col].value_counts()\n", + " if counts.min() < 2:\n", + " print(f\"[warning] Stratified split not possible (min class={int(counts.min())}). Using random split.\")\n", + " return train_test_split(df_in, test_size=test_size, random_state=seed, shuffle=True)\n", + " return train_test_split(df_in, test_size=test_size, random_state=seed, stratify=df_in[target_col])\n", + "\n", + "# ------------------------\n", + "# Transforms & Dataset\n", + "# ------------------------\n", + "def get_transforms(img_size=224, train=True):\n", + " if train:\n", + " return transforms.Compose([\n", + " transforms.Resize((img_size, img_size)),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomRotation(10),\n", + " transforms.ColorJitter(0.05,0.05,0.05,0.01),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n", + " ])\n", + " else:\n", + " return transforms.Compose([\n", + " transforms.Resize((img_size, img_size)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n", + " ])\n", + "\n", + "class MultiModalDiamondDataset(Dataset):\n", + " def __init__(self, df: pd.DataFrame, tab_preproc: TabularPreprocessor,\n", + " categorical_cols: List[str], numeric_cols: List[str],\n", + " target_col: str, train=True):\n", + " self.df = df.reset_index(drop=True)\n", + " self.transform = get_transforms(CONFIG['img_size'], train)\n", + " self.tab_preproc = tab_preproc\n", + " self.categorical_cols = categorical_cols\n", + " self.numeric_cols = numeric_cols\n", + " self.target_col = target_col\n", + " self.cats, self.nums = tab_preproc.transform(self.df)\n", + "\n", + " def __len__(self): return len(self.df)\n", + "\n", + " def __getitem__(self, idx):\n", + " row = self.df.iloc[idx]\n", + " img_path = row['full_path_to_img']\n", + " try:\n", + " img = Image.open(img_path).convert('RGB')\n", + " except Exception:\n", + " img = Image.new('RGB', (CONFIG['img_size'], CONFIG['img_size']), (0,0,0))\n", + " img = self.transform(img)\n", + "\n", + " cat = torch.tensor(self.cats[idx].astype(np.int64)) if self.cats.shape[1] else torch.empty(0, dtype=torch.long)\n", + " num = torch.tensor(self.nums[idx].astype(np.float32)) if self.nums.shape[1] else torch.empty(0, dtype=torch.float32)\n", + " target = row[self.target_col]\n", + " return {'image': img, 'cat': cat, 'num': num, 'target': torch.tensor(target, dtype=torch.long)}\n", + "\n", + "# ------------------------\n", + "# Model components\n", + "# ------------------------\n", + "class SimpleTabularEncoder(nn.Module):\n", + " def __init__(self, cardinalities: Dict[str,int], numeric_dim:int, emb_dim=128, hidden_dim=256):\n", + " super().__init__()\n", + " self.cat_cols = list(cardinalities.keys())\n", + " self.embs = nn.ModuleDict()\n", + " for k in self.cat_cols:\n", + " card = cardinalities[k]\n", + " self.embs[k] = nn.Embedding(card, min(50, (card+1)//2))\n", + " cat_total_dim = sum([self.embs[k].embedding_dim for k in self.cat_cols]) if self.cat_cols else 0\n", + " self.numeric_dim = numeric_dim\n", + " in_dim = cat_total_dim + numeric_dim\n", + " self.net = nn.Sequential(\n", + " nn.Linear(max(1, in_dim), hidden_dim),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(hidden_dim),\n", + " nn.Linear(hidden_dim, emb_dim),\n", + " )\n", + "\n", + " def forward(self, cat: torch.Tensor, num: torch.Tensor):\n", + " device = next(self.parameters()).device\n", + " if cat.shape[1] > 0:\n", + " emb_list = [self.embs[k](cat[:, i]) for i, k in enumerate(self.cat_cols)]\n", + " cat_emb = torch.cat(emb_list, dim=1)\n", + " else:\n", + " cat_emb = torch.zeros((cat.shape[0], 0), device=device)\n", + " x = torch.cat([cat_emb, num], dim=1) if num.shape[1] > 0 else cat_emb\n", + " if x.numel() == 0:\n", + " return torch.zeros((cat.shape[0], CONFIG['tab_emb_dim']), device=device)\n", + " return self.net(x)\n", + "\n", + "class CrossAttention(nn.Module):\n", + " def __init__(self, query_dim, context_dim, heads=8, dim_head=64, dropout=0.):\n", + " super().__init__()\n", + " inner_dim = dim_head * heads\n", + " self.scale = dim_head ** -0.5\n", + " self.heads = heads\n", + "\n", + " self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n", + " self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n", + " self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n", + " self.to_out = nn.Sequential(\n", + " nn.Linear(inner_dim, query_dim),\n", + " nn.Dropout(dropout)\n", + " )\n", + "\n", + " def forward(self, x, context):\n", + " q = self.to_q(x)\n", + " k = self.to_k(context)\n", + " v = self.to_v(context)\n", + "\n", + " # Reshape for multi-head attention\n", + " q, k, v = map(lambda t: t.view(t.shape[0], -1, self.heads, t.shape[-1] // self.heads).transpose(1, 2), (q, k, v))\n", + "\n", + " sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n", + " attn = sim.softmax(dim=-1)\n", + "\n", + " out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)\n", + " out = out.transpose(1, 2).reshape(out.shape[0], -1, self.heads * (out.shape[-1]))\n", + " return self.to_out(out)\n", + "\n", + "class CrossAttentionFusionBlock(nn.Module):\n", + " def __init__(self, tab_dim, img_dim, num_layers=2, heads=8):\n", + " super().__init__()\n", + " self.layers = nn.ModuleList([])\n", + " for _ in range(num_layers):\n", + " self.layers.append(nn.ModuleList([\n", + " nn.LayerNorm(tab_dim),\n", + " nn.LayerNorm(img_dim),\n", + " CrossAttention(tab_dim, img_dim, heads=heads, dim_head=img_dim//heads)\n", + " ]))\n", + "\n", + " def forward(self, tab_features, img_features):\n", + " tab_features = tab_features.unsqueeze(1) # Add sequence dimension\n", + " for norm_tab, norm_img, attn in self.layers:\n", + " tab_features = attn(norm_tab(tab_features), norm_img(img_features)) + tab_features\n", + " return tab_features.squeeze(1)\n", + "\n", + "\n", + "class MultiModalModel(nn.Module):\n", + " def __init__(self, cfg: Dict, cardinalities: Dict[str,int], numeric_dim:int, num_classes:int):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " self.backbone = timm.create_model(cfg['backbone_name'], pretrained=cfg['pretrained'], num_classes=0, img_size=cfg['img_size'])\n", + "\n", + " if hasattr(self.backbone, 'patch_embed') and hasattr(self.backbone.patch_embed, 'strict_img_size'):\n", + " self.backbone.patch_embed.strict_img_size = False\n", + "\n", + " if hasattr(self.backbone, \"set_grad_checkpointing\"):\n", + " self.backbone.set_grad_checkpointing(enable=True)\n", + "\n", + " embed_dim = getattr(self.backbone, 'num_features', getattr(self.backbone, 'embed_dim', 768))\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.tab_encoder = SimpleTabularEncoder(cardinalities, numeric_dim, emb_dim=cfg['tab_emb_dim'], hidden_dim=cfg['hidden_head_dim'])\n", + "\n", + " if cfg['fusion_mode'] == 'late':\n", + " head_in = embed_dim + cfg['tab_emb_dim']\n", + " self.head = nn.Sequential(\n", + " nn.Linear(head_in, cfg['hidden_head_dim']),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(cfg['hidden_head_dim']),\n", + " nn.Linear(cfg['hidden_head_dim'], num_classes)\n", + " )\n", + " elif cfg['fusion_mode'] == 'cross_attention':\n", + " self.fusion_block = CrossAttentionFusionBlock(\n", + " tab_dim=cfg['tab_emb_dim'],\n", + " img_dim=embed_dim,\n", + " num_layers=cfg['num_fusion_layers'],\n", + " heads=cfg['num_attention_heads']\n", + " )\n", + " head_in = cfg['tab_emb_dim']\n", + " self.head = nn.Sequential(\n", + " nn.Linear(head_in, cfg['hidden_head_dim']),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(cfg['hidden_head_dim']),\n", + " nn.Linear(cfg['hidden_head_dim'], num_classes)\n", + " )\n", + "\n", + " def forward(self, image: torch.Tensor, cat: torch.Tensor, num: torch.Tensor):\n", + " tab_emb = self.tab_encoder(cat, num)\n", + "\n", + " if self.cfg['fusion_mode'] == 'late':\n", + " img_cls = self.backbone.forward_features(image)[:, 0]\n", + " x = torch.cat([img_cls, tab_emb], dim=1)\n", + " return self.head(x)\n", + " elif self.cfg['fusion_mode'] == 'cross_attention':\n", + " img_features = self.backbone.forward_features(image) # All patch tokens\n", + " fused_emb = self.fusion_block(tab_emb, img_features)\n", + " return self.head(fused_emb)\n", + "\n", + "# ------------------------\n", + "# Train / Validate\n", + "# ------------------------\n", + "def train_one_epoch(model, loader, optimizer, device, epoch, scaler=None):\n", + " model.train()\n", + " losses = []\n", + " pbar = tqdm(loader, desc=f\"Train {epoch}\")\n", + " criterion = nn.CrossEntropyLoss()\n", + " accum_steps = CONFIG.get('accum_steps', 1)\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " for i, batch in enumerate(pbar):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0),0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0),0), dtype=torch.float32, device=device)\n", + " targets = batch['target'].to(device, non_blocking=True)\n", + "\n", + " if CONFIG['use_amp'] and scaler is not None:\n", + " with torch.cuda.amp.autocast():\n", + " logits = model(imgs, cat, num)\n", + " loss = criterion(logits, targets) / accum_steps\n", + " scaler.scale(loss).backward()\n", + " else:\n", + " logits = model(imgs, cat, num)\n", + " loss = criterion(logits, targets) / accum_steps\n", + " loss.backward()\n", + "\n", + " if (i + 1) % accum_steps == 0:\n", + " if scaler is not None:\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " else:\n", + " optimizer.step()\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " losses.append(loss.item() * accum_steps)\n", + " pbar.set_postfix(loss=np.mean(losses))\n", + " return float(np.mean(losses))\n", + "\n", + "@torch.no_grad()\n", + "def validate(model, loader, device):\n", + " model.eval()\n", + " preds, trues = [], []\n", + " for batch in tqdm(loader, desc='Val'):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0),0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0),0), dtype=torch.float32, device=device)\n", + " targets = batch['target'].to(device, non_blocking=True)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds.append(logits.argmax(dim=1).cpu().numpy())\n", + " trues.append(targets.cpu().numpy())\n", + " preds = np.concatenate(preds) if len(preds) else np.array([])\n", + " trues = np.concatenate(trues) if len(trues) else np.array([])\n", + " acc = accuracy_score(trues, preds) if trues.size else 0.0\n", + " # Ensure labels for confusion matrix are within the range of predictions/trues\n", + " labels = np.unique(np.concatenate((trues, preds)))\n", + " cm = confusion_matrix(trues, preds, labels=labels) if trues.size else np.zeros((0,0), dtype=int)\n", + " return acc, cm\n", + "\n", + "# ------------------------\n", + "# Runner\n", + "# ------------------------\n", + "def train_pipeline(train_df: pd.DataFrame, val_df: pd.DataFrame, target_col: str = 'colour'):\n", + " categorical_cols = ['clarity', 'cut', 'polish', 'symmetry', 'fluorescence', 'lab']\n", + " numeric_cols = ['carat', 'length', 'width', 'depth']\n", + "\n", + " df_combined = pd.concat([train_df, val_df], ignore_index=True)\n", + " te = LabelEncoder()\n", + " df_combined[target_col] = te.fit_transform(df_combined[target_col].astype(str))\n", + " CONFIG['num_classes'] = len(te.classes_)\n", + " tab_pre = TabularPreprocessor(categorical_cols, numeric_cols)\n", + " tab_pre.fit(df_combined)\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " train_df[target_col] = te.transform(train_df[target_col].astype(str))\n", + " val_df[target_col] = te.transform(val_df[target_col].astype(str))\n", + "\n", + " max_attempts, attempt = 4, 0\n", + " last_exc = None\n", + "\n", + " while attempt <= max_attempts:\n", + " try:\n", + " print(f\"[run] Attempt {attempt+1} | bs={CONFIG['batch_size']} | img={CONFIG['img_size']} | \"\n", + " f\"backbone={CONFIG['backbone_name']} | freeze={CONFIG['freeze_backbone']} | accum={CONFIG['accum_steps']}\")\n", + "\n", + " pin = torch.cuda.is_available()\n", + " train_ds = MultiModalDiamondDataset(train_df, tab_pre, categorical_cols, numeric_cols, target_col, train=True)\n", + " val_ds = MultiModalDiamondDataset(val_df, tab_pre, categorical_cols, numeric_cols, target_col, train=False)\n", + " train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True,\n", + " num_workers=CONFIG['num_workers'], pin_memory=pin, persistent_workers=False)\n", + " val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False,\n", + " num_workers=CONFIG['num_workers'], pin_memory=pin, persistent_workers=False)\n", + "\n", + " device = torch.device(CONFIG['device'] if torch.cuda.is_available() else \"cpu\")\n", + " if torch.cuda.is_available():\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.cuda.empty_cache()\n", + "\n", + " model = MultiModalModel(CONFIG, card, numeric_dim=len(numeric_cols), num_classes=CONFIG['num_classes'])\n", + "\n", + " if CONFIG['freeze_backbone']:\n", + " for p in model.backbone.parameters(): p.requires_grad = False\n", + "\n", + " model.to(device)\n", + "\n", + " opt = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad),\n", + " lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])\n", + " scaler = torch.cuda.amp.GradScaler() if (CONFIG['use_amp'] and torch.cuda.is_available()) else None\n", + "\n", + " best_acc = 0.0\n", + " for epoch in range(CONFIG['epochs']):\n", + " train_loss = train_one_epoch(model, train_loader, opt, device, epoch, scaler=scaler)\n", + " acc, cm = validate(model, val_loader, device)\n", + " print(f\"Epoch {epoch} | Train loss: {train_loss:.4f} | Val Acc: {acc:.4f}\")\n", + " print(cm)\n", + " if acc > best_acc:\n", + " best_acc = acc\n", + " torch.save({\"state_dict\": model.state_dict(),\n", + " \"classes\": te.classes_,\n", + " \"config\": CONFIG,\n", + " \"tab_pre\": tab_pre}, \"best_multimodal.pth\")\n", + " print(f\"Model saved with new best accuracy: {best_acc:.4f}\")\n", + " print(\"Best val acc:\", best_acc)\n", + " break\n", + "\n", + " except RuntimeError as e:\n", + " last_exc = e\n", + " msg = str(e).lower()\n", + " if (\"out of memory\" in msg) or (\"cuda out of memory\" in msg):\n", + " print(f\"[OOM] attempt {attempt+1}: {e}\")\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + "\n", + " if attempt == 0:\n", + " old = CONFIG['batch_size']\n", + " CONFIG['batch_size'] = max(1, old // 2)\n", + " CONFIG['accum_steps'] = max(CONFIG['accum_steps'], 2)\n", + " print(f\"[mitigation] batch {old} -> {CONFIG['batch_size']} | accum -> {CONFIG['accum_steps']}\")\n", + " elif attempt == 1:\n", + " CONFIG['freeze_backbone'] = True\n", + " print(\"[mitigation] freeze_backbone=True\")\n", + " elif attempt == 2:\n", + " old = CONFIG['img_size']\n", + " CONFIG['img_size'] = max(96, old // 2)\n", + " print(f\"[mitigation] img_size {old} -> {CONFIG['img_size']}\")\n", + " elif attempt == 3:\n", + " old = CONFIG['backbone_name']\n", + " CONFIG['backbone_name'] = 'vit_tiny_patch16_224' if 'small' in old else 'vit_small_patch16_224'\n", + " print(f\"[mitigation] backbone {old} -> {CONFIG['backbone_name']}\")\n", + " else:\n", + " print(\"[OOM] All mitigations exhausted.\")\n", + " raise\n", + " attempt += 1\n", + " continue\n", + " else:\n", + " raise\n", + " else:\n", + " if last_exc is not None:\n", + " raise last_exc\n", + "\n", + "def test_model(test_df: pd.DataFrame, target_col: str = 'colour', model_path: str = \"best_multimodal.pth\", device: Optional[torch.device] = None, save_preds_csv: bool = True, top_n_mismatch: int = 50, save_dir: str = \"/kaggle/working/\"):\n", + " import matplotlib.pyplot as plt\n", + " os.makedirs(save_dir, exist_ok=True)\n", + "\n", + " if not os.path.exists(model_path):\n", + " print(f\"Error: Model file '{model_path}' not found. Please train the model first.\")\n", + " return None\n", + "\n", + " checkpoint = torch.load(model_path, map_location='cpu')\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " saved_classes = checkpoint[\"classes\"]\n", + " saved_config = checkpoint[\"config\"]\n", + " tab_pre = checkpoint[\"tab_pre\"]\n", + "\n", + " print(\"--- Starting Test Evaluation ---\")\n", + " print(f\"Loading model trained on classes: {saved_classes}\")\n", + "\n", + " te = LabelEncoder()\n", + " te.classes_ = saved_classes\n", + "\n", + " test_df = test_df.copy()\n", + "\n", + " y_true = None\n", + " if target_col in test_df.columns:\n", + " # Filter out unseen labels from the test set\n", + " seen_mask = test_df[target_col].isin(saved_classes)\n", + " if not seen_mask.all():\n", + " print(f\"[Warning] Found {sum(~seen_mask)} samples with labels not seen during training. These will be ignored for metrics.\")\n", + " test_df = test_df[seen_mask].reset_index(drop=True)\n", + "\n", + " if not test_df.empty:\n", + " y_true = te.transform(test_df[target_col].astype(str))\n", + " test_df[target_col] = y_true\n", + " else:\n", + " print(\"Test set is empty after filtering unseen labels.\")\n", + "\n", + " categorical_cols = getattr(tab_pre, \"categorical_cols\", [])\n", + " numeric_cols = getattr(tab_pre, \"numeric_cols\", [])\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " if device is None:\n", + " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + " model = MultiModalModel(saved_config, card, numeric_dim=len(numeric_cols), num_classes=len(saved_classes))\n", + " model.load_state_dict(state_dict)\n", + " model.to(device)\n", + "\n", + " pin = device.type == 'cuda'\n", + " num_workers = min(4, saved_config.get('num_workers', 0)) if pin else 0\n", + " test_ds = MultiModalDiamondDataset(test_df, tab_pre, categorical_cols, numeric_cols, target_col, train=False)\n", + " test_loader = DataLoader(test_ds, batch_size=max(1, saved_config.get('batch_size', 1)),\n", + " shuffle=False, num_workers=num_workers, pin_memory=pin)\n", + "\n", + " model.eval()\n", + " all_preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(test_loader, desc=\"Test\"):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.float32, device=device)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds = logits.argmax(dim=1).cpu().numpy()\n", + " all_preds.append(preds)\n", + "\n", + " preds = np.concatenate(all_preds) if len(all_preds) else np.array([], dtype=int)\n", + " pred_labels = te.inverse_transform(preds) if preds.size else np.array([])\n", + "\n", + " compare_df = test_df.copy()\n", + " compare_df['predicted_label'] = pred_labels\n", + " if 'actual_label' not in compare_df.columns and y_true is not None:\n", + " compare_df['actual_label'] = te.inverse_transform(y_true)\n", + "\n", + " plots = {}\n", + " if y_true is not None and len(y_true):\n", + " acc = accuracy_score(y_true, preds)\n", + " cm = confusion_matrix(y_true, preds, labels=np.arange(len(saved_classes)))\n", + "\n", + " fig_cm, ax = plt.subplots(figsize=(8, 8))\n", + " im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", + " ax.set_title('Confusion Matrix')\n", + " fig_cm.colorbar(im)\n", + " tick_marks = np.arange(len(saved_classes))\n", + " ax.set_xticks(tick_marks)\n", + " ax.set_xticklabels(saved_classes, rotation=90)\n", + " ax.set_yticks(tick_marks)\n", + " ax.set_yticklabels(saved_classes)\n", + " ax.set_ylabel('True label')\n", + " ax.set_xlabel('Predicted label')\n", + " plt.tight_layout()\n", + " cm_path = os.path.join(save_dir, 'confusion_matrix.png')\n", + " fig_cm.savefig(cm_path, dpi=150)\n", + " plt.close(fig_cm)\n", + " plots['confusion_matrix'] = cm_path\n", + "\n", + " print(f\"Final Test Accuracy: {acc:.4f}\")\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions_with_actuals.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " return pred_labels, float(acc), cm, plots\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " print(\"No true labels available. Returning predictions only.\")\n", + " return pred_labels, plots\n", + "\n", + "\n", + "# ------------------------\n", + "# Run example\n", + "# ------------------------\n", + "if __name__ == '__main__':\n", + " # 1. Split the data once\n", + " train_df, val_df = safe_train_val_split(df.copy(), 'colour', test_size=0.2, seed=CONFIG['seed'])\n", + "\n", + " # 2. Run the training pipeline\n", + " train_pipeline(train_df, val_df.copy()) # Pass a copy to avoid modification issues\n", + "\n", + " # 3. Test on the validation set as a proxy\n", + " print(\"\\n--- Running Test on Validation Set ---\")\n", + " if os.path.exists(\"best_multimodal.pth\"):\n", + " test_output = test_model(val_df, save_preds_csv=True)\n", + " if test_output:\n", + " if len(test_output) == 4:\n", + " preds, acc, cm, plots = test_output\n", + " print(f\"Test Accuracy: {acc}\")\n", + " if 'confusion_matrix' in plots:\n", + " print(f\"Confusion matrix saved to {plots['confusion_matrix']}\")\n", + " # img = Image.open(plots['confusion_matrix'])\n", + " # img.show() # This might not work in all environments\n", + " if 'preds_csv' in plots:\n", + " print(f\"Predictions saved to {plots['preds_csv']}\")\n", + " print(pd.read_csv(plots['preds_csv']).head())\n", + " else:\n", + " preds, plots = test_output\n", + " print(\"Test completed without ground truth labels.\")\n", + " else:\n", + " print(\"Training did not complete successfully, skipping test.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-08T07:56:16.322731Z", + "iopub.status.busy": "2025-09-08T07:56:16.322328Z", + "iopub.status.idle": "2025-09-08T07:56:16.344387Z", + "shell.execute_reply": "2025-09-08T07:56:16.343435Z", + "shell.execute_reply.started": "2025-09-08T07:56:16.322697Z" + }, + "id": "QLd0zfEpOORj", + "jupyter": { + "source_hidden": true + }, + "trusted": true + }, + "outputs": [], + "source": [ + "def test_model(test_df: pd.DataFrame, target_col: str = 'colour', model_path: str = \"best_multimodal.pth\", device: Optional[torch.device] = None, save_preds_csv: bool = True, top_n_mismatch: int = 50, save_dir: str = \"/kaggle/working/\"):\n", + " import matplotlib.pyplot as plt\n", + " os.makedirs(save_dir, exist_ok=True)\n", + "\n", + " if not os.path.exists(model_path):\n", + " print(f\"Error: Model file '{model_path}' not found. Please train the model first.\")\n", + " return None\n", + "\n", + " checkpoint = torch.load(model_path, map_location='cpu')\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " saved_classes = checkpoint[\"classes\"]\n", + " saved_config = checkpoint[\"config\"]\n", + " tab_pre = checkpoint[\"tab_pre\"]\n", + "\n", + " print(\"--- Starting Test Evaluation ---\")\n", + " print(f\"Loading model trained on classes: {saved_classes}\")\n", + "\n", + " te = LabelEncoder()\n", + " te.classes_ = saved_classes\n", + "\n", + " test_df = test_df.copy()\n", + "\n", + " y_true = None\n", + " if target_col in test_df.columns:\n", + " # Filter out unseen labels from the test set\n", + " seen_mask = test_df[target_col].isin(saved_classes)\n", + " if not seen_mask.all():\n", + " print(f\"[Warning] Found {sum(~seen_mask)} samples with labels not seen during training. These will be ignored for metrics.\")\n", + " test_df = test_df[seen_mask].reset_index(drop=True)\n", + "\n", + " if not test_df.empty:\n", + " y_true = te.transform(test_df[target_col].astype(str))\n", + " test_df[target_col] = y_true\n", + " else:\n", + " print(\"Test set is empty after filtering unseen labels.\")\n", + "\n", + " categorical_cols = getattr(tab_pre, \"categorical_cols\", [])\n", + " numeric_cols = getattr(tab_pre, \"numeric_cols\", [])\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " if device is None:\n", + " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + " model = MultiModalModel(saved_config, card, numeric_dim=len(numeric_cols), num_classes=len(saved_classes))\n", + " model.load_state_dict(state_dict)\n", + " model.to(device)\n", + "\n", + " pin = device.type == 'cuda'\n", + " num_workers = min(4, saved_config.get('num_workers', 0)) if pin else 0\n", + " test_ds = MultiModalDiamondDataset(test_df, tab_pre, categorical_cols, numeric_cols, target_col, train=False)\n", + " test_loader = DataLoader(test_ds, batch_size=max(1, saved_config.get('batch_size', 1)),\n", + " shuffle=False, num_workers=num_workers, pin_memory=pin)\n", + "\n", + " model.eval()\n", + " all_preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(test_loader, desc=\"Test\"):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.float32, device=device)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds = logits.argmax(dim=1).cpu().numpy()\n", + " all_preds.append(preds)\n", + "\n", + " preds = np.concatenate(all_preds) if len(all_preds) else np.array([], dtype=int)\n", + " pred_labels = te.inverse_transform(preds) if preds.size else np.array([])\n", + "\n", + " compare_df = test_df.copy()\n", + " compare_df['predicted_label'] = pred_labels\n", + " if 'actual_label' not in compare_df.columns and y_true is not None:\n", + " compare_df['actual_label'] = te.inverse_transform(y_true)\n", + "\n", + " plots = {}\n", + " if y_true is not None and len(y_true):\n", + " acc = accuracy_score(y_true, preds)\n", + " cm = confusion_matrix(y_true, preds, labels=np.arange(len(saved_classes)))\n", + "\n", + " fig_cm, ax = plt.subplots(figsize=(8, 8))\n", + " im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", + " ax.set_title('Confusion Matrix')\n", + " fig_cm.colorbar(im)\n", + " tick_marks = np.arange(len(saved_classes))\n", + " ax.set_xticks(tick_marks)\n", + " ax.set_xticklabels(saved_classes, rotation=90)\n", + " ax.set_yticks(tick_marks)\n", + " ax.set_yticklabels(saved_classes)\n", + " ax.set_ylabel('True label')\n", + " ax.set_xlabel('Predicted label')\n", + " plt.tight_layout()\n", + " cm_path = os.path.join(save_dir, 'confusion_matrix.png')\n", + " fig_cm.savefig(cm_path, dpi=150)\n", + " plt.close(fig_cm)\n", + " plots['confusion_matrix'] = cm_path\n", + "\n", + " print(f\"Final Test Accuracy: {acc:.4f}\")\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions_with_actuals.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " return pred_labels, float(acc), cm, plots\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " print(\"No true labels available. Returning predictions only.\")\n", + " return pred_labels, plots\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "7e65e155d2d94b749ef660c2e0d51eab" + ] + }, + "collapsed": true, + "execution": { + "iopub.execute_input": "2025-09-08T07:56:24.551403Z", + "iopub.status.busy": "2025-09-08T07:56:24.550748Z", + "iopub.status.idle": "2025-09-08T07:58:26.361829Z", + "shell.execute_reply": "2025-09-08T07:58:26.360819Z", + "shell.execute_reply.started": "2025-09-08T07:56:24.551372Z" + }, + "id": "cJRb6toDOORj", + "jupyter": { + "outputs_hidden": true, + "source_hidden": true + }, + "outputId": "f4c159f8-b0ce-41fe-902b-6db57c5717c4", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[warning] Stratified split not possible (min class=1). Using random split.\n", + "\n", + "--- Running Test on Validation Set ---\n", + "--- Starting Test Evaluation ---\n", + "Loading model trained on classes: ['BLUE' 'D' 'D:P:BN' 'E' 'F' 'FANCY' 'FC:P' 'G' 'H' 'I' 'I:P' 'J' 'K' 'L'\n", + " 'M' 'N' 'O-P' 'Q-R' 'S-T' 'U-V' 'V:B' 'W-X' 'Y-Z']\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7e65e155d2d94b749ef660c2e0d51eab", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Test: 0%| | 0/4877 [00:00 img=128 bs=2 accum=4 freeze=True backbone=vit_small_patch16_224 AMP=True\n", + "[warning] Stratified split not possible (min class=1). Using random split.\n", + "\n", + "--- Starting Optuna search for 2 trials ---\n", + "[trial 0] cfg lr=3.50e-05 wd=3.30e-04 bs=4 img=96 backbone=vit_small_patch16_224 freeze=False accum=2\n", + "[run] Attempt 1 | bs=4 | img=96 | backbone=vit_small_patch16_224 | freeze=False | accum=2\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a32b1905bfd74bf48932097639808db4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Train 0: 0%| | 0/9753 [00:00 img={base['img_size']} bs={base['batch_size']} \"\n", + " f\"accum={base['accum_steps']} freeze={base['freeze_backbone']} backbone={base['backbone_name']} AMP={base['use_amp']}\")\n", + " return base\n", + "\n", + "# Global base CONFIG (will be copied and adjusted per Optuna trial)\n", + "CONFIG = get_runtime_config()\n", + "\n", + "# ------------------------\n", + "# Example dummy dataframe (uncomment for standalone testing)\n", + "# ------------------------\n", + "# def create_dummy_data(num_samples=200):\n", + "# img_dir = Path(\"./dummy_images\")\n", + "# img_dir.mkdir(exist_ok=True)\n", + "#\n", + "# dummy_image_paths = []\n", + "# for i in range(num_samples):\n", + "# path = img_dir / f\"dummy_{i}.png\"\n", + "# Image.new('RGB', (224, 224), color = (random.randint(0,255), random.randint(0,255), random.randint(0,255))).save(path)\n", + "# dummy_image_paths.append(str(path))\n", + "#\n", + "# data = {\n", + "# 'full_path_to_img': dummy_image_paths,\n", + "# 'carat': np.random.uniform(0.3, 2.5, num_samples),\n", + "# 'clarity': np.random.choice(['SI2', 'SI1', 'VS2', 'VS1', 'VVS2'], num_samples),\n", + "# 'colour': np.random.choice(['D', 'E', 'F', 'G', 'H', 'I', 'J'], num_samples),\n", + "# 'cut': np.random.choice(['EX', 'VG', 'G'], num_samples),\n", + "# 'polish': np.random.choice(['EX', 'VG', 'G'], num_samples),\n", + "# 'symmetry': np.random.choice(['EX', 'VG', 'G'], num_samples),\n", + "# 'fluorescence': np.random.choice(['N', 'F', 'M', 'S'], num_samples),\n", + "# 'lab': ['GIA'] * num_samples,\n", + "# 'length': np.random.uniform(4.0, 9.0, num_samples),\n", + "# 'width': np.random.uniform(4.0, 9.0, num_samples),\n", + "# 'depth': np.random.uniform(2.5, 5.5, num_samples),\n", + "# }\n", + "# return pd.DataFrame(data)\n", + "#\n", + "# df = create_dummy_data(200)\n", + "\n", + "# The user-provided dataframe variable is expected to be `data` (as in original).\n", + "# Keep the original behavior: df = data.copy()\n", + "try:\n", + " df = data.copy()\n", + "except NameError:\n", + " raise RuntimeError(\"DataFrame `data` not found. Create `data` or uncomment dummy data creation.\")\n", + "\n", + "# ------------------------\n", + "# Tabular preprocessing\n", + "# ------------------------\n", + "class TabularPreprocessor:\n", + " def __init__(self, categorical_cols: List[str], numeric_cols: List[str]):\n", + " self.categorical_cols = categorical_cols\n", + " self.numeric_cols = numeric_cols\n", + " self.label_encoders: Dict[str, LabelEncoder] = {}\n", + " self.scaler = StandardScaler()\n", + "\n", + " def fit(self, df: pd.DataFrame):\n", + " for c in self.categorical_cols:\n", + " le = LabelEncoder()\n", + " df[c] = df[c].astype(str).fillna('NA')\n", + " le.fit(df[c].values)\n", + " self.label_encoders[c] = le\n", + " if len(self.numeric_cols):\n", + " self.scaler.fit(df[self.numeric_cols].astype(float).values)\n", + "\n", + " def transform(self, df: pd.DataFrame):\n", + " cat_arrays = []\n", + " for c in self.categorical_cols:\n", + " arr = df[c].astype(str).fillna('NA').values\n", + " le = self.label_encoders[c]\n", + " cat_arrays.append(le.transform(arr))\n", + " cats = np.stack(cat_arrays, axis=1) if len(cat_arrays) else np.zeros((len(df), 0), dtype=np.int64)\n", + " nums = self.scaler.transform(df[self.numeric_cols].astype(float).values) if len(self.numeric_cols) else np.zeros((len(df), 0))\n", + " return cats, nums\n", + "\n", + " def get_cardinalities(self):\n", + " return {c: len(self.label_encoders[c].classes_) for c in self.categorical_cols}\n", + "\n", + "# ------------------------\n", + "# Robust split helper\n", + "# ------------------------\n", + "def safe_train_val_split(df_in: pd.DataFrame, target_col: str, test_size: float = 0.2, seed: int = 42):\n", + " if len(df_in) <= 10:\n", + " return df_in, df_in\n", + " counts = df_in[target_col].value_counts()\n", + " if counts.min() < 2:\n", + " print(f\"[warning] Stratified split not possible (min class={int(counts.min())}). Using random split.\")\n", + " return train_test_split(df_in, test_size=test_size, random_state=seed, shuffle=True)\n", + " return train_test_split(df_in, test_size=test_size, random_state=seed, stratify=df_in[target_col])\n", + "\n", + "# ------------------------\n", + "# Transforms & Dataset\n", + "# ------------------------\n", + "def get_transforms(img_size=224, train=True):\n", + " if train:\n", + " return transforms.Compose([\n", + " transforms.Resize((img_size, img_size)),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomRotation(10),\n", + " transforms.ColorJitter(0.05,0.05,0.05,0.01),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n", + " ])\n", + " else:\n", + " return transforms.Compose([\n", + " transforms.Resize((img_size, img_size)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])\n", + " ])\n", + "\n", + "# ------------------------\n", + "# UPDATED MultiModalDiamondDataset (robust fallback for bad images)\n", + "# ------------------------\n", + "class MultiModalDiamondDataset(Dataset):\n", + " def __init__(self, df: pd.DataFrame, tab_preproc: TabularPreprocessor,\n", + " categorical_cols: List[str], numeric_cols: List[str],\n", + " target_col: str, img_size:int, train=True):\n", + " self.df = df.reset_index(drop=True)\n", + " self.transform = get_transforms(img_size, train)\n", + " self.tab_preproc = tab_preproc\n", + " self.categorical_cols = categorical_cols\n", + " self.numeric_cols = numeric_cols\n", + " self.target_col = target_col\n", + " self.cats, self.nums = tab_preproc.transform(self.df)\n", + " # cache computed size to use for fallback images\n", + " self._fallback_size = self._detect_resize_size(img_size)\n", + "\n", + " def _detect_resize_size(self, default_size):\n", + " \"\"\"\n", + " Inspect transforms pipeline for a Resize transform and return (w,h) tuple.\n", + " If Resize.size is an int -> return (int,int). Otherwise return tuple.\n", + " If not found, return (default_size, default_size).\n", + " \"\"\"\n", + " try:\n", + " for t in getattr(self.transform, \"transforms\", []):\n", + " # torchvision.transforms.Resize instance\n", + " if isinstance(t, transforms.Resize):\n", + " s = getattr(t, \"size\", None)\n", + " if s is None:\n", + " continue\n", + " if isinstance(s, int):\n", + " return (s, s)\n", + " # ensure it's a 2-tuple\n", + " if isinstance(s, (tuple, list)) and len(s) == 2:\n", + " return (int(s[0]), int(s[1]))\n", + " # fallback\n", + " return (default_size, default_size)\n", + " except Exception:\n", + " return (default_size, default_size)\n", + "\n", + " def __len__(self):\n", + " return len(self.df)\n", + "\n", + " def __getitem__(self, idx):\n", + " row = self.df.iloc[idx]\n", + " img_path = row['full_path_to_img']\n", + "\n", + " img = None\n", + " # Primary attempt: PIL open\n", + " try:\n", + " img = Image.open(img_path).convert('RGB')\n", + " except Exception:\n", + " # Secondary attempt: read raw bytes then open (sometimes helps with broken file-like paths)\n", + " try:\n", + " with open(img_path, 'rb') as f:\n", + " img = Image.open(BytesIO(f.read())).convert('RGB')\n", + " except Exception:\n", + " # Final fallback: create a black image with correct shape\n", + " img = Image.new('RGB', self._fallback_size, (0, 0, 0))\n", + "\n", + " # apply transforms (Resize/ToTensor/Normalize etc.)\n", + " img = self.transform(img)\n", + "\n", + " cat = torch.tensor(self.cats[idx].astype(np.int64)) if self.cats.shape[1] else torch.empty(0, dtype=torch.long)\n", + " num = torch.tensor(self.nums[idx].astype(np.float32)) if self.nums.shape[1] else torch.empty(0, dtype=torch.float32)\n", + " target = row[self.target_col]\n", + " return {'image': img, 'cat': cat, 'num': num, 'target': torch.tensor(target, dtype=torch.long)}\n", + "\n", + "# ------------------------\n", + "# Model components\n", + "# ------------------------\n", + "class SimpleTabularEncoder(nn.Module):\n", + " def __init__(self, cardinalities: Dict[str,int], numeric_dim:int, emb_dim=128, hidden_dim=256):\n", + " super().__init__()\n", + " self.cat_cols = list(cardinalities.keys())\n", + " self.embs = nn.ModuleDict()\n", + " for k in self.cat_cols:\n", + " card = cardinalities[k]\n", + " self.embs[k] = nn.Embedding(card, min(50, (card+1)//2))\n", + " cat_total_dim = sum([self.embs[k].embedding_dim for k in self.cat_cols]) if self.cat_cols else 0\n", + " self.numeric_dim = numeric_dim\n", + " in_dim = cat_total_dim + numeric_dim\n", + " self.net = nn.Sequential(\n", + " nn.Linear(max(1, in_dim), hidden_dim),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(hidden_dim),\n", + " nn.Linear(hidden_dim, emb_dim),\n", + " )\n", + "\n", + " def forward(self, cat: torch.Tensor, num: torch.Tensor):\n", + " device = next(self.parameters()).device\n", + " if cat.shape[1] > 0:\n", + " emb_list = [self.embs[k](cat[:, i]) for i, k in enumerate(self.cat_cols)]\n", + " cat_emb = torch.cat(emb_list, dim=1)\n", + " else:\n", + " cat_emb = torch.zeros((cat.shape[0], 0), device=device)\n", + " x = torch.cat([cat_emb, num], dim=1) if num.shape[1] > 0 else cat_emb\n", + " if x.numel() == 0:\n", + " return torch.zeros((cat.shape[0], CONFIG['tab_emb_dim']), device=device)\n", + " return self.net(x)\n", + "\n", + "class CrossAttention(nn.Module):\n", + " def __init__(self, query_dim, context_dim, heads=8, dim_head=64, dropout=0.):\n", + " super().__init__()\n", + " inner_dim = dim_head * heads\n", + " self.scale = dim_head ** -0.5\n", + " self.heads = heads\n", + "\n", + " self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n", + " self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n", + " self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n", + " self.to_out = nn.Sequential(\n", + " nn.Linear(inner_dim, query_dim),\n", + " nn.Dropout(dropout)\n", + " )\n", + "\n", + " def forward(self, x, context):\n", + " # x: (B, seq, query_dim) or (B, 1, query_dim)\n", + " # context: (B, seq_ctx, context_dim)\n", + " q = self.to_q(x)\n", + " k = self.to_k(context)\n", + " v = self.to_v(context)\n", + "\n", + " # Reshape for multi-head attention -> (B, heads, seq, dim_head)\n", + " q, k, v = map(lambda t: t.view(t.shape[0], -1, self.heads, t.shape[-1] // self.heads).transpose(1, 2), (q, k, v))\n", + "\n", + " sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n", + " attn = sim.softmax(dim=-1)\n", + "\n", + " out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)\n", + " out = out.transpose(1, 2).reshape(out.shape[0], -1, self.heads * (out.shape[-1]))\n", + " return self.to_out(out)\n", + "\n", + "class CrossAttentionFusionBlock(nn.Module):\n", + " def __init__(self, tab_dim, img_dim, num_layers=2, heads=8):\n", + " super().__init__()\n", + " self.layers = nn.ModuleList([])\n", + " for _ in range(num_layers):\n", + " # dim_head chosen so that dim_head * heads == img_dim (or close)\n", + " dim_head = max(1, img_dim // heads)\n", + " self.layers.append(nn.ModuleList([\n", + " nn.LayerNorm(tab_dim),\n", + " nn.LayerNorm(img_dim),\n", + " CrossAttention(tab_dim, img_dim, heads=heads, dim_head=dim_head)\n", + " ]))\n", + "\n", + " def forward(self, tab_features, img_features):\n", + " # tab_features: (B, tab_dim)\n", + " # img_features: (B, seq_len, img_dim) typically from backbone.forward_features()\n", + " tab_features = tab_features.unsqueeze(1) # Add sequence dimension -> (B, 1, tab_dim)\n", + " for norm_tab, norm_img, attn in self.layers:\n", + " tab_features = attn(norm_tab(tab_features), norm_img(img_features)) + tab_features\n", + " return tab_features.squeeze(1)\n", + "\n", + "\n", + "class MultiModalModel(nn.Module):\n", + " def __init__(self, cfg: Dict, cardinalities: Dict[str,int], numeric_dim:int, num_classes:int):\n", + " super().__init__()\n", + " self.cfg = cfg\n", + " # create backbone; leave num_classes=0 to get features\n", + " self.backbone = timm.create_model(cfg['backbone_name'], pretrained=cfg['pretrained'], num_classes=0, img_size=cfg['img_size'])\n", + "\n", + " # some timm models have a strict_img_size flag on patch_embed\n", + " if hasattr(self.backbone, 'patch_embed') and hasattr(self.backbone.patch_embed, 'strict_img_size'):\n", + " try:\n", + " self.backbone.patch_embed.strict_img_size = False\n", + " except Exception:\n", + " pass\n", + "\n", + " if hasattr(self.backbone, \"set_grad_checkpointing\"):\n", + " try:\n", + " self.backbone.set_grad_checkpointing(enable=True)\n", + " except Exception:\n", + " pass\n", + "\n", + " embed_dim = getattr(self.backbone, 'num_features', getattr(self.backbone, 'embed_dim', 768))\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.tab_encoder = SimpleTabularEncoder(cardinalities, numeric_dim, emb_dim=cfg['tab_emb_dim'], hidden_dim=cfg['hidden_head_dim'])\n", + "\n", + " if cfg['fusion_mode'] == 'late':\n", + " head_in = embed_dim + cfg['tab_emb_dim']\n", + " self.head = nn.Sequential(\n", + " nn.Linear(head_in, cfg['hidden_head_dim']),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(cfg['hidden_head_dim']),\n", + " nn.Linear(cfg['hidden_head_dim'], num_classes)\n", + " )\n", + " elif cfg['fusion_mode'] == 'cross_attention':\n", + " self.fusion_block = CrossAttentionFusionBlock(\n", + " tab_dim=cfg['tab_emb_dim'],\n", + " img_dim=embed_dim,\n", + " num_layers=cfg.get('num_fusion_layers', 2),\n", + " heads=cfg.get('num_attention_heads', 8)\n", + " )\n", + " head_in = cfg['tab_emb_dim']\n", + " self.head = nn.Sequential(\n", + " nn.Linear(head_in, cfg['hidden_head_dim']),\n", + " nn.ReLU(),\n", + " nn.LayerNorm(cfg['hidden_head_dim']),\n", + " nn.Linear(cfg['hidden_head_dim'], num_classes)\n", + " )\n", + "\n", + " def forward(self, image: torch.Tensor, cat: torch.Tensor, num: torch.Tensor):\n", + " tab_emb = self.tab_encoder(cat, num)\n", + "\n", + " if self.cfg['fusion_mode'] == 'late':\n", + " img_cls = self.backbone.forward_features(image)[:, 0]\n", + " x = torch.cat([img_cls, tab_emb], dim=1)\n", + " return self.head(x)\n", + " elif self.cfg['fusion_mode'] == 'cross_attention':\n", + " img_features = self.backbone.forward_features(image) # All patch tokens -> (B, seq, embed_dim)\n", + " fused_emb = self.fusion_block(tab_emb, img_features)\n", + " return self.head(fused_emb)\n", + "\n", + "# ------------------------\n", + "# Train / Validate\n", + "# ------------------------\n", + "def train_one_epoch(model, loader, optimizer, device, epoch, cfg, scaler=None):\n", + " model.train()\n", + " losses = []\n", + " pbar = tqdm(loader, desc=f\"Train {epoch}\")\n", + " criterion = nn.CrossEntropyLoss()\n", + " accum_steps = cfg.get('accum_steps', 1)\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " for i, batch in enumerate(pbar):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0),0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0),0), dtype=torch.float32, device=device)\n", + " targets = batch['target'].to(device, non_blocking=True)\n", + "\n", + " if cfg.get('use_amp', False) and scaler is not None:\n", + " with torch.cuda.amp.autocast():\n", + " logits = model(imgs, cat, num)\n", + " loss = criterion(logits, targets) / accum_steps\n", + " scaler.scale(loss).backward()\n", + " else:\n", + " logits = model(imgs, cat, num)\n", + " loss = criterion(logits, targets) / accum_steps\n", + " loss.backward()\n", + "\n", + " if (i + 1) % accum_steps == 0:\n", + " if scaler is not None:\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " else:\n", + " optimizer.step()\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " losses.append(loss.item() * accum_steps)\n", + " pbar.set_postfix(loss=np.mean(losses))\n", + " return float(np.mean(losses))\n", + "\n", + "@torch.no_grad()\n", + "def validate(model, loader, device):\n", + " model.eval()\n", + " preds, trues = [], []\n", + " for batch in tqdm(loader, desc='Val'):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0),0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0),0), dtype=torch.float32, device=device)\n", + " targets = batch['target'].to(device, non_blocking=True)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds.append(logits.argmax(dim=1).cpu().numpy())\n", + " trues.append(targets.cpu().numpy())\n", + " preds = np.concatenate(preds) if len(preds) else np.array([])\n", + " trues = np.concatenate(trues) if len(trues) else np.array([])\n", + " acc = accuracy_score(trues, preds) if trues.size else 0.0\n", + " # Ensure labels for confusion matrix are within the range of predictions/trues\n", + " labels = np.unique(np.concatenate((trues, preds))) if trues.size or preds.size else np.array([])\n", + " cm = confusion_matrix(trues, preds, labels=labels) if trues.size else np.zeros((0,0), dtype=int)\n", + " return acc, cm\n", + "\n", + "# ------------------------\n", + "# Runner (refactored to accept cfg)\n", + "# ------------------------\n", + "def train_pipeline(train_df: pd.DataFrame, val_df: pd.DataFrame, target_col: str = 'colour', cfg: Optional[Dict]=None):\n", + " \"\"\"\n", + " Train pipeline accepts a config dict (copy of CONFIG) so each Optuna trial\n", + " can run independently with different parameters.\n", + " Returns best_val_acc (float).\n", + " \"\"\"\n", + " if cfg is None:\n", + " cfg = CONFIG\n", + " else:\n", + " # ensure default keys exist\n", + " default = CONFIG.copy()\n", + " for k, v in default.items():\n", + " cfg.setdefault(k, v)\n", + "\n", + " categorical_cols = ['clarity', 'cut', 'polish', 'symmetry', 'fluorescence', 'lab']\n", + " numeric_cols = ['carat', 'length', 'width', 'depth']\n", + "\n", + " # Fit label encoder for target using combined df (train + val)\n", + " df_combined = pd.concat([train_df, val_df], ignore_index=True)\n", + " te = LabelEncoder()\n", + " df_combined[target_col] = te.fit_transform(df_combined[target_col].astype(str))\n", + " cfg['num_classes'] = len(te.classes_)\n", + " tab_pre = TabularPreprocessor(categorical_cols, numeric_cols)\n", + " tab_pre.fit(df_combined)\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " train_df = train_df.copy()\n", + " val_df = val_df.copy()\n", + " train_df[target_col] = te.transform(train_df[target_col].astype(str))\n", + " val_df[target_col] = te.transform(val_df[target_col].astype(str))\n", + "\n", + " max_attempts, attempt = 4, 0\n", + " last_exc = None\n", + " best_acc = 0.0\n", + "\n", + " while attempt <= max_attempts:\n", + " try:\n", + " print(f\"[run] Attempt {attempt+1} | bs={cfg['batch_size']} | img={cfg['img_size']} | \"\n", + " f\"backbone={cfg['backbone_name']} | freeze={cfg['freeze_backbone']} | accum={cfg['accum_steps']}\")\n", + "\n", + " pin = (cfg['device'] == 'cuda') and torch.cuda.is_available()\n", + " train_ds = MultiModalDiamondDataset(train_df, tab_pre, categorical_cols, numeric_cols, target_col, img_size=cfg['img_size'], train=True)\n", + " val_ds = MultiModalDiamondDataset(val_df, tab_pre, categorical_cols, numeric_cols, target_col, img_size=cfg['img_size'], train=False)\n", + "\n", + " train_loader = DataLoader(train_ds, batch_size=cfg['batch_size'], shuffle=True,\n", + " num_workers=cfg['num_workers'], pin_memory=pin)\n", + " val_loader = DataLoader(val_ds, batch_size=cfg['batch_size'], shuffle=False,\n", + " num_workers=cfg['num_workers'], pin_memory=pin)\n", + "\n", + " device = torch.device(cfg['device'] if torch.cuda.is_available() and cfg['device']=='cuda' else \"cpu\")\n", + " if device.type == 'cuda':\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.cuda.empty_cache()\n", + "\n", + " model = MultiModalModel(cfg, card, numeric_dim=len(numeric_cols), num_classes=cfg['num_classes'])\n", + "\n", + " if cfg.get('freeze_backbone', True):\n", + " for p in model.backbone.parameters(): p.requires_grad = False\n", + "\n", + " model.to(device)\n", + "\n", + " opt = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad),\n", + " lr=cfg['lr'], weight_decay=cfg['weight_decay'])\n", + " scaler = torch.cuda.amp.GradScaler() if (cfg.get('use_amp', False) and torch.cuda.is_available()) else None\n", + "\n", + " best_acc = 0.0\n", + " for epoch in range(cfg['epochs']):\n", + " train_loss = train_one_epoch(model, train_loader, opt, device, epoch, cfg, scaler=scaler)\n", + " acc, cm = validate(model, val_loader, device)\n", + " print(f\"Epoch {epoch} | Train loss: {train_loss:.4f} | Val Acc: {acc:.4f}\")\n", + " # print small cm\n", + " print(cm)\n", + " if acc > best_acc:\n", + " best_acc = acc\n", + " # Save best model per-run with config and tab_pre\n", + " torch.save({\"state_dict\": model.state_dict(),\n", + " \"classes\": te.classes_,\n", + " \"config\": cfg,\n", + " \"tab_pre\": tab_pre}, \"best_multimodal.pth\")\n", + " print(f\"Model saved with new best accuracy: {best_acc:.4f}\")\n", + " print(\"Best val acc for this run:\", best_acc)\n", + " break\n", + "\n", + " except RuntimeError as e:\n", + " last_exc = e\n", + " msg = str(e).lower()\n", + " if (\"out of memory\" in msg) or (\"cuda out of memory\" in msg):\n", + " print(f\"[OOM] attempt {attempt+1}: {e}\")\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + "\n", + " if attempt == 0:\n", + " old = cfg['batch_size']\n", + " cfg['batch_size'] = max(1, old // 2)\n", + " cfg['accum_steps'] = max(cfg.get('accum_steps', 1), 2)\n", + " print(f\"[mitigation] batch {old} -> {cfg['batch_size']} | accum -> {cfg['accum_steps']}\")\n", + " elif attempt == 1:\n", + " cfg['freeze_backbone'] = True\n", + " print(\"[mitigation] freeze_backbone=True\")\n", + " elif attempt == 2:\n", + " old = cfg['img_size']\n", + " cfg['img_size'] = max(96, old // 2)\n", + " print(f\"[mitigation] img_size {old} -> {cfg['img_size']}\")\n", + " elif attempt == 3:\n", + " old = cfg['backbone_name']\n", + " cfg['backbone_name'] = 'vit_tiny_patch16_224' if 'small' in old else 'vit_small_patch16_224'\n", + " print(f\"[mitigation] backbone {old} -> {cfg['backbone_name']}\")\n", + " else:\n", + " print(\"[OOM] All mitigations exhausted.\")\n", + " raise\n", + " attempt += 1\n", + " continue\n", + " else:\n", + " raise\n", + " else:\n", + " if last_exc is not None:\n", + " raise last_exc\n", + "\n", + " return float(best_acc)\n", + "\n", + "def test_model(test_df: pd.DataFrame, target_col: str = 'colour', model_path: str = \"best_multimodal.pth\", device: Optional[torch.device] = None, save_preds_csv: bool = True, top_n_mismatch: int = 50, save_dir: str = \"/kaggle/working/\"):\n", + " import matplotlib.pyplot as plt\n", + " os.makedirs(save_dir, exist_ok=True)\n", + "\n", + " if not os.path.exists(model_path):\n", + " print(f\"Error: Model file '{model_path}' not found. Please train the model first.\")\n", + " return None\n", + "\n", + " checkpoint = torch.load(model_path, map_location='cpu')\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " saved_classes = checkpoint[\"classes\"]\n", + " saved_config = checkpoint[\"config\"]\n", + " tab_pre = checkpoint[\"tab_pre\"]\n", + "\n", + " print(\"--- Starting Test Evaluation ---\")\n", + " print(f\"Loading model trained on classes: {saved_classes}\")\n", + "\n", + " te = LabelEncoder()\n", + " te.classes_ = saved_classes\n", + "\n", + " test_df = test_df.copy()\n", + "\n", + " y_true = None\n", + " if target_col in test_df.columns:\n", + " # Filter out unseen labels from the test set\n", + " seen_mask = test_df[target_col].isin(saved_classes)\n", + " if not seen_mask.all():\n", + " print(f\"[Warning] Found {sum(~seen_mask)} samples with labels not seen during training. These will be ignored for metrics.\")\n", + " test_df = test_df[seen_mask].reset_index(drop=True)\n", + "\n", + " if not test_df.empty:\n", + " y_true = te.transform(test_df[target_col].astype(str))\n", + " test_df[target_col] = y_true\n", + " else:\n", + " print(\"Test set is empty after filtering unseen labels.\")\n", + "\n", + " categorical_cols = getattr(tab_pre, \"categorical_cols\", [])\n", + " numeric_cols = getattr(tab_pre, \"numeric_cols\", [])\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " if device is None:\n", + " device = torch.device('cuda' if torch.cuda.is_available() and saved_config.get('device') == 'cuda' else 'cpu')\n", + "\n", + " model = MultiModalModel(saved_config, card, numeric_dim=len(numeric_cols), num_classes=len(saved_classes))\n", + " model.load_state_dict(state_dict)\n", + " model.to(device)\n", + "\n", + " pin = device.type == 'cuda'\n", + " num_workers = min(4, saved_config.get('num_workers', 0)) if pin else 0\n", + " test_ds = MultiModalDiamondDataset(test_df, tab_pre, categorical_cols, numeric_cols, target_col, img_size=saved_config.get('img_size', 128), train=False)\n", + " test_loader = DataLoader(test_ds, batch_size=max(1, saved_config.get('batch_size', 1)),\n", + " shuffle=False, num_workers=num_workers, pin_memory=pin)\n", + "\n", + " model.eval()\n", + " all_preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(test_loader, desc=\"Test\"):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.float32, device=device)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds = logits.argmax(dim=1).cpu().numpy()\n", + " all_preds.append(preds)\n", + "\n", + " preds = np.concatenate(all_preds) if len(all_preds) else np.array([], dtype=int)\n", + " pred_labels = te.inverse_transform(preds) if preds.size else np.array([])\n", + "\n", + " compare_df = test_df.copy()\n", + " compare_df['predicted_label'] = pred_labels\n", + " if 'actual_label' not in compare_df.columns and y_true is not None:\n", + " compare_df['actual_label'] = te.inverse_transform(y_true)\n", + "\n", + " plots = {}\n", + " if y_true is not None and len(y_true):\n", + " acc = accuracy_score(y_true, preds)\n", + " cm = confusion_matrix(y_true, preds, labels=np.arange(len(saved_classes)))\n", + "\n", + " fig_cm, ax = plt.subplots(figsize=(8, 8))\n", + " im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", + " ax.set_title('Confusion Matrix')\n", + " fig_cm.colorbar(im)\n", + " tick_marks = np.arange(len(saved_classes))\n", + " ax.set_xticks(tick_marks)\n", + " ax.set_xticklabels(saved_classes, rotation=90)\n", + " ax.set_yticks(tick_marks)\n", + " ax.set_yticklabels(saved_classes)\n", + " ax.set_ylabel('True label')\n", + " ax.set_xlabel('Predicted label')\n", + " plt.tight_layout()\n", + " cm_path = os.path.join(save_dir, 'confusion_matrix.png')\n", + " fig_cm.savefig(cm_path, dpi=150)\n", + " plt.close(fig_cm)\n", + " plots['confusion_matrix'] = cm_path\n", + "\n", + " print(f\"Final Test Accuracy: {acc:.4f}\")\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions_with_actuals.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " return pred_labels, float(acc), cm, plots\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " print(\"No true labels available. Returning predictions only.\")\n", + " return pred_labels, plots\n", + "\n", + "# ------------------------\n", + "# Optuna integration\n", + "# ------------------------\n", + "def run_optuna_search(train_df, val_df, n_trials: int = 12, timeout: Optional[int] = None):\n", + " \"\"\"\n", + " Runs Optuna to maximize validation accuracy. Uses outer `CONFIG` as base and returns best_params.\n", + " \"\"\"\n", + " base_cfg = CONFIG.copy()\n", + "\n", + " def objective(trial: optuna.trial.Trial):\n", + " # copy base cfg for this trial\n", + " cfg = copy.deepcopy(base_cfg)\n", + "\n", + " # ---- suggested params ----\n", + " cfg['lr'] = trial.suggest_loguniform(\"lr\", 1e-5, 1e-3)\n", + " cfg['weight_decay'] = trial.suggest_loguniform(\"weight_decay\", 1e-6, 1e-2)\n", + " cfg['tab_emb_dim'] = trial.suggest_categorical(\"tab_emb_dim\", [64, 128, 256])\n", + " cfg['hidden_head_dim'] = trial.suggest_categorical(\"hidden_head_dim\", [128, 256, 512])\n", + "\n", + " # backbone choices (smaller -> faster)\n", + " cfg['backbone_name'] = trial.suggest_categorical(\"backbone_name\", [\n", + " \"vit_tiny_patch16_224\",\n", + " \"vit_small_patch16_224\",\n", + " \"vit_base_patch16_224\"\n", + " ])\n", + "\n", + " # image size choices: smaller sizes for tiny/small\n", + " cfg['img_size'] = trial.suggest_categorical(\"img_size\", [96, 128, 160])\n", + " # batch size constrained to small numbers (avoid OOM)\n", + " cfg['batch_size'] = int(trial.suggest_categorical(\"batch_size\", [1, 2, 4]))\n", + " cfg['freeze_backbone'] = trial.suggest_categorical(\"freeze_backbone\", [True, False])\n", + " cfg['num_fusion_layers'] = trial.suggest_int(\"num_fusion_layers\", 1, 3)\n", + " cfg['num_attention_heads'] = trial.suggest_categorical(\"num_attention_heads\", [4, 8])\n", + " cfg['accum_steps'] = int(trial.suggest_categorical(\"accum_steps\", [1, 2, 4, 8]))\n", + "\n", + " # keep other keys\n", + " cfg['pretrained'] = base_cfg.get('pretrained', True)\n", + " cfg['use_amp'] = base_cfg.get('use_amp', False) and torch.cuda.is_available()\n", + " cfg['device'] = 'cuda' if (torch.cuda.is_available() and cfg['use_amp']) else 'cpu'\n", + "\n", + " # Print a short summary for debugging\n", + " print(f\"[trial {trial.number}] cfg lr={cfg['lr']:.2e} wd={cfg['weight_decay']:.2e} bs={cfg['batch_size']} img={cfg['img_size']} backbone={cfg['backbone_name']} freeze={cfg['freeze_backbone']} accum={cfg['accum_steps']}\")\n", + "\n", + " # Ensure deterministic-ish behavior per trial\n", + " torch.manual_seed(cfg['seed']); np.random.seed(cfg['seed']); random.seed(cfg['seed'])\n", + "\n", + " # run train pipeline for this trial and return best val acc\n", + " try:\n", + " best_acc = train_pipeline(train_df.copy(), val_df.copy(), target_col='colour', cfg=cfg)\n", + " finally:\n", + " # free GPU memory between trials\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + "\n", + " # optuna tries to maximize validation acc\n", + " return float(best_acc)\n", + "\n", + " study = optuna.create_study(direction=\"maximize\", sampler=optuna.samplers.TPESampler())\n", + " study.optimize(objective, n_trials=n_trials, timeout=timeout)\n", + "\n", + " print(\"Optuna best trial:\")\n", + " print(study.best_trial.params)\n", + " return study\n", + "\n", + "# ------------------------\n", + "# Run example / entrypoint\n", + "# ------------------------\n", + "if __name__ == '__main__':\n", + " # 1. Split the data once\n", + " train_df, val_df = safe_train_val_split(df.copy(), 'colour', test_size=0.2, seed=CONFIG['seed'])\n", + "\n", + " # 2. Run Optuna search (adjust n_trials as you like)\n", + " N_TRIALS = 2\n", + " print(f\"\\n--- Starting Optuna search for {N_TRIALS} trials ---\")\n", + " study = run_optuna_search(train_df, val_df, n_trials=N_TRIALS)\n", + "\n", + " # 3. Run final training using best params (recreate final config)\n", + " best_params = study.best_trial.params\n", + " final_cfg = CONFIG.copy()\n", + " # map best params into final_cfg with safe casting\n", + " for k, v in best_params.items():\n", + " final_cfg[k] = v\n", + "\n", + " # ensure use_amp and device sensible\n", + " final_cfg['use_amp'] = final_cfg.get('use_amp', False) and torch.cuda.is_available()\n", + " final_cfg['device'] = 'cuda' if (torch.cuda.is_available() and final_cfg['use_amp']) else 'cpu'\n", + " # keep original epochs value already in CONFIG\n", + "\n", + " print(\"\\n--- Training final model with best Optuna params ---\")\n", + " best_acc = train_pipeline(train_df.copy(), val_df.copy(), target_col='colour', cfg=final_cfg)\n", + " print(\"Final training completed. Best validation acc:\", best_acc)\n", + "\n", + " # 4. Test/validate on the validation set as a quick sanity check\n", + " print(\"\\n--- Running Test on Validation Set ---\")\n", + " if os.path.exists(\"best_multimodal.pth\"):\n", + " test_output = test_model(val_df, save_preds_csv=True)\n", + " if test_output:\n", + " if len(test_output) == 4:\n", + " preds, acc, cm, plots = test_output\n", + " print(f\"Test Accuracy: {acc}\")\n", + " if 'confusion_matrix' in plots:\n", + " print(f\"Confusion matrix saved to {plots['confusion_matrix']}\")\n", + " if 'preds_csv' in plots:\n", + " print(f\"Predictions saved to {plots['preds_csv']}\")\n", + " print(pd.read_csv(plots['preds_csv']).head())\n", + " else:\n", + " preds, plots = test_output\n", + " print(\"Test completed without ground truth labels.\")\n", + " else:\n", + " print(\"Training did not complete successfully, skipping test.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1dHrbwAwOORj", + "trusted": true + }, + "outputs": [], + "source": [ + "def test_model(test_df: pd.DataFrame, target_col: str = 'colour', model_path: str = \"best_multimodal.pth\", device: Optional[torch.device] = None, save_preds_csv: bool = True, top_n_mismatch: int = 50, save_dir: str = \"/kaggle/working/\"):\n", + " import matplotlib.pyplot as plt\n", + " os.makedirs(save_dir, exist_ok=True)\n", + "\n", + " if not os.path.exists(model_path):\n", + " print(f\"Error: Model file '{model_path}' not found. Please train the model first.\")\n", + " return None\n", + "\n", + " checkpoint = torch.load(model_path, map_location='cpu')\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " saved_classes = checkpoint[\"classes\"]\n", + " saved_config = checkpoint[\"config\"]\n", + " tab_pre = checkpoint[\"tab_pre\"]\n", + "\n", + " print(\"--- Starting Test Evaluation ---\")\n", + " print(f\"Loading model trained on classes: {saved_classes}\")\n", + "\n", + " te = LabelEncoder()\n", + " te.classes_ = saved_classes\n", + "\n", + " test_df = test_df.copy()\n", + "\n", + " y_true = None\n", + " if target_col in test_df.columns:\n", + " # Filter out unseen labels from the test set\n", + " seen_mask = test_df[target_col].isin(saved_classes)\n", + " if not seen_mask.all():\n", + " print(f\"[Warning] Found {sum(~seen_mask)} samples with labels not seen during training. These will be ignored for metrics.\")\n", + " test_df = test_df[seen_mask].reset_index(drop=True)\n", + "\n", + " if not test_df.empty:\n", + " y_true = te.transform(test_df[target_col].astype(str))\n", + " test_df[target_col] = y_true\n", + " else:\n", + " print(\"Test set is empty after filtering unseen labels.\")\n", + "\n", + " categorical_cols = getattr(tab_pre, \"categorical_cols\", [])\n", + " numeric_cols = getattr(tab_pre, \"numeric_cols\", [])\n", + " card = tab_pre.get_cardinalities()\n", + "\n", + " if device is None:\n", + " device = torch.device('cuda' if torch.cuda.is_available() and saved_config.get('device') == 'cuda' else 'cpu')\n", + "\n", + " model = MultiModalModel(saved_config, card, numeric_dim=len(numeric_cols), num_classes=len(saved_classes))\n", + " model.load_state_dict(state_dict)\n", + " model.to(device)\n", + "\n", + " pin = device.type == 'cuda'\n", + " num_workers = min(4, saved_config.get('num_workers', 0)) if pin else 0\n", + " test_ds = MultiModalDiamondDataset(test_df, tab_pre, categorical_cols, numeric_cols, target_col, img_size=saved_config.get('img_size', 128), train=False)\n", + " test_loader = DataLoader(test_ds, batch_size=max(1, saved_config.get('batch_size', 1)),\n", + " shuffle=False, num_workers=num_workers, pin_memory=pin)\n", + "\n", + " model.eval()\n", + " all_preds = []\n", + " with torch.no_grad():\n", + " for batch in tqdm(test_loader, desc=\"Test\"):\n", + " imgs = batch['image'].to(device, non_blocking=True)\n", + " cat = batch['cat'].to(device, non_blocking=True) if batch['cat'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.long, device=device)\n", + " num = batch['num'].to(device, non_blocking=True) if batch['num'].numel() else torch.empty((imgs.size(0), 0), dtype=torch.float32, device=device)\n", + "\n", + " logits = model(imgs, cat, num)\n", + " preds = logits.argmax(dim=1).cpu().numpy()\n", + " all_preds.append(preds)\n", + "\n", + " preds = np.concatenate(all_preds) if len(all_preds) else np.array([], dtype=int)\n", + " pred_labels = te.inverse_transform(preds) if preds.size else np.array([])\n", + "\n", + " compare_df = test_df.copy()\n", + " compare_df['predicted_label'] = pred_labels\n", + " if 'actual_label' not in compare_df.columns and y_true is not None:\n", + " compare_df['actual_label'] = te.inverse_transform(y_true)\n", + "\n", + " plots = {}\n", + " if y_true is not None and len(y_true):\n", + " acc = accuracy_score(y_true, preds)\n", + " cm = confusion_matrix(y_true, preds, labels=np.arange(len(saved_classes)))\n", + "\n", + " fig_cm, ax = plt.subplots(figsize=(8, 8))\n", + " im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", + " ax.set_title('Confusion Matrix')\n", + " fig_cm.colorbar(im)\n", + " tick_marks = np.arange(len(saved_classes))\n", + " ax.set_xticks(tick_marks)\n", + " ax.set_xticklabels(saved_classes, rotation=90)\n", + " ax.set_yticks(tick_marks)\n", + " ax.set_yticklabels(saved_classes)\n", + " ax.set_ylabel('True label')\n", + " ax.set_xlabel('Predicted label')\n", + " plt.tight_layout()\n", + " cm_path = os.path.join(save_dir, 'confusion_matrix.png')\n", + " fig_cm.savefig(cm_path, dpi=150)\n", + " plt.close(fig_cm)\n", + " plots['confusion_matrix'] = cm_path\n", + "\n", + " print(f\"Final Test Accuracy: {acc:.4f}\")\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions_with_actuals.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " return pred_labels, float(acc), cm, plots\n", + "\n", + " if save_preds_csv:\n", + " preds_csv = os.path.join(save_dir, 'predictions.csv')\n", + " compare_df.to_csv(preds_csv, index=False)\n", + " plots['preds_csv'] = preds_csv\n", + "\n", + " print(\"No true labels available. Returning predictions only.\")\n", + " return pred_labels, plots\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lea83G-9OORk", + "trusted": true + }, + "outputs": [], + "source": [ + "train_df, val_df = safe_train_val_split(df.copy(), 'colour', test_size=0.2, seed=CONFIG['seed'])\n", + "\n", + "# 4. Test/validate on the validation set as a quick sanity check\n", + "print(\"\\n--- Running Test on Validation Set ---\")\n", + "if os.path.exists(\"best_multimodal.pth\"):\n", + " test_output = test_model(val_df, save_preds_csv=True)\n", + " if test_output:\n", + " if len(test_output) == 4:\n", + " preds, acc, cm, plots = test_output\n", + " print(f\"Test Accuracy: {acc}\")\n", + " if 'confusion_matrix' in plots:\n", + " print(f\"Confusion matrix saved to {plots['confusion_matrix']}\")\n", + " if 'preds_csv' in plots:\n", + " print(f\"Predictions saved to {plots['preds_csv']}\")\n", + " print(pd.read_csv(plots['preds_csv']).head())\n", + " else:\n", + " preds, plots = test_output\n", + " print(\"Test completed without ground truth labels.\")\n", + "else:\n", + " print(\"Training did not complete successfully, skipping test.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-05T12:30:32.026518Z", + "iopub.status.busy": "2025-09-05T12:30:32.025622Z", + "iopub.status.idle": "2025-09-05T12:30:32.062408Z", + "shell.execute_reply": "2025-09-05T12:30:32.061368Z", + "shell.execute_reply.started": "2025-09-05T12:30:32.026483Z" + }, + "id": "qEbIf4slOORk", + "trusted": true + }, + "outputs": [], + "source": [ + "dx = pd.read_csv(plots['preds_csv']).head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-05T12:30:55.857563Z", + "iopub.status.busy": "2025-09-05T12:30:55.857172Z", + "iopub.status.idle": "2025-09-05T12:30:55.864541Z", + "shell.execute_reply": "2025-09-05T12:30:55.863535Z", + "shell.execute_reply.started": "2025-09-05T12:30:55.857534Z" + }, + "id": "isuXpflWOORk", + "outputId": "4f2f5386-1095-45e6-f619-14a86b474ef8", + "trusted": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'path_to_img,stock_number,shape,carat,clarity,colour,cut,polish,symmetry,fluorescence,lab,length,width,depth,full_path_to_img,actual_label,predicted_label\\nweb_scraped/emerald/2132934.jpg,2132934,emerald,0.53,VVS2,3,EX,EX,VG,N,GIA,5.23,3.86,2.63,/kaggle/input/diamond-images-dataset/web_scraped/emerald/2132934.jpg,E,F\\nweb_scraped/princess/2101219.jpg,2101219,princess,0.5,VS1,3,EX,EX,VG,N,GIA,4.34,4.25,3.04,/kaggle/input/diamond-images-dataset/web_scraped/princess/2101219.jpg,E,E\\nweb_scraped/round/2127275.jpg,2127275,round,0.53,VS1,1,EX,EX,EX,N,GIA,5.26,5.28,3.18,/kaggle/input/diamond-images-dataset/web_scraped/round/2127275.jpg,D,F\\nweb_scraped/princess/223350-231.jpg,223350-231,princess,1.2,VS2,3,GD,EX,VG,N,GIA,5.64,5.47,4.29,/kaggle/input/diamond-images-dataset/web_scraped/princess/223350-231.jpg,E,H\\nweb_scraped/round/2087662.jpg,2087662,round,0.7,SI1,12,VG,VG,VG,N,GIA,5.54,5.58,3.58,/kaggle/input/diamond-images-dataset/web_scraped/round/2087662.jpg,K,G\\n'" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dx.to_csv(index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gUcj4gprOORk", + "trusted": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-05T12:20:27.240393Z", + "iopub.status.busy": "2025-09-05T12:20:27.239472Z", + "iopub.status.idle": "2025-09-05T12:20:27.348281Z", + "shell.execute_reply": "2025-09-05T12:20:27.347109Z", + "shell.execute_reply.started": "2025-09-05T12:20:27.240357Z" + }, + "id": "7I1VZAuPOORk", + "outputId": "af0cd715-0860-4c83-dd13-8c9b6270c17d", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " path_to_img stock_number shape carat clarity \\\n", + "0 web_scraped/emerald/2132934.jpg 2132934 emerald 0.53 VVS2 \n", + "1 web_scraped/princess/2101219.jpg 2101219 princess 0.50 VS1 \n", + "2 web_scraped/round/2127275.jpg 2127275 round 0.53 VS1 \n", + "3 web_scraped/princess/223350-231.jpg 223350-231 princess 1.20 VS2 \n", + "4 web_scraped/round/2087662.jpg 2087662 round 0.70 SI1 \n", + "\n", + " colour cut polish symmetry fluorescence lab length width depth \\\n", + "0 3 EX EX VG N GIA 5.23 3.86 2.63 \n", + "1 3 EX EX VG N GIA 4.34 4.25 3.04 \n", + "2 1 EX EX EX N GIA 5.26 5.28 3.18 \n", + "3 3 GD EX VG N GIA 5.64 5.47 4.29 \n", + "4 12 VG VG VG N GIA 5.54 5.58 3.58 \n", + "\n", + " full_path_to_img actual_label \\\n", + "0 /kaggle/input/diamond-images-dataset/web_scrap... E \n", + "1 /kaggle/input/diamond-images-dataset/web_scrap... E \n", + "2 /kaggle/input/diamond-images-dataset/web_scrap... D \n", + "3 /kaggle/input/diamond-images-dataset/web_scrap... E \n", + "4 /kaggle/input/diamond-images-dataset/web_scrap... K \n", + "\n", + " predicted_label \n", + "0 F \n", + "1 E \n", + "2 F \n", + "3 H \n", + "4 G \n", + " img_path actual predicted\n", + "0 /kaggle/input/diamond-images-dataset/web_scrap... E F\n", + "1 /kaggle/input/diamond-images-dataset/web_scrap... D F\n", + "2 /kaggle/input/diamond-images-dataset/web_scrap... E H\n", + "3 /kaggle/input/diamond-images-dataset/web_scrap... K G\n", + "4 /kaggle/input/diamond-images-dataset/web_scrap... J K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/bin/xdg-open: 882: www-browser: not found\n", + "/usr/bin/xdg-open: 882: links2: not found\n", + "/usr/bin/xdg-open: 882: elinks: not found\n", + "/usr/bin/xdg-open: 882: links: not found\n", + "/usr/bin/xdg-open: 882: lynx: not found\n", + "/usr/bin/xdg-open: 882: w3m: not found\n", + "xdg-open: no method available for opening '/tmp/tmp6kxhi93z.PNG'\n", + "/usr/bin/xdg-open: 882: www-browser: not found\n", + "/usr/bin/xdg-open: 882: links2: not found\n", + "/usr/bin/xdg-open: 882: elinks: not found\n", + "/usr/bin/xdg-open: 882: links: not found\n", + "/usr/bin/xdg-open: 882: lynx: not found\n", + "/usr/bin/xdg-open: 882: w3m: not found\n", + "xdg-open: no method available for opening '/tmp/tmpvgsg0cps.PNG'\n" + ] + } + ], + "source": [ + " #--- Open plots ---\n", + "from PIL import Image\n", + "import pandas as pd\n", + "\n", + "# Show confusion matrix\n", + "Image.open(plots['confusion_matrix']).show()\n", + "\n", + "# Show actual vs predicted counts\n", + "Image.open(plots['actual_vs_pred_counts']).show()\n", + "\n", + "# --- Load CSVs ---\n", + "preds_df = pd.read_csv(plots['preds_csv'])\n", + "print(preds_df.head())\n", + "\n", + "mismatches_df = pd.read_csv(plots['mismatches_csv'])\n", + "print(mismatches_df.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:26.684309Z", + "iopub.status.busy": "2025-09-03T12:32:26.684080Z", + "iopub.status.idle": "2025-09-03T12:32:26.694114Z", + "shell.execute_reply": "2025-09-03T12:32:26.693489Z", + "shell.execute_reply.started": "2025-09-03T12:32:26.684290Z" + }, + "id": "0wfTVGZiOORk", + "trusted": true + }, + "outputs": [], + "source": [ + "# # look-up table\n", + "# classes = {0:'cushion', 1 : 'emerald', 2 : 'heart', 3 : 'marquise', 4 : 'oval', 5: 'pear', 6: 'princess', 7: 'round'}\n", + "\n", + "\n", + "# # a function for encoding classes\n", + "# def create_class(X):\n", + "# if X == 'cushion':\n", + "# return 0\n", + "# elif X =='emerald':\n", + "# return 1\n", + "# elif X == 'heart':\n", + "# return 2\n", + "# elif X == 'marquise':\n", + "# return 3\n", + "# elif X == 'oval':\n", + "# return 4\n", + "# elif X == 'pear':\n", + "# return 5\n", + "# elif X == 'princess':\n", + "# return 6\n", + "# elif X == 'round':\n", + "# return 7\n", + "# else:\n", + "# print('error class')\n", + "\n", + "\n", + "# # Encoding classes\n", + "# data['encoded_class'] = data['shape'].apply(create_class)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:26.695662Z", + "iopub.status.busy": "2025-09-03T12:32:26.695086Z", + "iopub.status.idle": "2025-09-03T12:32:26.784834Z", + "shell.execute_reply": "2025-09-03T12:32:26.784244Z", + "shell.execute_reply.started": "2025-09-03T12:32:26.695633Z" + }, + "id": "Ne8BKMPzOORk", + "trusted": true + }, + "outputs": [], + "source": [ + "# look-up table for colour categories\n", + "colour_classes = {\n", + " 0: 'N',\n", + " 1: 'Y-Z',\n", + " 2: 'L',\n", + " 3: 'M',\n", + " 4: 'K',\n", + " 5: 'S-T',\n", + " 6: 'W-X',\n", + " 7: 'U-V',\n", + " 8: 'J',\n", + " 9: 'FANCY',\n", + " 10: 'G',\n", + " 11: 'I',\n", + " 12: 'H',\n", + " 13: 'E',\n", + " 14: 'D',\n", + " 15: 'F',\n", + " 16: 'O-P',\n", + " 17: 'Q-R',\n", + " 18: 'BLUE',\n", + " 19: 'V:B',\n", + " 20: 'FC:P',\n", + " 21: 'D:P:BN',\n", + " 22: 'I:P'\n", + "}\n", + "\n", + "# function for encoding colours\n", + "def create_colour_class(X):\n", + " if X in colour_classes.values():\n", + " return list(colour_classes.keys())[list(colour_classes.values()).index(X)]\n", + " else:\n", + " print('error: colour not found')\n", + "\n", + "# Encoding colour categories\n", + "data['encoded_colour'] = data['colour'].apply(create_colour_class)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_1t2vffLOORk" + }, + "source": [ + "# Creating Train, Validation, Test Sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:26.786124Z", + "iopub.status.busy": "2025-09-03T12:32:26.785832Z", + "iopub.status.idle": "2025-09-03T12:32:26.789861Z", + "shell.execute_reply": "2025-09-03T12:32:26.789037Z", + "shell.execute_reply.started": "2025-09-03T12:32:26.786098Z" + }, + "id": "eZuLr5AFOORk", + "trusted": true + }, + "outputs": [], + "source": [ + "# y = data.pop('encoded_colour')\n", + "# X = data\n", + "\n", + "# X_data, X_test, y_data, y_test = train_test_split(X,y, test_size = 0.1, stratify = y , random_state=SEED, shuffle=True)\n", + "# X_train, X_val, y_train, y_val = train_test_split(X_data, y_data, test_size = 0.1, stratify = y_data, random_state=SEED, shuffle=True)\n", + "\n", + "\n", + "# print(\"train shape -> \", X_train.shape[0])\n", + "# print(\"val shape -> \", X_val.shape[0])\n", + "# print(\"test shape -> \", X_test.shape[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:26.792702Z", + "iopub.status.busy": "2025-09-03T12:32:26.792499Z", + "iopub.status.idle": "2025-09-03T12:32:26.845117Z", + "shell.execute_reply": "2025-09-03T12:32:26.844249Z", + "shell.execute_reply.started": "2025-09-03T12:32:26.792684Z" + }, + "id": "fnakGGb0OORk", + "outputId": "cbe57e94-f730-4275-fbaf-aa8116c11bb3", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Columns: ['path_to_img', 'stock_number', 'shape', 'carat', 'clarity', 'colour', 'cut', 'polish', 'symmetry', 'fluorescence', 'lab', 'length', 'width', 'depth', 'full_path_to_img', 'encoded_colour']\n", + "Class counts:\n", + " encoded_colour\n", + "10 7674\n", + "15 6957\n", + "12 6555\n", + "13 6239\n", + "14 5824\n", + "11 5304\n", + "8 4273\n", + "4 2627\n", + "2 1388\n", + "3 754\n", + "0 417\n", + "9 328\n", + "16 134\n", + "17 80\n", + "7 78\n", + "5 68\n", + "6 36\n", + "1 23\n", + "18 1\n", + "19 1\n", + "20 1\n", + "21 1\n", + "22 1\n", + "Name: count, dtype: int64\n", + "train shape -> 39498\n", + "val shape -> 4389\n", + "test shape -> 4877\n" + ] + } + ], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "SEED = 42 # or whatever you use\n", + "\n", + "# 1) Inspect available columns (quick debug)\n", + "print(\"Columns:\", data.columns.tolist())\n", + "\n", + "# 2) If you previously made a mapping like this (int -> label), define it:\n", + "colour_classes = {\n", + " 0: 'N', 1: 'Y-Z', 2: 'L', 3: 'M', 4: 'K', 5: 'S-T', 6: 'W-X',\n", + " 7: 'U-V', 8: 'J', 9: 'FANCY', 10: 'G', 11: 'I', 12: 'H',\n", + " 13: 'E', 14: 'D', 15: 'F', 16: 'O-P', 17: 'Q-R', 18: 'BLUE',\n", + " 19: 'V:B', 20: 'FC:P', 21: 'D:P:BN', 22: 'I:P'\n", + "}\n", + "\n", + "# 3) Build reverse mapping: label -> int\n", + "colour_to_int = {v: k for k, v in colour_classes.items()}\n", + "\n", + "# 4) Create 'encoded_colour' column (if it doesn't exist or to recreate it)\n", + "# Assumes the original textual column is named 'colour' (adjust if your name differs).\n", + "if 'encoded_colour' not in data.columns:\n", + " if 'colour' not in data.columns:\n", + " raise KeyError(\"Column 'colour' not found. Adjust the source column name.\")\n", + " data['encoded_colour'] = data['colour'].map(colour_to_int)\n", + "\n", + "# 5) Check for unmapped categories (diagnose)\n", + "if data['encoded_colour'].isnull().any():\n", + " unmapped = data.loc[data['encoded_colour'].isnull(), 'colour'].unique()\n", + " print(\"Unmapped colour categories found:\", unmapped)\n", + " # Optionally handle them:\n", + " # data['encoded_colour'] = data['encoded_colour'].fillna(-1).astype(int)\n", + " # Or map them to an 'OTHER' code and update colour_to_int accordingly.\n", + "\n", + "# 6) Safer splitting: avoid pop so you can re-run cells without losing the column\n", + "y = data['encoded_colour'].copy()\n", + "X = data.drop(columns=['encoded_colour'])\n", + "\n", + "print(\"Class counts:\\n\", y.value_counts())\n", + "\n", + "# If you don't need stratification (because of very rare classes), do:\n", + "X_data, X_test, y_data, y_test = train_test_split(\n", + " X, y, test_size=0.1, random_state=SEED, shuffle=True\n", + ")\n", + "\n", + "X_train, X_val, y_train, y_val = train_test_split(\n", + " X_data, y_data, test_size=0.1, random_state=SEED, shuffle=True\n", + ")\n", + "\n", + "print(\"train shape -> \", X_train.shape[0])\n", + "print(\"val shape -> \", X_val.shape[0])\n", + "print(\"test shape -> \", X_test.shape[0])\n", + "\n", + "# 7) If you want to stratify but have rare classes, consider merging rare classes:\n", + "# counts = y.value_counts()\n", + "# rare = counts[counts < 2].index.tolist()\n", + "# if rare:\n", + "# print(\"Merging rare classes:\", rare)\n", + "# y_replaced = y.replace(rare, -1) # -1 for OTHER; update mapping if needed\n", + "# # then use stratify=y_replaced in train_test_split\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mgIMB9DUOORl" + }, + "source": [ + "# Creating tf.data Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:26.846835Z", + "iopub.status.busy": "2025-09-03T12:32:26.846212Z", + "iopub.status.idle": "2025-09-03T12:32:26.852610Z", + "shell.execute_reply": "2025-09-03T12:32:26.851687Z", + "shell.execute_reply.started": "2025-09-03T12:32:26.846803Z" + }, + "id": "UtqeGO3vOORl", + "trusted": true + }, + "outputs": [], + "source": [ + "# Reading -> Resizing -> Normalization\n", + "def img_preprocessing(image, label):\n", + " img = tf.io.read_file(image)\n", + " img = tf.io.decode_jpeg(img, channels = 3)\n", + " img = tf.image.resize(img, size = (IMG_SIZE))\n", + " img = tf.cast(img, tf.float32) / 255.0\n", + "\n", + " return img, label\n", + "\n", + "\n", + "# Data augmentation\n", + "def augmentation(image, label):\n", + " img = tf.image.random_flip_left_right(image, seed = SEED)\n", + " img = tf.image.random_flip_up_down(img, seed = SEED)\n", + " img = tf.image.random_brightness(img, 0.1, seed = SEED)\n", + " img = tf.image.random_contrast(img, 0.2, 0.4, seed = SEED)\n", + " img = tf.image.random_saturation(img, 2, 6, seed = SEED)\n", + "\n", + " return img, label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:26.854237Z", + "iopub.status.busy": "2025-09-03T12:32:26.853841Z", + "iopub.status.idle": "2025-09-03T12:32:27.127235Z", + "shell.execute_reply": "2025-09-03T12:32:27.126311Z", + "shell.execute_reply.started": "2025-09-03T12:32:26.854201Z" + }, + "id": "NTAUveiPOORl", + "trusted": true + }, + "outputs": [], + "source": [ + "# Creating dataset loaders and tf.datasets\n", + "\n", + "train_loader = tf.data.Dataset.from_tensor_slices((X_train['full_path_to_img'], y_train))\n", + "train_dataset = (train_loader\n", + " .map(img_preprocessing, num_parallel_calls = AUTO)\n", + " .map(augmentation, num_parallel_calls = AUTO)\n", + " .shuffle(BATCH_SIZE * 10)\n", + " .batch(BATCH_SIZE)\n", + " .prefetch(AUTO))\n", + "\n", + "\n", + "# Training dataset without shuffling and data augmantation operations for the classification stage\n", + "train_loader_feature = tf.data.Dataset.from_tensor_slices((X_train['full_path_to_img'], y_train))\n", + "train_dataset_feature = (train_loader_feature\n", + " .map(img_preprocessing, num_parallel_calls = AUTO)\n", + " .batch(BATCH_SIZE)\n", + " .prefetch(AUTO))\n", + "\n", + "\n", + "valid_loader = tf.data.Dataset.from_tensor_slices((X_val['full_path_to_img'], y_val))\n", + "valid_dataset = (valid_loader\n", + " .map(img_preprocessing, num_parallel_calls = AUTO)\n", + " .batch(BATCH_SIZE)\n", + " .prefetch(AUTO))\n", + "\n", + "\n", + "test_loader = tf.data.Dataset.from_tensor_slices((X_test['full_path_to_img'], y_test))\n", + "test_dataset = (test_loader\n", + " .map(img_preprocessing, num_parallel_calls = AUTO)\n", + " .batch(BATCH_SIZE)\n", + " .prefetch(AUTO))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r8ibYqOVOORl" + }, + "source": [ + "# Feature Extraction " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "krAK0tj_OORl" + }, + "source": [ + "# Custom ViT Model Feature Extractor " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T10:00:50.807579Z", + "iopub.status.busy": "2025-09-03T10:00:50.806597Z", + "iopub.status.idle": "2025-09-03T10:00:50.811569Z", + "shell.execute_reply": "2025-09-03T10:00:50.810713Z", + "shell.execute_reply.started": "2025-09-03T10:00:50.807545Z" + }, + "id": "pdtm9zW6OORl", + "trusted": true + }, + "outputs": [], + "source": [ + "# from vit_keras import vit\n", + "\n", + "# with stg.scope():\n", + "# vit_model = vit.build_model(image_size = IMG_SIZE, patch_size = 5, activation = 'softmax', include_top = False,\n", + "# classes = 8, num_layers = 5, hidden_size = 128, mlp_dim = 128, num_heads = 2, name = 'my_vit_model')\n", + "\n", + "# inp = Input(shape = (*IMG_SIZE, 3))\n", + "# vit = vit_model(inp)\n", + "# X = Flatten()(vit)\n", + "# X = Dense(64, activation = 'gelu', name = 'the_feature_layer')(X)\n", + "# X = Dense(32, activation = 'gelu')(X)\n", + "# out = Dense(8, activation = 'softmax')(X)\n", + "\n", + "# model = Model(inputs = inp, outputs = out)\n", + "# model.summary()\n", + "\n", + "# model.compile(optimizer = tf.keras.optimizers.AdamW(learning_rate = 0.0001,weight_decay = 0.0001),\n", + "# loss = tf.keras.losses.SparseCategoricalCrossentropy(),\n", + "# metrics = ['acc',tf.keras.metrics.SparseTopKCategoricalAccuracy(k = 4, name = \"top_4_acc\", dtype=None) ] )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T10:01:24.460049Z", + "iopub.status.busy": "2025-09-03T10:01:24.459251Z", + "iopub.status.idle": "2025-09-03T10:01:24.466297Z", + "shell.execute_reply": "2025-09-03T10:01:24.465483Z", + "shell.execute_reply.started": "2025-09-03T10:01:24.460016Z" + }, + "id": "rnJWISJOOORs", + "outputId": "6ff6a9c4-d20a-41ac-8311-c11d7fc0f130", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "keras -> \n", + "has keras.ops? False\n", + "tensorflow version: 2.13.0\n" + ] + } + ], + "source": [ + "# # CELL 0 - Run this first, and then run other cells\n", + "# import sys\n", + "# import tensorflow as tf\n", + "# import tensorflow.keras as keras_tf\n", + "\n", + "# # Redirect 'keras' imports to tf.keras for this session\n", + "# sys.modules['keras'] = keras_tf\n", + "# sys.modules['keras.utils'] = keras_tf.utils\n", + "# sys.modules['keras.layers'] = keras_tf.layers\n", + "# sys.modules['keras.backend'] = keras_tf.backend\n", + "# sys.modules['keras.initializers'] = keras_tf.initializers\n", + "# sys.modules['keras.applications'] = keras_tf.applications\n", + "# sys.modules['keras.metrics'] = keras_tf.metrics\n", + "# sys.modules['keras.activations'] = keras_tf.activations\n", + "# sys.modules['keras.losses'] = keras_tf.losses\n", + "# sys.modules['keras.optimizers'] = keras_tf.optimizers\n", + "\n", + "# # Optional quick check\n", + "# import keras\n", + "# print(\"keras ->\", keras)\n", + "# print(\"has keras.ops?\", hasattr(keras, \"ops\"))\n", + "# print(\"tensorflow version:\", tf.__version__)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Eeens_ZzOORs" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:27.128783Z", + "iopub.status.busy": "2025-09-03T12:32:27.128437Z", + "iopub.status.idle": "2025-09-03T12:32:29.437696Z", + "shell.execute_reply": "2025-09-03T12:32:29.437048Z", + "shell.execute_reply.started": "2025-09-03T12:32:27.128735Z" + }, + "id": "t-ea9GLkOORt", + "outputId": "b0041de4-ba8b-444c-8aef-690db68610f0", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_2 (InputLayer) [(None, 100, 100, 3)] 0 \n", + " \n", + " my_vit_model (Functional) (None, 128) 559360 \n", + " \n", + " flatten (Flatten) (None, 128) 0 \n", + " \n", + " the_feature_layer (Dense) (None, 64) 8256 \n", + " \n", + " dense (Dense) (None, 32) 2080 \n", + " \n", + " dense_1 (Dense) (None, 8) 264 \n", + " \n", + "=================================================================\n", + "Total params: 569960 (2.17 MB)\n", + "Trainable params: 569960 (2.17 MB)\n", + "Non-trainable params: 0 (0.00 Byte)\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "# Put this at the top of the notebook / cell BEFORE importing vit_keras\n", + "import sys\n", + "import tensorflow as tf\n", + "import tensorflow.keras as keras_tf\n", + "\n", + "# Redirect the \"keras\" module name to tf.keras for this Python session\n", + "sys.modules['keras'] = keras_tf\n", + "sys.modules['keras.utils'] = keras_tf.utils\n", + "sys.modules['keras.layers'] = keras_tf.layers\n", + "sys.modules['keras.backend'] = keras_tf.backend\n", + "sys.modules['keras.initializers'] = keras_tf.initializers\n", + "\n", + "# Now import libraries and build model\n", + "from vit_keras import vit as vit_module\n", + "from tensorflow.keras.layers import Input, Flatten, Dense\n", + "from tensorflow.keras.models import Model\n", + "\n", + "with stg.scope():\n", + " vit_model = vit_module.build_model(\n", + " image_size=IMG_SIZE,\n", + " patch_size=5,\n", + " activation='softmax',\n", + " include_top=False,\n", + " classes=8,\n", + " num_layers=5,\n", + " hidden_size=128,\n", + " mlp_dim=128,\n", + " num_heads=2,\n", + " name='my_vit_model'\n", + " )\n", + "\n", + " inp = Input(shape=(*IMG_SIZE, 3))\n", + " vit_out = vit_model(inp) # don't overwrite vit_module\n", + " X = Flatten()(vit_out)\n", + " X = Dense(64, activation='gelu', name='the_feature_layer')(X)\n", + " X = Dense(32, activation='gelu')(X)\n", + " out = Dense(8, activation='softmax')(X)\n", + "\n", + " model = Model(inputs=inp, outputs=out)\n", + " model.summary()\n", + "\n", + " model.compile(\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n", + " metrics=['acc', tf.keras.metrics.SparseTopKCategoricalAccuracy(k=4, name=\"top_4_acc\")]\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:32:29.439446Z", + "iopub.status.busy": "2025-09-03T12:32:29.439179Z", + "iopub.status.idle": "2025-09-03T12:46:25.062513Z", + "shell.execute_reply": "2025-09-03T12:46:25.061803Z", + "shell.execute_reply.started": "2025-09-03T12:32:29.439423Z" + }, + "id": "JXSdrPTKOORt", + "outputId": "91d6841f-4c2b-4f5d-a7f9-cc5152df2504", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/7\n", + "1235/1235 [==============================] - 180s 118ms/step - loss: nan - acc: 0.0087 - top_4_acc: 7.5953e-05 - val_loss: nan - val_acc: 0.0096 - val_top_4_acc: 0.0000e+00\n", + "Epoch 2/7\n", + "1235/1235 [==============================] - 112s 90ms/step - loss: nan - acc: 0.0086 - top_4_acc: 0.0000e+00 - val_loss: nan - val_acc: 0.0096 - val_top_4_acc: 0.0000e+00\n", + "Epoch 3/7\n", + "1235/1235 [==============================] - 111s 89ms/step - loss: nan - acc: 0.0086 - top_4_acc: 0.0000e+00 - val_loss: nan - val_acc: 0.0096 - val_top_4_acc: 0.0000e+00\n", + "Epoch 4/7\n", + "1235/1235 [==============================] - 107s 86ms/step - loss: nan - acc: 0.0086 - top_4_acc: 0.0000e+00 - val_loss: nan - val_acc: 0.0096 - val_top_4_acc: 0.0000e+00\n", + "Epoch 5/7\n", + "1235/1235 [==============================] - 108s 87ms/step - loss: nan - acc: 0.0086 - top_4_acc: 0.0000e+00 - val_loss: nan - val_acc: 0.0096 - val_top_4_acc: 0.0000e+00\n", + "Epoch 6/7\n", + "1235/1235 [==============================] - 108s 87ms/step - loss: nan - acc: 0.0086 - top_4_acc: 0.0000e+00 - val_loss: nan - val_acc: 0.0096 - val_top_4_acc: 0.0000e+00\n", + "Epoch 7/7\n", + "1235/1235 [==============================] - 109s 88ms/step - loss: nan - acc: 0.0086 - top_4_acc: 0.0000e+00 - val_loss: nan - val_acc: 0.0096 - val_top_4_acc: 0.0000e+00\n" + ] + } + ], + "source": [ + "# Training feature extraction model and saved\n", + "\n", + "hist = model.fit(train_dataset, epochs = 7, batch_size = BATCH_SIZE, validation_data = valid_dataset)\n", + "model.save(\"vit_feature_extractor.h5\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:50:40.133035Z", + "iopub.status.busy": "2025-09-03T12:50:40.132594Z", + "iopub.status.idle": "2025-09-03T12:50:52.330937Z", + "shell.execute_reply": "2025-09-03T12:50:52.330091Z", + "shell.execute_reply.started": "2025-09-03T12:50:40.133008Z" + }, + "id": "Y4yNupn_OORt", + "outputId": "57835d5b-cfa7-4a9b-d751-6f2c586216d1", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ViT model results\n", + "----------------------------------------------------------------------------------------------------\n", + "138/138 [==============================] - 6s 40ms/step - loss: nan - acc: 0.0096 - top_4_acc: 0.0000e+00\n", + "Validation Loss: nan\n", + "Validation Accuracy: 0.957 %\n", + "----------------------------------------------------------------------------------------------------\n", + "153/153 [==============================] - 6s 40ms/step - loss: nan - acc: 0.0072 - top_4_acc: 0.0000e+00\n", + "Test Loss: nan\n", + "Test Accuracy: 0.718 %\n" + ] + } + ], + "source": [ + "# Validation and Test evaluations of ViT model\n", + "\n", + "with stg.scope():\n", + " print('ViT model results')\n", + " print('--'*50)\n", + " val_eval_vit = model.evaluate(valid_dataset)\n", + " print('Validation Loss: {0:.3f}'.format(val_eval_vit[0]))\n", + " print('Validation Accuracy: {0:.3f} %'.format(val_eval_vit[1]*100))\n", + " print('--'*50)\n", + " test_eval_vit = model.evaluate(test_dataset)\n", + " print('Test Loss: {0:.3f}'.format(test_eval_vit[0]))\n", + " print('Test Accuracy: {0:.3f} %'.format(test_eval_vit[1]*100))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nuOfh3RIOORt" + }, + "source": [ + "# Classification Stage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-05T07:27:45.629675Z", + "iopub.status.busy": "2025-09-05T07:27:45.629056Z", + "iopub.status.idle": "2025-09-05T07:27:45.635113Z", + "shell.execute_reply": "2025-09-05T07:27:45.634273Z", + "shell.execute_reply.started": "2025-09-05T07:27:45.629644Z" + }, + "id": "fZ-2x3EmOORt", + "trusted": true + }, + "outputs": [], + "source": [ + "# import numpy as np\n", + "# import pandas as pd\n", + "# import tensorflow as tf\n", + "# from tensorflow.keras.models import Model\n", + "# from sklearn.impute import SimpleImputer\n", + "# from sklearn.decomposition import PCA\n", + "# from sklearn.pipeline import Pipeline\n", + "\n", + "# # ========================\n", + "# # 1. Reload model properly\n", + "# # ========================\n", + "# with stg.scope():\n", + "# feature_extr = tf.keras.models.load_model('/kaggle/working/vit_feature_extractor.h5')\n", + "\n", + "# print(\"All layers in loaded model:\", [l.name for l in feature_extr.layers])\n", + "\n", + "# if \"the_feature_layer\" not in [l.name for l in feature_extr.layers]:\n", + "# raise RuntimeError(\"Layer 'the_feature_layer' not found in model. Check naming!\")\n", + "\n", + "# feature_extractor_model = Model(\n", + "# inputs=feature_extr.input,\n", + "# outputs=feature_extr.get_layer(\"the_feature_layer\").output\n", + "# )\n", + "\n", + "# print(\"Feature extractor ready — output shape:\", feature_extractor_model.output_shape)\n", + "\n", + "# # ========================\n", + "# # 2. Dataset sanity check\n", + "# # ========================\n", + "# if len(X_train) == 0:\n", + "# raise RuntimeError(\"X_train is empty! Check your train split after preprocessing.\")\n", + "\n", + "# print(\"Train set size:\", len(X_train))\n", + "# print(\"Example paths:\", X_train['full_path_to_img'].head().tolist())\n", + "\n", + "# # ========================\n", + "# # 3. Feature extraction\n", + "# # ========================\n", + "# features_list = []\n", + "# labels_list = []\n", + "\n", + "# for idx, (img_path, label) in enumerate(zip(X_train['full_path_to_img'], y_train)):\n", + "# try:\n", + "# img_raw = tf.io.read_file(img_path)\n", + "# img = tf.io.decode_jpeg(img_raw, channels=3)\n", + "# img = tf.image.resize(img, IMG_SIZE)\n", + "# img = tf.cast(img, tf.float32) / 255.0\n", + "# img = tf.expand_dims(img, 0) # (1, H, W, 3)\n", + "\n", + "# feat = feature_extractor_model(img, training=False).numpy().squeeze()\n", + "\n", + "# if np.isnan(feat).any() or np.isinf(feat).any():\n", + "# print(f\"⚠️ NaN/Inf detected in feature for {img_path}, skipping.\")\n", + "# continue\n", + "\n", + "# features_list.append(feat)\n", + "# labels_list.append(label)\n", + "\n", + "# except Exception as e:\n", + "# print(f\"❌ Failed on {img_path}: {e}\")\n", + "# continue\n", + "\n", + "# if len(features_list) == 0:\n", + "# raise RuntimeError(\"No valid features extracted. Check preprocessing/model/data.\")\n", + "\n", + "# features = np.vstack(features_list)\n", + "# labels_clean = np.array(labels_list)\n", + "\n", + "# print(\"✅ Features extracted:\", features.shape)\n", + "\n", + "# # ========================\n", + "# # 4. PCA with imputer\n", + "# # ========================\n", + "# n_components = min(42, features.shape[1], features.shape[0])\n", + "# pca_pipeline = Pipeline([\n", + "# (\"imputer\", SimpleImputer(strategy=\"mean\")),\n", + "# (\"pca\", PCA(n_components=n_components))\n", + "# ])\n", + "\n", + "# features_pca = pca_pipeline.fit_transform(features)\n", + "\n", + "# new_feature_column_names = [f\"feature_{i+1}\" for i in range(features_pca.shape[1])]\n", + "# train_features = pd.DataFrame(features_pca, columns=new_feature_column_names)\n", + "# train_labels = pd.Series(labels_clean, name=\"label\")\n", + "\n", + "# print(\"Final train_features shape:\", train_features.shape)\n", + "# print(train_features.head())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T11:56:06.331132Z", + "iopub.status.busy": "2025-09-03T11:56:06.330803Z", + "iopub.status.idle": "2025-09-03T11:56:08.097574Z", + "shell.execute_reply": "2025-09-03T11:56:08.096832Z", + "shell.execute_reply.started": "2025-09-03T11:56:06.331106Z" + }, + "id": "6qcCQ758OORt", + "trusted": true + }, + "outputs": [], + "source": [ + "# # Reading saved model and weights\n", + "# feature_extr = tf.keras.models.load_model('/kaggle/working/vit_feature_extractor.h5')\n", + "\n", + "# # Feature extraction model\n", + "# feature_extractor_model = Model(inputs=feature_extr.input,\n", + "# outputs=feature_extr.get_layer('the_feature_layer').output)\n", + "# feature_extr = tf.keras.models.load_model('/kaggle/working/vit_feature_extractor.h5', compile=False)\n", + "\n", + "# # Feature extraction model (same layer as you used)\n", + "# feature_extractor_model = Model(inputs=feature_extr.input,\n", + "# outputs=feature_extr.get_layer('the_feature_layer').output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T11:56:54.168262Z", + "iopub.status.busy": "2025-09-03T11:56:54.167455Z", + "iopub.status.idle": "2025-09-03T11:56:54.246526Z", + "shell.execute_reply": "2025-09-03T11:56:54.245893Z", + "shell.execute_reply.started": "2025-09-03T11:56:54.168201Z" + }, + "id": "2DgyVIFXOORt", + "trusted": true + }, + "outputs": [], + "source": [ + "# ----------------------------\n", + "# Scan train images for NaN/Inf features and remove offending files\n", + "# ----------------------------\n", + "\n", + "# Build a dataset that yields (img_tensor, path, label) so we can report and locate bad files\n", + "paths_ds = tf.data.Dataset.from_tensor_slices(X_train['full_path_to_img'])\n", + "labels_ds = tf.data.Dataset.from_tensor_slices(y_train)\n", + "paths_and_labels = tf.data.Dataset.zip((paths_ds, labels_ds))\n", + "\n", + "def read_img_with_path(path, label):\n", + " img = tf.io.read_file(path)\n", + " # use decode_jpeg but fallback to decode_image if necessary\n", + " try:\n", + " img = tf.io.decode_jpeg(img, channels=3)\n", + " except Exception:\n", + " img = tf.io.decode_image(img, channels=3)\n", + " img = tf.image.resize(img, size=IMG_SIZE)\n", + " img = tf.cast(img, tf.float32) / 255.0\n", + " return img, path, label\n", + "\n", + "dataset_with_path = (paths_and_labels\n", + " .map(read_img_with_path, num_parallel_calls=AUTO)\n", + " .batch(BATCH_SIZE)\n", + " .prefetch(AUTO))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:17:07.020438Z", + "iopub.status.busy": "2025-09-03T12:17:07.019356Z", + "iopub.status.idle": "2025-09-03T12:17:07.079011Z", + "shell.execute_reply": "2025-09-03T12:17:07.078024Z", + "shell.execute_reply.started": "2025-09-03T12:17:07.020405Z" + }, + "id": "xj4DoH8YOORt", + "outputId": "d376627b-873a-4fe6-cfbd-389fe2a04ddb", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Attempting feature extraction for 0 training images (per-file loop).\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "0it [00:00, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Kept 0 samples, skipped 0 samples.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "No valid features extracted. Inspect skipped_paths.txt for details.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[63], line 70\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mKept \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_kept\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m samples, skipped \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_skipped\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m samples.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_kept \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 70\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo valid features extracted. Inspect skipped_paths.txt for details.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 72\u001b[0m \u001b[38;5;66;03m# Stack features into array (n_samples, n_features)\u001b[39;00m\n\u001b[1;32m 73\u001b[0m features \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(features_list)\n", + "\u001b[0;31mRuntimeError\u001b[0m: No valid features extracted. Inspect skipped_paths.txt for details." + ] + } + ], + "source": [ + "# Robust per-file extraction + PCA (no destructive removals)\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from PIL import Image\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.decomposition import PCA\n", + "from sklearn.pipeline import Pipeline\n", + "import os\n", + "from tqdm import tqdm\n", + "\n", + "# Settings\n", + "out_skipped_file = \"/kaggle/working/skipped_paths.txt\"\n", + "os.makedirs(os.path.dirname(out_skipped_file), exist_ok=True)\n", + "\n", + "# Containers\n", + "features_list = []\n", + "labels_kept = []\n", + "paths_kept = []\n", + "paths_skipped = []\n", + "\n", + "# Helper: single image load + preprocess -> numpy array (H,W,3), same as img_preprocessing\n", + "def load_preprocess_np(path):\n", + " x = tf.io.read_file(path)\n", + " try:\n", + " img = tf.io.decode_jpeg(x, channels=3)\n", + " except Exception:\n", + " img = tf.io.decode_image(x, channels=3)\n", + " img = tf.image.resize(img, size=IMG_SIZE)\n", + " img = tf.cast(img, tf.float32) / 255.0\n", + " return img.numpy()\n", + "\n", + "# Iterate over each training row (robust; avoids empty dataset issues)\n", + "n_total = len(X_train)\n", + "print(f\"Attempting feature extraction for {n_total} training images (per-file loop).\")\n", + "for idx in tqdm(range(n_total)):\n", + " path = X_train['full_path_to_img'].iloc[idx]\n", + " label = y_train.iloc[idx] if hasattr(y_train, 'iloc') else y_train[idx]\n", + " try:\n", + " img_np = load_preprocess_np(path) # shape (H,W,3)\n", + " except Exception as e:\n", + " # Could not read/parse image file — log and skip\n", + " paths_skipped.append((path, f\"read_error: {e}\"))\n", + " continue\n", + "\n", + " try:\n", + " # run model in eager mode on a batch of 1\n", + " feats_tf = feature_extractor_model(tf.convert_to_tensor(np.expand_dims(img_np, 0)), training=False)\n", + " feats = feats_tf.numpy()\n", + " except Exception as e:\n", + " # model threw on this image — log and skip\n", + " paths_skipped.append((path, f\"model_error: {e}\"))\n", + " continue\n", + "\n", + " # Check NaN/Inf for this sample\n", + " if np.isnan(feats).any() or np.isinf(feats).any():\n", + " paths_skipped.append((path, \"nan_or_inf_in_features\"))\n", + " continue\n", + "\n", + " # Keep valid features\n", + " features_list.append(feats.reshape(-1)) # flatten to 1D\n", + " labels_kept.append(label)\n", + " paths_kept.append(path)\n", + "\n", + "# Summary\n", + "n_kept = len(features_list)\n", + "n_skipped = len(paths_skipped)\n", + "print(f\"Kept {n_kept} samples, skipped {n_skipped} samples.\")\n", + "\n", + "if n_kept == 0:\n", + " raise RuntimeError(\"No valid features extracted. Inspect skipped_paths.txt for details.\")\n", + "\n", + "# Stack features into array (n_samples, n_features)\n", + "features = np.vstack(features_list)\n", + "print(\"features.shape =\", features.shape)\n", + "\n", + "# Persist skipped paths for review\n", + "with open(out_skipped_file, \"w\") as f:\n", + " for p, reason in paths_skipped:\n", + " f.write(f\"{p}\\t{reason}\\n\")\n", + "print(\"Wrote skipped image list to:\", out_skipped_file)\n", + "\n", + "# If some NaNs present (unexpected), impute\n", + "if np.isnan(features).any() or np.isinf(features).any():\n", + " print(\"Detected NaN/Inf in features array; applying SimpleImputer (column mean).\")\n", + " imputer = SimpleImputer(strategy='mean')\n", + " features = imputer.fit_transform(features)\n", + "\n", + "# PCA: n_components <= n_features\n", + "n_samples, n_features = features.shape\n", + "n_components = min(42, n_features)\n", + "print(f\"Running PCA with n_components={n_components} (n_features={n_features})\")\n", + "pca_pipeline = Pipeline([\n", + " ('imputer', SimpleImputer(strategy='mean')),\n", + " ('pca', PCA(n_components=n_components))\n", + "])\n", + "pred_pca = pca_pipeline.fit_transform(features)\n", + "new_feature_column_names = [f'feature_{i+1}' for i in range(pred_pca.shape[1])]\n", + "train_features = pd.DataFrame(pred_pca, columns=new_feature_column_names)\n", + "train_features['label'] = labels_kept\n", + "train_features['path'] = paths_kept\n", + "\n", + "print(\"train_features shape:\", train_features.shape)\n", + "train_features.head()\n", + "\n", + "# Save outputs\n", + "train_features.to_csv(\"/kaggle/working/train_features_pca_per_file.csv\", index=False)\n", + "print(\"Saved train_features_pca_per_file.csv to /kaggle/working/\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T12:03:35.904637Z", + "iopub.status.busy": "2025-09-03T12:03:35.904286Z", + "iopub.status.idle": "2025-09-03T12:03:35.910982Z", + "shell.execute_reply": "2025-09-03T12:03:35.910076Z", + "shell.execute_reply.started": "2025-09-03T12:03:35.904611Z" + }, + "id": "XVyryMAPOORt", + "trusted": true + }, + "outputs": [], + "source": [ + "# bad_paths = [] # collect any file paths that produce NaN/Inf in features\n", + "# print(\"Scanning training set for images that produce NaN/Inf feature vectors...\")\n", + "# for batch_idx, batch in enumerate(dataset_with_path):\n", + "# imgs, paths_batch, labels_batch = batch\n", + "# # convert TF string tensors to numpy bytes (to map back to pandas)\n", + "# paths_np = [p.numpy().decode('utf-8') for p in paths_batch]\n", + "\n", + "# # predict features for this batch\n", + "# feats = feature_extractor_model.predict(imgs, verbose=0)\n", + "# # per-sample check\n", + "# sample_has_nan = np.isnan(feats).any(axis=1) | np.isinf(feats).any(axis=1)\n", + "# if np.any(sample_has_nan):\n", + "# # record all offending paths in this batch\n", + "# for j, bad_flag in enumerate(sample_has_nan):\n", + "# if bad_flag:\n", + "# bp = paths_np[j]\n", + "# print(f\"Batch {batch_idx} - found bad sample: {bp}\")\n", + "# bad_paths.append(bp)\n", + "\n", + "# # Deduplicate bad_paths\n", + "# bad_paths = list(dict.fromkeys(bad_paths))\n", + "# print(\"Total bad files found:\", len(bad_paths))\n", + "# for p in bad_paths:\n", + "# print(\" -\", p)\n", + "\n", + "# # If no bad files found, bad_paths will be []\n", + "# if len(bad_paths) > 0:\n", + "# # Remove them from X_train / y_train (assuming X_train is a pandas DataFrame with column 'full_path_to_img')\n", + "# print(\"Removing bad files from X_train / y_train and rebuilding datasets...\")\n", + "# mask_good = ~X_train['full_path_to_img'].isin(bad_paths)\n", + "# X_train_clean = X_train[mask_good].reset_index(drop=True)\n", + "# y_train_clean = y_train[mask_good].reset_index(drop=True)\n", + "\n", + "# # Recreate dataset objects for the cleaned training set\n", + "# train_loader = tf.data.Dataset.from_tensor_slices((X_train_clean['full_path_to_img'], y_train_clean))\n", + "# train_dataset = (train_loader\n", + "# .map(img_preprocessing, num_parallel_calls=AUTO)\n", + "# .map(augmentation, num_parallel_calls=AUTO)\n", + "# .shuffle(BATCH_SIZE * 10, seed=SEED)\n", + "# .batch(BATCH_SIZE)\n", + "# .prefetch(AUTO))\n", + "\n", + "# train_loader_feature = tf.data.Dataset.from_tensor_slices((X_train_clean['full_path_to_img'], y_train_clean))\n", + "# train_dataset_feature = (train_loader_feature\n", + "# .map(img_preprocessing, num_parallel_calls=AUTO)\n", + "# .batch(BATCH_SIZE)\n", + "# .prefetch(AUTO))\n", + "\n", + "# # (Optional) update X_train and y_train variables in your workspace to use cleaned sets\n", + "# X_train = X_train_clean\n", + "# y_train = y_train_clean\n", + "# else:\n", + "# print(\"No bad files found; using original X_train / y_train.\")\n", + "# # ensure train_dataset_feature refers to the dataset we created earlier\n", + "# train_loader_feature = tf.data.Dataset.from_tensor_slices((X_train['full_path_to_img'], y_train))\n", + "# train_dataset_feature = (train_loader_feature\n", + "# .map(img_preprocessing, num_parallel_calls=AUTO)\n", + "# .batch(BATCH_SIZE)\n", + "# .prefetch(AUTO))\n", + "\n", + "# # ----------------------------\n", + "# # Extract features for cleaned training set (batch-wise) safely\n", + "# # ----------------------------\n", + "# print(\"Extracting features for cleaned training set (this may take a while)...\")\n", + "# images_only_ds = train_dataset_feature.map(lambda x, y: x)\n", + "\n", + "# # Collect features in a list and stack to avoid memory / predict on full dataset at once if needed\n", + "# features_list = []\n", + "# for batch_imgs in images_only_ds:\n", + "# batch_feats = feature_extractor_model.predict(batch_imgs, verbose=0)\n", + "# features_list.append(batch_feats)\n", + "\n", + "# features = np.vstack(features_list) # final features array, shape (n_samples, n_features)\n", + "# print(\"Extracted features shape:\", features.shape)\n", + "# print(\"Any NaN in features after cleaning?:\", np.isnan(features).any(), \"Any Inf?:\", np.isinf(features).any())\n", + "\n", + "# # If there are still NaNs, you can either (A) remove the related samples (scan again) or (B) impute.\n", + "# if np.isnan(features).any() or np.isinf(features).any():\n", + "# # fallback: impute NaNs using column means (only if small number of NaNs present)\n", + "# print(\"Warning: NaNs/Inf remain in features. Applying SimpleImputer (mean) to handle them before PCA.\")\n", + "# from sklearn.impute import SimpleImputer\n", + "# imputer = SimpleImputer(strategy='mean')\n", + "# features = imputer.fit_transform(features)\n", + "# print(\"After imputation: Any NaN left?\", np.isnan(features).any())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T10:26:17.959846Z", + "iopub.status.busy": "2025-09-03T10:26:17.959526Z", + "iopub.status.idle": "2025-09-03T10:27:05.864982Z", + "shell.execute_reply": "2025-09-03T10:27:05.864243Z", + "shell.execute_reply.started": "2025-09-03T10:26:17.959823Z" + }, + "id": "X64onoYcOORt", + "outputId": "512ac52b-bccb-4382-ffe3-5fbe2fbfef24", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1235/1235 [==============================] - 47s 36ms/step\n" + ] + } + ], + "source": [ + "# # Creating train features\n", + "\n", + "# with stg.scope():\n", + "# features = feature_extractor_model.predict(train_dataset_feature)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xbg_WPFsOORu" + }, + "source": [ + "# Dimensionality Reduction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T10:30:30.764572Z", + "iopub.status.busy": "2025-09-03T10:30:30.763760Z", + "iopub.status.idle": "2025-09-03T10:30:30.772646Z", + "shell.execute_reply": "2025-09-03T10:30:30.771760Z", + "shell.execute_reply.started": "2025-09-03T10:30:30.764534Z" + }, + "id": "6zCaPi67OORu", + "outputId": "6bd5f612-4018-4033-f90a-20dcaeb90da3", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Does the array contain any NaN values? True\n", + "Columns with all NaN values:\n", + "[ True True True True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "# Check if there are any NaN values in the entire array\n", + "print(\"Does the array contain any NaN values? \", np.isnan(features).any())\n", + "\n", + "# Check for columns (axis 1) where all values are NaN\n", + "print(\"Columns with all NaN values:\")\n", + "print(np.all(np.isnan(features), axis=0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T10:56:07.848859Z", + "iopub.status.busy": "2025-09-03T10:56:07.848529Z", + "iopub.status.idle": "2025-09-03T10:56:07.853158Z", + "shell.execute_reply": "2025-09-03T10:56:07.852289Z", + "shell.execute_reply.started": "2025-09-03T10:56:07.848833Z" + }, + "id": "lBvQi0FiOORu", + "trusted": true + }, + "outputs": [], + "source": [ + "# # Applying PCA 42 components nearly equal to 0.99 variance ratio\n", + "\n", + "# # pca_ = PCA(42)\n", + "# # pred_pca_ = pca_.fit(features)\n", + "# # pred_pca = pred_pca_.transform(features)\n", + "# from sklearn.impute import SimpleImputer\n", + "# from sklearn.decomposition import PCA\n", + "# from sklearn.pipeline import Pipeline\n", + "\n", + "# # Create a pipeline with an imputer and PCA\n", + "# pca_pipeline = Pipeline([\n", + "# ('imputer', SimpleImputer(strategy='mean')),\n", + "# ('pca', PCA(n_components=42))\n", + "# ])\n", + "\n", + "# # Fit the pipeline to the data\n", + "# pred_pca_ = pca_pipeline.fit(features)\n", + "# pred_pca = pred_pca_.transform(features)\n", + "\n", + "# new_feature_column_names = []\n", + "# for i in range(pred_pca.shape[1]):\n", + "# new_feature_column_names.append('feature_{0}'.format(i+1))\n", + "\n", + "# train_features = pd.DataFrame(pred_pca, columns = new_feature_column_names)\n", + "\n", + "\n", + "# # Features created with a ViT feature extractor\n", + "# train_features.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T11:36:23.682635Z", + "iopub.status.busy": "2025-09-03T11:36:23.681936Z", + "iopub.status.idle": "2025-09-03T11:36:23.961081Z", + "shell.execute_reply": "2025-09-03T11:36:23.960260Z", + "shell.execute_reply.started": "2025-09-03T11:36:23.682604Z" + }, + "id": "1X1hHqlIOORu", + "outputId": "1b9cef0b-958c-43ef-ef0d-9397aeaf2fab", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found NaN/Inf in features for batch index 0.\n", + "Image 0 path: /kaggle/input/diamond-images-dataset/web_scraped/princess/2110814.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.67034954 std: 0.14518233\n", + " any NaN in img? False any Inf in img? False\n", + "Image 1 path: /kaggle/input/diamond-images-dataset/web_scraped/oval/2114823.jpg\n", + " pixel stats -> min: 0.018529227 max: 1.0 mean: 0.68680847 std: 0.12197413\n", + " any NaN in img? False any Inf in img? False\n", + "Image 2 path: /kaggle/input/diamond-images-dataset/web_scraped/oval/2128977.jpg\n", + " pixel stats -> min: 0.033137422 max: 1.0 mean: 0.6546141 std: 0.13366248\n", + " any NaN in img? False any Inf in img? False\n", + "Image 3 path: /kaggle/input/diamond-images-dataset/web_scraped/emerald/2032854.jpg\n", + " pixel stats -> min: 0.011764706 max: 1.0 mean: 0.6315528 std: 0.19869038\n", + " any NaN in img? False any Inf in img? False\n", + "Image 4 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2107646.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.61658835 std: 0.21517538\n", + " any NaN in img? False any Inf in img? False\n", + "Image 5 path: /kaggle/input/diamond-images-dataset/web_scraped/princess/2119671.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.66367805 std: 0.1658859\n", + " any NaN in img? False any Inf in img? False\n", + "Image 6 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2116296.jpg\n", + " pixel stats -> min: 0.015686275 max: 1.0 mean: 0.681769 std: 0.13941239\n", + " any NaN in img? False any Inf in img? False\n", + "Image 7 path: /kaggle/input/diamond-images-dataset/web_scraped/oval/220207-628.jpg\n", + " pixel stats -> min: 0.011391972 max: 1.0 mean: 0.6577099 std: 0.13830958\n", + " any NaN in img? False any Inf in img? False\n", + "Image 8 path: /kaggle/input/diamond-images-dataset/web_scraped/cushion/220408-6.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.6619267 std: 0.14756434\n", + " any NaN in img? False any Inf in img? False\n", + "Image 9 path: /kaggle/input/diamond-images-dataset/web_scraped/round/220317-19.jpg\n", + " pixel stats -> min: 0.034235697 max: 1.0 mean: 0.6685742 std: 0.13233511\n", + " any NaN in img? False any Inf in img? False\n", + "Image 10 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2120873.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.6935362 std: 0.13459717\n", + " any NaN in img? False any Inf in img? False\n", + "Image 11 path: /kaggle/input/diamond-images-dataset/web_scraped/princess/2134781.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.63885754 std: 0.17516671\n", + " any NaN in img? False any Inf in img? False\n", + "Image 12 path: /kaggle/input/diamond-images-dataset/web_scraped/round/223476-254.jpg\n", + " pixel stats -> min: 0.0038236731 max: 1.0 mean: 0.68113595 std: 0.1485862\n", + " any NaN in img? False any Inf in img? False\n", + "Image 13 path: /kaggle/input/diamond-images-dataset/web_scraped/oval/220358-87.jpg\n", + " pixel stats -> min: 0.0984308 max: 1.0 mean: 0.68076944 std: 0.106789425\n", + " any NaN in img? False any Inf in img? False\n", + "Image 14 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2092193.jpg\n", + " pixel stats -> min: 0.039607506 max: 1.0 mean: 0.6647163 std: 0.14491771\n", + " any NaN in img? False any Inf in img? False\n", + "Image 15 path: /kaggle/input/diamond-images-dataset/web_scraped/emerald/221163-28.jpg\n", + " pixel stats -> min: 0.01372549 max: 1.0 mean: 0.60111386 std: 0.21680893\n", + " any NaN in img? False any Inf in img? False\n", + "Image 16 path: /kaggle/input/diamond-images-dataset/web_scraped/pear/2115065.jpg\n", + " pixel stats -> min: 0.031509757 max: 1.0 mean: 0.6816401 std: 0.10798271\n", + " any NaN in img? False any Inf in img? False\n", + "Image 17 path: /kaggle/input/diamond-images-dataset/web_scraped/princess/220231-295.jpg\n", + " pixel stats -> min: 0.017784448 max: 1.0 mean: 0.64701426 std: 0.18080431\n", + " any NaN in img? False any Inf in img? False\n", + "Image 18 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2116397.jpg\n", + " pixel stats -> min: 0.027588429 max: 1.0 mean: 0.672964 std: 0.15052211\n", + " any NaN in img? False any Inf in img? False\n", + "Image 19 path: /kaggle/input/diamond-images-dataset/web_scraped/princess/2083062.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.66837966 std: 0.15665197\n", + " any NaN in img? False any Inf in img? False\n", + "Image 20 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2112378.jpg\n", + " pixel stats -> min: 0.025735294 max: 1.0 mean: 0.69153726 std: 0.12064762\n", + " any NaN in img? False any Inf in img? False\n", + "Image 21 path: /kaggle/input/diamond-images-dataset/web_scraped/oval/2043354.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.67216915 std: 0.12480005\n", + " any NaN in img? False any Inf in img? False\n", + "Image 22 path: /kaggle/input/diamond-images-dataset/web_scraped/pear/1025490.jpg\n", + " pixel stats -> min: 0.10612745 max: 1.0 mean: 0.71845245 std: 0.08196875\n", + " any NaN in img? False any Inf in img? False\n", + "Image 23 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2077695.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.6548285 std: 0.14756441\n", + " any NaN in img? False any Inf in img? False\n", + "Image 24 path: /kaggle/input/diamond-images-dataset/web_scraped/cushion/2122361.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.6722985 std: 0.16471538\n", + " any NaN in img? False any Inf in img? False\n", + "Image 25 path: /kaggle/input/diamond-images-dataset/web_scraped/round/221162-34.jpg\n", + " pixel stats -> min: 0.0014117073 max: 1.0 mean: 0.6704468 std: 0.14351074\n", + " any NaN in img? False any Inf in img? False\n", + "Image 26 path: /kaggle/input/diamond-images-dataset/web_scraped/round/221123-62.jpg\n", + " pixel stats -> min: 0.040628828 max: 1.0 mean: 0.6921635 std: 0.12741992\n", + " any NaN in img? False any Inf in img? False\n", + "Image 27 path: /kaggle/input/diamond-images-dataset/web_scraped/princess/2079371.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.6341646 std: 0.18892504\n", + " any NaN in img? False any Inf in img? False\n", + "Image 28 path: /kaggle/input/diamond-images-dataset/web_scraped/princess/2139155.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.6248894 std: 0.19418596\n", + " any NaN in img? False any Inf in img? False\n", + "Image 29 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2105931.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.68103075 std: 0.1566879\n", + " any NaN in img? False any Inf in img? False\n", + "Image 30 path: /kaggle/input/diamond-images-dataset/web_scraped/round/2101288.jpg\n", + " pixel stats -> min: 0.0 max: 1.0 mean: 0.65298504 std: 0.15010358\n", + " any NaN in img? False any Inf in img? False\n", + "Image 31 path: /kaggle/input/diamond-images-dataset/web_scraped/round/221126-80.jpg\n", + " pixel stats -> min: 0.023039216 max: 1.0 mean: 0.6880185 std: 0.1406578\n", + " any NaN in img? False any Inf in img? False\n" + ] + } + ], + "source": [ + "# import numpy as np\n", + "# import tensorflow as tf\n", + "\n", + "# # Build a dataset that yields (image_tensor, path, label)\n", + "# paths_ds = tf.data.Dataset.from_tensor_slices(X_train['full_path_to_img'])\n", + "# labels_ds = tf.data.Dataset.from_tensor_slices(y_train)\n", + "# paths_and_labels = tf.data.Dataset.zip((paths_ds, labels_ds))\n", + "\n", + "# def read_img_with_path(path, label):\n", + "# img = tf.io.read_file(path)\n", + "# img = tf.io.decode_jpeg(img, channels=3) # or decode_image if needed\n", + "# img = tf.image.resize(img, size=IMG_SIZE)\n", + "# img = tf.cast(img, tf.float32) / 255.0\n", + "# return img, path, label\n", + "\n", + "# dataset_with_path = paths_and_labels.map(read_img_with_path, num_parallel_calls=AUTO).batch(BATCH_SIZE).prefetch(AUTO)\n", + "\n", + "# # Extract features batch-by-batch and inspect\n", + "# bad_batches = []\n", + "# for i, batch in enumerate(dataset_with_path):\n", + "# imgs, paths_batch, labels_batch = batch\n", + "# # convert paths to numpy strings for printing\n", + "# try:\n", + "# feats = feature_extractor_model.predict(imgs, verbose=0)\n", + "# except Exception as e:\n", + "# print(f\"Predict crashed on batch {i} with error: {e}\")\n", + "# # still inspect imgs/paths\n", + "# paths_np = [p.numpy().decode('utf-8') for p in paths_batch]\n", + "# print(\"Paths in this batch:\", paths_np)\n", + "# raise\n", + "\n", + "# if np.isnan(feats).any() or np.isinf(feats).any():\n", + "# print(f\"Found NaN/Inf in features for batch index {i}.\")\n", + "# # print per-image stats and paths\n", + "# paths_np = [p.numpy().decode('utf-8') for p in paths_batch]\n", + "# for j in range(len(paths_np)):\n", + "# img = imgs[j].numpy()\n", + "# print(f\"Image {j} path: {paths_np[j]}\")\n", + "# print(\" pixel stats -> min:\", np.nanmin(img), \"max:\", np.nanmax(img),\n", + "# \"mean:\", np.nanmean(img), \"std:\", np.nanstd(img))\n", + "# print(\" any NaN in img?\", np.isnan(img).any(), \"any Inf in img?\", np.isinf(img).any())\n", + "# bad_batches.append((i, paths_np))\n", + "# break\n", + "\n", + "# if not bad_batches:\n", + "# print(\"No NaNs detected in any batch features. (If you still saw NaNs earlier, rerun this to double-check.)\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T11:34:40.917836Z", + "iopub.status.busy": "2025-09-03T11:34:40.917512Z", + "iopub.status.idle": "2025-09-03T11:34:41.725010Z", + "shell.execute_reply": "2025-09-03T11:34:41.724100Z", + "shell.execute_reply.started": "2025-09-03T11:34:40.917810Z" + }, + "id": "z_7J_1ivOORu", + "outputId": "e1641eea-67b2-439a-c73a-fbf6264df771", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single-batch images shape: (32, 100, 100, 3)\n", + "1/1 [==============================] - 1s 655ms/step\n", + "Batch features shape: (32, 64)\n", + "Any NaNs in batch features? True\n" + ] + } + ], + "source": [ + "# Ensure train_dataset_feature yields preprocessed images (rank-4)\n", + "images_only = train_dataset_feature.map(lambda x, y: x)\n", + "\n", + "# Take one batch and run predict on it\n", + "for imgs in images_only.take(1):\n", + " print(\"Single-batch images shape:\", imgs.shape) # should be (batch, H, W, C)\n", + " feats = feature_extractor_model.predict(imgs, verbose=1)\n", + " print(\"Batch features shape:\", feats.shape)\n", + " print(\"Any NaNs in batch features?\", np.isnan(feats).any())\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T11:30:53.683807Z", + "iopub.status.busy": "2025-09-03T11:30:53.682953Z", + "iopub.status.idle": "2025-09-03T11:31:45.298608Z", + "shell.execute_reply": "2025-09-03T11:31:45.297428Z", + "shell.execute_reply.started": "2025-09-03T11:30:53.683775Z" + }, + "id": "rWblw4IpOORu", + "outputId": "f6e1b605-2941-4d8a-fe81-45deb0ce7a6c", + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1235/1235 [==============================] - 51s 41ms/step\n", + "Extracted features shape: (39498, 64)\n", + "Warning: NaNs found in features - imputer will handle them.\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Found array with 0 feature(s) (shape=(39498, 0)) while a minimum of 1 is required by PCA.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[44], line 27\u001b[0m\n\u001b[1;32m 21\u001b[0m n_components \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(\u001b[38;5;241m42\u001b[39m, features\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 22\u001b[0m pca_pipeline \u001b[38;5;241m=\u001b[39m Pipeline([\n\u001b[1;32m 23\u001b[0m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mimputer\u001b[39m\u001b[38;5;124m'\u001b[39m, SimpleImputer(strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m'\u001b[39m)),\n\u001b[1;32m 24\u001b[0m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpca\u001b[39m\u001b[38;5;124m'\u001b[39m, PCA(n_components\u001b[38;5;241m=\u001b[39mn_components))\n\u001b[1;32m 25\u001b[0m ])\n\u001b[0;32m---> 27\u001b[0m pred_pca \u001b[38;5;241m=\u001b[39m \u001b[43mpca_pipeline\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 28\u001b[0m new_feature_column_names \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfeature_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(pred_pca\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m])]\n\u001b[1;32m 29\u001b[0m train_features \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(pred_pca, columns\u001b[38;5;241m=\u001b[39mnew_feature_column_names)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/sklearn/pipeline.py:445\u001b[0m, in \u001b[0;36mPipeline.fit_transform\u001b[0;34m(self, X, y, **fit_params)\u001b[0m\n\u001b[1;32m 443\u001b[0m fit_params_last_step \u001b[38;5;241m=\u001b[39m fit_params_steps[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m0\u001b[39m]]\n\u001b[1;32m 444\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(last_step, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfit_transform\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 445\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlast_step\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mXt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfit_params_last_step\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 446\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 447\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m last_step\u001b[38;5;241m.\u001b[39mfit(Xt, y, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfit_params_last_step)\u001b[38;5;241m.\u001b[39mtransform(Xt)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/sklearn/utils/_set_output.py:140\u001b[0m, in \u001b[0;36m_wrap_method_output..wrapped\u001b[0;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 140\u001b[0m data_to_wrap \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_to_wrap, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 142\u001b[0m \u001b[38;5;66;03m# only wrap the first output for cross decomposition\u001b[39;00m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 144\u001b[0m _wrap_data_with_container(method, data_to_wrap[\u001b[38;5;241m0\u001b[39m], X, \u001b[38;5;28mself\u001b[39m),\n\u001b[1;32m 145\u001b[0m \u001b[38;5;241m*\u001b[39mdata_to_wrap[\u001b[38;5;241m1\u001b[39m:],\n\u001b[1;32m 146\u001b[0m )\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/sklearn/decomposition/_pca.py:462\u001b[0m, in \u001b[0;36mPCA.fit_transform\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Fit the model with X and apply the dimensionality reduction on X.\u001b[39;00m\n\u001b[1;32m 440\u001b[0m \n\u001b[1;32m 441\u001b[0m \u001b[38;5;124;03mParameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[38;5;124;03mC-ordered array, use 'np.ascontiguousarray'.\u001b[39;00m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 460\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m--> 462\u001b[0m U, S, Vt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 463\u001b[0m U \u001b[38;5;241m=\u001b[39m U[:, : \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_components_]\n\u001b[1;32m 465\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwhiten:\n\u001b[1;32m 466\u001b[0m \u001b[38;5;66;03m# X_new = X * V / S * sqrt(n_samples) = U * sqrt(n_samples)\u001b[39;00m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/sklearn/decomposition/_pca.py:485\u001b[0m, in \u001b[0;36mPCA._fit\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m issparse(X):\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPCA does not support sparse input. See \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTruncatedSVD for a possible alternative.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 483\u001b[0m )\n\u001b[0;32m--> 485\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 486\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat64\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat32\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mensure_2d\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\n\u001b[1;32m 487\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[38;5;66;03m# Handle n_components==None\u001b[39;00m\n\u001b[1;32m 490\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_components \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/sklearn/base.py:565\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[0;34m(self, X, y, reset, validate_separately, **check_params)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValidation should be done on X, y or both.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 564\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m no_val_y:\n\u001b[0;32m--> 565\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mX\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheck_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 566\u001b[0m out \u001b[38;5;241m=\u001b[39m X\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_y:\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/sklearn/utils/validation.py:940\u001b[0m, in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[1;32m 938\u001b[0m n_features \u001b[38;5;241m=\u001b[39m array\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 939\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_features \u001b[38;5;241m<\u001b[39m ensure_min_features:\n\u001b[0;32m--> 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFound array with \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m feature(s) (shape=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m) while\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m a minimum of \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m is required\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 943\u001b[0m \u001b[38;5;241m%\u001b[39m (n_features, array\u001b[38;5;241m.\u001b[39mshape, ensure_min_features, context)\n\u001b[1;32m 944\u001b[0m )\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m copy:\n\u001b[1;32m 947\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m xp\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnumpy\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnumpy.array_api\u001b[39m\u001b[38;5;124m\"\u001b[39m}:\n\u001b[1;32m 948\u001b[0m \u001b[38;5;66;03m# only make a copy if `array` and `array_orig` may share memory`\u001b[39;00m\n", + "\u001b[0;31mValueError\u001b[0m: Found array with 0 feature(s) (shape=(39498, 0)) while a minimum of 1 is required by PCA." + ] + } + ], + "source": [ + "# Ensure dataset yields images only\n", + "# Create a dataset that yields images only (it will keep batching as-is)\n", + "images_only = train_dataset_feature.map(lambda x, y: x)\n", + "\n", + "# Now predict\n", + "features = feature_extractor_model.predict(images_only, verbose=1)\n", + "print(\"Extracted features shape:\", features.shape)\n", + "\n", + "# Safety checks\n", + "import numpy as np\n", + "if np.isnan(features).any():\n", + " print(\"Warning: NaNs found in features - imputer will handle them.\")\n", + "if features.shape[1] == 0:\n", + " raise ValueError(\"Zero feature dimension — check the feature layer and saved model.\")\n", + "\n", + "# PCA (ensure n_components <= features.shape[1])\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.decomposition import PCA\n", + "from sklearn.pipeline import Pipeline\n", + "\n", + "n_components = min(42, features.shape[1])\n", + "pca_pipeline = Pipeline([\n", + " ('imputer', SimpleImputer(strategy='mean')),\n", + " ('pca', PCA(n_components=n_components))\n", + "])\n", + "\n", + "pred_pca = pca_pipeline.fit_transform(features)\n", + "new_feature_column_names = [f'feature_{i+1}' for i in range(pred_pca.shape[1])]\n", + "train_features = pd.DataFrame(pred_pca, columns=new_feature_column_names)\n", + "print(train_features.shape)\n", + "train_features.head()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T11:29:37.683011Z", + "iopub.status.busy": "2025-09-03T11:29:37.682676Z", + "iopub.status.idle": "2025-09-03T11:29:37.690083Z", + "shell.execute_reply": "2025-09-03T11:29:37.689232Z", + "shell.execute_reply.started": "2025-09-03T11:29:37.682983Z" + }, + "id": "L_027KddOORu", + "trusted": true + }, + "outputs": [], + "source": [ + "# import numpy as np\n", + "# import pandas as pd\n", + "# from sklearn.impute import SimpleImputer\n", + "# from sklearn.decomposition import PCA\n", + "# from sklearn.pipeline import Pipeline\n", + "\n", + "# # ---------- Diagnostic ----------\n", + "# print(\"TYPE and SHAPE of 'features':\", type(features), getattr(features, \"shape\", None))\n", + "# if isinstance(features, pd.DataFrame):\n", + "# print(\"DataFrame columns (first 20):\", features.columns.tolist()[:20])\n", + "# print(\"Showing up to first 10 rows (type + repr/shape):\")\n", + "# sample_objs = []\n", + "# for i in range(min(10, len(features))):\n", + "# try:\n", + "# row = features.iloc[i]\n", + "# except Exception:\n", + "# row = features[i]\n", + "# sample_objs.append(row)\n", + "\n", + "# for idx, s in enumerate(sample_objs):\n", + "# print(f\"\\n--- sample {idx} ---\")\n", + "# print(\"type:\", type(s))\n", + "# # If it's a Series (row of DF), show each element type\n", + "# if isinstance(s, (pd.Series, list, tuple, np.ndarray)):\n", + "# # If Series with vector-like element inside (common), try to show s.values / s.iloc[0]\n", + "# try:\n", + "# # If s is a Series of scalars (columns), show shape\n", + "# print(\"is pd.Series/list/ndarray. repr:\", repr(s))\n", + "# except:\n", + "# pass\n", + "# else:\n", + "# print(\"repr:\", repr(s))\n", + "\n", + "# # ---------- Convert to 2D array ----------\n", + "# def try_to_stack(X):\n", + "# \"\"\"\n", + "# Attempts several heuristics to produce a (n_samples, n_features) numpy array.\n", + "# Returns (arr, msg). arr is None on failure; msg explains what happened.\n", + "# \"\"\"\n", + "# # Case A: DataFrame with >0 columns -> use values\n", + "# if isinstance(X, pd.DataFrame) and X.shape[1] > 0:\n", + "# return X.values, \"DataFrame with columns\"\n", + "# # Convert to 1D object-array of per-sample vectors\n", + "# try:\n", + "# obj = np.array(X) # may be object dtype\n", + "# except Exception as e:\n", + "# return None, f\"np.array(features) failed: {e}\"\n", + "\n", + "# # If it's already 2D numeric\n", + "# if isinstance(obj, np.ndarray) and obj.ndim == 2 and obj.size > 0 and np.issubdtype(obj.dtype, np.number):\n", + "# return obj, \"Already numeric ndarray (2D)\"\n", + "\n", + "# # If obj is 1D object array where each element is a vector (np.array/list/tf.Tensor)\n", + "# if obj.ndim == 1 or (obj.ndim == 2 and obj.shape[1] == 1 and obj.dtype == object):\n", + "# # flatten to python list of elements\n", + "# elems = obj.ravel().tolist()\n", + "# # convert tensors to numpy if needed\n", + "# converted = []\n", + "# lengths = []\n", + "# for e in elems:\n", + "# # skip None\n", + "# if e is None:\n", + "# converted.append(None)\n", + "# lengths.append(0)\n", + "# continue\n", + "# # tf.Tensor -> .numpy() if possible\n", + "# try:\n", + "# import tensorflow as tf\n", + "# if isinstance(e, tf.Tensor):\n", + "# arr = e.numpy()\n", + "# else:\n", + "# arr = np.array(e)\n", + "# except Exception:\n", + "# arr = np.array(e)\n", + "# # flatten 1-D vector\n", + "# if arr.ndim == 0:\n", + "# arr = arr.reshape(1,)\n", + "# # ensure 1-D\n", + "# arr = np.asarray(arr).ravel()\n", + "# converted.append(arr)\n", + "# lengths.append(arr.shape[0])\n", + "\n", + "# unique_lengths = np.unique(lengths)\n", + "# print(\"Per-sample vector lengths (unique):\", unique_lengths[:10], \" (counts -> below)\")\n", + "# # counts\n", + "# import collections\n", + "# cnt = collections.Counter(lengths)\n", + "# print(\"counts (length:count) sample:\", dict(list(cnt.items())[:20]))\n", + "\n", + "# if len(unique_lengths) == 1 and unique_lengths[0] > 0:\n", + "# stacked = np.vstack(converted)\n", + "# return stacked, \"Stacked object-array into 2D (uniform length)\"\n", + "# # if all lengths are zero -> fail\n", + "# if max(lengths) == 0:\n", + "# return None, \"All per-sample vectors have length 0 (empty features). You need to regenerate features.\"\n", + "# # if lengths vary -> pad to max\n", + "# maxlen = int(max(lengths))\n", + "# print(f\"Varying lengths found. Padding all to max length = {maxlen}\")\n", + "# padded = np.zeros((len(converted), maxlen), dtype=float)\n", + "# for i, arr in enumerate(converted):\n", + "# if arr is None:\n", + "# continue\n", + "# L = arr.shape[0]\n", + "# padded[i, :L] = arr\n", + "# return padded, f\"Padded to maxlen={maxlen}\"\n", + "# # fallback: couldn't convert\n", + "# return None, \"Unhandled structure: can't convert to 2D\"\n", + "\n", + "# features_arr, msg = try_to_stack(features)\n", + "# print(\"\\nConversion message:\", msg)\n", + "# if features_arr is None:\n", + "# raise ValueError(\"Conversion to 2D array failed: \" + msg)\n", + "\n", + "# print(\"Converted features_arr.shape:\", features_arr.shape, \"dtype:\", features_arr.dtype)\n", + "\n", + "# # ---------- sanity ----------\n", + "# n_samples, n_feats = features_arr.shape\n", + "# if n_feats == 0:\n", + "# raise ValueError(\"After conversion there are 0 features per sample. This means your per-sample vectors are empty. \"\n", + "# \"You must re-extract ViT features (check feature extractor/predict step).\")\n", + "\n", + "# # ---------- PCA safely ----------\n", + "# requested_n_components = 42\n", + "# n_components = min(requested_n_components, n_feats)\n", + "# if n_components < requested_n_components:\n", + "# print(f\"Note: requested {requested_n_components} components but only {n_feats} features available. Using {n_components}.\")\n", + "\n", + "# # pipeline and transform\n", + "# pca_pipeline = Pipeline([\n", + "# ('imputer', SimpleImputer(strategy='mean')),\n", + "# ('pca', PCA(n_components=n_components))\n", + "# ])\n", + "# pred_pca = pca_pipeline.fit_transform(features_arr)\n", + "# print(\"pred_pca.shape:\", pred_pca.shape)\n", + "# colnames = [f\"feature_{i+1}\" for i in range(pred_pca.shape[1])]\n", + "# train_features = pd.DataFrame(pred_pca, columns=colnames)\n", + "# train_features.head()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2025-09-03T10:58:25.810060Z", + "iopub.status.busy": "2025-09-03T10:58:25.809442Z", + "iopub.status.idle": "2025-09-03T10:58:25.848282Z", + "shell.execute_reply": "2025-09-03T10:58:25.847234Z", + "shell.execute_reply.started": "2025-09-03T10:58:25.810027Z" + }, + "id": "2u8N_pYvOORu", + "outputId": "21b0fe25-6108-4c11-ed6b-38d8b8f647ae", + "trusted": true + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'train_features' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[36], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m X_train \u001b[38;5;241m=\u001b[39m X_train\u001b[38;5;241m.\u001b[39mreset_index(drop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 4\u001b[0m y_train \u001b[38;5;241m=\u001b[39m y_train\u001b[38;5;241m.\u001b[39mreset_index(drop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 5\u001b[0m train_data_last \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mconcat([X_train, \u001b[43mtrain_features\u001b[49m, y_train], axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Dropping a few null features\u001b[39;00m\n\u001b[1;32m 8\u001b[0m train_data_last\u001b[38;5;241m.\u001b[39mdrop([\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfull_path_to_img\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpath_to_img\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstock_number\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mshape\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlab\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcut\u001b[39m\u001b[38;5;124m'\u001b[39m], axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m, inplace \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'train_features' is not defined" + ] + } + ], + "source": [ + "# Merging the deep learning features with meta data features\n", + "\n", + "X_train = X_train.reset_index(drop = True)\n", + "y_train = y_train.reset_index(drop = True)\n", + "train_data_last = pd.concat([X_train, train_features, y_train], axis = 1)\n", + "\n", + "# Dropping a few null features\n", + "train_data_last.drop(['full_path_to_img', 'path_to_img', 'stock_number', 'shape', 'lab', 'cut'], axis = 1, inplace = True)\n", + "train_data_last.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wzvItwQmOORu", + "trusted": true + }, + "outputs": [], + "source": [ + "# Encoding nominal features\n", + "enc1 = OrdinalEncoder()\n", + "\n", + "train_data_last['clarity'] = enc1.fit_transform(np.array(train_data_last['clarity']).reshape(-1,1))\n", + "train_data_last['colour'] = enc1.fit_transform(np.array(train_data_last['colour']).reshape(-1,1))\n", + "train_data_last['polish'] = enc1.fit_transform(np.array(train_data_last['polish']).reshape(-1,1))\n", + "train_data_last['symmetry'] = enc1.fit_transform(np.array(train_data_last['symmetry']).reshape(-1,1))\n", + "train_data_last['fluorescence'] = enc1.fit_transform(np.array(train_data_last['fluorescence']).reshape(-1,1))\n", + "\n", + "# Dropping several NaN values\n", + "train_data_last.dropna(inplace = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zu_DZmQtOORv" + }, + "source": [ + "# Classification Catboost Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xG3TpeTqOORv", + "trusted": true + }, + "outputs": [], + "source": [ + "y_train = train_data_last.pop('encoded_class')\n", + "X_train = train_data_last\n", + "\n", + "# CatBoostClassifier model\n", + "cat_model = CatBoostClassifier(verbose = 250)\n", + "\n", + "# Training CatBoost model with features from the ViT feature extractor\n", + "cat_model.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Koo22eZoOORv", + "trusted": true + }, + "outputs": [], + "source": [ + "# Test set processing\n", + "# test feature exteaction -> applying PCA -> predictions of the CatBoostClassifier model\n", + "\n", + "with stg.scope():\n", + " test_features = feature_extractor_model.predict(test_dataset)\n", + "\n", + "test_features = pca_.transform(test_features)\n", + "test_features = pd.DataFrame(test_features, columns = new_feature_column_names)\n", + "\n", + "\n", + "X_test = X_test.reset_index(drop = True)\n", + "y_test = y_test.reset_index(drop = True)\n", + "test_data_last = pd.concat([X_test, test_features, y_test], axis = 1)\n", + "test_data_last.drop(['full_path_to_img', 'path_to_img', 'stock_number', 'shape', 'lab', 'cut'], axis = 1, inplace = True)\n", + "\n", + "test_data_last['clarity'] = enc1.fit_transform(np.array(test_data_last['clarity']).reshape(-1,1))\n", + "test_data_last['colour'] = enc1.fit_transform(np.array(test_data_last['colour']).reshape(-1,1))\n", + "test_data_last['polish'] = enc1.fit_transform(np.array(test_data_last['polish']).reshape(-1,1))\n", + "test_data_last['symmetry'] = enc1.fit_transform(np.array(test_data_last['symmetry']).reshape(-1,1))\n", + "test_data_last['fluorescence'] = enc1.fit_transform(np.array(test_data_last['fluorescence']).reshape(-1,1))\n", + "test_data_last.dropna(inplace = True)\n", + "\n", + "y_test = test_data_last.pop('encoded_class')\n", + "X_test = test_data_last\n", + "\n", + "test_pred = cat_model.predict(X_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1AfGvt9XOORv" + }, + "source": [ + "# Test Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q3UBckP_OORv", + "trusted": true + }, + "outputs": [], + "source": [ + "# Predictions and scores\n", + "\n", + "mse = mean_squared_error(y_test, test_pred)\n", + "f1 = f1_score(y_test, test_pred, average = 'weighted')\n", + "acc = accuracy_score(y_test, test_pred)\n", + "\n", + "print('Mean Squared Error : {0:.5f}'.format(mse))\n", + "print('Weighted F1 Score : {0:.3f}'.format(f1))\n", + "print('Accuracy Score : {0:.3f} %'.format(acc*100))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "APxfmkS4OORv" + }, + "source": [ + "# Test Classification Report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QNr2xy6LOORv", + "trusted": true + }, + "outputs": [], + "source": [ + "# classification report\n", + "\n", + "clf_report = classification_report(y_test, test_pred, target_names = list(classes.values()))\n", + "print(clf_report)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MEW8qtV5OORv" + }, + "source": [ + "# Test Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oOXiosw-OORv", + "trusted": true + }, + "outputs": [], + "source": [ + "cm = confusion_matrix(y_test, test_pred)\n", + "cmd = ConfusionMatrixDisplay(cm, display_labels = list(classes.values()))\n", + "\n", + "fig, ax = plt.subplots(figsize=(8,8))\n", + "cmd.plot(ax=ax, cmap = 'RdPu', colourbar = False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AtHGwMNgOORv" + }, + "source": [ + "# Feature Explanation w/SHAP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SFXlO6tSOORv", + "trusted": true + }, + "outputs": [], + "source": [ + "explainer = shap.TreeExplainer(cat_model)\n", + "shap_values = explainer(pd.DataFrame(X_test, columns = X_test.columns))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r3M-MCWbOORv", + "trusted": true + }, + "outputs": [], + "source": [ + "# cushion class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 0], max_display = 12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xCnrnp5LOORv", + "trusted": true + }, + "outputs": [], + "source": [ + "# emerald class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 1], max_display = 12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MjGXN6M1OORv" + }, + "outputs": [], + "source": [ + "# heart class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 2], max_display = 12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rY1kptm2OORw", + "trusted": true + }, + "outputs": [], + "source": [ + "# marquise class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 3], max_display = 12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WjOvY9xeOORw", + "trusted": true + }, + "outputs": [], + "source": [ + "# oval class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 4], max_display = 12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jTbB-gT0OORw", + "trusted": true + }, + "outputs": [], + "source": [ + "# pear class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 5], max_display = 12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kqkxBCEOOORw", + "trusted": true + }, + "outputs": [], + "source": [ + "# princess class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 6], max_display = 12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q_KM9RVnOORw", + "trusted": true + }, + "outputs": [], + "source": [ + "# round class feature explanation\n", + "shap.plots.beeswarm(shap_values[..., 7], max_display = 12)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MZsEHc8wOORw" + }, + "source": [ + "# Test Sample Prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "okevl7AxOORw", + "trusted": true + }, + "outputs": [], + "source": [ + "test_take1 = test_dataset.take(-1)\n", + "test_take1_ = list(test_take1)\n", + "\n", + "# A function that creating 5 random images in the test set and predictions\n", + "\n", + "# Red title -> a false prediction\n", + "# Green title -> a true prediction\n", + "\n", + "def random_test_sample_with_prediction(SEED):\n", + " idxs = np.random.default_rng(seed=SEED).permutation(len(test_pred))[:5]\n", + " batch_idx = idxs // BATCH_SIZE\n", + " image_idx = idxs-batch_idx * BATCH_SIZE\n", + " idx = idxs\n", + "\n", + " fig, axs = plt.subplots(1,5, figsize = (12,12) ,dpi = 150)\n", + "\n", + " for i in range(5):\n", + " img = test_take1_[batch_idx[i]][0][image_idx[i]]\n", + " img = cv2.cvtcolour(img.numpy(), cv2.colour_BGR2GRAY)\n", + "\n", + " label = test_take1_[batch_idx[i]][1][image_idx[i]].numpy()\n", + "\n", + "\n", + " if int(test_pred[idx[i]]) == label:\n", + " axs[i].imshow(img, cmap = 'gray')\n", + " axs[i].axis('off')\n", + " axs[i].set_title('image (no: ' + str(idx[i]) + ')' + '\\n' + classes[label], fontsize = 8, colour = 'green')\n", + " else:\n", + " axs[i].imshow(img, cmap = 'gray')\n", + " axs[i].axis('off')\n", + " axs[i].set_title('image (no: ' + str(idx[i]) + ')' + '\\n' + classes[label], fontsize = 8, colour = 'red')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qAe_G92HOORw", + "trusted": true + }, + "outputs": [], + "source": [ + "# Red title -> a false prediction\n", + "# Green title -> a true prediction\n", + "\n", + "random_test_sample_with_prediction(SEED = 140)\n", + "random_test_sample_with_prediction(SEED = 20)\n", + "random_test_sample_with_prediction(SEED = 30)\n", + "random_test_sample_with_prediction(SEED = 99)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kaggle": { + "accelerator": "nvidiaTeslaT4", + "dataSources": [ + { + "datasetId": 3985216, + "sourceId": 6939472, + "sourceType": "datasetVersion" + } + ], + "dockerImageVersionId": 30580, + "isGpuEnabled": true, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "my_env_diam", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "003b90eb71264ee4837cc26731eaaf68": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_0d8ac391bc624124ad80a90c8f83e5c5", + "IPY_MODEL_c8052b5fc74c434f8e41aed890cdd2b4", + "IPY_MODEL_42834001a2ae44eb8089d4a867a301e1" + ], + "layout": "IPY_MODEL_595f27c351834f83893813259593197c" + } + }, + "03f990d097234c98b158ab354f2cddd3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0a45dcdabe434b6fa7c8a79c166b5678": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_20964f171fd04b58a3c202638350fbd5", + "placeholder": "​", + "style": "IPY_MODEL_acaf7d77233a44919dfcb2be651ae5f3", + "value": "config.json: 100%" + } + }, + "0be875023c0d42e1acb500112f5920f6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0d8ac391bc624124ad80a90c8f83e5c5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a7a007e8da704593b5e4ddd0e7691dff", + "placeholder": "​", + "style": "IPY_MODEL_647c21a9037c4cc7b0b3d02081d45df3", + "value": "model-00005-of-00006.safetensors: 100%" + } + }, + "1113a8ad18b94f9f86d107a62a59804e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "125e7a0cacae41f493bd0976d4d8397f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "138f3e4411414bc2b049c878eed9b21f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "15835b5eda3942249c60a3caea5a753d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_aa03ea003bc44e8da03d7071c21f3c57", + "max": 6, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1b54bdbc27844c458b80a5e4fa36fd6d", + "value": 6 + } + }, + "16b9b627ec4e432f976212785a2c9438": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1113a8ad18b94f9f86d107a62a59804e", + "placeholder": "​", + "style": "IPY_MODEL_75fca78a498b4edbb39dd9d600b465ec", + "value": "model-00002-of-00006.safetensors: 100%" + } + }, + "16cc161d8a4344edbf063ef72bdc4b5c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e6bd3eb2c1aa49d889166f51ec3b62b4", + "max": 585, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_03f990d097234c98b158ab354f2cddd3", + "value": 585 + } + }, + "1a84872dda9b4a7aa8af9853670a936d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "1b54bdbc27844c458b80a5e4fa36fd6d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "1da54d3ceabe48e3835bf1c1e688b620": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1e1fc13a1f2640cb8c10c24c49e59bce": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "20964f171fd04b58a3c202638350fbd5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "20c9c4ccb5c34e8db247f073216ea769": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_381592cee7b943a1b849866052205f95", + "IPY_MODEL_15835b5eda3942249c60a3caea5a753d", + "IPY_MODEL_5f7a11e86fdc40609770ce9b84639855" + ], + "layout": "IPY_MODEL_d901e3f1c4064491a28e2a1820277689" + } + }, + "22c681378bf742d1a03f0f6c4447f2f2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_510e2a336239491c90b38bdefe2458e2", + "placeholder": "​", + "style": "IPY_MODEL_76125ba4efb74c8c9f96260a240768ba", + "value": " 2.01G/2.01G [04:19<00:00, 7.93MB/s]" + } + }, + "2370b724aa6449f2a36225923fdd4444": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "256a6f7965cb4692a93135cf30473310": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c5fc49048b054ce59032c0ff902d5027", + "IPY_MODEL_59f6d4c3694847ebb88350e5e98efa0a", + "IPY_MODEL_482a1a41cb63488d9a9b9ee3308ac925" + ], + "layout": "IPY_MODEL_b74beb0b19bf4f1d95bc853b2cce5b5f" + } + }, + "294f5a054ff5417b8dd58c7b3820ce08": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2a9063f0598a49bb97a92c99111d6e63": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "3200d4c23b9b453b80b4f7fd12ccbed8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b9ce5afbbcdb4c2a88605c7f6c2f3ee0", + "placeholder": "​", + "style": "IPY_MODEL_b12eabb3f64d4cbcb5c0b6bd610f9eca", + "value": " 4.98G/4.98G [08:55<00:00, 17.3MB/s]" + } + }, + "32b886b4434a4067acd406cf6c42ed3f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "333d0b72763d4f8c9055a9032aebf200": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "34851b6770104184b81198ec66210ad0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "373fa29214984f4b83331ecdd0e588ad": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_16b9b627ec4e432f976212785a2c9438", + "IPY_MODEL_6993b2c5967f4d0d9ddce9b8575eb362", + "IPY_MODEL_f139900f50a2410a927b74b68a784a3d" + ], + "layout": "IPY_MODEL_1e1fc13a1f2640cb8c10c24c49e59bce" + } + }, + "380d1f3fdd2b42ab83c7da7cb8a8343a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "381592cee7b943a1b849866052205f95": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1da54d3ceabe48e3835bf1c1e688b620", + "placeholder": "​", + "style": "IPY_MODEL_6e51b8f33d2b41fcba3dc8f0a85bd674", + "value": "Loading checkpoint shards: 100%" + } + }, + "3c6a2bf0b8d04d678d02ba68c9d48caa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2370b724aa6449f2a36225923fdd4444", + "placeholder": "​", + "style": "IPY_MODEL_86955d5073b747139e5bbdbb58709aef", + "value": "model-00004-of-00006.safetensors: 100%" + } + }, + "3e07d400ef6c4966b76a050ce824bd26": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "40ca9b965fca46c9912c4fca7ace8974": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "422bfbb2b1ef4b6ebd50fa67533f7617": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "42834001a2ae44eb8089d4a867a301e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_856be4312e0c4aa29a8db6b6fe2eb43f", + "placeholder": "​", + "style": "IPY_MODEL_0be875023c0d42e1acb500112f5920f6", + "value": " 4.97G/4.97G [08:56<00:00, 2.74MB/s]" + } + }, + "47e7a059bb304ac1a91f996401d3377b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "482a1a41cb63488d9a9b9ee3308ac925": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_51f7085a0bb84f2283aaea2c077cff37", + "placeholder": "​", + "style": "IPY_MODEL_d8ea116b269440f788b3d7c0b5b7960d", + "value": " 6/6 [08:57<00:00, 221.46s/it]" + } + }, + "487b45ed6e584c5689a5fbf044be220f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_51b0066f8e1b4213a26a595ff4ddd3e5", + "placeholder": "​", + "style": "IPY_MODEL_333d0b72763d4f8c9055a9032aebf200", + "value": " 746/746 [00:00<00:00, 36.4kB/s]" + } + }, + "4c1ee6e7c54f412e8849d82566339f52": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "510e2a336239491c90b38bdefe2458e2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "51b0066f8e1b4213a26a595ff4ddd3e5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "51f7085a0bb84f2283aaea2c077cff37": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "529de39b329543638ff427de8df7fc3e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5491f789531f46c7832db61ee85cea22": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "57fd4f7cb6694d91bf9fe3584c2a2291": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "595f27c351834f83893813259593197c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "59f6d4c3694847ebb88350e5e98efa0a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6ce65ec2a362422aac6c5d840ffc5b90", + "max": 6, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_bc76c7853ae44aa49e2fdcef5db9dd36", + "value": 6 + } + }, + "5b8471a6a6db4ee59bda217d17235758": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b340b1b093f74a46a5cc2aba48c546bb", + "max": 4980241600, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1a84872dda9b4a7aa8af9853670a936d", + "value": 4980241600 + } + }, + "5b98a4f7a87445c080211c05975a405b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5bc904340f614621bc2079fdc477950a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5f7a11e86fdc40609770ce9b84639855": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bc9919853de14b08b57eb267d256e103", + "placeholder": "​", + "style": "IPY_MODEL_94176e3863764593be768a447f168ef1", + "value": " 6/6 [00:02<00:00,  3.43it/s]" + } + }, + "5ffdba4dfa664dcabc7e9b035850fe4f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "60970f074d574caf8d47d98a1bf85e1e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_90485fc63dab4b55963a9d9d6ad3588c", + "IPY_MODEL_16cc161d8a4344edbf063ef72bdc4b5c", + "IPY_MODEL_891c8f2d84a145e5abc47242a5102b41" + ], + "layout": "IPY_MODEL_529de39b329543638ff427de8df7fc3e" + } + }, + "61fac4e384694f6d99850a52a451b4cd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cc644e6fabba4917af45828e8b7842af", + "max": 2013860920, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_ba17d00650c140b49fc1062090711a2d", + "value": 2013860920 + } + }, + "6385b059f3354af4835c86fde57fd8aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_0a45dcdabe434b6fa7c8a79c166b5678", + "IPY_MODEL_be3d41be4892448bb87a568d080757ba", + "IPY_MODEL_487b45ed6e584c5689a5fbf044be220f" + ], + "layout": "IPY_MODEL_34851b6770104184b81198ec66210ad0" + } + }, + "647c21a9037c4cc7b0b3d02081d45df3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "65236c5d0ba54920abc30daf651627ee": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6837bfda6ef24bdfa47d1b72510e0cae": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ed5cff7d6c1a4c2ab1d15b325f509e69", + "IPY_MODEL_b05720093bb34c4e9dcf174f3267e548", + "IPY_MODEL_ce2de52f01e240c79fcf77269eec3a61" + ], + "layout": "IPY_MODEL_bcb0522b4a864b1a89560ffaa91ae833" + } + }, + "6921fc373c0e4ac792aed0a2b2379f5e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "6993b2c5967f4d0d9ddce9b8575eb362": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5491f789531f46c7832db61ee85cea22", + "max": 4967510232, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_380d1f3fdd2b42ab83c7da7cb8a8343a", + "value": 4967510232 + } + }, + "6c37f5fc0e6240b0be3e57a5969bdd96": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6ce65ec2a362422aac6c5d840ffc5b90": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6e51b8f33d2b41fcba3dc8f0a85bd674": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "75bc40a0532b436897f50d01436bf47b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "75fca78a498b4edbb39dd9d600b465ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "76125ba4efb74c8c9f96260a240768ba": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "77907342a80046e1a0184e5bb1a187de": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_138f3e4411414bc2b049c878eed9b21f", + "placeholder": "​", + "style": "IPY_MODEL_422bfbb2b1ef4b6ebd50fa67533f7617", + "value": "model-00006-of-00006.safetensors: 100%" + } + }, + "7a2b64dc4b814b5eb398c8ed87084a94": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_fbaf51fec6e34c65916cc167ca966fd5", + "IPY_MODEL_5b8471a6a6db4ee59bda217d17235758", + "IPY_MODEL_3200d4c23b9b453b80b4f7fd12ccbed8" + ], + "layout": "IPY_MODEL_85cc74b6a47f44c48b267c5c9b38d1ac" + } + }, + "7ab4a73837cb4ac48cc0c9eb9545f213": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3c6a2bf0b8d04d678d02ba68c9d48caa", + "IPY_MODEL_c4096f7cdb3c4ba5b657c86eee007ad9", + "IPY_MODEL_d9e1d7189391479898829d7c3c2151db" + ], + "layout": "IPY_MODEL_8b68198a8de54ea6b246caa7704ea1ed" + } + }, + "801a8b0fdf3843d59125f3e94d68a7f4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_84fa302267ce4cf88f7267601332bab0", + "IPY_MODEL_89c0ddc2a0b445208e373fd223c2bbfa", + "IPY_MODEL_9cd3584c1d8e4d84a3e6a032ab155ca0" + ], + "layout": "IPY_MODEL_294f5a054ff5417b8dd58c7b3820ce08" + } + }, + "81191b7274f6455ca0415ff8a3929c6b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "84fa302267ce4cf88f7267601332bab0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_47e7a059bb304ac1a91f996401d3377b", + "placeholder": "​", + "style": "IPY_MODEL_ddd7cfff5db548828d5f5d9a060ed2b5", + "value": "model-00003-of-00006.safetensors: 100%" + } + }, + "856be4312e0c4aa29a8db6b6fe2eb43f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "85cc74b6a47f44c48b267c5c9b38d1ac": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "86955d5073b747139e5bbdbb58709aef": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "891c8f2d84a145e5abc47242a5102b41": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_32b886b4434a4067acd406cf6c42ed3f", + "placeholder": "​", + "style": "IPY_MODEL_81191b7274f6455ca0415ff8a3929c6b", + "value": " 585/585 [00:00<00:00, 30.9kB/s]" + } + }, + "89c0ddc2a0b445208e373fd223c2bbfa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_65236c5d0ba54920abc30daf651627ee", + "max": 4967510568, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5ffdba4dfa664dcabc7e9b035850fe4f", + "value": 4967510568 + } + }, + "8a6cdb36476547cca2e6ca81c642f359": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8b68198a8de54ea6b246caa7704ea1ed": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "90485fc63dab4b55963a9d9d6ad3588c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_eb9a1383baec4cdea191b744252c24f0", + "placeholder": "​", + "style": "IPY_MODEL_75bc40a0532b436897f50d01436bf47b", + "value": "preprocessor_config.json: 100%" + } + }, + "913258671b4b49278f0cdfd6611700db": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "94176e3863764593be768a447f168ef1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9cd3584c1d8e4d84a3e6a032ab155ca0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bb32ad25633040b8a6cbdbc78bbe3d63", + "placeholder": "​", + "style": "IPY_MODEL_b84e684096b4473fb161c660c91ccfc1", + "value": " 4.97G/4.97G [07:04<00:00, 5.20MB/s]" + } + }, + "a7a007e8da704593b5e4ddd0e7691dff": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aa03ea003bc44e8da03d7071c21f3c57": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "acaf7d77233a44919dfcb2be651ae5f3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b05720093bb34c4e9dcf174f3267e548": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3e07d400ef6c4966b76a050ce824bd26", + "max": 48723, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_de25b58a78a84cc9b92e0d645c04fa39", + "value": 48723 + } + }, + "b12eabb3f64d4cbcb5c0b6bd610f9eca": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b340b1b093f74a46a5cc2aba48c546bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b6e7af7913794a30859558284e36b0fb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_77907342a80046e1a0184e5bb1a187de", + "IPY_MODEL_61fac4e384694f6d99850a52a451b4cd", + "IPY_MODEL_22c681378bf742d1a03f0f6c4447f2f2" + ], + "layout": "IPY_MODEL_6c37f5fc0e6240b0be3e57a5969bdd96" + } + }, + "b74beb0b19bf4f1d95bc853b2cce5b5f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b84e684096b4473fb161c660c91ccfc1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b9ce5afbbcdb4c2a88605c7f6c2f3ee0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ba17d00650c140b49fc1062090711a2d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "bb32ad25633040b8a6cbdbc78bbe3d63": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bc76c7853ae44aa49e2fdcef5db9dd36": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "bc9919853de14b08b57eb267d256e103": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bcb0522b4a864b1a89560ffaa91ae833": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be3d41be4892448bb87a568d080757ba": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e64e6e50f34949f5b338adf091c081ee", + "max": 746, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_2a9063f0598a49bb97a92c99111d6e63", + "value": 746 + } + }, + "c4096f7cdb3c4ba5b657c86eee007ad9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ef3003636c43401f9e56496b2c492087", + "max": 4967543448, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6921fc373c0e4ac792aed0a2b2379f5e", + "value": 4967543448 + } + }, + "c451d89b1fe644c197c1054eb79ec76c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c54aad916d274a07840e76db5d7a7625": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c5fc49048b054ce59032c0ff902d5027": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e9de6be47baf4fcba86b32c61b520d3e", + "placeholder": "​", + "style": "IPY_MODEL_40ca9b965fca46c9912c4fca7ace8974", + "value": "Fetching 6 files: 100%" + } + }, + "c8052b5fc74c434f8e41aed890cdd2b4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_913258671b4b49278f0cdfd6611700db", + "max": 4967543320, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_57fd4f7cb6694d91bf9fe3584c2a2291", + "value": 4967543320 + } + }, + "cc644e6fabba4917af45828e8b7842af": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ce2de52f01e240c79fcf77269eec3a61": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5bc904340f614621bc2079fdc477950a", + "placeholder": "​", + "style": "IPY_MODEL_8a6cdb36476547cca2e6ca81c642f359", + "value": " 48.7k/48.7k [00:00<00:00, 2.01MB/s]" + } + }, + "d75143e089524729a48a91b4a3001b66": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d8ea116b269440f788b3d7c0b5b7960d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d901e3f1c4064491a28e2a1820277689": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d9e1d7189391479898829d7c3c2151db": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ef23a8cdc01d4646aa8d4448c2427827", + "placeholder": "​", + "style": "IPY_MODEL_125e7a0cacae41f493bd0976d4d8397f", + "value": " 4.97G/4.97G [08:10<00:00, 2.68MB/s]" + } + }, + "db8602ea3cc54723867c087d827031ee": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ddd7cfff5db548828d5f5d9a060ed2b5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "de25b58a78a84cc9b92e0d645c04fa39": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e64e6e50f34949f5b338adf091c081ee": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e6bd3eb2c1aa49d889166f51ec3b62b4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e9de6be47baf4fcba86b32c61b520d3e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "eb9a1383baec4cdea191b744252c24f0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ed5cff7d6c1a4c2ab1d15b325f509e69": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c451d89b1fe644c197c1054eb79ec76c", + "placeholder": "​", + "style": "IPY_MODEL_c54aad916d274a07840e76db5d7a7625", + "value": "model.safetensors.index.json: 100%" + } + }, + "ef23a8cdc01d4646aa8d4448c2427827": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ef3003636c43401f9e56496b2c492087": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f139900f50a2410a927b74b68a784a3d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4c1ee6e7c54f412e8849d82566339f52", + "placeholder": "​", + "style": "IPY_MODEL_5b98a4f7a87445c080211c05975a405b", + "value": " 4.97G/4.97G [08:56<00:00, 11.2MB/s]" + } + }, + "fbaf51fec6e34c65916cc167ca966fd5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_db8602ea3cc54723867c087d827031ee", + "placeholder": "​", + "style": "IPY_MODEL_d75143e089524729a48a91b4a3001b66", + "value": "model-00001-of-00006.safetensors: 100%" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}