RTDETRDetectionModel TorchScript, ONNX Predict and Val support (#8818)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-09 23:25:01 +01:00 committed by GitHub
parent 911d18e4f3
commit af6c02c39b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 52 additions and 7 deletions

View file

@ -36,8 +36,6 @@ class RTDETR(Model):
Raises:
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
"""
if model and Path(model).suffix not in (".pt", ".yaml", ".yml"):
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
super().__init__(model=model, task="detect")
@property

View file

@ -38,7 +38,7 @@ class RTDETRPredictor(BasePredictor):
The method filters detections based on confidence and class if specified in `self.args`.
Args:
preds (torch.Tensor): Raw predictions from the model.
preds (list): List of [predictions, extra] from the model.
img (torch.Tensor): Processed input images.
orig_imgs (list or torch.Tensor): Original, unprocessed images.
@ -46,6 +46,9 @@ class RTDETRPredictor(BasePredictor):
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
and class labels.
"""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

View file

@ -94,6 +94,9 @@ class RTDETRValidator(DetectionValidator):
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz