Ayeshaiqra's picture
Create app.py
9ada25e verified
raw
history blame contribute delete
661 Bytes
from fastapi import FastAPI, UploadFile, File
import torch
from PIL import Image
import torchvision.transforms as T
import io
app = FastAPI()
model = torch.load("model.pt", map_location="cpu")
model.eval()
transform = T.Compose([
T.Resize((128, 128)),
T.ToTensor()
])
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
prediction = "FIRE" if output.item() > 0.5 else "NO FIRE"
return {"prediction": prediction}