Spaces:
Sleeping
Sleeping
✏️ [Fix] Process <- Proccess typo
Browse files- yolo/__init__.py +2 -2
- yolo/tools/data_loader.py +1 -1
- yolo/tools/solver.py +3 -3
- yolo/utils/model_utils.py +1 -1
yolo/__init__.py
CHANGED
|
@@ -10,7 +10,7 @@ from yolo.utils.logging_utils import (
|
|
| 10 |
YOLORichModelSummary,
|
| 11 |
YOLORichProgressBar,
|
| 12 |
)
|
| 13 |
-
from yolo.utils.model_utils import
|
| 14 |
|
| 15 |
all = [
|
| 16 |
"create_model",
|
|
@@ -29,5 +29,5 @@ all = [
|
|
| 29 |
"create_dataloader",
|
| 30 |
"FastModelLoader",
|
| 31 |
"TrainModel",
|
| 32 |
-
"
|
| 33 |
]
|
|
|
|
| 10 |
YOLORichModelSummary,
|
| 11 |
YOLORichProgressBar,
|
| 12 |
)
|
| 13 |
+
from yolo.utils.model_utils import PostProcess
|
| 14 |
|
| 15 |
all = [
|
| 16 |
"create_model",
|
|
|
|
| 29 |
"create_dataloader",
|
| 30 |
"FastModelLoader",
|
| 31 |
"TrainModel",
|
| 32 |
+
"PostProcess",
|
| 33 |
]
|
yolo/tools/data_loader.py
CHANGED
|
@@ -170,7 +170,7 @@ def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]
|
|
| 170 |
"""
|
| 171 |
batch_size = len(batch)
|
| 172 |
target_sizes = [item[1].size(0) for item in batch]
|
| 173 |
-
# TODO: Improve readability of these
|
| 174 |
# TODO: remove maxBbox or reduce loss function memory usage
|
| 175 |
batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
|
| 176 |
batch_targets[:, :, 0] = -1
|
|
|
|
| 170 |
"""
|
| 171 |
batch_size = len(batch)
|
| 172 |
target_sizes = [item[1].size(0) for item in batch]
|
| 173 |
+
# TODO: Improve readability of these process
|
| 174 |
# TODO: remove maxBbox or reduce loss function memory usage
|
| 175 |
batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
|
| 176 |
batch_targets[:, :, 0] = -1
|
yolo/tools/solver.py
CHANGED
|
@@ -6,7 +6,7 @@ from yolo.model.yolo import create_model
|
|
| 6 |
from yolo.tools.data_loader import create_dataloader
|
| 7 |
from yolo.tools.loss_functions import create_loss_function
|
| 8 |
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
| 9 |
-
from yolo.utils.model_utils import
|
| 10 |
|
| 11 |
|
| 12 |
class BaseModel(LightningModule):
|
|
@@ -34,14 +34,14 @@ class ValidateModel(BaseModel):
|
|
| 34 |
self.vec2box = create_converter(
|
| 35 |
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
| 36 |
)
|
| 37 |
-
self.
|
| 38 |
|
| 39 |
def val_dataloader(self):
|
| 40 |
return self.val_loader
|
| 41 |
|
| 42 |
def validation_step(self, batch, batch_idx):
|
| 43 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
| 44 |
-
predicts = self.
|
| 45 |
batch_metrics = self.metric(
|
| 46 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
| 47 |
)
|
|
|
|
| 6 |
from yolo.tools.data_loader import create_dataloader
|
| 7 |
from yolo.tools.loss_functions import create_loss_function
|
| 8 |
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
| 9 |
+
from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler
|
| 10 |
|
| 11 |
|
| 12 |
class BaseModel(LightningModule):
|
|
|
|
| 34 |
self.vec2box = create_converter(
|
| 35 |
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
| 36 |
)
|
| 37 |
+
self.post_process = PostProcess(self.vec2box, self.validation_cfg.nms)
|
| 38 |
|
| 39 |
def val_dataloader(self):
|
| 40 |
return self.val_loader
|
| 41 |
|
| 42 |
def validation_step(self, batch, batch_idx):
|
| 43 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
| 44 |
+
predicts = self.post_process(self(images))
|
| 45 |
batch_metrics = self.metric(
|
| 46 |
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
| 47 |
)
|
yolo/utils/model_utils.py
CHANGED
|
@@ -124,7 +124,7 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
|
|
| 124 |
return device, ddp_flag
|
| 125 |
|
| 126 |
|
| 127 |
-
class
|
| 128 |
"""
|
| 129 |
TODO: function document
|
| 130 |
scale back the prediction and do nms for pred_bbox
|
|
|
|
| 124 |
return device, ddp_flag
|
| 125 |
|
| 126 |
|
| 127 |
+
class PostProcess:
|
| 128 |
"""
|
| 129 |
TODO: function document
|
| 130 |
scale back the prediction and do nms for pred_bbox
|