kedimestan commited on
Commit
b477c75
·
verified ·
1 Parent(s): 2a9def2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -57
app.py CHANGED
@@ -1,57 +1,57 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from pydantic import BaseModel
3
- import uvicorn
4
- import cv2
5
- import numpy as np
6
- import torch
7
- from detectron2.engine import DefaultPredictor
8
- from detectron2.config import get_cfg
9
- from detectron2 import model_zoo
10
- from detectron2.data import MetadataCatalog
11
-
12
- # Create FastAPI app
13
- app = FastAPI()
14
-
15
- # Set up the Detectron2 model
16
- cfg = get_cfg()
17
- cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
18
- cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
19
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set a threshold for detection
20
- predictor = DefaultPredictor(cfg)
21
-
22
- class PredictionResponse(BaseModel):
23
- objects: dict
24
-
25
- # API endpoint to process the image and return pixel coordinates
26
- @app.post("/predict/", response_model=PredictionResponse)
27
- async def predict(file: UploadFile = File(...)):
28
- contents = await file.read()
29
-
30
- # Load the image from bytes
31
- nparr = np.frombuffer(contents, np.uint8)
32
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
33
-
34
- # Make prediction
35
- outputs = predictor(img)
36
-
37
- # Process outputs to get pixel coordinates for each object
38
- panoptic_seg, segments_info = outputs["panoptic_seg"]
39
-
40
- objects_coordinates = {}
41
-
42
- # Iterate over each object and collect coordinates
43
- for segment in segments_info:
44
- category_id = segment["category_id"]
45
- mask = panoptic_seg == segment["id"]
46
- coordinates = np.argwhere(mask.cpu().numpy())
47
-
48
- # Convert category_id into a human-readable label
49
- label = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[category_id]
50
-
51
- objects_coordinates[label] = coordinates.tolist() # Convert to a list of coordinates
52
-
53
- return {"objects": objects_coordinates}
54
-
55
- # Start the API
56
- if __name__ == "__main__":
57
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from pydantic import BaseModel
3
+ import uvicorn
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from detectron2.engine import DefaultPredictor
8
+ from detectron2.config import get_cfg
9
+ from detectron2 import model_zoo
10
+ from detectron2.data import MetadataCatalog
11
+
12
+ # Create FastAPI app
13
+ app = FastAPI()
14
+
15
+ # Set up the Detectron2 model
16
+ cfg = get_cfg()
17
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
18
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
19
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set a threshold for detection
20
+ predictor = DefaultPredictor(cfg)
21
+
22
+ class PredictionResponse(BaseModel):
23
+ objects: dict
24
+
25
+ # API endpoint to process the image and return pixel coordinates
26
+ @app.post("/predict/", response_model=PredictionResponse)
27
+ async def predict(file: UploadFile = File(...)):
28
+ contents = await file.read()
29
+
30
+ # Load the image from bytes
31
+ nparr = np.frombuffer(contents, np.uint8)
32
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
33
+
34
+ # Make prediction
35
+ outputs = predictor(img)
36
+
37
+ # Process outputs to get pixel coordinates for each object
38
+ panoptic_seg, segments_info = outputs["panoptic_seg"]
39
+
40
+ objects_coordinates = {}
41
+
42
+ # Iterate over each object and collect coordinates
43
+ for segment in segments_info:
44
+ category_id = segment["category_id"]
45
+ mask = panoptic_seg == segment["id"]
46
+ coordinates = np.argwhere(mask.cpu().numpy())
47
+
48
+ # Convert category_id into a human-readable label
49
+ label = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[category_id]
50
+
51
+ objects_coordinates[label] = coordinates.tolist() # Convert to a list of coordinates
52
+
53
+ return {"objects": objects_coordinates}
54
+
55
+ # Start the API
56
+ if __name__ == "__main__":
57
+ uvicorn.run(app, host="0.0.0.0", port=8000)