From 7a79680dcc1d9c8d8da1c3910fa1775110c41255 Mon Sep 17 00:00:00 2001 From: Quet Almahdi Morris Date: Mon, 19 Aug 2024 03:55:01 -0500 Subject: [PATCH] `ultralytics 8.2.79` YOLOv10 CoreML and MPS training "gather" op error fix (#15672) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/__init__.py | 2 +- ultralytics/nn/modules/head.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8636ceb1..93ca50b1 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.78" +__version__ = "8.2.79" import os diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 2add5654..0d658197 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn from torch.nn.init import constant_, xavier_uniform_ +from ultralytics.utils import MACOS from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto @@ -151,13 +152,16 @@ class Detect(nn.Module): boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) - # NOTE: simplify but result slightly lower mAP + # NOTE: simplify result but slightly lower mAP # scores, labels = scores.max(dim=-1) # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) labels = index % nc index = index // nc + # Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error + if MACOS: + index = index.to(torch.int64) boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)