|
|
from fastapi import FastAPI, Form |
|
|
from fastapi.responses import FileResponse |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline, StableVideoDiffusionPipeline |
|
|
from diffusers.utils import export_to_video |
|
|
from PIL import Image |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
variant = "fp16" if device == "cuda" else None |
|
|
|
|
|
|
|
|
text2img_pipe = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=dtype |
|
|
).to(device) |
|
|
|
|
|
|
|
|
video_pipe = StableVideoDiffusionPipeline.from_pretrained( |
|
|
"stabilityai/stable-video-diffusion-img2vid-xt", |
|
|
torch_dtype=dtype, |
|
|
variant=variant |
|
|
).to(device) |
|
|
|
|
|
@app.post("/generate") |
|
|
def generate_video(prompt: str = Form(...), aspect: str = Form("16:9")): |
|
|
image = text2img_pipe(prompt).images[0] |
|
|
image = image.resize((1024, 576) if aspect == "16:9" else (576, 1024)) |
|
|
|
|
|
generator = torch.manual_seed(42) |
|
|
frames = video_pipe(image, decode_chunk_size=8, generator=generator).frames[0] |
|
|
export_to_video(frames, "output.mp4", fps=7) |
|
|
|
|
|
return FileResponse("output.mp4", media_type="video/mp4", filename="output.mp4") |