From 896da0c0a06eca5230d5fff2dbd6fd9ce7cfc11d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 29 Aug 2023 15:38:43 +0200 Subject: [PATCH] Tests improvements (#4616) Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> --- .github/workflows/ci.yaml | 6 +++--- tests/test_cuda.py | 2 +- ultralytics/engine/model.py | 2 +- ultralytics/engine/tuner.py | 2 +- ultralytics/utils/loss.py | 18 +++++++++++------- ultralytics/utils/tal.py | 4 ++-- 6 files changed, 19 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2b3e62f9..ea3bab90 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -125,9 +125,9 @@ jobs: python --version pip --version pip list - - name: Benchmark DetectionModel - shell: bash - run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.26 + #- name: Benchmark DetectionModel + # shell: bash + # run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.26 - name: Benchmark SegmentationModel shell: bash run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.30 diff --git a/tests/test_cuda.py b/tests/test_cuda.py index aac42e66..a32220b1 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -94,7 +94,7 @@ def test_model_ray_tune(): @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') def test_model_tune(): - YOLO('yolov8n.pt').tune(data='coco8.yaml', imgsz=32, epochs=1, iterations=1, device='cpu') + YOLO('yolov8n.pt').tune(data='coco8.yaml', imgsz=32, epochs=1, iterations=2, plots=False, device='cpu') @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 97a49bd7..fe638e69 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -361,7 +361,7 @@ class Model: from .tuner import Tuner custom = {} # method defaults - args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right + args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right return Tuner(args=args, _callbacks=self.callbacks)(model=self.model, iterations=iterations) def to(self, device): diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py index 5865c18c..dc44fc9b 100644 --- a/ultralytics/engine/tuner.py +++ b/ultralytics/engine/tuner.py @@ -92,7 +92,7 @@ class Tuner: callbacks.add_integration_callbacks(self) LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.") - def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.6, return_best=False): + def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2, return_best=False): """ Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`. diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 1da0586c..69f08dba 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -19,7 +19,8 @@ class VarifocalLoss(nn.Module): """Initialize the VarifocalLoss class.""" super().__init__() - def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0): + @staticmethod + def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0): """Computes varfocal loss.""" weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label with torch.cuda.amp.autocast(enabled=False): @@ -28,14 +29,14 @@ class VarifocalLoss(nn.Module): return loss -# Losses class FocalLoss(nn.Module): """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" def __init__(self, ): super().__init__() - def forward(self, pred, label, gamma=1.5, alpha=0.25): + @staticmethod + def forward(pred, label, gamma=1.5, alpha=0.25): """Calculates and updates confusion matrix for object detection/classification tasks.""" loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none') # p_t = torch.exp(-loss) @@ -89,6 +90,7 @@ class BboxLoss(nn.Module): class KeypointLoss(nn.Module): + """Criterion class for computing training losses.""" def __init__(self, sigmas) -> None: super().__init__() @@ -103,8 +105,8 @@ class KeypointLoss(nn.Module): return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean() -# Criterion class for computing Detection training losses class v8DetectionLoss: + """Criterion class for computing training losses.""" def __init__(self, model): # model must be de-paralleled @@ -199,8 +201,8 @@ class v8DetectionLoss: return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) -# Criterion class for computing training losses class v8SegmentationLoss(v8DetectionLoss): + """Criterion class for computing training losses.""" def __init__(self, model): # model must be de-paralleled super().__init__(model) @@ -294,8 +296,8 @@ class v8SegmentationLoss(v8DetectionLoss): return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() -# Criterion class for computing training losses class v8PoseLoss(v8DetectionLoss): + """Criterion class for computing training losses.""" def __init__(self, model): # model must be de-paralleled super().__init__(model) @@ -374,7 +376,8 @@ class v8PoseLoss(v8DetectionLoss): return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) - def kpts_decode(self, anchor_points, pred_kpts): + @staticmethod + def kpts_decode(anchor_points, pred_kpts): """Decodes predicted keypoints to image coordinates.""" y = pred_kpts.clone() y[..., :2] *= 2.0 @@ -384,6 +387,7 @@ class v8PoseLoss(v8DetectionLoss): class v8ClassificationLoss: + """Criterion class for computing training losses.""" def __call__(self, preds, batch): """Compute the classification loss between predictions and true labels.""" diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py index 87f45791..f52518ac 100644 --- a/ultralytics/utils/tal.py +++ b/ultralytics/utils/tal.py @@ -122,8 +122,8 @@ class TaskAlignedAssigner(nn.Module): # Normalize align_metric *= mask_pos - pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj - pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj + pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj + pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) target_scores = target_scores * norm_align_metric