diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index cf9cf9e2..7b5793a2 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.42" +__version__ = "8.2.43" import os diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index 0ec7e4a5..4c79feff 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -199,6 +199,7 @@ def non_max_suppression( max_nms (int): The maximum number of boxes into torchvision.ops.nms(). max_wh (int): The maximum box width and height in pixels. in_place (bool): If True, the input prediction tensor will be modified in place. + rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. Returns: (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of @@ -212,11 +213,16 @@ def non_max_suppression( assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) - if prediction.shape[-1] == 6: # end-to-end model - return [pred[pred[:, 4] > conf_thres] for pred in prediction] + if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6) + output = [pred[pred[:, 4] > conf_thres] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output - bs = prediction.shape[0] # batch size + bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) nc = nc or (prediction.shape[1] - 4) # number of classes nm = prediction.shape[1] - nc - 4 # number of masks mi = 4 + nc # mask start index @@ -265,7 +271,7 @@ def non_max_suppression( # Filter by class if classes is not None: - x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + x = x[(x[:, 5:6] == classes).any(1)] # Check shape n = x.shape[0] # number of boxes