ultralytics 8.2.79 YOLOv10 CoreML and MPS training "gather" op error fix (#15672)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Quet Almahdi Morris 2024-08-19 03:55:01 -05:00 committed by GitHub
parent bb3850ffa2
commit 7a79680dcc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 2 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.78" __version__ = "8.2.79"
import os import os

View file

@ -8,6 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_ from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils import MACOS
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
@ -151,13 +152,16 @@ class Detect(nn.Module):
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
# NOTE: simplify but result slightly lower mAP # NOTE: simplify result but slightly lower mAP
# scores, labels = scores.max(dim=-1) # scores, labels = scores.max(dim=-1)
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
labels = index % nc labels = index % nc
index = index // nc index = index // nc
# Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error
if MACOS:
index = index.to(torch.int64)
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)