ultralytics 8.0.153 YOLO Tasks Cleanup (#4314)

This commit is contained in:
Glenn Jocher 2023-08-12 02:30:57 +02:00 committed by GitHub
parent 39395aedc8
commit 822608986c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 87 additions and 55 deletions

View file

@ -43,11 +43,7 @@ class ClassificationTrainer(BaseTrainer):
return model
def setup_model(self):
"""
load/create/download model for any task
"""
# Classification models require special handling
"""load/create/download model for any task"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
@ -65,7 +61,7 @@ class ClassificationTrainer(BaseTrainer):
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # dont return ckpt. Classification doesn't support resume
return # do not return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
@ -102,9 +98,9 @@ class ClassificationTrainer(BaseTrainer):
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
segmentation & detection
"""
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is None:
return keys
@ -144,7 +140,7 @@ class ClassificationTrainer(BaseTrainer):
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO classification model."""
"""Train a YOLO classification model."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''