Spaces:
Sleeping
Sleeping
| from math import ceil | |
| from pathlib import Path | |
| from lightning import LightningModule | |
| from torchmetrics.detection import MeanAveragePrecision | |
| from yolo.config.config import Config | |
| from yolo.model.yolo import create_model | |
| from yolo.tools.data_loader import create_dataloader | |
| from yolo.tools.drawer import draw_bboxes | |
| from yolo.tools.loss_functions import create_loss_function | |
| from yolo.utils.bounding_box_utils import create_converter, to_metrics_format | |
| from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler | |
| class BaseModel(LightningModule): | |
| def __init__(self, cfg: Config): | |
| super().__init__() | |
| self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ValidateModel(BaseModel): | |
| def __init__(self, cfg: Config): | |
| super().__init__(cfg) | |
| self.cfg = cfg | |
| if self.cfg.task.task == "validation": | |
| self.validation_cfg = self.cfg.task | |
| else: | |
| self.validation_cfg = self.cfg.task.validation | |
| self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", backend="faster_coco_eval") | |
| self.metric.warn_on_many_detections = False | |
| self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task) | |
| self.ema = self.model | |
| def setup(self, stage): | |
| self.vec2box = create_converter( | |
| self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device | |
| ) | |
| self.post_process = PostProcess(self.vec2box, self.validation_cfg.nms) | |
| def val_dataloader(self): | |
| return self.val_loader | |
| def validation_step(self, batch, batch_idx): | |
| batch_size, images, targets, rev_tensor, img_paths = batch | |
| H, W = images.shape[2:] | |
| predicts = self.post_process(self.ema(images), image_size=[W, H]) | |
| self.metric.update( | |
| [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets] | |
| ) | |
| return predicts | |
| def on_validation_epoch_end(self): | |
| epoch_metrics = self.metric.compute() | |
| del epoch_metrics["classes"] | |
| self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True) | |
| self.log_dict( | |
| {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, | |
| sync_dist=True, | |
| rank_zero_only=True, | |
| ) | |
| self.metric.reset() | |
| class TrainModel(ValidateModel): | |
| def __init__(self, cfg: Config): | |
| super().__init__(cfg) | |
| self.cfg = cfg | |
| self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task) | |
| def setup(self, stage): | |
| super().setup(stage) | |
| self.loss_fn = create_loss_function(self.cfg, self.vec2box) | |
| def train_dataloader(self): | |
| return self.train_loader | |
| def on_train_epoch_start(self): | |
| self.trainer.optimizers[0].next_epoch( | |
| ceil(len(self.train_loader) / self.trainer.world_size), self.current_epoch | |
| ) | |
| self.vec2box.update(self.cfg.image_size) | |
| def training_step(self, batch, batch_idx): | |
| lr_dict = self.trainer.optimizers[0].next_batch() | |
| batch_size, images, targets, *_ = batch | |
| predicts = self(images) | |
| aux_predicts = self.vec2box(predicts["AUX"]) | |
| main_predicts = self.vec2box(predicts["Main"]) | |
| loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets) | |
| self.log_dict( | |
| loss_item, | |
| prog_bar=True, | |
| on_epoch=True, | |
| batch_size=batch_size, | |
| rank_zero_only=True, | |
| ) | |
| self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True) | |
| return loss * batch_size | |
| def configure_optimizers(self): | |
| optimizer = create_optimizer(self.model, self.cfg.task.optimizer) | |
| scheduler = create_scheduler(optimizer, self.cfg.task.scheduler) | |
| return [optimizer], [scheduler] | |
| class InferenceModel(BaseModel): | |
| def __init__(self, cfg: Config): | |
| super().__init__(cfg) | |
| self.cfg = cfg | |
| # TODO: Add FastModel | |
| self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task) | |
| def setup(self, stage): | |
| self.vec2box = create_converter( | |
| self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device | |
| ) | |
| self.post_process = PostProcess(self.vec2box, self.cfg.task.nms) | |
| def predict_dataloader(self): | |
| return self.predict_loader | |
| def predict_step(self, batch, batch_idx): | |
| images, rev_tensor, origin_frame = batch | |
| predicts = self.post_process(self(images), rev_tensor=rev_tensor) | |
| img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list) | |
| if getattr(self.predict_loader, "is_stream", None): | |
| fps = self._display_stream(img) | |
| else: | |
| fps = None | |
| if getattr(self.cfg.task, "save_predict", None): | |
| self._save_image(img, batch_idx) | |
| return img, fps | |
| def _save_image(self, img, batch_idx): | |
| save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png" | |
| img.save(save_image_path) | |
| print(f"๐พ Saved visualize image at {save_image_path}") | |