diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8636ceb1..93ca50b1 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.78" +__version__ = "8.2.79" import os diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 2add5654..0d658197 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn from torch.nn.init import constant_, xavier_uniform_ +from ultralytics.utils import MACOS from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto @@ -151,13 +152,16 @@ class Detect(nn.Module): boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) - # NOTE: simplify but result slightly lower mAP + # NOTE: simplify result but slightly lower mAP # scores, labels = scores.max(dim=-1) # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) labels = index % nc index = index // nc + # Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error + if MACOS: + index = index.to(torch.int64) boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)