Spaces:
Sleeping
Sleeping
✅ [Pass] the test for lightning train and validate
Browse files- tests/test_tools/test_solver.py +20 -22
tests/test_tools/test_solver.py
CHANGED
|
@@ -1,38 +1,39 @@
|
|
| 1 |
import sys
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import pytest
|
| 5 |
-
from
|
|
|
|
| 6 |
|
| 7 |
project_root = Path(__file__).resolve().parent.parent.parent
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
from yolo.model.yolo import YOLO
|
| 12 |
-
from yolo.tools.data_loader import StreamDataLoader
|
| 13 |
-
from yolo.tools.solver import
|
| 14 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
|
| 15 |
|
| 16 |
|
| 17 |
@pytest.fixture
|
| 18 |
-
def model_validator(validation_cfg: Config
|
| 19 |
-
validator =
|
| 20 |
-
validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
|
| 21 |
-
)
|
| 22 |
return validator
|
| 23 |
|
| 24 |
|
| 25 |
-
def test_model_validator_initialization(model_validator:
|
| 26 |
assert isinstance(model_validator.model, YOLO)
|
| 27 |
-
assert hasattr(
|
| 28 |
|
| 29 |
|
| 30 |
-
def test_model_validator_solve_mock_dataset(
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
assert
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
@pytest.fixture
|
|
@@ -63,17 +64,14 @@ def test_modelv7_tester_solve_single_image(modelv7_tester: ModelTester, file_str
|
|
| 63 |
@pytest.fixture
|
| 64 |
def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
|
| 65 |
train_cfg.task.epoch = 2
|
| 66 |
-
trainer =
|
| 67 |
return trainer
|
| 68 |
|
| 69 |
|
| 70 |
-
def test_model_trainer_initialization(model_trainer:
|
| 71 |
-
|
| 72 |
assert isinstance(model_trainer.model, YOLO)
|
| 73 |
-
assert hasattr(
|
| 74 |
-
assert
|
| 75 |
-
assert model_trainer.scheduler is not None
|
| 76 |
-
assert model_trainer.loss_fn is not None
|
| 77 |
|
| 78 |
|
| 79 |
# def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
|
|
|
|
| 1 |
import sys
|
| 2 |
+
from math import isclose
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
+
from lightning.pytorch import Trainer
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
|
| 9 |
project_root = Path(__file__).resolve().parent.parent.parent
|
| 10 |
sys.path.append(str(project_root))
|
| 11 |
|
| 12 |
from yolo.config.config import Config
|
| 13 |
from yolo.model.yolo import YOLO
|
| 14 |
+
from yolo.tools.data_loader import StreamDataLoader
|
| 15 |
+
from yolo.tools.solver import TrainModel, ValidateModel
|
| 16 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
|
| 17 |
|
| 18 |
|
| 19 |
@pytest.fixture
|
| 20 |
+
def model_validator(validation_cfg: Config):
|
| 21 |
+
validator = ValidateModel(validation_cfg)
|
|
|
|
|
|
|
| 22 |
return validator
|
| 23 |
|
| 24 |
|
| 25 |
+
def test_model_validator_initialization(solver: Trainer, model_validator: ValidateModel):
|
| 26 |
assert isinstance(model_validator.model, YOLO)
|
| 27 |
+
assert hasattr(solver, "validate")
|
| 28 |
|
| 29 |
|
| 30 |
+
def test_model_validator_solve_mock_dataset(
|
| 31 |
+
solver: Trainer, model_validator: ValidateModel, validation_dataloader: DataLoader
|
| 32 |
+
):
|
| 33 |
+
mAPs = solver.validate(model_validator, dataloaders=validation_dataloader)[0]
|
| 34 |
+
except_mAPs = {"map_50": 0.7379, "map": 0.5617}
|
| 35 |
+
assert isclose(mAPs["map_50"], except_mAPs["map_50"], abs_tol=1e-4)
|
| 36 |
+
assert isclose(mAPs["map"], except_mAPs["map"], abs_tol=1e-4)
|
| 37 |
|
| 38 |
|
| 39 |
@pytest.fixture
|
|
|
|
| 64 |
@pytest.fixture
|
| 65 |
def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
|
| 66 |
train_cfg.task.epoch = 2
|
| 67 |
+
trainer = TrainModel(train_cfg)
|
| 68 |
return trainer
|
| 69 |
|
| 70 |
|
| 71 |
+
def test_model_trainer_initialization(solver: Trainer, model_trainer: TrainModel):
|
|
|
|
| 72 |
assert isinstance(model_trainer.model, YOLO)
|
| 73 |
+
assert hasattr(solver, "fit")
|
| 74 |
+
assert solver.optimizers is not None
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
# def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
|