Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import tempfile | |
| from dataset_aug import process_images | |
| import convert_source_to_sketch # スケッチ変換用のモジュールをインポート | |
| import random | |
| def process_multiple_images( | |
| source_images, | |
| target_images, | |
| output_size, | |
| num_copies, | |
| is_flip, | |
| rotation_range, | |
| min_scale, | |
| max_scale, | |
| source_is_avg_color_fill, | |
| source_is_edge_mode_fill, | |
| target_is_avg_color_fill, | |
| target_is_edge_mode_fill, | |
| expand_to_long_side, | |
| progress=gr.Progress() | |
| ): | |
| result_source_images = [] | |
| result_target_images = [] | |
| progress(0, desc="処理を開始します...") | |
| # 各画像ペアに対して処理を実行 | |
| total_pairs = len(source_images) | |
| for idx, (source_path, target_path) in enumerate(zip(source_images, target_images), 1): | |
| progress(idx/total_pairs, desc=f"画像ペア {idx}/{total_pairs} を処理中...") | |
| # PILイメージとして読み込み | |
| source_img = Image.open(source_path.name) | |
| target_img = Image.open(target_path.name) | |
| # 拡張処理を実行し、PILイメージのリストを取得 | |
| aug_sources, aug_targets = process_images( | |
| source_img, | |
| target_img, | |
| num_copies=num_copies, | |
| output_size=(output_size, output_size), | |
| is_flip=is_flip, | |
| rotation_range=rotation_range, | |
| min_scale=min_scale, | |
| max_scale=max_scale, | |
| source_is_avg_color_fill=source_is_avg_color_fill, | |
| source_is_edge_mode_fill=source_is_edge_mode_fill, | |
| target_is_avg_color_fill=target_is_avg_color_fill, | |
| target_is_edge_mode_fill=target_is_edge_mode_fill, | |
| expand_to_long_side=expand_to_long_side | |
| ) | |
| # 生成された画像を収集 | |
| result_source_images.extend(aug_sources) | |
| result_target_images.extend(aug_targets) | |
| progress(1, desc="処理が完了しました") | |
| return result_source_images, result_target_images, "処理が完了しました!" | |
| def update_source_preview(source_files): | |
| preview_images = [] | |
| if source_files: | |
| for source in source_files: | |
| preview_images.append(source.name) | |
| return preview_images | |
| def update_target_preview(target_files): | |
| preview_images = [] | |
| if target_files: | |
| for target in target_files: | |
| preview_images.append(target.name) | |
| return preview_images | |
| def convert_to_sketch(source_files): | |
| """sourceをスケッチに変換""" | |
| converted_images = [] | |
| if source_files: | |
| # 一時ディレクトリを作成(グローバルに保持) | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| for source in source_files: | |
| # スケッチ変換処理 | |
| image = Image.open(source.name) | |
| sketch = convert_source_to_sketch.convert_pil_to_sketch(image) | |
| # 一時ファイルとして保存 | |
| temp_path = os.path.join(temp_dir, os.path.basename(source.name)) | |
| sketch.save(temp_path) | |
| converted_images.append(temp_path) | |
| except Exception as e: | |
| print(f"Error during conversion: {e}") | |
| # エラー時にも一時ディレクトリを削除 | |
| if os.path.exists(temp_dir): | |
| import shutil | |
| shutil.rmtree(temp_dir) | |
| return [] | |
| return converted_images | |
| # アプリケーション終了時のクリーンアップ処理を修正 | |
| def cleanup_temp_files(): | |
| """一時ファイルをクリーンアップ""" | |
| temp_root = tempfile.gettempdir() | |
| for item in os.listdir(temp_root): | |
| if item.startswith('tmp'): | |
| item_path = os.path.join(temp_root, item) | |
| try: | |
| if os.path.isdir(item_path): | |
| # ディレクトリ内の画像ファイルをチェック | |
| for root, dirs, files in os.walk(item_path): | |
| for file in files: | |
| if file.endswith(('.jpg', '.png')): | |
| file_path = os.path.join(root, file) | |
| try: | |
| with Image.open(file_path) as img: | |
| img.verify() # 画像ファイルの整合性チェック | |
| except Exception as e: | |
| print(f"Corrupted image found: {file_path} - {e}") | |
| import shutil | |
| shutil.rmtree(item_path) | |
| except Exception as e: | |
| print(f"Error cleaning up {item_path}: {e}") | |
| def randomize_params(): | |
| """パラメータをランダムに設定""" | |
| return ( | |
| random.choice([512, 768, 1024, 1536, 2048]), # output_size | |
| random.randint(1, 5), # num_copies | |
| random.choice([True, False]), # is_flip | |
| random.randint(0, 180), # rotation_range | |
| round(random.uniform(0.1, 1.0), 1), # min_scale | |
| round(random.uniform(1.0, 2.0), 1), # max_scale | |
| random.choice([True, False]), # source_is_avg_color_fill | |
| random.choice([True, False]), # source_is_edge_mode_fill | |
| random.choice([True, False]), # target_is_avg_color_fill | |
| random.choice([True, False]), # target_is_edge_mode_fill | |
| random.choice([True, False]) # expand_to_long_side | |
| ) | |
| def reset_params(): | |
| """パラメータを初期設定に戻す""" | |
| return ( | |
| 1024, # output_size | |
| 1, # num_copies | |
| True, # is_flip | |
| 0, # rotation_range | |
| 1.0, # min_scale | |
| 1.0, # max_scale | |
| True, # source_is_avg_color_fill | |
| False, # source_is_edge_mode_fill | |
| False, # target_is_avg_color_fill | |
| False, # target_is_edge_mode_fill | |
| False # expand_to_long_side | |
| ) | |
| def test_process_image_pair_with_expand_to_long_side(): | |
| """長辺拡張オプションのテスト""" | |
| source_image = Image.new('RGB', (800, 400), color='white') | |
| target_image = Image.new('RGB', (800, 400), color='white') | |
| # process_imagesを使用するように修正 | |
| aug_sources, aug_targets = process_images( | |
| source_image, | |
| target_image, | |
| num_copies=1, | |
| output_size=(512, 512), | |
| is_flip=False, | |
| rotation_range=0, | |
| min_scale=1.0, | |
| max_scale=1.0, | |
| source_is_avg_color_fill=True, | |
| source_is_edge_mode_fill=False, | |
| target_is_avg_color_fill=True, | |
| target_is_edge_mode_fill=False, | |
| expand_to_long_side=True | |
| ) | |
| result_source = aug_sources[0] | |
| result_target = aug_targets[0] | |
| # 結果が正方形であることを確認 | |
| assert result_source.size[0] == result_source.size[1] | |
| assert result_target.size[0] == result_target.size[1] | |
| # 出力サイズが指定通りであることを確認 | |
| assert result_source.size == (512, 512) | |
| assert result_target.size == (512, 512) | |
| # Gradioインターフェースの作成 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Pair Image Augmentation Test") | |
| gr.Markdown("Code : https://github.com/Yeq6X/pair-images-aug") | |
| with gr.Row(): | |
| # 左側のカラム(Source画像とパラメータ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| # Source画像 | |
| source_files = gr.File( | |
| label="Source画像を選択", | |
| file_count="multiple", | |
| file_types=["image"], | |
| height=150 | |
| ) | |
| # Target画像 | |
| target_files = gr.File( | |
| label="Target画像を選択", | |
| file_count="multiple", | |
| file_types=["image"], | |
| height=150 | |
| ) | |
| # サンプル画像の追加 | |
| gr.Examples( | |
| examples=[ | |
| [["samples/source/sample1.png", "samples/source/sample2.png"], | |
| ["samples/target/sample1.png", "samples/target/sample2.png"]], | |
| ], | |
| inputs=[source_files, target_files], | |
| label="サンプル画像セット", | |
| examples_per_page=5 | |
| ) | |
| source_preview = gr.Gallery( | |
| label="Source画像プレビュー", | |
| show_label=True, | |
| object_fit="contain", | |
| columns=4, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### scribble_xdogで変換") | |
| convert_src_to_tgt_btn = gr.Button("↓", variant="primary", size="sm") | |
| convert_tgt_to_src_btn = gr.Button("↑", variant="primary", size="sm") | |
| gr.Markdown("") | |
| target_preview = gr.Gallery( | |
| label="Target画像プレビュー", | |
| show_label=True, | |
| object_fit="contain", | |
| columns=4, | |
| ) | |
| # 右側のカラム(Target画像と出力) | |
| with gr.Column(): | |
| # パラメータ設定部分 | |
| with gr.Accordion("パラメータ設定", open=True): | |
| # パラメータ操作ボタン | |
| with gr.Row(): | |
| randomize_btn = gr.Button("🎲 ランダム設定", variant="secondary") | |
| reset_btn = gr.Button("↺ 初期設定に戻す", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_size = gr.Slider( | |
| minimum=256, | |
| maximum=2048, | |
| value=1024, | |
| step=256, | |
| label="出力画像サイズ" | |
| ) | |
| num_copies = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=1, | |
| step=1, | |
| label="リピート回数" | |
| ) | |
| is_flip = gr.Checkbox( | |
| label="ランダムフリップを適用", | |
| value=True | |
| ) | |
| expand_to_long_side = gr.Checkbox( | |
| label="長辺に合わせて拡張する", | |
| value=False | |
| ) | |
| rotation_range = gr.Slider( | |
| minimum=0, | |
| maximum=180, | |
| value=0, | |
| step=1, | |
| label="回転角度の範囲" | |
| ) | |
| with gr.Column(): | |
| min_scale = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| label="最小スケール" | |
| ) | |
| max_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="最大スケール" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_is_edge_mode_fill = gr.Checkbox( | |
| label="Source: 外周の最頻色で埋める", | |
| value=False | |
| ) | |
| source_is_avg_color_fill = gr.Checkbox( | |
| label="Source: 画像の平均色で埋める", | |
| value=True | |
| ) | |
| with gr.Column(): | |
| target_is_edge_mode_fill = gr.Checkbox( | |
| label="Target: 外周の最頻色で埋める", | |
| value=False | |
| ) | |
| target_is_avg_color_fill = gr.Checkbox( | |
| label="Target: 画像の平均色で埋める", | |
| value=False | |
| ) | |
| process_btn = gr.Button("処理開始", variant="primary") | |
| # ログ表示用のテキストボックスを追加 | |
| log_output = gr.Textbox( | |
| label="処理ログ", | |
| value="", | |
| lines=1, | |
| max_lines=1, | |
| interactive=False, | |
| visible=False | |
| ) | |
| # 結果表示 | |
| result_source_gallery = gr.Gallery( | |
| label="生成結果 (Source)", | |
| show_label=True, | |
| object_fit="contain", | |
| columns=4, | |
| type="pil" | |
| ) | |
| result_target_gallery = gr.Gallery( | |
| label="生成結果 (Target)", | |
| show_label=True, | |
| object_fit="contain", | |
| columns=4, | |
| type="pil" | |
| ) | |
| # イベントハンドラ | |
| source_files.change( | |
| fn=update_source_preview, | |
| inputs=[source_files], | |
| outputs=source_preview | |
| ) | |
| target_files.change( | |
| fn=update_target_preview, | |
| inputs=[target_files], | |
| outputs=target_preview | |
| ) | |
| convert_src_to_tgt_btn.click( | |
| fn=convert_to_sketch, | |
| inputs=[source_files], | |
| outputs=[target_files] | |
| ) | |
| convert_tgt_to_src_btn.click( | |
| fn=convert_to_sketch, | |
| inputs=[target_files], | |
| outputs=[source_files] | |
| ) | |
| param_outputs = [ | |
| output_size, | |
| num_copies, | |
| is_flip, | |
| rotation_range, | |
| min_scale, | |
| max_scale, | |
| source_is_avg_color_fill, | |
| source_is_edge_mode_fill, | |
| target_is_avg_color_fill, | |
| target_is_edge_mode_fill, | |
| expand_to_long_side | |
| ] | |
| randomize_btn.click( | |
| fn=randomize_params, | |
| inputs=[], | |
| outputs=param_outputs | |
| ) | |
| reset_btn.click( | |
| fn=reset_params, | |
| inputs=[], | |
| outputs=param_outputs | |
| ) | |
| process_btn.click( | |
| fn=process_multiple_images, | |
| inputs=[ | |
| source_files, | |
| target_files, | |
| output_size, | |
| num_copies, | |
| is_flip, | |
| rotation_range, | |
| min_scale, | |
| max_scale, | |
| source_is_avg_color_fill, | |
| source_is_edge_mode_fill, | |
| target_is_avg_color_fill, | |
| target_is_edge_mode_fill, | |
| expand_to_long_side | |
| ], | |
| outputs=[result_source_gallery, result_target_gallery, log_output] | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| demo.launch( | |
| # server_name="0.0.0.0", | |
| # server_port=8000, | |
| debug=True | |
| ) | |
| finally: | |
| cleanup_temp_files() # アプリケーション終了時にクリーンアップ |