Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,27 +1,192 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import
|
| 3 |
-
import
|
|
|
|
|
|
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# --- Gradio Interface Definition (Minimal) ---
|
| 14 |
with gr.Blocks() as demo:
|
| 15 |
gr.Markdown("""
|
| 16 |
-
# Minimal Button Test
|
| 17 |
-
|
|
|
|
| 18 |
""")
|
| 19 |
with gr.Column():
|
| 20 |
-
test_button = gr.Button("Test GPU
|
| 21 |
output_text = gr.Textbox(label="Output")
|
| 22 |
|
| 23 |
test_button.click(
|
| 24 |
-
fn=
|
| 25 |
inputs=[],
|
| 26 |
outputs=[output_text]
|
| 27 |
)
|
|
@@ -29,4 +194,43 @@ with gr.Blocks() as demo:
|
|
| 29 |
# --- Main Block ---
|
| 30 |
if __name__ == "__main__":
|
| 31 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
|
|
|
| 32 |
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image # Keep PIL for now, might be needed by helpers implicitly
|
| 4 |
+
# from PIL import Image, ImageDraw, ImageFont # No drawing yet
|
| 5 |
+
import json
|
| 6 |
import os
|
| 7 |
+
import io
|
| 8 |
+
import requests
|
| 9 |
+
# import matplotlib.pyplot as plt # No plotting yet
|
| 10 |
+
# import matplotlib # No plotting yet
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import List, Dict, Optional, Tuple
|
| 14 |
+
import time
|
| 15 |
+
import spaces # Required for @spaces.GPU
|
| 16 |
+
|
| 17 |
+
import torch # Keep torch for device check in Tagger
|
| 18 |
+
# import timm # No model loading yet
|
| 19 |
+
# from safetensors.torch import load_file as safe_load_file # No model loading yet
|
| 20 |
+
|
| 21 |
+
# MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
|
| 22 |
+
# matplotlib.use('Agg')
|
| 23 |
+
|
| 24 |
+
# --- Data Classes and Helper Functions ---
|
| 25 |
+
@dataclass
|
| 26 |
+
class LabelData:
|
| 27 |
+
names: list[str]
|
| 28 |
+
rating: list[np.int64]
|
| 29 |
+
general: list[np.int64]
|
| 30 |
+
artist: list[np.int64]
|
| 31 |
+
character: list[np.int64]
|
| 32 |
+
copyright: list[np.int64]
|
| 33 |
+
meta: list[np.int64]
|
| 34 |
+
quality: list[np.int64]
|
| 35 |
+
|
| 36 |
+
# Keep helpers needed for initialization
|
| 37 |
+
def load_tag_mapping(mapping_path):
|
| 38 |
+
with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
|
| 39 |
+
if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
|
| 40 |
+
idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
|
| 41 |
+
tag_to_category = tag_mapping_data["tag_to_category"]
|
| 42 |
+
elif isinstance(tag_mapping_data, dict):
|
| 43 |
+
tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
|
| 44 |
+
idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data.items()}
|
| 45 |
+
tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data.values()}
|
| 46 |
+
else: raise ValueError("Unsupported tag mapping format")
|
| 47 |
+
names = [None] * (max(idx_to_tag.keys()) + 1)
|
| 48 |
+
rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
|
| 49 |
+
for idx, tag in idx_to_tag.items():
|
| 50 |
+
if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
|
| 51 |
+
names[idx] = tag
|
| 52 |
+
category = tag_to_category.get(tag, 'Unknown')
|
| 53 |
+
idx_int = int(idx)
|
| 54 |
+
if category == 'Rating': rating.append(idx_int)
|
| 55 |
+
elif category == 'General': general.append(idx_int)
|
| 56 |
+
elif category == 'Artist': artist.append(idx_int)
|
| 57 |
+
elif category == 'Character': character.append(idx_int)
|
| 58 |
+
elif category == 'Copyright': copyright.append(idx_int)
|
| 59 |
+
elif category == 'Meta': meta.append(idx_int)
|
| 60 |
+
elif category == 'Quality': quality.append(idx_int)
|
| 61 |
+
return LabelData(names=names, rating=np.array(rating), general=np.array(general), artist=np.array(artist),
|
| 62 |
+
character=np.array(character), copyright=np.array(copyright), meta=np.array(meta), quality=np.array(quality)), tag_to_category
|
| 63 |
+
|
| 64 |
+
# --- Constants ---
|
| 65 |
+
REPO_ID = "cella110n/cl_tagger"
|
| 66 |
+
SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
|
| 67 |
+
METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
|
| 68 |
+
TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
|
| 69 |
+
CACHE_DIR = "./model_cache"
|
| 70 |
+
# BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # No model loading yet
|
| 71 |
+
|
| 72 |
+
# --- Tagger Class ---
|
| 73 |
+
class Tagger:
|
| 74 |
+
def __init__(self):
|
| 75 |
+
print("Initializing Tagger...")
|
| 76 |
+
self.safetensors_path = None
|
| 77 |
+
self.metadata_path = None
|
| 78 |
+
self.tag_mapping_path = None
|
| 79 |
+
self.labels_data = None
|
| 80 |
+
self.tag_to_category = None
|
| 81 |
+
self.model = None # Model will be loaded later
|
| 82 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
+
self._initialize_paths_and_labels()
|
| 84 |
+
print("Tagger Initialized.") # Add confirmation
|
| 85 |
+
|
| 86 |
+
def _download_files(self):
|
| 87 |
+
# Check if paths are already set and files exist (useful for restarts)
|
| 88 |
+
local_safetensors = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', SAFETENSORS_FILENAME)
|
| 89 |
+
local_tag_mapping = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', TAG_MAPPING_FILENAME)
|
| 90 |
+
local_metadata = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', METADATA_FILENAME)
|
| 91 |
+
|
| 92 |
+
needs_download = False
|
| 93 |
+
if not (self.safetensors_path and os.path.exists(self.safetensors_path)):
|
| 94 |
+
if os.path.exists(local_safetensors):
|
| 95 |
+
self.safetensors_path = local_safetensors
|
| 96 |
+
print(f"Found existing safetensors: {self.safetensors_path}")
|
| 97 |
+
else:
|
| 98 |
+
needs_download = True
|
| 99 |
+
if not (self.tag_mapping_path and os.path.exists(self.tag_mapping_path)):
|
| 100 |
+
if os.path.exists(local_tag_mapping):
|
| 101 |
+
self.tag_mapping_path = local_tag_mapping
|
| 102 |
+
print(f"Found existing tag mapping: {self.tag_mapping_path}")
|
| 103 |
+
else:
|
| 104 |
+
needs_download = True
|
| 105 |
+
# Metadata is optional, check similarly
|
| 106 |
+
if not (self.metadata_path and os.path.exists(self.metadata_path)):
|
| 107 |
+
if os.path.exists(local_metadata):
|
| 108 |
+
self.metadata_path = local_metadata
|
| 109 |
+
print(f"Found existing metadata: {self.metadata_path}")
|
| 110 |
+
# Don't trigger download just for metadata if others exist
|
| 111 |
+
|
| 112 |
+
if not needs_download and self.safetensors_path and self.tag_mapping_path:
|
| 113 |
+
print("Required files already exist or paths set.")
|
| 114 |
+
return
|
| 115 |
|
| 116 |
+
print("Downloading model files...")
|
| 117 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 118 |
+
try:
|
| 119 |
+
# Only download if not found locally
|
| 120 |
+
if not self.safetensors_path:
|
| 121 |
+
self.safetensors_path = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False) # Use force_download=False
|
| 122 |
+
if not self.tag_mapping_path:
|
| 123 |
+
self.tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
|
| 124 |
+
print(f"Safetensors: {self.safetensors_path}")
|
| 125 |
+
print(f"Tag mapping: {self.tag_mapping_path}")
|
| 126 |
+
try:
|
| 127 |
+
# Only download if not found locally
|
| 128 |
+
if not self.metadata_path:
|
| 129 |
+
self.metadata_path = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
|
| 130 |
+
print(f"Metadata: {self.metadata_path}")
|
| 131 |
+
except Exception as e_meta:
|
| 132 |
+
# Handle case where metadata genuinely doesn't exist or download fails
|
| 133 |
+
print(f"Metadata ({METADATA_FILENAME}) not found/download failed. Error: {e_meta}")
|
| 134 |
+
self.metadata_path = None
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"Error downloading files: {e}")
|
| 138 |
+
if "401 Client Error" in str(e) or "Repository not found" in str(e): raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN or repository status.")
|
| 139 |
+
else: raise gr.Error(f"Error downloading files: {e}")
|
| 140 |
+
|
| 141 |
+
def _initialize_paths_and_labels(self):
|
| 142 |
+
# Call download first (it now checks existence)
|
| 143 |
+
self._download_files()
|
| 144 |
+
# Only load labels if not already loaded
|
| 145 |
+
if self.labels_data is None:
|
| 146 |
+
print("Loading labels...")
|
| 147 |
+
if self.tag_mapping_path and os.path.exists(self.tag_mapping_path):
|
| 148 |
+
try:
|
| 149 |
+
self.labels_data, self.tag_to_category = load_tag_mapping(self.tag_mapping_path)
|
| 150 |
+
print(f"Labels loaded. Count: {len(self.labels_data.names)}")
|
| 151 |
+
except Exception as e: raise gr.Error(f"Error loading tag mapping: {e}")
|
| 152 |
+
else:
|
| 153 |
+
# This should ideally not happen if download worked
|
| 154 |
+
raise gr.Error(f"Tag mapping file not found at expected path: {self.tag_mapping_path}")
|
| 155 |
+
else:
|
| 156 |
+
print("Labels already loaded.")
|
| 157 |
+
|
| 158 |
+
# Add a simple test method decorated with GPU
|
| 159 |
+
@spaces.GPU()
|
| 160 |
+
def test_gpu_method(self):
|
| 161 |
+
current_time = time.time()
|
| 162 |
+
print(f"--- Tagger.test_gpu_method called on GPU worker at {current_time} ---")
|
| 163 |
+
# Check if labels are accessible from the GPU worker context
|
| 164 |
+
label_count = len(self.labels_data.names) if self.labels_data else -1
|
| 165 |
+
print(f"--- (Worker) Label count: {label_count} ---")
|
| 166 |
+
return f"Tagger method called at {current_time}. Label count: {label_count}"
|
| 167 |
+
|
| 168 |
+
# --- Original predict_on_gpu (Keep commented out for this test) ---
|
| 169 |
+
# @spaces.GPU()
|
| 170 |
+
# def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
|
| 171 |
+
# # ... (original prediction logic including model loading) ...
|
| 172 |
+
# pass
|
| 173 |
+
|
| 174 |
+
# Instantiate the tagger class (this will download files/load labels)
|
| 175 |
+
tagger = Tagger()
|
| 176 |
|
| 177 |
# --- Gradio Interface Definition (Minimal) ---
|
| 178 |
with gr.Blocks() as demo:
|
| 179 |
gr.Markdown("""
|
| 180 |
+
# Tagger Initialization + Minimal Button Test
|
| 181 |
+
Instantiates Tagger, then click the button below to check if a simple `@spaces.GPU` decorated *method* is triggered.
|
| 182 |
+
Check logs for Tagger initialization messages.
|
| 183 |
""")
|
| 184 |
with gr.Column():
|
| 185 |
+
test_button = gr.Button("Test Tagger GPU Method")
|
| 186 |
output_text = gr.Textbox(label="Output")
|
| 187 |
|
| 188 |
test_button.click(
|
| 189 |
+
fn=tagger.test_gpu_method, # Call the simple method on the instance
|
| 190 |
inputs=[],
|
| 191 |
outputs=[output_text]
|
| 192 |
)
|
|
|
|
| 194 |
# --- Main Block ---
|
| 195 |
if __name__ == "__main__":
|
| 196 |
if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
|
| 197 |
+
# Tagger instance is created above
|
| 198 |
demo.launch(share=True)
|
| 199 |
+
|
| 200 |
+
# --- Commented out original UI and helpers/constants not needed for init/simple test ---
|
| 201 |
+
"""
|
| 202 |
+
# import matplotlib.pyplot as plt
|
| 203 |
+
# import matplotlib
|
| 204 |
+
# matplotlib.use('Agg')
|
| 205 |
+
# from PIL import Image, ImageDraw, ImageFont
|
| 206 |
+
# import timm
|
| 207 |
+
# from safetensors.torch import load_file as safe_load_file
|
| 208 |
+
|
| 209 |
+
# def pil_ensure_rgb(...)
|
| 210 |
+
# def pil_pad_square(...)
|
| 211 |
+
# def get_tags(...)
|
| 212 |
+
# def visualize_predictions(...)
|
| 213 |
+
# def preprocess_image(...)
|
| 214 |
+
|
| 215 |
+
# BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k'
|
| 216 |
+
|
| 217 |
+
# class Tagger:
|
| 218 |
+
# ... (methods related to prediction, model loading) ...
|
| 219 |
+
# def _load_model_on_gpu(self):
|
| 220 |
+
# ...
|
| 221 |
+
# @spaces.GPU()
|
| 222 |
+
# def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
|
| 223 |
+
# ...
|
| 224 |
+
|
| 225 |
+
# --- Original Full Gradio UI ---
|
| 226 |
+
# css = ...
|
| 227 |
+
# js = ...
|
| 228 |
+
# with gr.Blocks(css=css, js=js) as demo:
|
| 229 |
+
# gr.Markdown("# WD EVA02 LoRA PyTorch Tagger")
|
| 230 |
+
# ...
|
| 231 |
+
# predict_button.click(
|
| 232 |
+
# fn=tagger.predict_on_gpu,
|
| 233 |
+
# ...
|
| 234 |
+
# )
|
| 235 |
+
# gr.Examples(...)
|
| 236 |
+
"""
|