| import argparse | |
| from utils.llm_utils import LLMCodeOptimizer | |
| from prompts import system_prompt, generate_prompt | |
| from utils.pipeline_utils import determine_pipe_loading_memory | |
| from utils.hardware_utils import ( | |
| categorize_vram, | |
| categorize_ram, | |
| get_gpu_vram_gb, | |
| get_system_ram_gb, | |
| is_compile_friendly_gpu, | |
| is_fp8_friendly, | |
| ) | |
| import torch | |
| from pprint import pprint | |
| def create_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--ckpt_id", | |
| type=str, | |
| default="black-forest-labs/FLUX.1-dev", | |
| help="Can be a repo id from the Hub or a local path where the checkpoint is stored.", | |
| ) | |
| parser.add_argument( | |
| "--gemini_model", | |
| type=str, | |
| default="gemini-2.5-flash-preview-05-20", | |
| help="Gemini model to use. Choose from https://ai.google.dev/gemini-api/docs/models.", | |
| ) | |
| parser.add_argument( | |
| "--variant", | |
| type=str, | |
| default=None, | |
| help="If the `ckpt_id` has variants, supply this flag to estimate compute. Example: 'fp16'.", | |
| ) | |
| parser.add_argument( | |
| "--disable_bf16", | |
| action="store_true", | |
| help="When enabled the load memory is affected. Prefer not enabling this flag.", | |
| ) | |
| parser.add_argument( | |
| "--enable_lossy", | |
| action="store_true", | |
| help="When enabled, the code will include snippets for enabling quantization.", | |
| ) | |
| return parser | |
| def main(args): | |
| if not torch.cuda.is_available(): | |
| raise ValueError("Not supported for non-CUDA devices for now.") | |
| loading_mem_out = determine_pipe_loading_memory(args.ckpt_id, args.variant, args.disable_bf16) | |
| load_memory = loading_mem_out["total_loading_memory_gb"] | |
| ram_gb = get_system_ram_gb() | |
| ram_category = categorize_ram(ram_gb) | |
| if ram_gb is not None: | |
| print(f"\nSystem RAM: {ram_gb:.2f} GB") | |
| print(f"RAM Category: {ram_category}") | |
| else: | |
| print("\nCould not determine System RAM.") | |
| vram_gb = get_gpu_vram_gb() | |
| vram_category = categorize_vram(vram_gb) | |
| if vram_gb is not None: | |
| print(f"\nGPU VRAM: {vram_gb:.2f} GB") | |
| print(f"VRAM Category: {vram_category}") | |
| else: | |
| print("\nGPU VRAM check complete.") | |
| is_compile_friendly = is_compile_friendly_gpu() | |
| is_fp8_compatible = is_fp8_friendly() | |
| llm = LLMCodeOptimizer(model_name=args.gemini_model, system_prompt=system_prompt) | |
| current_generate_prompt = generate_prompt.format( | |
| ckpt_id=args.ckpt_id, | |
| pipeline_loading_memory=load_memory, | |
| available_system_ram=ram_gb, | |
| available_gpu_vram=vram_gb, | |
| enable_lossy_outputs=args.enable_lossy, | |
| is_fp8_supported=is_fp8_compatible, | |
| enable_torch_compile=is_compile_friendly, | |
| ) | |
| pprint(f"{current_generate_prompt=}") | |
| print(llm(current_generate_prompt)) | |
| if __name__ == "__main__": | |
| parser = create_parser() | |
| args = parser.parse_args() | |
| main(args) | |