Update Validator to use model argument (#4480)

This commit is contained in:
Glenn Jocher 2023-08-21 19:21:55 +02:00 committed by GitHub
parent 615ddc9d97
commit b2f279ffdd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 15 additions and 14 deletions

View file

@ -14,7 +14,7 @@ from ultralytics.utils import colorstr, ops
__all__ = 'RTDETRValidator', # tuple or list
# TODO: Temporarily, RT-DETR does not need padding.
# TODO: Temporarily RT-DETR does not need padding.
class RTDETRDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
@ -47,7 +47,7 @@ class RTDETRDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
"""Temporarily, only for evaluation."""
"""Temporary, only for evaluation."""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
@ -76,12 +76,13 @@ class RTDETRValidator(DetectionValidator):
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
validator = RTDETRValidator(args=args)
validator(model=args['model'])
validator()
```
"""
def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset
"""
Build an RTDETR Dataset.
Args:
img_path (str): Path to the folder containing images.

View file

@ -22,7 +22,7 @@ class ClassificationValidator(BaseValidator):
args = dict(model='yolov8n-cls.pt', data='imagenet10')
validator = ClassificationValidator(args=args)
validator(model=args['model'])
validator()
```
"""

View file

@ -25,7 +25,7 @@ class DetectionValidator(BaseValidator):
args = dict(model='yolov8n.pt', data='coco8.yaml')
validator = DetectionValidator(args=args)
validator(model=args['model'])
validator()
```
"""

View file

@ -22,7 +22,7 @@ class PoseValidator(DetectionValidator):
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
validator = PoseValidator(args=args)
validator(model=args['model'])
validator()
```
"""

View file

@ -24,7 +24,7 @@ class SegmentationValidator(DetectionValidator):
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
validator = SegmentationValidator(args=args)
validator(model=args['model'])
validator()
```
"""