ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-01-10 03:16:08 +01:00 committed by GitHub
parent e795277391
commit fe27db2f6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
139 changed files with 6870 additions and 5125 deletions

View file

@ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRDetectionModel): Initialized model.
"""
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def build_dataset(self, img_path, mode='val', batch=None):
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build and return an RT-DETR dataset for training or validation.
@ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRDataset): Dataset object for the specific mode.
"""
return RTDETRDataset(img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == 'train',
hyp=self.args,
rect=False,
cache=self.args.cache or None,
prefix=colorstr(f'{mode}: '),
data=self.data)
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == "train",
hyp=self.args,
rect=False,
cache=self.args.cache or None,
prefix=colorstr(f"{mode}: "),
data=self.data,
)
def get_validator(self):
"""
@ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer):
Returns:
(RTDETRValidator): Validator object for model validation.
"""
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def preprocess_batch(self, batch):
@ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer):
(dict): Preprocessed batch.
"""
batch = super().preprocess_batch(batch)
bs = len(batch['img'])
batch_idx = batch['batch_idx']
bs = len(batch["img"])
batch_idx = batch["batch_idx"]
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch